From 43eec09348334578027747034b26ac1b5e329d15 Mon Sep 17 00:00:00 2001 From: Alex Pavlychev Date: Sat, 6 Jul 2024 18:18:47 +0300 Subject: [PATCH] Knn query lacks inner_hits field (#3101) * close https://github.com/Philippus/elastic4s/issues/3100 * [review] tests --- .../queries/InnerHitQueryBodyFnTest.scala | 3 +- .../searches/queries/KnnBuilderFnTest.scala | 19 +++ .../elastic4s/requests/searches/knn/Knn.scala | 7 +- .../requests/searches/queries/InnerHit.scala | 6 +- .../handlers/searches/knn/KnnBuilderFn.scala | 2 + .../nested/InnerHitQueryBodyBuilder.scala | 3 + .../elastic4s/search/knn/ChunkKnnTest.scala | 142 ++++++++++++++++++ 7 files changed, 178 insertions(+), 4 deletions(-) create mode 100644 elastic4s-tests/src/test/scala/com/sksamuel/elastic4s/search/knn/ChunkKnnTest.scala diff --git a/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/searches/queries/InnerHitQueryBodyFnTest.scala b/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/searches/queries/InnerHitQueryBodyFnTest.scala index 64e68f0bf..74b65f9a1 100644 --- a/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/searches/queries/InnerHitQueryBodyFnTest.scala +++ b/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/searches/queries/InnerHitQueryBodyFnTest.scala @@ -20,8 +20,9 @@ class InnerHitQueryBodyFnTest extends AnyFunSuite with Matchers { .sortBy(FieldSort("sortField")) .storedFieldNames(List("field1", "field2")) .highlighting(HighlightField("hlField")) + .fields(List("f1", "f2")) new XContentBuilder(InnerHitQueryBodyBuilder.toJson(q)).string shouldBe - """{"name":"inners","from":2,"explain":false,"track_scores":true,"version":true,"size":2,"docvalue_fields":["df1","df2"],"sort":[{"sortField":{"order":"asc"}}],"stored_fields":["field1","field2"],"highlight":{"fields":{"hlField":{}}}}""" + """{"name":"inners","from":2,"explain":false,"track_scores":true,"version":true,"size":2,"docvalue_fields":["df1","df2"],"sort":[{"sortField":{"order":"asc"}}],"stored_fields":["field1","field2"],"fields":["f1","f2"],"highlight":{"fields":{"hlField":{}}}}""" } } diff --git a/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/searches/queries/KnnBuilderFnTest.scala b/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/searches/queries/KnnBuilderFnTest.scala index 9f3506692..ddb454b75 100644 --- a/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/searches/queries/KnnBuilderFnTest.scala +++ b/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/searches/queries/KnnBuilderFnTest.scala @@ -1,7 +1,9 @@ package com.sksamuel.elastic4s.requests.searches.queries import com.sksamuel.elastic4s.handlers.searches.knn.KnnBuilderFn +import com.sksamuel.elastic4s.requests.searches.HighlightField import com.sksamuel.elastic4s.requests.searches.knn.Knn +import com.sksamuel.elastic4s.requests.searches.sort.FieldSort import com.sksamuel.elastic4s.requests.searches.term.TermQuery import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers @@ -18,4 +20,21 @@ class KnnBuilderFnTest extends AnyFunSuite with Matchers { KnnBuilderFn(request).string shouldBe """{"field":"image-vector","query_vector":[54.0,10.0,-2.0],"k":5,"num_candidates":50,"similarity":10.0,"filter":{"term":{"file-type":{"value":"png"}}},"boost":0.4}""" } + test("Knn with inner_hits generates proper query.") { + val innerHit = InnerHit("inners") + .from(2) + .explain(false) + .trackScores(true) + .version(true) + .size(2) + .docValueFields(List("df1", "df2")) + .sortBy(FieldSort("sortField")) + .storedFieldNames(List("field1", "field2")) + .highlighting(HighlightField("hlField")) + .fields(List("f1", "f2")) + + val request = Knn("image-vector", 50, Seq(54,10,-2), inner = Some(innerHit)) k 5 filter TermQuery("file-type", "png") similarity 10 boost .4 + KnnBuilderFn(request).string shouldBe + """{"field":"image-vector","query_vector":[54.0,10.0,-2.0],"k":5,"num_candidates":50,"similarity":10.0,"filter":{"term":{"file-type":{"value":"png"}}},"boost":0.4,"inner_hits":{"name":"inners","from":2,"explain":false,"track_scores":true,"version":true,"size":2,"docvalue_fields":["df1","df2"],"sort":[{"sortField":{"order":"asc"}}],"stored_fields":["field1","field2"],"fields":["f1","f2"],"highlight":{"fields":{"hlField":{}}}}}""" + } } diff --git a/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/requests/searches/knn/Knn.scala b/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/requests/searches/knn/Knn.scala index 814ae393d..f151ca671 100644 --- a/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/requests/searches/knn/Knn.scala +++ b/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/requests/searches/knn/Knn.scala @@ -1,7 +1,7 @@ package com.sksamuel.elastic4s.requests.searches.knn import com.sksamuel.elastic4s.ext.OptionImplicits.RichOptionImplicits -import com.sksamuel.elastic4s.requests.searches.queries.Query +import com.sksamuel.elastic4s.requests.searches.queries.{InnerHit, Query} case class Knn( field: String, @@ -10,7 +10,8 @@ case class Knn( k: Int = 1, similarity: Option[Float] = None, filter: Option[Query] = None, - boost: Double = 1.0) { + boost: Double = 1.0, + inner: Option[InnerHit] = None) { def k(k: Int): Knn = copy(k = k) @@ -19,4 +20,6 @@ case class Knn( def filter(filter: Query): Knn = copy(filter = filter.some) def boost(boost: Double): Knn = copy(boost = boost) + + def inner(inner: InnerHit): Knn = copy(inner = Option(inner)) } diff --git a/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/requests/searches/queries/InnerHit.scala b/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/requests/searches/queries/InnerHit.scala index 6307bab71..3f622a7a3 100644 --- a/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/requests/searches/queries/InnerHit.scala +++ b/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/requests/searches/queries/InnerHit.scala @@ -15,7 +15,8 @@ case class InnerHit(name: String, docValueFields: Seq[String] = Nil, sorts: Seq[Sort] = Nil, from: Option[Int] = None, - highlight: Option[Highlight] = None) { + highlight: Option[Highlight] = None, + fields: Seq[String] = Nil) { def sortBy(sorts: Sort*): InnerHit = sortBy(sorts) def sortBy(sorts: Iterable[Sort]): InnerHit = copy(sorts = sorts.toSeq) @@ -45,4 +46,7 @@ case class InnerHit(name: String, copy(storedFieldNames = storedFieldNames.toSeq) def explain(explain: Boolean): InnerHit = copy(explain = explain.some) + + def fields(fields: Iterable[String]): InnerHit = + copy(fields = fields.toSeq) } diff --git a/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/searches/knn/KnnBuilderFn.scala b/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/searches/knn/KnnBuilderFn.scala index f6908399e..8849a0d82 100644 --- a/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/searches/knn/KnnBuilderFn.scala +++ b/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/searches/knn/KnnBuilderFn.scala @@ -1,6 +1,7 @@ package com.sksamuel.elastic4s.handlers.searches.knn import com.sksamuel.elastic4s.handlers.searches.queries.QueryBuilderFn +import com.sksamuel.elastic4s.handlers.searches.queries.nested.InnerHitQueryBodyBuilder import com.sksamuel.elastic4s.json.{XContentBuilder, XContentFactory} import com.sksamuel.elastic4s.requests.searches.knn.Knn @@ -17,6 +18,7 @@ object KnnBuilderFn { } knn.filter.foreach(filter => builder.rawField("filter", QueryBuilderFn(filter))) builder.field("boost", knn.boost) + knn.inner.foreach(inner => builder.field("inner_hits", InnerHitQueryBodyBuilder.toJson(inner))) builder.endObject() builder } diff --git a/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/searches/queries/nested/InnerHitQueryBodyBuilder.scala b/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/searches/queries/nested/InnerHitQueryBodyBuilder.scala index 7be8f59d6..9b1be0ecc 100644 --- a/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/searches/queries/nested/InnerHitQueryBodyBuilder.scala +++ b/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/searches/queries/nested/InnerHitQueryBodyBuilder.scala @@ -34,6 +34,9 @@ object InnerHitQueryBodyBuilder extends BodyBuilder[InnerHit] { if (d.storedFieldNames.nonEmpty) builder.array("stored_fields", d.storedFieldNames.toArray) + if (d.fields.nonEmpty) + builder.array("fields", d.fields.toArray) + d.highlight.foreach { highlight => builder.rawField("highlight", searches.HighlightBuilderFn(highlight)) } diff --git a/elastic4s-tests/src/test/scala/com/sksamuel/elastic4s/search/knn/ChunkKnnTest.scala b/elastic4s-tests/src/test/scala/com/sksamuel/elastic4s/search/knn/ChunkKnnTest.scala new file mode 100644 index 000000000..521b222ba --- /dev/null +++ b/elastic4s-tests/src/test/scala/com/sksamuel/elastic4s/search/knn/ChunkKnnTest.scala @@ -0,0 +1,142 @@ +package com.sksamuel.elastic4s.search.knn + +import com.sksamuel.elastic4s.fields.{DenseVectorField, DotProduct} +import com.sksamuel.elastic4s.requests.common.{FetchSourceContext, RefreshPolicy} +import com.sksamuel.elastic4s.requests.searches.queries.InnerHit +import com.sksamuel.elastic4s.testkit.DockerTests +import org.scalatest.BeforeAndAfterAll +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import scala.util.Try + +class ChunkKnnTest extends AnyFlatSpec with Matchers with DockerTests with BeforeAndAfterAll { + + private val INDEX = "chunk-knn-index" + private val FULL_TEXT_FIELD = "full_text" + private val CREATION_TIME_FIELD = "creation_time" + private val PARAGRAPH_FIELD = "paragraph" + private val VECTOR_FIELD = "vector" + private val TEXT_FIELD = "text" + private val PARAGRAPH_ID_FIELD = "paragraph_id" + + override protected def afterAll() = { + Try { + client.execute { + deleteIndex(INDEX) + }.await + } + } + + override protected def beforeAll() = { + Try { + client.execute { + deleteIndex(INDEX) + }.await + } + + client.execute { + createIndex(INDEX) mapping { + properties( + textField(FULL_TEXT_FIELD), + dateField(CREATION_TIME_FIELD), + nestedField(PARAGRAPH_FIELD).properties( + DenseVectorField( + name = VECTOR_FIELD, + dims = 2, + index = true + ), + textField(TEXT_FIELD).index(false) + ) + ) + } + }.await + + client.execute { + bulk( + indexInto(INDEX) id "1" fields(FULL_TEXT_FIELD -> "first paragraph another paragraph", CREATION_TIME_FIELD -> "2019-05-04", PARAGRAPH_FIELD -> Seq(Map(TEXT_FIELD -> "first paragraph", VECTOR_FIELD -> Seq(0.45, 45), PARAGRAPH_ID_FIELD -> "1"), Map(TEXT_FIELD -> "another paragraph", VECTOR_FIELD -> Seq(0.8, 0.6), PARAGRAPH_ID_FIELD -> "2"))), + indexInto(INDEX) id "2" fields(FULL_TEXT_FIELD -> "number one paragraph number two paragraph", CREATION_TIME_FIELD -> "2020-05-04", PARAGRAPH_FIELD -> Seq(Map(TEXT_FIELD -> "number one paragraph", VECTOR_FIELD -> Seq(1.2, 4.5), PARAGRAPH_ID_FIELD -> "1"), Map(TEXT_FIELD -> "number two paragraph", VECTOR_FIELD -> Seq(-1, 42), PARAGRAPH_ID_FIELD -> "2"))) + ).refresh(RefreshPolicy.Immediate) + }.await + } + + "knn search over nested dense_vectors" should "always diversify the top results over the top-level document" in { + val resp = client.execute( + search(INDEX) + .fetchSource(false) + .sourceInclude(FULL_TEXT_FIELD, CREATION_TIME_FIELD) + .knn { + knnQuery( + field = PARAGRAPH_FIELD + "." + VECTOR_FIELD, + vector = Seq(0.45, 45), + numCandidates = 2 + ).k(2) + } + ).await.result + + resp.totalHits shouldBe(2) + resp.hits.hits.map(_.id).toSet shouldBe(Set("1", "2")) + } + + "knn search with filter" should "always be over the top-level document metadata" in { + val resp = client.execute( + search(INDEX) + .fetchSource(false) + .sourceInclude(FULL_TEXT_FIELD, CREATION_TIME_FIELD) + .knn { + knnQuery( + field = PARAGRAPH_FIELD + "." + VECTOR_FIELD, + vector = Seq(0.45, 45), + numCandidates = 2 + ) + .k(2) + .filter { + rangeQuery(CREATION_TIME_FIELD) + .gte("2019-05-01") + .lte("2019-05-05") + } + } + ).await.result + + resp.totalHits shouldBe(1) + resp.hits.hits.map(_.id).toSet shouldBe(Set("1")) + } + + "knn search" should "contain the nearest found paragraph when searching" in { + val resp = client.execute( + search(INDEX) + .fetchSource(false) + .sourceInclude(FULL_TEXT_FIELD, CREATION_TIME_FIELD) + .knn { + knnQuery( + field = PARAGRAPH_FIELD + "." + VECTOR_FIELD, + vector = Seq(0.45, 45), + numCandidates = 2 + ).k(2) + .inner(InnerHit(PARAGRAPH_FIELD) + .fetchSource(FetchSourceContext(fetchSource = false, includes = Set(PARAGRAPH_FIELD + "." + TEXT_FIELD))) + .size(1) + .fields(Seq(PARAGRAPH_FIELD + "." + TEXT_FIELD)) + ) + } + ).await.result + + resp.totalHits shouldBe(2) + resp.hits.hits.map(_.id).toSet shouldBe(Set("1", "2")) + resp.hits.hits.map(_.innerHits.get(PARAGRAPH_FIELD).fold(Seq.empty[String]) { + _.hits.flatMap { hit => + val texts = hit + .docValueFieldOpt(PARAGRAPH_FIELD) + .fold[Seq[String]](Seq.empty) { + _.values.flatMap { v => + Try { + v.asInstanceOf[Map[String, Seq[String]]] + .getOrElse(TEXT_FIELD, Seq.empty) + }.getOrElse(Seq.empty) + } + } + texts + } + }).toSet shouldBe Set(List("first paragraph"), List("number two paragraph")) + } +}