In [5]:
# Installed the required libraries for this project
!pip install torch transformers faiss-cpu tqdm

Collecting faiss-cpu
  Downloading faiss_cpu-1.12.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.1 kB)
Downloading faiss_cpu-1.12.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (31.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.4/31.4 MB[0m [31m22.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.12.0


In [6]:
import os
import json
import numpy as np
from pathlib import Path
from tqdm import tqdm

import torch
from transformers import T5Tokenizer, T5EncoderModel
import faiss

In [11]:
from transformers import ByT5Tokenizer, T5EncoderModel
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

# Use Google ByT5 tokenizer
tokenizer = ByT5Tokenizer.from_pretrained("google/byt5-small")

# Load Buddhist NLP fine-tuned Sanskrit model
model = T5EncoderModel.from_pretrained("buddhist-nlp/byt5-sanskrit").to(device)
model.eval()

print("Tokenizer: google/byt5-small")
print("Model: buddhist-nlp/byt5-sanskrit")

pytorch_model.bin:   0%|          | 0.00/2.33G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.33G [00:00<?, ?B/s]

Tokenizer: google/byt5-small
Model: buddhist-nlp/byt5-sanskrit


In [13]:
# --- Helper: Embed text using ByT5 (mean pooling, L2 normalized) ---
@torch.no_grad()
def embed_texts(texts, max_length=512):
    vectors = []
    for t in texts:
        inputs = tokenizer(t, return_tensors="pt", truncation=True, max_length=max_length).to(device)
        outputs = model(**inputs)
        vec = outputs.last_hidden_state.mean(dim=1)
        vec = torch.nn.functional.normalize(vec, p=2, dim=1)  # cosine similarity
        vectors.append(vec.squeeze(0).cpu().numpy())
    return np.vstack(vectors).astype("float32")


In [14]:
# --- Helper: Token-aware chunking for long texts ---
def chunk_text(text, max_tokens=512, joiner="\n"):
    lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
    chunks = []
    buf = []
    for ln in lines:
        candidate = joiner.join(buf + [ln]) if buf else ln
        tok_len = len(tokenizer(candidate)["input_ids"])
        if tok_len <= max_tokens:
            buf.append(ln)
        else:
            if buf:
                chunks.append(joiner.join(buf))
            if len(tokenizer(ln)["input_ids"]) > max_tokens:
                # hard wrap long lines as fallback
                for i in range(0, len(ln), 1000):
                    chunks.append(ln[i:i+1000])
                buf = []
            else:
                buf = [ln]
    if buf:
        chunks.append(joiner.join(buf))
    return chunks


In [16]:
!pwd #use the pwd as the input folder

/content


In [17]:
!ls

sa_abhidharmasamuccayabhASya.txt  sa_vasubandhu-paJcaskandhaprakaraNa.txt
sample_data			  sa_vasubandhu-triMzikAvijJaptikArikA-comm.txt


In [18]:
# --- Build FAISS index from all TXT files ---
input_folder = "/content/"
out_index_path = "sanskrit.faiss"
out_meta_path = "sanskrit_metadata.jsonl"

texts = []
metadata = []

txt_files = sorted(Path(input_folder).glob("*.txt"))

for file in tqdm(txt_files, desc="Processing TXT files"):
    content = file.read_text(encoding="utf-8", errors="ignore").strip()
    if not content:
        continue
    chunks = chunk_text(content, max_tokens=512)
    for i, ch in enumerate(chunks):
        texts.append(ch)
        metadata.append({
            "source_file": file.name,
            "chunk_id": i,
            "text": ch
        })

print(f"Total chunks: {len(texts)}")




Processing TXT files:   0%|          | 0/3 [00:00<?, ?it/s][A[A

Processing TXT files:  33%|███▎      | 1/3 [00:01<00:02,  1.33s/it][A[A

Processing TXT files:  67%|██████▋   | 2/3 [00:05<00:03,  3.08s/it][A[A

Processing TXT files: 100%|██████████| 3/3 [00:05<00:00,  1.97s/it]

Total chunks: 761





In [19]:
embeddings = embed_texts(texts, max_length=512)
print("Embeddings shape:", embeddings.shape)


Embeddings shape: (761, 1536)


In [20]:
index = faiss.IndexFlatIP(embeddings.shape[1])  # inner product on unit vectors = cosine
index.add(embeddings)
print("Index built with", index.ntotal, "vectors")

faiss.write_index(index, out_index_path)
with open(out_meta_path, "w", encoding="utf-8") as f:
    for m in metadata:
        f.write(json.dumps(m, ensure_ascii=False) + "\n")

print(f"FAISS index saved to {out_index_path}")
print(f"Metadata saved to {out_meta_path}")


Index built with 761 vectors
FAISS index saved to sanskrit.faiss
Metadata saved to sanskrit_metadata.jsonl


In [21]:
index = faiss.read_index(out_index_path)
meta = [json.loads(line) for line in open(out_meta_path, encoding="utf-8")]

@torch.no_grad()
def embed_query(query, max_length=512):
    inputs = tokenizer(query, return_tensors="pt", truncation=True, max_length=max_length).to(device)
    outputs = model(**inputs)
    vec = outputs.last_hidden_state.mean(dim=1)
    vec = torch.nn.functional.normalize(vec, p=2, dim=1)
    return vec.squeeze(0).cpu().numpy().astype("float32").reshape(1, -1)

def search(query, k=5):
    qvec = embed_query(query)
    sims, ids = index.search(qvec, k)
    results = []
    for rank, (idx, score) in enumerate(zip(ids[0], sims[0]), 1):
        m = meta[idx]
        results.append((rank, score, m["source_file"], m["chunk_id"], m["text"]))
    return results


In [22]:
query_text = "candramāḥ rāhuḥ buddha gāthā"  # example Sanskrit query
results = search(query_text, k=3)

for rank, score, fname, cid, txt in results:
    print(f"\n#{rank} | score={score:.4f} | file={fname} | chunk={cid}")
    print("-" * 80)
    print(txt[:500])  # show first 500 characters



#1 | score=0.7556 | file=sa_vasubandhu-triMzikAvijJaptikArikA-comm.txt | chunk=112
--------------------------------------------------------------------------------
atra gāthā |
ādānavijñāna gabhīrasūkṣmo ogho yathā vartati sarvabījo | bālāna eṣo mayi na prakāśi mā haiva ātmā parikalpayeyur

#2 | score=0.7305 | file=sa_vasubandhu-paJcaskandhaprakaraNa.txt | chunk=40
--------------------------------------------------------------------------------
kati sabhāgāḥ / ādhyātmikāḥ pañca rūpiṇaḥ / svavijñānasahitaviṣayasadṛśatāmupādāya // 166 //
kati tatsabhāgāḥ / ta eva svavijñānavirahitasvānvayasadṛśatāmupādāya // 167 //
dhātūddyeśanirdeśakastṛtīyādhikaraḥ pariniṣṭitaḥ /
ācāryavasubandhuviracit pañcaskandhaprakaraṇaṃ samāptam / pañcaskandhaprakaraṇaṃ śāstriṇā śāntibhikṣuṇā /
bhoṭānuvādamāgamya saṃskṛte punaruddhṛtam

#3 | score=0.7123 | file=sa_abhidharmasamuccayabhASya.txt | chunk=520
--------------------------------------------------------------------------------
[vādādhikaraṇam] atra vādaḥ