### Inspired by: https://www.youtube.com/watch?v=aU8OF0htbTo&t=133s&ab_channel=PatrickLoeber

In [None]:
import torch
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
from copy import deepcopy

In [None]:
random_seed = 1
n_epochs = 3
batch_size_train = 64
batch_size_test = 64
learning_rate = 0.01
momentum = 0.5
log_interval = 10

# For reproducibility, this turns off cudnn's auto-tuner.
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
#global_mean = 0.1307 # MNIST train dataset mean
#global_std = 0.3081 # MNIST train dataset standard deviation

global_mean = 0.5 # MNIST train dataset mean
global_std = 0.5 # MNIST train dataset standard deviation

# Define a custom transform to flatten and reshape MNIST images to 784 (28x28)
# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((global_mean,), (global_std,)),
#     transforms.Lambda(lambda x: x.view(-1, 784))  # Reshape to (batch_size, 784)
# ])

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((global_mean,), (global_std,)),
])

In [None]:
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        './mnst_files/', train=True, download=True,
        transform=transform),
    batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        './mnst_files/', train=False, download=True,
        transform=transform),
    batch_size=batch_size_test, shuffle=True)

In [None]:
import matplotlib.pyplot as plt
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
print(example_data.shape)
print(example_targets.shape)

fig = plt.figure()
for i in range(6):
  plt.subplot(2,3,i+1)
  plt.tight_layout()
  plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
  plt.title("Ground Truth: {}".format(example_targets[i]))
  plt.xticks([])
  plt.yticks([])
fig

### Building the Network

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self, input_size=784, n_classes=10, hidden_size=[128, 64]):
        super(NeuralNetwork, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size[0])
        #nn.init.xavier_uniform_(self.fc1.weight)
        self.fc2 = nn.Linear(hidden_size[0], hidden_size[1])
        #nn.init.xavier_uniform_(self.fc2.weight)
        self.fc3 = nn.Linear(hidden_size[1], n_classes)
        #nn.init.xavier_uniform_(self.fc3.weight)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

### Define training loop

In [None]:
# Training function
def train(model, data_loader, optimizer, loss_fn, device):
    model.train()
    loss_total = 0.0
    correct = 0
    total_samples = 0

    for x, y in data_loader:
        x, y = x.to(device), y.to(device)
        # Flatten the input images
        x = x.view(x.size(0), -1)
        optimizer.zero_grad()
        y_hat = model(x)
        loss = loss_fn(y_hat, y)
        loss.backward()
        optimizer.step()
        loss_total += loss.item()
        correct += (y_hat.argmax(1) == y).sum().item()
        total_samples += y.size(0)

    accuracy = 100.0 * correct / total_samples
    loss_total /= len(train_loader)

    return loss_total, accuracy

# Validation function
def validate(model, data_loader, loss_fn, device):
    model.eval()
    loss_total = 0.0
    correct = 0
    total_samples = 0

    with torch.no_grad():
        for x, y in data_loader:
            x, y = x.to(device), y.to(device)
            # Flatten the input images
            x = x.view(x.size(0), -1)
            y_hat = model(x)
            loss = loss_fn(y_hat, y)
            loss_total += loss.item()
            correct += (y_hat.argmax(1) == y).sum().item()
            total_samples += y.size(0)

    accuracy = 100.0 * correct / total_samples
    loss_total /= len(data_loader)

    return loss_total, accuracy

def evaluate(model, data_loader, device):
    model.eval()
    predictions = []
    correct = 0
    total_samples = 0

    with torch.no_grad():
        for x, y in data_loader:
            x, y = x.to(device), y.to(device)
            # Flatten the input images
            x = x.view(x.size(0), -1)
            y_hat = model(x)
            correct += (y_hat.argmax(1) == y).sum().item()
            total_samples += y.size(0)

    accuracy = 100.0 * correct / total_samples
    return torch.cat(predictions), accuracy


def predict(model, data_loader, device):
    model.eval()
    predictions = []

    with torch.no_grad():
        for x in data_loader:
            x = x.to(device)
            # Flatten the input images
            x = x.view(x.size(0), -1)
            y_hat = model(x)
            _, predicted = y_hat.max(1)
            predictions.append(predicted)

    return torch.cat(predictions)

class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_validation_loss = float('inf')
        self.best_model_weights = None

    def early_stop(self, validation_loss, model):
        print(f"Early Stopping counter: {self.counter} out of {self.patience}")
        if validation_loss < self.best_validation_loss - self.min_delta:
            self.best_validation_loss = validation_loss
            self.counter = 1
            self.save_best_weights(model)
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

    def save_best_weights(self, model):
        self.best_model_weights = deepcopy(model.state_dict())

    def restore_best_weights(self, model):
        model.load_state_dict(self.best_model_weights)

In [None]:
# Lists to store loss and accuracy
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

# Training parameters
n_epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Instantiate model, loss function, and optimizer
model = NeuralNetwork().to(device)
loss_fn = nn.CrossEntropyLoss()
#optimizer = optim.Adam(model.parameters(), lr=0.8, weight_decay=0.01)
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
early_stopper = EarlyStopper(patience=5, min_delta=0.001)

for epoch in tqdm(range(n_epochs), desc='Training Progress'):
    # Training
    train_loss, train_acc = train(model, train_loader, optimizer, loss_fn, device)
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    print(f"[{epoch+1}/{n_epochs}] Train loss: {train_loss:.4f} acc: {train_acc:.2f}%")

    # Validation
    val_loss, val_acc = validate(model, test_loader, loss_fn, device)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    print(f'[{epoch+1}/{n_epochs}] Val loss: {val_loss:.4f} acc: {val_acc:.2f}%')

    if early_stopper.early_stop(val_loss, model):
        early_stopper.save_best_weights(model)
        print("Patience Depleated: Early Stopping triggered.")
        break
early_stopper.restore_best_weights(model)

In [None]:
import matplotlib.pyplot as plt
fig, (ax1, ax2) = plt.subplots(2, figsize=(12, 6), sharex=True)

# Copy the tensors from GPU to CPU
train_loss_cpu = torch.Tensor(train_losses).cpu().tolist()
val_loss_cpu = torch.Tensor(val_losses).cpu().tolist()
train_acc_cpu = torch.Tensor(train_accuracies).cpu().tolist()
val_acc_cpu = torch.Tensor(val_accuracies).cpu().tolist()

ax1.plot(train_loss_cpu, label="train loss")
ax1.plot(val_loss_cpu, label="validation loss")
ax1.set_ylabel("Loss")
ax2.plot(train_acc_cpu, label="train accuracy")
ax2.plot(val_acc_cpu, label="validation accuracy")
ax2.set_ylabel("Accuracy")
ax2.set_xlabel("epochs")
ax1.legend()
ax2.legend()