In [2]:
from decoder_only_gpt import My_GPT_model
def build_model(config):
    model = My_GPT_model(
        vocab_size=config["model"]["vocab_size"],
        d_model=config["model"]["d_model"],
        num_layers=config["model"]["n_layer"],
        num_heads=config["model"]["n_head"],
        d_ff=config["model"]["d_ff"],
        seq_len=config["model"]["seq_len"],
        dropout=config["model"]["dropout"]
    )
    return model

In [3]:
CONFIG = {
    # --------------------
    # Model architecture
    # --------------------
    "model": {
        "vocab_size": 32768,
        "d_model": 512,
        "n_layer": 12,
        "n_head": 8,
        "d_ff": 2048,
        "seq_len": 512,
        "dropout": 0.1,
        "weight_tying": True,
        "norm_type": "rmsnorm",
        "ffn_type": "swiglu"
    },

    # --------------------
    # Training (SFT)
    # --------------------
    "train": {
        "batch_size": 1,
        "micro_batch_size": 1,     # future gradient accumulation
        "grad_accum_steps": 1,
        "epochs": 2,
        "lr": 1e-5,
        "weight_decay": 0.01,
        "grad_clip": 1.0,
        "label_smoothing": 0.0,
        "ignore_index": -100,
        "fp16": True,              # future ready
        "bf16": False,
        "seed": 42
    },

    # --------------------
    # Data
    # --------------------
    "data": {
        "dataset": "hindi_sft_v1",
        "format": "### प्रश्न: / ### उत्तर:",
        "pad_token_id": 0,
        "bos_token_id": 2, 
        "eos_token_id": 3,
        "max_seq_len": 512,
        "mask_prompt": True        # loss only on answer
    },

    # --------------------
    # Logging / Checkpoint
    # --------------------
    "logging": {
        "project": "HindiGPT-SFT",
        "log_every": 50,
        "eval_every": 2000,
        "save_every": 7000,
        "save_dir": "checkpoints_sft",
        "wandb": True
    }
}

base_model = build_model(CONFIG)

In [4]:
pretrain_model = My_GPT_model(vocab_size=CONFIG["model"]["vocab_size"],
        d_model=CONFIG["model"]["d_model"],
        num_layers=CONFIG["model"]["n_layer"],
        num_heads=CONFIG["model"]["n_head"],
        d_ff=CONFIG["model"]["d_ff"],
        seq_len=CONFIG["model"]["seq_len"],
        dropout=CONFIG["model"]["dropout"])

In [5]:
import torch
ckpt = torch.load("checkpoints_HindiGPT-v1_step280000.pt", map_location="cpu")

In [6]:
raw_state = ckpt["model"]

clean_state = {}
for k, v in raw_state.items():
    if k.startswith("_orig_mod."):
        k = k.replace("_orig_mod.", "")
    clean_state[k] = v

In [7]:
state_dict = ckpt["model"]
print(list(clean_state.keys())[:20])

['decoder.causal_mask', 'decoder.embedding.weight', 'decoder.layers.0.swi_glu.w1.weight', 'decoder.layers.0.swi_glu.w2.weight', 'decoder.layers.0.swi_glu.w3.weight', 'decoder.layers.0.masked_mha.Q.weight', 'decoder.layers.0.masked_mha.Q.bias', 'decoder.layers.0.masked_mha.K.weight', 'decoder.layers.0.masked_mha.K.bias', 'decoder.layers.0.masked_mha.V.weight', 'decoder.layers.0.masked_mha.V.bias', 'decoder.layers.0.masked_mha.fc_out.weight', 'decoder.layers.0.masked_mha.fc_out.bias', 'decoder.layers.0.rms_norm0.weight', 'decoder.layers.0.rms_norm1.weight', 'decoder.layers.1.swi_glu.w1.weight', 'decoder.layers.1.swi_glu.w2.weight', 'decoder.layers.1.swi_glu.w3.weight', 'decoder.layers.1.masked_mha.Q.weight', 'decoder.layers.1.masked_mha.Q.bias']


In [8]:
pretrain_model.load_state_dict(clean_state, strict=True)

<All keys matched successfully>

In [9]:
for p in pretrain_model.parameters():
    p.requires_grad = False

In [10]:
from torch import nn
class LoRALinear(nn.Module):
    def __init__(self, base_linear: nn.Linear, r=16, alpha=16, dropout=0.05):
        super().__init__()

        self.weight = base_linear.weight
        self.weight.requires_grad = False
        self.bias = base_linear.bias

        self.r = r
        self.scaling = alpha / r

        self.lora_A = nn.Linear(base_linear.in_features, r, bias=False)
        self.lora_B = nn.Linear(r, base_linear.out_features, bias=False)

        nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B.weight)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = F.linear(x, self.weight, self.bias)
        out += self.lora_B(self.lora_A(self.dropout(x))) * self.scaling
        return out

In [11]:
def inject_lora_attention(model, r=8, alpha=16):
    for layer in model.decoder.layers:
        mha = layer.masked_mha

        mha.Q = LoRALinear(mha.Q, r=r, alpha=alpha)
        mha.K = LoRALinear(mha.K, r=r, alpha=alpha)
        mha.V = LoRALinear(mha.V, r=r, alpha=alpha)

In [12]:
import math
inject_lora_attention(pretrain_model, r=8, alpha=16)

In [13]:
trainable = [(n, p.shape) for n, p in pretrain_model.named_parameters() if p.requires_grad]
print(len(trainable))

72


In [14]:
assert all("lora_" in n for n, p in pretrain_model.named_parameters() if p.requires_grad)

In [15]:
lora_params = [p for p in pretrain_model.parameters() if p.requires_grad]

optimizer = torch.optim.AdamW(
    lora_params,
    lr=2e-4,
    weight_decay=0.0
)

In [16]:
lora_state = {
    k: v.cpu()
    for k, v in pretrain_model.state_dict().items()
    if "lora_" in k
}

torch.save(lora_state, "lora_sft.pt")

In [17]:
def inject_lora_attention(model, r=8, alpha=16):
    for layer in model.decoder.layers:
        mha = layer.masked_mha

        if not isinstance(mha.Q, LoRALinear):
            mha.Q = LoRALinear(mha.Q, r=r, alpha=alpha)

        if not isinstance(mha.K, LoRALinear):
            mha.K = LoRALinear(mha.K, r=r, alpha=alpha)

        if not isinstance(mha.V, LoRALinear):
            mha.V = LoRALinear(mha.V, r=r, alpha=alpha)

In [18]:
inject_lora_attention(pretrain_model, r=8, alpha=16)
lora_state = torch.load("lora_sft.pt")
base_model.load_state_dict(lora_state, strict=False)

_IncompatibleKeys(missing_keys=['decoder.causal_mask', 'decoder.embedding.weight', 'decoder.layers.0.swi_glu.w1.weight', 'decoder.layers.0.swi_glu.w2.weight', 'decoder.layers.0.swi_glu.w3.weight', 'decoder.layers.0.masked_mha.Q.weight', 'decoder.layers.0.masked_mha.Q.bias', 'decoder.layers.0.masked_mha.K.weight', 'decoder.layers.0.masked_mha.K.bias', 'decoder.layers.0.masked_mha.V.weight', 'decoder.layers.0.masked_mha.V.bias', 'decoder.layers.0.masked_mha.fc_out.weight', 'decoder.layers.0.masked_mha.fc_out.bias', 'decoder.layers.0.rms_norm0.weight', 'decoder.layers.0.rms_norm1.weight', 'decoder.layers.1.swi_glu.w1.weight', 'decoder.layers.1.swi_glu.w2.weight', 'decoder.layers.1.swi_glu.w3.weight', 'decoder.layers.1.masked_mha.Q.weight', 'decoder.layers.1.masked_mha.Q.bias', 'decoder.layers.1.masked_mha.K.weight', 'decoder.layers.1.masked_mha.K.bias', 'decoder.layers.1.masked_mha.V.weight', 'decoder.layers.1.masked_mha.V.bias', 'decoder.layers.1.masked_mha.fc_out.weight', 'decoder.layer

In [19]:
model = build_model(CONFIG)

In [20]:
model.load_state_dict(clean_state, strict=True)

<All keys matched successfully>

In [21]:
inject_lora_attention(model, r=8, alpha=16)

In [22]:
lora_state = torch.load("lora_sft.pt")
model.load_state_dict(lora_state, strict=False)

_IncompatibleKeys(missing_keys=['decoder.causal_mask', 'decoder.embedding.weight', 'decoder.layers.0.swi_glu.w1.weight', 'decoder.layers.0.swi_glu.w2.weight', 'decoder.layers.0.swi_glu.w3.weight', 'decoder.layers.0.masked_mha.Q.weight', 'decoder.layers.0.masked_mha.Q.bias', 'decoder.layers.0.masked_mha.K.weight', 'decoder.layers.0.masked_mha.K.bias', 'decoder.layers.0.masked_mha.V.weight', 'decoder.layers.0.masked_mha.V.bias', 'decoder.layers.0.masked_mha.fc_out.weight', 'decoder.layers.0.masked_mha.fc_out.bias', 'decoder.layers.0.rms_norm0.weight', 'decoder.layers.0.rms_norm1.weight', 'decoder.layers.1.swi_glu.w1.weight', 'decoder.layers.1.swi_glu.w2.weight', 'decoder.layers.1.swi_glu.w3.weight', 'decoder.layers.1.masked_mha.Q.weight', 'decoder.layers.1.masked_mha.Q.bias', 'decoder.layers.1.masked_mha.K.weight', 'decoder.layers.1.masked_mha.K.bias', 'decoder.layers.1.masked_mha.V.weight', 'decoder.layers.1.masked_mha.V.bias', 'decoder.layers.1.masked_mha.fc_out.weight', 'decoder.layer

In [23]:
for n, p in model.named_parameters():
    if "lora_" in n:
        print(n, p.requires_grad)

decoder.layers.0.masked_mha.Q.lora_A.weight True
decoder.layers.0.masked_mha.Q.lora_B.weight True
decoder.layers.0.masked_mha.K.lora_A.weight True
decoder.layers.0.masked_mha.K.lora_B.weight True
decoder.layers.0.masked_mha.V.lora_A.weight True
decoder.layers.0.masked_mha.V.lora_B.weight True
decoder.layers.1.masked_mha.Q.lora_A.weight True
decoder.layers.1.masked_mha.Q.lora_B.weight True
decoder.layers.1.masked_mha.K.lora_A.weight True
decoder.layers.1.masked_mha.K.lora_B.weight True
decoder.layers.1.masked_mha.V.lora_A.weight True
decoder.layers.1.masked_mha.V.lora_B.weight True
decoder.layers.2.masked_mha.Q.lora_A.weight True
decoder.layers.2.masked_mha.Q.lora_B.weight True
decoder.layers.2.masked_mha.K.lora_A.weight True
decoder.layers.2.masked_mha.K.lora_B.weight True
decoder.layers.2.masked_mha.V.lora_A.weight True
decoder.layers.2.masked_mha.V.lora_B.weight True
decoder.layers.3.masked_mha.Q.lora_A.weight True
decoder.layers.3.masked_mha.Q.lora_B.weight True
decoder.layers.3.mas

In [24]:
model.eval().cuda()

My_GPT_model(
  (decoder): Decoder(
    (embedding): Embedding(32768, 512)
    (layers): ModuleList(
      (0-11): 12 x Decoder_GPT_Block(
        (swi_glu): SwiGLU_FFN(
          (w1): Linear(in_features=512, out_features=1536, bias=False)
          (w2): Linear(in_features=512, out_features=1536, bias=False)
          (w3): Linear(in_features=1536, out_features=512, bias=False)
          (act): SiLU()
        )
        (masked_mha): Masked_MHA(
          (Q): LoRALinear(
            (lora_A): Linear(in_features=512, out_features=8, bias=False)
            (lora_B): Linear(in_features=8, out_features=512, bias=False)
            (dropout): Dropout(p=0.05, inplace=False)
          )
          (K): LoRALinear(
            (lora_A): Linear(in_features=512, out_features=8, bias=False)
            (lora_B): Linear(in_features=8, out_features=512, bias=False)
            (dropout): Dropout(p=0.05, inplace=False)
          )
          (V): LoRALinear(
            (lora_A): Linear(in_features

In [25]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model.eval()
model.to(device)

My_GPT_model(
  (decoder): Decoder(
    (embedding): Embedding(32768, 512)
    (layers): ModuleList(
      (0-11): 12 x Decoder_GPT_Block(
        (swi_glu): SwiGLU_FFN(
          (w1): Linear(in_features=512, out_features=1536, bias=False)
          (w2): Linear(in_features=512, out_features=1536, bias=False)
          (w3): Linear(in_features=1536, out_features=512, bias=False)
          (act): SiLU()
        )
        (masked_mha): Masked_MHA(
          (Q): LoRALinear(
            (lora_A): Linear(in_features=512, out_features=8, bias=False)
            (lora_B): Linear(in_features=8, out_features=512, bias=False)
            (dropout): Dropout(p=0.05, inplace=False)
          )
          (K): LoRALinear(
            (lora_A): Linear(in_features=512, out_features=8, bias=False)
            (lora_B): Linear(in_features=8, out_features=512, bias=False)
            (dropout): Dropout(p=0.05, inplace=False)
          )
          (V): LoRALinear(
            (lora_A): Linear(in_features

In [26]:
def top_p_filtering(logits, top_p=0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    probs = torch.softmax(sorted_logits, dim=-1)

    cumulative_probs = torch.cumsum(probs, dim=-1)

    # tokens remove where cumulative prob > top_p
    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0

    indices_to_remove = sorted_indices[sorted_indices_to_remove]
    logits[indices_to_remove] = -float("inf")

    return logits


In [27]:
import torch
import torch.nn.functional as F
import sentencepiece as spm
from decoder_only_gpt import My_GPT_model_SFT

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CKPT_PATH = "checkpoints_sft_step_140000.pt"
SEQ_LEN = 512

TEMPERATURE = 1.0
TOP_P = 0.85
REPETITION_PENALTY = 1.7
MAX_NEW_TOKENS = 80

sp = spm.SentencePieceProcessor(model_file="hindi_tokenizer_new.model")
print("BOS:", sp.bos_id(), "EOS:", sp.eos_id(), "PAD:", sp.pad_id())

@torch.no_grad()
def generate(model, input_ids, max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE, top_p=TOP_P, repetition_penalty=REPETITION_PENALTY):
    model.eval()
    eos_id = sp.eos_id()
    seq_len = input_ids.shape[1]

    for step in range(max_new_tokens):
        # Keep only last SEQ_LEN tokens
        input_cond = input_ids[:, -SEQ_LEN:]
        logits = model(input_cond)  # (B, T, V)
        logits = logits[:, -1, :] / temperature  # Take last token logits & apply temperature

        # Clamp logits to avoid extreme values
        logits = torch.clamp(logits, min=-1e9, max=1e9)

        # Penalize non-Devanagari tokens
        for tid in range(sp.get_piece_size()):
            piece = sp.id_to_piece(tid)
            if not any('\u0900' <= ch <= '\u097F' for ch in piece):
                logits[0, tid] -= 5.0

        # Apply repetition penalty on generated tokens only (not prompt)
        generated_tokens = input_ids[0, seq_len:].tolist()
        for t in set(generated_tokens):
            logits[0, t] /= repetition_penalty

        # Top-p (nucleus) sampling
        probs = F.softmax(logits, dim=-1)
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        cum_probs = torch.cumsum(sorted_probs, dim=-1)

        mask = cum_probs > top_p
        mask[..., 1:] = mask[..., :-1].clone()
        mask[..., 0] = False

        sorted_probs[mask] = 0.0
        sorted_probs /= sorted_probs.sum(dim=-1, keepdim=True)

        next_token = torch.multinomial(sorted_probs, 1)
        next_token = torch.gather(sorted_indices, -1, next_token)

        input_ids = torch.cat([input_ids, next_token], dim=1)

        if next_token.item() == eos_id:
            break

    return input_ids


def main():
    print("Loading checkpoint...")
    ckpt = torch.load(CKPT_PATH, map_location=DEVICE)

    print("Initializing model...")
    model = My_GPT_model_SFT(
        vocab_size=sp.get_piece_size(),
        num_layers=12,
        d_model=512,
        d_ff=2048,
        num_heads=8,
        seq_len=SEQ_LEN
    ).to(DEVICE)

    state_dict = ckpt.get("model", ckpt)
    if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
        print("Detected torch.compile prefix, stripping '_orig_mod.'...")
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}

    model.load_state_dict(state_dict)
    model.eval()
    print("Model loaded successfully!")

    prompt = "### प्रश्न:\nप्रकृति के बारे में एक ऐसी कविता बनाएँ जिसमें केवल दो अलग-अलग तुकबंदी शब्दों का उपयोग हो।\n\n### उत्तर:\n"
    # prompt = "### प्रश्न:\nशिविर यात्रा के लिए एक व्यक्ति को दस वस्तुओं की आवश्यकता हो सकती हैः\n\n1. तम्बू-तत्वों से आश्रय और सुरक्षा प्रदान करने के लिए\n2. स्लीपिंग बैग-सोते समय गर्म और आरामदायक रहने के लिए\n3. पोर्टेबल स्टोव या कैम्पफायर ग्रिल-खाना पकाने के लिए\n4. जल्दी खराब होने वाले भोजन और पेय को ठंडा रखने के लिए बर्फ या बर्फ के पैक के साथ ठंडा करें।\n5. लालटेन या टॉर्च-रात के दौरान प्रकाश प्रदान करने के लिए\n6. प्राथमिक चिकित्सा किट-मामूली चोटों या बीमारियों के लिए\n7. मानचित्र और कम्पास या जी. पी. एस.-पर्वतारोहण या क्षेत्र की खोज के लिए\n8. कैम्प कुर्सियाँ या तह कुर्सियाँ-कैम्पसाइट के आसपास आरामदायक बैठने के लिए\n9. कीट विकर्षक-कीट के काटने से बचाने के लिए\n10. सनस्क्रीन-सनबर्न से बचाने के लिए\n\n### उत्तर:\n"
    
    print(f"Prompt: {prompt}")

    input_ids = [sp.bos_id()] + sp.encode(prompt, out_type=int)
    input_ids = torch.tensor([input_ids], device=DEVICE)

    for i in range(3):
        print(f"--- Generation {i+1} ---")
        output_ids = generate(model, input_ids.clone())
        generated_tokens = output_ids[0, input_ids.shape[1]:].tolist()
        generated_text = sp.decode(generated_tokens)
        print(generated_text)
        print("-" * 60)

if __name__ == "__main__":
    main()

BOS: 2 EOS: 3 PAD: 0
Loading checkpoint...
Initializing model...
Model loaded successfully!
Prompt: ### प्रश्न:
प्रकृति के बारे में एक ऐसी कविता बनाएँ जिसमें केवल दो अलग-अलग तुकबंदी शब्दों का उपयोग हो।

### उत्तर:

--- Generation 1 ---
 ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇ 
------------------------------------------------------------
--- Generation 2 ---
 ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇ 
------------------------------------------------------------
--- Generation 3 ---
 ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  

In [28]:
import sentencepiece as spm
sp = spm.SentencePieceProcessor(model_file="hindi_tokenizer_new.model")
import torch.nn.functional as F

prompt = "भारत का भविष्य"
input_ids = torch.tensor(
    [sp.encode(prompt)],
    dtype=torch.long
)

SEQ_LEN = 512

TEMPERATURE = 0.6
TOP_P = 0.85
REPETITION_PENALTY = 1.6

# Sampling hyperparameters
  # ← Increase this! 1.2–1.8 works best for repetitive small models
PENALTY_WINDOW = 128       # Not used directly now, but kept for future
MAX_NEW_TOKENS = 200

output_ids = generate(
    model,
    input_ids,
    max_new_tokens=80,
    temperature=TEMPERATURE,
    top_p=0.9,
    eos_token_id=sp.eos_id()
)

print(sp.decode(output_ids[0].tolist()))

TypeError: generate() got an unexpected keyword argument 'eos_token_id'

In [None]:
import torch
ckpt = torch.load("checkpoints_HindiGPT-v1_step280000.pt", map_location="cpu")

In [None]:
state_dict = ckpt["model"]

In [None]:
ckpt = torch.load("checkpoints_HindiGPT-v1_step280000.pt", map_location="cpu")
print(ckpt.keys())

In [None]:
raw_state = ckpt["model"]

clean_state = {}
for k, v in raw_state.items():
    if k.startswith("_orig_mod."):
        k = k.replace("_orig_mod.", "")
    clean_state[k] = v

In [None]:
state_dict = ckpt["model"]
print(list(clean_state.keys())[:20])

In [None]:
base_model.load_state_dict(clean_state, strict=True)

In [None]:
i = 100
TABLE_PREFIX = "com_"
table = f"{TABLE_PREFIX}{i}"
table