In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms.v2 as transforms
from datetime import datetime

In [2]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [3]:
transform = transforms.Compose([
    transforms.ToImage(),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize((0.5,), (0.5,))
])

In [4]:
train_data = torchvision.datasets.MNIST(root='../data', train=True, transform=transform, download=True)
test_data = torchvision.datasets.MNIST(root='../data', train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=True, num_workers=2)

In [5]:
class ShallowNet(nn.Module):
    def __init__(self, n_nodes):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, n_nodes)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(n_nodes, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [7]:
n_nodes = 256

net = ShallowNet(n_nodes=n_nodes).to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Rprop(net.parameters(), lr=0.01)

In [8]:
def train_loop(running_loss, train_loader):
    for i, data in enumerate(train_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            
            outputs = net(inputs)
            
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step() 
        
            running_loss += loss.item()
    return running_loss

In [9]:
def test_loop(test_loader):
    correct = 0
    total = 0
    test_loss = 0.0

    net.eval()
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs = net(images)
            _, prediction = torch.max(outputs, 1)
            total += labels.size(0)

            loss = loss_function(outputs, labels)
            test_loss += loss.item() * images.size(0)

            correct += (prediction == labels).sum().item()

    accuracy = 100 * correct / total
    test_loss /= len(test_loader)
    return test_loss, accuracy


In [10]:
print(f'Using optim: {optimizer}')
for epoch in range(50):
    print(f'Training epoch {epoch+1}...')

    running_loss = train_loop(running_loss=0.0, train_loader=train_loader)
    test_loss, accuracy = test_loop(test_loader=test_loader)

    print(f'Loss: {running_loss/len(train_loader):.4f}')
    print(f'Accuracy: {accuracy}%')

Using optim: Rprop (
Parameter Group 0
    capturable: False
    differentiable: False
    etas: (0.5, 1.2)
    foreach: None
    lr: 0.01
    maximize: False
    step_sizes: (1e-06, 50)
)
Training epoch 1...
Loss: 1.6065
Accuracy: 88.21%
Training epoch 2...
Loss: 2.7950
Accuracy: 87.99%
Training epoch 3...
Loss: 2.6479
Accuracy: 88.7%
Training epoch 4...
Loss: 2.6309
Accuracy: 88.69%
Training epoch 5...
Loss: 2.3588
Accuracy: 88.7%
Training epoch 6...
Loss: 2.4118
Accuracy: 88.91%
Training epoch 7...
Loss: 2.4185
Accuracy: 88.55%
Training epoch 8...
Loss: 2.4454
Accuracy: 89.22%
Training epoch 9...
Loss: 2.2739
Accuracy: 88.72%
Training epoch 10...
Loss: 2.2243
Accuracy: 88.59%
Training epoch 11...
Loss: 2.2552
Accuracy: 88.82%
Training epoch 12...
Loss: 2.2776
Accuracy: 88.87%
Training epoch 13...
Loss: 2.5929
Accuracy: 88.95%
Training epoch 14...
Loss: 2.1878
Accuracy: 89.3%
Training epoch 15...
Loss: 2.1063
Accuracy: 89.12%
Training epoch 16...
Loss: 2.4106
Accuracy: 89.12%
Trainin

In [11]:
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"trained_model_{current_time}.pth"

torch.save(net.state_dict(), filename)

In [12]:
net = ShallowNet(n_nodes=n_nodes)
net.load_state_dict(torch.load(filename, weights_only=True))
net.to(device)

ShallowNet(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=784, out_features=256, bias=True)
  (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (fc2): Linear(in_features=256, out_features=10, bias=True)
)

In [13]:
correct = 0
total = 0

net.eval()

with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)

        outputs = net(images)
        _, prediction = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (prediction == labels).sum().item()

accuracy = 100 * correct / total

print(f'Accuracy: {accuracy}%')

Accuracy: 89.75%
