In this notebook we generate the faiss index with their corresponding passages. The input for the functions is the ATLAS index shards directory, which can be downloaded using:
```
python atlas/preprocessing/download_index.py --index indices/atlas/wiki/base --output_directory {OUTPUT_INDEX_DIR} 
```

In [1]:
import numpy as np
import faiss
import torch
import os
import pickle

In [2]:
# ATLAS index shards
output_index_dir = '/home/tkolb/data/indices/atlas/wiki/base'

# Index output path
index_path = '/home/tkolb/data/faiss_index.index'

# Passages output path
passages_path = f'/home/tkolb/data/wiki_passages.pkl'

In [3]:
# Load index in chunks that fit into GPU memory
def load_embeddings_in_chunks(path, chunk_size=4):
    embeddings = []
    embeddings_files = sorted([f for f in os.listdir(path) if f.startswith('embeddings')], key=lambda f: int(f.split('.')[1]))
    for filename in embeddings_files:
        print(filename)
        file_path = os.path.join(path, filename)
        data = torch.load(file_path, map_location='cpu')
        embeddings.append(data)
        if len(embeddings) == chunk_size:
            yield torch.cat(embeddings, dim=1)
            embeddings = []
    if embeddings:
        yield torch.cat(embeddings, dim=1)

In [4]:
# Build FAISS .index file with chunk_size and max number of files (num_files)
def build_faiss_index_incrementally(vectors_path, index_path, chunk_size=4, num_files=28):
    d = None
    index = None

    for i, embeddings in enumerate(load_embeddings_in_chunks(vectors_path, chunk_size)):
        embeddings = embeddings.swapaxes(0, 1)
        embeddings_np = np.ascontiguousarray(embeddings.numpy()).astype(np.float32)
        if d is None:
            d = embeddings_np.shape[1]
            index = faiss.IndexFlatL2(d)
        index.add(embeddings_np)
        
        if (i+1)*chunk_size >= num_files:
            break
        
    # index = faiss.read_index(index_path)
    faiss.write_index(index, index_path)
    return index

In [5]:
index = build_faiss_index_incrementally(output_index_dir, index_path)

FileNotFoundError: [Errno 2] No such file or directory: '/home/tkolb/data/indices/atlas/wiki/base'

In [4]:
# Load index to test
index = faiss.read_index(index_path)
index.d, index.ntotal

(768, 7030352)

In [14]:
# Build complete wiki passages file from shards with max number of files (num_files)
def build_passages(output_index_dir, passages_path, num_files=28):
    passages_list = []
    passages_files = sorted([f for f in os.listdir(output_index_dir) if f.startswith('passages')], key=lambda f: int(f.split('.')[1]))
    for i, filename in enumerate(passages_files):
        print(filename)
        file_path = os.path.join(output_index_dir, filename)
        with open(file_path, "rb") as fobj:
            passages = pickle.load(fobj)
            passages_list += passages
            
        if i == num_files-1:
            break
    
    with open(passages_path, 'wb') as f:
        pickle.dump(passages_list, f)

In [15]:
build_passages(output_index_dir, passages_path)

passages.0.pt
passages.1.pt
passages.2.pt
passages.3.pt
passages.4.pt
passages.5.pt
passages.6.pt
passages.7.pt
passages.8.pt
passages.9.pt
passages.10.pt
passages.11.pt
passages.12.pt
passages.13.pt
passages.14.pt
passages.15.pt
passages.16.pt
passages.17.pt
passages.18.pt
passages.19.pt
passages.20.pt
passages.21.pt
passages.22.pt
passages.23.pt
passages.24.pt
passages.25.pt
passages.26.pt
passages.27.pt


In [6]:
# Load passages to test
with open(passages_path, 'rb') as f:
    passages = pickle.load(f)
len(passages)

7030352