In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from thoughtsformer import ThoughtsFormer

In [2]:
from tiny_shakespeare import TinyShakespeareDataset, TokenizerType
from character_tokenizer import ShakespeareCharacterTokenizer
dataset_tr = TinyShakespeareDataset(128,64, split="train", tokenizer=TokenizerType.CHARACTER_LEVEL)
dataset_te = TinyShakespeareDataset(128,64, split="test", tokenizer=TokenizerType.CHARACTER_LEVEL)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = torch.device("cuda")
max_sequence_length = 128
vocab_size = 65
num_layers = 1
n_head = 6
d_embed = 384
dim_feedforward = d_embed
dropout = 0

def return_model_with_n_thoughts(n: int):
  return ThoughtsFormer(
      max_thought_len=n,
      max_sequence_length=max_sequence_length,
      vocab_size=vocab_size,
      num_layers=num_layers,
      n_head=n_head,
      d_embed=d_embed,
      dim_feedforward=dim_feedforward,
      dropout=dropout
  ).to(device)

m0 = return_model_with_n_thoughts(0)
m1 = return_model_with_n_thoughts(3)
m2 = return_model_with_n_thoughts(4)
m3 = return_model_with_n_thoughts(15)



In [4]:
dataset = TinyShakespeareDataset(512,64)
train_size = len(dataset_tr)
test_size = len(dataset_te) 

train_loader = DataLoader(dataset_tr, batch_size=8, shuffle=True)
test_loader = DataLoader(dataset_te, batch_size=8, shuffle=False)

Token indices sequence length is longer than the specified maximum sequence length for this model (301966 > 1024). Running this sequence through the model will result in indexing errors


In [None]:
epochs = 10

models = [m2]
loss_fn = F.cross_entropy
optims = [torch.optim.Adam(params=m.parameters(), lr=0.001) for m in models]


# Assume your train_loader provides input tensors of shape [batch_size, 1000, d_embed]
losses_over_time = [ [] for _ in models]
test_losses_over_time = [ [] for _ in models]

for epoch in range(epochs):
    [m.train() for m in models]
    for idx, (tokens, labels) in enumerate(train_loader):
        batch_size, sequence_length = tokens.shape


        # Create padding mask (no padding here, but adding for future flexibility)
        padding_mask = torch.zeros(batch_size, sequence_length).to(device) # additional padding is done internally

        tokens = tokens.to(device)
        # Forward pass through the model
        ls_logits = [m(tokens, padding_mask) for m in models]
        # print(thoughts_logits.shape)
        thoughts_loss_ls = [loss_fn(logit.permute(0, 2, 1), labels.to(device)) for logit in ls_logits]
        [loss_ls.append(thought_loss.item()) for thought_loss, loss_ls in zip(thoughts_loss_ls, losses_over_time)]

        [optim.zero_grad() for optim in optims]
        [loss.backward() for loss in thoughts_loss_ls]
        [optim.step() for optim in optims]

        [print(f"Thoughtsformer Train Loss at batch {idx}, epoch {epoch}: {loss.item()}") for loss in thoughts_loss_ls]

    # Validate the model on the test set after each epoch
    [m.eval() for m in models]  # Set model to evaluation mode
    test_losses = [0 for _ in models]
    with torch.no_grad():  # Disable gradient calculation
        for idx, (tokens, labels) in enumerate(test_loader):
            batch_size, sequence_length = tokens.shape


            # Create padding mask (no padding here, but adding for future flexibility)
            padding_mask = torch.zeros(batch_size, sequence_length).to(device) # additional padding is done internally
            tokens = tokens.to(device)
            # Forward pass through the model
            ls_logits = [m(tokens, padding_mask) for m in models]
            test_loss_ls = [loss_fn(logit.permute(0, 2, 1), labels.to(device)) for logit in ls_logits]
        
            for i in range(len(test_losses)):
                test_losses[i] += test_loss_ls[i].item()

    avg_test_losses = [test_loss / len(test_loader) for test_loss in test_losses]
    [test_loss_over_time.append(avg_test_losses) for avg_test_loss, test_loss_over_time in zip(avg_test_losses, test_losses_over_time)]
    [print(f"Test Loss after epoch {epoch}: {avg_test_loss}") for avg_test_loss in avg_test_losses]




In [5]:
def get_backward_memory_in_gb(model: torch.nn.Module) -> float:
    """
    Calculate memory used by gradients during backward pass.
    Call this after loss.backward() to get memory usage.
    
    Args:
        model: PyTorch model after backward pass
    
    Returns:
        Memory usage in GB
    """
    total_memory = 0
    for param in model.parameters():
        if param.grad is not None:
            # Each gradient is stored in float32 (4 bytes)
            total_memory += param.grad.nelement() * param.grad.element_size()
            
            # For BPTT, account for gradient accumulation buffers
            if hasattr(param, '_grad_accumulator'):
                total_memory += param.grad.nelement() * param.grad.element_size()
    
    return total_memory / (1024**3)  # Convert bytes to GB

  test = test.load_state_dict(torch.load("saves/possibly_flawed_transformer.pth"))
