In [None]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
# Load CIFAR-10 dataset
training_data = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

test_data = datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device} device")

# Fully Connected Neural Network
class FullyConnectedNeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(32 * 32 * 3, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

# Convolutional Neural Network
class ConvolutionalNeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.convolutional_stack = nn.Sequential(
           nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),                # Pooling Layer 2
        )
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(128 * 8 * 8, 256),  # Match the flattened size
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

    def forward(self, x):
        x = self.convolutional_stack(x)
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

# Training function
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    total_loss = 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)
        total_loss += loss.item()

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    return total_loss / len(dataloader)

# Testing function
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    correct_example, incorrect_example = None, None
    correct_label, incorrect_label, incorrect_pred = None, None, None

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            for i in range(len(X)):
                if correct_example is None and pred[i].argmax() == y[i]:
                    correct_example = X[i].cpu()
                    correct_label = y[i].item()
                if incorrect_example is None and pred[i].argmax() != y[i]:
                    incorrect_example = X[i].cpu()
                    incorrect_label = y[i].item()
                    incorrect_pred = pred[i].argmax().item()
                if correct_example is not None and incorrect_example is not None:
                    break
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return test_loss, correct_example, correct_label, incorrect_example, incorrect_label, incorrect_pred

# Initialize models, loss function, and optimizers
loss_fn = nn.CrossEntropyLoss()

fcnn_model = FullyConnectedNeuralNetwork().to(device)
cnn_model = ConvolutionalNeuralNetwork().to(device)

fcnn_optimizer = torch.optim.Adam(fcnn_model.parameters(), lr=1e-3)
cnn_optimizer = torch.optim.Adam(cnn_model.parameters(), lr=1e-3)

# Train and test both models
epochs = 20

fcnn_losses = []
cnn_losses = []

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")

    # Fully Connected Neural Network
    print("Training Fully Connected Neural Network")
    fcnn_loss = train(train_dataloader, fcnn_model, loss_fn, fcnn_optimizer)
    fcnn_losses.append(fcnn_loss)
    print("Testing Fully Connected Neural Network")
    fcnn_images = test(test_dataloader, fcnn_model, loss_fn)

    # Convolutional Neural Network
    print("Training Convolutional Neural Network")
    cnn_loss = train(train_dataloader, cnn_model, loss_fn, cnn_optimizer)
    cnn_losses.append(cnn_loss)
    print("Testing Convolutional Neural Network")
    cnn_images = test(test_dataloader, cnn_model, loss_fn)

# Plots for epoch vs loss
plt.figure()
plt.plot(fcnn_losses, label='Fully Connected Neural Network')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss vs. Epoch for Fully Connected Neural Network')
plt.show()

plt.figure()
plt.plot(cnn_losses, label='Convolutional Neural Network')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss vs. Epoch for Convolutional Neural Network')
plt.show()

classes = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]

# Plots for correct/incorrect images
plt.figure()
plt.imshow(fcnn_images[1].permute(1, 2, 0))
plt.title(f'Correct Label: {classes[fcnn_images[2]]}')
plt.axis('off')
plt.show()

plt.figure()
plt.imshow(fcnn_images[3].permute(1, 2, 0))
plt.title(f'Incorrect Label: {classes[fcnn_images[4]]}, Predicted: {classes[fcnn_images[5]]}')
plt.axis('off')
plt.show()

plt.figure()
plt.imshow(cnn_images[1].permute(1, 2, 0))
plt.title(f'Correct Label: {classes[cnn_images[2]]}')
plt.axis('off')
plt.show()

plt.figure()
plt.imshow(cnn_images[3].permute(1, 2, 0))
plt.title(f'Incorrect Label: {classes[cnn_images[4]]}, Predicted: {classes[cnn_images[5]]}')
plt.axis('off')
plt.show()
