Skip to content

Commit

Permalink
Knn query lacks inner_hits field (#3101)
Browse files Browse the repository at this point in the history
* close #3100

* [review] tests
  • Loading branch information
apavlychev committed Jul 6, 2024
1 parent e91e2a9 commit 43eec09
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ class InnerHitQueryBodyFnTest extends AnyFunSuite with Matchers {
.sortBy(FieldSort("sortField"))
.storedFieldNames(List("field1", "field2"))
.highlighting(HighlightField("hlField"))
.fields(List("f1", "f2"))

new XContentBuilder(InnerHitQueryBodyBuilder.toJson(q)).string shouldBe
"""{"name":"inners","from":2,"explain":false,"track_scores":true,"version":true,"size":2,"docvalue_fields":["df1","df2"],"sort":[{"sortField":{"order":"asc"}}],"stored_fields":["field1","field2"],"highlight":{"fields":{"hlField":{}}}}"""
"""{"name":"inners","from":2,"explain":false,"track_scores":true,"version":true,"size":2,"docvalue_fields":["df1","df2"],"sort":[{"sortField":{"order":"asc"}}],"stored_fields":["field1","field2"],"fields":["f1","f2"],"highlight":{"fields":{"hlField":{}}}}"""
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package com.sksamuel.elastic4s.requests.searches.queries

import com.sksamuel.elastic4s.handlers.searches.knn.KnnBuilderFn
import com.sksamuel.elastic4s.requests.searches.HighlightField
import com.sksamuel.elastic4s.requests.searches.knn.Knn
import com.sksamuel.elastic4s.requests.searches.sort.FieldSort
import com.sksamuel.elastic4s.requests.searches.term.TermQuery
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
Expand All @@ -18,4 +20,21 @@ class KnnBuilderFnTest extends AnyFunSuite with Matchers {
KnnBuilderFn(request).string shouldBe
"""{"field":"image-vector","query_vector":[54.0,10.0,-2.0],"k":5,"num_candidates":50,"similarity":10.0,"filter":{"term":{"file-type":{"value":"png"}}},"boost":0.4}"""
}
test("Knn with inner_hits generates proper query.") {
val innerHit = InnerHit("inners")
.from(2)
.explain(false)
.trackScores(true)
.version(true)
.size(2)
.docValueFields(List("df1", "df2"))
.sortBy(FieldSort("sortField"))
.storedFieldNames(List("field1", "field2"))
.highlighting(HighlightField("hlField"))
.fields(List("f1", "f2"))

val request = Knn("image-vector", 50, Seq(54,10,-2), inner = Some(innerHit)) k 5 filter TermQuery("file-type", "png") similarity 10 boost .4
KnnBuilderFn(request).string shouldBe
"""{"field":"image-vector","query_vector":[54.0,10.0,-2.0],"k":5,"num_candidates":50,"similarity":10.0,"filter":{"term":{"file-type":{"value":"png"}}},"boost":0.4,"inner_hits":{"name":"inners","from":2,"explain":false,"track_scores":true,"version":true,"size":2,"docvalue_fields":["df1","df2"],"sort":[{"sortField":{"order":"asc"}}],"stored_fields":["field1","field2"],"fields":["f1","f2"],"highlight":{"fields":{"hlField":{}}}}}"""
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.sksamuel.elastic4s.requests.searches.knn

import com.sksamuel.elastic4s.ext.OptionImplicits.RichOptionImplicits
import com.sksamuel.elastic4s.requests.searches.queries.Query
import com.sksamuel.elastic4s.requests.searches.queries.{InnerHit, Query}

case class Knn(
field: String,
Expand All @@ -10,7 +10,8 @@ case class Knn(
k: Int = 1,
similarity: Option[Float] = None,
filter: Option[Query] = None,
boost: Double = 1.0) {
boost: Double = 1.0,
inner: Option[InnerHit] = None) {

def k(k: Int): Knn = copy(k = k)

Expand All @@ -19,4 +20,6 @@ case class Knn(
def filter(filter: Query): Knn = copy(filter = filter.some)

def boost(boost: Double): Knn = copy(boost = boost)

def inner(inner: InnerHit): Knn = copy(inner = Option(inner))
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ case class InnerHit(name: String,
docValueFields: Seq[String] = Nil,
sorts: Seq[Sort] = Nil,
from: Option[Int] = None,
highlight: Option[Highlight] = None) {
highlight: Option[Highlight] = None,
fields: Seq[String] = Nil) {

def sortBy(sorts: Sort*): InnerHit = sortBy(sorts)
def sortBy(sorts: Iterable[Sort]): InnerHit = copy(sorts = sorts.toSeq)
Expand Down Expand Up @@ -45,4 +46,7 @@ case class InnerHit(name: String,
copy(storedFieldNames = storedFieldNames.toSeq)

def explain(explain: Boolean): InnerHit = copy(explain = explain.some)

def fields(fields: Iterable[String]): InnerHit =
copy(fields = fields.toSeq)
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.sksamuel.elastic4s.handlers.searches.knn

import com.sksamuel.elastic4s.handlers.searches.queries.QueryBuilderFn
import com.sksamuel.elastic4s.handlers.searches.queries.nested.InnerHitQueryBodyBuilder
import com.sksamuel.elastic4s.json.{XContentBuilder, XContentFactory}
import com.sksamuel.elastic4s.requests.searches.knn.Knn

Expand All @@ -17,6 +18,7 @@ object KnnBuilderFn {
}
knn.filter.foreach(filter => builder.rawField("filter", QueryBuilderFn(filter)))
builder.field("boost", knn.boost)
knn.inner.foreach(inner => builder.field("inner_hits", InnerHitQueryBodyBuilder.toJson(inner)))
builder.endObject()
builder
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ object InnerHitQueryBodyBuilder extends BodyBuilder[InnerHit] {
if (d.storedFieldNames.nonEmpty)
builder.array("stored_fields", d.storedFieldNames.toArray)

if (d.fields.nonEmpty)
builder.array("fields", d.fields.toArray)

d.highlight.foreach { highlight =>
builder.rawField("highlight", searches.HighlightBuilderFn(highlight))
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package com.sksamuel.elastic4s.search.knn

import com.sksamuel.elastic4s.fields.{DenseVectorField, DotProduct}
import com.sksamuel.elastic4s.requests.common.{FetchSourceContext, RefreshPolicy}
import com.sksamuel.elastic4s.requests.searches.queries.InnerHit
import com.sksamuel.elastic4s.testkit.DockerTests
import org.scalatest.BeforeAndAfterAll
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

import scala.util.Try

class ChunkKnnTest extends AnyFlatSpec with Matchers with DockerTests with BeforeAndAfterAll {

private val INDEX = "chunk-knn-index"
private val FULL_TEXT_FIELD = "full_text"
private val CREATION_TIME_FIELD = "creation_time"
private val PARAGRAPH_FIELD = "paragraph"
private val VECTOR_FIELD = "vector"
private val TEXT_FIELD = "text"
private val PARAGRAPH_ID_FIELD = "paragraph_id"

override protected def afterAll() = {
Try {
client.execute {
deleteIndex(INDEX)
}.await
}
}

override protected def beforeAll() = {
Try {
client.execute {
deleteIndex(INDEX)
}.await
}

client.execute {
createIndex(INDEX) mapping {
properties(
textField(FULL_TEXT_FIELD),
dateField(CREATION_TIME_FIELD),
nestedField(PARAGRAPH_FIELD).properties(
DenseVectorField(
name = VECTOR_FIELD,
dims = 2,
index = true
),
textField(TEXT_FIELD).index(false)
)
)
}
}.await

client.execute {
bulk(
indexInto(INDEX) id "1" fields(FULL_TEXT_FIELD -> "first paragraph another paragraph", CREATION_TIME_FIELD -> "2019-05-04", PARAGRAPH_FIELD -> Seq(Map(TEXT_FIELD -> "first paragraph", VECTOR_FIELD -> Seq(0.45, 45), PARAGRAPH_ID_FIELD -> "1"), Map(TEXT_FIELD -> "another paragraph", VECTOR_FIELD -> Seq(0.8, 0.6), PARAGRAPH_ID_FIELD -> "2"))),
indexInto(INDEX) id "2" fields(FULL_TEXT_FIELD -> "number one paragraph number two paragraph", CREATION_TIME_FIELD -> "2020-05-04", PARAGRAPH_FIELD -> Seq(Map(TEXT_FIELD -> "number one paragraph", VECTOR_FIELD -> Seq(1.2, 4.5), PARAGRAPH_ID_FIELD -> "1"), Map(TEXT_FIELD -> "number two paragraph", VECTOR_FIELD -> Seq(-1, 42), PARAGRAPH_ID_FIELD -> "2")))
).refresh(RefreshPolicy.Immediate)
}.await
}

"knn search over nested dense_vectors" should "always diversify the top results over the top-level document" in {
val resp = client.execute(
search(INDEX)
.fetchSource(false)
.sourceInclude(FULL_TEXT_FIELD, CREATION_TIME_FIELD)
.knn {
knnQuery(
field = PARAGRAPH_FIELD + "." + VECTOR_FIELD,
vector = Seq(0.45, 45),
numCandidates = 2
).k(2)
}
).await.result

resp.totalHits shouldBe(2)
resp.hits.hits.map(_.id).toSet shouldBe(Set("1", "2"))
}

"knn search with filter" should "always be over the top-level document metadata" in {
val resp = client.execute(
search(INDEX)
.fetchSource(false)
.sourceInclude(FULL_TEXT_FIELD, CREATION_TIME_FIELD)
.knn {
knnQuery(
field = PARAGRAPH_FIELD + "." + VECTOR_FIELD,
vector = Seq(0.45, 45),
numCandidates = 2
)
.k(2)
.filter {
rangeQuery(CREATION_TIME_FIELD)
.gte("2019-05-01")
.lte("2019-05-05")
}
}
).await.result

resp.totalHits shouldBe(1)
resp.hits.hits.map(_.id).toSet shouldBe(Set("1"))
}

"knn search" should "contain the nearest found paragraph when searching" in {
val resp = client.execute(
search(INDEX)
.fetchSource(false)
.sourceInclude(FULL_TEXT_FIELD, CREATION_TIME_FIELD)
.knn {
knnQuery(
field = PARAGRAPH_FIELD + "." + VECTOR_FIELD,
vector = Seq(0.45, 45),
numCandidates = 2
).k(2)
.inner(InnerHit(PARAGRAPH_FIELD)
.fetchSource(FetchSourceContext(fetchSource = false, includes = Set(PARAGRAPH_FIELD + "." + TEXT_FIELD)))
.size(1)
.fields(Seq(PARAGRAPH_FIELD + "." + TEXT_FIELD))
)
}
).await.result

resp.totalHits shouldBe(2)
resp.hits.hits.map(_.id).toSet shouldBe(Set("1", "2"))
resp.hits.hits.map(_.innerHits.get(PARAGRAPH_FIELD).fold(Seq.empty[String]) {
_.hits.flatMap { hit =>
val texts = hit
.docValueFieldOpt(PARAGRAPH_FIELD)
.fold[Seq[String]](Seq.empty) {
_.values.flatMap { v =>
Try {
v.asInstanceOf[Map[String, Seq[String]]]
.getOrElse(TEXT_FIELD, Seq.empty)
}.getOrElse(Seq.empty)
}
}
texts
}
}).toSet shouldBe Set(List("first paragraph"), List("number two paragraph"))
}
}

0 comments on commit 43eec09

Please sign in to comment.