In [1]:
import re
import torch
import faiss
import sentencepiece as spm
import fitz  # PyMuPDF
from decoder_only_gpt import My_GPT_model

In [2]:
# Load tokenizer
sp = spm.SentencePieceProcessor()
sp.load("hindi_tokenizer_new.model") 

True

In [3]:
# Load pretrained model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
model = My_GPT_model(vocab_size=sp.get_piece_size(), num_layers=12,
    d_model=512, d_ff=2048, num_heads=8,seq_len=512
).to(DEVICE)

In [5]:
ckpt = torch.load("checkpoints_HindiGPT-v1_step280000.pt", map_location=DEVICE)
state_dict = ckpt["model"]

In [6]:
clean_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}

In [7]:
model.load_state_dict(clean_state_dict, strict=False)
model.eval()

My_GPT_model(
  (decoder): Decoder(
    (embedding): Embedding(32768, 512)
    (layers): ModuleList(
      (0-11): 12 x Decoder_GPT_Block(
        (swi_glu): SwiGLU_FFN(
          (w1): Linear(in_features=512, out_features=1536, bias=False)
          (w2): Linear(in_features=512, out_features=1536, bias=False)
          (w3): Linear(in_features=1536, out_features=512, bias=False)
          (act): SiLU()
        )
        (masked_mha): Masked_MHA(
          (Q): Linear(in_features=512, out_features=512, bias=True)
          (K): Linear(in_features=512, out_features=512, bias=True)
          (V): Linear(in_features=512, out_features=512, bias=True)
          (fc_out): Linear(in_features=512, out_features=512, bias=True)
        )
        (rms_norm0): RMSNorm()
        (rms_norm1): RMSNorm()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): RMSNorm()
  )
  (lm_head): Linear(in_features=512, out_features=32768, bias=False)
)

In [8]:
from sentence_transformers import SentenceTransformer

embed_model = SentenceTransformer(
    "intfloat/multilingual-e5-base")

In [9]:
from sentence_transformers import CrossEncoder
reranker = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1")

In [10]:
def build_faiss_index(chunks):
    embeddings = embed_model.encode(
        ["passage: " + t for t in chunks],
        normalize_embeddings=True,
        show_progress_bar=True
    )
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatIP(dimension)
    index.add(embeddings)
    return index

In [11]:
def retrieve_context(query, chunks, index, top_k=10, rerank_top=3, min_score=0.5):
    query_emb = embed_model.encode(["query: " + query], normalize_embeddings=True)
    scores, idxs = index.search(query_emb, top_k)
    candidate_texts = []
    for i, score in zip(idxs[0], scores[0]):
        if score < min_score:
            continue
        candidate_texts.append(chunks[i])
    
    if not candidate_texts:
        return []
    
    # Rerank top candidates
    reranked = rerank(query, candidate_texts, top_k=rerank_top)
    context_ids = []
    for t in reranked:
        context_ids += sp.encode(f"\n[संदर्भ]\n{t}\n")
    
    # Limit context to ~400 tokens to fit model
    return context_ids[:400]

In [12]:
def rerank(query, candidate_texts, top_k=3):
    pairs = [(query, t) for t in candidate_texts]
    scores = reranker.predict(pairs)
    ranked = sorted(zip(candidate_texts, scores), key=lambda x: x[1], reverse=True)
    return [x[0] for x in ranked[:top_k]]

In [15]:
EOS_ID = sp.eos_id()
@torch.no_grad()
def generate_answer_from_ids(prompt_ids, max_new_tokens=300, temperature=0.7, top_p=0.95, repetition_penalty=1.2):
    input_ids = torch.tensor([prompt_ids], dtype=torch.long).to(DEVICE)
    generated = input_ids.clone()  # copy for tracking
    
    for _ in range(max_new_tokens):
        if generated.shape[1] >= 512:
            generated = generated[:, -512:]
            
        logits = model(generated)
        next_logits = logits[:, -1, :]
        
        # Repetition penalty
        if repetition_penalty > 1.0:
            recent = generated[0, -128:].tolist()
            for tid in set(recent):
                next_logits[0, tid] /= repetition_penalty
        
        # Temperature + top-p
        next_logits = next_logits / temperature
        probs = torch.softmax(next_logits, dim=-1)
        
        # Nucleus sampling
        sorted_probs, sorted_idx = torch.sort(probs, descending=True)
        cum_probs = torch.cumsum(sorted_probs, dim=-1)
        mask = cum_probs > top_p
        mask[..., 1:] = mask[..., :-1].clone()
        mask[..., 0] = False
        sorted_probs[mask] = 0.0
        if sorted_probs.sum() > 0:
            sorted_probs /= sorted_probs.sum()
        
        next_token = torch.multinomial(sorted_probs, num_samples=1)
        next_token = sorted_idx.gather(-1, next_token)
        
        generated = torch.cat([generated, next_token], dim=1)
        
        if next_token.item() == sp.eos_id():
            break
    
    # Sirf NEW generated tokens decode karo (prompt ke baad se)
    prompt_len = len(prompt_ids)
    answer_ids = generated[0, prompt_len:].tolist()
    return sp.decode(answer_ids).strip()

In [16]:
# build_prompt_ids ko simple bana
def build_prompt_ids(context_ids, question):
    instruction = "संदर्भ के आधार पर हिंदी में सटीक उत्तर दो।"
    prompt_ids = (
        sp.encode(instruction + "\n\nसंदर्भ:\n") +
        context_ids[:250] +  # Hard limit 250 tokens context
        sp.encode("\nप्रश्न: " + question + "\nउत्तर:\n")
    )
    return prompt_ids  # No BOS if causing issues, ya sp.bos_id() last mein try

# Generate settings
answer = generate_answer_from_ids(
    prompt_ids,
    max_new_tokens=250,
    temperature=0.9,
    top_p=0.95,
    repetition_penalty=1.25
)

NameError: name 'prompt_ids' is not defined