In [1]:
import torch
from torchvision import datasets, transforms
from torcheval.metrics import MulticlassAccuracy
import numpy as np

In [2]:
if torch.cuda.is_available():
    dev = "cuda:0"
else:
    dev = "cpu"
device = torch.device(dev)
device

device(type='cuda', index=0)

In [3]:
transform = transforms.Compose([transforms.ToTensor()])
mnist_pytorch = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

In [4]:
target_tensor = torch.nn.functional.one_hot(mnist_pytorch.targets).to(dtype=torch.float32).to(device)

need to add channel dimension

In [5]:
input_tensor = mnist_pytorch.data.to(dtype=torch.float32).unsqueeze(1).to(device)

In [6]:
input_ds = torch.utils.data.TensorDataset(input_tensor, target_tensor)

In [7]:
train_ds, test_ds = torch.utils.data.random_split(input_ds, [0.8, 0.2])
train_ds, val_ds = torch.utils.data.random_split(train_ds, [0.9, 0.1])

In [8]:
mini_batch_size = 512

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=mini_batch_size, shuffle=True, drop_last=False)
valid_dl = torch.utils.data.DataLoader(val_ds, batch_size=mini_batch_size * 2)

In [9]:
class CNN(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.cnn = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5, padding=2),
            torch.nn.BatchNorm2d(10),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, stride=2),
            torch.nn.Conv2d(in_channels=10, out_channels=25, kernel_size=5),
            torch.nn.BatchNorm2d(25),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=25, out_channels=40, kernel_size=3, padding=1),
            torch.nn.BatchNorm2d(40),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=40, out_channels=30, kernel_size=3, padding=1),
            torch.nn.BatchNorm2d(30),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, stride=2),
            torch.nn.Flatten(),
            torch.nn.Linear(750, 300),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.4),
            torch.nn.Linear(300, 100),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.4),
            torch.nn.Linear(100, 10)
        )

        # historical architecture
        self.LeNet = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),
            torch.nn.Sigmoid(),
            torch.nn.AvgPool2d(2, stride=2),
            torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            torch.nn.Sigmoid(),
            torch.nn.AvgPool2d(2, stride=2),
            torch.nn.Flatten(),
            torch.nn.Linear(400, 120),
            torch.nn.Sigmoid(),
            torch.nn.Linear(120, 84),
            torch.nn.Sigmoid(),
            torch.nn.Linear(84, 10)
        )

    def forward(self, X):
        X = self.cnn(X)
        # X = self.LeNet(X)
        return X

In [10]:
model = CNN().to(device)
optimizer = torch.optim.Adam(model.parameters())

number of trainable model parameters

In [11]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

344731

In [12]:
class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0
        self.best_model_state = None

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.best_model_state = model.state_dict()
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_model_state = model.state_dict()
            self.counter = 0

    def load_best_model(self, model):
        model.load_state_dict(self.best_model_state)

In [13]:
early_stopping = EarlyStopping(patience=5, delta=0.01)

In [14]:
def fit(epochs, model, optimizer, train_dl, valid_dl=None):
    loss_func = torch.nn.CrossEntropyLoss()

    # loop over epochs
    for epoch in range(epochs):
        model.train()

        # loop over mini-batches
        for X_mb, y_mb in train_dl:
            y_hat = model(X_mb)

            loss = loss_func(y_hat, y_mb)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

        model.eval()
        with torch.no_grad():
            train_loss = sum(loss_func(model(X_mb), y_mb) for X_mb, y_mb in train_dl)
            valid_loss = sum(loss_func(model(X_mb), y_mb) for X_mb, y_mb in valid_dl)
        print('epoch {}, training loss {}'.format(epoch + 1, train_loss / len(train_dl)))
        print('epoch {}, validation loss {}'.format(epoch + 1, valid_loss / len(valid_dl)))

        early_stopping(valid_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    print('Finished training')

    return model

In [15]:
epochs = 50

fit(epochs, model, optimizer, train_dl, valid_dl)

epoch 1, training loss 0.06688742339611053
epoch 1, validation loss 0.07452839612960815
epoch 2, training loss 0.039183855056762695
epoch 2, validation loss 0.05412019044160843
epoch 3, training loss 0.04167329519987106
epoch 3, validation loss 0.0640040710568428
epoch 4, training loss 0.02674599550664425
epoch 4, validation loss 0.04119248315691948
epoch 5, training loss 0.021828999742865562
epoch 5, validation loss 0.043792903423309326
epoch 6, training loss 0.02477993443608284
epoch 6, validation loss 0.047204531729221344
epoch 7, training loss 0.014193478040397167
epoch 7, validation loss 0.04048856720328331
epoch 8, training loss 0.010915343649685383
epoch 8, validation loss 0.03449602052569389
epoch 9, training loss 0.009004905819892883
epoch 9, validation loss 0.03168097883462906
epoch 10, training loss 0.011030924506485462
epoch 10, validation loss 0.033881377428770065
epoch 11, training loss 0.012234250083565712
epoch 11, validation loss 0.04153953120112419
epoch 12, training 

CNN(
  (cnn): Sequential(
    (0): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(10, 25, kernel_size=(5, 5), stride=(1, 1))
    (5): BatchNorm2d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Conv2d(25, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU()
    (10): Conv2d(40, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU()
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Flatten(start_dim=1, end_dim=-1)
    (15): Linear(in_features=750, out_features=300, bias

In [16]:
early_stopping.load_best_model(model)

In [17]:
def evaluation(ds, model):
    with torch.no_grad():
        preds = model(ds[:][0]).cpu()

    # take output node with highest probability
    yhat = np.argmax(preds, axis=1)
    y = ds[:][1].cpu()

    # from one-hot back to labels
    y = torch.argmax(y, dim=1)

    metric = MulticlassAccuracy()

    metric.update(yhat, y)
    return metric.compute()

In [18]:
evaluation(train_ds, model)

tensor(0.9982)

In [19]:
evaluation(test_ds, model)

tensor(0.9925)