In [1]:
import os 

import yaml
import tiktoken
import torch
from torch import nn
import wandb

from processing_data.dataset import Data,ClassificationDataset
from processing_data.dataloader import get_data_loader
from embeddings import Embeddings
from transformer_block import TransformerBlock
from gpt2 import GPT2Model
from utils import text_to_tokens,tokens_to_text
from loss import cross_entropy,classification_loss
from train import Trainer

# from dotenv import load_dotenv


with open("config.yaml","r") as f:
    config = yaml.safe_load(f)

with open("generate_text_config.yaml","r") as f:
    generate_text_config = yaml.safe_load(f)


In [2]:
# API Keys 
# print(load_dotenv()) 
# os.environ["WANDB_API_KEY"] = os.getenv("WANDB_API_KEY")

True


In [2]:
# turn off scientific notation
# torch.set_printoptions(sci_mode=False,precision=10) 

# read the-verdict.txt
with open("raw_data/the-verdict.txt","r") as f: 
    raw_text = f.read()

In [3]:
train_dateset = ClassificationDataset(
    csv_path='raw_data/sms_spam_collection/train.csv',
    tokenizer=tiktoken.get_encoding("gpt2"),
    max_len=None
)
val_dataset = ClassificationDataset(
    csv_path='raw_data/sms_spam_collection/val.csv',
    tokenizer=tiktoken.get_encoding("gpt2"),
    max_len=train_dateset.max_len
)

test_dataset = ClassificationDataset(
    csv_path='raw_data/sms_spam_collection/test.csv',
    tokenizer=tiktoken.get_encoding("gpt2"),
    max_len=train_dateset.max_len
)

train_dl = get_data_loader(train_dateset,batch_size=32,shuffle=False,drop_last=True,num_workers=0)
val_dl = get_data_loader(val_dataset,batch_size=32,shuffle=False,drop_last=True,num_workers=0)
test_dl = get_data_loader(test_dataset,batch_size=32,shuffle=False,drop_last=True,num_workers=0)

In [4]:
train_dateset.max_len

118

In [13]:
for x,y in train_dl:
    print(x.shape)
    print('-'*100)
    print(y.shape)

torch.Size([32, 118])
----------------------------------------------------------------------------------------------------
torch.Size([32])
torch.Size([32, 118])
----------------------------------------------------------------------------------------------------
torch.Size([32])
torch.Size([32, 118])
----------------------------------------------------------------------------------------------------
torch.Size([32])
torch.Size([32, 118])
----------------------------------------------------------------------------------------------------
torch.Size([32])
torch.Size([32, 118])
----------------------------------------------------------------------------------------------------
torch.Size([32])
torch.Size([32, 118])
----------------------------------------------------------------------------------------------------
torch.Size([32])
torch.Size([32, 118])
----------------------------------------------------------------------------------------------------
torch.Size([32])
torch.Size([32, 118]

In [11]:
len(train_dl)

32

# Dataset & DataLoader 

In [3]:
train_ratio = 0.9
split_index = int(len(raw_text) * train_ratio)
train_text = raw_text[:split_index]
val_text = raw_text[split_index:]


In [4]:
train_dataset = Data(
    raw_text=train_text,
    tokenizer=tiktoken.get_encoding("gpt2"),
    context_length=config["context_window"],
    stride=config["stride"]
)

val_dataset = Data(
    raw_text=val_text,
    tokenizer=tiktoken.get_encoding("gpt2"),
    context_length=config["context_window"],
    stride=config["stride"]
)

train_dl = get_data_loader(
    train_dataset,
    batch_size=config["batch_size"],
    shuffle=config["shuffle"],
    drop_last=config["drop_last"],
    num_workers=config["num_workers"]
    )

val_dl = get_data_loader(
    val_dataset,
    batch_size=config["batch_size"],
    shuffle=config["shuffle"],
    drop_last=config["drop_last"],
    num_workers=config["num_workers"]
)


In [6]:
# for x,y in train_dl:
#     print(x.shape)
    # print(y.shape)


In [8]:
# train_tokens = 0 
# for x,y in train_dl:
#     train_tokens += x.numel()
# print(f"Train tokens: {train_tokens}")

# val_tokens = 0
# for x,y in val_dl:
#     val_tokens += x.numel()
# print(f"Val tokens: {val_tokens}")


# print(f'total tokens: {train_tokens + val_tokens}')

In [5]:

model = GPT2Model(config)
optimizer = torch.optim.AdamW(model.parameters(),lr=0.0004)

# with torch.no_grad():
#     logits = model(x)

#     print(logits.shape)


# Wandb

In [6]:
wandb.init(
    project="Foundation_models",
    name="accuracy round 2",
    config=config
)

[34m[1mwandb[0m: Currently logged in as: [33mhawardizayee[0m ([33mhawardizayee-unitedhealthcare[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
trainer = Trainer(
    model,
    train_dl,
    val_dl,
    loss_fn=cross_entropy,
    optimizer=optimizer,
    config=config,
    device="cpu",
    generate_text_config=generate_text_config
)

In [8]:
trainer.train(epochs=10,generate_text=True)

2025-05-12 19:18:50,108 - INFO - Epoch 1/10
2025-05-12 19:19:07,018 - INFO - Train loss: 9.211953587002224, Val loss: 7.687678813934326, Train acc: 0.03624131944444445, Val acc: 0.044921875
2025-05-12 19:19:07,838 - INFO - Generated text: Every single step. ,,-- ., the the-- ,, the, , the.
2025-05-12 19:19:07,838 - INFO - Epoch 2/10
2025-05-12 19:19:24,392 - INFO - Train loss: 6.710766315460205, Val loss: 6.732245445251465, Train acc: 0.049262152777777776, Val acc: 0.0234375
2025-05-12 19:19:25,067 - INFO - Generated text: Every single step, "      ,, the      . 
2025-05-12 19:19:25,067 - INFO - Epoch 3/10
2025-05-12 19:19:41,198 - INFO - Train loss: 6.3690649138556585, Val loss: 6.638359069824219, Train acc: 0.058159722222222224, Val acc: 0.052734375
2025-05-12 19:19:41,908 - INFO - Generated text: Every single step a of the----, the--, of the, the of the of, the,,
2025-05-12 19:19:41,908 - INFO - Epoch 4/10
2025-05-12 19:19:56,984 - INFO - Train loss: 5.825989829169379, Val loss: 6.6

([9.211953587002224,
  6.710766315460205,
  6.3690649138556585,
  5.825989829169379,
  5.760788334740533,
  5.573976887596978,
  5.3964655134412975,
  5.157639079623753,
  4.637348863813612,
  4.0321828789181176],
 [7.687678813934326,
  6.732245445251465,
  6.638359069824219,
  6.647461891174316,
  6.6304497718811035,
  6.514750003814697,
  6.544022560119629,
  6.402286529541016,
  6.2909698486328125,
  6.17617654800415],
 [0.03624131944444445,
  0.049262152777777776,
  0.058159722222222224,
  0.09722222222222222,
  0.09331597222222222,
  0.1115451388888889,
  0.1306423611111111,
  0.1421440972222222,
  0.17708333333333334,
  0.2419704861111111],
 [0.044921875,
  0.0234375,
  0.052734375,
  0.068359375,
  0.072265625,
  0.068359375,
  0.08203125,
  0.087890625,
  0.1015625,
  0.1171875])

In [9]:
wandb.finish()

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
seen tokens,▁▂▃▃▄▅▆▆▇█
train acc,▁▁▂▃▃▄▄▅▆█
train loss,█▅▄▃▃▃▃▃▂▁
val acc,▃▁▃▄▅▄▅▆▇█
val loss,█▄▃▃▃▃▃▂▂▁

0,1
seen tokens,46080.0
train acc,0.242
train loss,4.0322
val acc,0.1172
val loss,6.1762
