Skip to content

Commit

Permalink
close #3098
Browse files Browse the repository at this point in the history
  • Loading branch information
apavlychev committed Jun 27, 2024
1 parent c4d83b8 commit 0ff12c9
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
package com.sksamuel.elastic4s.requests.mappings

import com.sksamuel.elastic4s.ElasticApi
import com.sksamuel.elastic4s.fields.DenseVectorField
import com.sksamuel.elastic4s.handlers.fields.DenseVectorFieldBuilderFn
import com.sksamuel.elastic4s.{ElasticApi, JacksonSupport}
import com.sksamuel.elastic4s.fields.{DenseVectorField, DenseVectorIndexOptions, Hnsw}
import com.sksamuel.elastic4s.handlers.fields.ElasticFieldBuilderFn
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

class DenseVectorFieldTest extends AnyFlatSpec with Matchers with ElasticApi {

private val field = DenseVectorField(name = "myfield", dims = 3)
private val field = DenseVectorField(
name = "myfield",
dims = 3,
indexOptions = Some(DenseVectorIndexOptions(
`type` = Hnsw,
m = Some(100),
efConstruction = Some(100),
confidenceInterval = Some(0.5d)
))
)

"A DenseVectorField" should "support dims property" in {
DenseVectorFieldBuilderFn.build(field).string shouldBe
"""{"type":"dense_vector","dims":3,"index":false,"similarity":"l2_norm"}"""
"A DenseVectorField" should "support dims and index_options properties" in {
val jsonStringValue = """{"type":"dense_vector","dims":3,"index":false,"similarity":"l2_norm","index_options":{"type":"hnsw","m":100,"ef_construction":100,"confidence_interval":0.5}}"""
ElasticFieldBuilderFn(field).string shouldBe jsonStringValue
ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonStringValue)) shouldBe (field)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@ case object Cosine extends Similarity { val name = "cosine" }
case class DenseVectorField(name: String,
dims: Int,
index: Boolean = false,
similarity: Similarity = L2Norm) extends ElasticField {
similarity: Similarity = L2Norm,
indexOptions: Option[DenseVectorIndexOptions] = None) extends ElasticField {
override def `type`: String = DenseVectorField.`type`
}

sealed trait KnnAlgorithmType {
def name: String
}
case object Hnsw extends KnnAlgorithmType { val name = "hnsw" }
case object Int8Hnsw extends KnnAlgorithmType { val name = "int8_hnsw" }
case object Flat extends KnnAlgorithmType { val name = "flat" }
case object Int8Flat extends KnnAlgorithmType { val name = "int8_flat" }

case class DenseVectorIndexOptions(`type`: KnnAlgorithmType,
m: Option[Int] = None,
efConstruction: Option[Int] = None,
confidenceInterval: Option[Double] = None)
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
package com.sksamuel.elastic4s.handlers.fields

import com.sksamuel.elastic4s.fields.DenseVectorField
import com.sksamuel.elastic4s.fields.{DenseVectorField, DenseVectorIndexOptions, Flat, Hnsw, Int8Flat, Int8Hnsw}
import com.sksamuel.elastic4s.json.{XContentBuilder, XContentFactory}

object DenseVectorFieldBuilderFn {

private def getIndexOptions(values: Map[String, Any]) = DenseVectorIndexOptions(
values("type").asInstanceOf[String] match {
case "hnsw" => Hnsw
case "int8_hnsw" => Int8Hnsw
case "flat" => Flat
case "int8_flat" => Int8Flat
},
values.get("m").map(_.asInstanceOf[Int]),
values.get("ef_construction").map(_.asInstanceOf[Int]),
values.get("confidence_interval").map(_.asInstanceOf[Double])
)

def toField(name: String, values: Map[String, Any]): DenseVectorField = DenseVectorField(
name,
values.get("dims").map(_.asInstanceOf[Int]).get
values.get("dims").map(_.asInstanceOf[Int]).get,
indexOptions = values.get("index_options").map(_.asInstanceOf[Map[String, Any]]).map(getIndexOptions),
)


Expand All @@ -17,6 +31,14 @@ object DenseVectorFieldBuilderFn {
builder.field("dims", field.dims)
builder.field("index", field.index)
builder.field("similarity", field.similarity.name)
field.indexOptions.foreach { options =>
builder.startObject("index_options")
builder.field("type", options.`type`.name)
options.m.foreach(builder.field("m", _))
options.efConstruction.foreach(builder.field("ef_construction", _))
options.confidenceInterval.foreach(builder.field("confidence_interval", _))
builder.endObject()
}
builder.endObject()
}
}

0 comments on commit 0ff12c9

Please sign in to comment.