## Main Method for Supervised Learning
Trained with BPTT. Intermediate tokens are generated embeddings.

In [None]:
import torch
import torch.optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

import matplotlib.pyplot as plt
from thoughtsformer import ThoughtsFormer
from tiny_shakespeare import TinyShakespeareDataset
from transformers import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
dataset = TinyShakespeareDataset(512,64)
train_size = int(0.8 * len(dataset))  # 80% for training
test_size = len(dataset) - train_size  # 20% for testing

# Split the dataset into training and testing sets
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

In [None]:
thoughtsformer = ThoughtsFormer.from_pretrained_GPT2(1, reinforcement_learning=False).to(device)
thoughtsformer.train()

In [None]:
vocab_size = 50257  # GPT-2 tokenizer vocabulary size
d_embed = 768  # Embedding dimension
epochs = 30
sequence_length = 256

# thoughtsformer.load_state_dict(torch.load(r'/content/drive/MyDrive/Machine Learning/gpt2_starting_thoughtsformer.pth'))

loss_fn = F.cross_entropy
thought_optim = torch.optim.Adam(params=thoughtsformer.parameters(), lr=0.0003)


# Assume your train_loader provides input tensors of shape [batch_size, 1000, d_embed]
loss_over_time = []
test_loss_over_time = []

for epoch in range(epochs):
    thoughtsformer.train()
    for idx, (tokens, labels) in enumerate(train_loader):
        batch_size, sequence_length = tokens.shape


        # Create padding mask (no padding here, but adding for future flexibility)
        padding_mask = torch.zeros(batch_size, sequence_length).to(device) # additional padding is done internally

        tokens = tokens.to(device)
        # Forward pass through the model
        thoughts_logits = thoughtsformer(tokens, padding_mask)
        # print(thoughts_logits.shape)
        thoughts_loss = loss_fn(thoughts_logits.permute(0, 2, 1), labels.to(device))
        loss_over_time.append(thoughts_loss.item())



        thought_optim.zero_grad()
        thoughts_loss.backward()
        thought_optim.step()

        print(f"Thoughtsformer Train Loss at batch {idx}, epoch {epoch}: {thoughts_loss.item()}")

    # Validate the model on the test set after each epoch
    thoughtsformer.eval()  # Set model to evaluation mode
    test_loss = 0
    with torch.no_grad():  # Disable gradient calculation
        for idx, (tokens, labels) in enumerate(test_loader):
            batch_size, sequence_length = tokens.shape


            # Create padding mask (no padding here, but adding for future flexibility)
            padding_mask = torch.zeros(batch_size, sequence_length).to(device) # additional padding is done internally

            # Forward pass through the model
            thoughts_logits = thoughtsformer(tokens.to(device), padding_mask.to(device))

            loss = loss_fn(thoughts_logits.permute(0, 2, 1), labels.to(device))

            test_loss += loss.item()

    avg_test_loss = test_loss / len(test_loader)
    test_loss_over_time.append(avg_test_loss)
    print(f"Test Loss after epoch {epoch}: {avg_test_loss}")
