In [1]:
%%html
<style>
.output_wrapper, .output {
    display: flex !important;
    align-items: center;
    justify-content: center;
}
</style>


In [2]:
# %% [markdown]
# # Vision Transformer Attention Analysis
#
# ![Attention Visualization](https://i.imgur.com/Xg7XQ0T.png)

# %%
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.metrics import confusion_matrix
import seaborn as sns
from tqdm import tqdm

# %%
# %%capture
# !pip install torchcam

# %%
from models.hybrid_vit import HybridViT
from models.utils import load_checkpoint
from data.cifar10 import CIFAR10DataModule
from torchcam.methods import GradCAM
from torchcam.utils import overlay_mask

# %% [markdown]
# ## 1. Model and Data Loading

# %%
# Initialize model
model = HybridViT(dim=256, depth=6, heads=8)
model = load_checkpoint(model, "checkpoints/best_model.pth")
model.eval();

# Initialize datamodule
dm = CIFAR10DataModule(batch_size=128)
dm.setup()

# Class labels
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

# %% [markdown]
# ## 2. Attention Head Visualization

# %%
def visualize_attention_heads(sample, layer_idx=0, n_heads=8):
    with torch.no_grad():
        _, attn_maps = model(sample.unsqueeze(0), return_attn=True)

    attn = attn_maps[layer_idx][0]  # [heads, seq_len, seq_len]

    fig, axs = plt.subplots(2, 4, figsize=(20, 10))
    for head_idx in range(n_heads):
        ax = axs[head_idx//4, head_idx%4]
        head_attn = attn[head_idx, 0, 1:].reshape(4, 4)
        ax.imshow(head_attn, cmap='viridis')
        ax.set_title(f'Head {head_idx+1}')
        ax.axis('off')
    plt.tight_layout()
    return fig

# %%
# Get sample
sample, label = dm.val_dataloader().dataset[42]
plt.imshow(sample.permute(1, 2, 0))
plt.title(f"True: {classes[label]}");
plt.axis('off');

# %%
# Visualize first layer attention heads
_ = visualize_attention_heads(sample, layer_idx=0)

# %% [markdown]
# ## 3. Grad-CAM Visualization

# %%
class ViTCAM(GradCAM):
    """Custom GradCAM implementation for Hybrid ViT"""

    def __init__(self, model, target_layer):
        super().__init__(model, target_layer)

    def forward(self, x, class_idx=None):
        self.model.zero_grad()

        # Forward pass
        features, _ = self.model(x, return_attn=True)
        logits = features if class_idx is None else features[:, class_idx]

        # Backward hook
        self.hook_g = torch.zeros_like(features)
        def backward_hook(module, grad_in, grad_out):
            self.hook_g += grad_out[0].detach()

        self.hook = self.target_layer.register_full_backward_hook(backward_hook)

        # Backprop
        logits.sum().backward()
        self.hook.remove()

        # Get CAM
        cam = self.hook_g[0].mean(1).mean(1)
        cam = torch.relu(cam)

        return cam

# %%
# Initialize Grad-CAM
target_layer = model.cnn_backbone[-3]  # Last CNN layer
cam_extractor = ViTCAM(model, target_layer)

# %%
# Generate CAM
sample_tensor = sample.unsqueeze(0)
cam = cam_extractor(sample_tensor)

# Overlay on image
result = overlay_mask(
    T.ToPILImage()(sample),
    T.ToPILImage()(cam.unsqueeze(0)),
    alpha=0.5
)

# Plot
plt.figure(figsize=(10, 5))
plt.subplot(121)
plt.imshow(sample.permute(1, 2, 0))
plt.title("Original Image")
plt.axis('off')

plt.subplot(122)
plt.imshow(result)
plt.title("Grad-CAM Visualization")
plt.axis('off');

# %% [markdown]
# ## 4. Confusion Matrix Analysis

# %%
def generate_confusion_matrix(model, datamodule):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(datamodule.val_dataloader()):
            x, y = batch
            logits = model(x)
            preds = logits.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())

    cm = confusion_matrix(all_labels, all_preds)
    return cm

# %%
# Generate and plot confusion matrix
cm = generate_confusion_matrix(model, dm)

plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=classes, yticklabels=classes)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix');

# %% [markdown]
# ## 5. Attention Pattern Evolution

# %%
def plot_attention_evolution(sample):
    with torch.no_grad():
        _, attn_maps = model(sample.unsqueeze(0), return_attn=True)

    fig, axs = plt.subplots(2, 3, figsize=(20, 12))
    for layer_idx in range(6):
        ax = axs[layer_idx//3, layer_idx%3]
        layer_attn = attn_maps[layer_idx][0, :, 0, 1:].mean(0)
        ax.imshow(layer_attn.reshape(4, 4), cmap='viridis')
        ax.set_title(f'Layer {layer_idx+1}')
        ax.axis('off')
    plt.tight_layout()

# %%
plot_attention_evolution(sample)

# %% [markdown]
# ## 6. t-SNE Feature Visualization

# %%
def visualize_tsne(model, datamodule, n_samples=1000):
    features = []
    labels = []

    with torch.no_grad():
        for idx, (x, y) in enumerate(datamodule.val_dataloader()):
            cls_token = model(x)[:, 0]
            features.append(cls_token)
            labels.append(y)
            if idx * 128 >= n_samples:
                break

    features = torch.cat(features)[:n_samples]
    labels = torch.cat(labels)[:n_samples]

    # t-SNE
    tsne = TSNE(n_components=2, perplexity=30)
    embeddings = tsne.fit_transform(features.cpu().numpy())

    # Plot
    plt.figure(figsize=(15, 12))
    scatter = plt.scatter(embeddings[:, 0], embeddings[:, 1],
                         c=labels.cpu(), cmap='tab10', alpha=0.6)
    plt.legend(handles=scatter.legend_elements()[0], labels=classes)
    plt.title('t-SNE Visualization of CLS Token Embeddings');

# %%
visualize_tsne(model, dm)

# %% [markdown]
# ## 7. Quantitative Attention Analysis

# %%
def calculate_attention_entropy(model, datamodule):
    entropies = torch.zeros(6, 8)  # layers x heads

    with torch.no_grad():
        for x, _ in tqdm(datamodule.val_dataloader()):
            _, attn_maps = model(x, return_attn=True)

            for layer in range(6):
                attn = attn_maps[layer][:, :, 0, 1:]  # [B, heads, seq]
                prob = attn.softmax(dim=-1)
                entropy = (-prob * prob.log()).sum(-1)  # [B, heads]
                entropies[layer] += entropy.mean(0)

    entropies /= len(datamodule.val_dataloader())
    return entropies

# %%
entropies = calculate_attention_entropy(model, dm)

plt.figure(figsize=(12, 6))
sns.heatmap(entropies.numpy(), annot=True, fmt=".2f",
            xticklabels=[f"Head {i+1}" for i in range(8)],
            yticklabels=[f"Layer {i+1}" for i in range(6)])
plt.title("Attention Entropy Across Layers/Heads (nats)");


ModuleNotFoundError: No module named 'models'