In [1]:
!pip install faiss-cpu

Defaulting to user installation because normal site-packages is not writeable


In [5]:
import json
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
import os
import torch

In [6]:
# === Configuration ===
json_file_path = "/home/itewari1/NLP/Pruning/chunks.json"              # Path to your input JSON file
faiss_index_path = "RCT_embeddings.index"    # Path to store FAISS index
id_mapping_path = "chunk_id_map.json"          # Path to store chunk_id <-> FAISS index mapping

In [8]:
# === Load SentenceTransformer Model ===
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device=device)

print("Model device:", model.device)

# === Initialize FAISS Index (512 dim for MiniLM) ===
embedding_dim = 384
index = faiss.IndexFlatL2(embedding_dim)

# === Mappings ===
chunk_id_to_index = {}  # maps FAISS index to chunk_id
index_counter = 0
all_embeddings = []

Model device: cuda:0


In [10]:
# === Load JSON and process ===
with open(json_file_path, 'r', encoding='utf-8') as f:
    data = [json.loads(line) for line in f]

# Prepare valid items
texts = []
valid_chunk_ids = []

for item in data:
    text = item.get("text", "").strip()
    metadata = item.get("metadata", {})
    chunk_id = metadata.get("chunk_id", "")

    if text and chunk_id:
        texts.append(text)
        valid_chunk_ids.append(chunk_id)

print("Embedding starts...")

# Batch encode
embeddings = model.encode(texts, convert_to_numpy=True, batch_size=64, show_progress_bar=True)

print("Storing Starts...")

# Store embeddings
for embedding, chunk_id in zip(embeddings, valid_chunk_ids):
    all_embeddings.append(embedding)
    chunk_id_to_index[index_counter] = chunk_id
    index_counter += 1


if not all_embeddings:
    raise ValueError("No embeddings found. Aborting FAISS index creation.")

all_embeddings = np.array(all_embeddings, dtype='float32')
if len(all_embeddings.shape) != 2:
    raise ValueError("Expected 2D array for embeddings.")

index.add(all_embeddings)

faiss.write_index(index, faiss_index_path)
with open(id_mapping_path, 'w', encoding='utf-8') as f:
    json.dump(chunk_id_to_index, f, indent=2)

print(f"✅ Stored {len(chunk_id_to_index)} embeddings to FAISS at: {faiss_index_path}")
print(f"📝 Mapping saved at: {id_mapping_path}")

Embedding starts...


Batches:   0%|          | 0/45006 [00:00<?, ?it/s]

Storing Starts...
✅ Stored 2880373 embeddings to FAISS at: RCT_embeddings.index
📝 Mapping saved at: chunk_id_map.json
