In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import os
import numpy as np
from torch.utils.data import DataLoader, TensorDataset, random_split

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

mnist_train = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_train_loader = DataLoader(mnist_train, batch_size=64, shuffle=True)

mnist_test = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
mnist_test_loader = DataLoader(mnist_test, batch_size=64, shuffle=False)

class LeNet5_MNIST(nn.Module):
    def __init__(self):
        super(LeNet5_MNIST, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model_mnist = LeNet5_MNIST()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_mnist.parameters(), lr=0.001)

# File path for saving and loading model
model_path = './lenet5_mnist.pth'

# Check if model already exists, and load it if so
if os.path.exists(model_path):
    print(f'Loading saved model from {model_path}')
    model_mnist.load_state_dict(torch.load(model_path, weights_only=True))
else:
    print('No saved model found, training from scratch.')

    # Training loop for MNIST
    epochs = 10
    for epoch in range(epochs):
        model_mnist.train()
        total_loss = 0
        for x_batch, y_batch in mnist_train_loader:
            optimizer.zero_grad()
            y_pred = model_mnist(x_batch)
            loss = criterion(y_pred, y_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f'Epoch {epoch+1}/{epochs}, MNIST Train Loss: {total_loss / len(mnist_train_loader):.4f}')

    # Save the model after training
    torch.save(model_mnist.state_dict(), model_path)
    print(f'Model saved to {model_path}')

# Test
model_mnist.eval()
correct = 0
total = 0
with torch.no_grad():
    for x_batch, y_batch in mnist_test_loader:
        y_pred = model_mnist(x_batch)
        _, predicted = torch.max(y_pred, 1)
        total += y_batch.size(0)
        correct += (predicted == y_batch).sum().item()
mnist_acc = correct / total
print(f'MNIST Test Accuracy: {mnist_acc:.4f}')
