In [14]:
import torch, os
import sentencepiece as spm

In [3]:
sp = spm.SentencePieceProcessor(model_file="hindi_tokenizer_new.model")

In [4]:
print(sp.get_piece_size())

print("PAD:", sp.pad_id())
print("BOS:", sp.bos_id())
print("EOS:", sp.eos_id())
print("UNK:", sp.unk_id())

32768
PAD: 0
BOS: 2
EOS: 3
UNK: 1


In [6]:
from decoder_only_gpt import My_GPT_model_SFT

In [80]:
CONFIG = {
    # --------------------
    # Model architecture
    # --------------------
    "model": {
        "vocab_size": 32768,
        "d_model": 512,
        "n_layer": 12,
        "n_head": 8,
        "d_ff": 2048,
        "seq_len": 512,
        "dropout": 0.1,
        "weight_tying": True,
        "norm_type": "rmsnorm",
        "ffn_type": "swiglu"
    },

    # --------------------
    # Training (SFT)
    # --------------------
    "train": {
        "batch_size": 1,
        "micro_batch_size": 1,     # future gradient accumulation
        "grad_accum_steps": 1,
        "epochs": 2,
        "lr": 1e-5,
        "weight_decay": 0.01,
        "grad_clip": 1.0,
        "label_smoothing": 0.0,
        "ignore_index": -100,
        "fp16": True,              # future ready
        "bf16": False,
        "seed": 42
    },

    # --------------------
    # Data
    # --------------------
    "data": {
        "dataset": "hindi_sft_v1",
        "format": "### प्रश्न: / ### उत्तर:",
        "pad_token_id": 0,
        "bos_token_id": 2, 
        "eos_token_id": 3,
        "max_seq_len": 512,
        "mask_prompt": True        # loss only on answer
    },

    # --------------------
    # Logging / Checkpoint
    # --------------------
    "logging": {
        "project": "HindiGPT-SFT",
        "log_every": 50,
        "eval_every": 2000,
        "save_every": 7000,
        "save_dir": "checkpoints_sft",
        "wandb": True
    }
}

In [40]:
model = My_GPT_model_SFT(
    vocab_size=CONFIG["model"]["vocab_size"], num_layers=CONFIG["model"]["n_layer"],
    d_model=CONFIG["model"]["d_model"], d_ff=CONFIG["model"]["d_ff"],
    num_heads=CONFIG["model"]["n_head"], seq_len=CONFIG["model"]["seq_len"]
)

In [41]:
pretrained_path = "checkpoints_HindiGPT-v1_step280000.pt"  # <-- put actual pretrained checkpoint path here

if os.path.isfile(pretrained_path):
    print(f"Loading pretrained weights from {pretrained_path}")
    checkpoint = torch.load(pretrained_path, map_location='cpu')
    
    # If checkpoint has key "model" holding weights, else use checkpoint directly
    state_dict = checkpoint.get("model", checkpoint)
    
    # Remove any unwanted prefixes (like _orig_mod.) if present, matching your load_checkpoint code
    new_state = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    
    missing, unexpected = model.load_state_dict(new_state, strict=False)
    print(f"Loaded pretrained model with missing keys: {missing}")
    print("************")
    print(f"unexpected keys: {unexpected}")
else:
    print("Pretrained model checkpoint not found, training from scratch.")

Loading pretrained weights from checkpoints_HindiGPT-v1_step280000.pt
Loaded pretrained model with missing keys: ['decoder.layers.0.masked_mha.q_base.weight', 'decoder.layers.0.masked_mha.q_base.bias', 'decoder.layers.0.masked_mha.k_base.weight', 'decoder.layers.0.masked_mha.k_base.bias', 'decoder.layers.0.masked_mha.v_base.weight', 'decoder.layers.0.masked_mha.v_base.bias', 'decoder.layers.0.masked_mha.q_proj.weight', 'decoder.layers.0.masked_mha.q_proj.bias', 'decoder.layers.0.masked_mha.q_proj.lora_A.weight', 'decoder.layers.0.masked_mha.q_proj.lora_B.weight', 'decoder.layers.0.masked_mha.k_proj.weight', 'decoder.layers.0.masked_mha.k_proj.bias', 'decoder.layers.0.masked_mha.k_proj.lora_A.weight', 'decoder.layers.0.masked_mha.k_proj.lora_B.weight', 'decoder.layers.0.masked_mha.v_proj.weight', 'decoder.layers.0.masked_mha.v_proj.bias', 'decoder.layers.0.masked_mha.v_proj.lora_A.weight', 'decoder.layers.0.masked_mha.v_proj.lora_B.weight', 'decoder.layers.1.masked_mha.q_base.weight', '

In [42]:
model

My_GPT_model_SFT(
  (decoder): Decoder(
    (embedding): Embedding(32768, 512)
    (layers): ModuleList(
      (0-11): 12 x Decoder_GPT_Block(
        (swi_glu): SwiGLU_FFN(
          (w1): Linear(in_features=512, out_features=1536, bias=False)
          (w2): Linear(in_features=512, out_features=1536, bias=False)
          (w3): Linear(in_features=1536, out_features=512, bias=False)
          (act): SiLU()
        )
        (masked_mha): Masked_MHA_LORA(
          (q_base): Linear(in_features=512, out_features=512, bias=True)
          (k_base): Linear(in_features=512, out_features=512, bias=True)
          (v_base): Linear(in_features=512, out_features=512, bias=True)
          (q_proj): LoRALinear(
            (lora_A): Linear(in_features=512, out_features=8, bias=False)
            (lora_B): Linear(in_features=8, out_features=512, bias=False)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (k_proj): LoRALinear(
            (lora_A): Linear(in_features=512

In [19]:
model.decoder

Decoder(
  (embedding): Embedding(32768, 512)
  (layers): ModuleList(
    (0-11): 12 x Decoder_GPT_Block(
      (swi_glu): SwiGLU_FFN(
        (w1): Linear(in_features=512, out_features=1536, bias=False)
        (w2): Linear(in_features=512, out_features=1536, bias=False)
        (w3): Linear(in_features=1536, out_features=512, bias=False)
        (act): SiLU()
      )
      (masked_mha): Masked_MHA_LORA(
        (q_base): Linear(in_features=512, out_features=512, bias=True)
        (k_base): Linear(in_features=512, out_features=512, bias=True)
        (v_base): Linear(in_features=512, out_features=512, bias=True)
        (q_proj): LoRALinear(
          (lora_A): Linear(in_features=512, out_features=8, bias=False)
          (lora_B): Linear(in_features=8, out_features=512, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (k_proj): LoRALinear(
          (lora_A): Linear(in_features=512, out_features=8, bias=False)
          (lora_B): Linear(in_features=8,

In [20]:
model.decoder.embedding

Embedding(32768, 512)

In [21]:
model.decoder.embedding.weight

Parameter containing:
tensor([[ 1.6479e-03, -5.1934e-02,  1.9765e-02,  ...,  4.8727e-02,
          2.1490e-03,  2.9600e-02],
        [ 2.1276e-03, -3.7330e-02,  1.1916e-02,  ...,  1.6846e-02,
          9.4464e-06,  2.3255e-02],
        [ 1.6793e-03, -5.1942e-02,  1.9771e-02,  ...,  4.8732e-02,
          2.1496e-03,  2.9585e-02],
        ...,
        [ 1.6573e-03, -5.1932e-02,  1.9768e-02,  ...,  4.8731e-02,
          2.1527e-03,  2.9579e-02],
        [ 1.6747e-03, -5.1922e-02,  1.9737e-02,  ...,  4.8732e-02,
          2.1439e-03,  2.9594e-02],
        [ 1.8618e-03, -5.1502e-02,  1.9627e-02,  ...,  4.7211e-02,
          2.2105e-03,  2.9348e-02]], requires_grad=True)

In [22]:
model.decoder.embedding.weight.shape

torch.Size([32768, 512])

In [28]:
print(model.decoder.embedding.weight[0].shape)
model.decoder.embedding.weight[0]

torch.Size([512])


tensor([ 1.6479e-03, -5.1934e-02,  1.9765e-02,  4.1537e-02, -2.9619e-02,
        -3.1721e-02, -3.5112e-02,  7.1186e-03, -6.3270e-03,  2.3410e-02,
         4.2177e-02, -3.9054e-02,  5.1253e-02, -5.1261e-01,  2.2189e-02,
         2.7451e-02, -8.7450e-03, -3.0652e-01,  1.2234e-03,  9.8096e-03,
         2.9875e-02, -2.8258e-01,  1.1134e-01, -4.7863e-02, -1.1598e-01,
         1.0249e-02, -1.5890e-02,  3.0105e-02, -5.2984e-02,  4.2057e-02,
         7.4862e-02, -8.3000e-03,  2.8148e-02, -9.7618e-02,  1.5783e-02,
         3.2441e-02,  1.6196e-01,  3.0731e-02, -2.5335e-01,  2.2948e-02,
        -4.9372e-02,  1.3575e-02, -6.1646e-02,  2.1235e-03,  1.6382e-02,
         2.7277e-02,  1.2145e-02,  1.1668e-01, -4.0974e-02, -2.4567e-02,
         7.2909e-02,  5.0054e-02,  2.9814e-02, -9.4731e-03, -1.3241e-02,
         3.2810e-02, -1.4707e-02,  2.0861e-02,  2.5190e-02, -3.7752e-02,
         1.0164e-03, -2.4250e-02, -3.4445e-01,  3.8835e-02, -5.2761e-02,
        -5.9122e-02,  2.4172e-02,  5.9933e-02,  1.7

In [27]:
model.decoder.embedding.embedding_dim

512

In [30]:
model.lm_head.weight is model.decoder.embedding.weight

True

In [43]:
sp.encode("स्वस्थ रहने के लिए")

[3717, 1020, 11, 91]

In [44]:
model.decoder.embedding.max_norm

In [54]:
tokenizer = spm.SentencePieceProcessor(model_file="hindi_tokenizer_new.model")

In [55]:
from torch.utils.data import Dataset, DataLoader
class SFT_Dataset(Dataset):
    def __init__(self, data, tokenizer):
        
        self.data = data
        self.tokenizer = tokenizer
        self.answer_key = "### उत्तर:"

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        text = self.data[index]["text"]

        BOS = CONFIG["data"]["bos_token_id"]
        EOS = CONFIG["data"]["eos_token_id"]

        full_ids = [BOS] + self.tokenizer.encode(text) + [EOS]

        pos = text.find(self.answer_key)
        if pos == -1:
            pos = len(text)

        prompt_text = text[:pos + len(self.answer_key)]
        prompt_ids = [BOS] + self.tokenizer.encode(prompt_text)

        min_len = min(len(prompt_ids), len(full_ids))
        answer_start = 0
        for i in range(min_len, 0, -1):
            if prompt_ids[:i] == full_ids[:i]:
                answer_start = i
                break

        input_ids = torch.tensor(full_ids, dtype=torch.long)

        if CONFIG["data"]["mask_prompt"] and answer_start > 0:
            labels = torch.full_like(input_ids, -100)
            labels[answer_start:] = input_ids[answer_start:]
        else:
            labels = input_ids.clone()

        return {
            "input_ids": input_ids,
            "labels": labels
        }

In [56]:
print("VOCAB:", sp.get_piece_size())
print("BOS:", CONFIG["data"]["bos_token_id"])
print("EOS:", CONFIG["data"]["eos_token_id"])
print("PAD:", CONFIG["data"]["pad_token_id"])

VOCAB: 32768
BOS: 2
EOS: 3
PAD: 0


In [57]:
def test_single_sample(dataset, idx=0):
    sample = dataset[idx]

    input_ids = sample["input_ids"]
    labels = sample["labels"]

    print("="*80)
    print("INPUT IDS:", input_ids.tolist())
    print("LABELS   :", labels.tolist())

    print("\nDECODED FULL:")
    print(sp.decode(input_ids.tolist()))

    answer_tokens = [
        t for t, l in zip(input_ids.tolist(), labels.tolist()) if l != -100
    ]

    print("\nDECODED ANSWER ONLY:")
    print(sp.decode(answer_tokens))

    print("\nBOS OK:", input_ids[0].item() == sp.bos_id())
    print("EOS OK:", input_ids[-1].item() == sp.eos_id())
    print("="*80)

In [58]:
import json
data = []
with open("alpaca_hindi_sft_clean.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        data.append(json.loads(line))

In [59]:
sample = data[1]
sample

{'text': '### प्रश्न:\nतीन प्राथमिक रंग कौन से हैं?\n\n### उत्तर:\nतीन प्राथमिक रंग लाल, नीला और पीला हैं। इन रंगों को प्राथमिक कहा जाता है क्योंकि इन्हें अन्य रंगों को मिलाकर नहीं बनाया जा सकता है और अन्य सभी रंगों को विभिन्न अनुपातों में मिलाकर बनाया जा सकता है। प्रकाश के लिए उपयोग की जाने वाली योगात्मक रंग प्रणाली में, प्राथमिक रंग लाल, हरा और नीला (आर. जी. बी.) हैं।'}

In [62]:
dataset = SFT_Dataset(data, tokenizer=sp)

In [63]:
test_single_sample(dataset, idx=0)

INPUT IDS: [2, 28843, 1, 2821, 28911, 1, 1970, 357, 1020, 11, 91, 565, 4535, 884, 28869, 1, 832, 28911, 1, 28924, 28889, 10608, 44, 13215, 5479, 1678, 28920, 3010, 393, 38, 778, 2661, 22, 1295, 1005, 11, 2347, 44, 11726, 28879, 6388, 83, 6769, 28879, 23834, 6823, 44, 3717, 10944, 679, 776, 28869, 133, 778, 1358, 28, 12226, 377, 32, 311, 140, 11, 91, 1290, 8777, 3377, 1652, 140, 22, 981, 629, 15, 44, 3497, 5155, 28, 2932, 22, 981, 36, 438, 15, 28869, 1, 28932, 28889, 3420, 3946, 9069, 22, 12531, 2690, 28920, 2286, 12737, 28879, 11664, 44, 4727, 1167, 28, 1566, 1397, 11, 91, 8044, 1634, 15, 28869, 2386, 2296, 253, 32, 253, 11776, 28922, 1887, 11, 6288, 24342, 18105, 8044, 158, 9747, 28938, 1887, 11, 7125, 8044, 40, 2411, 2694, 28869, 1, 28937, 28889, 3839, 4499, 1678, 28920, 3839, 3517, 561, 4499, 2459, 3946, 44, 3754, 2940, 11, 91, 1634, 15, 28869, 133, 3346, 12413, 28, 6124, 140, 28879, 25251, 221, 1545, 311, 22, 2064, 140, 44, 3717, 777, 44, 17259, 311, 40, 1929, 140, 22, 981, 629, 15

In [64]:
def debug_token_labels(dataset, idx=0):
    sample = dataset[idx]
    ids = sample["input_ids"]
    labels = sample["labels"]

    print(f"{'TOKEN':<15} {'ID':<6} {'LABEL'}")
    print("-"*45)

    for i, l in zip(ids.tolist(), labels.tolist()):
        tok = sp.decode([i])
        print(f"{tok:<15} {i:<6} {l}")

In [65]:
debug_token_labels(dataset, 0)

TOKEN           ID     LABEL
---------------------------------------------
                2      -100
                28843  -100
 ⁇              1      -100
प्रश्न          2821   -100
:               28911  -100
 ⁇              1      -100
स्व             1970   -100
स्थ             357    -100
रहने            1020   -100
के              11     -100
लिए             91     -100
तीन             565    -100
सुझाव           4535   -100
दें             884    -100
।               28869  -100
 ⁇              1      -100
उत्तर           832    -100
:               28911  -100
 ⁇              1      1
1               28924  28924
.               28889  28889
संतुलित         10608  10608
और              44     44
पौष्टिक         13215  13215
आहार            5479   5479
लें             1678   1678
ः               28920  28920
सुनिश्चित       3010   3010
करें            393    393
कि              38     38
आपके            778    778
भोजन            2661   2661
में             22     22
विभिन्न

In [66]:
bad_sample = [{"text": "hello world"}]
bad_ds = SFT_Dataset(bad_sample, sp)

test_single_sample(bad_ds, 0)

INPUT IDS: [2, 5215, 22292, 28919, 26999, 28923, 28927, 3]
LABELS   : [-100, -100, -100, -100, -100, -100, -100, 3]

DECODED FULL:
hello world

DECODED ANSWER ONLY:


BOS OK: True
EOS OK: True


In [67]:
def collate_fn(batch):
    seq_len = CONFIG["model"]["seq_len"]
    input_ids_list = []
    labels_list = []
    attention_mask_list = []

    pad_id = CONFIG["data"]["pad_token_id"]   # ← yahan se le

    for item in batch:
        ids = item["input_ids"][:seq_len]
        lbls = item["labels"][:seq_len]

        curr_len = len(ids)
        pad_len = seq_len - curr_len

        if pad_len > 0:
            ids = torch.cat([ids, torch.full((pad_len,), pad_id, dtype=torch.long)])
            lbls = torch.cat([lbls, torch.full((pad_len,), -100, dtype=torch.long)])

        # Attention mask: 1 = real token, 0 = padding
        mask = torch.ones(seq_len, dtype=torch.long)
        if pad_len > 0:
            mask[curr_len:] = 0

        input_ids_list.append(ids)
        labels_list.append(lbls)
        attention_mask_list.append(mask)

    return {
        "input_ids": torch.stack(input_ids_list),
        "labels": torch.stack(labels_list),
        "attention_mask": torch.stack(attention_mask_list)
    }

In [69]:
from torch.utils.data import DataLoader

loader = DataLoader(
    dataset,
    batch_size=2,
    collate_fn=collate_fn,
    shuffle=False
)

batch = next(iter(loader))

print("input_ids:", batch["input_ids"].shape)
print("labels   :", batch["labels"].shape)
print("attn_mask:", batch["attention_mask"].shape)

print("\nPAD locations:")
print(batch["attention_mask"] == 0)

input_ids: torch.Size([2, 512])
labels   : torch.Size([2, 512])
attn_mask: torch.Size([2, 512])

PAD locations:
tensor([[False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True]])


In [75]:
device = "cuda"
model = model.cuda() 

In [77]:
import torch.nn.functional as F

lora_params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.AdamW(lora_params,lr=2e-4, weight_decay=0.0)

batch = next(iter(loader))
model.train()
optimizer.zero_grad()

logits = model(
    batch["input_ids"].cuda(),
    attention_mask=batch["attention_mask"].cuda()
)

loss = F.cross_entropy(
    logits.view(-1, logits.size(-1)),
    batch["labels"].cuda().view(-1),
    ignore_index=-100
)

loss.backward()
optimizer.step()

print("Training step done. Loss:", loss.item())

Training step done. Loss: 7.265008926391602


In [81]:
from tqdm import tqdm

for epoch in range(CONFIG["train"]["epochs"]):
    model.train()
    total_loss = 0

    # tqdm with dataloader
    loop = tqdm(loader, desc=f"Epoch {epoch+1}/{CONFIG['train']['epochs']}")
    
    for batch in loop:
        optimizer.zero_grad()
        input_ids = batch["input_ids"].cuda()
        attention_mask = batch["attention_mask"].cuda()
        labels = batch["labels"].cuda()

        logits = model(input_ids, attention_mask=attention_mask)
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            labels.view(-1),
            ignore_index=-100
        )
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Update tqdm postfix to show current loss
        loop.set_postfix(loss=loss.item())
    
    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch+1} finished. Average Loss: {avg_loss:.4f}")

Epoch 1/2: 100%|█████████████████████████████████████████████████| 18177/18177 [1:25:41<00:00,  3.54it/s, loss=5.04e-5]


Epoch 1 finished. Average Loss: 0.0382


Epoch 2/2: 100%|█████████████████████████████████████████████████| 18177/18177 [1:30:42<00:00,  3.34it/s, loss=2.23e-6]


Epoch 2 finished. Average Loss: 0.0098


In [82]:
import os
import torch

save_dir = "lora_checkpoints"
os.makedirs(save_dir, exist_ok=True)

torch.save(
    {
        "lora_state_dict": {
            k: v.cpu()
            for k, v in model.state_dict().items()
            if "lora" in k.lower()
        },
        "config": CONFIG
    },
    os.path.join(save_dir, "lora_adapter.pt")
)

print("✅ LoRA adapter saved")

✅ LoRA adapter saved


In [83]:
ckpt = torch.load("lora_checkpoints/lora_adapter.pt")
print(len(ckpt["lora_state_dict"]))

72


In [109]:
from decoder_only_gpt import My_GPT_model
def build_model(config):
    model = My_GPT_model(
        vocab_size=config["model"]["vocab_size"],
        d_model=config["model"]["d_model"],
        num_layers=config["model"]["n_layer"],
        num_heads=config["model"]["n_head"],
        d_ff=config["model"]["d_ff"],
        seq_len=config["model"]["seq_len"],
        dropout=config["model"]["dropout"]
    )
    return model

ImportError: cannot import name 'My_GPT_model' from 'decoder_only_gpt' (D:\deep learning\research_paper\Transformers\sft\decoder_only_gpt.py)

In [108]:
base_model = build_model(CONFIG)

NameError: name 'My_GPT_model' is not defined

In [92]:
ckpt = torch.load("checkpoints_HindiGPT-v1_step280000.pt", map_location="cpu")

In [93]:
state_dict = ckpt["model"]

In [95]:
ckpt = torch.load("checkpoints_HindiGPT-v1_step280000.pt", map_location="cpu")
print(ckpt.keys())

dict_keys(['model', 'optimizer', 'scaler', 'step', 'config'])


In [96]:
state_dict = ckpt["model"]
print(list(state_dict.keys())[:20])

['_orig_mod.decoder.causal_mask', '_orig_mod.decoder.embedding.weight', '_orig_mod.decoder.layers.0.swi_glu.w1.weight', '_orig_mod.decoder.layers.0.swi_glu.w2.weight', '_orig_mod.decoder.layers.0.swi_glu.w3.weight', '_orig_mod.decoder.layers.0.masked_mha.Q.weight', '_orig_mod.decoder.layers.0.masked_mha.Q.bias', '_orig_mod.decoder.layers.0.masked_mha.K.weight', '_orig_mod.decoder.layers.0.masked_mha.K.bias', '_orig_mod.decoder.layers.0.masked_mha.V.weight', '_orig_mod.decoder.layers.0.masked_mha.V.bias', '_orig_mod.decoder.layers.0.masked_mha.fc_out.weight', '_orig_mod.decoder.layers.0.masked_mha.fc_out.bias', '_orig_mod.decoder.layers.0.rms_norm0.weight', '_orig_mod.decoder.layers.0.rms_norm1.weight', '_orig_mod.decoder.layers.1.swi_glu.w1.weight', '_orig_mod.decoder.layers.1.swi_glu.w2.weight', '_orig_mod.decoder.layers.1.swi_glu.w3.weight', '_orig_mod.decoder.layers.1.masked_mha.Q.weight', '_orig_mod.decoder.layers.1.masked_mha.Q.bias']


In [97]:
model_keys = set(base_model.state_dict().keys())
ckpt_keys  = set(state_dict.keys())

print("Missing in checkpoint:", model_keys - ckpt_keys)
print("Extra in checkpoint:", ckpt_keys - model_keys)

Missing in checkpoint: {'decoder.layers.2.masked_mha.fc_out.weight', 'decoder.layers.4.swi_glu.w1.weight', 'decoder.layers.9.masked_mha.fc_out.weight', 'decoder.layers.5.rms_norm1.weight', 'decoder.layers.8.masked_mha.q_proj.lora_B.weight', 'decoder.layers.11.masked_mha.v_proj.lora_B.weight', 'decoder.layers.10.swi_glu.w1.weight', 'decoder.layers.6.masked_mha.v_proj.weight', 'decoder.layers.8.masked_mha.k_proj.bias', 'decoder.layers.10.masked_mha.k_base.bias', 'decoder.layers.0.masked_mha.v_base.weight', 'decoder.layers.3.masked_mha.v_proj.weight', 'decoder.layers.9.masked_mha.q_proj.bias', 'decoder.layers.4.masked_mha.v_base.bias', 'decoder.layers.11.rms_norm0.weight', 'decoder.layers.2.masked_mha.v_base.bias', 'decoder.layers.2.masked_mha.q_base.bias', 'decoder.layers.7.masked_mha.v_proj.lora_A.weight', 'decoder.layers.8.masked_mha.k_base.weight', 'decoder.layers.0.masked_mha.k_proj.bias', 'decoder.layers.1.swi_glu.w2.weight', 'decoder.layers.9.masked_mha.k_proj.bias', 'decoder.layer

In [98]:
raw_state = ckpt["model"]

clean_state = {}
for k, v in raw_state.items():
    if k.startswith("_orig_mod."):
        k = k.replace("_orig_mod.", "")
    clean_state[k] = v

In [99]:
state_dict = ckpt["model"]
print(list(clean_state.keys())[:20])

['decoder.causal_mask', 'decoder.embedding.weight', 'decoder.layers.0.swi_glu.w1.weight', 'decoder.layers.0.swi_glu.w2.weight', 'decoder.layers.0.swi_glu.w3.weight', 'decoder.layers.0.masked_mha.Q.weight', 'decoder.layers.0.masked_mha.Q.bias', 'decoder.layers.0.masked_mha.K.weight', 'decoder.layers.0.masked_mha.K.bias', 'decoder.layers.0.masked_mha.V.weight', 'decoder.layers.0.masked_mha.V.bias', 'decoder.layers.0.masked_mha.fc_out.weight', 'decoder.layers.0.masked_mha.fc_out.bias', 'decoder.layers.0.rms_norm0.weight', 'decoder.layers.0.rms_norm1.weight', 'decoder.layers.1.swi_glu.w1.weight', 'decoder.layers.1.swi_glu.w2.weight', 'decoder.layers.1.swi_glu.w3.weight', 'decoder.layers.1.masked_mha.Q.weight', 'decoder.layers.1.masked_mha.Q.bias']


In [106]:
from decoder_only_gpt import My_GPT_model
base_model = My_GPT_model(vocab_size=CONFIG["model"]["vocab_size"],
        d_model=CONFIG["model"]["d_model"],
        num_layers=CONFIG["model"]["n_layer"],
        num_heads=CONFIG["model"]["n_head"],
        d_ff=CONFIG["model"]["d_ff"],
        seq_len=CONFIG["model"]["seq_len"],
        dropout=CONFIG["model"]["dropout"])
base_model.load_state_dict(clean_state, strict=True)

ImportError: cannot import name 'My_GPT_model' from 'decoder_only_gpt' (D:\deep learning\research_paper\Transformers\sft\decoder_only_gpt.py)