In [1]:
import re
import torch
import faiss
import sentencepiece as spm
import fitz  # PyMuPDF
from decoder_only_gpt import My_GPT_model

In [2]:
# Load tokenizer
sp = spm.SentencePieceProcessor()
sp.load("hindi_tokenizer_new.model")

True

In [3]:
# Load pretrained model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
model = My_GPT_model(vocab_size=sp.get_piece_size(), num_layers=12,
    d_model=512, d_ff=2048, num_heads=8,seq_len=512
).to(DEVICE)

In [5]:
ckpt = torch.load("checkpoints_HindiGPT-v1_step280000.pt", map_location=DEVICE)
state_dict = ckpt["model"]

In [6]:
clean_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}

In [7]:
model.load_state_dict(clean_state_dict, strict=False)
model.eval()

My_GPT_model(
  (decoder): Decoder(
    (embedding): Embedding(32768, 512)
    (layers): ModuleList(
      (0-11): 12 x Decoder_GPT_Block(
        (swi_glu): SwiGLU_FFN(
          (w1): Linear(in_features=512, out_features=1536, bias=False)
          (w2): Linear(in_features=512, out_features=1536, bias=False)
          (w3): Linear(in_features=1536, out_features=512, bias=False)
          (act): SiLU()
        )
        (masked_mha): Masked_MHA(
          (Q): Linear(in_features=512, out_features=512, bias=True)
          (K): Linear(in_features=512, out_features=512, bias=True)
          (V): Linear(in_features=512, out_features=512, bias=True)
          (fc_out): Linear(in_features=512, out_features=512, bias=True)
        )
        (rms_norm0): RMSNorm()
        (rms_norm1): RMSNorm()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): RMSNorm()
  )
  (lm_head): Linear(in_features=512, out_features=32768, bias=False)
)

In [8]:
from sentence_transformers import SentenceTransformer

embed_model = SentenceTransformer(
    "intfloat/multilingual-e5-base")

In [9]:
from sentence_transformers import CrossEncoder
reranker = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1")

In [10]:
print("Vocab size from tokenizer:", sp.get_piece_size())
print("Model vocab size:", model.lm_head.out_features)  # should match

Vocab size from tokenizer: 32768
Model vocab size: 32768


In [11]:
import re

def clean_hindi_text(text: str) -> str:
    """
    Clean Hindi text extracted from PDF.
    Handles OCR errors, extra spaces, matra spacing issues, and unwanted characters.
    """
    if not text:
        return ""

    # Step 1: Remove common OCR garbage characters
    text = re.sub(r'[⁇�•◦◇◆■□▪▫]', '', text)

    # Step 2: Normalize multiple spaces, tabs, newlines → single space
    text = re.sub(r'\s+', ' ', text)

    # Step 3: Fix matra + extra space issue (very common in pypdf/fitz extractions)
    matra_fixes = {
        ' ो': 'ो', ' ौ': 'ौ', ' ा': 'ा', ' ी': 'ी', ' ू': 'ू',
        ' ु': 'ु', ' ू': 'ू', ' ृ': 'ृ', ' े': 'े', ' ै': 'ै',
        ' ं': 'ं', ' ः': 'ः', ' ँ': 'ँ'
    }
    for wrong, correct in matra_fixes.items():
        text = text.replace(wrong, correct)

    # Step 4: Remove all characters except:
    # - Devanagari (हिंदी)
    # - English a-z A-Z
    # - Digits 0-9
    # - Basic punctuation: । । ? ! , ; : - " ' ( )
    # - Spaces
    allowed = r'[^\u0900-\u097F\u0041-\u005A\u0061-\u007A0-9\s।?!,;:—\-\(\)"\'।]'
    text = re.sub(allowed, ' ', text)

    # Step 5: Again normalize spaces (in case previous steps created extra)
    text = re.sub(r'\s+', ' ', text)

    # Step 6: Optional - Fix common broken words (like half conjuncts)
    # Example: "र ा" → "रा" (rare but useful)
    text = re.sub(r'([क-ह])\s+([ा-ौृेैंीू])', r'\1\2', text)

    return text.strip()

In [12]:
import re
from typing import List

def sentence_chunk_hindi(
    text: str,
    max_tokens: int = 200,
    overlap_sentences: int = 2,
    min_chunk_tokens: int = 30  # optional: bahut chhote chunks avoid karne ke liye
) -> List[str]:
    """
    Hindi text ko meaningful sentence-level chunks mein baantta hai.
    Overlap sentences ke hisaab se rakhta hai.
    """
    if not text.strip():
        return []

    # Step 1: Text ko normalize karo (multiple spaces/newlines)
    text = re.sub(r'\s+', ' ', text.strip())

    # Step 2: Sentences split karo (। और \n दोनों pe, lekin separator ko preserve karne ki koshish)
    # Better way: । को sentence end maante hue split
    sentences = re.split(r'(?<=[।\n])\s*', text)
    
    # Agar split ke baad empty strings aaye toh remove kar do
    sentences = [s.strip() for s in sentences if s.strip()]

    chunks = []
    current = []

    for sent in sentences:
        current.append(sent)
        
        # Current chunk ka token count check karo
        current_text = " ".join(current)  # ya '' .join bhi try kar sakte ho agar space nahi chahiye
        token_count = len(sp.encode(current_text))
        
        if token_count >= max_tokens:
            chunk_text = " ".join(current)
            if len(sp.encode(chunk_text)) >= min_chunk_tokens:
                chunks.append(chunk_text)
            # Overlap rakho
            current = current[-overlap_sentences:]
    
    # Last remaining chunk add karo (agar meaningful hai)
    if current:
        last_chunk = " ".join(current)
        if len(sp.encode(last_chunk)) >= min_chunk_tokens:
            chunks.append(last_chunk)

    return chunks

In [13]:
def build_faiss_index(chunks):
    embeddings = embed_model.encode(
        ["passage: " + t for t in chunks],
        normalize_embeddings=True,
        show_progress_bar=True
    )
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatIP(dimension)
    index.add(embeddings)
    return index

In [14]:
def retrieve_context(
    query: str,
    chunks: list[str],
    index: faiss.Index,
    top_k: int = 10,
    rerank_top: int = 3,
    min_score: float = 0.5,
    max_context_tokens: int = 200,  # thoda margin rakho instruction ke liye
    debug: bool = False
) -> list[int]:
    """
    Retrieves and reranks relevant context chunks.
    Returns tokenized context IDs, limited to max_context_tokens.
    """
    if not chunks or not index:
        return []

    try:
        # Embed query (batch size 1, lekin list mein daalna zaroori hai encode ke liye)
        query_emb = embed_model.encode(
            ["query: " + query],
            normalize_embeddings=True,
            convert_to_numpy=True
        )
    except Exception as e:
        print(f"Query embedding failed: {e}")
        return []

    # Search
    scores, idxs = index.search(query_emb, min(top_k, len(chunks)))

    candidate_texts = []
    candidate_scores = []  # debug ke liye save kar rahe hain

    for idx, score in zip(idxs[0], scores[0]):
        if idx == -1:  # FAISS invalid index
            continue
        if score < min_score:
            continue
        candidate_texts.append(chunks[idx])
        candidate_scores.append(score)

    if not candidate_texts:
        if debug:
            print("No candidates above min_score threshold")
        return []

    if debug:
        print(f"Initial candidates: {len(candidate_texts)} (top scores: {candidate_scores[:3]})")

    # Rerank
    try:
        reranked_texts = rerank(query, candidate_texts, top_k=rerank_top)
    except Exception as e:
        print(f"Reranking failed: {e}")
        reranked_texts = candidate_texts[:rerank_top]  # fallback

    # Build context with proper separators
    context_ids = []
    for text in reranked_texts:
        section = f"\n[संदर्भ]\n{text.strip()}\n"
        context_ids.extend(sp.encode(section))

    # Smart truncation: total tokens limit
    if len(context_ids) > max_context_tokens:
        context_ids = context_ids[:max_context_tokens]
        if debug:
            print(f"Context truncated from {len(context_ids)} to {max_context_tokens} tokens")

    return context_ids

In [15]:
from typing import List, Optional

def rerank(
    query: str,
    candidate_texts: List[str],
    top_k: int = 3,
    fallback: bool = True
) -> List[str]:
    """
    Reranks candidate texts using cross-encoder for better relevance.
    
    Args:
        query: The search query
        candidate_texts: List of candidate chunk texts
        top_k: How many top results to return
        fallback: If reranking fails, return top candidates as-is
    
    Returns:
        List of reranked texts (up to top_k)
    """
    if not candidate_texts:
        return []

    if top_k <= 0:
        return []

    # Agar candidates bahut kam hain toh jitne hain utne hi return
    top_k = min(top_k, len(candidate_texts))

    try:
        # Pairs banana
        pairs = [(query, text) for text in candidate_texts]

        # Cross-encoder predict (returns numpy array of scores)
        scores = reranker.predict(pairs)

        # Zip aur sort by score descending
        ranked_pairs = sorted(
            zip(candidate_texts, scores),
            key=lambda x: x[1],
            reverse=True
        )

        # Sirf texts return karo
        reranked_texts = [text for text, score in ranked_pairs[:top_k]]

        return reranked_texts

    except Exception as e:
        print(f"Reranking failed: {e}")
        if fallback:
            # Fallback: original order mein top_k return kar do
            print("Falling back to initial retrieval order")
            return candidate_texts[:top_k]
        else:
            return []

In [16]:
EOS_ID = sp.eos_id()

@torch.no_grad()
def generate_answer_from_ids(
    prompt_ids,
    max_new_tokens: int = 300,
    temperature: float = 0.6,
    top_p: float = 0.85,
    repetition_penalty: float = 1.4,
    penalty_window: int = 64,          # parameter bana diya flexible
    max_total_length: int = 512
) -> str:
    """
    Generates answer using the pretrained model with nucleus sampling.
    Returns only the newly generated text (after prompt).
    """
    if not prompt_ids:
        return ""

    # Convert to tensor and ensure batch dim
    input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=DEVICE)
    generated = input_ids.clone()  # copy rakho

    for _ in range(max_new_tokens):
        current_len = generated.shape[1]

        # Sliding window if needed
        if current_len >= max_total_length:
            generated = generated[:, -max_total_length:]

        logits = model(generated)
        next_logits = logits[:, -1, :]

        # Apply repetition penalty
        if repetition_penalty > 1.0:
            recent_tokens = generated[0, -min(penalty_window, current_len):].tolist()
            for token_id in set(recent_tokens):
                if token_id < next_logits.shape[-1]:  # safety check
                    next_logits[0, token_id] /= repetition_penalty

        # Temperature scaling
        if temperature != 1.0:
            next_logits = next_logits / temperature

        # Softmax + top-p (nucleus) filtering
        probs = torch.softmax(next_logits, dim=-1)
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        # Mask tokens after top_p threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = False
        sorted_probs[sorted_indices_to_remove] = 0.0

        # Renormalize (safe check)
        if sorted_probs.sum() == 0:
            # Rare case: all probs zero → uniform fallback ya greedy
            next_token = torch.argmax(probs, dim=-1, keepdim=True)
        else:
            sorted_probs /= sorted_probs.sum(dim=-1, keepdim=True)
            next_token = torch.multinomial(sorted_probs, num_samples=1)
            next_token = sorted_indices.gather(-1, next_token)

        # Append new token
        generated = torch.cat([generated, next_token], dim=1)

        # Early stop on EOS
        if next_token.item() == EOS_ID:
            break

    # Only decode newly generated tokens
    prompt_len = len(prompt_ids)
    answer_ids = generated[0, prompt_len:].tolist()

    try:
        decoded = sp.decode(answer_ids).strip()
    except Exception as e:
        print(f"Decode error: {e}")
        decoded = ""  # fallback

    return decoded

In [17]:
import re

def post_process_hindi(text):
    if not text:
        return ""
    # Extra/multiple spaces ko single karo
    text = re.sub(r'\s+', ' ', text)
    # Punctuation ke baad space daalo agar nahi hai
    text = re.sub(r'([।?!])([^\s])', r'\1 \2', text)
    # हर । के बाद space ensure karo
    text = text.replace("।", "। ")
    # Multiple spaces after । ko clean karo
    text = re.sub(r'।\s+', '। ', text)
    # Starting/ending extra spaces hatao
    return text.strip()

In [18]:
def build_prompt_ids(context_ids, question):
    instruction = "संदर्भ के आधार पर हिंदी में सटीक उत्तर दो।"
    limited_context = context_ids[:max_context_tokens]

    # Sirf ek tarika use karo – yeh best hai
    prompt_parts = [
        sp.encode(instruction.strip()),
        sp.encode("\nसंदर्भ:\n"),
        limited_context,
        sp.encode(f"\nप्रश्न: {question}\nउत्तर:\n")
    ]

    prompt_ids = []
    for part in prompt_parts:
        prompt_ids.extend(part)

    # Optional BOS (training ke hisaab se decide karo)
    if add_bos and sp.bos_id() >= 0:
        prompt_ids.insert(0, sp.bos_id())

    # Safety truncation
    if len(prompt_ids) > 512:
        prompt_ids = prompt_ids[-512:]
        print(f"Warning: Prompt truncated to 512 tokens")

    return prompt_ids

In [19]:
import re
import fitz  # PyMuPDF

def rag_pipeline(pdf_path, question):
    """
    Complete RAG pipeline: PDF load → clean → chunk → index → retrieve → generate answer
    """
    try:
        # 1. Load and extract text from PDF
        doc = fitz.open(pdf_path)
        text = ""
        for page in doc:
            text += page.get_text()
        doc.close()

        # 2. Clean the extracted text
        pdf_text = clean_hindi_text(text)
        pdf_text = post_process_hindi(pdf_text)  # extra polish

        # 3. Chunk the text
        chunks = sentence_chunk_hindi(pdf_text, max_tokens=500, overlap_sentences=3)

        if not chunks:
            return "PDF में कोई उपयोगी सामग्री नहीं मिली।"

        # 4. Build FAISS index
        index = build_faiss_index(chunks)

#         def retrieve_context(
#     query: str,
#     chunks: list[str],
#     index: faiss.Index,
#     top_k: int = 10,
#     rerank_top: int = 3,
#     min_score: float = 0.5,
#     max_context_tokens: int = 200,  # thoda margin rakho instruction ke liye
#     debug: bool = False
# )

        # 5. Retrieve relevant context
        context_ids = retrieve_context(question, chunks, index, top_k = 10, max_context_tokens = 200,min_score = 0.5)

        if not context_ids:
            return "संदर्भ में संबंधित जानकारी उपलब्ध नहीं है।"

        # 6. Build prompt
        prompt_ids = build_prompt_ids(context_ids, question)

        print("Prompt length:", len(prompt_ids))
        print("Decoded prompt start:", sp.decode(prompt_ids))

        # 7. Generate answer
        # Generate settings
        answer = generate_answer_from_ids(
            prompt_ids,
            max_new_tokens=250,
            temperature=0.9,
            top_p=0.95,
            repetition_penalty=1.25,
            penalty_window=64
        )
        # 8. Post-process the answer
        final_answer = post_process_hindi(answer)

        return final_answer

    except Exception as e:
        return f"Error in pipeline: {str(e)}"

In [20]:
# Example usage (clean aur simple)
if __name__ == "__main__":
    pdf_path = "gandhi.pdf"  # yeh file sahi se path daal dena
    question = "उनके पिता का नाम क्या था और वे क्या काम करते थे?"

    print("Question:", question)
    print("-" * 60)

    answer = rag_pipeline(pdf_path, question)
    print("Final Answer:")
    print(answer)

Question: उनके पिता का नाम क्या था और वे क्या काम करते थे?
------------------------------------------------------------


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Final Answer:
Error in pipeline: name 'max_context_tokens' is not defined


In [27]:
# Manual LoRA without PEFT 
class LoRALayer(torch.nn.Module):
    def __init__(self, in_features, out_features, r=16, alpha=32, dropout=0.05):
        super().__init__()
        self.lora_A = torch.nn.Linear(in_features, r, bias=False)
        self.lora_B = torch.nn.Linear(r, out_features, bias=False)
        self.scaling = alpha / r
        self.dropout = torch.nn.Dropout(dropout)
        
    def forward(self, x):
        return self.lora_B(self.dropout(self.lora_A(x))) * self.scaling

# LoRA add function
def add_lora_to_model(model, r=16, alpha=32,  dropout=0.05):
    lora_modules = {}
    
    # Target layers pe LoRA add kar
    for name, module in model.named_modules():
        if "masked_mha.Q" in name or "masked_mha.V" in name or "swi_glu.w1" in name:
            in_f, out_f = module.weight.shape
            lora = LoRALayer(in_f, out_f, r, alpha)
            lora.to(DEVICE)
            lora_modules[name] = lora
            # Original weights freeze kar
            module.weight.requires_grad = False
    
    return lora_modules

# Usage
lora_layers = add_lora_to_model(model)
print(f"Added {len(lora_layers)} LoRA layers")

Added 36 LoRA layers


In [31]:
# sft_rag_train.py
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from peft import LoraConfig, get_peft_model, TaskType
from torch.optim import AdamW
import json
import sentencepiece as spm
from decoder_only_gpt import My_GPT_model  # tera model file

# Device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Step 1: Load tokenizer
sp = spm.SentencePieceProcessor()
sp.load("hindi_tokenizer_new.model")
print("Tokenizer loaded. Vocab size:", sp.get_piece_size())

# Step 2: Load your pretrained model + checkpoint
print("Loading pretrained model...")
model = My_GPT_model(
    vocab_size=sp.get_piece_size(),
    num_layers=12,
    d_model=512,
    d_ff=2048,
    num_heads=8,
    seq_len=512
).to(DEVICE)

# Load checkpoint (tera original weights)
ckpt_path = "checkpoints_HindiGPT-v1_step280000.pt"
ckpt = torch.load(ckpt_path, map_location=DEVICE)
state_dict = ckpt["model"]

# Clean _orig_mod. prefix if present
clean_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}

# Load weights
missing, unexpected = model.load_state_dict(clean_state_dict, strict=False)
print("Missing keys:", missing[:5] if missing else "None")
print("Unexpected keys:", unexpected[:5] if unexpected else "None")

# model = torch.compile(model)  # speed boost


# Step 3: Add LoRA
lora_modules = add_lora_to_model(model, r=16, alpha=32, dropout=0.05)


# Step 4: Dataset (RAG-style)
class HindiRAGDataset(Dataset):
    def __init__(self, jsonl_file, tokenizer, max_len=512):
        self.data = []
        with open(jsonl_file, "r", encoding="utf-8") as f:
            for line in f:
                self.data.append(json.loads(line)["text"])
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        text = self.data[idx]
        ids = self.tokenizer.encode(text)
        if len(ids) > self.max_len:
            ids = ids[:self.max_len]
        attention_mask = [1] * len(ids)
        return {
            "input_ids": torch.tensor(ids),
            "attention_mask": torch.tensor(attention_mask)
        }

# Dataset aur DataLoader
dataset = HindiRAGDataset("alpaca_hindi_sft_clean.jsonl", sp)  # tera jsonl file
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)  # small batch RTX 4050 ke liye safe

# Optimizer - ONLY LoRA parameters train karo
trainable_params = []
for lora in lora_modules.values():
    trainable_params.extend(lora.parameters())
optimizer = AdamW(trainable_params, lr=5e-5)

# Training loop
model.train()
for epoch in range(3):
    total_loss = 0
    for batch_idx, batch in enumerate(dataloader):
        input_ids = batch["input_ids"].to(DEVICE)
        
        logits = model(input_ids)
        labels = input_ids.clone()
        
        # Causal LM loss
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            labels.view(-1),
            ignore_index=-100  # padding ignore
        )
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 50 == 0:
            print(f"Epoch {epoch+1} | Step {batch_idx} | Loss: {loss.item():.4f}")
    
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1} completed | Avg Loss: {avg_loss:.4f}")
    
    # Save LoRA weights manually
    torch.save(
        {name: lora.state_dict() for name, lora in lora_modules.items()},
        f"hindi_rag_manual_lora_epoch{epoch+1}.pt"
    )

print("Training complete! LoRA adapters saved.")

Tokenizer loaded. Vocab size: 32768
Loading pretrained model...
Missing keys: None
Unexpected keys: None
Epoch 1 | Step 0 | Loss: 9.0468
Epoch 1 | Step 50 | Loss: 8.9136
Epoch 1 | Step 100 | Loss: 8.7121
Epoch 1 | Step 150 | Loss: 8.8634
Epoch 1 | Step 200 | Loss: 10.0868
Epoch 1 | Step 250 | Loss: 8.5730
Epoch 1 | Step 300 | Loss: 8.6607
Epoch 1 | Step 350 | Loss: 9.0976
Epoch 1 | Step 400 | Loss: 9.0473
Epoch 1 | Step 450 | Loss: 8.9530
Epoch 1 | Step 500 | Loss: 8.9131
Epoch 1 | Step 550 | Loss: 8.7601
Epoch 1 | Step 600 | Loss: 8.8973
Epoch 1 | Step 650 | Loss: 8.6226
Epoch 1 | Step 700 | Loss: 9.3095
Epoch 1 | Step 750 | Loss: 9.3510
Epoch 1 | Step 800 | Loss: 8.7515
Epoch 1 | Step 850 | Loss: 9.0279
Epoch 1 | Step 900 | Loss: 8.7280
Epoch 1 | Step 950 | Loss: 8.9465
Epoch 1 | Step 1000 | Loss: 8.2157
Epoch 1 | Step 1050 | Loss: 8.8646
Epoch 1 | Step 1100 | Loss: 9.1064
Epoch 1 | Step 1150 | Loss: 8.8043
Epoch 1 | Step 1200 | Loss: 8.7127
Epoch 1 | Step 1250 | Loss: 8.7338
Epoch 1

In [32]:
# Training ke end mein yeh line add kar (ya abhi run kar de)
torch.save(
    {name: lora.state_dict() for name, lora in lora_modules.items()},
    "final_lora_weights.pt"
)
print("LoRA weights saved as: final_lora_weights.pt")

LoRA weights saved as: final_lora_weights.pt


In [33]:
torch.save(model.state_dict(), "full_model_with_lora.pt")
print("Full model (base + LoRA) saved as: full_model_with_lora.pt")

Full model (base + LoRA) saved as: full_model_with_lora.pt


In [34]:
torch.save(model.state_dict(), "backup_epoch3_last.pt")

In [35]:
# test_lora.py
import torch
from decoder_only_gpt import My_GPT_model
import sentencepiece as spm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Load tokenizer
sp = spm.SentencePieceProcessor()
sp.load("hindi_tokenizer_new.model")

# Load base model + checkpoint
model = My_GPT_model(
    vocab_size=sp.get_piece_size(),
    num_layers=12,
    d_model=512,
    d_ff=2048,
    num_heads=8,
    seq_len=512
).to(DEVICE)

ckpt = torch.load("checkpoints_HindiGPT-v1_step280000.pt", map_location=DEVICE)
clean_state_dict = {k.replace("_orig_mod.", ""): v for k, v in ckpt["model"].items()}
model.load_state_dict(clean_state_dict, strict=False)

# Load saved LoRA weights
lora_state = torch.load("final_lora_weights.pt", map_location=DEVICE)

# Add LoRA layers (same function jo training mein use kiya tha)
lora_modules = add_lora_to_model(model)  # tera add_lora_to_model function use kar

for name, lora in lora_modules.items():
    if name in lora_state:
        lora.load_state_dict(lora_state[name])
    else:
        print(f"Warning: LoRA for {name} not found in saved file")

model.eval()
print("Model with LoRA loaded successfully!")

# Ab tera normal RAG pipeline chalaa
# Ya simple test kar
test_prompt = sp.encode("संदर्भ: उनके पिता करमचंद गांधी पोरबंदर राज्य के दीवान थे।\nप्रश्न: उनके पिता का नाम क्या था?\nउत्तर:")
test_ids = torch.tensor([test_prompt], device=DEVICE)

with torch.no_grad():
    logits = model(test_ids)
    next_token = torch.argmax(logits[:, -1, :], dim=-1)
    print("Sample next token:", next_token.item())

Model with LoRA loaded successfully!
Sample next token: 398


In [36]:
test_prompt = """
संदर्भ: उनके पिता करमचंद गांधी पोरबंदर राज्य के दीवान थे।
प्रश्न: उनके पिता का नाम क्या था और वे क्या काम करते थे?
उत्तर:
"""

prompt_ids = sp.encode(test_prompt)
input_ids = torch.tensor([prompt_ids], device=DEVICE)

model.eval()
with torch.no_grad():
    for _ in range(50):  # max 50 tokens generate
        logits = model(input_ids)
        next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0)
        input_ids = torch.cat([input_ids, next_token], dim=1)
        if next_token.item() == sp.eos_id():
            break

generated = sp.decode(input_ids[0].tolist())
print("Generated:", generated[len(test_prompt):].strip())

Generated: त्तर: ⁇ प्रश्न: उनके पिता का नाम क्या था?


In [38]:
# Model load with LoRA
model = My_GPT_model(
    vocab_size=sp.get_piece_size(),
    num_layers=12,
    d_model=512,
    d_ff=2048,
    num_heads=8,
    seq_len=512
).to(DEVICE)
model.load_state_dict(clean_state_dict, strict=False)
lora_state = torch.load("final_lora_weights.pt", map_location=DEVICE)
lora_modules = add_lora_to_model(model)  # tera function
for name, lora in lora_modules.items():
    lora.load_state_dict(lora_state[name])

In [42]:
# test_rag_with_lora.py
import torch
import sentencepiece as spm
from decoder_only_gpt import My_GPT_model
import fitz  # PyMuPDF (agar PDF use kar raha hai)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Load tokenizer
sp = spm.SentencePieceProcessor()
sp.load("hindi_tokenizer_new.model")

# Load base model + checkpoint
model = My_GPT_model(
    vocab_size=sp.get_piece_size(),
    num_layers=12,
    d_model=512,
    d_ff=2048,
    num_heads=8,
    seq_len=512
).to(DEVICE)

ckpt = torch.load("checkpoints_HindiGPT-v1_step280000.pt", map_location=DEVICE)
clean_state_dict = {k.replace("_orig_mod.", ""): v for k, v in ckpt["model"].items()}
model.load_state_dict(clean_state_dict, strict=False)

# Load LoRA weights
lora_state = torch.load("final_lora_weights.pt", map_location=DEVICE)
lora_modules = add_lora_to_model(model)  # tera add_lora_to_model function yahan call kar
for name, lora in lora_modules.items():
    if name in lora_state:
        lora.load_state_dict(lora_state[name])
    else:
        print(f"Warning: LoRA for {name} not found")

model.eval()
print("Model with LoRA loaded successfully!")

# Ab tera RAG pipeline chalaa (ya simple test)
question = "उनके पिता का नाम क्या था और वे क्या काम करते थे?"
# Tera rag_pipeline(question) call kar ya simple generation

# Simple generation test (tera generate_answer_from_ids function use kar)
test_prompt = sp.encode("""
संदर्भ: उनके पिता करमचंद गांधी पोरबंदर राज्य के दीवान थे।
प्रश्न: उनके पिता का नाम क्या था और वे क्या काम करते थे?
उत्तर:
""")

prompt_ids = test_prompt
answer = generate_answer_from_ids(prompt_ids)  # tera function
print("Generated Answer:", answer)

Model with LoRA loaded successfully!
Generated Answer: प्रमाण: उनकी मां ने उन्हें एक छोटे से परिवार में जन्म दिया, जो अपने दादा जी की मृत्यु के बाद गुजरात चले गए थे।


In [50]:
from torch.utils.data import Dataset
import json
import torch

class HindiRAGDataset(Dataset):
    def __init__(self, jsonl_file, tokenizer, max_len=512):
        """
        Hindi RAG-style SFT dataset loader
        - jsonl file mein har line: {"text": "संदर्भ: ... प्रश्न: ... उत्तर: ..."}
        - max_len: sequence length limit
        """
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.data = []

        print(f"Loading dataset from: {jsonl_file}")
        good_lines = 0
        bad_lines = 0

        with open(jsonl_file, "r", encoding="utf-8") as f:
            for line_num, line in enumerate(f, 1):
                line = line.strip()
                if not line:
                    bad_lines += 1
                    continue
                
                try:
                    data = json.loads(line)
                    text = data.get("text", "").strip()
                    if text and len(text) > 20:  # minimum meaningful length
                        self.data.append(text)
                        good_lines += 1
                    else:
                        bad_lines += 1
                except json.JSONDecodeError:
                    print(f"Bad JSON at line {line_num}: {line[:100]}...")
                    bad_lines += 1
                except Exception as e:
                    print(f"Error at line {line_num}: {e}")
                    bad_lines += 1

        print(f"Dataset loaded! Good lines: {good_lines} | Bad/Empty: {bad_lines}")
        print(f"Total valid examples: {len(self.data)}")
        
        if len(self.data) == 0:
            raise ValueError("No valid data found in the file!")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]
        
        # Tokenize
        ids = self.tokenizer.encode(text)
        
        # Truncate if too long
        if len(ids) > self.max_len:
            ids = ids[:self.max_len]
        
        # Attention mask (1 for real tokens)
        attention_mask = [1] * len(ids)
        
        # Optional: padding to max_len (agar batch mein pad karna ho toh)
        # padding = [0] * (self.max_len - len(ids))
        # ids += padding
        # attention_mask += [0] * len(padding)
        
        return {
            "input_ids": torch.tensor(ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
            # Agar chahiye toh original text bhi return kar sakta hai debug ke liye
            # "text": text
        }

In [1]:
# full_sft_final_attempt_with_tqdm_perplexity.py
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import json
import sentencepiece as spm
from decoder_only_gpt import My_GPT_model
from tqdm import tqdm  # ← Yeh add kar diya
import math  # perplexity ke liye

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# Tokenizer
sp = spm.SentencePieceProcessor()
sp.load("hindi_tokenizer_new.model")
print("Tokenizer loaded. Vocab size:", sp.get_piece_size())

# Model + checkpoint
model = My_GPT_model(
    vocab_size=sp.get_piece_size(),
    num_layers=12,
    d_model=512,
    d_ff=2048,
    num_heads=8,
    seq_len=512
).to(DEVICE)

ckpt_path = "checkpoints_HindiGPT-v1_step280000.pt"
print(f"Loading checkpoint: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=DEVICE)
clean_state_dict = {k.replace("_orig_mod.", ""): v for k, v in ckpt["model"].items()}
missing, unexpected = model.load_state_dict(clean_state_dict, strict=False)
print("Missing keys:", len(missing), "Unexpected keys:", len(unexpected))

# Dataset class (tera improved wala)
class HindiRAGDataset(Dataset):
    def __init__(self, jsonl_file, tokenizer, max_len=512):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.data = []
        good = 0
        with open(jsonl_file, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    data = json.loads(line)
                    text = data.get("text", "").strip()
                    if text:
                        self.data.append(text)
                        good += 1
                except:
                    pass
        print(f"Dataset loaded! Valid examples: {good}/{len(self.data)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]
        ids = self.tokenizer.encode(text)
        if len(ids) > self.max_len:
            ids = ids[:self.max_len]
        return {"input_ids": torch.tensor(ids, dtype=torch.long)}

# Load dataset
jsonl_file = "alpaca_hindi_sft_clean.jsonl"
dataset = HindiRAGDataset(jsonl_file, sp, max_len=512)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)

# Optimizer + Scheduler
optimizer = AdamW(model.parameters(), lr=8e-5, betas=(0.9, 0.98), eps=1e-6)
scheduler = CosineAnnealingLR(optimizer, T_max=len(dataloader) * 5, eta_min=1e-7)

model.train()
print("Training started!")

for epoch in range(5):
    total_loss = 0.0
    num_batches = 0

    # tqdm progress bar add kar diya (epoch ke andar)
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}", leave=True)

    for batch_idx, batch in enumerate(progress_bar):
        input_ids = batch["input_ids"].to(DEVICE)

        logits = model(input_ids)
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            input_ids.view(-1),
            ignore_index=-100
        )

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3.0)
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        num_batches += 1

        # tqdm mein live loss aur LR update karte hain
        current_lr = scheduler.get_last_lr()[0]
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'lr': f'{current_lr:.2e}'
        })

        if batch_idx % 10000 == 0 and batch_idx > 0:
            save_path = f"full_sft_epoch{epoch+1}_step{batch_idx}.pt"
            torch.save(model.state_dict(), save_path)
            print(f"Saved checkpoint: {save_path}")

    avg_loss = total_loss / num_batches if num_batches > 0 else 0

    # Perplexity calculate kar (bahut meaningful metric LM ke liye)
    perplexity = math.exp(avg_loss)
    print(f"Epoch {epoch+1} completed | Avg Loss: {avg_loss:.4f} | Perplexity: {perplexity:.2f}")

# Final save
final_path = "full_sft_final.pt"
torch.save(model.state_dict(), final_path)
print(f"Full SFT complete! Final model saved as: {final_path}")

Using device: cuda
Tokenizer loaded. Vocab size: 32768
Loading checkpoint: checkpoints_HindiGPT-v1_step280000.pt
Missing keys: 0 Unexpected keys: 0
Dataset loaded! Valid examples: 36353/36353
Training started!


Epoch 1:  28%|██████████▋                            | 10002/36353 [20:23<1:14:43,  5.88it/s, loss=0.6969, lr=7.94e-05]

Saved checkpoint: full_sft_epoch1_step10000.pt


Epoch 1:  55%|██████████████████████▌                  | 20002/36353 [40:39<45:57,  5.93it/s, loss=0.8592, lr=7.76e-05]

Saved checkpoint: full_sft_epoch1_step20000.pt


Epoch 1:  83%|████████████████████████████████▏      | 30002/36353 [1:01:00<18:54,  5.60it/s, loss=0.9076, lr=7.47e-05]

Saved checkpoint: full_sft_epoch1_step30000.pt


Epoch 1: 100%|███████████████████████████████████████| 36353/36353 [1:13:53<00:00,  8.20it/s, loss=0.2115, lr=7.24e-05]


Epoch 1 completed | Avg Loss: 0.3209 | Perplexity: 1.38


Epoch 2:  28%|██████████▋                            | 10002/36353 [20:19<1:23:21,  5.27it/s, loss=0.1014, lr=6.78e-05]

Saved checkpoint: full_sft_epoch2_step10000.pt


Epoch 2:  55%|██████████████████████▌                  | 20002/36353 [40:35<48:36,  5.61it/s, loss=0.1756, lr=6.25e-05]

Saved checkpoint: full_sft_epoch2_step20000.pt


Epoch 2:  83%|████████████████████████████████▏      | 30002/36353 [1:00:55<20:22,  5.20it/s, loss=0.8661, lr=5.65e-05]

Saved checkpoint: full_sft_epoch2_step30000.pt


Epoch 2: 100%|███████████████████████████████████████| 36353/36353 [1:13:48<00:00,  8.21it/s, loss=0.3387, lr=5.24e-05]


Epoch 2 completed | Avg Loss: 0.3169 | Perplexity: 1.37


Epoch 3:  28%|██████████▋                            | 10002/36353 [20:18<1:16:24,  5.75it/s, loss=0.0020, lr=4.57e-05]

Saved checkpoint: full_sft_epoch3_step10000.pt


Epoch 3:  55%|██████████████████████▌                  | 20002/36353 [40:35<48:21,  5.64it/s, loss=0.1700, lr=3.88e-05]

Saved checkpoint: full_sft_epoch3_step20000.pt


Epoch 3:  83%|████████████████████████████████▏      | 30002/36353 [1:00:45<18:15,  5.80it/s, loss=0.7211, lr=3.19e-05]

Saved checkpoint: full_sft_epoch3_step30000.pt


Epoch 3: 100%|███████████████████████████████████████| 36353/36353 [1:13:38<00:00,  8.23it/s, loss=0.1030, lr=2.77e-05]


Epoch 3 completed | Avg Loss: 0.2969 | Perplexity: 1.35


Epoch 4:  28%|██████████▋                            | 10002/36353 [20:11<1:22:42,  5.31it/s, loss=0.1206, lr=2.14e-05]

Saved checkpoint: full_sft_epoch4_step10000.pt


Epoch 4:  55%|██████████████████████▌                  | 20002/36353 [40:22<51:44,  5.27it/s, loss=0.3045, lr=1.56e-05]

Saved checkpoint: full_sft_epoch4_step20000.pt


Epoch 4:  83%|████████████████████████████████▏      | 30002/36353 [1:00:33<18:43,  5.65it/s, loss=0.0967, lr=1.05e-05]

Saved checkpoint: full_sft_epoch4_step30000.pt


Epoch 4: 100%|███████████████████████████████████████| 36353/36353 [1:13:22<00:00,  8.26it/s, loss=0.0413, lr=7.73e-06]


Epoch 4 completed | Avg Loss: 0.2744 | Perplexity: 1.32


Epoch 5:  28%|██████████▋                            | 10002/36353 [20:08<1:19:25,  5.53it/s, loss=0.0068, lr=4.17e-06]

Saved checkpoint: full_sft_epoch5_step10000.pt


Epoch 5:  55%|██████████████████████▌                  | 20002/36353 [40:16<52:09,  5.22it/s, loss=0.0013, lr=1.68e-06]

Saved checkpoint: full_sft_epoch5_step20000.pt


Epoch 5:  83%|████████████████████████████████▏      | 30002/36353 [1:00:21<17:57,  5.89it/s, loss=0.0388, lr=3.40e-07]

Saved checkpoint: full_sft_epoch5_step30000.pt


Epoch 5: 100%|███████████████████████████████████████| 36353/36353 [1:13:12<00:00,  8.28it/s, loss=0.6378, lr=1.00e-07]


Epoch 5 completed | Avg Loss: 0.2545 | Perplexity: 1.29
Full SFT complete! Final model saved as: full_sft_final.pt


In [1]:
with open("alpaca_hindi_sft_clean.jsonl", "r", encoding="utf-8") as f:
    for i in range(1):
        print(f.readline().strip())

{"text": "### प्रश्न:\nस्वस्थ रहने के लिए तीन सुझाव दें।\n\n### उत्तर:\n1. संतुलित और पौष्टिक आहार लेंः सुनिश्चित करें कि आपके भोजन में विभिन्न प्रकार के फल और सब्जियां, दुबला प्रोटीन, साबुत अनाज और स्वस्थ वसा शामिल हों। यह आपके शरीर को सर्वोत्तम रूप से कार्य करने के लिए आवश्यक पोषक तत्व प्रदान करने में मदद करता है और पुरानी बीमारियों को रोकने में मदद कर सकता है।\n\n2. नियमित शारीरिक गतिविधि में संलग्न रहेंः मजबूत हड्डियों, मांसपेशियों और हृदय स्वास्थ्य को बनाए रखने के लिए व्यायाम महत्वपूर्ण है। प्रत्येक सप्ताह कम से कम 150 मिनट के मध्यम एरोबिक व्यायाम या 75 मिनट के जोरदार व्यायाम का लक्ष्य रखें।\n\n3. पर्याप्त नींद लेंः पर्याप्त गुणवत्ता वाली नींद लेना शारीरिक और मानसिक कल्याण के लिए महत्वपूर्ण है। यह मनोदशा को नियंत्रित करने, संज्ञानात्मक कार्य में सुधार करने और स्वस्थ विकास और प्रतिरक्षा कार्य का समर्थन करने में मदद करता है। हर रात 7 से 9 घंटे सोने का लक्ष्य रखें।"}


In [12]:
import torch
import sentencepiece as spm
from decoder_only_gpt import My_GPT_model

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Tokenizer
sp = spm.SentencePieceProcessor()
sp.load("hindi_tokenizer_new.model")

# Model
model = My_GPT_model(
    vocab_size=sp.get_piece_size(),
    num_layers=12,
    d_model=512,
    d_ff=2048,
    num_heads=8,
    seq_len=512
).to(DEVICE)

# Load final SFT checkpoint
# model.load_state_dict(torch.load("full_sft_final.pt", map_location=DEVICE), strict=False)
model.load_state_dict(torch.load("full_sft_epoch1_step1000.pt", map_location=DEVICE), strict=False)
model.eval()

print("Full SFT model loaded!")

# Tera generate function (copy kar le)
@torch.no_grad()
def generate(text_prompt, max_new=100, temp=0.5, top_p=0.95, rep_pen=1.2):
    prompt_ids = sp.encode(text_prompt)
    input_ids = torch.tensor([prompt_ids], device=DEVICE)

    for _ in range(max_new):
        logits = model(input_ids)
        next_logits = logits[:, -1, :]
        next_logits /= temp
        probs = torch.softmax(next_logits, dim=-1)

        sorted_probs, sorted_idx = torch.sort(probs, descending=True)
        cum_probs = torch.cumsum(sorted_probs, dim=-1)
        mask = cum_probs > top_p
        mask[..., 1:] = mask[..., :-1].clone()
        mask[..., 0] = False
        sorted_probs[mask] = 0.0

        if sorted_probs.sum() > 0:
            sorted_probs /= sorted_probs.sum()
            next_token = torch.multinomial(sorted_probs, num_samples=1)
            next_token = sorted_idx.gather(-1, next_token)
        else:
            next_token = torch.argmax(probs, dim=-1, keepdim=True)

        input_ids = torch.cat([input_ids, next_token], dim=1)
        if next_token.item() == sp.eos_id():
            break

    return sp.decode(input_ids[0].tolist()[len(prompt_ids):]).strip()

# Test prompts (tera RAG style)
test_prompts = [
    "महात्मा गांधी का जन्म कब और कहाँ हुआ था?",
    "गांधीजी की मुख्य विचारधारा क्या थी? अहिंसा और सत्य के बारे में बताइए।",
    "स्वस्थ रहने के लिए तीन सुझाव दें।"
]

for p in test_prompts:
    print("\nPrompt:", p)
    answer = generate(p)
    print("Answer:", answer)
    print("-" * 80)

Full SFT model loaded!

Prompt: महात्मा गांधी का जन्म कब और कहाँ हुआ था?
Answer: ????????????????????????????????????????????????????????????????????????????????????????????????????
--------------------------------------------------------------------------------

Prompt: गांधीजी की मुख्य विचारधारा क्या थी? अहिंसा और सत्य के बारे में बताइए।
Answer: ।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।।
--------------------------------------------------------------------------------

Prompt: स्वस्थ रहने के लिए तीन सुझाव दें।
Answer: 
--------------------------------------------------------------------------------


In [6]:
print("Tokenizer vocab:", sp.get_piece_size())
print("Model vocab:", model.lm_head.out_features)  # ya model.embedding.weight.shape[0]

Tokenizer vocab: 32768
Model vocab: 32768


In [7]:
@torch.no_grad()
def simple_greedy_test(prompt_text):
    prompt_ids = sp.encode(prompt_text)
    input_ids = torch.tensor([prompt_ids], device=DEVICE)
    
    for _ in range(50):
        logits = model(input_ids)
        next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0)
        input_ids = torch.cat([input_ids, next_token], dim=1)
        if next_token.item() == sp.eos_id():
            break
    
    generated = sp.decode(input_ids[0].tolist())
    print("Greedy output:", generated[len(prompt_text):].strip())

simple_greedy_test("महात्मा गांधी का जन्म ")

Greedy output: जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म जन्म


In [15]:
# full_sft_final_attempt_fixed.py
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
import json
import sentencepiece as spm
from decoder_only_gpt import My_GPT_model
from tqdm import tqdm
import math
import os

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# Tokenizer
sp = spm.SentencePieceProcessor()
sp.load("hindi_tokenizer_new.model")
print("Tokenizer loaded. Vocab size:", sp.get_piece_size())

# Model + checkpoint
model = My_GPT_model(
    vocab_size=sp.get_piece_size(),
    num_layers=12,
    d_model=512,
    d_ff=2048,
    num_heads=8,
    seq_len=512
).to(DEVICE)

ckpt_path = "checkpoints_HindiGPT-v1_step280000.pt"
print(f"Loading checkpoint: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=DEVICE)
clean_state_dict = {k.replace("_orig_mod.", ""): v for k, v in ckpt["model"].items()}
missing, unexpected = model.load_state_dict(clean_state_dict, strict=False)
print("Missing keys:", len(missing), "Unexpected keys:", len(unexpected))

# Dataset (tera improved wala)
class HindiRAGDataset(Dataset):
    def __init__(self, jsonl_file, tokenizer, max_len=512):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.data = []
        good = 0
        with open(jsonl_file, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    data = json.loads(line)
                    text = data.get("text", "").strip()
                    if text:
                        self.data.append(text)
                        good += 1
                except:
                    pass
        print(f"Dataset loaded! Valid examples: {good}/{len(self.data)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]
        ids = self.tokenizer.encode(text)
        if len(ids) > self.max_len:
            ids = ids[:self.max_len]
        return {"input_ids": torch.tensor(ids, dtype=torch.long)}

# Load dataset
jsonl_file = "alpaca_hindi_sft_clean.jsonl"
dataset = HindiRAGDataset(jsonl_file, sp, max_len=512)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)

# Optimizer + Scheduler + Warmup
optimizer = AdamW(model.parameters(), lr=5e-6, betas=(0.9, 0.98), eps=1e-6)

warmup_steps = 2000
def warmup_lambda(step):
    return min(1.0, step / warmup_steps)
warmup_scheduler = LambdaLR(optimizer, lr_lambda=warmup_lambda)

cosine_scheduler = CosineAnnealingLR(optimizer, T_max=len(dataloader) * 5, eta_min=1e-7)

model.train()
print("Training started!")

for epoch in range(5):
    total_loss = 0.0
    num_batches = 0

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}", leave=True)

    for batch_idx, batch in enumerate(progress_bar):
        input_ids = batch["input_ids"].to(DEVICE)

        logits = model(input_ids)
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            input_ids.view(-1),
            ignore_index=-100  # agar data mein padding nahi toh comment kar ke try kar sakta hai
        )

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)  # very strict clip
        optimizer.step()
        warmup_scheduler.step()
        cosine_scheduler.step()

        total_loss += loss.item()
        num_batches += 1

        current_lr = warmup_scheduler.get_last_lr()[0]
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'lr': f'{current_lr:.2e}',
            'ppl': f'{math.exp(loss.item()):.2f}'
        })

        # Save every 1000 steps
        if batch_idx % 1000 == 0 and batch_idx > 0:
            save_path = f"full_sft_epoch{epoch+1}_step{batch_idx}.pt"
            torch.save(model.state_dict(), save_path)
            print(f"Saved checkpoint: {save_path}")

    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    perplexity = math.exp(avg_loss)
    print(f"Epoch {epoch+1} completed | Avg Loss: {avg_loss:.4f} | Perplexity: {perplexity:.2f}")

# Final save
final_path = "full_sft_final.pt"
torch.save(model.state_dict(), final_path)
print(f"Full SFT complete! Final model saved as: {final_path}")

Using device: cuda
Tokenizer loaded. Vocab size: 32768
Loading checkpoint: checkpoints_HindiGPT-v1_step280000.pt
Missing keys: 0 Unexpected keys: 0
Dataset loaded! Valid examples: 36353/36353
Training started!


Epoch 1:   3%|▊                             | 1002/36353 [02:05<1:50:49,  5.32it/s, loss=0.2500, lr=2.51e-06, ppl=1.28]

Saved checkpoint: full_sft_epoch1_step1000.pt


Epoch 1:   3%|▉                             | 1109/36353 [02:19<1:13:58,  7.94it/s, loss=0.4041, lr=2.77e-06, ppl=1.50]


KeyboardInterrupt: 

In [13]:
model.eval()
with torch.no_grad():
    prompt = "संदर्भ: उनके पिता करमचंद गांधी पोरबंदर राज्य के दीवान थे।\nप्रश्न: उनके पिता का नाम क्या था?\nउत्तर:"
    ids = sp.encode(prompt)
    x = torch.tensor([ids], device=DEVICE)
    
    for _ in range(50):
        logits = model(x)
        next_id = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0)
        x = torch.cat([x, next_id], dim=1)
        if next_id.item() == sp.eos_id():
            break
    
    generated = sp.decode(x[0].tolist())
    print("Live test output:", generated[len(prompt):].strip())
model.train()

Live test output: ्तर:::::::::::::::::::::::::::::::::::::::::::::::::::


My_GPT_model(
  (decoder): Decoder(
    (embedding): Embedding(32768, 512)
    (layers): ModuleList(
      (0-11): 12 x Decoder_GPT_Block(
        (swi_glu): SwiGLU_FFN(
          (w1): Linear(in_features=512, out_features=1536, bias=False)
          (w2): Linear(in_features=512, out_features=1536, bias=False)
          (w3): Linear(in_features=1536, out_features=512, bias=False)
          (act): SiLU()
        )
        (masked_mha): Masked_MHA(
          (Q): Linear(in_features=512, out_features=512, bias=True)
          (K): Linear(in_features=512, out_features=512, bias=True)
          (V): Linear(in_features=512, out_features=512, bias=True)
          (fc_out): Linear(in_features=512, out_features=512, bias=True)
        )
        (rms_norm0): RMSNorm()
        (rms_norm1): RMSNorm()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): RMSNorm()
  )
  (lm_head): Linear(in_features=512, out_features=32768, bias=False)
)