In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import time
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.init as init

In [2]:
print("CUDA Device Index:", torch.cuda.current_device())
print(torch.cuda.get_device_properties(0)) 

CUDA Device Index: 0
_CudaDeviceProperties(name='NVIDIA GeForce RTX 4060 Laptop GPU', major=8, minor=9, total_memory=8187MB, multi_processor_count=24)


In [3]:
import os

if not os.path.exists("../data/mnist_train.csv"):
    !curl -O https://pjreddie.com/media/files/mnist_train.csv

if not os.path.exists("../data/mnist_test.csv"):
    !curl -O https://pjreddie.com/media/files/mnist_test.csv

In [4]:
torch.manual_seed(250228)
torch.cuda.manual_seed_all(250228)

def read_dataset(csv_path):
    df = pd.read_csv(csv_path, header=None).values
    x = torch.tensor(df[:, 1:] / 255.0, dtype=torch.float32)
    y = torch.tensor(df[:, 0], dtype=torch.long)
    return x, y

x_train, y_train = read_dataset("../data/mnist_train.csv")
x_test, y_test = read_dataset("../data/mnist_test.csv")

device = torch.device("cuda")

batch_size = 64
epochs = 30
lr = 0.03

train_loader = DataLoader(TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)
test_loader = DataLoader(TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 320)
        self.fc2 = nn.Linear(320, 160)
        self.fc3 = nn.Linear(160, 10)
        self.init_params()

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                # He Normalization: https://paperswithcode.com/method/he-initialization
                init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                init.zeros_(m.bias)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Net().to(device)
# torch.set_float32_matmul_precision('high')
# model = torch.compile(model) # add.

In [5]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr)

training_time = 0.0

for epoch in range(1, epochs + 1):
    model.train()
    epoch_start = time.time()
    train_loss = 0.0
    train_correct = 0
    train_cnt = 0
    
    for X, y in train_loader:
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        y_hat = model(X)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * X.size(0)
        _, pred = torch.max(y_hat, 1)
        train_cnt += y.size(0)
        train_correct += (pred == y).sum().item()

    epoch_time = (time.time() - epoch_start) * 1000  # ms
    training_time += epoch_time

    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_cnt = 0
    with torch.no_grad():
        for X, y in test_loader:
            X, y = X.to(device), y.to(device)
            y_hat = model(X)
            loss = criterion(y_hat, y)
            val_loss += loss.item() * X.size(0)
            _, pred = torch.max(y_hat, 1)
            val_cnt += y.size(0)
            val_correct += (pred == y).sum().item()

    train_acc   = train_correct / train_cnt
    val_acc     = val_correct / val_cnt
    train_loss  /= train_cnt
    val_loss    /= val_cnt
    
    print(f"epoch {epoch}: {epoch_time:.0f}ms "
          f"train loss: {train_loss:.6f} "
          f"train accuracy: {train_acc:.6f} "
          f"val loss: {val_loss:.6f} "
          f"val accuracy: {val_acc:.6f}"
    )

print(f"Total Training Time: {training_time:.0f}ms")
print(f"{training_time / 30:.0f}ms per epoch")

epoch 1: 2002ms train loss: 0.444774 train accuracy: 0.878869 val loss: 0.268687 val accuracy: 0.921474
epoch 2: 1656ms train loss: 0.225104 train accuracy: 0.935832 val loss: 0.203643 val accuracy: 0.939103
epoch 3: 1854ms train loss: 0.175538 train accuracy: 0.949606 val loss: 0.166154 val accuracy: 0.949720
epoch 4: 2040ms train loss: 0.144439 train accuracy: 0.958261 val loss: 0.141514 val accuracy: 0.957632
epoch 5: 1712ms train loss: 0.122692 train accuracy: 0.964631 val loss: 0.125060 val accuracy: 0.963341
epoch 6: 1931ms train loss: 0.106329 train accuracy: 0.969400 val loss: 0.113687 val accuracy: 0.966346
epoch 7: 2121ms train loss: 0.093335 train accuracy: 0.973152 val loss: 0.105405 val accuracy: 0.968450
epoch 8: 1898ms train loss: 0.082669 train accuracy: 0.976638 val loss: 0.098768 val accuracy: 0.971154
epoch 9: 1771ms train loss: 0.073775 train accuracy: 0.979439 val loss: 0.093778 val accuracy: 0.972756
epoch 10: 1912ms train loss: 0.066209 train accuracy: 0.981590 v