In [2]:
import json
import torch
import pandas as pd
from pathlib import Path
from multiprocessing import Pool, cpu_count
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from tqdm import tqdm

In [3]:
import faiss

In [4]:
import torch
print(torch.version.cuda)
print(torch.cuda.is_available())

11.8
True


In [5]:
# ✅ Set GPU or CPU for embeddings
device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
# ✅ Optimized HuggingFace Embedding Model (Using GPU)
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"  # Faster model
embeddings = HuggingFaceEmbeddings(model_name=MODEL_NAME)

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
# ✅ Optimized Text Splitter (Larger Chunk Size)
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=2000,  # Reduce total embeddings
    chunk_overlap=400
)

In [8]:
# ✅ Constants
VECTOR_DB_PATH = "../faiss_db"
CHUNK_SIZE = 10000  # JSON Loading Chunk Size
EMBED_BATCH_SIZE = 1000  # Prevent OOM (Adjust if needed)
SAVE_INTERVAL = 50000  # Save FAISS every 100K papers
NUM_CORES = 4  # Adjust for your CPU cores
MAX_PAPERS = 500000  # Limit to 2L (500K) papers


In [9]:
# ✅ Load Research Papers Efficiently
def load_papers(file_path, max_papers=MAX_PAPERS, chunk_size=CHUNK_SIZE):
    papers = []
    try:
        for chunk in tqdm(pd.read_json(file_path, lines=True, chunksize=chunk_size), desc="Loading Papers"):
            cs_papers = chunk[chunk['categories'].str.contains('cs.')].to_dict(orient='records')
            papers.extend(cs_papers)
            if len(papers) >= max_papers:
                break  # Stop at 500K papers
    except Exception as e:
        print(f"Error loading papers: {e}")
    return papers[:max_papers]  # Ensure no more than 500K papers

In [10]:
# ✅ Process Text into Chunks
def process_text(paper):
    text = f"Title: {paper['title']}\nAuthors: {', '.join(paper['authors'])}\nAbstract: {paper['abstract']}\n"
    return text_splitter.split_text(text)


In [11]:
# ✅ Parallel Processing of Text Splitting
def parallel_process_papers(papers):
    documents = []
    with Pool(NUM_CORES) as pool:
        chunks = pool.map(process_text, papers)
        for chunk in chunks:
            documents.extend(chunk)
    return documents


In [12]:
# ✅ GPU-Accelerated Embeddings
def embed_texts(texts):
    return embeddings.embed_documents(texts)  # Runs on GPU

In [13]:
# ✅ Save FAISS Vector Database
def save_vector_db(documents, vectors):
    index = faiss.IndexFlatL2(len(vectors[0]))  # L2 Distance Index
    index.add(torch.tensor(vectors).numpy())  # Convert to numpy before adding

    vector_store = FAISS(embedding_function=embeddings, index=index)
    vector_store.save_local(VECTOR_DB_PATH)
    print(f"✔ Vector DB saved to {VECTOR_DB_PATH}")

In [14]:
# ✅ Create FAISS Vector Database
def create_vector_db(papers):
    print("Processing papers and creating vector database...")
    documents = parallel_process_papers(papers)

    vectors = []
    for i in tqdm(range(0, len(documents), EMBED_BATCH_SIZE), desc="Generating Embeddings"):
        batch = documents[i:i + EMBED_BATCH_SIZE]
        vectors.extend(embed_texts(batch))
        
        if (i > 0 and i % SAVE_INTERVAL == 0) or (i + EMBED_BATCH_SIZE >= len(documents)):
            save_vector_db(documents[:i + EMBED_BATCH_SIZE], vectors)
            print(f"✔ Saved FAISS Vector DB at {i + EMBED_BATCH_SIZE} documents")
    
    return vectors


In [None]:
# ✅ Main Execution
if __name__ == "__main__":
    file_path = "../Dataset/arxiv-metadata-oai-snapshot.json"

    if not Path(VECTOR_DB_PATH).exists():
        print("Creating new vector database...")
        papers = load_papers(file_path)
        create_vector_db(papers)

    vector_store = load_vector_db()

Creating new vector database...


Loading Papers: 28it [00:11,  2.36it/s]


Processing papers and creating vector database...


In [15]:
# ✅ Load FAISS Vector Database
def load_vector_db():
    print("Loading FAISS vector database...")
    return FAISS.load_local(VECTOR_DB_PATH, embeddings, allow_dangerous_deserialization=True)

In [16]:
vector_store = load_vector_db()

Loading FAISS vector database...
