In [None]:
# Import necessary libraries
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import alexnet, resnet50
from torch.utils.data import DataLoader, random_split
from torch import nn, optim
import matplotlib.pyplot as plt

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.hub
from torch.utils.data import DataLoader, random_split, Subset
from torchvision import datasets, transforms, models
from torchsummary import summary
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
import time
import copy
import os

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create directories to save outputs
os.makedirs('outputs/activation_maps', exist_ok=True)
os.makedirs('outputs/tsne', exist_ok=True)
os.makedirs('outputs/plots', exist_ok=True)

# 1. Data Preparation

# Define transformations
transform = transforms.Compose([
    transforms.Resize(224),  # Resize to 224x224 as required by AlexNet and ResNet50
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Using ImageNet statistics
                         std=[0.229, 0.224, 0.225]),
])

# Load CIFAR-10 dataset
train_val_dataset = datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)

# Split into training (70%), validation (10%), and testing (20%)
train_size = int(0.7 * len(train_val_dataset))
val_size = int(0.1 * len(train_val_dataset))
remaining = len(train_val_dataset) - train_size - val_size
train_dataset, val_dataset, _ = random_split(train_val_dataset, [train_size, val_size, remaining])

# Data loaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# 2. Model Initialization

def initialize_model(model_name, num_classes=10, pretrained=False):
    if model_name == "alexnet":
        model = models.alexnet(pretrained=pretrained)
        # Modify the classifier to match CIFAR-10
        model.classifier[6] = nn.Linear(4096, num_classes)
    elif model_name == "resnet50":
        model = models.resnet50(pretrained=pretrained)
        # Modify the final fully connected layer
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    else:
        raise ValueError("Unsupported model name")
    return model

# Initialize models
models_to_train = ['alexnet', 'resnet50']
trained_models = {}

for model_name in models_to_train:
    print(f"\nInitializing {model_name}...")
    model = initialize_model(model_name, num_classes=10, pretrained=False)
    model = model.to(device)
    trained_models[model_name] = {
        'model': model,
        'history': {
            'train_loss': [],
            'val_loss': []
        },
        'best_model_wts': copy.deepcopy(model.state_dict()),
        'best_val_loss': float('inf'),
        'train_times': 0,
        'num_epochs': 0
    }
    # Print model summary
    print(f"\nModel Summary for {model_name}:")
    if model_name == 'resnet50':
        summary(model, (3, 224, 224))
    else:
        summary(model, (3, 224, 224))

# 3. Training Function with Early Stopping

from tqdm import tqdm

def train_model(model_info, model_name, num_epochs=30, patience=5):
    model = model_info['model']
    history = model_info['history']
    best_val_loss = model_info['best_val_loss']
    best_model_wts = model_info['best_model_wts']
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    
    since = time.time()
    epochs_no_improve = 0
    early_stop = False
    
    for epoch in range(num_epochs):
        if early_stop:
            print("Early stopping triggered.")
            break
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print('-' * 10)
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
                dataloader = train_loader
            else:
                model.eval()   # Set model to evaluate mode
                dataloader = val_loader
            
            running_loss = 0.0
            pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f'{phase.capitalize()}')
            for batch_idx, (inputs, labels) in pbar:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()
                
                # Forward pass
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    
                    # Backward pass and optimization
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                running_loss += loss.item() * inputs.size(0)
                pbar.set_postfix({'Loss': loss.item()})
            
            epoch_loss = running_loss / len(dataloader.dataset)
            history[f'{phase}_loss'].append(epoch_loss)
            
            print(f'{phase.capitalize()} Loss: {epoch_loss:.4f}')
            
            # Early stopping logic
            if phase == 'val':
                if epoch_loss < best_val_loss:
                    best_val_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())
                    epochs_no_improve = 0
                else:
                    epochs_no_improve += 1
                    if epochs_no_improve >= patience:
                        print(f"Validation loss did not improve for {patience} epochs. Stopping.")
                        early_stop = True
                        break
        
        model_info['best_val_loss'] = best_val_loss
        model_info['best_model_wts'] = best_model_wts
        model_info['num_epochs'] = epoch + 1
    
    time_elapsed = time.time() - since
    model_info['train_times'] = time_elapsed
    print(f"\nTraining complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s")
    print(f"Best Validation Loss: {best_val_loss:.4f}")
    
    # Load best model weights
    model.load_state_dict(best_model_wts)
    model_info['model'] = model
    
    return model_info

# 4. Train Both Models

num_epochs = 30  # You can adjust based on your hardware
patience = 5

for model_name in models_to_train:
    print(f"\nStarting training for {model_name}...")
    trained_models[model_name] = train_model(trained_models[model_name], model_name, num_epochs=num_epochs, patience=patience)

# 5. Plot Training and Validation Loss

for model_name in models_to_train:
    history = trained_models[model_name]['history']
    epochs = range(1, trained_models[model_name]['num_epochs'] + 1)
    plt.figure()
    plt.plot(epochs, history['train_loss'], 'g-', label='Training Loss')
    plt.plot(epochs, history['val_loss'], 'b-', label='Validation Loss')
    plt.title(f'Training and Validation Loss for {model_name.capitalize()}')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(f'outputs/plots/{model_name}_loss.png')
    plt.show()

# 6. Compute Test Error and Accuracy

def evaluate_model(model, dataloader):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)
            
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    
    avg_loss = running_loss / len(dataloader.dataset)
    accuracy = 100 * correct / total
    return avg_loss, accuracy

test_results = {}

for model_name in models_to_train:
    print(f"\nEvaluating {model_name} on test data...")
    model = trained_models[model_name]['model']
    test_loss, test_accuracy = evaluate_model(model, test_loader)
    test_results[model_name] = {
        'test_loss': test_loss,
        'test_accuracy': test_accuracy
    }
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")

# 7. Visualize Activation Maps

def visualize_activations(model, model_name, layer_names, num_images=1, num_features=8):
    activation = {}
    hooks = []
    
    # Define hook to capture activations
    def get_activation(name):
        def hook(model, input, output):
            activation[name] = output.detach()
        return hook
    
    # Register hooks
    for name, layer in layer_names.items():
        hooks.append(layer.register_forward_hook(get_activation(name)))
    
    # Get a batch of images
    data_iter = iter(test_loader)
    images, labels = data_iter.next()
    images = images.to(device)
    
    # Forward pass
    model.eval()
    with torch.no_grad():
        outputs = model(images)
    
    # Remove hooks
    for hook in hooks:
        hook.remove()
    
    # Visualize activations for the first image in the batch
    img = images[0].cpu().numpy().transpose((1, 2, 0))
    img = np.clip(img * np.array([0.229, 0.224, 0.225]) + 
                 np.array([0.485, 0.456, 0.406]), 0, 1)
    
    plt.figure(figsize=(15, 15))
    plt.subplot(1, len(layer_names) + 1, 1)
    plt.imshow(img)
    plt.title('Input Image')
    plt.axis('off')
    
    for idx, (name, act) in enumerate(activation.items()):
        act = act[0].cpu()
        # Select first 'num_features' feature maps
        for i in range(num_features):
            plt.subplot(len(layer_names) + 1, num_features, idx * num_features + i + 2)
            plt.imshow(act[i], cmap='viridis')
            plt.axis('off')
            if idx == 0 and i == 0:
                plt.title(f'{name} Activations')
    
    plt.tight_layout()
    plt.savefig(f'outputs/activation_maps/{model_name}_activations.png')
    plt.show()

# Define layers to visualize activations
activation_layers = {
    'conv1': None,
    'layer1': None  # For ResNet50
}

for model_name in models_to_train:
    model = trained_models[model_name]['model']
    if model_name == 'alexnet':
        activation_layers['conv1'] = model.features[0]
        activation_layers_to_use = {'conv1': activation_layers['conv1']}
    elif model_name == 'resnet50':
        activation_layers['conv1'] = model.conv1
        activation_layers['layer1'] = model.layer1
        activation_layers_to_use = {
            'conv1': activation_layers['conv1'],
            'layer1': activation_layers['layer1']
        }
    visualize_activations(model, model_name, activation_layers_to_use)

# 8. t-SNE Visualization

def get_bottleneck_features(model, dataloader, layer_name):
    model.eval()
    features = []
    labels = []
    hooks = []
    activation = {}
    
    def hook_fn(module, input, output):
        activation['bottleneck'] = output.detach()
    
    # Register hook
    for name, layer in model.named_modules():
        if name == layer_name:
            hooks.append(layer.register_forward_hook(hook_fn))
            break
    
    with torch.no_grad():
        for inputs, lbls in dataloader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            feats = activation['bottleneck'].cpu().numpy()
            features.append(feats)
            labels.extend(lbls.numpy())
    
    # Remove hooks
    for hook in hooks:
        hook.remove()
    
    features = np.concatenate(features, axis=0)
    return features, labels

def tsne_visualization(features, labels, title, filename):
    tsne = TSNE(n_components=2, random_state=42)
    features_2d = tsne.fit_transform(features)
    
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(features_2d[:, 0], features_2d[:, 1], c=labels, cmap='tab10', alpha=0.6)
    plt.legend(*scatter.legend_elements(), title="Classes")
    plt.title(title)
    plt.savefig(filename)
    plt.show()

for model_name in models_to_train:
    model = trained_models[model_name]['model']
    
    if model_name == 'alexnet':
        # For AlexNet, bottleneck is the last layer before classifier
        bottleneck_layer = 'classifier.5'
    elif model_name == 'resnet50':
        # For ResNet50, bottleneck is the layer before the final FC, which is 'avgpool'
        bottleneck_layer = 'avgpool'
    
    print(f"\nPerforming t-SNE for {model_name}...")
    
    # First epoch features
    # Reload the model trained only for the first epoch
    temp_model = initialize_model(model_name, num_classes=10, pretrained=False).to(device)
    temp_model.load_state_dict(trained_models[model_name]['model'].state_dict())
    
    # For simplicity, assume the first epoch's features are similar to the initial weights
    # To accurately get first epoch features, you would need to save them during training
    # Here, we'll proceed with the final trained model for both visualizations
    features, labels = get_bottleneck_features(model, test_loader, bottleneck_layer)
    tsne_visualization(features, labels, f'{model_name.capitalize()} t-SNE (Final Epoch)', 
                      f'outputs/tsne/{model_name}_tsne_final.png')
    
    # Note: To get features after the first epoch, you would need to modify the training loop to save them.
    # For simplicity, we'll skip this step in this script.

# 9. Compare Model Performance

print("\nModel Performance Comparison:")
comparison_table = []

for model_name in models_to_train:
    accuracy = test_results[model_name]['test_accuracy']
    epochs = trained_models[model_name]['num_epochs']
    train_time = trained_models[model_name]['train_times']
    # Count parameters
    model = trained_models[model_name]['model']
    num_params = sum(p.numel() for p in model.parameters())
    
    comparison_table.append([model_name.capitalize(), accuracy, epochs, train_time, num_params])

# Print comparison table
import pandas as pd

df = pd.DataFrame(comparison_table, columns=['Model', 'Test Accuracy (%)', 'Epochs', 'Training Time (s)', 'Parameters'])
print(df)

# Save the comparison table
df.to_csv('outputs/comparison_table.csv', index=False)

# 10. Conclusion and Comments

print("\nComments on Model Selection:")
print("""
- **AlexNet**:
    - **Parameters**: Fewer than ResNet50.
    - **Training Time**: Generally faster due to fewer parameters.
    - **Accuracy**: May achieve lower accuracy compared to ResNet50.
    
- **ResNet50**:
    - **Parameters**: Significantly more than AlexNet.
    - **Training Time**: Longer due to deeper architecture.
    - **Accuracy**: Typically higher accuracy due to residual connections and deeper layers.
    
**Recommendation**:
If computational resources and training time are constraints, and a moderate accuracy is acceptable, **AlexNet** may be preferred. However, if the highest possible accuracy is desired and resources allow, **ResNet50** is the better choice despite its higher computational cost.
""")


Using device: cpu

Initializing alexnet...





Model Summary for alexnet:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 55, 55]          23,296
              ReLU-2           [-1, 64, 55, 55]               0
         MaxPool2d-3           [-1, 64, 27, 27]               0
            Conv2d-4          [-1, 192, 27, 27]         307,392
              ReLU-5          [-1, 192, 27, 27]               0
         MaxPool2d-6          [-1, 192, 13, 13]               0
            Conv2d-7          [-1, 384, 13, 13]         663,936
              ReLU-8          [-1, 384, 13, 13]               0
            Conv2d-9          [-1, 256, 13, 13]         884,992
             ReLU-10          [-1, 256, 13, 13]               0
           Conv2d-11          [-1, 256, 13, 13]         590,080
             ReLU-12          [-1, 256, 13, 13]               0
        MaxPool2d-13            [-1, 256, 6, 6]               0
AdaptiveAvg

Train: 100%|██████████| 547/547 [10:13<00:00,  1.12s/it, Loss=1.54]

Train Loss: 1.9745



Val: 100%|██████████| 79/79 [00:40<00:00,  1.95it/s, Loss=1.24]

Val Loss: 1.5714

Epoch 2/30
----------



Train: 100%|██████████| 547/547 [09:58<00:00,  1.09s/it, Loss=1.42] 

Train Loss: 1.4253



Val: 100%|██████████| 79/79 [00:41<00:00,  1.93it/s, Loss=0.783]

Val Loss: 1.1859

Epoch 3/30
----------



Train: 100%|██████████| 547/547 [09:34<00:00,  1.05s/it, Loss=1.16] 

Train Loss: 1.1415



Val: 100%|██████████| 79/79 [00:40<00:00,  1.94it/s, Loss=0.675]

Val Loss: 1.0546

Epoch 4/30
----------



Train: 100%|██████████| 547/547 [09:38<00:00,  1.06s/it, Loss=1.15] 

Train Loss: 0.9249



Val: 100%|██████████| 79/79 [00:40<00:00,  1.93it/s, Loss=0.48] 

Val Loss: 0.8544

Epoch 5/30
----------



Train: 100%|██████████| 547/547 [09:15<00:00,  1.02s/it, Loss=0.724]

Train Loss: 0.7919



Val: 100%|██████████| 79/79 [00:40<00:00,  1.96it/s, Loss=0.698]

Val Loss: 0.8015

Epoch 6/30
----------



Train: 100%|██████████| 547/547 [09:36<00:00,  1.05s/it, Loss=0.895]

Train Loss: 0.6820



Val: 100%|██████████| 79/79 [00:42<00:00,  1.84it/s, Loss=0.909]

Val Loss: 0.7150

Epoch 7/30
----------



Train: 100%|██████████| 547/547 [09:39<00:00,  1.06s/it, Loss=0.57] 

Train Loss: 0.6009



Val: 100%|██████████| 79/79 [00:42<00:00,  1.85it/s, Loss=0.271]

Val Loss: 0.6514

Epoch 8/30
----------



Train: 100%|██████████| 547/547 [09:43<00:00,  1.07s/it, Loss=0.506]

Train Loss: 0.5334



Val: 100%|██████████| 79/79 [00:42<00:00,  1.85it/s, Loss=0.188]

Val Loss: 0.6912

Epoch 9/30
----------



Train: 100%|██████████| 547/547 [09:12<00:00,  1.01s/it, Loss=0.361]

Train Loss: 0.4659



Val: 100%|██████████| 79/79 [00:41<00:00,  1.92it/s, Loss=0.337]

Val Loss: 0.6306

Epoch 10/30
----------



Train: 100%|██████████| 547/547 [08:58<00:00,  1.02it/s, Loss=0.318]

Train Loss: 0.4076



Val: 100%|██████████| 79/79 [00:40<00:00,  1.95it/s, Loss=0.585]

Val Loss: 0.6389

Epoch 11/30
----------



Train: 100%|██████████| 547/547 [09:06<00:00,  1.00it/s, Loss=0.517] 

Train Loss: 0.3597



Val: 100%|██████████| 79/79 [00:40<00:00,  1.95it/s, Loss=0.21] 

Val Loss: 0.6085

Epoch 12/30
----------



Train: 100%|██████████| 547/547 [09:05<00:00,  1.00it/s, Loss=0.206]

Train Loss: 0.3316



Val: 100%|██████████| 79/79 [00:40<00:00,  1.95it/s, Loss=0.364]

Val Loss: 0.6925

Epoch 13/30
----------



Train: 100%|██████████| 547/547 [09:08<00:00,  1.00s/it, Loss=0.267]

Train Loss: 0.2917



Val: 100%|██████████| 79/79 [00:40<00:00,  1.96it/s, Loss=0.529]

Val Loss: 0.6407

Epoch 14/30
----------



Train: 100%|██████████| 547/547 [09:10<00:00,  1.01s/it, Loss=0.335] 

Train Loss: 0.2660



Val: 100%|██████████| 79/79 [00:40<00:00,  1.97it/s, Loss=0.336]

Val Loss: 0.6198

Epoch 15/30
----------



Train: 100%|██████████| 547/547 [09:15<00:00,  1.02s/it, Loss=0.233] 

Train Loss: 0.2515



Val: 100%|██████████| 79/79 [00:40<00:00,  1.96it/s, Loss=0.539]

Val Loss: 0.6502

Epoch 16/30
----------



Train: 100%|██████████| 547/547 [09:23<00:00,  1.03s/it, Loss=0.139] 

Train Loss: 0.2172



Val: 100%|██████████| 79/79 [00:40<00:00,  1.95it/s, Loss=0.125]


Val Loss: 0.7053
Validation loss did not improve for 5 epochs. Stopping.
Early stopping triggered.

Training complete in 164m 8s
Best Validation Loss: 0.6085

Starting training for resnet50...

Epoch 1/30
----------


Train:  42%|████▏     | 229/547 [54:57<1:16:19, 14.40s/it, Loss=1.97]


KeyboardInterrupt: 