In [1]:
import faiss
import json
from pathlib import Path
import numpy as np
from tqdm import tqdm
from sqlitedict import SqliteDict

In [2]:
# Path to data
path = Path("/scratch/project_462000615/ehenriks/llm-descriptor-evaluation/data/processed/descriptors_with_explainers_embeddings.jsonl")

In [3]:
def get_train_data(path, num_lines=0):
    with path.open("r") as f:
        data = []
        for line in tqdm(f):
            data.append(json.loads(line))
            if num_lines == 0:
                continue
            elif num_lines == len(data):
                break
    embeddings = np.vstack([doc["embedding"] for doc in data])
    return embeddings

train_data = get_train_data(path, 40000)

39999it [00:11, 3521.00it/s]


In [4]:
train_data[0].shape[0]

1024

In [10]:
type(train_data[0][0])

numpy.float64

In [5]:
embedding_dimensions = train_data[0].shape[0]
i = faiss.IndexFlatL2(embedding_dimensions)
i.add(train_data)
i.ntotal

40000

In [7]:
idsel = faiss.IDSelector(0)

AttributeError: No constructor defined - class is abstract

In [6]:
i.remove_ids(0)

AttributeError: 'int' object has no attribute 'ndim'

In [24]:
train_data[0] == i.reconstruct(0)

array([ True,  True,  True, ...,  True,  True,  True])

In [4]:
# Train IVFPQ
quantizer = faiss.IndexFlatL2(1024)
index = faiss.IndexIVF(quantizer,
                         1024, # embedding dim
                         1024, # num cells
                         64, # num quantizers
                         8, # quantizer bits
                        )
print(f"Index is trained: {index.is_trained}")
print("Training...")
index.train(train_data)
print(f"Index is trained: {index.is_trained}")

Index is trained: False
Training...
Index is trained: True


In [21]:
# Index is trained, but still empty
# We need to add embeddings to it
index.ntotal

0

In [22]:
def add_to_index(path, index):
    db = SqliteDict("descriptors.db")
    with path.open("r") as f:
        embeddings = []
        indices = []
        for idx, line in enumerate(tqdm(f)):
            doc = json.loads(line)
            emb = doc["embedding"]
            db[idx] = doc["descriptor"]
            embeddings.append(emb)
            indices.append(idx)
            if len(embeddings) >= 1000:
                index.add_with_ids(np.vstack(embeddings), indices)
                db.commit()
                embeddings = []
                indices = []
        if embeddings:
            index.add_with_ids(np.vstack(embeddings), indices)
            db.commit()    
            embeddings = []
            indices = []
    return index

In [23]:
index = add_to_index(path, index)

998277it [1:15:26, 220.56it/s]


In [24]:
###faiss.write_index(index, "../data/faiss/descriptors.index")

In [None]:
# JUST LOAD THE INDEX FROM HERE!

In [3]:
saved_index = faiss.read_index("../data/faiss/descriptors.index")

In [4]:
saved_index.ntotal

998277

In [7]:
saved_index = saved_index.reconstruct(0)

RuntimeError: Error in faiss::idx_t faiss::DirectMap::get(faiss::idx_t) const at /opt/faiss/faiss/invlists/DirectMap.cpp:83: direct map not initialized

In [27]:
search_index = 50
n_neighbours = 10

D, I = index.search(np.array([train_data[search_index]]),n_neighbours)
db = SqliteDict("descriptors.db")

print(f"Search text: {db[search_index]}")
print("======================")
print("Nearest neigbours:")
for i in I[0]:
    print(db[str(i)])

Search text: 16th-century setting: the time period in which the events of the document take place.: this descriptor is relevant to the document as it provides context for thomas treffry ii's life and career.
Nearest neigbours:
16th-century setting: the time period in which the events of the document take place.: this descriptor is relevant to the document as it provides context for thomas treffry ii's life and career.
sixteenth century setting: the document is set in the sixteenth century, which is relevant to understanding the subject's life and career.
historical context: the circumstances and events that surround a historical event or figure.: this descriptor is relevant to the document as it provides context for thomas treffry ii's life and career.
medieval setting: the document is set in the medieval period, specifically in the 15th and 16th centuries. this descriptor is general as it can be applied to other documents set in the same time period.
16th century: a film set in the 16

In [28]:
D

array([[0.10812277, 0.19638021, 0.2013426 , 0.2122599 , 0.21481532,
        0.21565974, 0.21701847, 0.21745032, 0.21797583, 0.2190612 ]],
      dtype=float32)

In [29]:
def find_top_n_neighbors(path_to_embeds, top_n=1, nprobe=10, exclude_self=True, stop_index=-1):
    path = Path(path_to_embeds)
    nearest_neighbors = {}
    if exclude_self:
        top_n += 1
    index.nprobe = nprobe
    with path.open("r") as f:
        for idx, line in enumerate(tqdm(f)):
            doc = json.loads(line)
            emb = np.array([doc["embedding"]])
            distances, indices = index.search(np.array(emb),top_n)
            if exclude_self:
                # Remove self-match (typically at index 0 for each query)
                neighbors = indices[:, 1:top_n + 1]
                nearest_neighbors[idx] = neighbors[0][0]
            else:
                neighbors = indices
                nearest_neighbors[idx] = neighbors[0][0]
            if idx == stop_index:
                break
    return nearest_neighbors

In [None]:
nn = find_top_n_neighbors(path)

128530it [23:16, 61.72it/s] 

In [None]:
def find_mutual_pairs(neighbor_dict):
    mutual_pairs = set()
    
    for a, b in neighbor_dict.items():
        if neighbor_dict.get(b) == a:
            # Store sorted tuple to avoid duplicates (e.g., (7, 8) and (8, 7))
            mutual_pairs.add(tuple(sorted((a, b))))
    
    return list(mutual_pairs)

In [None]:
mutual_pairs = find_mutual_pairs(nn)
print(len(mutual_pairs))

In [None]:
with open("mutual_pair_search_results.txt", "w") as f:
    f.write(f"Embedding in index: {index.ntotal}\n")
    f.write(f"Mutual pairs: {len(mutual_pairs)}\n"}
    for i in range(10):
        f.write(f"{db[str(mutual_pairs[i][0])]}\n")
        f.write(f"{db[str(mutual_pairs[i][1])]}\n")
        f.write("====================")

In [68]:
def show_mutual_nearest_pairs(mutual_pairs):
    for pair in mutual_pairs:
        print(db[str(pair[0])])
        print(db[str(pair[1])])
        print("=======================")

show_mutual_nearest_pairs(mutual_pairs)

3d game: the game is a 3d game, which implies a certain type of graphics and gameplay, and is relevant to the document as it describes the game's visual and interactive aspects.
3d game: the game is described as a 3d game, which means that it will have three-dimensional graphics. this type of game can provide a more immersive experience for players. the 3d graphics will likely enhance the game's visuals and gameplay.
1952: this descriptor provides the year the case was heard, giving historical context.
1993 case: the year the case was decided, providing context for legal precedents and historical reference.: the case was decided in 1993, which provides context for understanding the legal precedents and historical reference that were relevant at the time.
absurd: the document is absurd, meaning that it challenges traditional notions of logic and reason. this absurdity is used to create a sense of humor and irony, and to challenge the reader's expectations and assumptions about the natur