In [19]:
import sys
sys.path.append('../src')

from training import train
from models import transformer
import datasets
import torch
import os
import re
import numpy as np
import plotly.express as px


def get_checkpoint_epochs(path):
    checkpoint_files = [f for f in os.listdir(path)]
    epochs = []
    epoch_pattern = re.compile(r'epoch_(\d+)')
    for file in checkpoint_files:
        match = epoch_pattern.search(file)
        if match:
            epoch = int(match.group(1))
            epochs.append(epoch)
    return epochs

def get_model_and_optimizer(checkpoint_dir, epoch):
    checkpoint_dict = torch.load(f'{checkpoint_dir}/epoch_{epoch}')
    model_config = checkpoint_dict['config']
    model_state_dict = checkpoint_dict['model']
    optimizer_state_dict = checkpoint_dict['optimizer']
    model = transformer.Transformer(model_config)
    model.load_state_dict(model_state_dict)
    optimizer = torch.optim.AdamW(model.parameters(), lr = model.config.lr, weight_decay = model.config.weight_decay, betas=model.config.betas)
    optimizer.load_state_dict(optimizer_state_dict)    
    return model, optimizer

checkpoint_dir = '../src/checkpoints'
epochs = get_checkpoint_epochs(checkpoint_dir)
epochs.sort()
epoch = epochs[0]
model, optimizer = get_model_and_optimizer(checkpoint_dir, epoch)

number_of_epochs_until_spike = 24
train_dataloader, test_dataloader = datasets.modular_addition.generate_test_train_split(model.config.prime, model.config.frac_train, model.config.seed)
model.config.wandb = False
model.config.save_checkpoints = False
model.config.num_epochs = number_of_epochs_until_spike
train.train(model, train_dataloader, test_dataloader)



Using the following device: cuda!


100%|██████████| 24/24 [00:00<00:00, 58.03it/s]


In [20]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'
for batch_data in train_dataloader:
    batch_input = batch_data['data'].to(device)
    batch_labels = batch_data['label'].to(device)
    output = model(batch_input)
    train_loss = train.get_full_loss(output, batch_labels).detach()

data_point = torch.argmax(train_loss).item()

for batch_data in train_dataloader:
    corrupt_input = batch_data['data'].to(device)[data_point].unsqueeze(0)
    corupt_label = batch_data['label'].to(device)[data_point].unsqueeze(0)
    
second_probs = torch.nn.functional.softmax(model(corrupt_input).squeeze()[-1], dim = 0)
second_prediction = torch.argmax(second_probs).item()

model, optimizer = get_model_and_optimizer(checkpoint_dir, epoch)
model.config.wandb = False
model.config.save_checkpoints = False
model.config.num_epochs = number_of_epochs_until_spike - 1
train.train(model, train_dataloader, test_dataloader)

for batch_data in train_dataloader:
    batch_input = batch_data['data'].to(device)
    batch_labels = batch_data['label'].to(device)
    output = model(batch_input)
    train_loss = train.get_full_loss(output, batch_labels).detach()
    print(f'Batch dim causing max error in the first prediction: {torch.argmax(train_loss).item()}')
    print(f'Index of label causing max error in the first prediction: {batch_data['label'].to(device)[torch.argmax(train_loss).item()]}')
    print(f'Batch dim causing max error in the second prediction: {data_point}')
    print(f'Index of label causing max error in the second prediction: {corupt_label.squeeze()}')
    

first_probs = torch.nn.functional.softmax(model(corrupt_input).squeeze()[-1], dim = 0)
first_prediction = torch.argmax(first_probs).item()

print('\n'*3)
print(f'Input: {corrupt_input.squeeze()}')
print(f'Label: {corupt_label.squeeze()}')
print()
print(f'First prediction: {corupt_label.squeeze().item()}')
print(f'Probability of correct label: {first_probs[corupt_label.squeeze().item()]}')
print(f'Probability of incorrect label: {first_probs[second_prediction]}')
print()
print(f'Second prediction: {second_prediction}')
print(f'Probability of correct label: {second_probs[corupt_label.squeeze().item()]}')
print(f'Probability of incorrect label: {second_probs[second_prediction]}')

Using the following device: cuda!


  0%|          | 0/23 [00:00<?, ?it/s]

100%|██████████| 23/23 [00:00<00:00, 70.90it/s]

Batch dim causing max error in the first prediction: 2985
Index of label causing max error in the first prediction: 51
Batch dim causing max error in the second prediction: 3680
Index of label causing max error in the second prediction: 22




Input: tensor([ 73,  62, 113], device='cuda:0')
Label: 22

First prediction: 22
Probability of correct label: 0.9998949766159058
Probability of incorrect label: 6.114076677476987e-05

Second prediction: 51
Probability of correct label: 2.914002470788546e-05
Probability of incorrect label: 0.9972339272499084



