In [1]:
import torch
import torch.nn as nn
from transformers import GPT2Model
import torch.nn.functional as F
import wandb
from tqdm import tqdm

import numpy as np
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from einops import rearrange

In [2]:
class GPT2CIFAR10(nn.Module):
    def __init__(self, patch_size=4, num_classes=10, freeze_gpt2=True):
        super().__init__()
        
        # Load pretrained GPT2
        self.gpt2 = GPT2Model.from_pretrained('gpt2')
        self.hidden_size = self.gpt2.config.hidden_size  # 768 for base GPT2
        
        # CIFAR-10 characteristics
        self.image_size = 32
        self.patch_size = patch_size
        self.num_patches = (self.image_size // patch_size) ** 2
        
        # Patch embedding layer: from image patches to GPT2 hidden size
        self.patch_embedding = nn.Conv2d(3, self.hidden_size, 
                                       kernel_size=patch_size, 
                                       stride=patch_size)
        
        # Classification head
        self.classifier = nn.Linear(self.hidden_size, num_classes)
        
        if freeze_gpt2:
            # Freeze GPT2 parameters except LayerNorm and positional embeddings
            for name, param in self.gpt2.named_parameters():
                if 'ln' in name or 'wpe' in name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
    
    def forward(self, x):
        batch_size = x.shape[0]
        
        # Convert image to patches
        # Shape: (batch_size, hidden_size, h', w')
        patches = self.patch_embedding(x)
        
        # Reshape and transpose for GPT2
        # Shape: (batch_size, num_patches, hidden_size)
        patches = rearrange(patches, 'b d h w -> b (h w) d')
        
        # Pass through GPT2 and get last hidden state
        outputs = self.gpt2(inputs_embeds=patches)
        hidden_states = outputs.last_hidden_state
        
        # Use the last token's representation for classification
        cls_representation = hidden_states[:, -1]
        
        # Classify
        logits = self.classifier(cls_representation)
        
        return logits

In [3]:
class CIFAR10Trainer:
    def __init__(self, model, train_loader, val_loader, 
                 criterion, optimizer, device, config):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.config = config
        
        # Initialize metrics tracking
        self.train_losses = []
        self.train_accuracies = []
        self.val_losses = []
        self.val_accuracies = []
        
        # Initialize best validation accuracy for model saving
        self.best_val_acc = 0.0
        
    def train_epoch(self):
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(self.train_loader, desc='Training')
        for batch_idx, (images, labels) in enumerate(pbar):
            images, labels = images.to(self.device), labels.to(self.device)
            
            self.optimizer.zero_grad()
            outputs = self.model(images)
            loss = self.criterion(outputs, labels)
            
            loss.backward()
            self.optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            # Update progress bar
            pbar.set_postfix({
                'loss': running_loss/(batch_idx+1),
                'acc': 100.*correct/total
            })
        
        epoch_loss = running_loss / len(self.train_loader)
        epoch_acc = 100. * correct / total
        return epoch_loss, epoch_acc
    
    def validate(self):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in tqdm(self.val_loader, desc='Validation'):
                images, labels = images.to(self.device), labels.to(self.device)
                
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                
                running_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        val_loss = running_loss / len(self.val_loader)
        val_acc = 100. * correct / total
        return val_loss, val_acc
    
    def train(self, num_epochs):
        # Initialize wandb
        wandb.init(
            project="gpt2-cifar10", 
            config=self.config,
            name=f'{self.config["learning_rate"]}lr_{self.config["batch_size"]}bs_{self.config["patch_size"]}patch'
        )
        
        for epoch in range(num_epochs):
            # Training phase
            train_loss, train_acc = self.train_epoch()
            self.train_losses.append(train_loss)
            self.train_accuracies.append(train_acc)
            
            # Validation phase
            val_loss, val_acc = self.validate()
            self.val_losses.append(val_loss)
            self.val_accuracies.append(val_acc)
            
            # Log metrics
            wandb.log({
                "epoch": epoch,
                "train_loss": train_loss,
                "train_acc": train_acc,
                "val_loss": val_loss,
                "val_acc": val_acc,
            })
            
            # Save best model
            if val_acc > self.best_val_acc:
                self.best_val_acc = val_acc
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'val_acc': val_acc,
                }, 'best_model.pth')
                wandb.save('best_model.pth')
            
            print(f'Epoch: {epoch+1}/{num_epochs}')
            print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
            print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
            print('-' * 60)
        
        wandb.finish()

In [4]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Load CIFAR10
trainset = datasets.CIFAR10(root='./data', train=True,
                          download=True, transform=transform_train)
valset = datasets.CIFAR10(root='./data', train=False,
                         download=True, transform=transform_val)

trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
valloader = DataLoader(valset, batch_size=128, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:07<00:00, 23.3MB/s] 


Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [5]:
# Configuration
config = {
    'patch_size': 4,
    'num_classes': 10,
    'batch_size': 16,
    'learning_rate': 1e-3,
    'num_epochs': 100,
}

# Initialize model and training components
model = GPT2CIFAR10(patch_size=config['patch_size'], 
                    num_classes=config['num_classes'])
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Initialize trainer
trainer = CIFAR10Trainer(
    model=model,
    train_loader=trainloader,
    val_loader=valloader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    config=config
)

# Train model
trainer.train(num_epochs=config['num_epochs'])

wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: dporres (dporres-computer-vision-center). Use `wandb login --relogin` to force relogin


Training: 100%|██████████| 391/391 [16:33<00:00,  2.54s/it, loss=2.32, acc=14.6]  
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.57it/s]


Epoch: 1/50
Train Loss: 2.3167, Train Acc: 14.61%
Val Loss: 2.1477, Val Acc: 21.83%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:30<00:00,  2.60it/s, loss=2.19, acc=19.9]
Validation: 100%|██████████| 79/79 [00:16<00:00,  4.70it/s]


Epoch: 2/50
Train Loss: 2.1854, Train Acc: 19.93%
Val Loss: 2.0410, Val Acc: 24.67%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:28<00:00,  2.63it/s, loss=2.02, acc=26.1]
Validation: 100%|██████████| 79/79 [00:16<00:00,  4.67it/s]


Epoch: 3/50
Train Loss: 2.0153, Train Acc: 26.10%
Val Loss: 1.9293, Val Acc: 29.97%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:31<00:00,  2.58it/s, loss=1.89, acc=30.7]
Validation: 100%|██████████| 79/79 [00:16<00:00,  4.75it/s]


Epoch: 4/50
Train Loss: 1.8900, Train Acc: 30.66%
Val Loss: 1.8495, Val Acc: 34.01%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:29<00:00,  2.61it/s, loss=1.79, acc=34.1]
Validation: 100%|██████████| 79/79 [00:16<00:00,  4.72it/s]


Epoch: 5/50
Train Loss: 1.7879, Train Acc: 34.12%
Val Loss: 1.8029, Val Acc: 35.88%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:29<00:00,  2.61it/s, loss=1.71, acc=37.1]
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.64it/s]


Epoch: 6/50
Train Loss: 1.7066, Train Acc: 37.05%
Val Loss: 1.6674, Val Acc: 39.77%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:31<00:00,  2.58it/s, loss=1.65, acc=39.1]
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.41it/s]


Epoch: 7/50
Train Loss: 1.6483, Train Acc: 39.07%
Val Loss: 1.6357, Val Acc: 40.89%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:37<00:00,  2.48it/s, loss=1.61, acc=40.9]
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.41it/s]


Epoch: 8/50
Train Loss: 1.6056, Train Acc: 40.89%
Val Loss: 1.5743, Val Acc: 43.42%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s, loss=1.57, acc=42.5]
Validation: 100%|██████████| 79/79 [00:16<00:00,  4.72it/s]


Epoch: 9/50
Train Loss: 1.5651, Train Acc: 42.47%
Val Loss: 1.4955, Val Acc: 45.68%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:30<00:00,  2.61it/s, loss=1.53, acc=44]  
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.59it/s]


Epoch: 10/50
Train Loss: 1.5282, Train Acc: 44.05%
Val Loss: 1.4660, Val Acc: 46.96%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:30<00:00,  2.60it/s, loss=1.5, acc=45.5] 
Validation: 100%|██████████| 79/79 [00:16<00:00,  4.67it/s]


Epoch: 11/50
Train Loss: 1.4956, Train Acc: 45.50%
Val Loss: 1.4415, Val Acc: 47.79%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:30<00:00,  2.61it/s, loss=1.46, acc=46.8]
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.64it/s]


Epoch: 12/50
Train Loss: 1.4649, Train Acc: 46.79%
Val Loss: 1.3951, Val Acc: 49.60%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:30<00:00,  2.60it/s, loss=1.44, acc=47.5]
Validation: 100%|██████████| 79/79 [00:16<00:00,  4.70it/s]


Epoch: 13/50
Train Loss: 1.4435, Train Acc: 47.53%
Val Loss: 1.3835, Val Acc: 50.44%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:30<00:00,  2.60it/s, loss=1.42, acc=48.3]
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.61it/s]


Epoch: 14/50
Train Loss: 1.4197, Train Acc: 48.31%
Val Loss: 1.3608, Val Acc: 51.01%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:33<00:00,  2.55it/s, loss=1.4, acc=49]  
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.54it/s]


Epoch: 15/50
Train Loss: 1.4003, Train Acc: 49.04%
Val Loss: 1.3625, Val Acc: 50.51%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s, loss=1.38, acc=49.7]
Validation: 100%|██████████| 79/79 [00:18<00:00,  4.36it/s]


Epoch: 16/50
Train Loss: 1.3819, Train Acc: 49.74%
Val Loss: 1.3254, Val Acc: 52.31%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:39<00:00,  2.46it/s, loss=1.37, acc=50.3]
Validation: 100%|██████████| 79/79 [00:18<00:00,  4.24it/s]


Epoch: 17/50
Train Loss: 1.3669, Train Acc: 50.31%
Val Loss: 1.3098, Val Acc: 52.45%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:40<00:00,  2.44it/s, loss=1.35, acc=51.4]
Validation: 100%|██████████| 79/79 [00:18<00:00,  4.34it/s]


Epoch: 18/50
Train Loss: 1.3476, Train Acc: 51.40%
Val Loss: 1.2748, Val Acc: 54.12%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:39<00:00,  2.46it/s, loss=1.33, acc=51.6]
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.48it/s]


Epoch: 19/50
Train Loss: 1.3309, Train Acc: 51.57%
Val Loss: 1.2554, Val Acc: 55.02%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:36<00:00,  2.50it/s, loss=1.32, acc=52.1]
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.49it/s]


Epoch: 20/50
Train Loss: 1.3244, Train Acc: 52.12%
Val Loss: 1.2419, Val Acc: 55.39%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:33<00:00,  2.54it/s, loss=1.31, acc=52.7]
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.53it/s]


Epoch: 21/50
Train Loss: 1.3076, Train Acc: 52.65%
Val Loss: 1.2695, Val Acc: 55.23%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:28<00:00,  2.63it/s, loss=1.3, acc=53]   
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.63it/s]


Epoch: 22/50
Train Loss: 1.3011, Train Acc: 53.00%
Val Loss: 1.2421, Val Acc: 55.26%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:33<00:00,  2.54it/s, loss=1.29, acc=53.5]
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.59it/s]


Epoch: 23/50
Train Loss: 1.2907, Train Acc: 53.53%
Val Loss: 1.2274, Val Acc: 56.30%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:33<00:00,  2.55it/s, loss=1.28, acc=53.9]
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.50it/s]


Epoch: 24/50
Train Loss: 1.2780, Train Acc: 53.86%
Val Loss: 1.2272, Val Acc: 56.01%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:36<00:00,  2.50it/s, loss=1.27, acc=54.1]
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.60it/s]


Epoch: 25/50
Train Loss: 1.2681, Train Acc: 54.10%
Val Loss: 1.2306, Val Acc: 56.01%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s, loss=1.26, acc=54.3]
Validation: 100%|██████████| 79/79 [00:16<00:00,  4.65it/s]


Epoch: 26/50
Train Loss: 1.2618, Train Acc: 54.33%
Val Loss: 1.1964, Val Acc: 57.28%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:22<00:00,  2.75it/s, loss=1.25, acc=54.8]
Validation: 100%|██████████| 79/79 [00:15<00:00,  5.14it/s]


Epoch: 27/50
Train Loss: 1.2487, Train Acc: 54.78%
Val Loss: 1.1720, Val Acc: 58.52%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:17<00:00,  2.85it/s, loss=1.24, acc=54.9]
Validation: 100%|██████████| 79/79 [00:15<00:00,  4.99it/s]


Epoch: 28/50
Train Loss: 1.2433, Train Acc: 54.92%
Val Loss: 1.1745, Val Acc: 58.54%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:17<00:00,  2.85it/s, loss=1.24, acc=55.2]
Validation: 100%|██████████| 79/79 [00:16<00:00,  4.93it/s]


Epoch: 29/50
Train Loss: 1.2400, Train Acc: 55.21%
Val Loss: 1.1722, Val Acc: 58.52%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:23<00:00,  2.73it/s, loss=1.23, acc=55.6]
Validation: 100%|██████████| 79/79 [00:16<00:00,  4.73it/s]


Epoch: 30/50
Train Loss: 1.2276, Train Acc: 55.62%
Val Loss: 1.1777, Val Acc: 58.37%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:29<00:00,  2.61it/s, loss=1.22, acc=56.2]
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.55it/s]


Epoch: 31/50
Train Loss: 1.2228, Train Acc: 56.17%
Val Loss: 1.1646, Val Acc: 58.45%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:28<00:00,  2.63it/s, loss=1.22, acc=56.1]
Validation: 100%|██████████| 79/79 [00:16<00:00,  4.82it/s]


Epoch: 32/50
Train Loss: 1.2207, Train Acc: 56.12%
Val Loss: 1.1711, Val Acc: 58.58%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:34<00:00,  2.53it/s, loss=1.21, acc=56.2]
Validation: 100%|██████████| 79/79 [00:16<00:00,  4.75it/s]


Epoch: 33/50
Train Loss: 1.2107, Train Acc: 56.22%
Val Loss: 1.1445, Val Acc: 59.19%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:22<00:00,  2.75it/s, loss=1.21, acc=56.5]
Validation: 100%|██████████| 79/79 [00:15<00:00,  5.12it/s]


Epoch: 34/50
Train Loss: 1.2057, Train Acc: 56.51%
Val Loss: 1.1477, Val Acc: 59.54%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:23<00:00,  2.72it/s, loss=1.2, acc=56.8] 
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.56it/s]


Epoch: 35/50
Train Loss: 1.1992, Train Acc: 56.78%
Val Loss: 1.1338, Val Acc: 59.91%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:26<00:00,  2.66it/s, loss=1.19, acc=56.8]
Validation: 100%|██████████| 79/79 [00:16<00:00,  4.75it/s]


Epoch: 36/50
Train Loss: 1.1930, Train Acc: 56.80%
Val Loss: 1.1357, Val Acc: 60.02%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:30<00:00,  2.61it/s, loss=1.18, acc=57.5]
Validation: 100%|██████████| 79/79 [00:16<00:00,  4.74it/s]


Epoch: 37/50
Train Loss: 1.1803, Train Acc: 57.48%
Val Loss: 1.1290, Val Acc: 60.36%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:28<00:00,  2.64it/s, loss=1.18, acc=57.5]
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.45it/s]


Epoch: 38/50
Train Loss: 1.1803, Train Acc: 57.53%
Val Loss: 1.1279, Val Acc: 59.90%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:33<00:00,  2.56it/s, loss=1.18, acc=57.9]
Validation: 100%|██████████| 79/79 [00:16<00:00,  4.74it/s]


Epoch: 39/50
Train Loss: 1.1773, Train Acc: 57.92%
Val Loss: 1.1152, Val Acc: 60.69%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:29<00:00,  2.61it/s, loss=1.17, acc=58]  
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.64it/s]


Epoch: 40/50
Train Loss: 1.1666, Train Acc: 58.04%
Val Loss: 1.0982, Val Acc: 61.26%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:28<00:00,  2.63it/s, loss=1.17, acc=58.1]
Validation: 100%|██████████| 79/79 [00:16<00:00,  4.68it/s]


Epoch: 41/50
Train Loss: 1.1675, Train Acc: 58.08%
Val Loss: 1.1031, Val Acc: 60.58%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:29<00:00,  2.61it/s, loss=1.16, acc=58]  
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.57it/s]


Epoch: 42/50
Train Loss: 1.1623, Train Acc: 58.01%
Val Loss: 1.0878, Val Acc: 61.49%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:27<00:00,  2.65it/s, loss=1.16, acc=58.1]
Validation: 100%|██████████| 79/79 [00:16<00:00,  4.71it/s]


Epoch: 43/50
Train Loss: 1.1594, Train Acc: 58.14%
Val Loss: 1.0954, Val Acc: 61.35%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:27<00:00,  2.65it/s, loss=1.15, acc=58.6]
Validation: 100%|██████████| 79/79 [00:16<00:00,  4.73it/s]


Epoch: 44/50
Train Loss: 1.1518, Train Acc: 58.55%
Val Loss: 1.0833, Val Acc: 61.50%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:28<00:00,  2.63it/s, loss=1.14, acc=58.9]
Validation: 100%|██████████| 79/79 [00:16<00:00,  4.67it/s]


Epoch: 45/50
Train Loss: 1.1425, Train Acc: 58.87%
Val Loss: 1.0991, Val Acc: 60.89%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:28<00:00,  2.63it/s, loss=1.14, acc=58.8]
Validation: 100%|██████████| 79/79 [00:16<00:00,  4.68it/s]


Epoch: 46/50
Train Loss: 1.1433, Train Acc: 58.78%
Val Loss: 1.0851, Val Acc: 61.50%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:31<00:00,  2.58it/s, loss=1.14, acc=59.2]
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.63it/s]


Epoch: 47/50
Train Loss: 1.1364, Train Acc: 59.17%
Val Loss: 1.0792, Val Acc: 61.61%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:36<00:00,  2.50it/s, loss=1.13, acc=59.5]
Validation: 100%|██████████| 79/79 [00:18<00:00,  4.31it/s]


Epoch: 48/50
Train Loss: 1.1299, Train Acc: 59.46%
Val Loss: 1.0756, Val Acc: 61.87%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:37<00:00,  2.49it/s, loss=1.13, acc=59.5]
Validation: 100%|██████████| 79/79 [00:17<00:00,  4.63it/s]


Epoch: 49/50
Train Loss: 1.1275, Train Acc: 59.52%
Val Loss: 1.0723, Val Acc: 61.98%
------------------------------------------------------------


Training: 100%|██████████| 391/391 [02:29<00:00,  2.61it/s, loss=1.12, acc=59.7]
Validation: 100%|██████████| 79/79 [00:16<00:00,  4.68it/s]


Epoch: 50/50
Train Loss: 1.1245, Train Acc: 59.68%
Val Loss: 1.0737, Val Acc: 62.22%
------------------------------------------------------------


VBox(children=(Label(value='10.442 MB of 353.161 MB uploaded\r'), FloatProgress(value=0.029568185148566333, ma…

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇███
train_acc,▁▂▃▃▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇██████████████
train_loss,█▇▆▅▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
val_acc,▁▁▂▃▃▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇██████████████
val_loss,█▇▇▆▆▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,49.0
train_acc,59.676
train_loss,1.12449
val_acc,62.22
val_loss,1.0737


## Visualizing trained model

In [6]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns

from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np

from PIL import Image

ModuleNotFoundError: No module named 'seaborn'

In [None]:
class GPT2Visualizer:
    def __init__(self, model, device, class_names):
        self.model = model.to(device)
        self.device = device
        self.class_names = class_names
        self.model.eval()
        
        # Save reference to GPT2 attention
        self.attention_maps = []
        
        # Register hook to get attention weights
        def attention_hook(module, input, output):
            # Get attention weights from output tuple
            # Shape: (batch_size, num_heads, sequence_length, sequence_length)
            self.attention_maps.append(output[0].detach())
        
        # Register hooks for all attention blocks
        for name, module in model.named_modules():
            if "attn" in name and "block" in name:
                module.register_forward_hook(attention_hook)
        
        # Standard CIFAR-10 normalization
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), 
                               (0.2023, 0.1994, 0.2010))
        ])
    
    def predict_and_visualize(self, images, true_labels=None, num_images=5):
        """
        Visualize predictions and attention maps for a batch of images
        
        Args:
            images: List of PIL images or tensor of shape (N, C, H, W)
            true_labels: Optional list of true labels
            num_images: Number of images to visualize
        """
        # Clear previous attention maps
        self.attention_maps = []
        
        # Prepare images if they're PIL
        if not torch.is_tensor(images):
            tensors = []
            for img in images:
                tensors.append(self.transform(img))
            images = torch.stack(tensors)
        
        # Move to device
        images = images.to(self.device)
        
        # Get predictions
        with torch.no_grad():
            outputs = self.model(images[:num_images])
            predictions = outputs.argmax(dim=1)
        
        # Get attention weights (average over heads and layers)
        # Shape: (batch_size, num_patches, num_patches)
        avg_attention = torch.mean(torch.stack([
            torch.mean(attention, dim=1) 
            for attention in self.attention_maps
        ]), dim=0)
        
        # Create figure
        num_cols = 3  # image, attention, patch attention
        fig = plt.figure(figsize=(15, 5 * num_images))
        
        for idx in range(num_images):
            # Original image with prediction
            ax1 = plt.subplot(num_images, num_cols, idx * num_cols + 1)
            img = images[idx].cpu()
            img = img * torch.tensor([0.2023, 0.1994, 0.2010]).view(3, 1, 1) + \
                  torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
            plt.imshow(img.permute(1, 2, 0).clip(0, 1))
            
            # Set title color based on prediction
            pred_class = self.class_names[predictions[idx]]
            if true_labels is not None:
                color = 'green' if predictions[idx] == true_labels[idx] else 'red'
                title = f'Pred: {pred_class}\nTrue: {self.class_names[true_labels[idx]]}'
            else:
                color = 'black'
                title = f'Pred: {pred_class}'
            
            ax1.set_title(title, color=color)
            plt.axis('off')
            
            # Attention heatmap
            ax2 = plt.subplot(num_images, num_cols, idx * num_cols + 2)
            attention_map = avg_attention[idx].cpu()
            sns.heatmap(attention_map, cmap='viridis')
            ax2.set_title('Average Self-Attention')
            
            # Patch-wise attention visualization
            ax3 = plt.subplot(num_images, num_cols, idx * num_cols + 3)
            # Get attention for the classification token (last token)
            patch_attention = attention_map[-1, :-1].reshape(4, 4)  # for 8x8 patches
            sns.heatmap(patch_attention, cmap='viridis')
            ax3.set_title('Patch Attention Weights')
        
        plt.tight_layout()
        return fig

In [None]:
# Load best model
checkpoint = torch.load('C:\\Users\\Windows\\Documents\\CVC\\repos\seeing-language\\notebooks\wandb\\run-20241111_230849-kjps7qnm\\files\\best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])

# CIFAR-10 class names
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

# Initialize visualizer
visualizer = GPT2Visualizer(model, device, class_names)

transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Load CIFAR10
valset = datasets.CIFAR10(root='./data', train=False,
                         download=True, transform=transform_val)

valloader = DataLoader(valset, batch_size=128, shuffle=False, num_workers=2)

# Get some test images
dataiter = iter(valloader)
images, labels = next(dataiter)

# Visualize predictions and attention
fig = visualizer.predict_and_visualize(images[:5], labels[:5])
plt.show()

# To save the figure
# fig.savefig('predictions_attention.png', bbox_inches='tight', dpi=300)