From a9fb3a9b4c620328d6096922f07b6a320e22d769 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Sun, 9 Nov 2025 18:13:18 +0100 Subject: [PATCH 01/11] add macros to validate SQL queries at compile-time --- build.sbt | 54 ++- core/build.sbt | 4 +- .../sql/macros/SQLQueryValidatorSpec.scala | 171 ++++++++++ .../client/macros/TestElasticClientApi.scala | 31 ++ .../elastic/sql/macros/SQLQueryMacros.scala | 107 ++++++ .../sql/macros/SQLQueryValidator.scala | 308 ++++++++++++++++++ project/Versions.scala | 2 +- sql/build.sbt | 5 +- 8 files changed, 675 insertions(+), 7 deletions(-) create mode 100644 macros-tests/src/test/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidatorSpec.scala create mode 100644 macros/src/main/scala/app/softnetwork/elastic/client/macros/TestElasticClientApi.scala create mode 100644 macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryMacros.scala create mode 100644 macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala diff --git a/build.sbt b/build.sbt index f965f0d6..728347c1 100644 --- a/build.sbt +++ b/build.sbt @@ -19,7 +19,7 @@ ThisBuild / organization := "app.softnetwork" name := "softclient4es" -ThisBuild / version := "0.11.0" +ThisBuild / version := "0.12.0" ThisBuild / scalaVersion := scala213 @@ -103,8 +103,54 @@ lazy val sql = project .in(file("sql")) .configs(IntegrationTest) .settings( + Defaults.itSettings + ) + +lazy val macros = project + .in(file("macros")) + .configs(IntegrationTest) + .settings( + name := "softclient4es-macros", + + libraryDependencies ++= Seq( + "org.scala-lang" % "scala-reflect" % scalaVersion.value, + "org.json4s" %% "json4s-native" % Versions.json4s + ), Defaults.itSettings, - moduleSettings + moduleSettings, + scalacOptions ++= Seq( + "-language:experimental.macros", + "-Ymacro-annotations", + "-Ymacro-debug-lite", // Debug macros + "-Xlog-implicits" // Debug implicits + ) + ) + .dependsOn(sql) + +lazy val macrosTests = project + .in(file("macros-tests")) + .configs(IntegrationTest) + .settings( + name := "softclient4es-macros-tests", + Publish.noPublishSettings, + + libraryDependencies ++= Seq( + "org.scalatest" %% "scalatest" % Versions.scalatest % Test + ), + + Defaults.itSettings, + moduleSettings, + + scalacOptions ++= Seq( + "-language:experimental.macros", + "-Ymacro-debug-lite" + ), + + Test / scalacOptions += "-Xlog-free-terms" + ) + .dependsOn( + macros % "compile->compile", + sql % "compile->compile" ) lazy val core = project @@ -115,7 +161,7 @@ lazy val core = project moduleSettings ) .dependsOn( - sql % "compile->compile;test->test;it->it" + macros % "compile->compile;test->test;it->it" ) lazy val persistence = project @@ -432,6 +478,8 @@ lazy val root = project ) .aggregate( sql, + macros, + macrosTests, bridge, core, persistence, diff --git a/core/build.sbt b/core/build.sbt index e1cd0f88..58a8a719 100644 --- a/core/build.sbt +++ b/core/build.sbt @@ -32,4 +32,6 @@ val mockito = Seq( libraryDependencies ++= akka ++ typesafeConfig ++ http ++ json4s ++ mockito :+ "com.google.code.gson" % "gson" % Versions.gson :+ - "com.typesafe.scala-logging" %% "scala-logging" % Versions.scalaLogging + "com.typesafe.scala-logging" %% "scala-logging" % Versions.scalaLogging :+ + "org.scalatest" %% "scalatest" % Versions.scalatest % Test + diff --git a/macros-tests/src/test/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidatorSpec.scala b/macros-tests/src/test/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidatorSpec.scala new file mode 100644 index 00000000..8088b6a1 --- /dev/null +++ b/macros-tests/src/test/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidatorSpec.scala @@ -0,0 +1,171 @@ +package app.softnetwork.elastic.sql.macros + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers { + + // ============================================================ + // Positive Tests (Should Compile) + // ============================================================ + + "SQLQueryValidator" should "validate all numeric types" in { + assertCompiles(""" + import app.softnetwork.elastic.client.macros.TestElasticClientApi + import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats + import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.Numbers + import app.softnetwork.elastic.sql.query.SQLQuery + + TestElasticClientApi.searchAs[Numbers]( + "SELECT tiny::TINYINT as tiny, small::SMALLINT as small, normal::INT as normal, big::BIGINT as big, huge::BIGINT as huge, decimal::DOUBLE as decimal, r::REAL as r FROM numbers" + )""") + } + + it should "validate string types" in { + assertCompiles(""" + import app.softnetwork.elastic.client.macros.TestElasticClientApi + import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats + import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.Strings + import app.softnetwork.elastic.sql.query.SQLQuery + + TestElasticClientApi.searchAs[Strings]( + "SELECT vchar::VARCHAR, c::CHAR, text FROM strings" + )""") + } + + it should "validate temporal types" in { + assertCompiles(""" + import app.softnetwork.elastic.client.macros.TestElasticClientApi + import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats + import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.Temporal + import app.softnetwork.elastic.sql.query.SQLQuery + + TestElasticClientApi.searchAs[Temporal]( + "SELECT d::DATE, t::TIME, dt::DATETIME, ts::TIMESTAMP FROM temporal" + )""") + } + + it should "validate Product with all fields" in { + assertCompiles(""" + import app.softnetwork.elastic.client.macros.TestElasticClientApi + import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats + import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.Product + import app.softnetwork.elastic.sql.query.SQLQuery + + TestElasticClientApi.searchAs[Product]( + "SELECT id, name, price::DOUBLE, stock::INT, active::BOOLEAN, createdAt::DATETIME FROM products" + )""") + } + + it should "validate with aliases" in { + assertCompiles(""" + import app.softnetwork.elastic.client.macros.TestElasticClientApi + import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats + import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.Product + import app.softnetwork.elastic.sql.query.SQLQuery + + TestElasticClientApi.searchAs[Product]( + "SELECT product_id AS id, product_name AS name, product_price::DOUBLE AS price, product_stock::INT AS stock, is_active::BOOLEAN AS active, created_at::TIMESTAMP AS createdAt FROM products" + )""") + } + + // ============================================================ + // Negative Tests (Should NOT Compile) + // ============================================================ + + it should "reject missing fields" in { + assertDoesNotCompile(""" + import app.softnetwork.elastic.client.macros.TestElasticClientApi + import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats + import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.Product + import app.softnetwork.elastic.sql.query.SQLQuery + + TestElasticClientApi.searchAs[Product]( + "SELECT id, name FROM products" + )""") + } + + it should "reject invalid field names" in { + assertDoesNotCompile(""" + import app.softnetwork.elastic.client.macros.TestElasticClientApi + import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats + import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.Product + import app.softnetwork.elastic.sql.query.SQLQuery + + TestElasticClientApi.searchAs[Product]( + "SELECT id, invalid_name, price, stock, active, createdAt FROM products" + )""") + } + + it should "reject type mismatches" in { + assertDoesNotCompile(""" + import app.softnetwork.elastic.client.macros.TestElasticClientApi + import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats + import app.softnetwork.elastic.sql.query.SQLQuery + + case class WrongTypes(id: Int, name: Int) + + TestElasticClientApi.searchAs[WrongTypes]( + "SELECT id::LONG, name FROM products" + )""") + } + + it should "suggest closest field names" in { + assertDoesNotCompile(""" + import app.softnetwork.elastic.client.macros.TestElasticClientApi + import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats + import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.Product + import app.softnetwork.elastic.sql.query.SQLQuery + + TestElasticClientApi.searchAs[Product]( + "SELECT id, nam, price, stock, active, createdAt FROM products" + )""") + } + + it should "reject dynamic queries (non-literals)" in { + assertDoesNotCompile(""" + import app.softnetwork.elastic.client.macros.TestElasticClientApi + import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats + import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.Product + import app.softnetwork.elastic.sql.query.SQLQuery + + val dynamicField = "name" + TestElasticClientApi.searchAs[Product]( + s"SELECT id, $dynamicField FROM products" + )""") + } +} + +object SQLQueryValidatorSpec { + case class Product( + id: String, + name: String, + price: Double, + stock: Int, + active: Boolean, + createdAt: java.time.LocalDateTime + ) + + case class Numbers( + tiny: Byte, + small: Short, + normal: Int, + big: Long, + huge: BigInt, + decimal: Double, + r: Float + ) + + case class Strings( + vchar: String, + c: String, + text: String + ) + + case class Temporal( + d: java.time.LocalDate, + t: java.time.LocalTime, + dt: java.time.LocalDateTime, + ts: java.time.Instant + ) +} diff --git a/macros/src/main/scala/app/softnetwork/elastic/client/macros/TestElasticClientApi.scala b/macros/src/main/scala/app/softnetwork/elastic/client/macros/TestElasticClientApi.scala new file mode 100644 index 00000000..8e022628 --- /dev/null +++ b/macros/src/main/scala/app/softnetwork/elastic/client/macros/TestElasticClientApi.scala @@ -0,0 +1,31 @@ +package app.softnetwork.elastic.client.macros + +import app.softnetwork.elastic.sql.macros.SQLQueryMacros +import app.softnetwork.elastic.sql.query.SQLQuery +import org.json4s.{DefaultFormats, Formats} + +import scala.language.experimental.macros + +/** Test trait that uses macros for compile-time validation. + */ +trait TestElasticClientApi { + + /** Search with compile-time SQL validation (macro). + */ + def searchAs[T](query: String)(implicit m: Manifest[T], formats: Formats): Seq[T] = + macro SQLQueryMacros.searchAsImpl[T] + + /** Search without compile-time validation (runtime). + */ + def searchAsUnchecked[T]( + sqlQuery: SQLQuery + )(implicit m: Manifest[T], formats: Formats): Seq[T] = { + // Dummy implementation for tests + Seq.empty[T] + } +} + +object TestElasticClientApi extends TestElasticClientApi { + // default implicit for the tests + implicit val defaultFormats: Formats = DefaultFormats +} diff --git a/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryMacros.scala b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryMacros.scala new file mode 100644 index 00000000..d28c1076 --- /dev/null +++ b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryMacros.scala @@ -0,0 +1,107 @@ +/* + * Copyright 2025 SOFTNETWORK + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package app.softnetwork.elastic.sql.macros + +import org.json4s.Formats + +import scala.language.experimental.macros +import scala.reflect.macros.blackbox + +object SQLQueryMacros extends SQLQueryValidator { + + // ============================================================ + // searchAs + // ============================================================ + + def searchAsImpl[T: c.WeakTypeTag](c: blackbox.Context)( + query: c.Expr[String] + )( + m: c.Expr[Manifest[T]], + formats: c.Expr[Formats] + ): c.Expr[Seq[T]] = { + import c.universe._ + + c.echo(c.enclosingPosition, "=" * 60) + c.echo(c.enclosingPosition, "🚀🚀🚀 MACRO searchAsImpl CALLED 🚀🚀🚀") + c.echo(c.enclosingPosition, "=" * 60) + + val tpe = weakTypeOf[T] + val validatedQuery = validateSQLQuery[T](c)(query) + + c.Expr[Seq[T]](q""" + ${c.prefix}.searchAsUnchecked[$tpe]( + SQLQuery($validatedQuery) + )($m, $formats) + """) + } + + // ============================================================ + // searchAsyncAs + // ============================================================ + + def searchAsyncAsImpl[U: c.WeakTypeTag](c: blackbox.Context)( + sqlQuery: c.Expr[String] + )( + m: c.Expr[Any], + ec: c.Expr[Any], + formats: c.Expr[Formats] + ): c.Expr[Any] = { + import c.universe._ + + c.echo(c.enclosingPosition, "=" * 60) + c.echo(c.enclosingPosition, "🚀🚀🚀 MACRO searchAsyncAsImpl CALLED 🚀🚀🚀") + c.echo(c.enclosingPosition, "=" * 60) + + val tpe = weakTypeOf[U] + val validatedQuery = validateSQLQuery[U](c)(sqlQuery) + + c.Expr[Any](q""" + ${c.prefix}.searchAsyncAsUnchecked[$tpe]( + SQLQuery($validatedQuery) + )($m, $ec, $formats) + """) + } + + // ============================================================ + // scrollAs + // ============================================================ + + def scrollAsImpl[T: c.WeakTypeTag](c: blackbox.Context)( + sql: c.Expr[String], + config: c.Expr[Any] + )( + system: c.Expr[Any], + m: c.Expr[Any], + formats: c.Expr[Formats] + ): c.Expr[Any] = { + import c.universe._ + + val tpe = weakTypeOf[T] + val validatedQuery = validateSQLQuery[T](c)(sql) + + c.echo(c.enclosingPosition, "=" * 60) + c.echo(c.enclosingPosition, "🚀🚀🚀 MACRO scrollAsImpl CALLED 🚀🚀🚀") + c.echo(c.enclosingPosition, "=" * 60) + + c.Expr[Any](q""" + ${c.prefix}.scrollAsUnchecked[$tpe]( + SQLQuery($validatedQuery), + $config + )($system, $m, $formats) + """) + } +} diff --git a/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala new file mode 100644 index 00000000..7ff81392 --- /dev/null +++ b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala @@ -0,0 +1,308 @@ +/* + * Copyright 2025 SOFTNETWORK + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package app.softnetwork.elastic.sql.macros + +import app.softnetwork.elastic.sql.`type`.{SQLType, SQLTypes} +import app.softnetwork.elastic.sql.parser.Parser +import app.softnetwork.elastic.sql.query.SQLSearchRequest + +import scala.language.experimental.macros +import scala.reflect.macros.blackbox + +/** Reusable core validation logic for all SQL macros. + */ +trait SQLQueryValidator { + + /** Validates an SQL query against a type T. Returns the SQL query if valid, otherwise aborts + * compilation. + */ + protected def validateSQLQuery[T: c.WeakTypeTag](c: blackbox.Context)( + query: c.Expr[String] + ): String = { + + c.echo(c.enclosingPosition, "🚀 MACRO IS BEING CALLED!") + + // 1. Extract the SQL query (must be a literal) + val sqlQuery = extractStringLiteral(c)(query) + + if (sys.props.get("elastic.sql.debug").contains("true")) { + c.info(c.enclosingPosition, s"Validating SQL: $sqlQuery", force = false) + } + + // 2. Parse the SQL query + val parsedQuery = parseSQLQuery(c)(sqlQuery) + + // 3. Extract the selected fields + val queryFields = extractQueryFields(parsedQuery) + + c.echo(c.enclosingPosition, s"🔍 Parsed fields: ${queryFields.mkString(", ")}") + + // 4. Extract the fields from case class T + val tpe = c.weakTypeOf[T] + val caseClassFields = extractCaseClassFields(c)(tpe) + c.echo(c.enclosingPosition, s"📦 Case class fields: ${caseClassFields.mkString(", ")}") + + // 5. Validate the fields + validateFields(c)(queryFields, caseClassFields, tpe) + + // 6. Validate the types + validateTypes(c)(parsedQuery, caseClassFields) + + // 7. Return the validated request + sqlQuery + } + + // ============================================================ + // Helper Methods + // ============================================================ + + private def extractStringLiteral(c: blackbox.Context)( + query: c.Expr[String] + ): String = { + import c.universe._ + + query.tree match { + case Literal(Constant(sql: String)) => + c.echo(c.enclosingPosition, s"📝 Query: $sql") + sql + case other => + c.echo(c.enclosingPosition, s"❌ Not a literal: ${showRaw(other)}") + c.abort( + c.enclosingPosition, + "❌ SQL query must be a string literal for compile-time validation. " + + "Use the *Unchecked() variant for dynamic queries." + ) + } + } + + private def parseSQLQuery(c: blackbox.Context)(sqlQuery: String): SQLSearchRequest = { + Parser(sqlQuery) match { + case Right(Left(request)) => + request + + case Right(Right(multi)) => + multi.requests.headOption.getOrElse { + c.abort(c.enclosingPosition, "Empty multi-search query") + } + + case Left(error) => + c.abort( + c.enclosingPosition, + s"❌ SQL parsing error: ${error.msg}\n" + + s"Query: $sqlQuery" + ) + } + } + + private def extractQueryFields(parsedQuery: SQLSearchRequest): Set[String] = { + parsedQuery.select.fields.map { field => + field.fieldAlias.map(_.alias).getOrElse(field.identifier.name) + }.toSet + } + + private def extractCaseClassFields(c: blackbox.Context)( + tpe: c.universe.Type + ): Map[String, c.universe.Type] = { + import c.universe._ + + tpe.members.collect { + case m: MethodSymbol if m.isCaseAccessor => + m.name.toString -> m.returnType + }.toMap + } + + private def validateFields(c: blackbox.Context)( + queryFields: Set[String], + caseClassFields: Map[String, c.universe.Type], + tpe: c.universe.Type + ): Unit = { + val missingFields = caseClassFields.keySet -- queryFields + + if (missingFields.nonEmpty) { + val availableFields = caseClassFields.keys.toSeq.sorted.mkString(", ") + val suggestions = missingFields.flatMap { missing => + findClosestMatch(missing, caseClassFields.keys.toSeq) + } + + val suggestionMsg = if (suggestions.nonEmpty) { + s"\nDid you mean: ${suggestions.mkString(", ")}?" + } else "" + + c.abort( + c.enclosingPosition, + s"❌ SQL case class fields in ${tpe.typeSymbol.name} not present in ${queryFields.mkString(",")}: " + + s"${missingFields.mkString(", ")}\n" + + s"Available fields: $availableFields$suggestionMsg" + ) + } + } + + private def validateTypes(c: blackbox.Context)( + parsedQuery: SQLSearchRequest, + caseClassFields: Map[String, c.universe.Type] + ): Unit = { + + parsedQuery.select.fields.foreach { field => + val fieldName = field.fieldAlias.map(_.alias).getOrElse(field.identifier.name) + + (field.out, caseClassFields.get(fieldName)) match { + case (sqlType, Some(scalaType)) => + if (!areTypesCompatible(c)(sqlType, scalaType)) { + c.abort( + c.enclosingPosition, + s"Type mismatch for field '$fieldName': " + + s"SQL type $sqlType is incompatible with Scala type ${scalaType.toString}\n" + + s"Expected one of: ${getCompatibleScalaTypes(sqlType)}" + ) + } + case _ => // Cannot validate without type info + } + } + } + + private def areTypesCompatible(c: blackbox.Context)( + sqlType: SQLType, + scalaType: c.universe.Type + ): Boolean = { + import c.universe._ + + sqlType match { + case SQLTypes.TinyInt => + scalaType =:= typeOf[Byte] || + scalaType =:= typeOf[Short] || + scalaType =:= typeOf[Int] || + scalaType =:= typeOf[Long] || + scalaType =:= typeOf[Option[Byte]] || + scalaType =:= typeOf[Option[Short]] || + scalaType =:= typeOf[Option[Int]] || + scalaType =:= typeOf[Option[Long]] + + case SQLTypes.SmallInt => + scalaType =:= typeOf[Short] || + scalaType =:= typeOf[Int] || + scalaType =:= typeOf[Long] || + scalaType =:= typeOf[Option[Short]] || + scalaType =:= typeOf[Option[Int]] || + scalaType =:= typeOf[Option[Long]] + + case SQLTypes.Int => + scalaType =:= typeOf[Int] || + scalaType =:= typeOf[Long] || + scalaType =:= typeOf[Option[Int]] || + scalaType =:= typeOf[Option[Long]] + + case SQLTypes.BigInt => + scalaType =:= typeOf[Long] || + scalaType =:= typeOf[BigInt] || + scalaType =:= typeOf[Option[Long]] || + scalaType =:= typeOf[Option[BigInt]] + + case SQLTypes.Double | SQLTypes.Real => + scalaType =:= typeOf[Double] || + scalaType =:= typeOf[Float] || + scalaType =:= typeOf[Option[Double]] || + scalaType =:= typeOf[Option[Float]] + + case SQLTypes.Char => + scalaType =:= typeOf[String] || // CHAR(n) → String + scalaType =:= typeOf[Char] || // CHAR(1) → Char + scalaType =:= typeOf[Option[String]] || + scalaType =:= typeOf[Option[Char]] + + case SQLTypes.Varchar => + scalaType =:= typeOf[String] || + scalaType =:= typeOf[Option[String]] + + case SQLTypes.Boolean => + scalaType =:= typeOf[Boolean] || + scalaType =:= typeOf[Option[Boolean]] + + case SQLTypes.Time => + scalaType.toString.contains("Instant") || + scalaType.toString.contains("LocalTime") + + case SQLTypes.Date => + scalaType.toString.contains("Date") || + scalaType.toString.contains("Instant") || + scalaType.toString.contains("LocalDate") + + case SQLTypes.DateTime | SQLTypes.Timestamp => + scalaType.toString.contains("LocalDateTime") || + scalaType.toString.contains("ZonedDateTime") || + scalaType.toString.contains("Instant") + + case SQLTypes.Struct => + if (scalaType.typeSymbol.isClass && scalaType.typeSymbol.asClass.isCaseClass) { + // validateStructFields(c)(sqlField, scalaType) + true + } else { + false + } + + case _ => + true // Cannot validate unknown types + } + } + + private def getCompatibleScalaTypes(sqlType: SQLType): String = { + sqlType match { + case SQLTypes.TinyInt => + "Byte, Short, Int, Long, Option[Byte], Option[Short], Option[Int], Option[Long]" + case SQLTypes.SmallInt => "Short, Int, Long, Option[Short], Option[Int], Option[Long]" + case SQLTypes.Int => "Int, Long, Option[Int], Option[Long]" + case SQLTypes.BigInt => "Long, BigInt, Option[Long], Option[BigInt]" + case SQLTypes.Double | SQLTypes.Real => "Double, Float, Option[Double], Option[Float]" + case SQLTypes.Varchar => "String, Option[String]" + case SQLTypes.Char => "String, Char, Option[String], Option[Char]" + case SQLTypes.Boolean => "Boolean, Option[Boolean]" + case SQLTypes.Time => "java.time.LocalTime, java.time.Instant" + case SQLTypes.Date => "java.time.LocalDate, java.time.Instant, java.util.Date" + case SQLTypes.DateTime | SQLTypes.Timestamp => + "java.time.LocalDateTime, java.time.ZonedDateTime, java.time.Instant" + case SQLTypes.Struct => "Case Class" + case _ => "Unknown" + } + } + + private def findClosestMatch(target: String, candidates: Seq[String]): Option[String] = { + if (candidates.isEmpty) None + else { + val distances = candidates.map { candidate => + (candidate, levenshteinDistance(target.toLowerCase, candidate.toLowerCase)) + } + val closest = distances.minBy(_._2) + if (closest._2 <= 3) Some(closest._1) else None + } + } + + private def levenshteinDistance(s1: String, s2: String): Int = { + val dist = Array.tabulate(s2.length + 1, s1.length + 1) { (j, i) => + if (j == 0) i else if (i == 0) j else 0 + } + + for { + j <- 1 to s2.length + i <- 1 to s1.length + } { + dist(j)(i) = + if (s2(j - 1) == s1(i - 1)) dist(j - 1)(i - 1) + else (dist(j - 1)(i) min dist(j)(i - 1) min dist(j - 1)(i - 1)) + 1 + } + + dist(s2.length)(s1.length) + } +} diff --git a/project/Versions.scala b/project/Versions.scala index 336b9e75..e2ee035c 100644 --- a/project/Versions.scala +++ b/project/Versions.scala @@ -46,6 +46,6 @@ object Versions { val genericPersistence = "0.8.0" - val gson = "2.8.0" + val gson = "2.8.9" } diff --git a/sql/build.sbt b/sql/build.sbt index 0a065f7f..eb0e89c6 100644 --- a/sql/build.sbt +++ b/sql/build.sbt @@ -1,4 +1,5 @@ -import SoftClient4es._ +import SoftClient4es.* +import sbt.Keys.scalaVersion organization := "app.softnetwork.elastic" @@ -17,7 +18,7 @@ libraryDependencies ++= jacksonDependencies(elasticSearchVersion.value) ++ "javax.activation" % "activation" % "1.1.1" % Test ) :+ // ("app.softnetwork.persistence" %% "persistence-core" % Versions.genericPersistence excludeAll(jacksonExclusions: _*)) :+ - "org.scala-lang" % "scala-reflect" % "2.13.16" :+ + "org.scala-lang" % "scala-reflect" % scalaVersion.value :+ "com.google.code.gson" % "gson" % Versions.gson % Test From 71979ef1dc255edea5fe298cc7de459e90985b8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Sun, 9 Nov 2025 18:57:53 +0100 Subject: [PATCH 02/11] update macros so as to reject SELECT * queries + accept queries with missing fields that have defaults or are options --- .../sql/macros/SQLQueryValidatorSpec.scala | 89 ++++++++- .../sql/macros/SQLQueryValidator.scala | 176 ++++++++++++++++-- 2 files changed, 250 insertions(+), 15 deletions(-) diff --git a/macros-tests/src/test/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidatorSpec.scala b/macros-tests/src/test/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidatorSpec.scala index 8088b6a1..4b72908b 100644 --- a/macros-tests/src/test/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidatorSpec.scala +++ b/macros-tests/src/test/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidatorSpec.scala @@ -69,11 +69,50 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers { )""") } + it should "accept query with missing Option fields" in { + assertCompiles(""" + import app.softnetwork.elastic.client.macros.TestElasticClientApi + import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats + import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.ProductWithOptional + import app.softnetwork.elastic.sql.query.SQLQuery + + TestElasticClientApi.searchAs[ProductWithOptional]( + "SELECT id, name FROM products" + ) + """) + } + + it should "accept query with missing fields that have defaults" in { + assertCompiles(""" + import app.softnetwork.elastic.client.macros.TestElasticClientApi + import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats + import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.ProductWithDefaults + import app.softnetwork.elastic.sql.query.SQLQuery + + TestElasticClientApi.searchAs[ProductWithDefaults]( + "SELECT id, name FROM products" + ) + """) + } + + it should "accept SELECT * with Unchecked variant" in { + assertCompiles(""" + import app.softnetwork.elastic.client.macros.TestElasticClientApi + import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats + import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.Product + import app.softnetwork.elastic.sql.query.SQLQuery + + TestElasticClientApi.searchAsUnchecked[Product]( + SQLQuery("SELECT * FROM products") + ) + """) + } + // ============================================================ // Negative Tests (Should NOT Compile) // ============================================================ - it should "reject missing fields" in { + it should "REJECT query with missing required fields" in { assertDoesNotCompile(""" import app.softnetwork.elastic.client.macros.TestElasticClientApi import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats @@ -85,7 +124,7 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers { )""") } - it should "reject invalid field names" in { + it should "REJECT query with invalid field names" in { assertDoesNotCompile(""" import app.softnetwork.elastic.client.macros.TestElasticClientApi import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats @@ -97,7 +136,7 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers { )""") } - it should "reject type mismatches" in { + it should "REJECT query with type mismatches" in { assertDoesNotCompile(""" import app.softnetwork.elastic.client.macros.TestElasticClientApi import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats @@ -122,7 +161,7 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers { )""") } - it should "reject dynamic queries (non-literals)" in { + it should "REJECT dynamic queries (non-literals)" in { assertDoesNotCompile(""" import app.softnetwork.elastic.client.macros.TestElasticClientApi import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats @@ -134,6 +173,33 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers { s"SELECT id, $dynamicField FROM products" )""") } + + it should "REJECT SELECT * queries" in { + assertDoesNotCompile(""" + import app.softnetwork.elastic.client.macros.TestElasticClientApi + import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats + import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.Product + import app.softnetwork.elastic.sql.query.SQLQuery + + TestElasticClientApi.searchAs[Product]( + "SELECT * FROM products" + ) + """) + } + + it should "REJECT SELECT * even with WHERE clause" in { + assertDoesNotCompile(""" + import app.softnetwork.elastic.client.macros.TestElasticClientApi + import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats + import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.Product + import app.softnetwork.elastic.sql.query.SQLQuery + + TestElasticClientApi.searchAs[Product]( + "SELECT * FROM products WHERE active = true" + ) + """) + } + } object SQLQueryValidatorSpec { @@ -146,6 +212,21 @@ object SQLQueryValidatorSpec { createdAt: java.time.LocalDateTime ) + case class ProductWithOptional( + id: String, + name: String, + price: Option[Double], // ✅ OK if missing + stock: Option[Int] // ✅ OK if missing + ) + + // Case class with default values + case class ProductWithDefaults( + id: String, + name: String, + price: Double = 0.0, // ✅ OK if missing + stock: Int = 0 // ✅ OK if missing + ) + case class Numbers( tiny: Byte, small: Short, diff --git a/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala index 7ff81392..f37900bf 100644 --- a/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala +++ b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala @@ -29,6 +29,9 @@ trait SQLQueryValidator { /** Validates an SQL query against a type T. Returns the SQL query if valid, otherwise aborts * compilation. + * @note + * query fields must not exist in case class because we are using Jackson to deserialize the + * results with the following option DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES = false */ protected def validateSQLQuery[T: c.WeakTypeTag](c: blackbox.Context)( query: c.Expr[String] @@ -46,6 +49,11 @@ trait SQLQueryValidator { // 2. Parse the SQL query val parsedQuery = parseSQLQuery(c)(sqlQuery) + // ============================================================ + // ✅ NEW: Reject SELECT * + // ============================================================ + rejectSelectStar(c)(parsedQuery, sqlQuery) + // 3. Extract the selected fields val queryFields = extractQueryFields(parsedQuery) @@ -56,13 +64,13 @@ trait SQLQueryValidator { val caseClassFields = extractCaseClassFields(c)(tpe) c.echo(c.enclosingPosition, s"📦 Case class fields: ${caseClassFields.mkString(", ")}") - // 5. Validate the fields - validateFields(c)(queryFields, caseClassFields, tpe) + // 5. Validate: missing case class fields must have defaults or be Option + validateMissingFieldsHaveDefaults(c)(queryFields, caseClassFields, tpe) - // 6. Validate the types + // 7. Validate the types validateTypes(c)(parsedQuery, caseClassFields) - // 7. Return the validated request + // 8. Return the validated request sqlQuery } @@ -108,6 +116,47 @@ trait SQLQueryValidator { } } + // ============================================================ + // ✅ Reject SELECT * (incompatible with compile-time validation) + // ============================================================ + private def rejectSelectStar(c: blackbox.Context)( + parsedQuery: SQLSearchRequest, + sqlQuery: String + ): Unit = { + + // Check if any field is a wildcard (*) + val hasWildcard = parsedQuery.select.fields.exists { field => + field.identifier.name == "*" + } + + if (hasWildcard) { + c.abort( + c.enclosingPosition, + s"""❌ SELECT * is not allowed with compile-time validation. + | + |Query: $sqlQuery + | + |Reason: + | • Cannot validate field existence at compile-time + | • Cannot validate type compatibility at compile-time + | • Schema changes will break silently at runtime + | + |Solution: + | 1. Explicitly list all required fields: + | SELECT id, name, price FROM products + | + | 2. Use the *Unchecked() variant for dynamic queries: + | searchAsUnchecked[Product](SQLQuery("SELECT * FROM products")) + | + |Best Practice: + | Always explicitly select only the fields you need. + |""".stripMargin + ) + } + + c.echo(c.enclosingPosition, "✅ No SELECT * detected") + } + private def extractQueryFields(parsedQuery: SQLSearchRequest): Set[String] = { parsedQuery.select.fields.map { field => field.fieldAlias.map(_.alias).getOrElse(field.identifier.name) @@ -125,17 +174,21 @@ trait SQLQueryValidator { }.toMap } - private def validateFields(c: blackbox.Context)( + // ============================================================ + // ✅ VALIDATION 1: Query fields must exist in case class + // ============================================================ + @deprecated + private def validateQueryFieldsExist(c: blackbox.Context)( queryFields: Set[String], caseClassFields: Map[String, c.universe.Type], tpe: c.universe.Type ): Unit = { - val missingFields = caseClassFields.keySet -- queryFields + val unknownFields = queryFields.filterNot(f => caseClassFields.contains(f)) - if (missingFields.nonEmpty) { + if (unknownFields.nonEmpty) { val availableFields = caseClassFields.keys.toSeq.sorted.mkString(", ") - val suggestions = missingFields.flatMap { missing => - findClosestMatch(missing, caseClassFields.keys.toSeq) + val suggestions = unknownFields.flatMap { unknown => + findClosestMatch(unknown, caseClassFields.keys.toSeq) } val suggestionMsg = if (suggestions.nonEmpty) { @@ -144,13 +197,112 @@ trait SQLQueryValidator { c.abort( c.enclosingPosition, - s"❌ SQL case class fields in ${tpe.typeSymbol.name} not present in ${queryFields.mkString(",")}: " + - s"${missingFields.mkString(", ")}\n" + + s"❌ SQL query selects fields not present in ${tpe.typeSymbol.name}: " + + s"${unknownFields.mkString(", ")}\n" + s"Available fields: $availableFields$suggestionMsg" ) } + + c.echo(c.enclosingPosition, "✅ All query fields exist in case class") } + // ============================================================ + // ✅ VALIDATION 2: Missing fields must have defaults or be Option + // ============================================================ + private def validateMissingFieldsHaveDefaults(c: blackbox.Context)( + queryFields: Set[String], + caseClassFields: Map[String, c.universe.Type], + tpe: c.universe.Type + ): Unit = { + import c.universe._ + + val missingFields = caseClassFields.keySet -- queryFields + + if (missingFields.isEmpty) { + c.echo(c.enclosingPosition, "✅ No missing fields to validate") + return + } + + c.echo(c.enclosingPosition, s"⚠️ Missing fields: ${missingFields.mkString(", ")}") + + // Get constructor parameters with their positions + val constructor = tpe.decl(termNames.CONSTRUCTOR).asMethod + val params = constructor.paramLists.flatten + + // Build map: fieldName -> (index, hasDefault, isOption) + val fieldInfo = params.zipWithIndex.map { case (param, idx) => + val fieldName = param.name.toString + val fieldType = param.typeSignature + + // Check if Option + val isOption = fieldType.typeConstructor =:= typeOf[Option[_]].typeConstructor + + // Check if has default value + val companionSymbol = tpe.typeSymbol.companion + val hasDefault = if (companionSymbol != NoSymbol) { + val companionType = companionSymbol.typeSignature + val defaultMethodName = s"apply$$default$$${idx + 1}" + companionType.member(TermName(defaultMethodName)) != NoSymbol + } else { + false + } + + (fieldName, (idx, hasDefault, isOption)) + }.toMap + + // Check each missing field + val fieldsWithoutDefaults = missingFields.filterNot { fieldName => + fieldInfo.get(fieldName) match { + case Some((_, hasDefault, isOption)) => + if (isOption) { + c.echo(c.enclosingPosition, s"✅ Field '$fieldName' is Option - OK") + true + } else if (hasDefault) { + c.echo(c.enclosingPosition, s"✅ Field '$fieldName' has default value - OK") + true + } else { + c.echo(c.enclosingPosition, s"❌ Field '$fieldName' has NO default and is NOT Option") + false + } + case None => + c.echo(c.enclosingPosition, s"⚠️ Field '$fieldName' not found in constructor") + false + } + } + + if (fieldsWithoutDefaults.nonEmpty) { + c.abort( + c.enclosingPosition, + s"❌ SQL query does not select the following required fields from ${tpe.typeSymbol.name}:\n" + + s" ${fieldsWithoutDefaults.mkString(", ")}\n\n" + + s"These fields are missing from the query:\n" + + s" SELECT ${queryFields.mkString(", ")} FROM ...\n\n" + + s"To fix this, either:\n" + + s" 1. Add them to the SELECT clause\n" + + s" 2. Make them Option[T] in the case class\n" + + s" 3. Provide default values in the case class definition" + ) + } + + c.echo(c.enclosingPosition, "✅ All missing fields have defaults or are Option") + } + + // Helper: Get the index of a field in the case class constructor + private def getFieldIndex(c: blackbox.Context)( + tpe: c.universe.Type, + fieldName: String + ): Int = { + import c.universe._ + + val constructor = tpe.decl(termNames.CONSTRUCTOR).asMethod + val params = constructor.paramLists.flatten + + params.indexWhere(_.name.toString == fieldName) + } + + // ============================================================ + // VALIDATION 3: Type compatibility + // ============================================================ private def validateTypes(c: blackbox.Context)( parsedQuery: SQLSearchRequest, caseClassFields: Map[String, c.universe.Type] @@ -172,6 +324,8 @@ trait SQLQueryValidator { case _ => // Cannot validate without type info } } + + c.echo(c.enclosingPosition, "✅ Type validation passed") } private def areTypesCompatible(c: blackbox.Context)( From 8178158de0fe0fc2cc050acb8e1767107c629bd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Mon, 10 Nov 2025 08:09:59 +0100 Subject: [PATCH 03/11] integrate macros within ElasticClientApi --- build.sbt | 10 +- .../client/ElasticClientDelegator.scala | 44 ++++-- .../elastic/client/ScrollApi.scala | 52 ++++++- .../elastic/client/SearchApi.scala | 50 ++++++- .../client/metrics/MetricsElasticClient.scala | 34 ++++- .../client/macros/TestElasticClientApi.scala | 16 ++ .../elastic/sql/macros/SQLQueryMacros.scala | 50 +++---- .../sql/macros/SQLQueryValidator.scala | 137 ++++++++++++++---- .../persistence/query/ElasticProvider.scala | 2 +- .../elastic/client/ElasticClientSpec.scala | 76 +++++++--- 10 files changed, 377 insertions(+), 94 deletions(-) diff --git a/build.sbt b/build.sbt index 728347c1..14941bf4 100644 --- a/build.sbt +++ b/build.sbt @@ -158,7 +158,11 @@ lazy val core = project .configs(IntegrationTest) .settings( Defaults.itSettings, - moduleSettings + moduleSettings, + scalacOptions ++= Seq( + "-language:experimental.macros", + "-Ymacro-debug-lite" + ) ) .dependsOn( macros % "compile->compile;test->test;it->it" @@ -213,6 +217,10 @@ def testkitProject(esVersion: String, ss: Def.SettingsDefinition*): Project = { Defaults.itSettings, app.softnetwork.Info.infoSettings, moduleSettings, + scalacOptions ++= Seq( + "-language:experimental.macros", + "-Ymacro-debug-lite" + ), elasticSearchVersion := esVersion, buildInfoKeys += BuildInfoKey("elasticVersion" -> elasticSearchVersion.value), buildInfoObject := "SoftClient4esCoreTestkitBuildInfo", diff --git a/core/src/main/scala/app/softnetwork/elastic/client/ElasticClientDelegator.scala b/core/src/main/scala/app/softnetwork/elastic/client/ElasticClientDelegator.scala index e6233c5b..55a1f588 100644 --- a/core/src/main/scala/app/softnetwork/elastic/client/ElasticClientDelegator.scala +++ b/core/src/main/scala/app/softnetwork/elastic/client/ElasticClientDelegator.scala @@ -521,9 +521,9 @@ trait ElasticClientDelegator extends ElasticClientApi with BulkTypes { * true if the entity was indexed successfully, false otherwise */ override def index( - index: JSONResults, - id: JSONResults, - source: JSONResults, + index: String, + id: String, + source: String, wait: Boolean = false ): ElasticResult[Boolean] = delegate.index(index, id, source, wait) @@ -990,9 +990,10 @@ trait ElasticClientDelegator extends ElasticClientApi with BulkTypes { * @return * the entities matching the query */ - override def searchAs[U]( + override def searchAsUnchecked[U]( sqlQuery: SQLQuery - )(implicit m: Manifest[U], formats: Formats): ElasticResult[Seq[U]] = delegate.searchAs(sqlQuery) + )(implicit m: Manifest[U], formats: Formats): ElasticResult[Seq[U]] = + delegate.searchAsUnchecked(sqlQuery) /** Searches and converts results into typed entities. * @@ -1035,6 +1036,9 @@ trait ElasticClientDelegator extends ElasticClientApi with BulkTypes { delegate.multisearchAs(elasticQueries, fieldAliases, aggregations) /** Asynchronous search with conversion to typed entities. + * + * @note + * This method is a variant of searchAsyncAs without compile-time SQL validation. * * @param sqlQuery * the SQL query @@ -1043,11 +1047,12 @@ trait ElasticClientDelegator extends ElasticClientApi with BulkTypes { * @return * a Future containing the entities */ - override def searchAsyncAs[U](sqlQuery: SQLQuery)(implicit + override def searchAsyncAsUnchecked[U](sqlQuery: SQLQuery)(implicit m: Manifest[U], ec: ExecutionContext, formats: Formats - ): Future[ElasticResult[Seq[U]]] = delegate.searchAsyncAs(sqlQuery) + ): Future[ElasticResult[Seq[U]]] = + delegate.searchAsyncAsUnchecked(sqlQuery) /** Asynchronous search with conversion to typed entities. * @@ -1150,13 +1155,32 @@ trait ElasticClientDelegator extends ElasticClientApi with BulkTypes { system: ActorSystem ): Source[(Map[String, Any], ScrollMetrics), NotUsed] = delegate.scroll(sql, config) - /** Typed scroll source + /** Scroll and convert results into typed entities from an SQL query. + * + * @note + * This method is a variant of scrollAs without compile-time SQL validation. + * + * @param sql + * - SQL query + * @param config + * - Scroll configuration + * @param system + * - Actor system + * @param m + * - Manifest for type T + * @param formats + * - JSON formats + * @tparam T + * - Target type + * @return + * - Source of tuples (T, ScrollMetrics) */ - override def scrollAs[T](sql: SQLQuery, config: ScrollConfig)(implicit + override def scrollAsUnchecked[T](sql: SQLQuery, config: ScrollConfig)(implicit system: ActorSystem, m: Manifest[T], formats: Formats - ): Source[(T, ScrollMetrics), NotUsed] = delegate.scrollAs(sql, config) + ): Source[(T, ScrollMetrics), NotUsed] = + delegate.scrollAsUnchecked(sql, config) override private[client] def scrollClassic( elasticQuery: ElasticQuery, diff --git a/core/src/main/scala/app/softnetwork/elastic/client/ScrollApi.scala b/core/src/main/scala/app/softnetwork/elastic/client/ScrollApi.scala index 253b6e05..8db08707 100644 --- a/core/src/main/scala/app/softnetwork/elastic/client/ScrollApi.scala +++ b/core/src/main/scala/app/softnetwork/elastic/client/ScrollApi.scala @@ -27,11 +27,13 @@ import app.softnetwork.elastic.client.scroll.{ UseScroll, UseSearchAfter } +import app.softnetwork.elastic.sql.macros.SQLQueryMacros import app.softnetwork.elastic.sql.query.{SQLAggregation, SQLQuery} import org.json4s.{Formats, JNothing} import org.json4s.jackson.JsonMethods.parse import scala.concurrent.{ExecutionContext, Promise} +import scala.language.experimental.macros import scala.util.{Failure, Success} /** API for scrolling through search results using Akka Streams. @@ -167,9 +169,57 @@ trait ScrollApi extends ElasticClientHelpers { hasSorts: Boolean = false )(implicit system: ActorSystem): Source[Map[String, Any], NotUsed] - /** Typed scroll source + /** Typed scroll source converting results into typed entities from an SQL query + * + * @note + * This method provides compile-time SQL validation via macros. + * + * @param sql + * - SQL query + * @param config + * - Scroll configuration + * @param system + * - Actor system + * @param m + * - Manifest for type T + * @param formats + * - JSON formats + * @tparam T + * - Target type + * @return + * - Source of tuples (T, ScrollMetrics) */ def scrollAs[T]( + sql: String, + config: ScrollConfig = ScrollConfig() + )(implicit + system: ActorSystem, + m: Manifest[T], + formats: Formats + ): Source[(T, ScrollMetrics), NotUsed] = + macro SQLQueryMacros.scrollAsImpl[T] + + /** Scroll and convert results into typed entities from an SQL query. + * + * @note + * This method is a variant of scrollAs without compile-time SQL validation. + * + * @param sql + * - SQL query + * @param config + * - Scroll configuration + * @param system + * - Actor system + * @param m + * - Manifest for type T + * @param formats + * - JSON formats + * @tparam T + * - Target type + * @return + * - Source of tuples (T, ScrollMetrics) + */ + def scrollAsUnchecked[T]( sql: SQLQuery, config: ScrollConfig = ScrollConfig() )(implicit diff --git a/core/src/main/scala/app/softnetwork/elastic/client/SearchApi.scala b/core/src/main/scala/app/softnetwork/elastic/client/SearchApi.scala index 9c5c1e32..4822375e 100644 --- a/core/src/main/scala/app/softnetwork/elastic/client/SearchApi.scala +++ b/core/src/main/scala/app/softnetwork/elastic/client/SearchApi.scala @@ -22,12 +22,14 @@ import app.softnetwork.elastic.client.result.{ ElasticResult, ElasticSuccess } +import app.softnetwork.elastic.sql.macros.SQLQueryMacros import app.softnetwork.elastic.sql.query.{SQLAggregation, SQLQuery, SQLSearchRequest} import com.google.gson.{Gson, JsonElement, JsonObject, JsonParser} import org.json4s.Formats import scala.concurrent.{ExecutionContext, Future} import scala.jdk.CollectionConverters._ +import scala.language.experimental.macros import scala.reflect.{classTag, ClassTag} import scala.util.{Failure, Success, Try} @@ -400,7 +402,10 @@ trait SearchApi extends ElasticConversion with ElasticClientHelpers { /** Searches and converts results into typed entities from an SQL query. * - * @param sqlQuery + * @note + * This method uses compile-time macros to validate the SQL query against the type U. + * + * @param query * the SQL query containing fieldAliases and aggregations * @tparam U * the type of entities to return @@ -408,6 +413,23 @@ trait SearchApi extends ElasticConversion with ElasticClientHelpers { * the entities matching the query */ def searchAs[U]( + query: String + )(implicit m: Manifest[U], formats: Formats): ElasticResult[Seq[U]] = + macro SQLQueryMacros.searchAsImpl[U] + + /** Searches and converts results into typed entities from an SQL query. + * + * @note + * This method is a variant of searchAs without compile-time SQL validation. + * + * @param sqlQuery + * the SQL query containing fieldAliases and aggregations + * @tparam U + * the type of entities to return + * @return + * the entities matching the query + */ + def searchAsUnchecked[U]( sqlQuery: SQLQuery )(implicit m: Manifest[U], formats: Formats): ElasticResult[Seq[U]] = { for { @@ -473,7 +495,10 @@ trait SearchApi extends ElasticConversion with ElasticClientHelpers { /** Asynchronous search with conversion to typed entities. * - * @param sqlQuery + * @note + * This method uses compile-time macros to validate the SQL query against the type U. + * + * @param query * the SQL query * @tparam U * the type of entities to return @@ -481,6 +506,27 @@ trait SearchApi extends ElasticConversion with ElasticClientHelpers { * a Future containing the entities */ def searchAsyncAs[U]( + query: String + )(implicit + m: Manifest[U], + ec: ExecutionContext, + formats: Formats + ): Future[ElasticResult[Seq[U]]] = + macro SQLQueryMacros.searchAsyncAsImpl[U] + + /** Asynchronous search with conversion to typed entities. + * + * @note + * This method is a variant of searchAsyncAs without compile-time SQL validation. + * + * @param sqlQuery + * the SQL query + * @tparam U + * the type of entities to return + * @return + * a Future containing the entities + */ + def searchAsyncAsUnchecked[U]( sqlQuery: SQLQuery )(implicit m: Manifest[U], diff --git a/core/src/main/scala/app/softnetwork/elastic/client/metrics/MetricsElasticClient.scala b/core/src/main/scala/app/softnetwork/elastic/client/metrics/MetricsElasticClient.scala index f34d0fec..8a16dabd 100644 --- a/core/src/main/scala/app/softnetwork/elastic/client/metrics/MetricsElasticClient.scala +++ b/core/src/main/scala/app/softnetwork/elastic/client/metrics/MetricsElasticClient.scala @@ -663,14 +663,17 @@ class MetricsElasticClient( * @return * the entities matching the query */ - override def searchAs[U]( + override def searchAsUnchecked[U]( sqlQuery: SQLQuery )(implicit m: Manifest[U], formats: Formats): ElasticResult[Seq[U]] = measureResult("searchAs") { - delegate.searchAs[U](sqlQuery) + delegate.searchAsUnchecked[U](sqlQuery) } /** Asynchronous search with conversion to typed entities. + * + * @note + * This method is a variant of searchAsyncAs without compile-time SQL validation. * * @param sqlQuery * the SQL query @@ -679,13 +682,13 @@ class MetricsElasticClient( * @return * a Future containing the entities */ - override def searchAsyncAs[U](sqlQuery: SQLQuery)(implicit + override def searchAsyncAsUnchecked[U](sqlQuery: SQLQuery)(implicit m: Manifest[U], ec: ExecutionContext, formats: Formats ): Future[ElasticResult[Seq[U]]] = measureAsync("searchAsyncAs") { - delegate.searchAsyncAs[U](sqlQuery) + delegate.searchAsyncAsUnchecked[U](sqlQuery) } override def singleSearch( @@ -900,16 +903,33 @@ class MetricsElasticClient( } - /** Typed scroll source + /** Scroll and convert results into typed entities from an SQL query. + * + * @note + * This method is a variant of scrollAs without compile-time SQL validation. + * @param sql + * - SQL query + * @param config + * - Scroll configuration + * @param system + * - Actor system + * @param m + * - Manifest for type T + * @param formats + * - JSON formats + * @tparam T + * - Target type + * @return + * - Source of tuples (T, ScrollMetrics) */ - override def scrollAs[T](sql: SQLQuery, config: ScrollConfig)(implicit + override def scrollAsUnchecked[T](sql: SQLQuery, config: ScrollConfig)(implicit system: ActorSystem, m: Manifest[T], formats: Formats ): Source[(T, ScrollMetrics), NotUsed] = { // Note: For streams, we measure at the beginning but not every element val startTime = System.currentTimeMillis() - val source = delegate.scrollAs[T](sql, config) + val source = delegate.scrollAsUnchecked[T](sql, config) source.watchTermination() { (_, done) => done.onComplete { result => diff --git a/macros/src/main/scala/app/softnetwork/elastic/client/macros/TestElasticClientApi.scala b/macros/src/main/scala/app/softnetwork/elastic/client/macros/TestElasticClientApi.scala index 8e022628..ecfaed86 100644 --- a/macros/src/main/scala/app/softnetwork/elastic/client/macros/TestElasticClientApi.scala +++ b/macros/src/main/scala/app/softnetwork/elastic/client/macros/TestElasticClientApi.scala @@ -1,3 +1,19 @@ +/* + * Copyright 2025 SOFTNETWORK + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package app.softnetwork.elastic.client.macros import app.softnetwork.elastic.sql.macros.SQLQueryMacros diff --git a/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryMacros.scala b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryMacros.scala index d28c1076..8371e771 100644 --- a/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryMacros.scala +++ b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryMacros.scala @@ -32,21 +32,21 @@ object SQLQueryMacros extends SQLQueryValidator { )( m: c.Expr[Manifest[T]], formats: c.Expr[Formats] - ): c.Expr[Seq[T]] = { + ): c.Tree = { import c.universe._ - c.echo(c.enclosingPosition, "=" * 60) - c.echo(c.enclosingPosition, "🚀🚀🚀 MACRO searchAsImpl CALLED 🚀🚀🚀") - c.echo(c.enclosingPosition, "=" * 60) + // 1. Validate the SQL query at compile-time + val validatedQuery = validateSQLQuery[T](c)(query) + // 2. Get the type parameter val tpe = weakTypeOf[T] - val validatedQuery = validateSQLQuery[T](c)(query) - c.Expr[Seq[T]](q""" + // 3. Generate the call to searchAsUnchecked + q""" ${c.prefix}.searchAsUnchecked[$tpe]( - SQLQuery($validatedQuery) + _root_.app.softnetwork.elastic.sql.query.SQLQuery($validatedQuery) )($m, $formats) - """) + """ } // ============================================================ @@ -54,26 +54,26 @@ object SQLQueryMacros extends SQLQueryValidator { // ============================================================ def searchAsyncAsImpl[U: c.WeakTypeTag](c: blackbox.Context)( - sqlQuery: c.Expr[String] + query: c.Expr[String] )( m: c.Expr[Any], ec: c.Expr[Any], formats: c.Expr[Formats] - ): c.Expr[Any] = { + ): c.Tree = { import c.universe._ - c.echo(c.enclosingPosition, "=" * 60) - c.echo(c.enclosingPosition, "🚀🚀🚀 MACRO searchAsyncAsImpl CALLED 🚀🚀🚀") - c.echo(c.enclosingPosition, "=" * 60) + // 1. Validate the SQL query at compile-time + val validatedQuery = validateSQLQuery[U](c)(query) + // 2. Get the type parameter val tpe = weakTypeOf[U] - val validatedQuery = validateSQLQuery[U](c)(sqlQuery) - c.Expr[Any](q""" + // 3. Generate the call to searchAsUnchecked + q""" ${c.prefix}.searchAsyncAsUnchecked[$tpe]( - SQLQuery($validatedQuery) + _root_.app.softnetwork.elastic.sql.query.SQLQuery($validatedQuery) )($m, $ec, $formats) - """) + """ } // ============================================================ @@ -87,21 +87,21 @@ object SQLQueryMacros extends SQLQueryValidator { system: c.Expr[Any], m: c.Expr[Any], formats: c.Expr[Formats] - ): c.Expr[Any] = { + ): c.Tree = { import c.universe._ - val tpe = weakTypeOf[T] + // 1. Validate the SQL query at compile-time val validatedQuery = validateSQLQuery[T](c)(sql) - c.echo(c.enclosingPosition, "=" * 60) - c.echo(c.enclosingPosition, "🚀🚀🚀 MACRO scrollAsImpl CALLED 🚀🚀🚀") - c.echo(c.enclosingPosition, "=" * 60) + // 2. Get the type parameter + val tpe = weakTypeOf[T] - c.Expr[Any](q""" + // 3. Generate the call to searchAsUnchecked + q""" ${c.prefix}.scrollAsUnchecked[$tpe]( - SQLQuery($validatedQuery), + _root_.app.softnetwork.elastic.sql.query.SQLQuery($validatedQuery), $config )($system, $m, $formats) - """) + """ } } diff --git a/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala index f37900bf..e2384cb4 100644 --- a/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala +++ b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala @@ -37,10 +37,10 @@ trait SQLQueryValidator { query: c.Expr[String] ): String = { - c.echo(c.enclosingPosition, "🚀 MACRO IS BEING CALLED!") + debug(c)("🚀 MACRO IS BEING CALLED!") // 1. Extract the SQL query (must be a literal) - val sqlQuery = extractStringLiteral(c)(query) + val sqlQuery = extractSQLString(c)(query) if (sys.props.get("elastic.sql.debug").contains("true")) { c.info(c.enclosingPosition, s"Validating SQL: $sqlQuery", force = false) @@ -57,12 +57,12 @@ trait SQLQueryValidator { // 3. Extract the selected fields val queryFields = extractQueryFields(parsedQuery) - c.echo(c.enclosingPosition, s"🔍 Parsed fields: ${queryFields.mkString(", ")}") + debug(c)(s"🔍 Parsed fields: ${queryFields.mkString(", ")}") // 4. Extract the fields from case class T val tpe = c.weakTypeOf[T] val caseClassFields = extractCaseClassFields(c)(tpe) - c.echo(c.enclosingPosition, s"📦 Case class fields: ${caseClassFields.mkString(", ")}") + debug(c)(s"📦 Case class fields: ${caseClassFields.mkString(", ")}") // 5. Validate: missing case class fields must have defaults or be Option validateMissingFieldsHaveDefaults(c)(queryFields, caseClassFields, tpe) @@ -70,6 +70,10 @@ trait SQLQueryValidator { // 7. Validate the types validateTypes(c)(parsedQuery, caseClassFields) + debug(c)("=" * 80) + debug(c)("✅ SQL Query Validation Complete") + debug(c)("=" * 80) + // 8. Return the validated request sqlQuery } @@ -78,23 +82,96 @@ trait SQLQueryValidator { // Helper Methods // ============================================================ - private def extractStringLiteral(c: blackbox.Context)( - query: c.Expr[String] - ): String = { + /** Extracts SQL string from various tree structures. Supports: literals, .stripMargin, and simple + * expressions. + */ + protected def extractSQLString(c: blackbox.Context)(query: c.Expr[String]): String = { import c.universe._ - query.tree match { - case Literal(Constant(sql: String)) => - c.echo(c.enclosingPosition, s"📝 Query: $sql") - sql - case other => - c.echo(c.enclosingPosition, s"❌ Not a literal: ${showRaw(other)}") + debug(c)("=" * 80) + debug(c)("🔍 Starting SQL Query Validation") + debug(c)("=" * 80) + + val sqlString = + (query match { + // Case 1: Direct string literal + // Example: "SELECT * FROM table" + case Literal(Constant(sql: String)) => + debug(c)("📝 Detected: Direct string literal") + Some(sql) + + // Case 2: String with .stripMargin + // Example: """SELECT * FROM table""".stripMargin + case Select(Literal(Constant(sql: String)), TermName("stripMargin")) => + debug(c)("📝 Detected: String with .stripMargin") + Some(sql.stripMargin) + + // Case 3: Try to evaluate as compile-time constant + case _ => + debug(c)(s"⚠️ Attempting to evaluate: ${showCode(query.tree)}") + try { + val evaluated = c.eval(c.Expr[String](c.untypecheck(query.tree.duplicate))) + debug(c)(s"✅ Successfully evaluated to: $evaluated") + Some(evaluated) + } catch { + case e: Throwable => + debug(c)(s"❌ Could not evaluate: ${e.getMessage}") + None + } + }).getOrElse { c.abort( c.enclosingPosition, - "❌ SQL query must be a string literal for compile-time validation. " + - "Use the *Unchecked() variant for dynamic queries." + s"""❌ SQL query must be a compile-time constant for validation. + | + |Found: ${showCode(query.tree)} + |Tree structure: ${showRaw(query.tree)} + | + |✅ Valid usage: + | scrollAs[Product]("SELECT id, name FROM products") + | scrollAs[Product](\"\"\"SELECT id, name FROM products\"\"\".stripMargin) + | + |❌ For dynamic queries, use: + | scrollAsUnchecked[Product](SQLQuery(dynamicSql), ScrollConfig()) + | + |""".stripMargin ) + } + + debug(c)(s"📝 Extracted SQL: $sqlString") + + sqlString + } + + /** Validates the SQL query structure against the type T. + */ + protected def validateQueryStructure[T: c.WeakTypeTag](c: blackbox.Context)( + sql: String + ): Unit = { + import c.universe._ + + val tpe = weakTypeOf[T] + + debug(c)(s"🔍 Validating query for type: ${tpe.typeSymbol.name}") + + // Example validations (customize as needed) + + // 1. Check for SELECT * + if (sql.matches("(?i).*SELECT\\s+\\*.*")) { + c.abort( + c.enclosingPosition, + s"""❌ SELECT * is not allowed for type-safe queries. + | + |Please explicitly list all fields required for type ${tpe.typeSymbol.name}. + |""".stripMargin + ) } + + // 2. Additional validations... + // - Check field names against type T + // - Validate JOIN syntax + // - etc. + + debug(c)(s"✅ Query structure valid for ${tpe.typeSymbol.name}") } private def parseSQLQuery(c: blackbox.Context)(sqlQuery: String): SQLSearchRequest = { @@ -154,7 +231,7 @@ trait SQLQueryValidator { ) } - c.echo(c.enclosingPosition, "✅ No SELECT * detected") + debug(c)("✅ No SELECT * detected") } private def extractQueryFields(parsedQuery: SQLSearchRequest): Set[String] = { @@ -203,7 +280,7 @@ trait SQLQueryValidator { ) } - c.echo(c.enclosingPosition, "✅ All query fields exist in case class") + debug(c)("✅ All query fields exist in case class") } // ============================================================ @@ -219,11 +296,11 @@ trait SQLQueryValidator { val missingFields = caseClassFields.keySet -- queryFields if (missingFields.isEmpty) { - c.echo(c.enclosingPosition, "✅ No missing fields to validate") + debug(c)("✅ No missing fields to validate") return } - c.echo(c.enclosingPosition, s"⚠️ Missing fields: ${missingFields.mkString(", ")}") + debug(c)(s"⚠️ Missing fields: ${missingFields.mkString(", ")}") // Get constructor parameters with their positions val constructor = tpe.decl(termNames.CONSTRUCTOR).asMethod @@ -255,17 +332,17 @@ trait SQLQueryValidator { fieldInfo.get(fieldName) match { case Some((_, hasDefault, isOption)) => if (isOption) { - c.echo(c.enclosingPosition, s"✅ Field '$fieldName' is Option - OK") + debug(c)(s"✅ Field '$fieldName' is Option - OK") true } else if (hasDefault) { - c.echo(c.enclosingPosition, s"✅ Field '$fieldName' has default value - OK") + debug(c)(s"✅ Field '$fieldName' has default value - OK") true } else { - c.echo(c.enclosingPosition, s"❌ Field '$fieldName' has NO default and is NOT Option") + debug(c)(s"❌ Field '$fieldName' has NO default and is NOT Option") false } case None => - c.echo(c.enclosingPosition, s"⚠️ Field '$fieldName' not found in constructor") + debug(c)(s"⚠️ Field '$fieldName' not found in constructor") false } } @@ -284,7 +361,7 @@ trait SQLQueryValidator { ) } - c.echo(c.enclosingPosition, "✅ All missing fields have defaults or are Option") + debug(c)("✅ All missing fields have defaults or are Option") } // Helper: Get the index of a field in the case class constructor @@ -325,7 +402,7 @@ trait SQLQueryValidator { } } - c.echo(c.enclosingPosition, "✅ Type validation passed") + debug(c)("✅ Type validation passed") } private def areTypesCompatible(c: blackbox.Context)( @@ -459,4 +536,14 @@ trait SQLQueryValidator { dist(s2.length)(s1.length) } + + protected def debug(c: blackbox.Context)(msg: String): Unit = { + if (SQLQueryValidator.DEBUG) { + debug(c)(msg) + } + } +} + +object SQLQueryValidator { + val DEBUG: Boolean = sys.props.get("sql.macro.debug").contains("true") } diff --git a/persistence/src/main/scala/app/softnetwork/elastic/persistence/query/ElasticProvider.scala b/persistence/src/main/scala/app/softnetwork/elastic/persistence/query/ElasticProvider.scala index af1a7706..533cfb84 100644 --- a/persistence/src/main/scala/app/softnetwork/elastic/persistence/query/ElasticProvider.scala +++ b/persistence/src/main/scala/app/softnetwork/elastic/persistence/query/ElasticProvider.scala @@ -188,7 +188,7 @@ trait ElasticProvider[T <: Timestamped] override def searchDocuments( query: String )(implicit m: Manifest[T], formats: Formats): List[T] = { - searchAs[T](SQLQuery(query)) match { + searchAsUnchecked[T](SQLQuery(query)) match { case ElasticSuccess(results) => results.toList case ElasticFailure(elasticError) => logger.error(s"searchDocuments failed -> ${elasticError.message}") diff --git a/testkit/src/main/scala/app/softnetwork/elastic/client/ElasticClientSpec.scala b/testkit/src/main/scala/app/softnetwork/elastic/client/ElasticClientSpec.scala index 55c12563..f2985eb0 100644 --- a/testkit/src/main/scala/app/softnetwork/elastic/client/ElasticClientSpec.scala +++ b/testkit/src/main/scala/app/softnetwork/elastic/client/ElasticClientSpec.scala @@ -226,7 +226,7 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M "person_mapping" should haveCount(3) - pClient.searchAs[Person]("select * from person_mapping") match { + pClient.searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person_mapping") match { case ElasticSuccess(value) => value match { case r if r.size == 3 => @@ -237,7 +237,7 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M fail(elasticError.fullMessage) } - pClient.searchAs[Person]("select * from person_mapping where uuid = 'A16'").get match { + pClient.searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person_mapping where uuid = 'A16'").get match { case r if r.size == 1 => r.map(_.uuid) should contain only "A16" case other => fail(other.toString) @@ -245,7 +245,7 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M pClient .searchAs[Person]( - "select * from person_mapping where match (name) against ('gum')" + "select uuid, name, birthDate, createdDate, lastUpdated from person_mapping where match (name) against ('gum')" ) .get match { case r if r.size == 1 => @@ -255,7 +255,7 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M pClient .searchAs[Person]( - "select * from person_mapping where uuid <> 'A16' and match (name) against ('gum')" + "select uuid, name, birthDate, createdDate, lastUpdated from person_mapping where uuid <> 'A16' and match (name) against ('gum')" ) .get match { case r if r.isEmpty => @@ -305,7 +305,7 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M pClient .searchAs[Person]( - "select * from person_migration where match (name) against ('gum')" + "select uuid, name, birthDate, createdDate, lastUpdated from person_migration where match (name) against ('gum')" ) .get match { case r if r.isEmpty => @@ -358,7 +358,7 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M pClient .searchAs[Person]( - "select * from person_migration where match (name) against ('gum')" + "select uuid, name, birthDate, createdDate, lastUpdated from person_migration where match (name) against ('gum')" ) .get match { case r if r.size == 1 => @@ -382,7 +382,7 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M "person1" should haveCount(3) - pClient.searchAs[Person]("select * from person1").get match { + pClient.searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person1").get match { case r if r.size == 3 => r.map(_.uuid) should contain allOf ("A12", "A14", "A16") r.map(_.name) should contain allOf ("Homer Simpson", "Moe Szyslak", "Barney Gumble") @@ -406,7 +406,7 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M "person2" should haveCount(3) - pClient.searchAs[Person]("select * from person2").get match { + pClient.searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person2").get match { case r if r.size == 3 => r.map(_.uuid) should contain allOf ("A12", "A14", "A16") r.map(_.name) should contain allOf ("Homer Simpson", "Moe Szyslak", "Barney Gumble") @@ -447,7 +447,9 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M "person-1967-11-21" should haveCount(2) "person-1969-05-09" should haveCount(1) - pClient.searchAs[Person]("select * from person-1967-11-21, person-1969-05-09").get match { + pClient + .searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person-1967-11-21, person-1969-05-09") + .get match { case r if r.size == 3 => r.map(_.uuid) should contain allOf ("A12", "A14", "A16") r.map(_.name) should contain allOf ("Homer Simpson", "Moe Szyslak", "Barney Gumble") @@ -494,7 +496,7 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M "person4" should haveCount(3) - pClient.searchAs[Person]("select * from person4").get match { + pClient.searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person4").get match { case r if r.size == 3 => r.map(_.uuid) should contain allOf ("A12", "A14", "A16") r.map(_.name) should contain allOf ("Homer Simpson", "Moe Szyslak", "Barney Gumble2") @@ -536,7 +538,9 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M "person5-1967-11-21" should haveCount(2) "person5-1969-05-09" should haveCount(1) - pClient.searchAs[Person]("select * from person5-1967-11-21, person5-1969-05-09").get match { + pClient + .searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person5-1967-11-21, person5-1969-05-09") + .get match { case r if r.size == 3 => r.map(_.uuid) should contain allOf ("A12", "A14", "A16") r.map(_.name) should contain allOf ("Homer Simpson", "Moe Szyslak", "Barney Gumble2") @@ -611,11 +615,11 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M "person7" should haveCount(3) - val r1 = pClient.searchAs[Person]("select * from person7").get + val r1 = pClient.searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person7").get r1.size should ===(3) r1.map(_.uuid) should contain allOf ("A12", "A14", "A16") - pClient.searchAsyncAs[Person]("select * from person7") onComplete { + pClient.searchAsyncAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person7") onComplete { case Success(s) => val r = s.get r.size should ===(3) @@ -623,11 +627,11 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M case Failure(f) => fail(f.getMessage) } - val r2 = pClient.searchAs[Person]("select * from person7 where _id=\"A16\"").get + val r2 = pClient.searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person7 where _id=\"A16\"").get r2.size should ===(1) r2.map(_.uuid) should contain("A16") - pClient.searchAsyncAs[Person]("select * from person7 where _id=\"A16\"") onComplete { + pClient.searchAsyncAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person7 where _id=\"A16\"") onComplete { case Success(s) => val r = s.get r.size should ===(1) @@ -658,7 +662,7 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M "person8" should haveCount(3) - val response = pClient.searchAs[Person]("select * from person8").get + val response = pClient.searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person8").get response.size should ===(3) @@ -1169,7 +1173,19 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M "parent" should haveCount(3) - val parents = parentClient.searchAs[Parent]("select * from parent") + val parents = parentClient.searchAs[Parent]( + """SELECT + | p.uuid, + | p.name, + | p.birthDate, + | children.name, + | children.birthDate, + | children.parentId + | FROM + | parent as p + | JOIN UNNEST(p.children) as children + |""".stripMargin + ) parents.get.size shouldBe 3 val results = parentClient @@ -1202,8 +1218,9 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M ) should contain allOf ("1999-05-09", "2002-05-09") result._2.map(_.parentId) should contain only "A16" - val query = - """SELECT + val searchResults = parentClient + .searchAs[Parent]( + """SELECT | p.uuid, | p.name, | p.birthDate, @@ -1216,8 +1233,8 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M |WHERE | children.name is not null AND p.uuid = 'A16' |""".stripMargin - - val searchResults = parentClient.searchAs[Parent](query).get + ) + .get searchResults.size shouldBe 1 val searchResult = searchResults.head searchResult.uuid shouldBe "A16" @@ -1229,7 +1246,22 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M searchResult.children.map(_.parentId) should contain only "A16" val scrollResults: Future[Seq[(Parent, ScrollMetrics)]] = parentClient - .scrollAs[Parent](query, config = ScrollConfig(logEvery = 1)) + .scrollAs[Parent]( + """SELECT + | p.uuid, + | p.name, + | p.birthDate, + | children.name, + | children.birthDate, + | children.parentId + | FROM + | parent as p + | JOIN UNNEST(p.children) as children + |WHERE + | children.name is not null AND p.uuid = 'A16' + |""".stripMargin, + ScrollConfig(logEvery = 1) + ) .runWith(Sink.seq) scrollResults await { rows => val parents = rows.map(_._1) From 7de76bdcc453c4bce55724f00287699df8d743f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Mon, 10 Nov 2025 10:22:50 +0100 Subject: [PATCH 04/11] update query validator to support collection --- .../sql/macros/SQLQueryValidator.scala | 309 +++++++++--------- .../elastic/client/ElasticClientSpec.scala | 91 ++++-- 2 files changed, 221 insertions(+), 179 deletions(-) diff --git a/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala index e2384cb4..8b3361d9 100644 --- a/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala +++ b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala @@ -50,25 +50,25 @@ trait SQLQueryValidator { val parsedQuery = parseSQLQuery(c)(sqlQuery) // ============================================================ - // ✅ NEW: Reject SELECT * + // Reject SELECT * // ============================================================ rejectSelectStar(c)(parsedQuery, sqlQuery) - // 3. Extract the selected fields + // 3. Extract the selected fields from the query val queryFields = extractQueryFields(parsedQuery) debug(c)(s"🔍 Parsed fields: ${queryFields.mkString(", ")}") - // 4. Extract the fields from case class T + // 4. Extract the required fields from case class T val tpe = c.weakTypeOf[T] - val caseClassFields = extractCaseClassFields(c)(tpe) - debug(c)(s"📦 Case class fields: ${caseClassFields.mkString(", ")}") + val requiredFields = getRequiredFields(c)(tpe) + debug(c)(s"📦 Case class fields: ${requiredFields.mkString(", ")}") // 5. Validate: missing case class fields must have defaults or be Option - validateMissingFieldsHaveDefaults(c)(queryFields, caseClassFields, tpe) + validateRequiredFields(c)(queryFields) // 7. Validate the types - validateTypes(c)(parsedQuery, caseClassFields) + validateTypes(c)(parsedQuery, requiredFields.map(values => values._1 -> values._2._1)) debug(c)("=" * 80) debug(c)("✅ SQL Query Validation Complete") @@ -85,7 +85,7 @@ trait SQLQueryValidator { /** Extracts SQL string from various tree structures. Supports: literals, .stripMargin, and simple * expressions. */ - protected def extractSQLString(c: blackbox.Context)(query: c.Expr[String]): String = { + private def extractSQLString(c: blackbox.Context)(query: c.Expr[String]): String = { import c.universe._ debug(c)("=" * 80) @@ -142,38 +142,9 @@ trait SQLQueryValidator { sqlString } - /** Validates the SQL query structure against the type T. - */ - protected def validateQueryStructure[T: c.WeakTypeTag](c: blackbox.Context)( - sql: String - ): Unit = { - import c.universe._ - - val tpe = weakTypeOf[T] - - debug(c)(s"🔍 Validating query for type: ${tpe.typeSymbol.name}") - - // Example validations (customize as needed) - - // 1. Check for SELECT * - if (sql.matches("(?i).*SELECT\\s+\\*.*")) { - c.abort( - c.enclosingPosition, - s"""❌ SELECT * is not allowed for type-safe queries. - | - |Please explicitly list all fields required for type ${tpe.typeSymbol.name}. - |""".stripMargin - ) - } - - // 2. Additional validations... - // - Check field names against type T - // - Validate JOIN syntax - // - etc. - - debug(c)(s"✅ Query structure valid for ${tpe.typeSymbol.name}") - } - + // ============================================================ + // Helper: Parse SQL query into SQLSearchRequest + // ============================================================ private def parseSQLQuery(c: blackbox.Context)(sqlQuery: String): SQLSearchRequest = { Parser(sqlQuery) match { case Right(Left(request)) => @@ -194,7 +165,7 @@ trait SQLQueryValidator { } // ============================================================ - // ✅ Reject SELECT * (incompatible with compile-time validation) + // Reject SELECT * (incompatible with compile-time validation) // ============================================================ private def rejectSelectStar(c: blackbox.Context)( parsedQuery: SQLSearchRequest, @@ -234,161 +205,184 @@ trait SQLQueryValidator { debug(c)("✅ No SELECT * detected") } - private def extractQueryFields(parsedQuery: SQLSearchRequest): Set[String] = { - parsedQuery.select.fields.map { field => - field.fieldAlias.map(_.alias).getOrElse(field.identifier.name) - }.toSet - } - - private def extractCaseClassFields(c: blackbox.Context)( - tpe: c.universe.Type - ): Map[String, c.universe.Type] = { + // ============================================================ + // Helper: Detect if a type is a collection + // ============================================================ + private def isCollectionType(c: blackbox.Context)(tpe: c.universe.Type): Boolean = { import c.universe._ - tpe.members.collect { - case m: MethodSymbol if m.isCaseAccessor => - m.name.toString -> m.returnType - }.toMap + val collectionTypes = Set( + typeOf[List[_]].typeConstructor, + typeOf[Seq[_]].typeConstructor, + typeOf[Vector[_]].typeConstructor, + typeOf[Set[_]].typeConstructor, + typeOf[Array[_]].typeConstructor + ) + + collectionTypes.exists(collType => tpe.typeConstructor <:< collType) } // ============================================================ - // ✅ VALIDATION 1: Query fields must exist in case class + // Helper: Extract the element type from a collection // ============================================================ - @deprecated - private def validateQueryFieldsExist(c: blackbox.Context)( - queryFields: Set[String], - caseClassFields: Map[String, c.universe.Type], - tpe: c.universe.Type - ): Unit = { - val unknownFields = queryFields.filterNot(f => caseClassFields.contains(f)) - - if (unknownFields.nonEmpty) { - val availableFields = caseClassFields.keys.toSeq.sorted.mkString(", ") - val suggestions = unknownFields.flatMap { unknown => - findClosestMatch(unknown, caseClassFields.keys.toSeq) - } - - val suggestionMsg = if (suggestions.nonEmpty) { - s"\nDid you mean: ${suggestions.mkString(", ")}?" - } else "" - - c.abort( - c.enclosingPosition, - s"❌ SQL query selects fields not present in ${tpe.typeSymbol.name}: " + - s"${unknownFields.mkString(", ")}\n" + - s"Available fields: $availableFields$suggestionMsg" - ) + private def getCollectionElementType( + c: blackbox.Context + )(tpe: c.universe.Type): Option[c.universe.Type] = { + if (isCollectionType(c)(tpe)) { + tpe.typeArgs.headOption + } else { + None } - - debug(c)("✅ All query fields exist in case class") } // ============================================================ - // ✅ VALIDATION 2: Missing fields must have defaults or be Option + // Helper: Extract the required fields from a class case // ============================================================ - private def validateMissingFieldsHaveDefaults(c: blackbox.Context)( - queryFields: Set[String], - caseClassFields: Map[String, c.universe.Type], - tpe: c.universe.Type - ): Unit = { + private def getRequiredFields( + c: blackbox.Context + )(tpe: c.universe.Type): Map[String, (c.universe.Type, Boolean, Boolean)] = { import c.universe._ - val missingFields = caseClassFields.keySet -- queryFields - - if (missingFields.isEmpty) { - debug(c)("✅ No missing fields to validate") - return - } + val constructor = tpe.decls + .collectFirst { + case m: MethodSymbol if m.isPrimaryConstructor => m + } + .getOrElse { + c.abort(c.enclosingPosition, s"No primary constructor found for $tpe") + } - debug(c)(s"⚠️ Missing fields: ${missingFields.mkString(", ")}") + constructor.paramLists.flatten.flatMap { param => + val paramName = param.name.decodedName.toString + val paramType = param.typeSignature - // Get constructor parameters with their positions - val constructor = tpe.decl(termNames.CONSTRUCTOR).asMethod - val params = constructor.paramLists.flatten + // Check if the parameter has a default value or is an option. + val isOption = paramType <:< typeOf[Option[_]] - // Build map: fieldName -> (index, hasDefault, isOption) - val fieldInfo = params.zipWithIndex.map { case (param, idx) => - val fieldName = param.name.toString - val fieldType = param.typeSignature + val hasDefault = param.asTerm.isParamWithDefault - // Check if Option - val isOption = fieldType.typeConstructor =:= typeOf[Option[_]].typeConstructor + /* We should not filter out optional parameters here, + because we need to know all fields to validate their types later. - // Check if has default value - val companionSymbol = tpe.typeSymbol.companion - val hasDefault = if (companionSymbol != NoSymbol) { - val companionType = companionSymbol.typeSignature - val defaultMethodName = s"apply$$default$$${idx + 1}" - companionType.member(TermName(defaultMethodName)) != NoSymbol + if (isOption || hasDefault) { + None } else { - false - } + Some((paramName, (paramType, isOption, hasDefault))) + }*/ + + Some((paramName, (paramType, isOption, hasDefault))) - (fieldName, (idx, hasDefault, isOption)) }.toMap + } + + /** Extracts selected fields from the parsed SQL query. + */ + private def extractQueryFields(parsedQuery: SQLSearchRequest): Set[String] = { + parsedQuery.select.fields.flatMap { field => + val f = field.fieldAlias.map(_.alias).getOrElse(field.identifier.name) + /*field.identifier.nestedElement match { + case Some(nested) => List(f, nested.innerHitsName) + case None => List(f) + }*/ + List(f) + }.toSet + } + + // ============================================================ + // Helper: Validate required vs. selected fields + // ============================================================ + private def validateRequiredFields[T: c.WeakTypeTag]( + c: blackbox.Context + )( + queryFields: Set[String] + ): Unit = { + import c.universe._ + + val tpe = weakTypeOf[T] + val requiredFields = getRequiredFields(c)(tpe) + + val missingFields = requiredFields.filterNot { + case (fieldName, (fieldType, isOption, hasDefault)) => + // ✅ Check if the field is selected + val isSelected = queryFields.contains(fieldName) + + if (!isSelected) { + debug(c)(s"⚠️ Missing field: $fieldName") - // Check each missing field - val fieldsWithoutDefaults = missingFields.filterNot { fieldName => - fieldInfo.get(fieldName) match { - case Some((_, hasDefault, isOption)) => if (isOption) { debug(c)(s"✅ Field '$fieldName' is Option - OK") true } else if (hasDefault) { debug(c)(s"✅ Field '$fieldName' has default value - OK") true + } + // ✅ If it's a collection, check if its nested fields are selected. + else if (isCollectionType(c)(fieldType)) { + getCollectionElementType(c)(fieldType) match { + case Some(elementType) => + // Check if the nested fields of the collection are selected + // Eg: "children.name", "children.birthDate" + val nestedFields = getRequiredFields(c)(elementType) + val hasNestedFields = nestedFields.forall { case (nestedFieldName, _) => + queryFields.exists(f => f.startsWith(s"$fieldName.$nestedFieldName")) + } + + if (hasNestedFields) { + // ✅ The nested fields are present, so the collection is considered valid. + debug(c)(s"✅ Collection field '$fieldName' validated via nested fields") + } + + hasNestedFields + + case None => false + } } else { - debug(c)(s"❌ Field '$fieldName' has NO default and is NOT Option") false } - case None => - debug(c)(s"⚠️ Field '$fieldName' not found in constructor") - false - } + } else { + true + } } - if (fieldsWithoutDefaults.nonEmpty) { + if (missingFields.nonEmpty) { + val missingFieldNames = missingFields.keys.mkString(", ") + val exampleFields = (queryFields ++ missingFields.keys).mkString(", ") + + val unknownFields = queryFields.filterNot(f => requiredFields.contains(f)) + val suggestions = unknownFields.flatMap { unknown => + findClosestMatch(unknown, missingFields.keys.toSeq) + } + val suggestionMsg = if (suggestions.nonEmpty) { + s"\nDid you mean: ${suggestions.mkString(", ")}?" + } else "" + c.abort( c.enclosingPosition, - s"❌ SQL query does not select the following required fields from ${tpe.typeSymbol.name}:\n" + - s" ${fieldsWithoutDefaults.mkString(", ")}\n\n" + - s"These fields are missing from the query:\n" + - s" SELECT ${queryFields.mkString(", ")} FROM ...\n\n" + - s"To fix this, either:\n" + - s" 1. Add them to the SELECT clause\n" + - s" 2. Make them Option[T] in the case class\n" + - s" 3. Provide default values in the case class definition" + s"""❌ SQL query does not select the following required fields from ${tpe.typeSymbol.name}: + |$missingFieldNames$suggestionMsg + | + |These fields are missing from the query: + |SELECT ${exampleFields} FROM ... + | + |To fix this, either: + | 1. Add them to the SELECT clause + | 2. Make them Option[T] in the case class + | 3. Provide default values in the case class definition""".stripMargin ) } - - debug(c)("✅ All missing fields have defaults or are Option") - } - - // Helper: Get the index of a field in the case class constructor - private def getFieldIndex(c: blackbox.Context)( - tpe: c.universe.Type, - fieldName: String - ): Int = { - import c.universe._ - - val constructor = tpe.decl(termNames.CONSTRUCTOR).asMethod - val params = constructor.paramLists.flatten - - params.indexWhere(_.name.toString == fieldName) } // ============================================================ - // VALIDATION 3: Type compatibility + // Helper: Validate Type compatibility // ============================================================ private def validateTypes(c: blackbox.Context)( parsedQuery: SQLSearchRequest, - caseClassFields: Map[String, c.universe.Type] + requiredFields: Map[String, c.universe.Type] ): Unit = { parsedQuery.select.fields.foreach { field => val fieldName = field.fieldAlias.map(_.alias).getOrElse(field.identifier.name) - (field.out, caseClassFields.get(fieldName)) match { + (field.out, requiredFields.get(fieldName)) match { case (sqlType, Some(scalaType)) => if (!areTypesCompatible(c)(sqlType, scalaType)) { c.abort( @@ -405,6 +399,9 @@ trait SQLQueryValidator { debug(c)("✅ Type validation passed") } + // ============================================================ + // Helper: Check if SQL type is compatible with Scala type + // ============================================================ private def areTypesCompatible(c: blackbox.Context)( sqlType: SQLType, scalaType: c.universe.Type @@ -489,6 +486,9 @@ trait SQLQueryValidator { } } + // ============================================================ + // Helper: Get compatible Scala types for a given SQL type + // ============================================================ private def getCompatibleScalaTypes(sqlType: SQLType): String = { sqlType match { case SQLTypes.TinyInt => @@ -509,6 +509,9 @@ trait SQLQueryValidator { } } + // ============================================================ + // Helper: Find closest matching field name + // ============================================================ private def findClosestMatch(target: String, candidates: Seq[String]): Option[String] = { if (candidates.isEmpty) None else { @@ -520,6 +523,9 @@ trait SQLQueryValidator { } } + // ============================================================ + // Helper: Compute Levenshtein distance between two strings + // ============================================================ private def levenshteinDistance(s1: String, s2: String): Int = { val dist = Array.tabulate(s2.length + 1, s1.length + 1) { (j, i) => if (j == 0) i else if (i == 0) j else 0 @@ -537,9 +543,12 @@ trait SQLQueryValidator { dist(s2.length)(s1.length) } - protected def debug(c: blackbox.Context)(msg: String): Unit = { + // ============================================================ + // Helper: Debug logging + // ============================================================ + private def debug(c: blackbox.Context)(msg: String): Unit = { if (SQLQueryValidator.DEBUG) { - debug(c)(msg) + c.info(c.enclosingPosition, msg, force = true) } } } diff --git a/testkit/src/main/scala/app/softnetwork/elastic/client/ElasticClientSpec.scala b/testkit/src/main/scala/app/softnetwork/elastic/client/ElasticClientSpec.scala index f2985eb0..898f78c1 100644 --- a/testkit/src/main/scala/app/softnetwork/elastic/client/ElasticClientSpec.scala +++ b/testkit/src/main/scala/app/softnetwork/elastic/client/ElasticClientSpec.scala @@ -226,7 +226,9 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M "person_mapping" should haveCount(3) - pClient.searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person_mapping") match { + pClient.searchAs[Person]( + "select uuid, name, birthDate, createdDate, lastUpdated from person_mapping" + ) match { case ElasticSuccess(value) => value match { case r if r.size == 3 => @@ -237,7 +239,11 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M fail(elasticError.fullMessage) } - pClient.searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person_mapping where uuid = 'A16'").get match { + pClient + .searchAs[Person]( + "select uuid, name, birthDate, createdDate, lastUpdated from person_mapping where uuid = 'A16'" + ) + .get match { case r if r.size == 1 => r.map(_.uuid) should contain only "A16" case other => fail(other.toString) @@ -382,7 +388,9 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M "person1" should haveCount(3) - pClient.searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person1").get match { + pClient + .searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person1") + .get match { case r if r.size == 3 => r.map(_.uuid) should contain allOf ("A12", "A14", "A16") r.map(_.name) should contain allOf ("Homer Simpson", "Moe Szyslak", "Barney Gumble") @@ -406,7 +414,9 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M "person2" should haveCount(3) - pClient.searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person2").get match { + pClient + .searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person2") + .get match { case r if r.size == 3 => r.map(_.uuid) should contain allOf ("A12", "A14", "A16") r.map(_.name) should contain allOf ("Homer Simpson", "Moe Szyslak", "Barney Gumble") @@ -448,7 +458,9 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M "person-1969-05-09" should haveCount(1) pClient - .searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person-1967-11-21, person-1969-05-09") + .searchAs[Person]( + "select uuid, name, birthDate, createdDate, lastUpdated from person-1967-11-21, person-1969-05-09" + ) .get match { case r if r.size == 3 => r.map(_.uuid) should contain allOf ("A12", "A14", "A16") @@ -496,7 +508,9 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M "person4" should haveCount(3) - pClient.searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person4").get match { + pClient + .searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person4") + .get match { case r if r.size == 3 => r.map(_.uuid) should contain allOf ("A12", "A14", "A16") r.map(_.name) should contain allOf ("Homer Simpson", "Moe Szyslak", "Barney Gumble2") @@ -539,7 +553,9 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M "person5-1969-05-09" should haveCount(1) pClient - .searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person5-1967-11-21, person5-1969-05-09") + .searchAs[Person]( + "select uuid, name, birthDate, createdDate, lastUpdated from person5-1967-11-21, person5-1969-05-09" + ) .get match { case r if r.size == 3 => r.map(_.uuid) should contain allOf ("A12", "A14", "A16") @@ -615,11 +631,15 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M "person7" should haveCount(3) - val r1 = pClient.searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person7").get + val r1 = pClient + .searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person7") + .get r1.size should ===(3) r1.map(_.uuid) should contain allOf ("A12", "A14", "A16") - pClient.searchAsyncAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person7") onComplete { + pClient.searchAsyncAs[Person]( + "select uuid, name, birthDate, createdDate, lastUpdated from person7" + ) onComplete { case Success(s) => val r = s.get r.size should ===(3) @@ -627,11 +647,17 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M case Failure(f) => fail(f.getMessage) } - val r2 = pClient.searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person7 where _id=\"A16\"").get + val r2 = pClient + .searchAs[Person]( + "select uuid, name, birthDate, createdDate, lastUpdated from person7 where _id=\"A16\"" + ) + .get r2.size should ===(1) r2.map(_.uuid) should contain("A16") - pClient.searchAsyncAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person7 where _id=\"A16\"") onComplete { + pClient.searchAsyncAs[Person]( + "select uuid, name, birthDate, createdDate, lastUpdated from person7 where _id=\"A16\"" + ) onComplete { case Success(s) => val r = s.get r.size should ===(1) @@ -662,7 +688,9 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M "person8" should haveCount(3) - val response = pClient.searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person8").get + val response = pClient + .searchAs[Person]("select uuid, name, birthDate, createdDate, lastUpdated from person8") + .get response.size should ===(3) @@ -1178,15 +1206,20 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M | p.uuid, | p.name, | p.birthDate, - | children.name, - | children.birthDate, - | children.parentId + | p.children | FROM | parent as p - | JOIN UNNEST(p.children) as children |""".stripMargin ) - parents.get.size shouldBe 3 + parents match { + case ElasticSuccess(ps) => ps.size shouldBe 3 + case ElasticFailure(error) => + error.cause match { + case Some(cause) => log.error("Error during search", cause) + case None => + } + fail(s"Error during search: ${error.message}") + } val results = parentClient .searchWithInnerHits[Parent, Child]( @@ -1221,18 +1254,18 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M val searchResults = parentClient .searchAs[Parent]( """SELECT - | p.uuid, - | p.name, - | p.birthDate, - | children.name, - | children.birthDate, - | children.parentId - | FROM - | parent as p - | JOIN UNNEST(p.children) as children - |WHERE - | children.name is not null AND p.uuid = 'A16' - |""".stripMargin + | p.uuid, + | p.name, + | p.birthDate, + | children.name, + | children.birthDate, + | children.parentId + | FROM + | parent as p + | JOIN UNNEST(p.children) as children + |WHERE + | children.name is not null AND p.uuid = 'A16' + |""".stripMargin ) .get searchResults.size shouldBe 1 From 1217cc8f2fcfd5ef6966fa2f8b7d2b91457fcd44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Mon, 10 Nov 2025 10:52:26 +0100 Subject: [PATCH 05/11] update query validator to cache validated queries + update type compatibility check --- .../sql/macros/SQLQueryValidator.scala | 120 ++++++++++-------- 1 file changed, 66 insertions(+), 54 deletions(-) diff --git a/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala index 8b3361d9..beac3f49 100644 --- a/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala +++ b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala @@ -42,6 +42,12 @@ trait SQLQueryValidator { // 1. Extract the SQL query (must be a literal) val sqlQuery = extractSQLString(c)(query) + // ✅ Check if already validated + if (SQLQueryValidator.isCached(sqlQuery)) { + debug(c)(s"✅ Query already validated (cached): $sqlQuery") + return sqlQuery + } + if (sys.props.get("elastic.sql.debug").contains("true")) { c.info(c.enclosingPosition, s"Validating SQL: $sqlQuery", force = false) } @@ -75,6 +81,9 @@ trait SQLQueryValidator { debug(c)("=" * 80) // 8. Return the validated request + // ✅ Mark as validated + SQLQueryValidator.markValidated(sqlQuery) + sqlQuery } @@ -167,10 +176,11 @@ trait SQLQueryValidator { // ============================================================ // Reject SELECT * (incompatible with compile-time validation) // ============================================================ - private def rejectSelectStar(c: blackbox.Context)( + private def rejectSelectStar[T: c.WeakTypeTag](c: blackbox.Context)( parsedQuery: SQLSearchRequest, sqlQuery: String ): Unit = { + import c.universe._ // Check if any field is a wildcard (*) val hasWildcard = parsedQuery.select.fields.exists { field => @@ -178,6 +188,10 @@ trait SQLQueryValidator { } if (hasWildcard) { + val tpe = weakTypeOf[T] + val requiredFields = getRequiredFields(c)(tpe) + val fieldNames = requiredFields.keys.mkString(", ") + c.abort( c.enclosingPosition, s"""❌ SELECT * is not allowed with compile-time validation. @@ -190,11 +204,11 @@ trait SQLQueryValidator { | • Schema changes will break silently at runtime | |Solution: - | 1. Explicitly list all required fields: - | SELECT id, name, price FROM products + | 1. Explicitly list all required fields for ${tpe.typeSymbol.name}: + | SELECT $fieldNames FROM ... | | 2. Use the *Unchecked() variant for dynamic queries: - | searchAsUnchecked[Product](SQLQuery("SELECT * FROM products")) + | searchAsUnchecked[${tpe.typeSymbol.name}](SQLQuery("SELECT * FROM ...")) | |Best Practice: | Always explicitly select only the fields you need. @@ -278,12 +292,16 @@ trait SQLQueryValidator { */ private def extractQueryFields(parsedQuery: SQLSearchRequest): Set[String] = { parsedQuery.select.fields.flatMap { field => - val f = field.fieldAlias.map(_.alias).getOrElse(field.identifier.name) - /*field.identifier.nestedElement match { - case Some(nested) => List(f, nested.innerHitsName) - case None => List(f) - }*/ - List(f) + val fieldName = field.fieldAlias.map(_.alias).getOrElse(field.identifier.name) + + // ✅ Manage nested fields (ex: "children.name" → "children", "children.name") + val nestedParts = fieldName.split("\\.").toList + + // Return all levels of nested fields + // Ex: "children.address.city" → ["children", "children.address", "children.address.city"] + nestedParts.indices.map { i => + nestedParts.take(i + 1).mkString(".") + } }.toSet } @@ -408,74 +426,63 @@ trait SQLQueryValidator { ): Boolean = { import c.universe._ + val underlyingType = if (scalaType <:< typeOf[Option[_]]) { + scalaType.typeArgs.headOption.getOrElse(scalaType) + } else { + scalaType + } + sqlType match { case SQLTypes.TinyInt => - scalaType =:= typeOf[Byte] || - scalaType =:= typeOf[Short] || - scalaType =:= typeOf[Int] || - scalaType =:= typeOf[Long] || - scalaType =:= typeOf[Option[Byte]] || - scalaType =:= typeOf[Option[Short]] || - scalaType =:= typeOf[Option[Int]] || - scalaType =:= typeOf[Option[Long]] + underlyingType =:= typeOf[Byte] || + underlyingType =:= typeOf[Short] || + underlyingType =:= typeOf[Int] || + underlyingType =:= typeOf[Long] case SQLTypes.SmallInt => - scalaType =:= typeOf[Short] || - scalaType =:= typeOf[Int] || - scalaType =:= typeOf[Long] || - scalaType =:= typeOf[Option[Short]] || - scalaType =:= typeOf[Option[Int]] || - scalaType =:= typeOf[Option[Long]] + underlyingType =:= typeOf[Short] || + underlyingType =:= typeOf[Int] || + underlyingType =:= typeOf[Long] case SQLTypes.Int => - scalaType =:= typeOf[Int] || - scalaType =:= typeOf[Long] || - scalaType =:= typeOf[Option[Int]] || - scalaType =:= typeOf[Option[Long]] + underlyingType =:= typeOf[Int] || + underlyingType =:= typeOf[Long] case SQLTypes.BigInt => - scalaType =:= typeOf[Long] || - scalaType =:= typeOf[BigInt] || - scalaType =:= typeOf[Option[Long]] || - scalaType =:= typeOf[Option[BigInt]] + underlyingType =:= typeOf[Long] || + underlyingType =:= typeOf[BigInt] case SQLTypes.Double | SQLTypes.Real => - scalaType =:= typeOf[Double] || - scalaType =:= typeOf[Float] || - scalaType =:= typeOf[Option[Double]] || - scalaType =:= typeOf[Option[Float]] + underlyingType =:= typeOf[Double] || + underlyingType =:= typeOf[Float] case SQLTypes.Char => - scalaType =:= typeOf[String] || // CHAR(n) → String - scalaType =:= typeOf[Char] || // CHAR(1) → Char - scalaType =:= typeOf[Option[String]] || - scalaType =:= typeOf[Option[Char]] + underlyingType =:= typeOf[String] || // CHAR(n) → String + underlyingType =:= typeOf[Char] // CHAR(1) → Char case SQLTypes.Varchar => - scalaType =:= typeOf[String] || - scalaType =:= typeOf[Option[String]] + underlyingType =:= typeOf[String] case SQLTypes.Boolean => - scalaType =:= typeOf[Boolean] || - scalaType =:= typeOf[Option[Boolean]] + underlyingType =:= typeOf[Boolean] case SQLTypes.Time => - scalaType.toString.contains("Instant") || - scalaType.toString.contains("LocalTime") + underlyingType.toString.contains("Instant") || + underlyingType.toString.contains("LocalTime") case SQLTypes.Date => - scalaType.toString.contains("Date") || - scalaType.toString.contains("Instant") || - scalaType.toString.contains("LocalDate") + underlyingType.toString.contains("Date") || + underlyingType.toString.contains("Instant") || + underlyingType.toString.contains("LocalDate") case SQLTypes.DateTime | SQLTypes.Timestamp => - scalaType.toString.contains("LocalDateTime") || - scalaType.toString.contains("ZonedDateTime") || - scalaType.toString.contains("Instant") + underlyingType.toString.contains("LocalDateTime") || + underlyingType.toString.contains("ZonedDateTime") || + underlyingType.toString.contains("Instant") case SQLTypes.Struct => - if (scalaType.typeSymbol.isClass && scalaType.typeSymbol.asClass.isCaseClass) { - // validateStructFields(c)(sqlField, scalaType) + if (underlyingType.typeSymbol.isClass && underlyingType.typeSymbol.asClass.isCaseClass) { + // TODO validateStructFields(c)(sqlField, underlyingType) true } else { false @@ -555,4 +562,9 @@ trait SQLQueryValidator { object SQLQueryValidator { val DEBUG: Boolean = sys.props.get("sql.macro.debug").contains("true") + // ✅ Cache pour éviter les validations redondantes + private val validationCache = scala.collection.mutable.Map[String, Boolean]() + + private def isCached(sql: String): Boolean = validationCache.contains(sql.trim) + private def markValidated(sql: String): Unit = validationCache(sql.trim) = true } From 92e419c15fe3af409858873210a8b797172d875a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Mon, 10 Nov 2025 13:11:46 +0100 Subject: [PATCH 06/11] add support for nested objects within query validator --- .../elastic/client/SearchApi.scala | 8 +- .../sql/macros/SQLQueryValidator.scala | 466 +++++++++++++++--- .../elastic/client/ElasticClientSpec.scala | 11 +- .../softnetwork/elastic/model/Parent.scala | 2 +- 4 files changed, 402 insertions(+), 85 deletions(-) diff --git a/core/src/main/scala/app/softnetwork/elastic/client/SearchApi.scala b/core/src/main/scala/app/softnetwork/elastic/client/SearchApi.scala index 4822375e..e2d460af 100644 --- a/core/src/main/scala/app/softnetwork/elastic/client/SearchApi.scala +++ b/core/src/main/scala/app/softnetwork/elastic/client/SearchApi.scala @@ -969,14 +969,18 @@ trait SearchApi extends ElasticConversion with ElasticClientHelpers { val results = ElasticResult.fromTry(convertTo[U](response)) results .fold( - onFailure = error => + onFailure = error => { + logger.error( + s"❌ Conversion to entities failed: ${error.message} with query \n${response.query}\n and results:\n ${response.results}" + ) ElasticResult.failure( ElasticError( message = s"Failed to convert search results to ${m.runtimeClass.getSimpleName}", cause = error.cause, operation = Some("convertToEntities") ) - ), + ) + }, onSuccess = entities => ElasticResult.success(entities) ) } diff --git a/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala index beac3f49..95b9e1b8 100644 --- a/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala +++ b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala @@ -39,11 +39,13 @@ trait SQLQueryValidator { debug(c)("🚀 MACRO IS BEING CALLED!") - // 1. Extract the SQL query (must be a literal) + // ✅ Extract the SQL query (must be a literal) val sqlQuery = extractSQLString(c)(query) + val tpe = c.weakTypeOf[T] + // ✅ Check if already validated - if (SQLQueryValidator.isCached(sqlQuery)) { + if (SQLQueryValidator.isCached(c)(tpe, sqlQuery)) { debug(c)(s"✅ Query already validated (cached): $sqlQuery") return sqlQuery } @@ -52,38 +54,43 @@ trait SQLQueryValidator { c.info(c.enclosingPosition, s"Validating SQL: $sqlQuery", force = false) } - // 2. Parse the SQL query + // ✅ Parse the SQL query val parsedQuery = parseSQLQuery(c)(sqlQuery) - // ============================================================ - // Reject SELECT * - // ============================================================ + // ✅ Reject SELECT * rejectSelectStar(c)(parsedQuery, sqlQuery) - // 3. Extract the selected fields from the query + // ✅ Extract the selected fields from the query val queryFields = extractQueryFields(parsedQuery) debug(c)(s"🔍 Parsed fields: ${queryFields.mkString(", ")}") - // 4. Extract the required fields from case class T - val tpe = c.weakTypeOf[T] + // ✅ Extract UNNEST information from the query + val unnestedCollections = extractUnnestedCollections(parsedQuery) + + debug(c)(s"🔍 Unnested collections: ${unnestedCollections.mkString(", ")}") + + // ✅ Recursive validation of required fields + validateRequiredFieldsRecursive(c)(tpe, queryFields, unnestedCollections, prefix = "") + + // ✅ Recursive validation of unknown fields + validateUnknownFieldsRecursive(c)(tpe, queryFields, prefix = "") + + // ✅ Extract required fields from the case class val requiredFields = getRequiredFields(c)(tpe) debug(c)(s"📦 Case class fields: ${requiredFields.mkString(", ")}") - // 5. Validate: missing case class fields must have defaults or be Option - validateRequiredFields(c)(queryFields) - - // 7. Validate the types + // ✅ Type validation validateTypes(c)(parsedQuery, requiredFields.map(values => values._1 -> values._2._1)) debug(c)("=" * 80) debug(c)("✅ SQL Query Validation Complete") debug(c)("=" * 80) - // 8. Return the validated request // ✅ Mark as validated - SQLQueryValidator.markValidated(sqlQuery) + SQLQueryValidator.markValidated(c)(tpe, sqlQuery) + // ✅ Return the validated request sqlQuery } @@ -219,6 +226,13 @@ trait SQLQueryValidator { debug(c)("✅ No SELECT * detected") } + // ============================================================ + // Helper: Check if a type is a case class + // ============================================================ + private def isCaseClassType(c: blackbox.Context)(tpe: c.universe.Type): Boolean = { + tpe.typeSymbol.isClass && tpe.typeSymbol.asClass.isCaseClass + } + // ============================================================ // Helper: Detect if a type is a collection // ============================================================ @@ -250,7 +264,7 @@ trait SQLQueryValidator { } // ============================================================ - // Helper: Extract the required fields from a class case + // Helper: Extract the required fields from a case class // ============================================================ private def getRequiredFields( c: blackbox.Context @@ -269,106 +283,170 @@ trait SQLQueryValidator { val paramName = param.name.decodedName.toString val paramType = param.typeSignature - // Check if the parameter has a default value or is an option. val isOption = paramType <:< typeOf[Option[_]] - val hasDefault = param.asTerm.isParamWithDefault - /* We should not filter out optional parameters here, - because we need to know all fields to validate their types later. - - if (isOption || hasDefault) { - None - } else { - Some((paramName, (paramType, isOption, hasDefault))) - }*/ - Some((paramName, (paramType, isOption, hasDefault))) - }.toMap } - /** Extracts selected fields from the parsed SQL query. - */ + // ============================================================ + // Helper: Extract selected fields from parsed SQL query + // ============================================================ private def extractQueryFields(parsedQuery: SQLSearchRequest): Set[String] = { - parsedQuery.select.fields.flatMap { field => - val fieldName = field.fieldAlias.map(_.alias).getOrElse(field.identifier.name) - - // ✅ Manage nested fields (ex: "children.name" → "children", "children.name") - val nestedParts = fieldName.split("\\.").toList + parsedQuery.select.fields.map { field => + field.fieldAlias.map(_.alias).getOrElse(field.identifier.name) + }.toSet + } - // Return all levels of nested fields - // Ex: "children.address.city" → ["children", "children.address", "children.address.city"] - nestedParts.indices.map { i => - nestedParts.take(i + 1).mkString(".") + // ============================================================ + // Helper: Extract UNNEST collections from the query + // ============================================================ + private def extractUnnestedCollections(parsedQuery: SQLSearchRequest): Set[String] = { + // Check if the query has nested elements (UNNEST) + parsedQuery.select.fields.flatMap { field => + field.identifier.nestedElement.map { nested => + // Extract the collection name from the nested element + // Example: "children" from "UNNEST(parent.children) AS children" + nested.innerHitsName } }.toSet } // ============================================================ - // Helper: Validate required vs. selected fields + // Helper: Recursive Validation of Required Fields // ============================================================ - private def validateRequiredFields[T: c.WeakTypeTag]( + private def validateRequiredFieldsRecursive( c: blackbox.Context )( - queryFields: Set[String] + tpe: c.universe.Type, + queryFields: Set[String], + unnestedCollections: Set[String], + prefix: String ): Unit = { - import c.universe._ - val tpe = weakTypeOf[T] val requiredFields = getRequiredFields(c)(tpe) val missingFields = requiredFields.filterNot { case (fieldName, (fieldType, isOption, hasDefault)) => - // ✅ Check if the field is selected - val isSelected = queryFields.contains(fieldName) + val fullFieldName = if (prefix.isEmpty) fieldName else s"$prefix.$fieldName" - if (!isSelected) { - debug(c)(s"⚠️ Missing field: $fieldName") + // ✅ Check if the field is directly selected (e.g., "address") + val isDirectlySelected = queryFields.contains(fullFieldName) - if (isOption) { - debug(c)(s"✅ Field '$fieldName' is Option - OK") - true - } else if (hasDefault) { - debug(c)(s"✅ Field '$fieldName' has default value - OK") - true - } - // ✅ If it's a collection, check if its nested fields are selected. - else if (isCollectionType(c)(fieldType)) { - getCollectionElementType(c)(fieldType) match { - case Some(elementType) => - // Check if the nested fields of the collection are selected - // Eg: "children.name", "children.birthDate" - val nestedFields = getRequiredFields(c)(elementType) - val hasNestedFields = nestedFields.forall { case (nestedFieldName, _) => - queryFields.exists(f => f.startsWith(s"$fieldName.$nestedFieldName")) - } + // ✅ Check if nested fields of this field are selected (e.g., "address.street") + val hasNestedSelection = queryFields.exists(_.startsWith(s"$fullFieldName.")) - if (hasNestedFields) { - // ✅ The nested fields are present, so the collection is considered valid. - debug(c)(s"✅ Collection field '$fieldName' validated via nested fields") + if (isDirectlySelected) { + // ✅ Field is selected as a whole (e.g., SELECT address FROM ...) + debug(c)(s"✅ Field '$fullFieldName' is directly selected") + true + } else if (isOption) { + // ✅ Field is optional, can be omitted + debug(c)(s"✅ Field '$fullFieldName' is Option - OK") + true + } else if (hasDefault) { + // ✅ Field has a default value, can be omitted + debug(c)(s"✅ Field '$fullFieldName' has default value - OK") + true + } else if (hasNestedSelection) { + // ⚠️ Nested fields are selected (e.g., SELECT address.street FROM ...) + // We must validate that ALL required nested fields are present + + if (isCollectionType(c)(fieldType)) { + // ✅ Collection: check if it's unnested + validateCollectionFieldsRecursive(c)( + fieldName, + fieldType, + queryFields, + unnestedCollections, + prefix + ) + } else if (isCaseClassType(c)(fieldType)) { + // ✅ Nested case class: validate that ALL required nested fields are selected + debug(c)(s"🔍 Validating nested case class fields: $fullFieldName") + + try { + validateRequiredFieldsRecursive(c)( + fieldType, + queryFields, + unnestedCollections, + prefix = fullFieldName + ) + // ✅ All required nested fields are present + debug(c)(s"✅ All required nested fields of '$fullFieldName' are present") + true + } catch { + case _: Throwable => + // ❌ Some required nested fields are missing + val nestedFields = getRequiredFields(c)(fieldType) + val missingNestedFields = nestedFields.filterNot { + case (nestedFieldName, (nestedFieldType, nestedIsOption, nestedHasDefault)) => + val fullNestedFieldName = s"$fullFieldName.$nestedFieldName" + val isNestedSelected = queryFields.contains(fullNestedFieldName) + val hasNestedNestedSelection = + queryFields.exists(_.startsWith(s"$fullNestedFieldName.")) + + isNestedSelected || nestedIsOption || nestedHasDefault || + (hasNestedNestedSelection && isCaseClassType(c)(nestedFieldType)) } - hasNestedFields + if (missingNestedFields.nonEmpty) { + val missingNames = + missingNestedFields.keys.map(n => s"$fullFieldName.$n").mkString(", ") + val allRequiredFields = nestedFields + .filterNot { case (_, (_, isOpt, hasDef)) => isOpt || hasDef } + .keys + .map(n => s"$fullFieldName.$n") + .mkString(", ") + + c.abort( + c.enclosingPosition, + s"""❌ Nested case class field '$fullFieldName' has missing required fields: + |$missingNames + | + |When selecting nested fields individually, ALL required fields must be present. + | + |Option 1: Select the entire nested object: + | SELECT $fullFieldName FROM ... + | + |Option 2: Select ALL required nested fields: + | SELECT $allRequiredFields FROM ... + | + |Option 3: Make missing fields optional or provide default values in the case class + |""".stripMargin + ) + } - case None => false + false } } else { + // ✅ Primitive type with nested selection (shouldn't happen) + debug(c)(s"⚠️ Unexpected nested selection for primitive field: $fullFieldName") false } } else { - true + // ❌ Field is not selected at all + debug(c)(s"❌ Field '$fullFieldName' is missing") + false } } if (missingFields.nonEmpty) { - val missingFieldNames = missingFields.keys.mkString(", ") - val exampleFields = (queryFields ++ missingFields.keys).mkString(", ") + val missingFieldNames = missingFields.keys + .map { fieldName => + if (prefix.isEmpty) fieldName else s"$prefix.$fieldName" + } + .mkString(", ") + + val exampleFields = (queryFields ++ missingFields.keys.map { fieldName => + if (prefix.isEmpty) fieldName else s"$prefix.$fieldName" + }).mkString(", ") - val unknownFields = queryFields.filterNot(f => requiredFields.contains(f)) - val suggestions = unknownFields.flatMap { unknown => - findClosestMatch(unknown, missingFields.keys.toSeq) + val suggestions = missingFields.keys.flatMap { fieldName => + findClosestMatch(fieldName, queryFields.map(_.split("\\.").last).toSeq) } + val suggestionMsg = if (suggestions.nonEmpty) { s"\nDid you mean: ${suggestions.mkString(", ")}?" } else "" @@ -379,7 +457,7 @@ trait SQLQueryValidator { |$missingFieldNames$suggestionMsg | |These fields are missing from the query: - |SELECT ${exampleFields} FROM ... + |SELECT $exampleFields FROM ... | |To fix this, either: | 1. Add them to the SELECT clause @@ -389,6 +467,214 @@ trait SQLQueryValidator { } } + // ============================================================ + // Helper: Validate Collection Fields Recursively + // ============================================================ + private def validateCollectionFieldsRecursive( + c: blackbox.Context + )( + fieldName: String, + fieldType: c.universe.Type, + queryFields: Set[String], + unnestedCollections: Set[String], + prefix: String + ): Boolean = { + + getCollectionElementType(c)(fieldType) match { + case Some(elementType) => + val fullFieldName = if (prefix.isEmpty) fieldName else s"$prefix.$fieldName" + + // ✅ Check if the collection is selected as a whole + val isDirectlySelected = queryFields.contains(fullFieldName) + + // ✅ Check if the collection is unnested (uses UNNEST) + val isUnnested = unnestedCollections.contains(fieldName) + + if (isDirectlySelected) { + debug(c)(s"✅ Collection field '$fullFieldName' is directly selected") + true + } else if (isUnnested) { + debug(c)(s"✅ Collection field '$fullFieldName' is unnested") + // ✅ For unnested collections, validate nested fields + if (isCaseClassType(c)(elementType)) { + val nestedFields = getRequiredFields(c)(elementType) + + val missingNestedFields = nestedFields.filterNot { + case (nestedFieldName, (nestedFieldType, isOption, hasDefault)) => + val fullNestedFieldName = s"$fullFieldName.$nestedFieldName" + + val isSelected = queryFields.contains(fullNestedFieldName) + val hasNestedSelection = queryFields.exists(_.startsWith(s"$fullNestedFieldName.")) + + if (isSelected) { + debug(c)(s"✅ Nested field '$fullNestedFieldName' is selected") + true + } else if (isOption) { + debug(c)(s"✅ Nested field '$fullNestedFieldName' is Option - OK") + true + } else if (hasDefault) { + debug(c)(s"✅ Nested field '$fullNestedFieldName' has default value - OK") + true + } else if (hasNestedSelection && isCaseClassType(c)(nestedFieldType)) { + debug(c)(s"🔍 Validating deeply nested case class: $fullNestedFieldName") + try { + validateRequiredFieldsRecursive(c)( + nestedFieldType, + queryFields, + unnestedCollections, + prefix = fullNestedFieldName + ) + true + } catch { + case _: Throwable => false + } + } else { + debug(c)(s"❌ Nested field '$fullNestedFieldName' is missing") + false + } + } + + if (missingNestedFields.nonEmpty) { + val missingNames = + missingNestedFields.keys.map(n => s"$fullFieldName.$n").mkString(", ") + val allRequiredFields = nestedFields + .filterNot { case (_, (_, isOpt, hasDef)) => isOpt || hasDef } + .keys + .map(n => s"$fullFieldName.$n") + .mkString(", ") + + c.abort( + c.enclosingPosition, + s"""❌ Unnested collection field '$fullFieldName' is missing required nested fields: + |$missingNames + | + |When using UNNEST, ALL required nested fields must be selected: + | SELECT $allRequiredFields FROM parent JOIN UNNEST(parent.$fieldName) AS $fieldName + | + |Or make missing nested fields optional or provide default values + |""".stripMargin + ) + } + + true + } else { + true + } + } else if (isCaseClassType(c)(elementType)) { + // ❌ Collection of case classes with nested field selection but NO UNNEST + val hasNestedSelection = queryFields.exists(_.startsWith(s"$fullFieldName.")) + + if (hasNestedSelection) { + c.abort( + c.enclosingPosition, + s"""❌ Collection field '$fullFieldName' cannot be deserialized correctly. + | + |You are selecting nested fields of a collection without using UNNEST: + | ${queryFields.filter(_.startsWith(s"$fullFieldName.")).mkString(", ")} + | + |This will result in flat arrays that cannot be reconstructed into objects. + | + |Example of the problem: + | Elasticsearch returns: { "children.name": ["Alice", "Bob"], "children.age": [10, 12] } + | But we need: { "children": [{"name": "Alice", "age": 10}, {"name": "Bob", "age": 12}] } + | + |Solution 1: Select the entire collection: + | SELECT $fullFieldName FROM ... + | + |Solution 2: Use UNNEST to properly handle nested objects: + | SELECT name, $fieldName.name, $fieldName.age + | FROM ${if (prefix.isEmpty) "table" else prefix} + | JOIN UNNEST(${if (prefix.isEmpty) "" else s"$prefix."}$fieldName) AS $fieldName + | + |Solution 3: Make the collection optional: + | $fullFieldName: Option[List[${elementType.typeSymbol.name}]] = None + |""".stripMargin + ) + } + + false + } else { + // ✅ Collection of primitive types + true + } + + case None => + debug(c)(s"⚠️ Cannot extract element type from collection: $fieldName") + false + } + } + + // ============================================================ + // Helper: Build all valid field paths recursively + // ============================================================ + private def buildValidFieldPaths( + c: blackbox.Context + )( + tpe: c.universe.Type, + prefix: String + ): Set[String] = { + + val requiredFields = getRequiredFields(c)(tpe) + + requiredFields.flatMap { case (fieldName, (fieldType, _, _)) => + val fullFieldName = if (prefix.isEmpty) fieldName else s"$prefix.$fieldName" + + if (isCollectionType(c)(fieldType)) { + getCollectionElementType(c)(fieldType) match { + case Some(elementType) if isCaseClassType(c)(elementType) => + // ✅ Collection of case classes: recurse + Set(fullFieldName) ++ buildValidFieldPaths(c)(elementType, fullFieldName) + case _ => + Set(fullFieldName) + } + } else if (isCaseClassType(c)(fieldType)) { + // ✅ Nested case class: recurse + Set(fullFieldName) ++ buildValidFieldPaths(c)(fieldType, fullFieldName) + } else { + Set(fullFieldName) + } + }.toSet + } + + // ============================================================ + // Helper: Validate Unknown Fields Recursively + // ============================================================ + private def validateUnknownFieldsRecursive( + c: blackbox.Context + )( + tpe: c.universe.Type, + queryFields: Set[String], + prefix: String + ): Unit = { + + val requiredFields = getRequiredFields(c)(tpe) + + // ✅ Get all valid field paths at this level and below + val validFieldPaths = buildValidFieldPaths(c)(tpe, prefix) + + // ✅ Find unknown fields + val unknownFields = queryFields.filterNot { queryField => + validFieldPaths.contains(queryField) || + validFieldPaths.exists(vf => queryField.startsWith(s"$vf.")) + } + + if (unknownFields.nonEmpty) { + val unknownFieldNames = unknownFields.mkString(", ") + val availableFields = validFieldPaths.toSeq.sorted.mkString(", ") + + c.warning( + c.enclosingPosition, + s"""⚠️ SQL query selects fields that don't exist in ${tpe.typeSymbol.name}: + |$unknownFieldNames + | + |Available fields: $availableFields + | + |Note: These fields will be ignored during deserialization. + |""".stripMargin + ) + } + } + // ============================================================ // Helper: Validate Type compatibility // ============================================================ @@ -562,9 +848,29 @@ trait SQLQueryValidator { object SQLQueryValidator { val DEBUG: Boolean = sys.props.get("sql.macro.debug").contains("true") - // ✅ Cache pour éviter les validations redondantes + + // ✅ Cache to avoid redundant validations private val validationCache = scala.collection.mutable.Map[String, Boolean]() - private def isCached(sql: String): Boolean = validationCache.contains(sql.trim) - private def markValidated(sql: String): Unit = validationCache(sql.trim) = true + private[macros] def isCached( + c: blackbox.Context + )(tpe: c.universe.Type, sql: String): Boolean = { + // ✅ Disable cache in test mode + if (sys.props.get("sql.macro.test").contains("true")) { + false + } else { + validationCache.contains(s"${tpe.typeSymbol.name}::${sql.trim}") + } + } + + private[macros] def markValidated( + c: blackbox.Context + )(tpe: c.universe.Type, sql: String): Unit = { + if (!sys.props.get("sql.macro.test").contains("true")) { + validationCache(s"${tpe.typeSymbol.name}::${sql.trim}") = true + } + } + + // ✅ Method for clearing the cache + private[macros] def clearCache(): Unit = validationCache.clear() } diff --git a/testkit/src/main/scala/app/softnetwork/elastic/client/ElasticClientSpec.scala b/testkit/src/main/scala/app/softnetwork/elastic/client/ElasticClientSpec.scala index 898f78c1..c83d38d9 100644 --- a/testkit/src/main/scala/app/softnetwork/elastic/client/ElasticClientSpec.scala +++ b/testkit/src/main/scala/app/softnetwork/elastic/client/ElasticClientSpec.scala @@ -1206,13 +1206,20 @@ trait ElasticClientSpec extends AnyFlatSpecLike with ElasticDockerTestKit with M | p.uuid, | p.name, | p.birthDate, - | p.children + | children.name, + | children.birthDate, + | children.parentId | FROM | parent as p + | JOIN UNNEST(p.children) as children + |WHERE + | children.name is not null |""".stripMargin ) parents match { - case ElasticSuccess(ps) => ps.size shouldBe 3 + case ElasticSuccess(ps) => + ps.size shouldBe 1 + ps.head.children.size shouldBe 2 case ElasticFailure(error) => error.cause match { case Some(cause) => log.error("Error during search", cause) diff --git a/testkit/src/main/scala/app/softnetwork/elastic/model/Parent.scala b/testkit/src/main/scala/app/softnetwork/elastic/model/Parent.scala index 48f217ed..c8edd0e5 100644 --- a/testkit/src/main/scala/app/softnetwork/elastic/model/Parent.scala +++ b/testkit/src/main/scala/app/softnetwork/elastic/model/Parent.scala @@ -25,7 +25,7 @@ case class Parent( uuid: String, name: String, birthDate: LocalDate, - children: Seq[Child] = Seq.empty[Child] + children: Seq[Child] ) extends Timestamped { def addChild(child: Child): Parent = copy(children = children :+ child) lazy val createdDate: Instant = Instant.now() From be352d9fd59a20021c117f8e79bcd937b74a0f65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Wed, 12 Nov 2025 05:55:53 +0100 Subject: [PATCH 07/11] add support for nested object within query validator --- .../client/ElasticConversionSpec.scala | 59 ++- .../sql/macros/SQLQueryValidatorSpec.scala | 63 ++- .../sql/macros/SQLQueryValidator.scala | 456 ++++++++---------- 3 files changed, 302 insertions(+), 276 deletions(-) diff --git a/core/src/test/scala/app/softnetwork/elastic/client/ElasticConversionSpec.scala b/core/src/test/scala/app/softnetwork/elastic/client/ElasticConversionSpec.scala index c654e379..02a43f0d 100644 --- a/core/src/test/scala/app/softnetwork/elastic/client/ElasticConversionSpec.scala +++ b/core/src/test/scala/app/softnetwork/elastic/client/ElasticConversionSpec.scala @@ -1,8 +1,5 @@ package app.softnetwork.elastic.client -import app.softnetwork.elastic.sql.Identifier -import app.softnetwork.elastic.sql.function.aggregate.ArrayAgg -import app.softnetwork.elastic.sql.query.{OrderBy, SQLAggregation} import org.json4s.ext.{JavaTimeSerializers, JavaTypesSerializers, JodaTimeSerializers} import org.json4s.jackson.Serialization import org.json4s.{Formats, NoTypeHints} @@ -63,6 +60,50 @@ class ElasticConversionSpec extends AnyFlatSpec with Matchers with ElasticConver throw error } } + it should "parse hits with field object" in { + val results = + """{ + | "took": 8, + | "hits": { + | "total": { "value": 1, "relation": "eq" }, + | "max_score": 1.0, + | "hits": [ + | { + | "_index": "users", + | "_id": "u1", + | "_score": 1.0, + | "_source": { + | "id": "u1", + | "name": "Alice", + | "address": { + | "street": "123 Main St", + | "city": "Wonderland", + | "country": "Fictionland" + | } + | } + | } + | ] + | } + |}""".stripMargin + parseResponse( + ElasticResponse("", results, Map.empty, Map.empty) + ) match { + case Success(rows) => + rows.foreach(println) + // Map(name -> Alice, address -> Map(street -> 123 Main St, city -> Wonderland, country -> Fictionland), _id -> u1, _index -> users, _score -> 1.0) + val users = rows.map(row => convertTo[User](row)) + users.foreach(println) + // User(u1,Alice,Address(123 Main St,Wonderland,Fictionland)) + users.size shouldBe 1 + users.head.id shouldBe "u1" + users.head.name shouldBe "Alice" + users.head.address.street shouldBe "123 Main St" + users.head.address.city shouldBe "Wonderland" + users.head.address.country shouldBe "Fictionland" + case Failure(error) => + throw error + } + } it should "parse aggregations with top hits" in { val results = """{ | "took": 10, @@ -638,3 +679,15 @@ case class SalesHistory( sales_over_time: ZonedDateTime, total_revenue: Double ) + +case class Address( + street: String, + city: String, + country: String +) + +case class User( + id: String, + name: String, + address: Address +) diff --git a/macros-tests/src/test/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidatorSpec.scala b/macros-tests/src/test/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidatorSpec.scala index 4b72908b..c939b459 100644 --- a/macros-tests/src/test/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidatorSpec.scala +++ b/macros-tests/src/test/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidatorSpec.scala @@ -9,7 +9,7 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers { // Positive Tests (Should Compile) // ============================================================ - "SQLQueryValidator" should "validate all numeric types" in { + "SQLQueryValidator" should "VALIDATE all numeric types" in { assertCompiles(""" import app.softnetwork.elastic.client.macros.TestElasticClientApi import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats @@ -21,7 +21,7 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers { )""") } - it should "validate string types" in { + it should "VALIDATE string types" in { assertCompiles(""" import app.softnetwork.elastic.client.macros.TestElasticClientApi import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats @@ -33,7 +33,7 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers { )""") } - it should "validate temporal types" in { + it should "VALIDATE temporal types" in { assertCompiles(""" import app.softnetwork.elastic.client.macros.TestElasticClientApi import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats @@ -45,7 +45,7 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers { )""") } - it should "validate Product with all fields" in { + it should "VALIDATE with all fields" in { assertCompiles(""" import app.softnetwork.elastic.client.macros.TestElasticClientApi import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats @@ -57,7 +57,7 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers { )""") } - it should "validate with aliases" in { + it should "VALIDATE with aliases" in { assertCompiles(""" import app.softnetwork.elastic.client.macros.TestElasticClientApi import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats @@ -69,7 +69,7 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers { )""") } - it should "accept query with missing Option fields" in { + it should "ACCEPT query with missing Option fields" in { assertCompiles(""" import app.softnetwork.elastic.client.macros.TestElasticClientApi import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats @@ -82,7 +82,7 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers { """) } - it should "accept query with missing fields that have defaults" in { + it should "ACCEPT query with missing fields that have defaults" in { assertCompiles(""" import app.softnetwork.elastic.client.macros.TestElasticClientApi import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats @@ -95,7 +95,7 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers { """) } - it should "accept SELECT * with Unchecked variant" in { + it should "ACCEPT SELECT * with Unchecked variant" in { assertCompiles(""" import app.softnetwork.elastic.client.macros.TestElasticClientApi import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats @@ -108,6 +108,28 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers { """) } + it should "ACCEPT nested object with complete selection" in { + assertCompiles(""" + import app.softnetwork.elastic.client.macros.TestElasticClientApi + import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats + import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.{User, Address} + + TestElasticClientApi.searchAs[User]( + "SELECT id, name, address FROM users" + )""") + } + + it should "ACCEPT nested object with UNNEST" in { + assertCompiles(""" + import app.softnetwork.elastic.client.macros.TestElasticClientApi + import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats + import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.{User, Address} + + TestElasticClientApi.searchAs[User]( + "SELECT id, name, address.street, address.city, address.country FROM users JOIN UNNEST(users.address) AS address" + )""") + } + // ============================================================ // Negative Tests (Should NOT Compile) // ============================================================ @@ -149,7 +171,7 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers { )""") } - it should "suggest closest field names" in { + it should "SUGGEST closest field names" in { assertDoesNotCompile(""" import app.softnetwork.elastic.client.macros.TestElasticClientApi import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats @@ -200,6 +222,17 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers { """) } + it should "REJECT nested object with individual field selection without UNNEST" in { + assertDoesNotCompile(""" + import app.softnetwork.elastic.client.macros.TestElasticClientApi + import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats + import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.{User, Address} + + TestElasticClientApi.searchAs[User]( + "SELECT id, name, address.street, address.city, address.country FROM users" + )""") + } + } object SQLQueryValidatorSpec { @@ -249,4 +282,16 @@ object SQLQueryValidatorSpec { dt: java.time.LocalDateTime, ts: java.time.Instant ) + + case class Address( + street: String, + city: String, + country: String + ) + + case class User( + id: String, + name: String, + address: Address + ) } diff --git a/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala index 95b9e1b8..a6eafcc7 100644 --- a/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala +++ b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala @@ -23,15 +23,19 @@ import app.softnetwork.elastic.sql.query.SQLSearchRequest import scala.language.experimental.macros import scala.reflect.macros.blackbox -/** Reusable core validation logic for all SQL macros. +/** SQL Query Validator Trait + * + * Provides compile-time validation of SQL queries against Scala case classes. Ensures type safety + * and prevents runtime deserialization errors. */ trait SQLQueryValidator { /** Validates an SQL query against a type T. Returns the SQL query if valid, otherwise aborts * compilation. * @note - * query fields must not exist in case class because we are using Jackson to deserialize the - * results with the following option DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES = false + * query fields that do not exist in case class will be ignored because we are using Jackson to + * deserialize the results with the following option + * DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES = false */ protected def validateSQLQuery[T: c.WeakTypeTag](c: blackbox.Context)( query: c.Expr[String] @@ -50,9 +54,7 @@ trait SQLQueryValidator { return sqlQuery } - if (sys.props.get("elastic.sql.debug").contains("true")) { - c.info(c.enclosingPosition, s"Validating SQL: $sqlQuery", force = false) - } + debug(c)(s"🔍 Validating SQL query $sqlQuery for type: ${tpe.typeSymbol.name}") // ✅ Parse the SQL query val parsedQuery = parseSQLQuery(c)(sqlQuery) @@ -71,10 +73,10 @@ trait SQLQueryValidator { debug(c)(s"🔍 Unnested collections: ${unnestedCollections.mkString(", ")}") // ✅ Recursive validation of required fields - validateRequiredFieldsRecursive(c)(tpe, queryFields, unnestedCollections, prefix = "") + validateRequiredFieldsRecursively(c)(tpe, queryFields, unnestedCollections, prefix = "") // ✅ Recursive validation of unknown fields - validateUnknownFieldsRecursive(c)(tpe, queryFields, prefix = "") + validateUnknownFieldsRecursively(c)(tpe, queryFields, prefix = "") // ✅ Extract required fields from the case class val requiredFields = getRequiredFields(c)(tpe) @@ -168,7 +170,7 @@ trait SQLQueryValidator { case Right(Right(multi)) => multi.requests.headOption.getOrElse { - c.abort(c.enclosingPosition, "Empty multi-search query") + c.abort(c.enclosingPosition, "❌ Empty multi-search query") } case Left(error) => @@ -233,6 +235,21 @@ trait SQLQueryValidator { tpe.typeSymbol.isClass && tpe.typeSymbol.asClass.isCaseClass } + // ============================================================ + // Helper: Check if a type is a Product (case class or tuple) + // ============================================================ + private def isProductType(c: blackbox.Context)(tpe: c.universe.Type): Boolean = { + import c.universe._ + + // Check if it's a case class + val isCaseClass = tpe.typeSymbol.isClass && tpe.typeSymbol.asClass.isCaseClass + + // Check if it's a Product type (includes tuples) + val isProduct = tpe <:< typeOf[Product] + + isCaseClass || isProduct + } + // ============================================================ // Helper: Detect if a type is a collection // ============================================================ @@ -316,7 +333,7 @@ trait SQLQueryValidator { // ============================================================ // Helper: Recursive Validation of Required Fields // ============================================================ - private def validateRequiredFieldsRecursive( + private def validateRequiredFieldsRecursively( c: blackbox.Context )( tpe: c.universe.Type, @@ -326,281 +343,194 @@ trait SQLQueryValidator { ): Unit = { val requiredFields = getRequiredFields(c)(tpe) + debug(c)( + s"📋 Required fields for ${tpe.typeSymbol.name} (prefix='$prefix'): ${requiredFields.keys.mkString(", ")}" + ) - val missingFields = requiredFields.filterNot { - case (fieldName, (fieldType, isOption, hasDefault)) => - val fullFieldName = if (prefix.isEmpty) fieldName else s"$prefix.$fieldName" + requiredFields.foreach { case (fieldName, (fieldType, isOption, hasDefault)) => + val fullFieldName = if (prefix.isEmpty) fieldName else s"$prefix.$fieldName" + debug(c)( + s"🔍 Checking field: $fullFieldName (type: $fieldType, optional: $isOption, hasDefault: $hasDefault)" + ) - // ✅ Check if the field is directly selected (e.g., "address") - val isDirectlySelected = queryFields.contains(fullFieldName) + // ✅ Check if the field is directly selected (e.g., "address") + val isDirectlySelected = queryFields.contains(fullFieldName) - // ✅ Check if nested fields of this field are selected (e.g., "address.street") - val hasNestedSelection = queryFields.exists(_.startsWith(s"$fullFieldName.")) + // ✅ Check if nested fields of this field are selected (e.g., "address.street") + val hasNestedSelection = queryFields.exists(_.startsWith(s"$fullFieldName.")) - if (isDirectlySelected) { - // ✅ Field is selected as a whole (e.g., SELECT address FROM ...) - debug(c)(s"✅ Field '$fullFieldName' is directly selected") - true - } else if (isOption) { - // ✅ Field is optional, can be omitted - debug(c)(s"✅ Field '$fullFieldName' is Option - OK") - true - } else if (hasDefault) { - // ✅ Field has a default value, can be omitted - debug(c)(s"✅ Field '$fullFieldName' has default value - OK") - true - } else if (hasNestedSelection) { - // ⚠️ Nested fields are selected (e.g., SELECT address.street FROM ...) - // We must validate that ALL required nested fields are present - - if (isCollectionType(c)(fieldType)) { - // ✅ Collection: check if it's unnested - validateCollectionFieldsRecursive(c)( - fieldName, - fieldType, - queryFields, - unnestedCollections, - prefix + // ✅ Determine field characteristics + val isCollection = isCollectionType(c)(fieldType) + val isNestedObject = isProductType(c)(fieldType) + val isNestedCollection = isCollection && { + getCollectionElementType(c)(fieldType).exists(isProductType(c)) + } + + if (isDirectlySelected) { + // ✅ Field is selected as a whole (e.g., SELECT address FROM ...) + debug(c)(s"✅ Field '$fullFieldName' is directly selected") + } else if (isOption || hasDefault) { + // ✅ Field is optional or has a default value, can be omitted + debug(c)(s"✅ Field '$fullFieldName' is optional or has default - OK") + } else if (hasNestedSelection) { + // ⚠️ Nested fields are selected (e.g., SELECT address.street FROM ...) + + if (isNestedCollection) { + // ✅ Collection of case classes + debug(c)(s"📦 Field '$fullFieldName' is a nested collection") + validateNestedCollection(c)( + fullFieldName, + fieldName, + fieldType, + queryFields, + unnestedCollections, + tpe + ) + } else if (isNestedObject) { + // ❌ Nested object (non-collection) with individual field selection + debug(c)(s"🏗️ Field '$fullFieldName' is a nested object (non-collection)") + + val isUnnested = unnestedCollections.contains(fullFieldName) || + unnestedCollections.contains(fieldName) + + if (!isUnnested) { + c.abort( + c.enclosingPosition, + s"""❌ Nested object field '$fullFieldName' cannot be deserialized correctly. + | + |❌ Problem: + | You are selecting nested fields individually: + | ${queryFields.filter(_.startsWith(s"$fullFieldName.")).mkString(", ")} + | + | Elasticsearch will return flat fields like: + | { "$fullFieldName.field1": "value1", "$fullFieldName.field2": "value2" } + | + | But Jackson needs a structured object like: + | { "$fullFieldName": {"field1": "value1", "field2": "value2"} } + | + |✅ Solution 1: Select the entire nested object (recommended) + | SELECT $fullFieldName FROM ... + | + |✅ Solution 2: Use UNNEST (if you need to filter or join on nested fields) + | SELECT ${queryFields.filter(_.startsWith(s"$fullFieldName.")).mkString(", ")} + | FROM ... + | JOIN UNNEST(....$fullFieldName) AS $fieldName + | + |✅ Solution 3: Make the nested object optional (if it's not always needed) + | case class ${tpe.typeSymbol.name}(..., $fieldName: Option[${fieldType.typeSymbol.name}] = None) + | + |📚 Note: This applies to ALL nested objects, not just collections. + |""".stripMargin ) - } else if (isCaseClassType(c)(fieldType)) { - // ✅ Nested case class: validate that ALL required nested fields are selected - debug(c)(s"🔍 Validating nested case class fields: $fullFieldName") - - try { - validateRequiredFieldsRecursive(c)( - fieldType, - queryFields, - unnestedCollections, - prefix = fullFieldName - ) - // ✅ All required nested fields are present - debug(c)(s"✅ All required nested fields of '$fullFieldName' are present") - true - } catch { - case _: Throwable => - // ❌ Some required nested fields are missing - val nestedFields = getRequiredFields(c)(fieldType) - val missingNestedFields = nestedFields.filterNot { - case (nestedFieldName, (nestedFieldType, nestedIsOption, nestedHasDefault)) => - val fullNestedFieldName = s"$fullFieldName.$nestedFieldName" - val isNestedSelected = queryFields.contains(fullNestedFieldName) - val hasNestedNestedSelection = - queryFields.exists(_.startsWith(s"$fullNestedFieldName.")) - - isNestedSelected || nestedIsOption || nestedHasDefault || - (hasNestedNestedSelection && isCaseClassType(c)(nestedFieldType)) - } - - if (missingNestedFields.nonEmpty) { - val missingNames = - missingNestedFields.keys.map(n => s"$fullFieldName.$n").mkString(", ") - val allRequiredFields = nestedFields - .filterNot { case (_, (_, isOpt, hasDef)) => isOpt || hasDef } - .keys - .map(n => s"$fullFieldName.$n") - .mkString(", ") - - c.abort( - c.enclosingPosition, - s"""❌ Nested case class field '$fullFieldName' has missing required fields: - |$missingNames - | - |When selecting nested fields individually, ALL required fields must be present. - | - |Option 1: Select the entire nested object: - | SELECT $fullFieldName FROM ... - | - |Option 2: Select ALL required nested fields: - | SELECT $allRequiredFields FROM ... - | - |Option 3: Make missing fields optional or provide default values in the case class - |""".stripMargin - ) - } - - false - } - } else { - // ✅ Primitive type with nested selection (shouldn't happen) - debug(c)(s"⚠️ Unexpected nested selection for primitive field: $fullFieldName") - false } - } else { - // ❌ Field is not selected at all - debug(c)(s"❌ Field '$fullFieldName' is missing") - false - } - } - if (missingFields.nonEmpty) { - val missingFieldNames = missingFields.keys - .map { fieldName => - if (prefix.isEmpty) fieldName else s"$prefix.$fieldName" + // ✅ With UNNEST: validate nested fields recursively + validateRequiredFieldsRecursively(c)( + fieldType, + queryFields, + unnestedCollections, + fullFieldName + ) + } else { + // ✅ Primitive type with nested selection (shouldn't happen) + debug(c)(s"⚠️ Unexpected nested selection for primitive field: $fullFieldName") } - .mkString(", ") + } else { + // ❌ Required field is not selected at all + debug(c)(s"❌ Field '$fullFieldName' is missing") - val exampleFields = (queryFields ++ missingFields.keys.map { fieldName => - if (prefix.isEmpty) fieldName else s"$prefix.$fieldName" - }).mkString(", ") + val exampleFields = (queryFields + fullFieldName).mkString(", ") + val suggestions = findClosestMatch(fieldName, queryFields.map(_.split("\\.").last).toSeq) + val suggestionMsg = suggestions.map(s => s"\nDid you mean: $s?").getOrElse("") - val suggestions = missingFields.keys.flatMap { fieldName => - findClosestMatch(fieldName, queryFields.map(_.split("\\.").last).toSeq) + c.abort( + c.enclosingPosition, + s"""❌ SQL query does not select the required field: $fullFieldName + |$suggestionMsg + | + |Example query: + |SELECT $exampleFields FROM ... + | + |To fix this, either: + | 1. Add it to the SELECT clause + | 2. Make it Option[T] in the case class + | 3. Provide a default value in the case class definition + |""".stripMargin + ) } - - val suggestionMsg = if (suggestions.nonEmpty) { - s"\nDid you mean: ${suggestions.mkString(", ")}?" - } else "" - - c.abort( - c.enclosingPosition, - s"""❌ SQL query does not select the following required fields from ${tpe.typeSymbol.name}: - |$missingFieldNames$suggestionMsg - | - |These fields are missing from the query: - |SELECT $exampleFields FROM ... - | - |To fix this, either: - | 1. Add them to the SELECT clause - | 2. Make them Option[T] in the case class - | 3. Provide default values in the case class definition""".stripMargin - ) } } // ============================================================ - // Helper: Validate Collection Fields Recursively + // Helper: Validate Nested Collection Fields // ============================================================ - private def validateCollectionFieldsRecursive( + private def validateNestedCollection( c: blackbox.Context )( + fullFieldName: String, fieldName: String, fieldType: c.universe.Type, queryFields: Set[String], unnestedCollections: Set[String], - prefix: String - ): Boolean = { - - getCollectionElementType(c)(fieldType) match { - case Some(elementType) => - val fullFieldName = if (prefix.isEmpty) fieldName else s"$prefix.$fieldName" - - // ✅ Check if the collection is selected as a whole - val isDirectlySelected = queryFields.contains(fullFieldName) - - // ✅ Check if the collection is unnested (uses UNNEST) - val isUnnested = unnestedCollections.contains(fieldName) - - if (isDirectlySelected) { - debug(c)(s"✅ Collection field '$fullFieldName' is directly selected") - true - } else if (isUnnested) { - debug(c)(s"✅ Collection field '$fullFieldName' is unnested") - // ✅ For unnested collections, validate nested fields - if (isCaseClassType(c)(elementType)) { - val nestedFields = getRequiredFields(c)(elementType) - - val missingNestedFields = nestedFields.filterNot { - case (nestedFieldName, (nestedFieldType, isOption, hasDefault)) => - val fullNestedFieldName = s"$fullFieldName.$nestedFieldName" - - val isSelected = queryFields.contains(fullNestedFieldName) - val hasNestedSelection = queryFields.exists(_.startsWith(s"$fullNestedFieldName.")) - - if (isSelected) { - debug(c)(s"✅ Nested field '$fullNestedFieldName' is selected") - true - } else if (isOption) { - debug(c)(s"✅ Nested field '$fullNestedFieldName' is Option - OK") - true - } else if (hasDefault) { - debug(c)(s"✅ Nested field '$fullNestedFieldName' has default value - OK") - true - } else if (hasNestedSelection && isCaseClassType(c)(nestedFieldType)) { - debug(c)(s"🔍 Validating deeply nested case class: $fullNestedFieldName") - try { - validateRequiredFieldsRecursive(c)( - nestedFieldType, - queryFields, - unnestedCollections, - prefix = fullNestedFieldName - ) - true - } catch { - case _: Throwable => false - } - } else { - debug(c)(s"❌ Nested field '$fullNestedFieldName' is missing") - false - } - } - - if (missingNestedFields.nonEmpty) { - val missingNames = - missingNestedFields.keys.map(n => s"$fullFieldName.$n").mkString(", ") - val allRequiredFields = nestedFields - .filterNot { case (_, (_, isOpt, hasDef)) => isOpt || hasDef } - .keys - .map(n => s"$fullFieldName.$n") - .mkString(", ") - - c.abort( - c.enclosingPosition, - s"""❌ Unnested collection field '$fullFieldName' is missing required nested fields: - |$missingNames - | - |When using UNNEST, ALL required nested fields must be selected: - | SELECT $allRequiredFields FROM parent JOIN UNNEST(parent.$fieldName) AS $fieldName - | - |Or make missing nested fields optional or provide default values - |""".stripMargin - ) - } - - true - } else { - true - } - } else if (isCaseClassType(c)(elementType)) { - // ❌ Collection of case classes with nested field selection but NO UNNEST - val hasNestedSelection = queryFields.exists(_.startsWith(s"$fullFieldName.")) - - if (hasNestedSelection) { - c.abort( - c.enclosingPosition, - s"""❌ Collection field '$fullFieldName' cannot be deserialized correctly. - | - |You are selecting nested fields of a collection without using UNNEST: - | ${queryFields.filter(_.startsWith(s"$fullFieldName.")).mkString(", ")} - | - |This will result in flat arrays that cannot be reconstructed into objects. - | - |Example of the problem: - | Elasticsearch returns: { "children.name": ["Alice", "Bob"], "children.age": [10, 12] } - | But we need: { "children": [{"name": "Alice", "age": 10}, {"name": "Bob", "age": 12}] } - | - |Solution 1: Select the entire collection: - | SELECT $fullFieldName FROM ... - | - |Solution 2: Use UNNEST to properly handle nested objects: - | SELECT name, $fieldName.name, $fieldName.age - | FROM ${if (prefix.isEmpty) "table" else prefix} - | JOIN UNNEST(${if (prefix.isEmpty) "" else s"$prefix."}$fieldName) AS $fieldName - | - |Solution 3: Make the collection optional: - | $fullFieldName: Option[List[${elementType.typeSymbol.name}]] = None - |""".stripMargin - ) - } + parentType: c.universe.Type + ): Unit = { - false - } else { - // ✅ Collection of primitive types - true + // ✅ Check if the collection is unnested (uses UNNEST) + val isUnnested = unnestedCollections.contains(fullFieldName) || + unnestedCollections.contains(fieldName) + + if (!isUnnested) { + // ❌ Collection with nested field selection but NO UNNEST + val selectedNestedFields = queryFields.filter(_.startsWith(s"$fullFieldName.")) + + getCollectionElementType(c)(fieldType) match { + case Some(elementType) => + c.abort( + c.enclosingPosition, + s"""❌ Collection field '$fullFieldName' cannot be deserialized correctly. + | + |❌ Problem: + | You are selecting nested fields without using UNNEST: + | ${selectedNestedFields.mkString(", ")} + | + | Elasticsearch will return flat arrays like: + | { "$fullFieldName.field1": ["val1", "val2"], "$fullFieldName.field2": ["val3", "val4"] } + | + | But Jackson needs structured objects like: + | { "$fullFieldName": [{"field1": "val1", "field2": "val3"}, {"field1": "val2", "field2": "val4"}] } + | + |✅ Solution 1: Select the entire collection (recommended for simple queries) + | SELECT $fullFieldName FROM ... + | + |✅ Solution 2: Use UNNEST for precise field selection (recommended for complex queries) + | SELECT ${selectedNestedFields.mkString(", ")} + | FROM ... + | JOIN UNNEST(....$fullFieldName) AS $fieldName + | + |✅ Solution 3: Make the collection optional (if it's not always needed) + | case class ${parentType.typeSymbol.name}(..., $fieldName: Option[List[${elementType.typeSymbol.name}]] = None) + | + |📚 Documentation: + | https://www.elastic.co/guide/en/elasticsearch/reference/current/nested.html + |""".stripMargin + ) + case None => + debug(c)(s"⚠️ Cannot extract element type from collection: $fullFieldName") + } + } else { + // ✅ Collection is unnested: validate nested fields recursively + debug(c)(s"✅ Collection field '$fullFieldName' is unnested") + + getCollectionElementType(c)(fieldType).foreach { elementType => + if (isProductType(c)(elementType)) { + validateRequiredFieldsRecursively(c)( + elementType, + queryFields, + unnestedCollections, + fullFieldName + ) } - - case None => - debug(c)(s"⚠️ Cannot extract element type from collection: $fieldName") - false + } } } @@ -639,7 +569,7 @@ trait SQLQueryValidator { // ============================================================ // Helper: Validate Unknown Fields Recursively // ============================================================ - private def validateUnknownFieldsRecursive( + private def validateUnknownFieldsRecursively( c: blackbox.Context )( tpe: c.universe.Type, @@ -647,8 +577,6 @@ trait SQLQueryValidator { prefix: String ): Unit = { - val requiredFields = getRequiredFields(c)(tpe) - // ✅ Get all valid field paths at this level and below val validFieldPaths = buildValidFieldPaths(c)(tpe, prefix) From 140f60f421d66ed920d8fbcb702bcdff5a7eee73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Wed, 12 Nov 2025 06:48:19 +0100 Subject: [PATCH 08/11] update validator suggestions for missing required fields --- .../sql/macros/SQLQueryValidatorSpec.scala | 2 +- .../sql/macros/SQLQueryValidator.scala | 31 ++++++++++++++----- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/macros-tests/src/test/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidatorSpec.scala b/macros-tests/src/test/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidatorSpec.scala index c939b459..4733479d 100644 --- a/macros-tests/src/test/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidatorSpec.scala +++ b/macros-tests/src/test/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidatorSpec.scala @@ -167,7 +167,7 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers { case class WrongTypes(id: Int, name: Int) TestElasticClientApi.searchAs[WrongTypes]( - "SELECT id::LONG, name FROM products" + "SELECT id::BIGINT, name FROM products" )""") } diff --git a/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala index a6eafcc7..8af3b927 100644 --- a/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala +++ b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala @@ -72,11 +72,17 @@ trait SQLQueryValidator { debug(c)(s"🔍 Unnested collections: ${unnestedCollections.mkString(", ")}") - // ✅ Recursive validation of required fields - validateRequiredFieldsRecursively(c)(tpe, queryFields, unnestedCollections, prefix = "") - // ✅ Recursive validation of unknown fields - validateUnknownFieldsRecursively(c)(tpe, queryFields, prefix = "") + val unknownFields = validateUnknownFieldsRecursively(c)(tpe, queryFields, prefix = "") + + // ✅ Recursive validation of required fields + validateRequiredFieldsRecursively(c)( + tpe, + queryFields, + unknownFields, + unnestedCollections, + prefix = "" + ) // ✅ Extract required fields from the case class val requiredFields = getRequiredFields(c)(tpe) @@ -338,6 +344,7 @@ trait SQLQueryValidator { )( tpe: c.universe.Type, queryFields: Set[String], + unknownFields: Set[String], unnestedCollections: Set[String], prefix: String ): Unit = { @@ -383,6 +390,7 @@ trait SQLQueryValidator { fieldName, fieldType, queryFields, + unknownFields, unnestedCollections, tpe ) @@ -428,6 +436,7 @@ trait SQLQueryValidator { validateRequiredFieldsRecursively(c)( fieldType, queryFields, + unknownFields, unnestedCollections, fullFieldName ) @@ -439,9 +448,11 @@ trait SQLQueryValidator { // ❌ Required field is not selected at all debug(c)(s"❌ Field '$fullFieldName' is missing") - val exampleFields = (queryFields + fullFieldName).mkString(", ") - val suggestions = findClosestMatch(fieldName, queryFields.map(_.split("\\.").last).toSeq) - val suggestionMsg = suggestions.map(s => s"\nDid you mean: $s?").getOrElse("") + val exampleFields = ((queryFields -- unknownFields) + fullFieldName).mkString(", ") + val suggestions = findClosestMatch(fieldName, unknownFields.map(_.split("\\.").last).toSeq) + val suggestionMsg = suggestions + .map(s => s"\nYou have selected unknown field \"$s\", did you mean \"$fullFieldName\"?") + .getOrElse("") c.abort( c.enclosingPosition, @@ -471,6 +482,7 @@ trait SQLQueryValidator { fieldName: String, fieldType: c.universe.Type, queryFields: Set[String], + unknownFields: Set[String], unnestedCollections: Set[String], parentType: c.universe.Type ): Unit = { @@ -526,6 +538,7 @@ trait SQLQueryValidator { validateRequiredFieldsRecursively(c)( elementType, queryFields, + unknownFields, unnestedCollections, fullFieldName ) @@ -575,7 +588,7 @@ trait SQLQueryValidator { tpe: c.universe.Type, queryFields: Set[String], prefix: String - ): Unit = { + ): Set[String] = { // ✅ Get all valid field paths at this level and below val validFieldPaths = buildValidFieldPaths(c)(tpe, prefix) @@ -601,6 +614,8 @@ trait SQLQueryValidator { |""".stripMargin ) } + + unknownFields } // ============================================================ From 9f55eb5b350286639a525dbe5d647842415e80b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Wed, 12 Nov 2025 06:48:41 +0100 Subject: [PATCH 09/11] add documentation for SQL Query Validation --- documentation/client/README.md | 1 + documentation/client/sql_validation.md | 729 +++++++++++++++++++++++++ 2 files changed, 730 insertions(+) create mode 100644 documentation/client/sql_validation.md diff --git a/documentation/client/README.md b/documentation/client/README.md index fc37f028..5a94564d 100644 --- a/documentation/client/README.md +++ b/documentation/client/README.md @@ -3,6 +3,7 @@ Welcome to the Client Engine Documentation. Navigate through the sections below: - [Client Common Principles](common_principles.md) +- [SQL Query Validation](sql_validation.md) - [Version Information](version.md) - [Flush Index](flush.md) - [Refresh Index](refresh.md) diff --git a/documentation/client/sql_validation.md b/documentation/client/sql_validation.md new file mode 100644 index 00000000..1f153d65 --- /dev/null +++ b/documentation/client/sql_validation.md @@ -0,0 +1,729 @@ +# **SQL Query Validation at Compile Time using macros** + +## **Table of Contents** + +1. [Overview](#overview-) +2. [Validations Performed](#validations-performed) +3. [Validation Examples with Error Messages](#validation-examples-with-error-messages) + - [1. SELECT * Validation](#1-select--validation) + - [2. Missing Fields Validation](#2-missing-fields-validation) + - [3. Unknown Fields Validation](#3-unknown-fields-validation) + - [4. Nested Objects Validation](#4-nested-objects-validation) + - [5. Nested Collections Validation](#5-nested-collections-validation) + - [6. Type Validation](#6-type-validation) +4. [Best Practices](#best-practices) +5. [Debug Configuration](#debug-configuration) + +--- + +## **Overview** 🎯 + +The Elasticsearch SQL client integrates a **compile-time validation system** that automatically verifies your SQL queries are compatible with the Scala case classes used for deserialization. This validation detects errors **before execution**, ensuring consistency between your queries and your data model. + +### **Benefits** ✅ + +- ✅ **Early error detection**: Issues are identified at compile time, not in production +- ✅ **Safe refactoring**: Renaming or removing a field generates a compilation error +- ✅ **Living documentation**: Case classes document the data structure +- ✅ **Strong typing**: Guarantees consistency between SQL and Scala +- ✅ **Explicit error messages**: Clear guidance on how to fix issues + +--- + +## **Validations Performed** + +| Validation | Description | Level | +|------------------------|--------------------------------------------------------|------------| +| **SELECT * Rejection** | Prohibits `SELECT *` to ensure compile-time validation | ❌ ERROR | +| **Required Fields** | Verifies that all required fields are selected | ❌ ERROR | +| **Unknown Fields** | Detects fields that don't exist in the case class | ⚠️ WARNING | +| **Nested Objects** | Validates the structure of nested objects | ❌ ERROR | +| **Nested Collections** | Validates the use of UNNEST for collections | ❌ ERROR | +| **Type Compatibility** | Checks compatibility between SQL and Scala types | ❌ ERROR | + +--- + +## **Validation Examples with Error Messages** + +### **1. SELECT \* Validation** + +#### **❌ Error Example** + +```scala +case class Product( + id: String, + name: String, + price: Double +) + +// ❌ ERROR: SELECT * is not allowed +client.searchAs[Product]("SELECT * FROM products") +``` + +#### **📋 Exact Error Message** + +``` +❌ SELECT * is not allowed with compile-time validation. + +Query: SELECT * FROM products + +Reason: + • Cannot validate field existence at compile-time + • Cannot validate type compatibility at compile-time + • Schema changes will break silently at runtime + +Solution: + 1. Explicitly list all required fields for Product: + SELECT id, name, price FROM ... + + 2. Use the *Unchecked() variant for dynamic queries: + searchAsUnchecked[Product](SQLQuery("SELECT * FROM ...")) + +Best Practice: + Always explicitly select only the fields you need. +``` + +#### **✅ Solution** + +```scala +// ✅ CORRECT: Explicitly select all fields +client.searchAs[Product]("SELECT id, name, price FROM products") +``` + +--- + +### **2. Missing Fields Validation** + +#### **2.1. Missing Simple Field** + +##### **❌ Error Example** + +```scala +case class User( + id: String, + name: String, + email: String +) + +// ❌ ERROR: The 'email' field is missing +client.searchAs[User]("SELECT id, name FROM users") +``` + +##### **📋 Exact Error Message** + +``` +❌ SQL query does not select the required field: email + +Example query: +SELECT id, name, email FROM ... + +To fix this, either: + 1. Add it to the SELECT clause + 2. Make it Option[T] in the case class + 3. Provide a default value in the case class definition +``` + +##### **✅ Solutions** + +**Option 1: Add the missing field** +```scala +// ✅ CORRECT +client.searchAs[User]("SELECT id, name, email FROM users") +``` + +**Option 2: Make the field optional** +```scala +case class User( + id: String, + name: String, + email: Option[String] = None // ✅ Optional field +) + +// ✅ CORRECT +client.searchAs[User]("SELECT id, name FROM users") +``` + +**Option 3: Provide a default value** +```scala +case class User( + id: String, + name: String, + email: String = "" // ✅ Default value +) + +// ✅ CORRECT +client.searchAs[User]("SELECT id, name FROM users") +``` + +--- + +#### **2.2. Field with Suggestion (Did You Mean?)** + +##### **❌ Error Example** + +```scala +case class Product( + id: String, + name: String, + price: Double +) + +// ❌ ERROR: Typo in 'price' -> 'pric' +client.searchAs[Product]("SELECT id, name, pric FROM products") +``` + +##### **📋 Exact Error Message** + +``` +❌ SQL query does not select the required field: price +You have selected unknown field "pric", did you mean "price"? + +Example query: +SELECT id, name, price FROM ... + +To fix this, either: + 1. Add it to the SELECT clause + 2. Make it Option[T] in the case class + 3. Provide a default value in the case class definition +``` + +##### **✅ Solution** + +```scala +// ✅ CORRECT: Fix the typo +client.searchAs[Product]("SELECT id, name, price FROM products") +``` + +--- + +### **3. Unknown Fields Validation** + +#### **⚠️ Warning Example** + +```scala +case class User( + id: String, + name: String, + email: String +) + +// ⚠️ WARNING: The 'age' field doesn't exist in User +client.searchAs[User]("SELECT id, name, email, age FROM users") +``` + +#### **📋 Exact Warning Message** + +``` +⚠️ SQL query selects fields that don't exist in User: +age + +Available fields: id, name, email + +Note: These fields will be ignored during deserialization. +``` + +#### **💡 Behavior** + +- ✅ **The code compiles successfully** +- ⚠️ **A warning is displayed** to inform about the unknown field +- 🔄 **During deserialization**, the unknown field is **silently ignored** +- 📦 **The JSON response** contains the field, but it's not mapped to the case class + +#### **✅ Solutions** + +**Option 1: Remove the unknown field** +```scala +// ✅ CORRECT: Only select existing fields +client.searchAs[User]("SELECT id, name, email FROM users") +``` + +**Option 2: Add the field to the case class** +```scala +case class User( + id: String, + name: String, + email: String, + age: Option[Int] = None // ✅ Field added +) + +// ✅ CORRECT +client.searchAs[User]("SELECT id, name, email, age FROM users") +``` + +--- + +### **4. Nested Objects Validation** + +#### **4.1. Nested Object with Individual Field Selection** + +##### **❌ Error Example** + +```scala +case class Address( + street: String, + city: String, + country: String +) + +case class User( + id: String, + name: String, + address: Address +) + +// ❌ ERROR: Selecting nested fields without UNNEST +client.searchAs[User]( + "SELECT id, name, address.street, address.city, address.country FROM users" +) +``` + +##### **📋 Exact Error Message** + +``` +❌ Nested object field 'address' cannot be deserialized correctly. + +❌ Problem: + You are selecting nested fields individually: + address.street, address.city, address.country + + Elasticsearch will return flat fields like: + { "address.street": "value1", "address.city": "value2", "address.country": "value3" } + + But Jackson needs a structured object like: + { "address": {"street": "value1", "city": "value2", "country": "value3"} } + +✅ Solution 1: Select the entire nested object (recommended) + SELECT address FROM ... + +✅ Solution 2: Use UNNEST (if you need to filter or join on nested fields) + SELECT address.street, address.city, address.country + FROM ... + JOIN UNNEST(....address) AS address + +📚 Note: This applies to ALL nested objects, not just collections. +``` + +##### **✅ Solutions** + +**Option 1: Select the complete object (RECOMMENDED)** +```scala +// ✅ CORRECT: Select the entire object +client.searchAs[User]("SELECT id, name, address FROM users") +``` + +**Elasticsearch Response**: +```json +{ + "id": "u1", + "name": "Alice", + "address": { + "street": "123 Main St", + "city": "Wonderland", + "country": "Fictionland" + } +} +``` + +**Option 2: Use UNNEST** +```scala +// ✅ CORRECT: Use UNNEST for filtering/joining +client.searchAs[User]( + """SELECT id, name, address.street, address.city, address.country + FROM users + JOIN UNNEST(users.address) AS address + WHERE address.city = 'Wonderland'""" +) +``` + +--- + +#### **4.2. Missing Nested Object** + +##### **❌ Error Example** + +```scala +case class Address( + street: String, + city: String, + country: String +) + +case class User( + id: String, + name: String, + address: Address // ❌ Required field not selected +) + +// ❌ ERROR: The 'address' object is not selected +client.searchAs[User]("SELECT id, name FROM users") +``` + +##### **📋 Exact Error Message** + +``` +❌ SQL query does not select the required field: address + +Example query: +SELECT id, name, address FROM ... + +To fix this, either: + 1. Add it to the SELECT clause + 2. Make it Option[T] in the case class + 3. Provide a default value in the case class definition +``` + +##### **✅ Solutions** + +**Option 1: Add the missing field** +```scala +// ✅ CORRECT +client.searchAs[User]("SELECT id, name, address FROM users") +``` + +**Option 2: Make the object optional** +```scala +case class User( + id: String, + name: String, + address: Option[Address] = None // ✅ Optional object +) + +// ✅ CORRECT +client.searchAs[User]("SELECT id, name FROM users") +``` + +--- + +### **5. Nested Collections Validation** + +#### **5.1. Nested Collection with Individual Field Selection without UNNEST** + +##### **❌ Error Example** + +```scala +case class Child( + name: String, + age: Int +) + +case class Parent( + id: String, + name: String, + children: List[Child] +) + +// ❌ ERROR: Selecting nested fields without UNNEST +client.searchAs[Parent]( + "SELECT id, name, children.name, children.age FROM parent" +) +``` + +##### **📋 Exact Error Message** + +``` +❌ Collection field 'children' cannot be deserialized correctly. + +❌ Problem: + You are selecting nested fields without using UNNEST: + children.name, children.age + + Elasticsearch will return flat arrays like: + { "children.name": ["Alice", "Bob"], "children.age": [10, 12] } + + But Jackson needs structured objects like: + { "children": [{"name": "Alice", "age": 10}, {"name": "Bob", "age": 12}] } + +✅ Solution 1: Select the entire collection (recommended for simple queries) + SELECT children FROM ... + +✅ Solution 2: Use UNNEST for precise field selection (recommended for complex queries) + SELECT children.name, children.age + FROM ... + JOIN UNNEST(....children) AS children + +📚 Documentation: + https://www.elastic.co/guide/en/elasticsearch/reference/current/nested.html +``` + +##### **✅ Solutions** + +**Option 1: Select the complete collection (RECOMMENDED)** +```scala +// ✅ CORRECT: Select the entire collection +client.searchAs[Parent]("SELECT id, name, children FROM parent") +``` + +**Elasticsearch Response**: +```json +{ + "id": "p1", + "name": "Parent Name", + "children": [ + {"name": "Alice", "age": 10}, + {"name": "Bob", "age": 12} + ] +} +``` + +**Option 2: Use UNNEST** +```scala +// ✅ CORRECT: Use UNNEST for filtering/joining +client.searchAs[Parent]( + """SELECT id, name, children.name, children.age + FROM parent + JOIN UNNEST(parent.children) AS children + WHERE children.age > 10""" +) +``` + +--- + +#### **5.2. Missing Nested Collection** + +##### **❌ Error Example** + +```scala +case class Child( + name: String, + age: Int +) + +case class Parent( + id: String, + name: String, + children: List[Child] // ❌ Required collection not selected +) + +// ❌ ERROR: The 'children' collection is not selected +client.searchAs[Parent]("SELECT id, name FROM parent") +``` + +##### **📋 Exact Error Message** + +``` +❌ SQL query does not select the required field: children + +Example query: +SELECT id, name, children FROM ... + +To fix this, either: + 1. Add it to the SELECT clause + 2. Make it Option[T] in the case class + 3. Provide a default value in the case class definition +``` + +##### **✅ Solutions** + +**Option 1: Add the missing collection** +```scala +// ✅ CORRECT +client.searchAs[Parent]("SELECT id, name, children FROM parent") +``` + +**Option 2: Make the collection optional** +```scala +case class Parent( + id: String, + name: String, + children: Option[List[Child]] = None // ✅ Optional collection +) + +// ✅ CORRECT +client.searchAs[Parent]("SELECT id, name FROM parent") +``` + +--- + +### **6. Type Validation** + +#### **6.1. Type Incompatibility** + +##### **❌ Error Example** + +```scala +case class Product( + id: String, + name: String, + stock: Int // ❌ Wrong type (should be Long) +) + +// ❌ ERROR: The 'stock' field is cast to BIGINT in SQL +client.searchAs[Product]("SELECT id, name, stock::BIGINT FROM products") +``` + +##### **📋 Exact Error Message** + +``` +Type mismatch for field 'stock': SQL type BIGINT is incompatible with Scala type Int +Expected one of: Long, BigInt, Option[Long], Option[BigInt] +``` + +##### **✅ Solution** + +```scala +case class Product( + id: String, + name: String, + stock: Long // ✅ Correct type +) + +// ✅ CORRECT +client.searchAs[Product]("SELECT id, name, stock::BIGINT FROM products") +``` + +--- + +#### **6.2. Type Compatibility Table** + +| SQL Type | Compatible Scala Types | +|-------------------------|------------------------------------------------------------------------------------------------| +| `TINYINT` | `Byte`, `Short`, `Int`, `Long`, `Option[Byte]`, `Option[Short]`, `Option[Int]`, `Option[Long]` | +| `SMALLINT` | `Short`, `Int`, `Long`, `Option[Short]`, `Option[Int]`, `Option[Long]` | +| `INT` | `Int`, `Long`, `Option[Int]`, `Option[Long]` | +| `BIGINT` | `Long`, `BigInt`, `Option[Long]`, `Option[BigInt]` | +| `DOUBLE`, `REAL` | `Double`, `Float`, `Option[Double]`, `Option[Float]` | +| `VARCHAR` | `String`, `Option[String]` | +| `CHAR` | `String`, `Char`, `Option[String]`, `Option[Char]` | +| `BOOLEAN` | `Boolean`, `Option[Boolean]` | +| `TIME` | `java.time.LocalTime`, `java.time.Instant` | +| `DATE` | `java.time.LocalDate`, `java.time.Instant`, `java.util.Date` | +| `DATETIME`, `TIMESTAMP` | `java.time.LocalDateTime`, `java.time.ZonedDateTime`, `java.time.Instant` | +| `STRUCT` | Case Class | + +--- + +## **Best Practices** + +### **1. Always Explicitly Select Fields** + +❌ **Avoid**: +```scala +client.searchAs[Product]("SELECT * FROM products") +``` + +✅ **Prefer**: +```scala +client.searchAs[Product]("SELECT id, name, price FROM products") +``` + +--- + +### **2. Use UNNEST for Nested Collections and Objects** + +❌ **Avoid**: + +```scala +client.searchAs[Parent]("SELECT id, children.name FROM parent") +``` + +✅ **Prefer**: + +```scala +// Option 1: Select the complete collection +client.searchAs[Parent]("SELECT id, children FROM parent") + +// Option 2: Use UNNEST for filtering +client.searchAs[Parent]( + """SELECT id, children.name + FROM parent + JOIN UNNEST(parent.children) AS children""" +) +``` + +--- + +### **3. Make Fields Optional Only When Necessary** + +✅ **Simple Fields**: Can be made optional if not required + +```scala +case class User( + id: String, + name: String, + email: Option[String] = None // ✅ OK for simple fields +) +``` + +⚠️ **Nested Objects/Collections**: Don't make optional to bypass validation errors + +```scala +// ❌ BAD PRACTICE: Making nested optional to avoid error +case class User( + id: String, + name: String, + address: Option[Address] = None // ❌ Avoid if 'address' is required +) + +// ✅ GOOD PRACTICE: Fix the SQL query +client.searchAs[User]("SELECT id, name, address FROM users") +``` + +--- + +### **4. Use Default Values with Caution** + +✅ **For Simple Fields**: +```scala +case class Product( + id: String, + name: String, + price: Double = 0.0, // ✅ OK + inStock: Boolean = true // ✅ OK +) +``` + +❌ **For Nested Objects** (avoid): +```scala +case class User( + id: String, + name: String, + address: Address = Address("", "", "") // ❌ Avoid +) +``` + +--- + +## **Debug Configuration** + +### **Enable Debug Mode** + +```scala +// In build.sbt or command line +System.setProperty("elastic.sql.debug", "true") +``` + +### **Debug Output Example** + +``` +================================================================================ +🔍 Starting SQL Query Validation +================================================================================ +📝 Extracted SQL: SELECT id, name, address.street FROM users +🔍 Parsed fields: id, name, address.street +🔍 Unnested collections: +📋 Required fields for User (prefix=''): id, name, address +🔍 Checking field: id (type: String, optional: false, hasDefault: false) +✅ Field 'id' is directly selected +🔍 Checking field: name (type: String, optional: false, hasDefault: false) +✅ Field 'name' is directly selected +🔍 Checking field: address (type: Address, optional: false, hasDefault: false) +🏗️ Field 'address' is a nested object (non-collection) +❌ ERROR: Nested object field 'address' cannot be deserialized correctly. +``` + +--- + +## **Validation Rules Summary** + +| Rule | Behavior | Level | +|--------------------------------------|-------------------------------------------------------|------------| +| **SELECT \*** | Prohibited | ❌ ERROR | +| **Required field missing** | Must be added, made optional, or have a default value | ❌ ERROR | +| **Unknown field** | Warning (ignored during deserialization) | ⚠️ WARNING | +| **Nested object without UNNEST** | Must select complete object or use UNNEST | ❌ ERROR | +| **Nested collection without UNNEST** | Must select complete collection or use UNNEST | ❌ ERROR | +| **Type incompatibility** | Must use a compatible Scala type | ❌ ERROR | + +--- + +**This compile-time validation ensures the robustness and maintainability of your code! 🚀✅** From c763de777626221f4746dfa72c58636b57490bef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Wed, 12 Nov 2025 07:21:50 +0100 Subject: [PATCH 10/11] update documentation for SQL Query validation --- README.md | 123 +++++++++++++++++- documentation/client/README.md | 1 - documentation/sql/README.md | 1 + .../sql_validation.md => sql/validation.md} | 0 .../sql/macros/SQLQueryValidator.scala | 6 - 5 files changed, 123 insertions(+), 8 deletions(-) rename documentation/{client/sql_validation.md => sql/validation.md} (100%) diff --git a/README.md b/README.md index 66433580..0ef820be 100644 --- a/README.md +++ b/README.md @@ -179,7 +179,9 @@ result match { --- -### **3. SQL to Elasticsearch Query Translation** +### **3. SQL compatible ** + +### **3.1 SQL to Elasticsearch Query DSL** SoftClient4ES includes a powerful SQL parser that translates standard SQL `SELECT` queries into native Elasticsearch queries. @@ -464,6 +466,125 @@ val results = client.search(SQLQuery(sqlQuery)) } } ``` +--- + +### **3.2. Compile-Time SQL Query Validation** + +SoftClient4ES provides **compile-time validation** for SQL queries used with type-safe methods like `searchAs[T]` and `scrollAs[T]`. This ensures that your queries are compatible with your Scala case classes **before your code even runs**, preventing runtime deserialization errors. + +#### **Why Compile-Time Validation?** + +- ✅ **Catch Errors Early**: Detect missing fields, typos, and type mismatches at compile-time +- ✅ **Type Safety**: Ensure SQL queries match your domain models +- ✅ **Better Developer Experience**: Get helpful error messages with suggestions +- ✅ **Prevent Runtime Failures**: No more Jackson deserialization exceptions in production + +#### **Validated Operations** + +| Validation | Description | Level | +|------------------------|--------------------------------------------------------|------------| +| **SELECT * Rejection** | Prohibits `SELECT *` to ensure compile-time validation | ❌ ERROR | +| **Required Fields** | Verifies that all required fields are selected | ❌ ERROR | +| **Unknown Fields** | Detects fields that don't exist in the case class | ⚠️ WARNING | +| **Nested Objects** | Validates the structure of nested objects | ❌ ERROR | +| **Nested Collections** | Validates the use of UNNEST for collections | ❌ ERROR | +| **Type Compatibility** | Checks compatibility between SQL and Scala types | ❌ ERROR | + +#### **Example 1: Missing Required Field with Nested Object** + +```scala +case class Address( + street: String, + city: String, + country: String +) + +case class User( + id: String, + name: String, + address: Address // ❌ Required nested object +) + +// ❌ COMPILE ERROR: Missing required field 'address' +client.searchAs[User]("SELECT id, name FROM users") +``` + +**Compile Error:** + +``` +❌ SQL query does not select the required field: address + +Example query: +SELECT id, name, address FROM ... + +To fix this, either: + 1. Add it to the SELECT clause + 2. Make it Option[T] in the case class + 3. Provide a default value in the case class definition +``` + +**✅ Solution:** + +```scala +// Option 1: Select the entire nested object (recommended) +client.searchAs[User]("SELECT id, name, address FROM users") + +// Option 2: Make the field optional +case class User( + id: String, + name: String, + address: Option[Address] = None +) +client.searchAs[User]("SELECT id, name FROM users") +``` + +#### **Example 2: Typo Detection with Smart Suggestions** + +```scala +case class Product( + id: String, + name: String, + price: Double, + stock: Int +) + +// ❌ COMPILE ERROR: Typo in 'name' -> 'nam' +client.searchAs[Product]("SELECT id, nam, price, stock FROM products") +``` + +**Compile Error:** +``` +❌ SQL query does not select the required field: name +You have selected unknown field "nam", did you mean "name"? + +Example query: +SELECT id, price, stock, name FROM ... + +To fix this, either: + 1. Add it to the SELECT clause + 2. Make it Option[T] in the case class + 3. Provide a default value in the case class definition +``` + +**✅ Solution:** +```scala +// Fix the typo +client.searchAs[Product]("SELECT id, name, price, stock FROM products") +``` + +#### **Dynamic Queries (Skip Validation)** + +For dynamic SQL queries where validation isn't possible, use the `*Unchecked` variants: + +```scala +val dynamicQuery = buildQueryAtRuntime() + +// ✅ Skip compile-time validation for dynamic queries +client.searchAsUnchecked[Product](SQLQuery(dynamicQuery)) +client.scrollAsUnchecked[Product](dynamicQuery) +``` + +📖 **[Full SQL Validation Documentation](documentation/sql/validation.md)** 📖 **[Full SQL Documentation](documentation/sql/README.md)** diff --git a/documentation/client/README.md b/documentation/client/README.md index 5a94564d..fc37f028 100644 --- a/documentation/client/README.md +++ b/documentation/client/README.md @@ -3,7 +3,6 @@ Welcome to the Client Engine Documentation. Navigate through the sections below: - [Client Common Principles](common_principles.md) -- [SQL Query Validation](sql_validation.md) - [Version Information](version.md) - [Flush Index](flush.md) - [Refresh Index](refresh.md) diff --git a/documentation/sql/README.md b/documentation/sql/README.md index 0bf373cc..d383815f 100644 --- a/documentation/sql/README.md +++ b/documentation/sql/README.md @@ -3,6 +3,7 @@ Welcome to the SQL Engine Documentation. Navigate through the sections below: - [Query Structure](request_structure.md) +- [Query Validation](validation.md) - [Operators](operators.md) - [Operator Precedence](operator_precedence.md) - [Aggregate Functions](functions_aggregate.md) diff --git a/documentation/client/sql_validation.md b/documentation/sql/validation.md similarity index 100% rename from documentation/client/sql_validation.md rename to documentation/sql/validation.md diff --git a/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala index 8af3b927..7067908b 100644 --- a/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala +++ b/macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala @@ -424,9 +424,6 @@ trait SQLQueryValidator { | FROM ... | JOIN UNNEST(....$fullFieldName) AS $fieldName | - |✅ Solution 3: Make the nested object optional (if it's not always needed) - | case class ${tpe.typeSymbol.name}(..., $fieldName: Option[${fieldType.typeSymbol.name}] = None) - | |📚 Note: This applies to ALL nested objects, not just collections. |""".stripMargin ) @@ -519,9 +516,6 @@ trait SQLQueryValidator { | FROM ... | JOIN UNNEST(....$fullFieldName) AS $fieldName | - |✅ Solution 3: Make the collection optional (if it's not always needed) - | case class ${parentType.typeSymbol.name}(..., $fieldName: Option[List[${elementType.typeSymbol.name}]] = None) - | |📚 Documentation: | https://www.elastic.co/guide/en/elasticsearch/reference/current/nested.html |""".stripMargin From a824d33a02e0c19c6b3d67dc1d966e1853ff63d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Wed, 12 Nov 2025 07:45:01 +0100 Subject: [PATCH 11/11] update documentation for searching and scrolling apis --- README.md | 2 +- documentation/client/scroll.md | 52 ++++++++++++++++++++++++++-------- documentation/client/search.md | 46 ++++++++++++++++++++++++++++-- 3 files changed, 85 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 0ef820be..aca96cf8 100644 --- a/README.md +++ b/README.md @@ -99,7 +99,7 @@ val searchResult = client.search(SQLQuery("SELECT * FROM users WHERE age > 25")) case class Product(id: String, name: String, price: Double, category: String, obsolete: Boolean) // Scroll through large datasets -val obsoleteProducts: Source[Product, NotUsed] = client.scrollAs[Product]( +val obsoleteProducts: Source[Product, NotUsed] = client.scrollAsUnchecked[Product]( """ |SELECT uuid AS id, name, price, category, outdated AS obsolete FROM products WHERE outdated = true |""".stripMargin diff --git a/documentation/client/scroll.md b/documentation/client/scroll.md index 61c5b35e..ee2fbe17 100644 --- a/documentation/client/scroll.md +++ b/documentation/client/scroll.md @@ -90,7 +90,7 @@ def scroll( )(implicit system: ActorSystem): Source[(Map[String, Any], ScrollMetrics), NotUsed] // Typed scroll source (automatic deserialization) -def scrollAs[T]( +def scrollAsUnchecked[T]( sql: SQLQuery, config: ScrollConfig = ScrollConfig() )(implicit @@ -394,7 +394,7 @@ val query = SQLQuery( ) // Scroll with automatic type conversion -client.scrollAs[Product](query).runWith(Sink.foreach { case (product, metrics) => +client.scrollAsUnchecked[Product](query).runWith(Sink.foreach { case (product, metrics) => println(s"Product: ${product.name} - $${product.price}") println(s"Progress: ${metrics.totalDocuments} products") }) @@ -407,7 +407,7 @@ client.scrollAs[Product](query).runWith(Sink.foreach { case (product, metrics) = ```scala // Collect all products val allProducts: Future[Seq[Product]] = - client.scrollAs[Product](query) + client.scrollAsUnchecked[Product](query) .map(_._1) // Extract product, discard metrics .runWith(Sink.seq) @@ -425,7 +425,7 @@ allProducts.foreach { products => ```scala // Filter expensive products during streaming -client.scrollAs[Product](query) +client.scrollAsUnchecked[Product](query) .filter { case (product, _) => product.price > 500 } .map(_._1) // Extract product .runWith(Sink.seq) @@ -444,6 +444,30 @@ client.scrollAs[Product](query) ```scala case class ProductSummary(name: String, value: Double) +client.scrollAsUnchecked[Product](query) + .map { case (product, _) => + ProductSummary( + name = product.name, + value = product.price * product.stock + ) + } + .runWith(Sink.seq) + .foreach { summaries => + val totalValue = summaries.map(_.value).sum + println(f"Total inventory value: $$${totalValue}%,.2f") + } +``` + +### Validating Query at compile-time + +```scala +val query = + """ + SELECT id, name, price, category, stock + FROM products + WHERE category = 'electronics' + """ + client.scrollAs[Product](query) .map { case (product, _) => ProductSummary( @@ -456,8 +480,12 @@ client.scrollAs[Product](query) val totalValue = summaries.map(_.value).sum println(f"Total inventory value: $$${totalValue}%,.2f") } +) + ``` +📖 **[Full SQL Validation Documentation](../sql/validation.md)** + --- ## Metrics and Monitoring @@ -807,7 +835,7 @@ def commitBatch(size: Int): Future[Unit] = { case class RawProduct(id: String, name: String, price: Double) case class EnrichedProduct(id: String, name: String, price: Double, category: String, tags: Seq[String]) -client.scrollAs[RawProduct](query) +client.scrollAsUnchecked[RawProduct](query) .mapAsync(parallelism = 4) { case (raw, _) => // Enrich each product enrichProduct(raw) @@ -904,7 +932,7 @@ case class Statistics( ) } -client.scrollAs[Product](query) +client.scrollAsUnchecked[Product](query) .map(_._1.price) // Extract prices .fold(Statistics())(_ update _) .runWith(Sink.head) @@ -922,7 +950,7 @@ client.scrollAs[Product](query) ### Conditional Processing ```scala -client.scrollAs[Product](query) +client.scrollAsUnchecked[Product](query) .mapAsync(parallelism = 4) { case (product, _) => product.category match { case "electronics" => processElectronics(product) @@ -1003,7 +1031,7 @@ class ScrollApiSpec extends AsyncFlatSpec with Matchers { // Test query = SQLQuery(query = s"SELECT id, value FROM $testIndex") - results <- client.scrollAs[TestDoc](query).map(_._1).runWith(Sink.seq) + results <- client.scrollAsUnchecked[TestDoc](query).map(_._1).runWith(Sink.seq) // Assertions _ = { @@ -1259,7 +1287,7 @@ client.scroll(query).map { case (doc, _) => // ✅ GOOD: Automatic type conversion implicit val formats: Formats = DefaultFormats -client.scrollAs[Product](query) +client.scrollAsUnchecked[Product](query) .map(_._1) .runWith(Sink.seq) ``` @@ -1556,7 +1584,7 @@ case class ValidationResult( ) def validateData(query: SQLQuery): Future[ValidationResult] = { - client.scrollAs[Product](query) + client.scrollAsUnchecked[Product](query) .map(_._1) .runWith(Sink.fold(ValidationResult(0, 0, Seq.empty)) { (result, product) => if (isValid(product)) { @@ -1598,7 +1626,7 @@ case class CategoryStats( ) def aggregateByCategory(query: SQLQuery): Future[Map[String, CategoryStats]] = { - client.scrollAs[Product](query) + client.scrollAsUnchecked[Product](query) .map(_._1) .runWith(Sink.fold(Map.empty[String, CategoryStats]) { (stats, product) => val current = stats.getOrElse( @@ -1645,7 +1673,7 @@ case class EnrichedOrder( ) def transformOrders(query: SQLQuery): Future[Seq[EnrichedOrder]] = { - client.scrollAs[RawOrder](query) + client.scrollAsUnchecked[RawOrder](query) .map(_._1) .mapAsync(parallelism = 4) { order => // Enrich with customer data diff --git a/documentation/client/search.md b/documentation/client/search.md index 5c6b0ca3..5dbdc8c6 100644 --- a/documentation/client/search.md +++ b/documentation/client/search.md @@ -14,10 +14,12 @@ - [singleSearchAsync](#singlesearchasync) - [multiSearchAsync](#multisearchasync) - [Search with Type Conversion](#search-with-type-conversion) + - [searchAsUnchecked](#searchasunchecked) - [searchAs](#searchas) - [singleSearchAs](#singlesearchas) - [multisearchAs](#multisearchas) - [Asynchronous Search with Type Conversion](#asynchronous-search-with-type-conversion) + - [searchAsyncAsUnchecked](#searchasyncasunchecked) - [searchAsyncAs](#searchasyncas) - [singleSearchAsyncAs](#singlesearchasyncas) - [multiSearchAsyncAs](#multisearchasyncas) @@ -611,7 +613,7 @@ client.multiSearchAsync(queries, Map.empty, Map.empty).foreach { ## Search with Type Conversion -### searchAs +### searchAsUnchecked Searches and automatically converts results to typed entities using an SQL query. @@ -706,6 +708,29 @@ val result: ElasticResult[List[EnrichedProduct]] = for { --- +### searchAs + +Searches and automatically converts results to typed entities using an SQL query [validated at compile-time](../sql/validation.md). + +**Signature:** + +```scala +def searchAs[U]( + query: String +)(implicit m: Manifest[U], formats: Formats): ElasticResult[Seq[U]] +``` + +**Parameters:** +- `query` - SQL query +- `m` - Implicit Manifest for type information +- `formats` - Implicit JSON serialization formats + +**Returns:** +- `ElasticSuccess[Seq[U]]` with typed entities +- `ElasticFailure` with conversion or search errors + +--- + ### singleSearchAs Searches and converts results to typed entities using an Elasticsearch query. @@ -818,7 +843,7 @@ client.multisearchAs[Product](queries, Map.empty, Map.empty) match { ## Asynchronous Search with Type Conversion -### searchAsyncAs +### searchAsyncAsUnchecked Asynchronously searches and converts results to typed entities. @@ -880,6 +905,23 @@ Future.sequence(futures).map { results => } } ``` +--- + +### searchAsyncAs + +Asynchronously searches and converts results to typed entities using an SQL query [validated at compile-time](../sql/validation.md). + +**Signature:** + +```scala +def searchAsyncAs[U]( + query: String +)(implicit + m: Manifest[U], + ec: ExecutionContext, + formats: Formats +): Future[ElasticResult[Seq[U]]] +``` ---