In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from vit_pytorch import ViT
from vit_pytorch.mobile_vit import MobileViT
from vit_pytorch.crossformer import CrossFormer
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import os

In [None]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "mps")

In [2]:
# TensorBoard writer
log_dir = "./logs/vit_cifar10"
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir)

In [3]:
# Hyperparameters
image_size = 32  # CIFAR-10 image size
patch_size = 4   # Patch size for ViT
num_classes = 10
dim = 128        # Embedding dimension
depth = 6        # Number of transformer layers
heads = 8        # Number of attention heads
mlp_dim = 256    # Dimension of MLP layer
dropout = 0.1
batch_size = 128
num_epochs = 50
learning_rate = 3e-4

In [4]:
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
# Choose only the first 5000 examples for training and testing
train_loader.dataset.data = train_loader.dataset.data[:5000]
train_loader.dataset.targets = train_loader.dataset.targets[:5000]
test_loader.dataset.data = test_loader.dataset.data[:1000]
test_loader.dataset.targets = test_loader.dataset.targets[:1000]

print(f"Number of training examples: {len(train_dataset)} | Number of testing examples: {len(test_dataset)}")

Number of training examples: 5000 | Number of testing examples: 1000


In [9]:
# Load ViT model

# Define the Vision Transformer model
# model = ViT(
#     image_size=image_size,
#     patch_size=patch_size,
#     num_classes=num_classes,
#     dim=dim,
#     depth=depth,
#     heads=heads,
#     mlp_dim=mlp_dim,
#     dropout=dropout,
#     emb_dropout=dropout
# ).to(device)

model = CrossFormer(
    num_classes=10,                 # CIFAR-10 has 10 classes
    dim=(32, 64, 128, 256),         # Reduced dimensions for smaller dataset
    depth=(1, 2, 4, 1),             # Shallower model
    global_window_size=(4, 2, 1, 1), # Adjusted for CIFAR-10
    local_window_size=1,            # Default local window size
).to(device)


# Model Summary
print(model)

CrossFormer(
  (layers): ModuleList(
    (0): ModuleList(
      (0): CrossEmbedLayer(
        (convs): ModuleList(
          (0): Conv2d(3, 16, kernel_size=(4, 4), stride=(4, 4))
          (1): Conv2d(3, 8, kernel_size=(8, 8), stride=(4, 4), padding=(2, 2))
          (2): Conv2d(3, 4, kernel_size=(16, 16), stride=(4, 4), padding=(6, 6))
          (3): Conv2d(3, 4, kernel_size=(32, 32), stride=(4, 4), padding=(14, 14))
        )
      )
      (1): Transformer(
        (layers): ModuleList(
          (0): ModuleList(
            (0): Attention(
              (norm): LayerNorm()
              (dropout): Dropout(p=0.0, inplace=False)
              (to_qkv): Conv2d(32, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (to_out): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
              (dpb): Sequential(
                (0): Linear(in_features=2, out_features=8, bias=True)
                (1): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
                (2): ReLU(

In [10]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [11]:
train_loss_history = []
test_accuracy_history = []

print("Starting Training...")
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    train_loss = total_loss / len(train_loader)
    train_loss_history.append(train_loss)

    # Log training loss to TensorBoard
    writer.add_scalar("Loss/Train", train_loss, epoch + 1)

    # Testing loop
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    test_accuracy_history.append(accuracy)

    # Log test accuracy to TensorBoard
    writer.add_scalar("Accuracy/Test", accuracy, epoch + 1)

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}, Test Accuracy: {accuracy:.2f}%")

Starting Training...
Epoch [1/50], Loss: 1.9820, Test Accuracy: 31.50%
Epoch [2/50], Loss: 1.7032, Test Accuracy: 37.90%
Epoch [3/50], Loss: 1.5820, Test Accuracy: 40.00%
Epoch [4/50], Loss: 1.4552, Test Accuracy: 41.20%
Epoch [5/50], Loss: 1.3348, Test Accuracy: 41.30%
Epoch [6/50], Loss: 1.2479, Test Accuracy: 41.40%
Epoch [7/50], Loss: 1.1090, Test Accuracy: 42.40%
Epoch [8/50], Loss: 1.0702, Test Accuracy: 41.00%
Epoch [9/50], Loss: 0.9913, Test Accuracy: 43.80%


KeyboardInterrupt: 

In [None]:
print("Evaluating on Test Data...")
all_labels = []
all_preds = []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)

        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(predicted.cpu().numpy())

# Generate confusion matrix
cm = confusion_matrix(all_labels, all_preds)
class_names = test_dataset.classes

# Plot confusion matrix as an image for TensorBoard
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.title("Confusion Matrix")
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")

# Save confusion matrix plot to TensorBoard
plt.tight_layout()
writer.add_figure("Confusion erfgerfgweryMatrix", plt.gcf())
plt.show()

# Print classification report
print("Classification Report:")
print(classification_report(all_labels, all_preds, target_names=class_names))

In [None]:
# Visualize test images with predictions and ground truth
import random

def visualize_test_images(model, test_loader, class_names, num_images=16):
    model.eval()  # Ensure the model is in evaluation mode
    images, labels = next(iter(test_loader))  # Get a batch of test images
    images, labels = images.to(device), labels.to(device)

    with torch.no_grad():
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)

    # Select random indices
    indices = random.sample(range(len(images)), num_images)

    plt.figure(figsize=(12, 12))
    for i, idx in enumerate(indices):
        image = images[idx].cpu().numpy().transpose(1, 2, 0)  # Convert to HWC format
        image = (image * 0.5) + 0.5  # Unnormalize
        label = labels[idx].item()
        prediction = predicted[idx].item()

        plt.subplot(4, 4, i + 1)  # Arrange in a 4x4 grid
        plt.imshow(image)
        plt.title(f"True: {class_names[label]}\nPred: {class_names[prediction]}")
        plt.axis("off")

    plt.tight_layout()
    plt.show()

    # Optionally log the visualization to TensorBoard
    writer.add_figure("Test Images", plt.gcf())

# Call the function to visualize
visualize_test_images(model, test_loader, class_names, num_images=16)

In [None]:
def visualize_attention(model, data_loader, num_layers=6, patch_size=4):
    """
    Visualizes attention maps for a random image after each transformer layer.

    Args:
        model: Trained Vision Transformer model.
        data_loader: DataLoader for test data.
        num_layers: Number of transformer layers in the model.
        patch_size: Patch size used in ViT.
    """
    # Hook to store attention maps
    attention_maps = []

    def hook(module, input, output):
        # The attention weights are typically in `module.attention`
        attention_maps.append(module.attention_weights)

    # Register hooks for each transformer block
    hooks = []
    for i in range(num_layers):
        hooks.append(
            model.transformer.layers[i].register_forward_hook(hook)
        )

    # Select a random image from the test loader
    images, labels = next(iter(data_loader))
    random_idx = random.randint(0, images.size(0) - 1)
    image = images[random_idx:random_idx + 1].to(device)
    label = labels[random_idx].item()

    # Pass the image through the model
    model.eval()
    with torch.no_grad():
        _ = model(image)

    # Plot the original image
    plt.figure(figsize=(12, 6))
    plt.subplot(1, num_layers + 1, 1)
    plt.imshow((image[0].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5).clip(0, 1))
    plt.axis('off')
    plt.title("Original Image")

    # Plot attention maps
    for i, attention in enumerate(attention_maps):
        attention = attention[0].mean(dim=0).cpu().numpy()  # Take mean attention over heads
        num_patches = int(attention.shape[0] ** 0.5)

        # Reshape attention to match patches
        attention = attention.reshape(num_patches, num_patches)
        attention = torch.nn.functional.interpolate(
            torch.tensor(attention).unsqueeze(0).unsqueeze(0),
            size=(image.size(2), image.size(3)),
            mode="bilinear",
            align_corners=False
        ).squeeze().cpu().numpy()

        plt.subplot(1, num_layers + 1, i + 2)
        plt.imshow(attention, cmap="viridis")
        plt.axis('off')
        plt.title(f"Layer {i + 1}")

    plt.tight_layout()
    plt.show()

    # Remove hooks
    for hook in hooks:
        hook.remove()

# Call the visualization function
visualize_attention(model, test_loader, num_layers=depth, patch_size=patch_size)

In [13]:
def get_attention_maps(model, images):
    """
    Captures attention maps from the Vision Transformer model.
    Args:
        model: Vision Transformer model.
        images: Input batch of images.
    Returns:
        List of attention maps for each transformer layer.
    """
    attention_maps = []

    def hook_fn(module, input, output):
        attention_maps.append(output)

    # Register hooks for each transformer block
    hooks = []
    for transformer_block in model.transformer.blocks:
        hooks.append(transformer_block.attn.attn_drop.register_forward_hook(hook_fn))

    # Perform a forward pass to capture attention maps
    with torch.no_grad():
        _ = model(images)

    # Remove hooks
    for hook in hooks:
        hook.remove()

    return attention_maps

In [14]:
def visualize_attention(images, attention_maps, patch_size, image_size, num_heads=8):
    """
    Visualizes the attention maps for a batch of images.
    Args:
        images: Batch of input images (torch.Tensor).
        attention_maps: List of attention maps from ViT.
        patch_size: Patch size used in the ViT model.
        image_size: Image size of the input images.
        num_heads: Number of attention heads.
    """
    batch_size = images.size(0)
    num_layers = len(attention_maps)

    # Rescale images to [0, 1] for visualization
    images = images.permute(0, 2, 3, 1).cpu().numpy() * 0.5 + 0.5

    for idx in range(batch_size):
        plt.figure(figsize=(15, 5))
        plt.suptitle(f"Image {idx + 1}: Attention Maps", fontsize=16)

        # Plot the original image
        plt.subplot(2, num_layers + 1, 1)
        plt.imshow(images[idx])
        plt.axis("off")
        plt.title("Original Image")

        # Loop through each layer
        for layer_idx, layer_attentions in enumerate(attention_maps):
            # Extract the attention map for this image and layer
            attention = layer_attentions[idx]  # Shape: (num_heads, num_patches, num_patches)
            attention = attention.mean(dim=0).reshape(patch_size, patch_size)  # Average over heads

            # Resize attention map to match the image size
            attention_resized = cv2.resize(attention.cpu().numpy(), (image_size, image_size))
            attention_resized = (attention_resized - attention_resized.min()) / (attention_resized.max() - attention_resized.min())

            # Overlay attention on the original image
            overlay = images[idx].copy()
            heatmap = cv2.applyColorMap((attention_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
            heatmap = heatmap[..., ::-1] / 255.0
            overlay = cv2.addWeighted(overlay, 0.5, heatmap, 0.5, 0)

            plt.subplot(2, num_layers + 1, layer_idx + 2)
            plt.imshow(overlay)
            plt.axis("off")
            plt.title(f"Layer {layer_idx + 1}")

        plt.tight_layout()
        plt.show()

In [None]:
# Load a batch of test images
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False)

# Get a batch of test images
test_images, test_labels = next(iter(test_loader))
test_images = test_images.to(device)

# Capture attention maps
attention_maps = get_attention_maps(model, test_images)

# Visualize the attention maps
visualize_attention(test_images, attention_maps, patch_size=patch_size, image_size=image_size)