Note: steps (1)-(3) only need to be executed **once**

If you've already computed your index files and have them stored locally, you can skip to (4)

*Want to see this in action? Our viewer app is based on outputs from this exact notebook!* https://huggingface.co/spaces/Major-TOM/MajorTOM-Core-Viewer

In [25]:
import os
import pandas as pd
import numpy as np
import faiss
import gc
import json

def get_parquet_files(directory):
    return sorted([
        os.path.join(directory, f) 
        for f in os.listdir(directory) 
        if f.endswith('.parquet')
    ])

def read_vectors_from_parquet(filepath):
    """
    Reads parquet, extracts 'embedding' column, converts to float32 matrix.
    """
    df = pd.read_parquet(filepath, columns=['embedding'])
    
    # The crucial step: Convert column of numpy arrays to a single 2D matrix
    # This creates a copy in memory, so we must be careful.
    vector_matrix = np.stack(df['embedding'].values)
    
    # FAISS requires float32 and C-contiguous memory
    vector_matrix = np.ascontiguousarray(vector_matrix.astype('float32'))
    
    # Normalize if you want Cosine Similarity!
    faiss.normalize_L2(vector_matrix)
    
    return vector_matrix

In [26]:
# --- CONFIGURATION ---
BASE_DIR='data/Major-TOM/Core-S2RGB-SigLIP'
DATA_DIR = f'{BASE_DIR}/embeddings/'

# 1. Train FAISS

In [None]:
 # ==========================================
# STEP 1: TRAIN THE INDEX
# ==========================================
print("--- STEP 1: TRAINING ---")
files = get_parquet_files(DATA_DIR)
train_vectors = []
target_train_size = 1500_000 # 500k is usually sufficient for 20M
current_count = 0

for f in files:
    print(f"Loading {f} for training sample...")
    vecs = read_vectors_from_parquet(f)
    
    # Take a random subsample from this file to ensure distribution
    # (Optional: just take the first N, but random is safer)
    indices = np.random.choice(vecs.shape[0], size=min(30000, vecs.shape[0]), replace=False)
    sample = vecs[indices]
    
    train_vectors.append(sample)
    current_count += len(sample)
    
    if current_count >= target_train_size:
        break

del train_vectors

# Stack all training samples
train_matrix = np.vstack(train_vectors)
print(f"Training set shape: {train_matrix.shape}")
np.save(f'{BASE_DIR}/train_matrix.npy', train_matrix)

In [None]:
train_matrix=np.load(f'{BASE_DIR}/train_matrix.npy')

# FAISS Hyperparameters
D = 1152
# 4 * sqrt(N) is a good rule of thumb for nlist. sqrt(20M) ~ 4500. 
# 16384 or 32768 are good standard powers of 2 for 20M vectors.
NLIST = 32768
# m must be a divisor of 1152. 
# m=72 gives 16-dim subvectors (1152/72). Good balance.
M = 32         
NBITS = 8     

# Create the config object
cloner_options = faiss.GpuClonerOptions()
cloner_options.useFloat16LookupTables = True

# Create the index
quantizer = faiss.IndexFlatL2(D)
index = faiss.IndexIVFPQ(quantizer, D, NLIST, M, NBITS)

# Train (Use GPU for training if available to speed it up)
# Note: We train on GPU, but we might build on CPU to save VRAM 
# if the GPU can't hold the growing index + batch data.
res = faiss.StandardGpuResources()
gpu_index = faiss.index_cpu_to_gpu(res, 0, index, cloner_options)

print("Training index (this may take a few minutes)...")
gpu_index.train(train_matrix)

# Move back to CPU to populate data (Safe method for large RAM)
index = faiss.index_gpu_to_cpu(gpu_index)

# Clean up memory
del train_matrix, gpu_index
gc.collect()

# 2. Encode All Vectors

In [None]:
print("\n--- STEP 2: INDEXING & METADATA ---")

metadata_chunks = [] # We will store small dataframes here
total_vectors = 0

files = sorted([os.path.join(DATA_DIR, f) for f in os.listdir(DATA_DIR) if f.endswith('.parquet')])

for f_path in files:
    f_name = os.path.basename(f_path)
    print(f"Processing {f_name}...")
    
    # 1. Read Parquet (Vectors + Grid Cell)
    df = gpd.read_parquet(f_path, columns=['embedding', 'grid_cell','geometry'])
    
    # 2. Process Vectors for FAISS
    # Stack vectors into matrix
    vecs = np.stack(df['embedding'].values)
    vecs = np.ascontiguousarray(vecs.astype('float32'))
    faiss.normalize_L2(vecs)
    
    # 3. Add to FAISS Index
    index.add(vecs)
    
    # 4. Prepare Metadata Chunk
    # We only keep necessary columns to save RAM. 
    # We add 'row_in_file' so you can find the exact original vector later.
    meta_chunk = df[['grid_cell','geometry']].copy()
    meta_chunk['file'] = f_name
    meta_chunk['row_idx'] = np.arange(len(df), dtype=np.int32)
    
    # Add global FAISS ID (optional, but good for debugging)
    # meta_chunk['faiss_id'] = np.arange(total_vectors, total_vectors + len(df), dtype=np.int32)

    metadata_chunks.append(meta_chunk)
    
    total_vectors += len(df)
    
    # Clean up
    del df, vecs, meta_chunk
    gc.collect()

print(f"Total vectors indexed: {total_vectors}")

# 3. Create Index Mapping (grid_cell to index)

In [None]:
INDEX_OUTPUT = f'{BASE_DIR}/siglip_ivfpq.index'
METADATA_OUTPUT = f'{BASE_DIR}/siglip_ivfpq_metadata.parquet'

# --- STEP 3: SAVING ---
print("\n--- STEP 3: SAVING ---")

# Save FAISS Index
print(f"Writing index to {INDEX_OUTPUT}...")
faiss.write_index(index, INDEX_OUTPUT)

# Save Metadata
print(f"Concatenating and saving metadata to {METADATA_OUTPUT}...")
# This combines all chunks into one table. 
# Row 0 of this table corresponds to FAISS ID 0.
full_metadata =  gpd.GeoDataFrame(pd.concat(metadata_chunks, axis=0, ignore_index=True))

# Save as parquet (efficient compression for repeated filenames/grid_cells)
full_metadata.to_parquet(METADATA_OUTPUT, index=False)

print("Done! Metadata shape:", full_metadata.shape)
del full_metadata
gc.collect()

# 4. Inference Example

In [1]:
import faiss
import pandas as pd
import numpy as np
import torch

def search_with_grid_id(query_vec, k=5):
    # Prepare query
    if isinstance(query_vec, torch.Tensor):
        query_vec = query_vec.cpu().numpy()
    query_vec = query_vec.reshape(1, -1).astype('float32')
    faiss.normalize_L2(query_vec)
    
    # Search
    distances, indices = gpu_index.search(query_vec, k)
    
    # Flatten results
    ids = indices[0]
    scores = distances[0]
    
    results = []
    
    # Batch lookup in pandas (Faster than looping)
    # We ignore -1 (which happens if k > total vectors, unlikely here)
    valid_mask = ids != -1
    valid_ids = ids[valid_mask]
    valid_scores = scores[valid_mask]
    
    if len(valid_ids) > 0:
        # MAGIC LINE: Direct lookup by integer index
        matches = metadata_df.iloc[valid_ids].copy()
        matches['score'] = valid_scores
        
        # Convert to list of dicts for easy usage
        results = matches.to_dict(orient='records')
        
    return results

In [2]:
import faiss
import pandas as pd
import geopandas as gpd
import torch
from open_clip import create_model_from_pretrained, get_tokenizer

class SearchSigLIP():

    def __init__(self, index_path, metadata_path):

        # 1. Initialise Index
        print(f'Loading index from PATH={index_path}')
        self.index_path = index_path
        self.init_index()
        print('[DONE]')

        # 2. Initialise Metadata
        print(f'Loading metadata from PATH={metadata_path}')
        self.metadata_path = metadata_path
        self.metadata_df = pd.read_parquet(self.metadata_path)
        print('[DONE]')

        # 3. Initialise Text Encoder
        self.init_model()

    def init_index(self):
        self.cpu_index = faiss.read_index(self.index_path)
        res = faiss.StandardGpuResources()
        cloner_options = faiss.GpuClonerOptions()
        cloner_options.useFloat16LookupTables = True 
        self.gpu_index = faiss.index_cpu_to_gpu(res, 0, self.cpu_index, cloner_options)
        self.gpu_index.nprobe = 32 # Higher = more accurate, slower

    def init_model(self):
        self.model, self.preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP-384')
        self.model.eval()
        self.tokenizer = get_tokenizer('hf-hub:timm/ViT-SO400M-14-SigLIP')

    def encode_text(self, text, device='cuda'):
        self.model.to(device)
        with torch.no_grad():
            text = self.tokenizer([text], context_length=self.model.context_length)
            return self.model.encode_text(text.to(device))

    def search_with_grid(self, query_vec, k=5):
        # Prepare query
        if isinstance(query_vec, torch.Tensor):
            query_vec = query_vec.cpu().squeeze().numpy()
            
        query_vec = query_vec.reshape(1, -1).astype('float32')
        faiss.normalize_L2(query_vec)
        
        # Search
        distances, indices = self.gpu_index.search(query_vec, k)
        
        # Flatten results
        ids = indices[0]
        scores = distances[0]
        
        results = []
        
        # Batch lookup in pandas (Faster than looping)
        # We ignore -1 (which happens if k > total vectors, unlikely here)
        valid_mask = ids != -1
        valid_ids = ids[valid_mask]
        valid_scores = scores[valid_mask]
        
        if len(valid_ids) > 0:
            # MAGIC LINE: Direct lookup by integer index
            matches = self.metadata_df.iloc[valid_ids].copy()
            matches['score'] = valid_scores
            
            # Convert to list of dicts for easy usage
            results = matches.to_dict(orient='records')
            
        return results

    def faiss(self, text, k=1): # k - number of neighbours

        # 1. Compute query
        q = self.encode_text(text)

        # 2. Find Hits
        results = self.search_with_grid(q, k=k)

        return results


BASE_DIR='data/Major-TOM/Core-S2RGB-SigLIP'
INDEX_OUTPUT = f'{BASE_DIR}/siglip_ivfpq.index'
METADATA_OUTPUT = f'{BASE_DIR}/siglip_ivfpq_metadata.parquet'

search = SearchSigLIP(index_path=INDEX_OUTPUT,
                      metadata_path=METADATA_OUTPUT)



Loading index from PATH=data/Major-TOM/Core-S2RGB-SigLIP/siglip_ivfpq.index
[DONE]
Loading metadata from PATH=data/Major-TOM/Core-S2RGB-SigLIP/siglip_ivfpq_metadata.parquet
[DONE]


In [3]:
search.faiss("London", k = 10) # k controls the number of nearest neighbours

[{'grid_cell': '573U_2L',
  'geometry': b'\x01\x03\x00\x00\x00\x01\x00\x00\x00\x05\x00\x00\x00S>\xa8\x98\x99\xec\xc7\xbf\x81/\xaf\xfa\x81\xbfI@h$\xcc\xca%2\xc8\xbf]\xa6\xef\xea\x17\xbbI@\x1c\x8b6SHC\xcf\xbfY\x9c\xac\xe4B\xbbI@\xdb\xb7&\xbb\x18\xff\xce\xbfa:\xef\x01\xad\xbfI@S>\xa8\x98\x99\xec\xc7\xbf\x81/\xaf\xfa\x81\xbfI@',
  'file': 'part_03601-03700.parquet',
  'row_idx': 114071,
  'score': 1.7111053466796875},
 {'grid_cell': '594U_19L',
  'geometry': b'\x01\x03\x00\x00\x00\x01\x00\x00\x00\x05\x00\x00\x00b\x07\xca\xd7\xc4s\x06\xc0er\xa9\xe5\x1b\xb9J@\x19 L\xfb\x16t\x06\xc0Hpv\xda\xb0\xb4J@\xe8\xd8c\x12g\xea\x06\xc0\x990\xadu\xb3\xb4J@{j\xa2z-\xea\x06\xc0\x14\x0c\xb6\x81\x1e\xb9J@b\x07\xca\xd7\xc4s\x06\xc0er\xa9\xe5\x1b\xb9J@',
  'file': 'part_03701-03800.parquet',
  'row_idx': 12447,
  'score': 1.7131608724594116},
 {'grid_cell': '565U_2L',
  'geometry': b'\x01\x03\x00\x00\x00\x01\x00\x00\x00\x05\x00\x00\x00\x9d\xea\x80\xf9\x83K\xcd\xbf\x91\xf8\xf2&\x8ckI@\xff2\x95\\q\x8e\xcd\xbfSgC