In [None]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torch import nn

# **Dataset**

## **Load dataset**

In [None]:
transforms = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
     torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

In [None]:
training_data = torchvision.datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=transforms
)
testing_data = torchvision.datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=transforms
)

In [None]:
X, y = training_data.data, training_data.targets
print(f"Type of X: {type(X)}")
print(f"Size of X: {X.shape}")
print(f"Type of y: {type(y)}")
print(f"Size of y: {len(y)}")

In [None]:
num_class = {i: y.count(i) for i in range(10)}
print(f"Number of each class in training_data: \n{num_class}")

# **Model**

In [None]:
from torch.nn.modules.pooling import MaxPool2d
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.Stage_1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding='same'),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.Stage_2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding='same'),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.Stage_3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding='same'),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.Stage_4 = nn.Flatten(start_dim=1, end_dim=-1)

        self.Stage_5 = nn.Sequential(
            nn.Linear(4096, 1024, bias=True),
            nn.ReLU(),
            nn.Linear(1024, 512, bias=True),
            nn.ReLU(),
            nn.Linear(512, 10, bias=True)
        )

    def forward(self, X):
        Stage_1 = self.Stage_1(X)
        Stage_2 = self.Stage_2(Stage_1)
        Stage_3 = self.Stage_3(Stage_2)
        Stage_4 = self.Stage_4(Stage_3)
        Stage_5 = self.Stage_5(Stage_4)
        return Stage_5

    def train(self, training_data, optimizer="sgd", batch_size=64, epochs=10, lr=5e-6):
        training_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=2)
        Loss = nn.CrossEntropyLoss()
        if optimizer.lower()=="adam":
            Optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        else:
            Optimizer = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9)

        print("-----------------------------------------------")
        for i in range(epochs):
            for batch, (X, y) in enumerate(training_dataloader):
                y_pre = self.forward(X)
                loss_fn = Loss(y_pre, y)

                Optimizer.zero_grad()
                loss_fn.backward()
                Optimizer.step()
            print(f"-- Epoch {i}/{epochs}: Loss = {loss_fn}")
    
    def evaluate(self, testing_data):
        corrected, total = 0, 0
        testing_dataloader = DataLoader(testing_data, batch_size=1, shuffle=False)

        for batch, (X, y) in enumerate(testing_dataloader):
            y_pre = self.forward(X)
            if (y_pre.argmax().item() == y.item()):
                corrected += 1
            total += 1
        return (corrected / total) * 100

In [None]:
model = NeuralNetwork()
model.train(training_data, optimizer='adam', epochs=3, lr=1e-3)

In [None]:
acc_on_trainset = model.evaluate(training_data)
acc_on_testset = model.evaluate(testing_data)
print(f"Accuracy on trainset: {acc_on_trainset}")
print(f"Accuracy on testset: {acc_on_testset}")

In [None]:
model.train(training_data, optimizer='adam', epochs=3, lr=1e-3)

In [None]:
acc_on_trainset = model.evaluate(training_data)
acc_on_testset = model.evaluate(testing_data)
print(f"Accuracy on trainset: {acc_on_trainset}")
print(f"Accuracy on testset: {acc_on_testset}")