diff --git a/build.gradle b/build.gradle index 64887a175..c68f13963 100644 --- a/build.gradle +++ b/build.gradle @@ -47,6 +47,7 @@ subprojects { Project api4s = project(':api4s') Project benchmarks = project(':benchmarks') Project clientElastic4s = project(':client-elastic4s') +Project clientJava = project(':client-java') Project lucene = project(':lucene') Project models = project(':models') Project plugin = project(':plugin') @@ -204,6 +205,16 @@ configure(clientElastic4s, List.of( } })) +configure(clientJava, List.of( + publishConfig("Java APIs for Elastiknn, intended for use with Elasticsearch REST clients", false), + { + dependencies { + implementation "org.elasticsearch:elasticsearch:${esVersion}" + } + } +)) + + configure(models, List.of( publishConfig("Exact and approximate similarity models used in Elastiknn", false))) @@ -266,6 +277,7 @@ configure(testing, List.of(scalaProjectConfig, { dependencies { implementation models implementation clientElastic4s + implementation clientJava implementation plugin implementation lucene implementation 'com.typesafe:config:1.4.0' @@ -276,6 +288,7 @@ configure(testing, List.of(scalaProjectConfig, { implementation "org.apache.lucene:lucene-codecs:${luceneVersion}" implementation "org.apache.lucene:lucene-analyzers-common:${luceneVersion}" implementation "org.elasticsearch:elasticsearch:${esVersion}" + implementation "org.elasticsearch.client:elasticsearch-rest-high-level-client:${esVersion}" implementation "com.storm-enroute:scalameter_${scalaShortVersion}:0.19" implementation "org.scalanlp:breeze_${scalaShortVersion}:1.0" implementation "com.klibisz.futil:futil_${scalaShortVersion}:0.1.2" diff --git a/docs/Taskfile.yml b/docs/Taskfile.yml index 1955ecae8..728ba81d4 100644 --- a/docs/Taskfile.yml +++ b/docs/Taskfile.yml @@ -14,7 +14,7 @@ tasks: - install-bundler cmds: - bundle install - - bundle exec jekyll serve + - bundle exec jekyll serve --port 4001 compile: desc: Compile docs into a static site. diff --git a/docs/pages/libraries.md b/docs/pages/libraries.md index 45654a66c..f65e4d624 100644 --- a/docs/pages/libraries.md +++ b/docs/pages/libraries.md @@ -34,9 +34,9 @@ This includes a low level client that roughly mirrors the [Scala client](/scala- |:--|:--| |Release|[![Python Release][Badge-Python-Release]][Link-Python-Release]| -## Java library with exact and approximate similarity models +## Java library with exact and approximate nearest neighbor search models -This library contains the exact and approximate similarity models used by Elastiknn. +This library contains the exact and approximate nearest neighbor search models used by Elastiknn. **Install** @@ -68,9 +68,27 @@ implementation 'com.klibisz.elastiknn:lucene:' **Versions** |:--|:--| -|Rekease|[![Lucene Release][Badge-Lucene-Release]][Link-Lucene-Release]| +|Release|[![Lucene Release][Badge-Lucene-Release]][Link-Lucene-Release]| |Snapshot|[![Lucene Snapshot][Badge-Lucene-Snapshot]][Link-Lucene-Snapshot]| +## Java library with Elasticsearch query builder for Elastiknn queries + +This library contains a custom [query builder](https://www.elastic.co/guide/en/elasticsearch/client/java-rest/current/java-rest-high-query-builders.html) +for defining Elastiknn queries in Java. + +**Install** + +In a Gradle project: + +```groovy +implementation 'com.klibisz.elastiknn:client-java:' +``` + +**Versions** + +|:--|:--| +|Release|[![Lucene Release][Badge-Java-Client-Release]][Link-Java-Client-Release]| +|Snapshot|[![Lucene Snapshot][Badge-Java-Client-Snapshot]][Link-Java-Client-Snapshot]| ## Scala client @@ -142,6 +160,11 @@ libraryDependencies += "com.klibisz.elastiknn" %% "api4s" % [Link-Lucene-Release]: https://search.maven.org/artifact/com.klibisz.elastiknn/lucene [Link-Lucene-Snapshot]: https://oss.sonatype.org/#nexus-search;gav~com.klibisz.elastiknn~lucene~~~ +[Badge-Java-Client-Release]: https://img.shields.io/nexus/r/com.klibisz.elastiknn/client-java?server=http%3A%2F%2Foss.sonatype.org&style=flat-square "lucene release" +[Badge-Java-Client-Snapshot]: https://img.shields.io/nexus/s/com.klibisz.elastiknn/client-java?server=http%3A%2F%2Foss.sonatype.org&style=flat-square "lucene snapshot" +[Link-Java-Client-Release]: https://search.maven.org/artifact/com.klibisz.elastiknn/client-java +[Link-Java-Client-Snapshot]: https://oss.sonatype.org/#nexus-search;gav~com.klibisz.elastiknn~client-java~~~ + [Badge-Api4s-Release]: https://img.shields.io/nexus/r/com.klibisz.elastiknn/api4s_2.12?server=http%3A%2F%2Foss.sonatype.org&style=flat-square "api4s_2.12 release" [Badge-Api4s-Snapshot]: https://img.shields.io/nexus/s/com.klibisz.elastiknn/api4s_2.12?server=http%3A%2F%2Foss.sonatype.org&style=flat-square "api4s_2.12 snapshot" [Link-Api4s-Release]: https://search.maven.org/artifact/com.klibisz.elastiknn/api4s_2.12 diff --git a/elastiknn-client-java/src/main/java/com/klibisz/elastiknn/ElastiknnNearestNeighborsQueryBuilder.java b/elastiknn-client-java/src/main/java/com/klibisz/elastiknn/ElastiknnNearestNeighborsQueryBuilder.java new file mode 100644 index 000000000..2f6290620 --- /dev/null +++ b/elastiknn-client-java/src/main/java/com/klibisz/elastiknn/ElastiknnNearestNeighborsQueryBuilder.java @@ -0,0 +1,86 @@ +package com.klibisz.elastiknn; + +import com.klibisz.elastiknn.api4j.ElastiknnNearestNeighborsQuery; +import com.klibisz.elastiknn.api4j.Vector; +import org.apache.lucene.search.Query; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.SearchExecutionContext; + +import java.io.IOException; +import java.util.Objects; + +public class ElastiknnNearestNeighborsQueryBuilder extends AbstractQueryBuilder { + + private final ElastiknnNearestNeighborsQuery query; + private final String field; + + public ElastiknnNearestNeighborsQueryBuilder(ElastiknnNearestNeighborsQuery query, String field) { + this.query = query; + this.field = field; + } + + @Override + protected void doWriteTo(StreamOutput out) { + throw new UnsupportedOperationException("doWriteTo is not implemented"); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(getWriteableName()); + builder.field("field", field); + builder.field("similarity", query.getSimilarity().toString()); + if (query instanceof ElastiknnNearestNeighborsQuery.Exact) { + builder.field("model", "exact"); + } else if (query instanceof ElastiknnNearestNeighborsQuery.AngularLsh) { + ElastiknnNearestNeighborsQuery.AngularLsh q = (ElastiknnNearestNeighborsQuery.AngularLsh) query; + builder.field("model", "lsh"); + builder.field("candidates", q.getCandidates()); + } else if (query instanceof ElastiknnNearestNeighborsQuery.L2Lsh) { + ElastiknnNearestNeighborsQuery.L2Lsh q = (ElastiknnNearestNeighborsQuery.L2Lsh) query; + builder.field("model", "lsh"); + builder.field("candidates", q.getCandidates()); + builder.field("probes", q.getProbes()); + } else if (query instanceof ElastiknnNearestNeighborsQuery.PermutationLsh) { + ElastiknnNearestNeighborsQuery.PermutationLsh q = (ElastiknnNearestNeighborsQuery.PermutationLsh) query; + builder.field("model", "permutation_lsh"); + builder.field("candidates", q.getCandidates()); + } else { + throw new RuntimeException(String.format("Unexpected query type [%s]", query.getClass().toString())); + } + if (query.getVector() instanceof Vector.DenseFloat) { + Vector.DenseFloat dfv = (Vector.DenseFloat) query.getVector(); + builder.field("vec", dfv.values); + } else if (query.getVector() instanceof Vector.SparseBool) { + Vector.SparseBool sbv = (Vector.SparseBool) query.getVector(); + builder.startArray("vec"); + builder.value(sbv.trueIndices); + builder.value(sbv.totalIndices); + builder.endArray(); + } else { + throw new RuntimeException(String.format("Unexpected vector type [%s]", query.getVector().getClass().toString())); + } + builder.endObject(); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) { + throw new UnsupportedOperationException("doToQuery is not implemented"); + } + + @Override + protected boolean doEquals(ElastiknnNearestNeighborsQueryBuilder other) { + return other != null && ((this == other) || (query.equals(other.query) && field.equals(other.field))); + } + + @Override + protected int doHashCode() { + return Objects.hash(query, field); + } + + @Override + public String getWriteableName() { + return "elastiknn_nearest_neighbors"; + } +} diff --git a/elastiknn-client-java/src/main/java/com/klibisz/elastiknn/api4j/ElastiknnNearestNeighborsQuery.java b/elastiknn-client-java/src/main/java/com/klibisz/elastiknn/api4j/ElastiknnNearestNeighborsQuery.java new file mode 100644 index 000000000..d2f268b4c --- /dev/null +++ b/elastiknn-client-java/src/main/java/com/klibisz/elastiknn/api4j/ElastiknnNearestNeighborsQuery.java @@ -0,0 +1,159 @@ +package com.klibisz.elastiknn.api4j; + +import java.util.Objects; + +public abstract class ElastiknnNearestNeighborsQuery { + + private ElastiknnNearestNeighborsQuery() {} + + public abstract Vector getVector(); + public abstract Similarity getSimilarity(); + + public static final class Exact extends ElastiknnNearestNeighborsQuery { + private final Similarity similarity; + private final Vector vector; + public Exact(Vector vector, Similarity similarity) { + this.similarity = similarity; + this.vector = vector; + } + + @Override + public Vector getVector() { + return vector; + } + + @Override + public Similarity getSimilarity() { + return similarity; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Exact exact = (Exact) o; + return getSimilarity() == exact.getSimilarity() && Objects.equals(getVector(), exact.getVector()); + } + + @Override + public int hashCode() { + return Objects.hash(getSimilarity(), getVector()); + } + } + + public static final class AngularLsh extends ElastiknnNearestNeighborsQuery { + private final Vector vector; + private final Integer candidates; + public AngularLsh(Vector vector, Integer candidates) { + this.vector = vector; + this.candidates = candidates; + } + + public Integer getCandidates() { + return candidates; + } + + @Override + public Vector getVector() { + return vector; + } + + @Override + public Similarity getSimilarity() { + return Similarity.ANGULAR; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AngularLsh that = (AngularLsh) o; + return Objects.equals(getVector(), that.getVector()) && Objects.equals(getCandidates(), that.getCandidates()) && getSimilarity() == that.getSimilarity(); + } + + @Override + public int hashCode() { + return Objects.hash(getVector(), getCandidates(), getSimilarity()); + } + } + + public static final class L2Lsh extends ElastiknnNearestNeighborsQuery { + private final Vector.DenseFloat vector; + private final Integer candidates; + private final Integer probes; + public L2Lsh(Vector.DenseFloat vector, Integer candidates, Integer probes) { + this.vector = vector; + this.candidates = candidates; + this.probes = probes; + } + + public Integer getProbes() { + return probes; + } + + public Integer getCandidates() { + return candidates; + } + + @Override + public Vector getVector() { + return vector; + } + + @Override + public Similarity getSimilarity() { + return Similarity.L2; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + L2Lsh l2Lsh = (L2Lsh) o; + return Objects.equals(getVector(), l2Lsh.getVector()) && Objects.equals(getCandidates(), l2Lsh.getCandidates()) && Objects.equals(getProbes(), l2Lsh.getProbes()) && getSimilarity() == l2Lsh.getSimilarity(); + } + + @Override + public int hashCode() { + return Objects.hash(getVector(), getCandidates(), getProbes(), getSimilarity()); + } + } + + public final static class PermutationLsh extends ElastiknnNearestNeighborsQuery { + private final Vector.DenseFloat vector; + private final Similarity similarity; + private final Integer candidates; + public PermutationLsh(Vector.DenseFloat vector, Similarity similarity, Integer candidates) { + this.vector = vector; + this.similarity = similarity; + this.candidates = candidates; + } + + public Integer getCandidates() { + return candidates; + } + + @Override + public Vector getVector() { + return vector; + } + + @Override + public Similarity getSimilarity() { + return similarity; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PermutationLsh that = (PermutationLsh) o; + return Objects.equals(getVector(), that.getVector()) && getSimilarity() == that.getSimilarity() && Objects.equals(getCandidates(), that.getCandidates()); + } + + @Override + public int hashCode() { + return Objects.hash(getVector(), getSimilarity(), getCandidates()); + } + } +} \ No newline at end of file diff --git a/elastiknn-client-java/src/main/java/com/klibisz/elastiknn/api4j/Similarity.java b/elastiknn-client-java/src/main/java/com/klibisz/elastiknn/api4j/Similarity.java new file mode 100644 index 000000000..fcb302313 --- /dev/null +++ b/elastiknn-client-java/src/main/java/com/klibisz/elastiknn/api4j/Similarity.java @@ -0,0 +1,9 @@ +package com.klibisz.elastiknn.api4j; + +public enum Similarity { + JACCARD, + HAMMING, + L1, + L2, + ANGULAR +} diff --git a/elastiknn-client-java/src/main/java/com/klibisz/elastiknn/api4j/Vector.java b/elastiknn-client-java/src/main/java/com/klibisz/elastiknn/api4j/Vector.java new file mode 100644 index 000000000..dec3f1b31 --- /dev/null +++ b/elastiknn-client-java/src/main/java/com/klibisz/elastiknn/api4j/Vector.java @@ -0,0 +1,53 @@ +package com.klibisz.elastiknn.api4j; + +import java.util.Arrays; +import java.util.Objects; + +public abstract class Vector { + + private Vector() {} + + public static final class DenseFloat extends Vector { + public final float[] values; + public DenseFloat(float[] values) { + this.values = values; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DenseFloat that = (DenseFloat) o; + return Arrays.equals(values, that.values); + } + + @Override + public int hashCode() { + return Arrays.hashCode(values); + } + } + + public static final class SparseBool extends Vector { + public final int[] trueIndices; + public final Integer totalIndices; + public SparseBool(int[] trueIndices, Integer totalIndices) { + this.trueIndices = trueIndices; + this.totalIndices = totalIndices; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SparseBool that = (SparseBool) o; + return Arrays.equals(trueIndices, that.trueIndices) && Objects.equals(totalIndices, that.totalIndices); + } + + @Override + public int hashCode() { + int result = Objects.hash(totalIndices); + result = 31 * result + Arrays.hashCode(trueIndices); + return result; + } + } +} diff --git a/elastiknn-testing/src/test/scala/com/klibisz/elastiknn/query/JavaClientSuite.scala b/elastiknn-testing/src/test/scala/com/klibisz/elastiknn/query/JavaClientSuite.scala new file mode 100644 index 000000000..a5420fe3a --- /dev/null +++ b/elastiknn-testing/src/test/scala/com/klibisz/elastiknn/query/JavaClientSuite.scala @@ -0,0 +1,96 @@ +package com.klibisz.elastiknn.query + +import com.klibisz.elastiknn.api.ElasticsearchCodec._ +import com.klibisz.elastiknn.api._ +import com.klibisz.elastiknn.{ElastiknnNearestNeighborsQueryBuilder, api4j} +import com.klibisz.elastiknn.api4j.ElastiknnNearestNeighborsQuery +import com.klibisz.elastiknn.testing.ElasticAsyncClient +import com.sksamuel.elastic4s.ElasticDsl._ +import org.apache.http.HttpHost +import org.elasticsearch.action.search.SearchRequest +import org.elasticsearch.client.{RequestOptions, RestClient, RestHighLevelClient} +import org.elasticsearch.common.xcontent.json.JsonXContent +import org.elasticsearch.common.xcontent.{ToXContent, XContentBuilder} +import org.elasticsearch.search.builder.SearchSourceBuilder +import org.scalatest.funsuite.AsyncFunSuite +import org.scalatest.matchers.should.Matchers + +import java.io.ByteArrayOutputStream +import scala.concurrent.Future +import scala.util.Random + +class JavaClientSuite extends AsyncFunSuite with Matchers with ElasticAsyncClient { + + implicit val rng = new Random(0) + + test("Java client smoketest") { + val (index, field, id) = ("java-client-smoketest", "vec", "id") + val corpus = Vec.DenseFloat.randoms(100, 1000) + val ids = corpus.indices.map(i => s"v$i") + val mapping = Mapping.L2Lsh(corpus.head.dims, 50, 1, 2) + + val javaClient = new RestHighLevelClient(RestClient.builder(new HttpHost("localhost", 9200, "http"))) + val query = new ElastiknnNearestNeighborsQuery.L2Lsh(new api4j.Vector.DenseFloat(corpus.head.values), 20, 2) + val queryBuilder = new ElastiknnNearestNeighborsQueryBuilder(query, field) + val searchRequest = new SearchRequest() + searchRequest.source(new SearchSourceBuilder().query(queryBuilder)) + + for { + _ <- deleteIfExists(index) + _ <- eknn.createIndex(index) + _ <- eknn.putMapping(index, field, id, mapping) + _ <- eknn.index(index, field, corpus, id, ids) + _ <- eknn.execute(refreshIndex(index)) + + } yield { + val javaClientResult = javaClient.search(searchRequest, RequestOptions.DEFAULT) + val hits = javaClientResult.getHits.getHits + hits.length shouldBe 10 + hits.head.getId shouldBe "v0" + } + } + + test("XContent codec matches Scala codec") { + + val dfv = Vec.DenseFloat.random(10) + val sbv = Vec.SparseBool.random(20) + + val cases = Seq( + new ElastiknnNearestNeighborsQuery.Exact(new api4j.Vector.DenseFloat(dfv.values), api4j.Similarity.L1) -> + NearestNeighborsQuery.Exact("vec", Similarity.L1, dfv), + new ElastiknnNearestNeighborsQuery.Exact(new api4j.Vector.DenseFloat(dfv.values), api4j.Similarity.L2) -> + NearestNeighborsQuery.Exact("vec", Similarity.L2, dfv), + new ElastiknnNearestNeighborsQuery.Exact(new api4j.Vector.DenseFloat(dfv.values), api4j.Similarity.ANGULAR) -> + NearestNeighborsQuery.Exact("vec", Similarity.Angular, dfv), + new ElastiknnNearestNeighborsQuery.Exact(new api4j.Vector.SparseBool(sbv.trueIndices, sbv.totalIndices), api4j.Similarity.JACCARD) -> + NearestNeighborsQuery.Exact("vec", Similarity.Jaccard, sbv), + new ElastiknnNearestNeighborsQuery.L2Lsh(new api4j.Vector.DenseFloat(dfv.values), 22, 3) -> + NearestNeighborsQuery.L2Lsh("vec", 22, 3, dfv), + new ElastiknnNearestNeighborsQuery.AngularLsh(new api4j.Vector.DenseFloat(dfv.values), 22) -> + NearestNeighborsQuery.AngularLsh("vec", 22, dfv), + new ElastiknnNearestNeighborsQuery.PermutationLsh(new api4j.Vector.DenseFloat(dfv.values), api4j.Similarity.ANGULAR, 22) -> + NearestNeighborsQuery.PermutationLsh("vec", Similarity.Angular, 22, dfv), + new ElastiknnNearestNeighborsQuery.PermutationLsh(new api4j.Vector.DenseFloat(dfv.values), api4j.Similarity.L2, 22) -> + NearestNeighborsQuery.PermutationLsh("vec", Similarity.L2, 22, dfv) + ) + + // Encode the java query via XContent, decode it via Circe, compare it to the Scala query. + val checked = cases.zipWithIndex.map { + case ((javaQuery, scalaQuery), i) => + val bos = new ByteArrayOutputStream() + val xcb = new XContentBuilder(JsonXContent.jsonXContent, bos) + val qb = new ElastiknnNearestNeighborsQueryBuilder(javaQuery, scalaQuery.field) + qb.toXContent(xcb, ToXContent.EMPTY_PARAMS) + xcb.flush() + val qbJsonString = bos.toString() + val qbJsonParsed = parse(qbJsonString).map(_ \\ "elastiknn_nearest_neighbors").flatMap(_.head.as[NearestNeighborsQuery]) + info(s"case $i: ${scalaQuery.withVec(Vec.Empty())}") + withClue(s"case $i:") { + qbJsonParsed shouldBe Right(scalaQuery) + } + } + + Future(checked.last) + } + +} diff --git a/settings.gradle b/settings.gradle index 1ddc70345..3d6509404 100644 --- a/settings.gradle +++ b/settings.gradle @@ -4,6 +4,7 @@ List subprojectNames = List.of( 'api4s', 'benchmarks', 'client-elastic4s', + 'client-java', 'lucene', 'models', 'plugin', diff --git a/version b/version index 82d51e7c9..2126d83c8 100644 --- a/version +++ b/version @@ -1 +1 @@ -7.13.2.0 +7.13.2.1 \ No newline at end of file