# Module 3: Vision Transformers in PyTorch
---

In [None]:
# Import necessary libraries
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm

print(f"PyTorch version: {torch.__version__}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seed
torch.manual_seed(42)
np.random.seed(42)

In [None]:
# Define paths and parameters
dataset_path = './images_dataSAT/'
IMG_SIZE = 64
BATCH_SIZE = 32
NUM_CLASSES = 2
LEARNING_RATE = 0.001

## Task 1: Create train_transform transforms for the training dataset.

In [None]:
# Task 1: Create train_transform
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.2),
    transforms.RandomRotation(degrees=45),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

print("Training transformation pipeline created:")
print(train_transform)

## Task 2: Create val_transform transforms for the validation dataset.

In [None]:
# Task 2: Create val_transform
val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

print("Validation transformation pipeline created:")
print(val_transform)

## Task 3: Create the Dataloader train_loader and val_loader using train_dataset and val_dataset.

In [None]:
# Task 3: Create DataLoaders

# Load datasets with respective transforms
train_dataset_full = datasets.ImageFolder(root=dataset_path, transform=train_transform)
val_dataset_full = datasets.ImageFolder(root=dataset_path, transform=val_transform)

# Split indices
total_size = len(train_dataset_full)
val_size = int(0.2 * total_size)
train_size = total_size - val_size

# Generate consistent split indices
generator = torch.Generator().manual_seed(42)
train_indices, val_indices = random_split(range(total_size), [train_size, val_size], generator=generator)

# Create subsets
train_dataset = torch.utils.data.Subset(train_dataset_full, train_indices.indices)
val_dataset = torch.utils.data.Subset(val_dataset_full, val_indices.indices)

# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

print(f"Classes: {train_dataset_full.classes}")
print(f"Class to idx: {train_dataset_full.class_to_idx}")
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

# Verify a batch
sample_images, sample_labels = next(iter(train_loader))
print(f"\nSample batch - Images shape: {sample_images.shape}, Labels shape: {sample_labels.shape}")

In [None]:
# ============================================================
# Define CNN Feature Extractor
# ============================================================
class CNNFeatureExtractor(nn.Module):
    def __init__(self):
        super(CNNFeatureExtractor, self).__init__()
        self.features = nn.Sequential(
            # Block 1
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            # Block 2
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            # Block 3
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            # Block 4
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
        )
    
    def forward(self, x):
        return self.features(x)

print("CNNFeatureExtractor defined.")

In [None]:
# ============================================================
# Define Transformer Encoder Block
# ============================================================
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1):
        super(TransformerEncoderBlock, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, embed_dim),
            nn.Dropout(dropout)
        )
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # Self-Attention with residual
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        x = x + self.dropout(attn_out)
        # MLP with residual
        x_norm = self.norm2(x)
        x = x + self.mlp(x_norm)
        return x

print("TransformerEncoderBlock defined.")

In [None]:
# ============================================================
# Define CNN-ViT Hybrid Model
# ============================================================
class CNNViTHybrid(nn.Module):
    def __init__(self, num_classes=2, embed_dim=64, num_heads=4, 
                 depth=4, mlp_dim=128, dropout=0.1):
        super(CNNViTHybrid, self).__init__()
        
        # CNN Feature Extractor
        self.cnn = CNNFeatureExtractor()
        
        # After CNN: (batch, 256, 4, 4) -> seq_length=16, feature_dim=256
        self.seq_length = 4 * 4  # 16
        self.cnn_feature_dim = 256
        
        # Linear projection
        self.projection = nn.Linear(self.cnn_feature_dim, embed_dim)
        
        # Positional embedding
        self.pos_embedding = nn.Parameter(torch.randn(1, self.seq_length, embed_dim))
        self.pos_dropout = nn.Dropout(dropout)
        
        # Transformer encoder blocks
        self.transformer_blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, mlp_dim, dropout)
            for _ in range(depth)
        ])
        
        # Final normalization
        self.norm = nn.LayerNorm(embed_dim)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        # CNN features: (batch, 256, 4, 4)
        x = self.cnn(x)
        batch_size = x.size(0)
        
        # Reshape: (batch, 256, 4, 4) -> (batch, 16, 256)
        x = x.flatten(2).transpose(1, 2)
        
        # Project to embed_dim: (batch, 16, embed_dim)
        x = self.projection(x)
        
        # Add positional embedding
        x = x + self.pos_embedding
        x = self.pos_dropout(x)
        
        # Transformer blocks
        for block in self.transformer_blocks:
            x = block(x)
        
        # Normalize
        x = self.norm(x)
        
        # Global average pooling: (batch, embed_dim)
        x = x.mean(dim=1)
        
        # Classify
        x = self.classifier(x)
        return x

print("CNNViTHybrid model defined.")

In [None]:
# ============================================================
# Define training function
# ============================================================
def train_model(model, train_loader, val_loader, criterion, optimizer, 
                epochs, device, model_name='model'):
    """
    Train the model and return training history and total training time.
    """
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [],
        'epoch_times': []
    }
    
    best_val_acc = 0.0
    total_start_time = time.time()
    
    for epoch in range(epochs):
        epoch_start = time.time()
        
        # ===== Training =====
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        train_bar = tqdm(train_loader, desc=f'{model_name} Epoch {epoch+1}/{epochs} [Train]')
        for images, labels in train_bar:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            
            train_bar.set_postfix(loss=loss.item(), acc=train_correct/train_total)
        
        epoch_train_loss = train_loss / len(train_loader)
        epoch_train_acc = train_correct / train_total
        
        # ===== Validation =====
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        epoch_val_loss = val_loss / len(val_loader)
        epoch_val_acc = val_correct / val_total
        
        epoch_time = time.time() - epoch_start
        
        # Save history
        history['train_loss'].append(epoch_train_loss)
        history['train_acc'].append(epoch_train_acc)
        history['val_loss'].append(epoch_val_loss)
        history['val_acc'].append(epoch_val_acc)
        history['epoch_times'].append(epoch_time)
        
        # Save best model
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            torch.save(model.state_dict(), f'best_{model_name}.pth')
        
        print(f'{model_name} Epoch {epoch+1}/{epochs} - '
              f'Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f} | '
              f'Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f} | '
              f'Time: {epoch_time:.2f}s')
    
    total_time = time.time() - total_start_time
    history['total_time'] = total_time
    
    print(f"\n{model_name} Training completed! Total time: {total_time:.2f}s")
    print(f"Best Validation Accuracy: {best_val_acc:.4f}")
    
    return history

print("Training function defined.")

In [None]:
# ============================================================
# Train Model 1 (baseline): smaller ViT
# ============================================================
print("="*60)
print("Training Model 1 (Baseline CNN-ViT)")
print("="*60)

# Model 1: Baseline with smaller hyperparameters
model = CNNViTHybrid(
    num_classes=NUM_CLASSES,
    embed_dim=64,
    num_heads=4,
    depth=4,
    mlp_dim=128,
    dropout=0.1
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print(f"Model 1 parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  embed_dim=64, num_heads=4, depth=4, mlp_dim=128")

# Train for 5 epochs
history_model = train_model(
    model, train_loader, val_loader, criterion, optimizer,
    epochs=5, device=device, model_name='model'
)

## Task 4: Design and train a CNN-ViT hybrid model model_test with the following hyperparameters: epochs=5, mlp_heads=12, embed_dim=768, transformer block depth=12.

In [None]:
# Task 4: Train model_test with specified hyperparameters
print("\n" + "="*60)
print("Training Model Test (Larger CNN-ViT)")
print("  epochs=5, mlp_heads(num_heads)=12, embed_dim=768, depth=12")
print("="*60)

model_test = CNNViTHybrid(
    num_classes=NUM_CLASSES,
    embed_dim=768,
    num_heads=12,
    depth=12,
    mlp_dim=3072,   # typically 4x embed_dim for ViT-Base
    dropout=0.1
).to(device)

criterion_test = nn.CrossEntropyLoss()
optimizer_test = optim.Adam(model_test.parameters(), lr=LEARNING_RATE)

print(f"Model Test parameters: {sum(p.numel() for p in model_test.parameters()):,}")
print(f"  embed_dim=768, num_heads=12, depth=12, mlp_dim=3072")
print()

# Train for 5 epochs
history_model_test = train_model(
    model_test, train_loader, val_loader, criterion_test, optimizer_test,
    epochs=5, device=device, model_name='model_test'
)

## Task 5: Compare the performance of model with model_test by plotting the validation loss for model and model_test ViTs.

In [None]:
# Task 5: Compare validation loss
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

epochs_range = range(1, 6)

# Plot 1: Validation Loss Comparison
axes[0].plot(epochs_range, history_model['val_loss'], 
             label='Model (Baseline)', color='blue', linewidth=2, marker='o', markersize=6)
axes[0].plot(epochs_range, history_model_test['val_loss'], 
             label='Model Test (Large)', color='red', linewidth=2, marker='s', markersize=6)
axes[0].set_title('Validation Loss Comparison', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Validation Loss', fontsize=12)
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)
axes[0].set_xticks(list(epochs_range))

# Plot 2: Validation Accuracy Comparison
axes[1].plot(epochs_range, history_model['val_acc'], 
             label='Model (Baseline)', color='blue', linewidth=2, marker='o', markersize=6)
axes[1].plot(epochs_range, history_model_test['val_acc'], 
             label='Model Test (Large)', color='red', linewidth=2, marker='s', markersize=6)
axes[1].set_title('Validation Accuracy Comparison', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Validation Accuracy', fontsize=12)
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)
axes[1].set_xticks(list(epochs_range))

plt.suptitle('Model vs Model Test: Performance Comparison', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

# Print comparison table
print("\nPerformance Comparison Summary:")
print("=" * 65)
print(f"{'Metric':<25} {'Model (Baseline)':>18} {'Model Test (Large)':>18}")
print("-" * 65)
print(f"{'Best Val Loss':<25} {min(history_model['val_loss']):>18.4f} {min(history_model_test['val_loss']):>18.4f}")
print(f"{'Best Val Accuracy':<25} {max(history_model['val_acc']):>18.4f} {max(history_model_test['val_acc']):>18.4f}")
print(f"{'Final Val Loss':<25} {history_model['val_loss'][-1]:>18.4f} {history_model_test['val_loss'][-1]:>18.4f}")
print(f"{'Final Val Accuracy':<25} {history_model['val_acc'][-1]:>18.4f} {history_model_test['val_acc'][-1]:>18.4f}")
print(f"{'Total Parameters':<25} {sum(p.numel() for p in model.parameters()):>18,} {sum(p.numel() for p in model_test.parameters()):>18,}")

## Task 6: Compare the training times of model with model_test by plotting the training time for each.

In [None]:
# Task 6: Compare training times
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Plot 1: Per-Epoch Training Time
axes[0].plot(epochs_range, history_model['epoch_times'], 
             label='Model (Baseline)', color='blue', linewidth=2, marker='o', markersize=6)
axes[0].plot(epochs_range, history_model_test['epoch_times'], 
             label='Model Test (Large)', color='red', linewidth=2, marker='s', markersize=6)
axes[0].set_title('Per-Epoch Training Time', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Time (seconds)', fontsize=12)
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)
axes[0].set_xticks(list(epochs_range))

# Plot 2: Total Training Time Bar Chart
model_names = ['Model\n(Baseline)', 'Model Test\n(Large)']
total_times = [history_model['total_time'], history_model_test['total_time']]
colors = ['steelblue', 'coral']

bars = axes[1].bar(model_names, total_times, color=colors, width=0.5, edgecolor='black')
axes[1].set_title('Total Training Time Comparison', fontsize=14, fontweight='bold')
axes[1].set_ylabel('Total Time (seconds)', fontsize=12)
axes[1].grid(axis='y', alpha=0.3)

# Add value labels on bars
for bar, t in zip(bars, total_times):
    axes[1].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.5,
                 f'{t:.2f}s', ha='center', va='bottom', fontsize=12, fontweight='bold')

plt.suptitle('Model vs Model Test: Training Time Comparison', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

# Print timing summary
print("\nTraining Time Comparison:")
print("=" * 60)
print(f"{'Metric':<30} {'Model':>12} {'Model Test':>12}")
print("-" * 60)
print(f"{'Total Training Time (s)':<30} {history_model['total_time']:>12.2f} {history_model_test['total_time']:>12.2f}")
print(f"{'Avg Epoch Time (s)':<30} {np.mean(history_model['epoch_times']):>12.2f} {np.mean(history_model_test['epoch_times']):>12.2f}")
print(f"{'Min Epoch Time (s)':<30} {min(history_model['epoch_times']):>12.2f} {min(history_model_test['epoch_times']):>12.2f}")
print(f"{'Max Epoch Time (s)':<30} {max(history_model['epoch_times']):>12.2f} {max(history_model_test['epoch_times']):>12.2f}")

speedup = history_model_test['total_time'] / history_model['total_time']
print(f"\nModel Test is {speedup:.2f}x {'slower' if speedup > 1 else 'faster'} than baseline Model.")
print(f"Model Test has {sum(p.numel() for p in model_test.parameters()) / sum(p.numel() for p in model.parameters()):.1f}x more parameters.")

In [None]:
# Save both models for use in Question 9
torch.save(model.state_dict(), 'best_model.pth')
torch.save(model_test.state_dict(), 'best_model_test.pth')
print("Both models saved successfully!")

---
## All 6 tasks completed successfully.