In [24]:
import yaml
import tiktoken
import torch
from torch import nn

from processing_data.dataset import Data,SpamDataset
from processing_data.dataloader import get_data_loader
from embeddings import Embeddings
from transformer_block import TransformerBlock
from gpt2 import GPT2Model
from utils import text_to_tokens,tokens_to_text,generate_text
from loss import cross_entropy,classification_loss
from train import traininng_loop
from evaluation import eval

with open("config.yaml","r") as f:
    config = yaml.safe_load(f)

In [16]:
torch.set_printoptions(sci_mode=False,precision=10)

In [21]:
with open("raw_data/the-verdict.txt","r") as f:
    raw_text = f.read()

len(text_to_tokens(raw_text)[0])

5145

In [22]:
train_dateset = SpamDataset(
    csv_path='raw_data/sms_spam_collection/train.csv',
    tokenizer=tiktoken.get_encoding("gpt2"),
    max_len=None
)
val_dataset = SpamDataset(
    csv_path='raw_data/sms_spam_collection/val.csv',
    tokenizer=tiktoken.get_encoding("gpt2"),
    max_len=train_dateset.max_len
)

test_dataset = SpamDataset(
    csv_path='raw_data/sms_spam_collection/test.csv',
    tokenizer=tiktoken.get_encoding("gpt2"),
    max_len=train_dateset.max_len
)

train_dl = get_data_loader(train_dateset,batch_size=32,shuffle=False,drop_last=True,num_workers=0)
val_dl = get_data_loader(val_dataset,batch_size=32,shuffle=False,drop_last=True,num_workers=0)
test_dl = get_data_loader(test_dataset,batch_size=32,shuffle=False,drop_last=True,num_workers=0)

In [5]:
train_dateset.max_len

118

In [13]:
for x,y in train_dl:
    print(x.shape)
    print('-'*100)
    print(y.shape)

torch.Size([32, 118])
----------------------------------------------------------------------------------------------------
torch.Size([32])
torch.Size([32, 118])
----------------------------------------------------------------------------------------------------
torch.Size([32])
torch.Size([32, 118])
----------------------------------------------------------------------------------------------------
torch.Size([32])
torch.Size([32, 118])
----------------------------------------------------------------------------------------------------
torch.Size([32])
torch.Size([32, 118])
----------------------------------------------------------------------------------------------------
torch.Size([32])
torch.Size([32, 118])
----------------------------------------------------------------------------------------------------
torch.Size([32])
torch.Size([32, 118])
----------------------------------------------------------------------------------------------------
torch.Size([32])
torch.Size([32, 118]

In [11]:
len(train_dl)

32

In [4]:
train_ratio = 0.9
split_index = int(len(raw_text) * train_ratio)
train_text = raw_text[:split_index]
val_text = raw_text[split_index:]


# Dataset & DataLoader 

In [5]:
train_dataset = Data(
    raw_text=train_text,
    tokenizer=tiktoken.get_encoding("gpt2"),
    context_length=config["context_window"],
    stride=config["stride"]
)

val_dataset = Data(
    raw_text=val_text,
    tokenizer=tiktoken.get_encoding("gpt2"),
    context_length=config["context_window"],
    stride=config["stride"]
)

train_dl = get_data_loader(
    train_dataset,
    batch_size=config["batch_size"],
    shuffle=config["shuffle"],
    drop_last=config["drop_last"],
    num_workers=config["num_workers"]
    )

val_dl = get_data_loader(
    val_dataset,
    batch_size=config["batch_size"],
    shuffle=config["shuffle"],
    drop_last=config["drop_last"],
    num_workers=config["num_workers"]
)


In [7]:
# for x,y in train_dl:
#     print(x.shape)
#     print(y.shape)
#     break

In [8]:
# train_tokens = 0 
# for x,y in train_dl:
#     train_tokens += x.numel()
# print(f"Train tokens: {train_tokens}")

# val_tokens = 0
# for x,y in val_dl:
#     val_tokens += x.numel()
# print(f"Val tokens: {val_tokens}")


# print(f'total tokens: {train_tokens + val_tokens}')

In [6]:
config['num_classes'] = 2

model = GPT2Model(config)

# with torch.no_grad():
#     logits = model(x)

#     print(logits.shape)


In [8]:
classification_loss(logits,y)

tensor(0.6434470415)

In [7]:
optimizer = torch.optim.AdamW(model.parameters(),lr=0.0004)

In [8]:
traininng_loop(
    model,
    train_dl,
    val_dl,
    loss_fn = classification_loss,
    optimizer = optimizer,
    num_epochs = 10,
    device = "cpu",
    # text_to_generate = "Every single step",
    look_back = config["context_window"],
    num_tokens_to_generate = config["num_tokens_to_generate"],
)

2025-05-02 18:34:45,320 - INFO - Epoch 1/10


check point 1
check point 2
check point 3
check point 4
check point 5
check point 6
check point 7
check point 8
check point 9
check point 1
check point 2
check point 3
check point 4
check point 5
check point 6
check point 7
check point 8
check point 9
check point 1
check point 2
check point 3
check point 4
check point 5
check point 6
check point 7
check point 8
check point 9
check point 1
check point 2
check point 3
check point 4
check point 5
check point 6
check point 7
check point 8
check point 9
check point 1
check point 2
check point 3
check point 4
check point 5
check point 6
check point 7
check point 8
check point 9
check point 1
check point 2
check point 3
check point 4
check point 5
check point 6
check point 7
check point 8
check point 9
check point 1
check point 2
check point 3
check point 4
check point 5
check point 6
check point 7
check point 8
check point 9
check point 1
check point 2
check point 3
check point 4
check point 5
check point 6
check point 7
check point 8
check 

2025-05-02 18:37:02,559 - INFO - Seen tokens: 120832
2025-05-02 18:37:02,560 - INFO - Loss: 1.3120


check point 7
check point 8
check point 9
check point 10


2025-05-02 18:37:06,698 - INFO - Validation Loss: 0.7209
2025-05-02 18:37:06,699 - INFO - Epoch 2/10


check point 1
check point 2
check point 3
check point 4
check point 5
check point 6
check point 7
check point 8
check point 9
check point 1
check point 2
check point 3
check point 4
check point 5
check point 6


KeyboardInterrupt: 

In [12]:
eval(
    model,
    val_loader=val_dl,
    loss_fn= cross_entropy,
    device='cpu'
)

2025-04-29 19:14:24,002 - INFO - Validation Loss: 6.3376
