In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms.v2 as transforms
from datetime import datetime

In [2]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [3]:
transform = transforms.Compose([
    transforms.ToImage(),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize((0.5,), (0.5,))
])

In [4]:
train_data = torchvision.datasets.MNIST(root='../data', train=True, transform=transform, download=True)
test_data = torchvision.datasets.MNIST(root='../data', train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1000, shuffle=True, num_workers=2)

In [5]:
class ShallowNet(nn.Module):
    def __init__(self, n_nodes):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, n_nodes)
        self.bn1 = nn.BatchNorm1d(n_nodes)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(n_nodes, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [6]:
class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0
        self.best_model_state = None

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.best_model_state = model.state_dict()
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_model_state = model.state_dict()
            self.counter = 0

    def load_best_model(self, model):
        model.load_state_dict(self.best_model_state)

In [7]:
n_nodes = 16
net = ShallowNet(n_nodes).to(device)

loss_function = nn.CrossEntropyLoss()
optimizer = optim.Rprop(net.parameters(), lr=0.01)

In [8]:
def train_loop(running_loss, train_loader):
    for i, data in enumerate(train_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            
            inputs = inputs.view(-1, 28*28)
            outputs = net(inputs)
            
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step() 
        
            running_loss += loss.item()
    return running_loss

In [9]:
def test_loop(test_loader):
    correct = 0
    total = 0
    test_loss = 0.0

    net.eval()
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            images = images.view(-1, 28*28)
            outputs = net(images)
            _, prediction = torch.max(outputs, 1)
            total += labels.size(0)

            loss = loss_function(outputs, labels)
            test_loss += loss.item() * images.size(0)

            correct += (prediction == labels).sum().item()

    accuracy = 100 * correct / total
    test_loss /= len(test_loader)
    return test_loss, accuracy


In [10]:
early_stopping = EarlyStopping(patience=10, delta=0.01)

best_accuracy = 0.0

print(f'Net structure: {net}')
print(f'Using optim: {optimizer}')
for epoch in range(50):
    print(f'Training epoch {epoch+1}...')

    running_loss = train_loop(running_loss=0.0, train_loader=train_loader)
    test_loss, accuracy = test_loop(test_loader=test_loader)

    print(f'Loss: {running_loss/len(train_loader):.4f}')
    print(f'Accuracy: {accuracy}%')

    if (accuracy > best_accuracy):
        best_accuracy = accuracy

    early_stopping(test_loss, net)
    if early_stopping.early_stop:
        print("Early stopping")
        break   

early_stopping.load_best_model(net)
print(f'Best accuracy was: {best_accuracy}% with n_nodes {n_nodes}')

Net structure: ShallowNet(
  (fc1): Linear(in_features=784, out_features=16, bias=True)
  (bn1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (fc2): Linear(in_features=16, out_features=10, bias=True)
)
Using optim: Rprop (
Parameter Group 0
    capturable: False
    differentiable: False
    etas: (0.5, 1.2)
    foreach: None
    lr: 0.01
    maximize: False
    step_sizes: (1e-06, 50)
)
Training epoch 1...
Loss: 1.0559
Accuracy: 87.97%
Training epoch 2...
Loss: 1.0240
Accuracy: 88.03%
Training epoch 3...
Loss: 0.9520
Accuracy: 88.07%
Training epoch 4...
Loss: 0.9218
Accuracy: 88.24%
Training epoch 5...
Loss: 0.9122
Accuracy: 88.35%
Training epoch 6...
Loss: 0.9089
Accuracy: 88.33%
Training epoch 7...
Loss: 0.8838
Accuracy: 88.38%
Training epoch 8...
Loss: 0.8904
Accuracy: 88.42%
Training epoch 9...
Loss: 0.8670
Accuracy: 88.51%
Training epoch 10...


KeyboardInterrupt: 

In [None]:
correct = 0
total = 0

net.eval()

with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)

        outputs = net(images)
        _, prediction = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (prediction == labels).sum().item()

accuracy = 100 * correct / total

print(f'Accuracy: {accuracy}%')

Accuracy: 89.34%
