In [1]:
import torch
from torchvision import datasets, transforms
from torcheval.metrics import MulticlassAccuracy
import numpy as np

In [2]:
if torch.cuda.is_available():
    dev = "cuda:0"
else:
    dev = "cpu"
device = torch.device(dev)
device

device(type='cuda', index=0)

In [3]:
transform = transforms.Compose([transforms.ToTensor()])
mnist_pytorch = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

In [4]:
target_tensor = torch.nn.functional.one_hot(mnist_pytorch.targets).to(dtype=torch.float32).to(device)

need to add channel dimension

In [5]:
input_tensor = mnist_pytorch.data.to(dtype=torch.float32).unsqueeze(1).to(device)

In [6]:
input_ds = torch.utils.data.TensorDataset(input_tensor, target_tensor)

In [7]:
train_ds, test_ds = torch.utils.data.random_split(input_ds, [0.8, 0.2])
train_ds, val_ds = torch.utils.data.random_split(train_ds, [0.9, 0.1])

In [8]:
mini_batch_size = 512

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=mini_batch_size, shuffle=True, drop_last=False)
valid_dl = torch.utils.data.DataLoader(val_ds, batch_size=mini_batch_size * 2)

In [9]:
class CNN(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.cnn = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5, padding=2),
            torch.nn.BatchNorm2d(10),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, stride=2),
            torch.nn.Conv2d(in_channels=10, out_channels=25, kernel_size=5),
            torch.nn.BatchNorm2d(25),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=25, out_channels=40, kernel_size=3, padding=1),
            torch.nn.BatchNorm2d(40),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=40, out_channels=30, kernel_size=3, padding=1),
            torch.nn.BatchNorm2d(30),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, stride=2),
            torch.nn.Flatten(),
            torch.nn.Linear(750, 300),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.4),
            torch.nn.Linear(300, 100),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.4),
            torch.nn.Linear(100, 10)
        )

        # historical architecture
        self.LeNet = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),
            torch.nn.Sigmoid(),
            torch.nn.AvgPool2d(2, stride=2),
            torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            torch.nn.Sigmoid(),
            torch.nn.AvgPool2d(2, stride=2),
            torch.nn.Flatten(),
            torch.nn.Linear(400, 120),
            torch.nn.Sigmoid(),
            torch.nn.Linear(120, 84),
            torch.nn.Sigmoid(),
            torch.nn.Linear(84, 10)
        )

    def forward(self, X):
        X = self.cnn(X)
        # X = self.LeNet(X)
        return X

In [10]:
model = CNN().to(device)
optimizer = torch.optim.Adam(model.parameters())

In [11]:
def fit(epochs, model, optimizer, train_dl, valid_dl=None):
    loss_func = torch.nn.CrossEntropyLoss()

    # loop over epochs
    for epoch in range(epochs):
        model.train()

        # loop over mini-batches
        for X_mb, y_mb in train_dl:
            y_hat = model(X_mb)

            loss = loss_func(y_hat, y_mb)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

        model.eval()
        with torch.no_grad():
            train_loss = sum(loss_func(model(X_mb), y_mb) for X_mb, y_mb in train_dl)
            valid_loss = sum(loss_func(model(X_mb), y_mb) for X_mb, y_mb in valid_dl)
        print('epoch {}, training loss {}'.format(epoch + 1, train_loss / len(train_dl)))
        print('epoch {}, validation loss {}'.format(epoch + 1, valid_loss / len(valid_dl)))

    print('Finished training')

    return model

In [12]:
epochs = 20

trained_model = fit(epochs, model, optimizer, train_dl, valid_dl)

epoch 1, training loss 0.06457460671663284
epoch 1, validation loss 0.06789901107549667
epoch 2, training loss 0.04137010872364044
epoch 2, validation loss 0.0526127815246582
epoch 3, training loss 0.03103666938841343
epoch 3, validation loss 0.04165218397974968
epoch 4, training loss 0.030297299847006798
epoch 4, validation loss 0.044845182448625565
epoch 5, training loss 0.028076807036995888
epoch 5, validation loss 0.043241746723651886
epoch 6, training loss 0.01603178307414055
epoch 6, validation loss 0.041545260697603226
epoch 7, training loss 0.016189923509955406
epoch 7, validation loss 0.0338912159204483
epoch 8, training loss 0.014538048766553402
epoch 8, validation loss 0.03178149089217186
epoch 9, training loss 0.009976143017411232
epoch 9, validation loss 0.033225882798433304
epoch 10, training loss 0.011175926774740219
epoch 10, validation loss 0.03820788860321045
epoch 11, training loss 0.01431985106319189
epoch 11, validation loss 0.048481084406375885
epoch 12, training 

In [13]:
def evaluation(ds, model):
    with torch.no_grad():
        preds = model(ds[:][0]).cpu()

    # take output node with highest probability
    yhat = np.argmax(preds, axis=1)
    y = ds[:][1].cpu()

    # from one-hot back to labels
    y = torch.argmax(y, dim=1)

    metric = MulticlassAccuracy()

    metric.update(yhat, y)
    return metric.compute()

In [14]:
evaluation(train_ds, trained_model)

tensor(0.9991)

In [15]:
evaluation(test_ds, trained_model)

tensor(0.9911)