In [1]:
import torch
import torch.nn as nn
import matplotlib
from torch.utils.data import DataLoader
import dataloader as dl
import network

In [2]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cpu device


In [3]:
train_dataset = dl.KMNIST('./data/', True)
test_dataset = dl.KMNIST('./data/', False)

train_data = DataLoader(train_dataset, 64, True)
test_data = DataLoader(test_dataset, 64)

In [4]:
model = network.LeNet(1, 10).to(device)

In [5]:
loss_fn = nn.NLLLoss()
optimizer = torch.optim.Adagrad(model.parameters(), 1e-2)

In [6]:
def train(data, model, loss_fn, optimizer):
    model.train()
    train_loss = 0
    for batch, (X, y) in enumerate(data):
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        train_loss += loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    train_loss = train_loss/len(data)
    print(f"Epoch was completed! Avg loss: {train_loss:>8f}")

In [7]:
def test(data, model, loss_fn):
    model.eval()
    correct = 0
    test_loss = 0
    with torch.no_grad():
        for (X, y) in data:
            pred = model(X)
            test_loss += loss_fn(pred, y)
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= len(data)
    correct /= len(data.dataset)
    print(f"Test Error:\n Accuracy: {(100*correct):>4f}%. Avg loss: {test_loss:>8f}")

In [8]:
epochs = 10
for epoch in range(epochs):
    print(f"Epoch {epoch+1} is starting!")
    train(train_data, model, loss_fn, optimizer)
    test(test_data, model, loss_fn)

Epoch 1 is starting!
Epoch was completed! Avg loss: 1.322644
Test Error:
 Accuracy: 85.330000%. Avg loss: 0.523346
Epoch 2 is starting!
Epoch was completed! Avg loss: 0.144912
Test Error:
 Accuracy: 89.380000%. Avg loss: 0.395908
Epoch 3 is starting!
Epoch was completed! Avg loss: 0.087848
Test Error:
 Accuracy: 90.840000%. Avg loss: 0.345439
Epoch 4 is starting!
Epoch was completed! Avg loss: 0.059516
Test Error:
 Accuracy: 91.100000%. Avg loss: 0.343190
Epoch 5 is starting!
Epoch was completed! Avg loss: 0.040660
Test Error:
 Accuracy: 91.930000%. Avg loss: 0.319676
Epoch 6 is starting!
Epoch was completed! Avg loss: 0.028552
Test Error:
 Accuracy: 91.820000%. Avg loss: 0.336711
Epoch 7 is starting!
Epoch was completed! Avg loss: 0.019578
Test Error:
 Accuracy: 92.350000%. Avg loss: 0.332820
Epoch 8 is starting!
Epoch was completed! Avg loss: 0.013495
Test Error:
 Accuracy: 92.470000%. Avg loss: 0.333104
Epoch 9 is starting!
Epoch was completed! Avg loss: 0.009260
Test Error:
 Accura