# Exercise 1: Convolutional Neural Networks

In [None]:
import numpy as np
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Exercise 1.1: Load the CIFAR-10 
Run the code below to load the [CIFAR-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html) and visualize some of the images in the dataset. The images in this dataset are 32 x 32 color images of the following objects: \
`classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')`

In [None]:
# Define batch size for the train and test datasets
batch_size = 64

# Load CIFAR-10 dataset and normalize the RGB channels
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # normalize RGB channels
])

train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform
)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Visualize some example images from the dataset
classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')
def imshow(img, title=None):
    img = img / 2 + 0.5  # unnormalize RGB channels
    plt.imshow(np.transpose(img.numpy(), (1, 2, 0)))
    plt.axis("off")
    if title:
        plt.title(title)

# Get a few random training images
sample_images, sample_labels = next(iter(train_loader))

plt.figure(figsize=(8, 2))
imshow(torchvision.utils.make_grid(sample_images[:8]))
plt.title("Sample CIFAR-10 Images:\n" + "  ".join(classes[sample_labels[j]] for j in range(8)))
plt.show()

#### Exercise 1.2: Define Simple CNN Model and Train
Implement the structure of the CNN in the class `SimpleCNN` below. Specifically:
1. In `SimpleCNN.__init__`, define `self.conv_layers` and `self.fc_layers` for the convolution and fully-connected layers of the network.
    1. Your model should have two convolution layers in `self.conv_layers`, each with ReLU activation after the convolution and a 2D max pooling to downsample. Use the parameter values given below to define each feature.
    2. For `self.fc_layers`, you should flatten the output of the convolution layers, then have two linear layers separated by a ReLU activation. Use the parameter value given below to define the output size of the first linear layer. You must determine the size of the input for the first linear layer and the size of the final layer output.
2. Implement the remaining code in the training loop to train your model using the `criterion` and `optimizer` provided.
3. Run the code provided to evaluate the accuracy of your model.

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        conv_kernel_size = 3
        conv_padding = 1
        max_pool_kernel_size = 2
        max_pool_stride = 2
        conv1_out_channels = 32
        conv2_out_channels = 64
        ########## Code starts here ##########
        # Define self.conv_layers
        # Hint: use nn.Sequential
        
        ########## Code ends here ############

        hidden_layer_size = 256
        ########## Code starts here ##########
        # Define self.fc_layers
        # Hint: use nn.Sequential
        
        ########## Code ends here ############

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x

Now, run the code below to train your model. You should see the epoch loss decreasing.

In [None]:
model = SimpleCNN().to(device)

# Define loss function and optimizer
learning_rate = 0.001
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        ########## Code starts here ##########
        
        ########## Code ends here ############

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

Run the code below to estimate the accuracy of your trained model on the test dataset.

In [None]:
# Evaluate model accuracy on test dataset
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predictions = torch.max(outputs, 1)
        all_preds.extend(predictions.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

accuracy = 100 * np.mean(np.array(all_preds) == np.array(all_labels))
print(f"\nTest Accuracy: {accuracy:.2f}%")

#### Exercise 1.3: Analyze Confusion Matrix
Run the provided code to plot the confusion matrix from your test above. Take a look at the results, do you see any interesting patterns?

In [None]:
# Display confusion matrix
cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)

plt.figure(figsize=(8, 8))
disp.plot(cmap='Blues', xticks_rotation=45, colorbar=False)
plt.title("Confusion Matrix - CIFAR-10")
plt.show()

Finally, run the code provided below to revisit the sampled images from the dataset and see how your model predicted the image's classes!

In [None]:
# Display samples images from above
sample_images = sample_images.to(device)
with torch.no_grad():
    outputs = model(sample_images)
    _, predicted = torch.max(outputs, 1)

plt.figure(figsize=(8, 2))
imshow(torchvision.utils.make_grid(sample_images[:8].cpu()))
plt.title("Predicted: " + "  ".join(classes[predicted[j]] for j in range(8)) +
          "\nTrue:      " + "  ".join(classes[sample_labels[j]] for j in range(8)))
plt.show()