Skip to content

Commit

Permalink
Introduce new DenseVectorField apply method
Browse files Browse the repository at this point in the history
and deprecate old DenseVectorField apply methods
  • Loading branch information
Philippus committed Jul 13, 2024
1 parent 9e6bc95 commit 32c3749
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,26 @@ class DenseVectorFieldTest extends AnyFlatSpec with Matchers with ElasticApi {
}

it should "not set similarity or indexOptions when index = false" in {
val field = DenseVectorField(name = "myfield", dims = 3, index = false, indexOptions = Some(denseVectorIndexOptions))
val field = DenseVectorField(name = "myfield", dims = Some(3), index = Some(false), indexOptions = Some(denseVectorIndexOptions))
DenseVectorFieldBuilderFn.build(field).string shouldBe
"""{"type":"dense_vector","dims":3,"index":false}"""
}

it should "support indexOptions property" in {
val field = DenseVectorField(name = "myfield", dims = 3, index = true, indexOptions = Some(denseVectorIndexOptions))
val field = DenseVectorField(name = "myfield", dims = Some(3), index = Some(true), indexOptions = Some(denseVectorIndexOptions))
DenseVectorFieldBuilderFn.build(field).string shouldBe
"""{"type":"dense_vector","dims":3,"index":true,"similarity":"l2_norm","index_options":{"type":"int8_hnsw","m":10,"ef_construction":100,"confidence_interval":1.0}}"""
"""{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"int8_hnsw","m":10,"ef_construction":100,"confidence_interval":1.0}}"""
}

it should "support all index options types and only set m, efConstruction and confidenceInterval when applicable" in {
val field = DenseVectorField(name = "myfield", dims = 3, index = true, indexOptions = Some(denseVectorIndexOptions))
val field = DenseVectorField(name = "myfield", dims = Some(3), index = Some(true), indexOptions = Some(denseVectorIndexOptions))
DenseVectorFieldBuilderFn.build(field).string shouldBe
"""{"type":"dense_vector","dims":3,"index":true,"similarity":"l2_norm","index_options":{"type":"int8_hnsw","m":10,"ef_construction":100,"confidence_interval":1.0}}"""
"""{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"int8_hnsw","m":10,"ef_construction":100,"confidence_interval":1.0}}"""
DenseVectorFieldBuilderFn.build(field.indexOptions(denseVectorIndexOptions.copy(`type` = Hnsw))).string shouldBe
"""{"type":"dense_vector","dims":3,"index":true,"similarity":"l2_norm","index_options":{"type":"hnsw","m":10,"ef_construction":100}}"""
"""{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"hnsw","m":10,"ef_construction":100}}"""
DenseVectorFieldBuilderFn.build(field.indexOptions(denseVectorIndexOptions.copy(`type` = Flat))).string shouldBe
"""{"type":"dense_vector","dims":3,"index":true,"similarity":"l2_norm","index_options":{"type":"flat"}}"""
"""{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"flat"}}"""
DenseVectorFieldBuilderFn.build(field.indexOptions(denseVectorIndexOptions.copy(`type` = Int8Flat))).string shouldBe
"""{"type":"dense_vector","dims":3,"index":true,"similarity":"l2_norm","index_options":{"type":"int8_flat","confidence_interval":1.0}}"""
"""{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"int8_flat","confidence_interval":1.0}}"""
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,48 @@ object DenseVectorField {
case object Int8Hnsw extends KnnType { val name = "int8_hnsw" }
case object Flat extends KnnType { val name = "flat" }
case object Int8Flat extends KnnType { val name = "int8_flat" }

@deprecated("Use the new apply method", "8.14.0")
def apply(name: String,
dims: Int): DenseVectorField =
DenseVectorField(name, None, Some(dims), Some(false), Some(L2Norm))

@deprecated("Use the new apply method", "8.14.0")
def apply(name: String,
dims: Int,
index: Boolean): DenseVectorField =
DenseVectorField(name, None, Some(dims), Some(index), Some(L2Norm))

@deprecated("Use the new apply method", "8.14.0")
def apply(name: String,
dims: Int,
index: Boolean,
similarity: Similarity): DenseVectorField =
DenseVectorField(name, None, Some(dims), Some(index), Some(similarity))
}

sealed trait Similarity {
def name: String
}

case object L2Norm extends Similarity { val name = "l2_norm" }
case object DotProduct extends Similarity { val name = "dot_product" }
case object Cosine extends Similarity { val name = "cosine" }
case object MaxInnerProduct extends Similarity { val name = "max_inner_product" }

case class DenseVectorField(name: String,
dims: Int,
index: Boolean = false,
similarity: Similarity = L2Norm,
indexOptions: Option[DenseVectorIndexOptions] = None,
elementType: Option[String] = None) extends ElasticField {
override def `type`: String = DenseVectorField.`type`

def dims(dims: Int): DenseVectorField = copy(dims = dims)

def index(index: Boolean): DenseVectorField = copy(index = index)
case class DenseVectorIndexOptions(`type`: DenseVectorField.KnnType, m: Option[Int] = None, efConstruction: Option[Int] = None, confidenceInterval: Option[Double] = None)

def similarity(similarity: Similarity): DenseVectorField = copy(similarity = similarity)
case class DenseVectorField(name: String,
elementType: Option[String] = None,
dims: Option[Int] = None,
index: Option[Boolean] = None,
similarity: Option[Similarity] = None,
indexOptions: Option[DenseVectorIndexOptions] = None) extends ElasticField {
override def `type`: String = DenseVectorField.`type`

def elementType(elementType: String): DenseVectorField = copy(elementType = Some(elementType))

def dims(dims: Int): DenseVectorField = copy(dims = Some(dims))
def index(index: Boolean): DenseVectorField = copy(index = Some(index))
def similarity(similarity: Similarity): DenseVectorField = copy(similarity = Some(similarity))
def indexOptions(indexOptions: DenseVectorIndexOptions): DenseVectorField = copy(indexOptions = Some(indexOptions))
}

case class DenseVectorIndexOptions(`type`: DenseVectorField.KnnType, m: Option[Int] = None, efConstruction: Option[Int] = None, confidenceInterval: Option[Double] = None)
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
package com.sksamuel.elastic4s.handlers.fields

import com.sksamuel.elastic4s.fields.DenseVectorField.{Hnsw, Int8Flat, Int8Hnsw}
import com.sksamuel.elastic4s.fields.{DenseVectorField, DenseVectorIndexOptions}
import com.sksamuel.elastic4s.fields.{Cosine, DenseVectorField, DenseVectorIndexOptions, DotProduct, L2Norm, MaxInnerProduct, Similarity}
import com.sksamuel.elastic4s.json.{XContentBuilder, XContentFactory}

object DenseVectorFieldBuilderFn {
private def similarityFromString(similarity: String): Similarity = similarity match {
case "l2_norm" => L2Norm
case "dot_product" => DotProduct
case "cosine" => Cosine
case "max_inner_product" => MaxInnerProduct
}

private def getIndexOptions(values: Map[String, Any]): DenseVectorIndexOptions =
values("type").asInstanceOf[String] match {
Expand Down Expand Up @@ -32,21 +38,21 @@ object DenseVectorFieldBuilderFn {

def toField(name: String, values: Map[String, Any]): DenseVectorField = DenseVectorField(
name,
values("dims").asInstanceOf[Int],
values("index").asInstanceOf[Boolean],
values.get("element_type").map(_.asInstanceOf[String]),
values.get("dims").map(_.asInstanceOf[Int]),
values.get("index").map(_.asInstanceOf[Boolean]),
values.get("similarity").map(s => similarityFromString(s.asInstanceOf[String])),
indexOptions = values.get("index_options").map(_.asInstanceOf[Map[String, Any]]).map(getIndexOptions),
elementType = values.get("element_type").map(_.asInstanceOf[String])
)

def build(field: DenseVectorField): XContentBuilder = {

val builder = XContentFactory.jsonBuilder()
builder.field("type", field.`type`)
field.elementType.foreach(builder.field("element_type", _))
builder.field("dims", field.dims)
builder.field("index", field.index)
if (field.index) {
builder.field("similarity", field.similarity.name)
field.dims.foreach(builder.field("dims", _))
field.index.foreach(builder.field("index", _))
if (field.index.getOrElse(true)) {
field.similarity.foreach(similarity => builder.field("similarity", similarity.name))
field.indexOptions.foreach { options =>
builder.startObject("index_options")
builder.field("type", options.`type`.name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,22 @@ class ElasticFieldBuilderFnTest extends AnyWordSpec with Matchers {
}

"support DenseVectorField" in {
val field = DenseVectorField("dense_vector_field", dims = 3, index = true, indexOptions = Some(DenseVectorIndexOptions(DenseVectorField.Flat)), elementType = Some("byte"))
val jsonString = """{"type":"dense_vector","element_type":"byte","dims":3,"index":true,"similarity":"l2_norm","index_options":{"type":"flat"}}"""
val field = DenseVectorField("dense_vector_field", elementType = Some("byte"), dims = Some(3), index = Some(true), indexOptions = Some(DenseVectorIndexOptions(DenseVectorField.Flat)))
val jsonString = """{"type":"dense_vector","element_type":"byte","dims":3,"index":true,"index_options":{"type":"flat"}}"""
ElasticFieldBuilderFn(field).string shouldBe jsonString
ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonString)) shouldBe field
}

"support DenseVectorField with similarity" in {
val field = DenseVectorField("dense_vector_field", elementType = Some("byte"), dims = Some(3), index = Some(true), similarity = Some(MaxInnerProduct))
val jsonString = """{"type":"dense_vector","element_type":"byte","dims":3,"index":true,"similarity":"max_inner_product"}"""
ElasticFieldBuilderFn(field).string shouldBe jsonString
ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonString)) shouldBe field
}

"support DenseVectorField with all index options" in {
val field = DenseVectorField("dense_vector_field", dims = 3, index = true, indexOptions = Some(DenseVectorIndexOptions(DenseVectorField.Int8Hnsw, Some(100), Some(200), Some(0.5f))), elementType = Some("byte"))
val jsonString = """{"type":"dense_vector","element_type":"byte","dims":3,"index":true,"similarity":"l2_norm","index_options":{"type":"int8_hnsw","m":100,"ef_construction":200,"confidence_interval":0.5}}"""
val field = DenseVectorField("dense_vector_field", elementType = Some("byte"), dims = Some(3), index = Some(true), indexOptions = Some(DenseVectorIndexOptions(DenseVectorField.Int8Hnsw, Some(100), Some(200), Some(0.5f))))
val jsonString = """{"type":"dense_vector","element_type":"byte","dims":3,"index":true,"index_options":{"type":"int8_hnsw","m":100,"ef_construction":200,"confidence_interval":0.5}}"""
ElasticFieldBuilderFn(field).string shouldBe jsonString
ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonString)) shouldBe field
}
Expand Down

0 comments on commit 32c3749

Please sign in to comment.