Skip to content

Commit

Permalink
LUCENE-10057: Use Lucene abstractions to store demo KnnVectorDict (Da…
Browse files Browse the repository at this point in the history
…wid Weiss)
  • Loading branch information
msokolov committed Aug 19, 2021
1 parent eeb296c commit 5896e53
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 94 deletions.
34 changes: 24 additions & 10 deletions lucene/demo/src/java/org/apache/lucene/demo/IndexFiles.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.util.IOUtils;

/**
* Index all text files under a directory.
Expand All @@ -55,17 +56,18 @@
* command-line arguments for usage information.
*/
public class IndexFiles implements AutoCloseable {
static final String KNN_DICT = "knn-dict";

// Calculates embedding vectors for KnnVector search
private final DemoEmbeddings demoEmbeddings;
private final KnnVectorDict vectorDict;

private IndexFiles(Path vectorDictPath) throws IOException {
if (vectorDictPath != null) {
vectorDict = new KnnVectorDict(vectorDictPath);
private IndexFiles(KnnVectorDict vectorDict) throws IOException {
if (vectorDict != null) {
this.vectorDict = vectorDict;
demoEmbeddings = new DemoEmbeddings(vectorDict);
} else {
vectorDict = null;
this.vectorDict = null;
demoEmbeddings = null;
}
}
Expand All @@ -80,7 +82,7 @@ public static void main(String[] args) throws Exception {
+ "IF DICT_PATH contains a KnnVector dictionary, the index will also support KnnVector search";
String indexPath = "index";
String docsPath = null;
Path vectorDictPath = null;
String vectorDictSource = null;
boolean create = true;
for (int i = 0; i < args.length; i++) {
switch (args[i]) {
Expand All @@ -91,7 +93,7 @@ public static void main(String[] args) throws Exception {
docsPath = args[++i];
break;
case "-knn_dict":
vectorDictPath = Paths.get(args[++i]);
vectorDictSource = args[++i];
break;
case "-update":
create = false;
Expand Down Expand Up @@ -142,8 +144,16 @@ public static void main(String[] args) throws Exception {
//
// iwc.setRAMBufferSizeMB(256.0);

KnnVectorDict vectorDictInstance = null;
long vectorDictSize = 0;
if (vectorDictSource != null) {
KnnVectorDict.build(Paths.get(vectorDictSource), dir, KNN_DICT);
vectorDictInstance = new KnnVectorDict(dir, KNN_DICT);
vectorDictSize = vectorDictInstance.ramBytesUsed();
}

try (IndexWriter writer = new IndexWriter(dir, iwc);
IndexFiles indexFiles = new IndexFiles(vectorDictPath)) {
IndexFiles indexFiles = new IndexFiles(vectorDictInstance)) {
indexFiles.indexDocs(writer, docDir);

// NOTE: if you want to maximize search performance,
Expand All @@ -153,6 +163,8 @@ public static void main(String[] args) throws Exception {
// you're done adding documents to it):
//
// writer.forceMerge(1);
} finally {
IOUtils.close(vectorDictInstance);
}

Date end = new Date();
Expand All @@ -163,6 +175,10 @@ public static void main(String[] args) throws Exception {
+ " documents in "
+ (end.getTime() - start.getTime())
+ " milliseconds");
if (reader.numDocs() > 100 && vectorDictSize < 1_000_000) {
throw new RuntimeException(
"Are you (ab)using the toy vector dictionary? See the package javadocs to understand why you got this exception.");
}
}
} catch (IOException e) {
System.out.println(" caught a " + e.getClass() + "\n with message: " + e.getMessage());
Expand Down Expand Up @@ -263,8 +279,6 @@ void indexDoc(IndexWriter writer, Path file, long lastModified) throws IOExcepti

@Override
public void close() throws IOException {
if (vectorDict != null) {
vectorDict.close();
}
IOUtils.close(vectorDict);
}
}
5 changes: 2 additions & 3 deletions lucene/demo/src/java/org/apache/lucene/demo/SearchFiles.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.apache.lucene.demo.knn.KnnVectorDict;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.queryparser.classic.QueryParser;
import org.apache.lucene.search.BooleanClause;
Expand Down Expand Up @@ -103,12 +102,12 @@ public static void main(String[] args) throws Exception {
}
}

IndexReader reader = DirectoryReader.open(FSDirectory.open(Paths.get(index)));
DirectoryReader reader = DirectoryReader.open(FSDirectory.open(Paths.get(index)));
IndexSearcher searcher = new IndexSearcher(reader);
Analyzer analyzer = new StandardAnalyzer();
KnnVectorDict vectorDict = null;
if (knnVectors > 0) {
vectorDict = new KnnVectorDict(Paths.get(index).resolve("knn-dict"));
vectorDict = new KnnVectorDict(reader.directory(), IndexFiles.KNN_DICT);
}
BufferedReader in;
if (queries != null) {
Expand Down
85 changes: 43 additions & 42 deletions lucene/demo/src/java/org/apache/lucene/demo/knn/KnnVectorDict.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,19 @@
package org.apache.lucene.demo.knn;

import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.Closeable;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.regex.Pattern;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IntsRefBuilder;
import org.apache.lucene.util.VectorUtil;
Expand All @@ -40,32 +42,29 @@
* Manages a map from token to numeric vector for use with KnnVector indexing and search. The map is
* stored as an FST: token-to-ordinal plus a dense binary file holding the vectors.
*/
public class KnnVectorDict implements AutoCloseable {
public class KnnVectorDict implements Closeable {

private final FST<Long> fst;
private final FileChannel vectors;
private final ByteBuffer vbuffer;
private final IndexInput vectors;
private final int dimension;

/**
* Sole constructor
*
* @param knnDictPath the base path name of the files that will store the KnnVectorDict. The file
* with extension '.bin' holds the vectors and the '.fst' maps tokens to offsets in the '.bin'
* file.
* @param directory Lucene directory from which knn directory should be read.
* @param dictName the base name of the directory files that store the knn vector dictionary. A
* file with extension '.bin' holds the vectors and the '.fst' maps tokens to offsets in the
* '.bin' file.
*/
public KnnVectorDict(Path knnDictPath) throws IOException {
String dictName = knnDictPath.getFileName().toString();
Path fstPath = knnDictPath.resolveSibling(dictName + ".fst");
Path binPath = knnDictPath.resolveSibling(dictName + ".bin");
fst = FST.read(fstPath, PositiveIntOutputs.getSingleton());
vectors = FileChannel.open(binPath);
long size = vectors.size();
if (size > Integer.MAX_VALUE) {
throw new IllegalArgumentException("vector file is too large: " + size + " bytes");
public KnnVectorDict(Directory directory, String dictName) throws IOException {
try (IndexInput fstIn = directory.openInput(dictName + ".fst", IOContext.READ)) {
fst = new FST<>(fstIn, fstIn, PositiveIntOutputs.getSingleton());
}
vbuffer = vectors.map(FileChannel.MapMode.READ_ONLY, 0, size);
dimension = vbuffer.getInt((int) (size - Integer.BYTES));

vectors = directory.openInput(dictName + ".bin", IOContext.READ);
long size = vectors.length();
vectors.seek(size - Integer.BYTES);
dimension = vectors.readInt();
if ((size - Integer.BYTES) % (dimension * Float.BYTES) != 0) {
throw new IllegalStateException(
"vector file size " + size + " is not consonant with the vector dimension " + dimension);
Expand Down Expand Up @@ -96,8 +95,8 @@ public void get(BytesRef token, byte[] output) throws IOException {
if (ord == null) {
Arrays.fill(output, (byte) 0);
} else {
vbuffer.position((int) (ord * dimension * Float.BYTES));
vbuffer.get(output);
vectors.seek(ord * dimension * Float.BYTES);
vectors.readBytes(output, 0, output.length);
}
}

Expand All @@ -122,11 +121,12 @@ public void close() throws IOException {
* and each line is space-delimited. The first column has the token, and the remaining columns
* are the vector components, as text. The dictionary must be sorted by its leading tokens
* (considered as bytes).
* @param dictOutput a dictionary path prefix. The output will be two files, named by appending
* '.fst' and '.bin' to this path.
* @param directory a Lucene directory to write the dictionary to.
* @param dictName Base name for the knn dictionary files.
*/
public static void build(Path gloveInput, Path dictOutput) throws IOException {
new Builder().build(gloveInput, dictOutput);
public static void build(Path gloveInput, Directory directory, String dictName)
throws IOException {
new Builder().build(gloveInput, directory, dictName);
}

private static class Builder {
Expand All @@ -140,25 +140,20 @@ private static class Builder {
private long ordinal = 1;
private int numFields;

void build(Path gloveInput, Path dictOutput) throws IOException {
String dictName = dictOutput.getFileName().toString();
Path fstPath = dictOutput.resolveSibling(dictName + ".fst");
Path binPath = dictOutput.resolveSibling(dictName + ".bin");
void build(Path gloveInput, Directory directory, String dictName) throws IOException {
try (BufferedReader in = Files.newBufferedReader(gloveInput);
OutputStream binOut = Files.newOutputStream(binPath);
DataOutputStream binDataOut = new DataOutputStream(binOut)) {
IndexOutput binOut = directory.createOutput(dictName + ".bin", IOContext.DEFAULT);
IndexOutput fstOut = directory.createOutput(dictName + ".fst", IOContext.DEFAULT)) {
writeFirstLine(in, binOut);
while (true) {
if (addOneLine(in, binOut) == false) {
break;
}
while (addOneLine(in, binOut)) {
// continue;
}
fstCompiler.compile().save(fstPath);
binDataOut.writeInt(numFields - 1);
fstCompiler.compile().save(fstOut, fstOut);
binOut.writeInt(numFields - 1);
}
}

private void writeFirstLine(BufferedReader in, OutputStream out) throws IOException {
private void writeFirstLine(BufferedReader in, IndexOutput out) throws IOException {
String[] fields = readOneLine(in);
if (fields == null) {
return;
Expand All @@ -178,7 +173,7 @@ private String[] readOneLine(BufferedReader in) throws IOException {
return SPACE_RE.split(line, 0);
}

private boolean addOneLine(BufferedReader in, OutputStream out) throws IOException {
private boolean addOneLine(BufferedReader in, IndexOutput out) throws IOException {
String[] fields = readOneLine(in);
if (fields == null) {
return false;
Expand All @@ -197,15 +192,21 @@ private boolean addOneLine(BufferedReader in, OutputStream out) throws IOExcepti
return true;
}

private void writeVector(String[] fields, OutputStream out) throws IOException {
private void writeVector(String[] fields, IndexOutput out) throws IOException {
byteBuffer.position(0);
FloatBuffer floatBuffer = byteBuffer.asFloatBuffer();
for (int i = 1; i < fields.length; i++) {
scratch[i - 1] = Float.parseFloat(fields[i]);
}
VectorUtil.l2normalize(scratch);
floatBuffer.put(scratch);
out.write(byteBuffer.array());
byte[] bytes = byteBuffer.array();
out.writeBytes(bytes, bytes.length);
}
}

/** Return the size of the dictionary in bytes */
public long ramBytesUsed() {
return fst.ramBytesUsed() + vectors.length();
}
}
12 changes: 12 additions & 0 deletions lucene/demo/src/java/overview.html
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ <h1>Apache Lucene - Building and Installing the Basic Demo</h1>
<li><a href="#Location_of_the_source">Location of the source</a></li>
<li><a href="#IndexFiles">IndexFiles</a></li>
<li><a href="#Searching_Files">Searching Files</a></li>
<li><a href="#Embeddings">Working with vector embeddings</a></li>
</ul>
</div>
<a id="About_this_Document"></a>
Expand Down Expand Up @@ -203,6 +204,17 @@ <h2 class="boxed">Searching Files</h2>
<span class="codefrag">n</span> hits. The results are printed in pages, sorted
by score (i.e. relevance).</p>
</div>
<h2 id="Embeddings" class="boxed">Working with vector embeddings</h2>
<div class="section">
<p>In addition to indexing and searching text, IndexFiles and SearchFiles can also index and search
numeric vectors derived from that text, known as "embeddings." This demo code uses pre-computed embeddings
provided by the <a href="https://nlp.stanford.edu/projects/glove/">GloVe</a> project, which are in the public
domain. The dictionary here is a tiny subset of the full GloVe dataset. It includes only the words that occur
in the toy data set, and is definitely <i>not ready for production use</i>! If you use this code to create
a vector index for a larger document set, the indexer will throw an exception because
a more complete set of embeddings is needed to get reasonable results.
</p>
</div>
</body>
</html>

7 changes: 2 additions & 5 deletions lucene/demo/src/test/org/apache/lucene/demo/TestDemo.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.io.PrintStream;
import java.nio.charset.Charset;
import java.nio.file.Path;
import org.apache.lucene.demo.knn.KnnVectorDict;
import org.apache.lucene.util.LuceneTestCase;

public class TestDemo extends LuceneTestCase {
Expand Down Expand Up @@ -90,10 +89,8 @@ private void testVectorSearch(
public void testKnnVectorSearch() throws Exception {
Path dir = getDataPath("test-files/docs");
Path indexDir = createTempDir("ContribDemoTest");
Path dictPath = indexDir.resolve("knn-dict");
Path vectorDictSource = getDataPath("test-files/knn-dict").resolve("knn-token-vectors");
KnnVectorDict.build(vectorDictSource, dictPath);

Path vectorDictSource = getDataPath("test-files/knn-dict").resolve("knn-token-vectors");
IndexFiles.main(
new String[] {
"-create",
Expand All @@ -102,7 +99,7 @@ public void testKnnVectorSearch() throws Exception {
"-index",
indexDir.toString(),
"-knn_dict",
dictPath.toString()
vectorDictSource.toString()
});

// We add a single semantic hit by passing the "-knn_vector 1" argument to SearchFiles. The
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Path;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.VectorUtil;
Expand All @@ -28,30 +29,31 @@ public class TestDemoEmbeddings extends LuceneTestCase {

public void testComputeEmbedding() throws IOException {
Path testVectors = getDataPath("../test-files/knn-dict").resolve("knn-token-vectors");
Path dictPath = createTempDir("knn-demo").resolve("dict");
KnnVectorDict.build(testVectors, dictPath);
try (KnnVectorDict dict = new KnnVectorDict(dictPath)) {
DemoEmbeddings demoEmbeddings = new DemoEmbeddings(dict);
try (Directory directory = newDirectory()) {
KnnVectorDict.build(testVectors, directory, "dict");
try (KnnVectorDict dict = new KnnVectorDict(directory, "dict")) {
DemoEmbeddings demoEmbeddings = new DemoEmbeddings(dict);

// test garbage
float[] garbageVector =
demoEmbeddings.computeEmbedding("garbagethathasneverbeen seeneverinlife");
assertEquals(50, garbageVector.length);
assertArrayEquals(new float[50], garbageVector, 0);
// test garbage
float[] garbageVector =
demoEmbeddings.computeEmbedding("garbagethathasneverbeen seeneverinlife");
assertEquals(50, garbageVector.length);
assertArrayEquals(new float[50], garbageVector, 0);

// test space
assertArrayEquals(new float[50], demoEmbeddings.computeEmbedding(" "), 0);
// test space
assertArrayEquals(new float[50], demoEmbeddings.computeEmbedding(" "), 0);

// test some real words that are in the dictionary and some that are not
float[] realVector = demoEmbeddings.computeEmbedding("the real fact");
assertEquals(50, realVector.length);
// test some real words that are in the dictionary and some that are not
float[] realVector = demoEmbeddings.computeEmbedding("the real fact");
assertEquals(50, realVector.length);

float[] the = getTermVector(dict, "the");
assertArrayEquals(new float[50], getTermVector(dict, "real"), 0);
float[] fact = getTermVector(dict, "fact");
VectorUtil.add(the, fact);
VectorUtil.l2normalize(the);
assertArrayEquals(the, realVector, 0);
float[] the = getTermVector(dict, "the");
assertArrayEquals(new float[50], getTermVector(dict, "real"), 0);
float[] fact = getTermVector(dict, "fact");
VectorUtil.add(the, fact);
VectorUtil.l2normalize(the);
assertArrayEquals(the, realVector, 0);
}
}
}

Expand Down

0 comments on commit 5896e53

Please sign in to comment.