data taken from https://www.cs.toronto.edu/~kriz/cifar.html

In [1]:
import torch
from torch import nn
from torcheval.metrics import MulticlassAccuracy
import numpy as np

In [2]:
if torch.cuda.is_available():
    dev = "cuda:0"
elif torch.backends.mps.is_available():
    dev = "mps"
else:
    dev = "cpu"
device = torch.device(dev)
device

device(type='mps')

In [3]:
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

In [4]:
cifar_data1 = unpickle("cifar-10-batches-py/data_batch_1")
cifar_data1.keys()

dict_keys([b'batch_label', b'labels', b'data', b'filenames'])

In [5]:
images = cifar_data1[b'data']
labels = cifar_data1[b'labels']

In [6]:
for data_batch in ["cifar-10-batches-py/data_batch_2",
                   "cifar-10-batches-py/data_batch_3",
                   "cifar-10-batches-py/data_batch_4",
                   "cifar-10-batches-py/data_batch_5"]:
    data_dict = unpickle(data_batch)

    images = np.append(images, data_dict[b'data'], axis=0)
    labels = labels + data_dict[b'labels']

In [7]:
len(labels)

50000

In [8]:
images.shape

(50000, 3072)

In [9]:
test_data = unpickle("cifar-10-batches-py/test_batch")
test_data.keys()

dict_keys([b'batch_label', b'labels', b'data', b'filenames'])

In [10]:
test_images = test_data[b'data']
test_labels = test_data[b'labels']

In [11]:
len(test_labels)

10000

In [12]:
test_images.shape

(10000, 3072)

convert to one-hot targets for multi-classification

In [13]:
target_tensor = torch.nn.functional.one_hot(torch.tensor(labels)).to(dtype=torch.float32).to(device)
test_target_tensor = torch.tensor(test_labels)
target_tensor

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.]], device='mps:0')

In [14]:
input_tensor = torch.tensor(images).to(dtype=torch.float32).to(device)
test_tensor = torch.tensor(test_images).to(dtype=torch.float32)

In [15]:
input_ds = torch.utils.data.TensorDataset(input_tensor, target_tensor)
test_ds = torch.utils.data.TensorDataset(test_tensor, test_target_tensor)

In [16]:
train_ds, val_ds = torch.utils.data.random_split(input_ds, [0.9, 0.1])

In [17]:
mini_batch_size = 64

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=mini_batch_size, shuffle=True, drop_last=False)
valid_dl = torch.utils.data.DataLoader(val_ds, batch_size=mini_batch_size * 2)

In [18]:
class FCN(nn.Module):
    def __init__(self):
        super().__init__()

        self.fcn = nn.Sequential(
            nn.Linear(3072, 3000),
            nn.ReLU(),
            nn.Linear(3000, 2000),
            nn.ReLU(),
            nn.Linear(2000, 1000),
            nn.ReLU(),
            nn.Linear(1000, 500),
            nn.ReLU(),
            nn.Linear(500, 100),
            nn.ReLU(),
            nn.Linear(100, 10)
        )

    def forward(self, X):
        X = self.fcn(X)
        return X

In [19]:
model = FCN().to(device)
optimizer = torch.optim.Adam(model.parameters())

number of trainable model parameters

In [20]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

17773610

In [21]:
class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0
        self.best_model_state = None

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.best_model_state = model.state_dict()
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_model_state = model.state_dict()
            self.counter = 0

    def load_best_model(self, model):
        model.load_state_dict(self.best_model_state)

In [22]:
early_stopping = EarlyStopping(patience=5, delta=0.01)

In [23]:
def fit(epochs, model, optimizer, train_dl, valid_dl=None):
    loss_func = torch.nn.CrossEntropyLoss()

    # loop over epochs
    for epoch in range(epochs):
        model.train()

        # loop over mini-batches
        for X_mb, y_mb in train_dl:
            y_hat = model(X_mb)

            loss = loss_func(y_hat, y_mb)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

        model.eval()
        with torch.no_grad():
            train_loss = sum(loss_func(model(X_mb), y_mb) for X_mb, y_mb in train_dl)
            valid_loss = sum(loss_func(model(X_mb), y_mb) for X_mb, y_mb in valid_dl)
        print('epoch {}, training loss {}'.format(epoch + 1, train_loss / len(train_dl)))
        print('epoch {}, validation loss {}'.format(epoch + 1, valid_loss / len(valid_dl)))

        early_stopping(valid_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    print('Finished training')

    return model

In [24]:
epochs = 50

trained_model = fit(epochs, model, optimizer, train_dl, valid_dl)

epoch 1, training loss 1.8320415019989014
epoch 1, validation loss 1.8176853656768799
epoch 2, training loss 1.735809326171875
epoch 2, validation loss 1.722028374671936
epoch 3, training loss 1.7056403160095215
epoch 3, validation loss 1.7233415842056274
epoch 4, training loss 1.6360268592834473
epoch 4, validation loss 1.6406323909759521
epoch 5, training loss 1.6145004034042358
epoch 5, validation loss 1.6317249536514282
epoch 6, training loss 1.58390474319458
epoch 6, validation loss 1.6007716655731201
epoch 7, training loss 1.5879112482070923
epoch 7, validation loss 1.6154510974884033
epoch 8, training loss 1.6189764738082886
epoch 8, validation loss 1.6524364948272705
epoch 9, training loss 1.569131851196289
epoch 9, validation loss 1.611169457435608
epoch 10, training loss 1.4945003986358643
epoch 10, validation loss 1.5475890636444092
epoch 11, training loss 1.5218746662139893
epoch 11, validation loss 1.5678530931472778
epoch 12, training loss 1.5570906400680542
epoch 12, val

In [25]:
early_stopping.load_best_model(model)
model = model.cpu()

In [26]:
def evaluation(ds, model, train=False):
    with torch.no_grad():
        preds = model(ds[:][0].cpu())

    # take output node with highest probability
    yhat = np.argmax(preds, axis=1)
    y = ds[:][1].cpu()

    # from one-hot back to labels
    if train:
        y = torch.argmax(y, dim=1)

    metric = MulticlassAccuracy()

    metric.update(yhat, y)
    return metric.compute()

In [27]:
evaluation(train_ds, trained_model, train=True)

tensor(0.4956)

In [28]:
evaluation(val_ds, model, train=True)

tensor(0.4464)

In [29]:
evaluation(test_ds, trained_model)

tensor(0.4420)

accuracy with batch1 training data only: 40%