In [15]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns

from transformers import GPT2Model

from torch.utils.data import DataLoader
import torch.nn as nn
from torchvision import datasets, transforms
import numpy as np
from einops import rearrange

from PIL import Image

In [11]:
class GPT2CIFAR10(nn.Module):
    def __init__(self, patch_size=4, num_classes=10, freeze_gpt2=True):
        super().__init__()
        
        # Load pretrained GPT2
        self.gpt2 = GPT2Model.from_pretrained('gpt2')
        self.hidden_size = self.gpt2.config.hidden_size  # 768 for base GPT2
        
        # CIFAR-10 characteristics
        self.image_size = 32
        self.patch_size = patch_size
        self.num_patches = (self.image_size // patch_size) ** 2
        
        # Patch embedding layer: from image patches to GPT2 hidden size
        self.patch_embedding = nn.Conv2d(3, self.hidden_size, 
                                       kernel_size=patch_size, 
                                       stride=patch_size)
        
        # Classification head
        self.classifier = nn.Linear(self.hidden_size, num_classes)
        
        if freeze_gpt2:
            # Freeze GPT2 parameters except LayerNorm and positional embeddings
            for name, param in self.gpt2.named_parameters():
                if 'ln' in name or 'wpe' in name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
    
    def forward(self, x):
        batch_size = x.shape[0]
        
        # Convert image to patches
        # Shape: (batch_size, hidden_size, h', w')
        patches = self.patch_embedding(x)
        
        # Reshape and transpose for GPT2
        # Shape: (batch_size, num_patches, hidden_size)
        patches = rearrange(patches, 'b d h w -> b (h w) d')
        
        # Pass through GPT2 and get last hidden state
        outputs = self.gpt2(inputs_embeds=patches)
        hidden_states = outputs.last_hidden_state
        
        # Use the last token's representation for classification
        cls_representation = hidden_states[:, -1]
        
        # Classify
        logits = self.classifier(cls_representation)
        
        return logits

In [12]:
class GPT2Visualizer:
    def __init__(self, model, device, class_names):
        self.model = model.to(device)
        self.device = device
        self.class_names = class_names
        self.model.eval()
        
        # Save reference to GPT2 attention
        self.attention_maps = []
        
        # Register hook to get attention weights
        def attention_hook(module, input, output):
            # Get attention weights from output tuple
            # Shape: (batch_size, num_heads, sequence_length, sequence_length)
            self.attention_maps.append(output[0].detach())
        
        # Register hooks for all attention blocks
        for name, module in model.named_modules():
            if "attn" in name and "block" in name:
                module.register_forward_hook(attention_hook)
        
        # Standard CIFAR-10 normalization
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), 
                               (0.2023, 0.1994, 0.2010))
        ])
    
    def predict_and_visualize(self, images, true_labels=None, num_images=5):
        """
        Visualize predictions and attention maps for a batch of images
        
        Args:
            images: List of PIL images or tensor of shape (N, C, H, W)
            true_labels: Optional list of true labels
            num_images: Number of images to visualize
        """
        # Clear previous attention maps
        self.attention_maps = []
        
        # Prepare images if they're PIL
        if not torch.is_tensor(images):
            tensors = []
            for img in images:
                tensors.append(self.transform(img))
            images = torch.stack(tensors)
        
        # Move to device
        images = images.to(self.device)
        
        # Get predictions
        with torch.no_grad():
            outputs = self.model(images[:num_images])
            predictions = outputs.argmax(dim=1)
        
        # Get attention weights (average over heads and layers)
        # Shape: (batch_size, num_patches, num_patches)
        avg_attention = torch.mean(torch.stack([
            torch.mean(attention, dim=1) 
            for attention in self.attention_maps
        ]), dim=0)
        
        # Create figure
        num_cols = 3  # image, attention, patch attention
        fig = plt.figure(figsize=(15, 5 * num_images))
        
        for idx in range(num_images):
            # Original image with prediction
            ax1 = plt.subplot(num_images, num_cols, idx * num_cols + 1)
            img = images[idx].cpu()
            img = img * torch.tensor([0.2023, 0.1994, 0.2010]).view(3, 1, 1) + \
                  torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
            plt.imshow(img.permute(1, 2, 0).clip(0, 1))
            
            # Set title color based on prediction
            pred_class = self.class_names[predictions[idx]]
            if true_labels is not None:
                color = 'green' if predictions[idx] == true_labels[idx] else 'red'
                title = f'Pred: {pred_class}\nTrue: {self.class_names[true_labels[idx]]}'
            else:
                color = 'black'
                title = f'Pred: {pred_class}'
            
            ax1.set_title(title, color=color)
            plt.axis('off')
            
            # Attention heatmap
            ax2 = plt.subplot(num_images, num_cols, idx * num_cols + 2)
            attention_map = avg_attention[idx].cpu()
            sns.heatmap(attention_map, cmap='viridis')
            ax2.set_title('Average Self-Attention')
            
            # Patch-wise attention visualization
            ax3 = plt.subplot(num_images, num_cols, idx * num_cols + 3)
            # Get attention for the classification token (last token)
            patch_attention = attention_map[-1, :-1].reshape(4, 4)  # for 8x8 patches
            sns.heatmap(patch_attention, cmap='viridis')
            ax3.set_title('Patch Attention Weights')
        
        plt.tight_layout()
        return fig

In [16]:
# Load best model
checkpoint = torch.load('C:\\Users\\Windows\\Documents\\CVC\\repos\seeing-language\\notebooks\wandb\\run-20241111_230849-kjps7qnm\\files\\best_model.pth')
model = GPT2CIFAR10()
model.load_state_dict(checkpoint['model_state_dict'])

# CIFAR-10 class names
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

# Initialize visualizer
visualizer = GPT2Visualizer(model=model, device='cpu', class_names=class_names)

transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Load CIFAR10
valset = datasets.CIFAR10(root='./data', train=False,
                         download=True, transform=transform_val)

valloader = DataLoader(valset, batch_size=128, shuffle=False, num_workers=2)

# Get some test images
dataiter = iter(valloader)
images, labels = next(dataiter)

# Visualize predictions and attention
fig = visualizer.predict_and_visualize(images[:5], labels[:5])
plt.show()

# To save the figure
# fig.savefig('predictions_attention.png', bbox_inches='tight', dpi=300)

  checkpoint = torch.load('C:\\Users\\Windows\\Documents\\CVC\\repos\seeing-language\\notebooks\wandb\\run-20241111_230849-kjps7qnm\\files\\best_model.pth')


Files already downloaded and verified


RuntimeError: stack expects a non-empty TensorList