In [1]:
%pip install torch torchvision pandas matplotlib numpy

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as transforms

import shared

Note: you may need to restart the kernel to use updated packages.


In [2]:
dl = shared.load_data("dataset/train.csv", 256, shuffle=True, train=True)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu")
# device = torch.device("cpu")

print(f"Using device: {device}")

model = shared.SimpleCNN()
loss = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-2)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1)
model.to(device)

def train(model, dl, loss, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for i, (data, target) in enumerate(dl):
        data, target = data.to(torch.float32).to(device), target.to(torch.float32).to(device)

        optimizer.zero_grad()
        output = model(data)
        loss_value = loss(output, target)
        loss_value.backward()
        optimizer.step()

        total_loss += loss_value.item()
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        print(f'Batch Progress: {(i + 1) / len(dl):.5f}, Batch Loss: {total_loss / total:.4f}, Accuracy: {correct / total:.5f}', end='\r', flush=True)
    print()

    accuracy = correct / total
    return total_loss / len(dl), accuracy

for epoch in range(1, 11):
    train_loss, train_accuracy = train(model, dl, loss, optimizer, device)
    print(f'Epoch {epoch}: Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.4f}, LR: {np.log10(scheduler.get_last_lr()[0]):.6f}')
    scheduler.step()

# Save the model
torch.save(model.state_dict(), 'model.pth')

Using device: mps
Batch Progress: 1.00000, Batch Loss: 0.0010, Accuracy: 0.93378
Epoch 1: Loss: 0.2488, Accuracy: 0.9338, LR: -2.000000
Batch Progress: 1.00000, Batch Loss: 0.0001, Accuracy: 0.98896
Epoch 2: Loss: 0.0349, Accuracy: 0.9890, LR: -3.000000
Batch Progress: 1.00000, Batch Loss: 0.0001, Accuracy: 0.99163
Epoch 3: Loss: 0.0260, Accuracy: 0.9916, LR: -4.000000
Batch Progress: 1.00000, Batch Loss: 0.0001, Accuracy: 0.99237
Epoch 4: Loss: 0.0244, Accuracy: 0.9924, LR: -5.000000
Batch Progress: 1.00000, Batch Loss: 0.0001, Accuracy: 0.99243
Epoch 5: Loss: 0.0239, Accuracy: 0.9924, LR: -6.000000
Batch Progress: 1.00000, Batch Loss: 0.0001, Accuracy: 0.99249
Epoch 6: Loss: 0.0242, Accuracy: 0.9925, LR: -7.000000
Batch Progress: 1.00000, Batch Loss: 0.0001, Accuracy: 0.99241
Epoch 7: Loss: 0.0241, Accuracy: 0.9924, LR: -8.000000
Batch Progress: 1.00000, Batch Loss: 0.0001, Accuracy: 0.99241
Epoch 8: Loss: 0.0244, Accuracy: 0.9924, LR: -9.000000
Batch Progress: 1.00000, Batch Loss: 0