# CS 5970/6970 Assignment 4
**Name/Team:** [Your Name Here]

In [None]:
# Documentation
# This notebook is designed to run on Google Colab or a local environment.
# If running locally, ensure you have the dependencies installed via requirements.txt.
# If running on Colab, the next cell will install necessary packages.

In [None]:
# Install dependencies for Google Colab
!pip install torch torchvision numpy matplotlib tqdm scikit-learn

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torch.utils.data import Subset

## Task 1: Small CNN (from scratch)

In [None]:
# Load CIFAR-10 dataset with Augmentation
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
# Function to show an image
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# Get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)

# Show images
imshow(torchvision.utils.make_grid(images))
# Print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))

In [None]:
# Define the Convolutional Neural Network
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # Layer 1: Conv -> BN -> ReLU -> Pool
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d(2, 2)

        # Layer 2: Conv -> BN -> ReLU -> Pool
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        # Layer 3: Conv -> BN -> ReLU -> Pool
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)

        # Fully Connected Layers
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = x.view(-1, 128 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = Net()

In [None]:
# Model Summary and Parameter Count
print("Model Summary:")
print(net)

total_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f"\nTotal Trainable Parameters: {total_params}")

if total_params > 1500000:
    print("WARNING: Parameter count exceeds 1.5 million!")
else:
    print("Parameter count is within the limit (<= 1.5M).")

In [None]:
# Define a Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [None]:
# Train the network
num_epochs = 25
train_losses = []
train_accs = []

for epoch in range(num_epochs):  # loop over the dataset multiple times

    running_loss = 0.0
    correct = 0
    total = 0
    
    # Use tqdm for progress bar
    pbar = tqdm(enumerate(trainloader, 0), total=len(trainloader), desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for i, data in pbar:
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        # Update progress bar description
        if i % 200 == 199:
             pbar.set_postfix({'Loss': running_loss / (i+1), 'Acc': 100 * correct / total})

    epoch_loss = running_loss / len(trainloader)
    epoch_acc = 100 * correct / total
    train_losses.append(epoch_loss)
    train_accs.append(epoch_acc)
    
    print(f"Epoch {epoch+1} finished. Loss: {epoch_loss:.3f}, Accuracy: {epoch_acc:.2f}%")

print('Finished Training')

In [None]:
# Plot Training Curves
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.title('Training Loss vs Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Training Accuracy')
plt.title('Training Accuracy vs Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.show()

In [None]:
# Test the network on the test data and Plot Confusion Matrix
correct = 0
total = 0
all_preds = []
all_labels = []

with torch.no_grad():
    for data in tqdm(testloader, desc="Testing"):
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        all_preds.extend(predicted.numpy())
        all_labels.extend(labels.numpy())

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
fig, ax = plt.subplots(figsize=(10, 10))
disp.plot(ax=ax, cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.show()

In [None]:
# Save the model
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
print(f"Model saved to {PATH}")

---
**End of Task 1**

## Task 2: Improve your CNN with Residual Connections

### Residual Block Design
The residual block consists of two 3x3 convolutions, each followed by Batch Normalization and ReLU activation. 
A skip connection adds the input to the output of the second convolution. 
If the input and output shapes do not match (due to stride or channel changes), a 1x1 convolution is used in the skip connection to project the input to the correct shape.

In [None]:
# Define Residual Block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, 
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, 
                          stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [None]:
# Define Residual CNN
class ResidualCNN(nn.Module):
    def __init__(self):
        super(ResidualCNN, self).__init__()
        self.in_channels = 32
        
        # Initial Conv Layer
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        
        # Residual Blocks
        self.layer1 = self._make_layer(32, 32, stride=1)
        self.layer2 = self._make_layer(32, 64, stride=2)
        self.layer3 = self._make_layer(64, 128, stride=2)
        
        # Pooling to reduce size to 4x4
        self.avgpool = nn.AdaptiveAvgPool2d((4, 4))
        
        # Fully Connected Layers
        self.fc1 = nn.Linear(128 * 4 * 4, 512) 
        self.fc2 = nn.Linear(512, 10)

    def _make_layer(self, in_c, out_c, stride):
        layer = ResidualBlock(in_c, out_c, stride)
        return layer

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = self.fc2(out)
        return out

res_net = ResidualCNN()

In [None]:
# Model Summary and Parameter Count for ResidualCNN
print("ResidualCNN Summary:")
print(res_net)

res_total_params = sum(p.numel() for p in res_net.parameters() if p.requires_grad)
print(f"\nTotal Trainable Parameters (ResidualCNN): {res_total_params}")

if 0.8 * total_params <= res_total_params <= 1.2 * total_params:
    print("Parameter count is comparable to Task 1 (within +/- 20%).")
else:
    print("WARNING: Parameter count difference is significant.")

In [None]:
# Train ResidualCNN
criterion = nn.CrossEntropyLoss()
optimizer_res = optim.SGD(res_net.parameters(), lr=0.001, momentum=0.9)

train_losses_res = []
train_accs_res = []

print("Starting training for ResidualCNN...")
for epoch in range(num_epochs): 
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(enumerate(trainloader, 0), total=len(trainloader), desc=f"ResCNN Epoch {epoch+1}/{num_epochs}")
    
    for i, data in pbar:
        inputs, labels = data
        optimizer_res.zero_grad()
        outputs = res_net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_res.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        if i % 200 == 199:
             pbar.set_postfix({'Loss': running_loss / (i+1), 'Acc': 100 * correct / total})

    epoch_loss = running_loss / len(trainloader)
    epoch_acc = 100 * correct / total
    train_losses_res.append(epoch_loss)
    train_accs_res.append(epoch_acc)
    
    print(f"Epoch {epoch+1} finished. Loss: {epoch_loss:.3f}, Accuracy: {epoch_acc:.2f}%")

print('Finished Training ResidualCNN')

In [None]:
# Overlay Plot: Task 1 vs Task 2
plt.figure(figsize=(10, 6))
plt.plot(train_accs, label='Task 1: Simple CNN', linestyle='--')
plt.plot(train_accs_res, label='Task 2: Residual CNN', linewidth=2)
plt.title('Training Accuracy Comparison: Simple vs Residual CNN')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Test ResidualCNN
correct = 0
total = 0
with torch.no_grad():
    for data in tqdm(testloader, desc="Testing ResCNN"):
        images, labels = data
        outputs = res_net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the ResidualCNN on the 10000 test images: %d %%' % (
    100 * correct / total))

### Comparative Analysis

**Optimization Stability:**
[Your analysis here based on the loss curves. Residual networks often show smoother convergence.]

**Convergence Speed:**
[Compare how many epochs it took for each model to reach a high accuracy.]

**Generalization:**
[Compare the final test accuracy of both models. Did the residual connections help?]

---
**End of Task 2**

## Task 3: Self-Supervised Pretraining + Fine-tuning

### Rotation Prediction SSL Objective

In this task, we implement a self-supervised learning approach based on **Rotation Prediction**. The core idea is to train the network to predict the rotation angle applied to an input image. Each image is randomly rotated by 0째, 90째, 180째, or 270째, and the network learns to classify which rotation was applied (4-class classification). This pretext task forces the network to learn meaningful visual features without requiring labeled data, as understanding object orientation requires recognizing shapes, textures, and spatial relationships.

**Augmentations:** We use `torchvision.transforms.functional.rotate()` to apply rotations and normalize the images for input to the network.

In [None]:
# Rotation Dataset for SSL
import torchvision.transforms.functional as TF

class RotationDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset
        self.rotations = [0, 90, 180, 270]
        
    def __len__(self):
        return len(self.base_dataset)
    
    def __getitem__(self, idx):
        img, _ = self.base_dataset[idx]  # ignore original label
        
        # Randomly select a rotation
        rot_idx = np.random.randint(0, 4)
        rotation = self.rotations[rot_idx]
        
        # Rotate the image
        img_rotated = TF.rotate(img, rotation)
        
        return img_rotated, rot_idx

In [None]:
# Create rotation dataset for SSL pretraining
ssl_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

ssl_trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=ssl_transform)
rotation_dataset = RotationDataset(ssl_trainset)
ssl_trainloader = torch.utils.data.DataLoader(rotation_dataset, batch_size=64,
                                              shuffle=True, num_workers=2)

print(f"SSL Dataset size: {len(rotation_dataset)}")

In [None]:
# Initialize ResidualCNN for SSL pretraining
ssl_model = ResidualCNN()

# Replace the final layer with a 4-class classifier for rotation prediction
ssl_model.fc2 = nn.Linear(512, 4)

print("SSL Model initialized with 4-class output for rotation prediction.")

In [None]:
# SSL Pretraining
criterion_ssl = nn.CrossEntropyLoss()
optimizer_ssl = optim.SGD(ssl_model.parameters(), lr=0.001, momentum=0.9)

ssl_epochs = 10
ssl_losses = []

print("Starting SSL Pretraining (Rotation Prediction)...")
for epoch in range(ssl_epochs):
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(ssl_trainloader, desc=f"SSL Epoch {epoch+1}/{ssl_epochs}")
    
    for data in pbar:
        inputs, labels = data
        
        optimizer_ssl.zero_grad()
        outputs = ssl_model(inputs)
        loss = criterion_ssl(outputs, labels)
        loss.backward()
        optimizer_ssl.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(ssl_trainloader)
    epoch_acc = 100 * correct / total
    ssl_losses.append(epoch_loss)
    
    print(f"SSL Epoch {epoch+1} finished. Loss: {epoch_loss:.3f}, Acc: {epoch_acc:.2f}%")

print('Finished SSL Pretraining')

# Save pretrained weights
torch.save(ssl_model.state_dict(), './ssl_pretrained.pth')
print("Pretrained weights saved to ssl_pretrained.pth")

In [None]:
# Plot SSL Training Loss
plt.figure(figsize=(8, 5))
plt.plot(ssl_losses, label='SSL Pretraining Loss', marker='o')
plt.title('SSL Pretraining Loss vs Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Create 10% subset of CIFAR-10 (balanced across classes)
def get_balanced_subset_indices(dataset, ratio=0.1):
    # Get all labels
    labels = np.array(dataset.targets)
    num_classes = 10
    samples_per_class = int(len(dataset) * ratio / num_classes)
    
    indices = []
    for class_idx in range(num_classes):
        class_indices = np.where(labels == class_idx)[0]
        selected = np.random.choice(class_indices, samples_per_class, replace=False)
        indices.extend(selected)
    
    return indices

subset_indices = get_balanced_subset_indices(trainset, ratio=0.1)
subset_trainset = Subset(trainset, subset_indices)
subset_trainloader = torch.utils.data.DataLoader(subset_trainset, batch_size=4,
                                                  shuffle=True, num_workers=2)

print(f"10% Subset size: {len(subset_trainset)}")

In [None]:
# Load pretrained model and reset classification head
finetuned_model = ResidualCNN()
finetuned_model.load_state_dict(torch.load('./ssl_pretrained.pth'))

# Reset final layer for 10-class classification
finetuned_model.fc2 = nn.Linear(512, 10)

print("Loaded pretrained weights. Reset classification head to 10 classes.")

In [None]:
# Fine-tuning on 10% data
criterion_ft = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(finetuned_model.parameters(), lr=0.001, momentum=0.9)

ft_epochs = 20
ft_losses = []
ft_accs = []

print("Starting Fine-tuning on 10% labeled data...")
for epoch in range(ft_epochs):
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(subset_trainloader, desc=f"FT Epoch {epoch+1}/{ft_epochs}")
    
    for data in pbar:
        inputs, labels = data
        
        optimizer_ft.zero_grad()
        outputs = finetuned_model(inputs)
        loss = criterion_ft(outputs, labels)
        loss.backward()
        optimizer_ft.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(subset_trainloader)
    epoch_acc = 100 * correct / total
    ft_losses.append(epoch_loss)
    ft_accs.append(epoch_acc)
    
    print(f"FT Epoch {epoch+1} finished. Loss: {epoch_loss:.3f}, Acc: {epoch_acc:.2f}%")

print('Finished Fine-tuning')

In [None]:
# Plot Fine-tuning Training Curves
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(ft_losses, label='Fine-tuning Loss', marker='o')
plt.title('Fine-tuning Loss vs Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(ft_accs, label='Fine-tuning Accuracy', marker='o', color='green')
plt.title('Fine-tuning Accuracy vs Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)

plt.show()

In [None]:
# Test Fine-tuned Model
correct = 0
total = 0
with torch.no_grad():
    for data in tqdm(testloader, desc="Testing Fine-tuned Model"):
        images, labels = data
        outputs = finetuned_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

ft_test_acc = 100 * correct / total
print('Accuracy of the Fine-tuned Model on the 10000 test images: %d %%' % ft_test_acc)

### Discussion: Did SSL Help?

**Impact of SSL Pretraining:**
[Analyze whether the SSL pretraining helped achieve higher accuracy or faster convergence compared to training from scratch on the 10% data. Consider factors like final test accuracy and the shape of the training curves.]

**Observations:**
- SSL pretraining should provide a better initialization than random weights.
- The fine-tuned model may converge faster and achieve higher accuracy with limited labeled data.
- Compare the final test accuracy with the baseline Task 2 model trained on full data to understand the effectiveness of SSL.

---
**End of Task 3**