In [56]:
import json 
import yaml
from pygments import highlight, lexers, formatters
from typing import List,Dict,Tuple,Any

import torch 
from torch.utils.data import Dataset,DataLoader
import tiktoken

from processing_data.dataset import InstructionDataset,format_input
from processing_data.dataloader import get_data_loader,instruction_collate_fn
from utils import tokens_to_text


from gpt2 import GPT2Model
from loss import cross_entropy
from train import traininng_loop

tokenizer = tiktoken.get_encoding("gpt2")

with open('config.yaml','r') as f:
    config = yaml.safe_load(f)


In [34]:
with open('raw_data/instruction-examples.json','r') as f:
    data = json.load(f)


In [35]:
formatted_json = json.dumps(data[1], indent=4)
colorful_json = highlight(formatted_json,
                          lexers.JsonLexer(),
                          formatters.TerminalFormatter())

print(colorful_json)

{[37m[39;49;00m
[37m    [39;49;00m[94m"instruction"[39;49;00m:[37m [39;49;00m[33m"What is the plural form of \"goose\"?"[39;49;00m,[37m[39;49;00m
[37m    [39;49;00m[94m"input"[39;49;00m:[37m [39;49;00m[33m""[39;49;00m,[37m[39;49;00m
[37m    [39;49;00m[94m"output"[39;49;00m:[37m [39;49;00m[33m"The plural form of \"goose\" is \"geese.\""[39;49;00m[37m[39;49;00m
}[37m[39;49;00m



In [36]:
print(format_input(data[1]))

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
What is the plural form of "goose"?


In [37]:
train_index = int(len(data) * 0.8)
val_index = int(len(data) * 0.1)

train_data = data[:train_index]
val_data = data[train_index: train_index + val_index]
test_data = data[train_index + val_index:]

print(f"Train size: {len(train_data)}")
print(f"Validation size: {len(val_data)}")
print(f"Test size: {len(test_data)}")


Train size: 160
Validation size: 20
Test size: 20


In [38]:
train_ds = InstructionDataset(train_data,tokenizer)

In [39]:
len(train_ds[1])

54

In [40]:
print(tokens_to_text(train_ds[0]))

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Identify the verb in the following sentence: The cat sleeps on the couch.

### Responsive:
The verb in the sentence is "sleeps."


In [51]:
train_dl = get_data_loader(
    train_ds,
    batch_size=2,
    shuffle=False,
    drop_last=True,
    num_workers=0,
    collate_fn=instruction_collate_fn
    )

val_ds = InstructionDataset(val_data,tokenizer)
val_dl = get_data_loader(
    val_ds,
    batch_size=2,
    shuffle=False,
    drop_last=True,
    num_workers=0,
    collate_fn=instruction_collate_fn
    )


In [42]:
for input,output in train_dl:
    print(f'input.shape: {input.shape} output.shape: {output.shape}')
    

input.shape: torch.Size([2, 57]) output.shape: torch.Size([2, 57])
input.shape: torch.Size([2, 57]) output.shape: torch.Size([2, 57])
input.shape: torch.Size([2, 61]) output.shape: torch.Size([2, 61])
input.shape: torch.Size([2, 60]) output.shape: torch.Size([2, 60])
input.shape: torch.Size([2, 61]) output.shape: torch.Size([2, 61])
input.shape: torch.Size([2, 57]) output.shape: torch.Size([2, 57])
input.shape: torch.Size([2, 59]) output.shape: torch.Size([2, 59])
input.shape: torch.Size([2, 63]) output.shape: torch.Size([2, 63])
input.shape: torch.Size([2, 60]) output.shape: torch.Size([2, 60])
input.shape: torch.Size([2, 64]) output.shape: torch.Size([2, 64])
input.shape: torch.Size([2, 75]) output.shape: torch.Size([2, 75])
input.shape: torch.Size([2, 60]) output.shape: torch.Size([2, 60])
input.shape: torch.Size([2, 66]) output.shape: torch.Size([2, 66])
input.shape: torch.Size([2, 58]) output.shape: torch.Size([2, 58])
input.shape: torch.Size([2, 55]) output.shape: torch.Size([2, 

In [52]:
print(len(train_dl))
print(len(val_dl))

80
10


In [43]:
# for input,output in train_dl:
#     print(input)
#     print(output)
#     break

In [44]:
model = GPT2Model(config)

with torch.no_grad():
    output = model(torch.randint(0,100,(2,10)))
    print(output)


tensor([[[-0.4223,  0.8664, -0.9820,  ..., -1.6785, -0.8263, -0.5373],
         [-0.9225,  0.3052, -0.6632,  ..., -1.4212,  0.5825,  1.2669],
         [ 0.1208,  0.0786, -1.1472,  ..., -0.0067, -0.7273,  1.0709],
         ...,
         [-0.5586, -0.2207, -0.0293,  ..., -0.7779, -0.1665, -0.2204],
         [-0.5827, -1.2379, -0.4539,  ..., -0.5088,  1.3019, -0.2421],
         [ 0.4074, -0.0570, -0.0421,  ...,  0.9512,  0.6207,  0.3603]],

        [[-0.1490,  0.5916, -0.9294,  ..., -0.5167, -0.7071, -0.1347],
         [-0.2227,  0.3213,  0.1298,  ..., -0.6169,  0.2928, -0.0327],
         [-0.5166,  0.6594, -0.5260,  ...,  0.5458, -0.3208,  1.5027],
         ...,
         [-1.0470, -0.2002,  0.0915,  ...,  0.0203, -0.2529, -0.3517],
         [-0.1029, -0.6388, -0.4151,  ..., -0.0661,  0.6080,  0.4242],
         [ 1.5079,  0.1866,  1.0460,  ..., -0.3795,  0.1673,  0.2957]]])


In [45]:
config

{'dropout': 0.1,
 'vocab_size': 50257,
 'embed_dim': 768,
 'stride': 256,
 'batch_size': 2,
 'shuffle': False,
 'drop_last': True,
 'num_workers': 0,
 'context_window': 256,
 'num_heads': 12,
 'Q_K_V_bias': False,
 'kv_bias': False,
 'batch_first': True,
 'device': None,
 'n_layers': 12,
 'num_classes': 50257,
 'num_tokens_to_generate': 20,
 'look_back': 256,
 'text_to_generate': 'Every single step',
 'learning_rate': 0.0004,
 'num_epochs': 10}

In [59]:
optimizer = torch.optim.AdamW(model.parameters(),lr=0.0001)

In [60]:
traininng_loop(
    model,
    train_dl,
    val_dl,
    cross_entropy,
    optimizer,
    num_epochs=10,
    device='cpu'
)


2025-05-05 20:45:37,652 - INFO - Epoch 1/10
2025-05-05 20:46:26,078 - INFO - Seen tokens: 10374
2025-05-05 20:46:26,079 - INFO - Loss: 2.3498
2025-05-05 20:46:26,692 - INFO - Validation Loss: 3.3537
2025-05-05 20:46:26,693 - INFO - Epoch 2/10
2025-05-05 20:47:10,955 - INFO - Seen tokens: 20748
2025-05-05 20:47:10,956 - INFO - Loss: 2.1597
2025-05-05 20:47:11,577 - INFO - Validation Loss: 3.3766
2025-05-05 20:47:11,578 - INFO - Epoch 3/10
2025-05-05 20:47:48,259 - INFO - Seen tokens: 31122
2025-05-05 20:47:48,261 - INFO - Loss: 2.0995
2025-05-05 20:47:48,909 - INFO - Validation Loss: 3.3921
2025-05-05 20:47:48,909 - INFO - Epoch 4/10
2025-05-05 20:48:36,046 - INFO - Seen tokens: 41496
2025-05-05 20:48:36,048 - INFO - Loss: 2.0630
2025-05-05 20:48:36,670 - INFO - Validation Loss: 3.4224
2025-05-05 20:48:36,671 - INFO - Epoch 5/10
2025-05-05 20:49:23,169 - INFO - Seen tokens: 51870
2025-05-05 20:49:23,171 - INFO - Loss: 2.0232
2025-05-05 20:49:23,808 - INFO - Validation Loss: 3.4492
2025-