In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
import numpy as np


import torch.nn.functional as F


class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()

        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=5)
        self.conv3 = nn.Conv2d(32,64, kernel_size=5)

        self.fc1 = nn.Linear(3*3*64, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(F.max_pool2d(self.conv3(x),2))
        x = F.dropout(x, p=0.5, training=self.training)
        x = x.view(-1,3*3*64 )
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)

        return F.log_softmax(x, dim=1)

def calculate_metrics(loader: DataLoader, model: nn.Module):
    y_true = []
    y_pred = []
    model.eval()
    with torch.no_grad():
        for inputs, labels in loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            y_pred.extend(predicted.numpy())
            y_true.extend(labels.numpy())

    accuracy = accuracy_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred, average='macro')
    precision = precision_score(y_true, y_pred, average='macro')
    f1 = f1_score(y_true, y_pred, average='macro')
    return accuracy, recall, precision, f1

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)


net = SimpleNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

val_acc, val_rec, val_prec, val_f1 =0,0,0,0
epoch = 0


while val_acc < 0.95 and val_rec < 0.95 and val_prec < 0.95 and val_f1 < 0.95:
    
    running_loss = 0.0
    net.train()
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = net(inputs)
        loss:nn.CrossEntropyLoss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(trainloader):.3f}')

    train_acc, train_rec, train_prec, train_f1 = calculate_metrics(trainloader, net)
    print(f'Training - Accuracy: {train_acc}, Recall: {train_rec}, Precision: {train_prec}, F1 Score: {train_f1}')

    val_acc, val_rec, val_prec, val_f1 = calculate_metrics(valloader, net)
    print(f'Validation - Accuracy: {val_acc}, Recall: {val_rec}, Precision: {val_prec}, F1 Score: {val_f1}')

    epoch+=1

print('Finished Training')

Epoch 1, Loss: 0.398
Training - Accuracy: 0.9048666666666667, Recall: 0.9035876703544565, Precision: 0.9061393442307928, F1 Score: 0.9030244593094515
Validation - Accuracy: 0.9055, Recall: 0.9042650959397939, Precision: 0.9061660697689474, F1 Score: 0.9034656441641733
Epoch 2, Loss: 0.322
Training - Accuracy: 0.90945, Recall: 0.9088221816912461, Precision: 0.9116646293776501, F1 Score: 0.9084181049509684
Validation - Accuracy: 0.9051, Recall: 0.904467189096825, Precision: 0.9074041509420676, F1 Score: 0.9039410091324687
Epoch 3, Loss: 0.313
Training - Accuracy: 0.9092666666666667, Recall: 0.9078265881555572, Precision: 0.9123361157518012, F1 Score: 0.9084821647208322
Validation - Accuracy: 0.9087, Recall: 0.9071194467241408, Precision: 0.9108829502220803, F1 Score: 0.907582462127668
Epoch 4, Loss: 0.305
Training - Accuracy: 0.91425, Recall: 0.9121144741755248, Precision: 0.916283925509908, F1 Score: 0.9130611663617023
Validation - Accuracy: 0.9125, Recall: 0.9106133233312528, Precision