In [50]:
import numpy as np
from tqdm import tqdm, trange

import torch
import torch.nn as nn
from torch import optim
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, Subset, default_collate
from torchvision import transforms
from torchvision.transforms import ToTensor, v2
from torchvision import datasets

np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x208b8b5c5f0>

In [2]:
# load fashion MNIST data and transform images
transform = transforms.Compose([transforms.Resize((64, 64)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=transform
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=transform
)

In [3]:
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}

In [4]:
def patchify(images, n_patches):
    n, c, h, w = images.shape # n=num_images, c=image dimension

    # ensure input images are squares
    assert h==w, "image must be square"

    # instantiate patches as a 3D zero tensor (** means power of)
    patches = torch.zeros(n, n_patches**2, h*w*c // n_patches**2) # (num_images, num_patches, patch dimension)
    patch_size = h //n_patches

    for index, image in enumerate(images):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = image[:, i*patch_size: (i+1)*patch_size, j*patch_size: (j+1)*patch_size]
                patches[index, i*n_patches + j] = patch.flatten()
    return patches


In [5]:
def get_positional_embedding(seq_length, dim):
    pe = torch.ones(seq_length, dim)
    for i in range(seq_length):
        for j in range(dim):
            if j % 2 == 0: # even
                pe[i][j] = np.sin(i/(10000 ** (j/dim)))
            elif j % 2 == 1: # odd
                pe[i][j] = np.cos(i / (10000 ** ((j-1)/dim)))
    return pe

In [6]:
class MultiHeadSA(nn.Module):
    def __init__(self, dim, n_heads=2):
        super(MultiHeadSA, self).__init__()
        self.dim = dim
        self.n_heads = n_heads

        assert dim % n_heads == 0, f"Can't divide dimension {dim} into {n_heads} heads"

        # creating weight matrix of q, k and v
        d_head = int(dim / n_heads)
        self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for x in range(self.n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for x in range(self.n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for x in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        # Sequences has shape (N, seq_length, token_dim)
        # We go into shape    (N, seq_length, n_heads, token_dim / n_heads)
        # And come back to    (N, seq_length, item_dim)  (through concatenation)
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]

                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)
                
                # calculate attention score
                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))

        # concat attention from each head together
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

In [7]:
class ResidualConnection(nn.Module):
    def __init__(self, hidden_dim, n_heads, mlp_ratio=4):
        super(ResidualConnection, self).__init__()
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(hidden_dim)

        # self-attention layer
        self.mhsa = MultiHeadSA(hidden_dim, n_heads)

        self.norm2 = nn.LayerNorm(hidden_dim)

        # feed forward layer
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, mlp_ratio * hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_dim, hidden_dim)
        )

    def forward(self, x):
        mhsa_out = x + self.mhsa(self.norm1(x))
        ff_out = mhsa_out + self.mlp(self.norm2(mhsa_out))
        return ff_out

In [8]:
# instantiate model
class VisionTransformer(nn.Module):
    def __init__(self, chw, n_patches, hidden_dim, n_encodelayers, n_heads, output_dim):
    super(VisionTransformer, self).__init__()

    self.chw = chw
    self.n_patches = n_patches
    self.hidden_dim = hidden_dim
    self.n_encodelayers = n_encodelayers
    self.n_heads = n_heads
    self.output_dim = output_dim

    # ensure that height and width are divisble by num of patches
    assert chw[1] % n_patches == 0, "Input shape is not divisible by number of patches"
    assert chw[2] % n_patches == 0, "Input shape is not divisible by number of patches"
    self.patch_size = (chw[1]/n_patches, chw[2]/n_patches)

    # linear mapping
    self.input_dim = int(chw[0] * self.patch_size[0] * self.patch_size[1])
    self.linear_map = nn.Linear(self.input_dim, self.hidden_dim)

    # classification head
    self.class_token = nn.Parameter(torch.rand(1, self.hidden_dim))

    # positional encoding
    self.pos_embed = nn.Parameter(torch.tensor(get_positional_embedding(self.n_patches ** 2 + 1, self.hidden_dim)))
    # self.pos_embed.requires_grad = False

    # transformer encoder
    self.encoder_layers = nn.ModuleList([ResidualConnection(hidden_dim, n_heads) for x in range(n_encodelayers)])

    # extract classification token
    self.mlp = nn.Sequential(
        nn.Linear(self.hidden_dim, output_dim),
        nn.Softmax(dim=-1)
    )

    def forward(self, input_image):
    patches = patchify(input_image, self.n_patches)
    tokens = self.linear_map(patches)

    # add classification head to each token (learnable embedding)
    tokens_with_class = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])

    # add positional embedding
    pos_embed = self.pos_embed.repeat(tokens_with_class.shape[0], 1, 1)
    pos_embed_token = tokens_with_class + pos_embed

    # transformer encoding layers
    for layer in self.encoder_layers:
        output = layer(pos_embed_token)

    # extract classification token
    output = output[:, 0]
    pred = self.mlp(output)

    return pred # Map to output dimension, output category distribution

In [9]:
# train and test loop
def train_loop(train_loader, model, loss_fn, optimizer, epoch):
    train_loss = 0.0
    correct, total = 0, 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
        x, y = batch
        output = model(x)
        loss = loss_fn(output, y)

        train_loss += (loss.item() / len(train_loader))
        correct += torch.sum(torch.argmax(output, dim=1) == y).item()
        total += len(x)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} loss: {train_loss:.2f}")
    print(f"Validation accuracy: {correct / total * 100:.2f}%")

def test_loop(test_loader, model, loss_fn):
    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            output = model(x)
            loss = loss_fn(output, y)
            test_loss += (loss.item() / len(test_loader))

            correct += torch.sum(torch.argmax(output, dim=1) == y).item()
            total += len(x)
        test_acc = correct / total * 100
        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {test_acc:.2f}%")
    return test_acc, test_loss

In [10]:
batch_size = 512

comparing different hyperparameters using subset

In [11]:
# obtain subset of data
train_small = Subset(training_data, range(int(0.6*len(training_data))))
test_small = Subset(test_data, range(int(0.6*len(test_data))))

# load data
train_dataloader = DataLoader(train_small, shuffle=True, batch_size = batch_size)
test_dataloader = DataLoader(test_small, shuffle=True, batch_size = batch_size)

### Base Transformer 

In [21]:
chw = (1, 64, 64) # image dimensions
hidden_dim = 16 # number of features in each patch's representation
n_encodelayers = 2
n_heads = 2 # no of attention heads
output_dim = 10 # Fashion MNIST has 10 classes

n_patches = 16

learning_rate = 0.005
num_epochs = 20

# instantiate model
model = VisionTransformer(chw, n_patches, hidden_dim, n_encodelayers, n_heads, output_dim)

# instantiate loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# train model with training data
for epoch in range(num_epochs):
    train_loop(train_dataloader, model, loss_fn, optimizer, epoch)

  self.pos_embed = nn.Parameter(torch.tensor(get_positional_embedding(self.n_patches ** 2 + 1, self.hidden_dim)))
                                                                      

Epoch 1 loss: 2.05
Test accuracy: 41.10%


                                                                      

Epoch 2 loss: 1.83
Test accuracy: 63.82%


                                                                      

Epoch 3 loss: 1.75
Test accuracy: 71.05%


                                                                      

Epoch 4 loss: 1.72
Test accuracy: 74.04%


                                                                      

Epoch 5 loss: 1.71
Test accuracy: 75.35%


                                                                      

Epoch 6 loss: 1.70
Test accuracy: 76.09%


                                                                      

Epoch 7 loss: 1.69
Test accuracy: 77.15%


                                                                      

Epoch 8 loss: 1.68
Test accuracy: 77.61%


                                                                      

Epoch 9 loss: 1.68
Test accuracy: 78.28%


                                                                       

Epoch 10 loss: 1.67
Test accuracy: 78.59%


                                                                       

Epoch 11 loss: 1.68
Test accuracy: 77.79%


                                                                       

Epoch 12 loss: 1.67
Test accuracy: 78.91%


                                                                       

Epoch 13 loss: 1.67
Test accuracy: 79.38%


                                                                       

Epoch 14 loss: 1.67
Test accuracy: 79.18%


                                                                       

Epoch 15 loss: 1.66
Test accuracy: 79.66%


                                                                       

Epoch 16 loss: 1.66
Test accuracy: 79.65%


                                                                       

Epoch 17 loss: 1.67
Test accuracy: 79.47%


                                                                       

Epoch 18 loss: 1.66
Test accuracy: 79.90%


                                                                       

Epoch 19 loss: 1.66
Test accuracy: 80.14%


                                                                       

Epoch 20 loss: 1.66
Test accuracy: 80.03%




In [22]:
# test model to get test loss and accuracy
test_accuracy, test_loss = test_loop(test_dataloader, model, loss_fn)

Testing:   5%|▌         | 1/20 [00:02<00:55,  2.91s/it]

tensor([[2.0435e-21, 2.1317e-12, 1.4870e-31,  ..., 1.0000e+00, 1.6927e-23,
         5.1711e-10],
        [0.0000e+00, 1.0771e-10, 2.0171e-31,  ..., 2.5845e-16, 3.9119e-27,
         1.0000e+00],
        [0.0000e+00, 1.8282e-12, 4.1712e-33,  ..., 1.2513e-14, 2.0280e-23,
         1.0000e+00],
        ...,
        [5.3050e-17, 1.0000e+00, 6.6771e-23,  ..., 8.5728e-24, 0.0000e+00,
         1.4401e-28],
        [5.3267e-15, 7.7979e-16, 3.8715e-19,  ..., 6.9117e-13, 5.8055e-22,
         4.4612e-17],
        [1.0000e+00, 5.0422e-29, 4.9498e-16,  ..., 1.2395e-35, 0.0000e+00,
         0.0000e+00]])


Testing:  10%|█         | 2/20 [00:05<00:52,  2.90s/it]

tensor([[1.3707e-30, 2.3906e-39, 8.8082e-17,  ..., 3.2541e-28, 9.6677e-01,
         2.9584e-22],
        [1.4013e-44, 0.0000e+00, 1.1356e-16,  ..., 0.0000e+00, 1.4348e-13,
         5.3122e-35],
        [1.4557e-13, 2.2105e-38, 8.5774e-17,  ..., 3.9195e-26, 8.8302e-09,
         6.0864e-29],
        ...,
        [5.4569e-21, 9.9371e-09, 4.3037e-29,  ..., 1.0000e+00, 7.6246e-27,
         1.6474e-08],
        [0.0000e+00, 2.3401e-11, 1.3416e-18,  ..., 1.0270e-33, 8.5072e-24,
         1.0000e+00],
        [1.7002e-20, 8.4213e-08, 4.5935e-28,  ..., 1.0000e+00, 3.7659e-24,
         7.1156e-09]])


Testing:  15%|█▌        | 3/20 [00:08<00:49,  2.89s/it]

tensor([[2.2318e-16, 1.0000e+00, 5.8607e-23,  ..., 6.4478e-27, 0.0000e+00,
         3.0637e-33],
        [2.0929e-25, 3.3617e-21, 1.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         4.9413e-36],
        [4.5288e-33, 6.7244e-31, 4.4587e-29,  ..., 3.9833e-20, 1.0000e+00,
         4.2307e-17],
        ...,
        [2.0967e-20, 1.1355e-07, 3.4746e-27,  ..., 1.0000e+00, 3.0161e-24,
         1.6055e-06],
        [2.0934e-30, 2.9865e-33, 1.0000e+00,  ..., 0.0000e+00, 2.8212e-29,
         4.8275e-33],
        [2.8767e-06, 8.4979e-10, 2.3366e-10,  ..., 2.5332e-21, 2.2440e-23,
         7.2837e-22]])


Testing:  20%|██        | 4/20 [00:11<00:46,  2.89s/it]

tensor([[3.3552e-22, 3.0351e-27, 1.0000e+00,  ..., 3.2334e-39, 3.1877e-35,
         2.5207e-31],
        [1.0000e+00, 2.6036e-22, 5.4080e-17,  ..., 7.7256e-37, 0.0000e+00,
         0.0000e+00],
        [1.6441e-05, 1.8123e-05, 1.8448e-10,  ..., 1.7319e-21, 2.3841e-28,
         3.4951e-22],
        ...,
        [1.8798e-14, 4.2607e-17, 3.5310e-23,  ..., 2.5534e-10, 9.7767e-13,
         3.5962e-16],
        [8.5102e-14, 1.3032e-42, 1.0853e-11,  ..., 6.3394e-36, 1.0712e-25,
         6.4148e-39],
        [3.3512e-12, 2.9391e-15, 9.1829e-23,  ..., 2.0206e-07, 3.5742e-13,
         1.3351e-15]])


Testing:  25%|██▌       | 5/20 [00:14<00:43,  2.88s/it]

tensor([[0.0000e+00, 1.7217e-13, 8.3742e-30,  ..., 2.7544e-19, 2.5163e-22,
         1.0000e+00],
        [1.2039e-24, 4.3429e-11, 8.9382e-33,  ..., 1.0000e+00, 1.0706e-22,
         7.5881e-08],
        [8.7069e-27, 0.0000e+00, 1.8401e-16,  ..., 4.6172e-40, 7.5277e-07,
         9.6024e-36],
        ...,
        [5.0201e-20, 7.5231e-22, 1.0000e+00,  ..., 7.7618e-42, 0.0000e+00,
         4.0270e-37],
        [1.2415e-38, 0.0000e+00, 3.8359e-18,  ..., 0.0000e+00, 1.6401e-12,
         4.9871e-37],
        [0.0000e+00, 0.0000e+00, 6.5739e-24,  ..., 0.0000e+00, 9.9969e-01,
         6.6710e-23]])


Testing:  30%|███       | 6/20 [00:17<00:40,  2.89s/it]

tensor([[5.1099e-19, 1.3271e-31, 1.2983e-06,  ..., 1.2267e-41, 1.7455e-20,
         4.2890e-31],
        [2.3235e-10, 4.6659e-18, 2.1715e-12,  ..., 2.8297e-30, 3.0864e-27,
         1.1008e-28],
        [2.2330e-13, 1.0000e+00, 1.7741e-16,  ..., 7.3928e-20, 1.6864e-40,
         2.7479e-23],
        ...,
        [4.9045e-44, 0.0000e+00, 1.1691e-32,  ..., 1.0313e-36, 1.0000e+00,
         7.4699e-27],
        [2.2448e-15, 1.0000e+00, 6.3386e-23,  ..., 1.3071e-26, 0.0000e+00,
         4.5372e-32],
        [4.6176e-27, 3.1301e-29, 1.0000e+00,  ..., 0.0000e+00, 1.7096e-43,
         1.3320e-37]])


Testing:  35%|███▌      | 7/20 [00:20<00:37,  2.88s/it]

tensor([[4.4927e-19, 5.7707e-21, 3.1570e-26,  ..., 2.8886e-13, 1.8466e-18,
         5.0667e-18],
        [1.8457e-35, 1.0140e-29, 1.0000e+00,  ..., 0.0000e+00, 1.6287e-31,
         2.1969e-30],
        [3.1796e-35, 8.3761e-28, 8.1178e-01,  ..., 0.0000e+00, 7.3768e-21,
         3.3189e-25],
        ...,
        [0.0000e+00, 1.2524e-38, 6.9934e-29,  ..., 3.0463e-35, 1.0000e+00,
         7.2593e-18],
        [6.3856e-15, 1.0000e+00, 6.4249e-21,  ..., 1.2945e-22, 0.0000e+00,
         2.9228e-27],
        [1.2153e-14, 1.0000e+00, 4.7141e-19,  ..., 1.1099e-21, 2.9287e-43,
         3.1310e-25]])


Testing:  40%|████      | 8/20 [00:23<00:34,  2.88s/it]

tensor([[1.0062e-25, 3.9764e-31, 4.6776e-05,  ..., 2.2758e-31, 4.3651e-05,
         6.6855e-19],
        [0.0000e+00, 3.0678e-12, 1.2354e-27,  ..., 2.6674e-16, 3.5383e-21,
         1.0000e+00],
        [5.5845e-20, 6.0415e-07, 3.2200e-26,  ..., 1.0000e+00, 1.0827e-22,
         2.2104e-07],
        ...,
        [6.5414e-24, 1.6348e-21, 1.0000e+00,  ..., 2.3682e-43, 0.0000e+00,
         3.4649e-35],
        [1.3970e-15, 1.0000e+00, 4.1753e-22,  ..., 8.4714e-23, 0.0000e+00,
         4.0222e-28],
        [9.2597e-08, 5.2755e-12, 7.1655e-15,  ..., 2.7191e-18, 6.6001e-16,
         1.6227e-19]])


Testing:  45%|████▌     | 9/20 [00:25<00:31,  2.88s/it]

tensor([[2.8023e-08, 1.3159e-08, 1.0000e+00,  ..., 6.9642e-35, 0.0000e+00,
         1.3871e-32],
        [0.0000e+00, 1.0988e-15, 1.9012e-31,  ..., 4.0043e-24, 9.5819e-22,
         1.0000e+00],
        [7.1165e-32, 2.0519e-22, 1.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         1.2233e-33],
        ...,
        [3.7343e-33, 4.3196e-25, 1.0000e+00,  ..., 0.0000e+00, 8.6881e-44,
         1.9239e-34],
        [0.0000e+00, 4.3256e-16, 3.3353e-31,  ..., 2.1245e-27, 4.1849e-24,
         1.0000e+00],
        [2.9373e-18, 9.1820e-13, 5.0933e-30,  ..., 9.5941e-01, 3.1508e-18,
         4.9084e-11]])


Testing:  50%|█████     | 10/20 [00:28<00:28,  2.88s/it]

tensor([[2.8915e-34, 2.8026e-45, 1.1433e-15,  ..., 1.5414e-44, 1.6583e-10,
         1.6612e-32],
        [5.5885e-12, 1.0000e+00, 1.5447e-16,  ..., 7.6741e-29, 0.0000e+00,
         2.6009e-32],
        [0.0000e+00, 1.5002e-13, 5.5923e-30,  ..., 5.1498e-21, 1.9647e-24,
         1.0000e+00],
        ...,
        [1.1542e-07, 1.3161e-11, 1.2484e-14,  ..., 1.4350e-17, 2.2653e-15,
         6.6863e-19],
        [1.8085e-28, 0.0000e+00, 1.0994e-17,  ..., 4.1969e-42, 4.3168e-11,
         7.1950e-40],
        [2.5222e-04, 5.3442e-04, 3.9960e-11,  ..., 2.5270e-23, 1.2831e-33,
         1.5654e-24]])


Testing:  55%|█████▌    | 11/20 [00:31<00:25,  2.88s/it]

tensor([[1.0000e+00, 1.1930e-24, 4.1957e-19,  ..., 4.4158e-36, 0.0000e+00,
         0.0000e+00],
        [3.4945e-27, 1.3512e-30, 1.0000e+00,  ..., 0.0000e+00, 1.0675e-32,
         9.0867e-34],
        [0.0000e+00, 4.5103e-16, 8.8960e-27,  ..., 1.3724e-24, 2.4493e-18,
         1.0000e+00],
        ...,
        [8.5181e-24, 9.9119e-10, 3.9627e-32,  ..., 1.0000e+00, 8.8986e-28,
         1.1319e-08],
        [2.6261e-23, 1.7190e-36, 3.8003e-28,  ..., 6.1701e-26, 1.0000e+00,
         6.5527e-26],
        [1.0100e-05, 4.0954e-04, 4.5125e-10,  ..., 1.7745e-23, 5.8297e-33,
         2.3325e-23]])


Testing:  60%|██████    | 12/20 [00:34<00:23,  2.89s/it]

tensor([[2.7467e-13, 1.0000e+00, 7.3534e-20,  ..., 2.7507e-22, 0.0000e+00,
         4.0260e-27],
        [0.0000e+00, 4.5402e-43, 8.4739e-30,  ..., 9.9882e-37, 1.0000e+00,
         1.8058e-23],
        [1.8301e-24, 0.0000e+00, 7.6466e-13,  ..., 1.2738e-41, 1.0505e-21,
         1.2960e-39],
        ...,
        [1.0000e+00, 4.5893e-17, 3.0556e-13,  ..., 3.2191e-36, 0.0000e+00,
         0.0000e+00],
        [2.6529e-16, 1.0000e+00, 5.5116e-22,  ..., 3.4282e-26, 0.0000e+00,
         3.4403e-31],
        [9.8254e-16, 1.7658e-13, 6.5838e-28,  ..., 2.0105e-02, 3.6288e-20,
         2.1914e-13]])


Testing:  65%|██████▌   | 13/20 [00:37<00:20,  2.88s/it]

tensor([[1.0000e+00, 1.3773e-20, 2.3196e-09,  ..., 1.7928e-35, 0.0000e+00,
         2.8026e-45],
        [4.2278e-32, 2.3822e-44, 1.3094e-06,  ..., 1.0734e-42, 1.2650e-12,
         1.0649e-31],
        [1.4013e-45, 0.0000e+00, 3.9280e-12,  ..., 0.0000e+00, 1.7361e-17,
         2.1742e-37],
        ...,
        [4.8482e-16, 4.1424e-38, 2.4472e-13,  ..., 2.3549e-29, 2.6978e-08,
         3.1246e-28],
        [0.0000e+00, 3.8697e-16, 3.0254e-21,  ..., 1.4776e-28, 1.6580e-16,
         1.0000e+00],
        [1.5450e-08, 1.3742e-28, 5.5185e-04,  ..., 2.7202e-29, 2.1475e-15,
         2.1438e-26]])


Testing:  70%|███████   | 14/20 [00:40<00:17,  2.88s/it]

tensor([[5.0937e-42, 5.0394e-37, 4.5129e-33,  ..., 1.1606e-27, 1.0000e+00,
         1.4326e-19],
        [3.7009e-35, 1.5222e-40, 1.4134e-04,  ..., 2.8026e-45, 6.4386e-15,
         1.5624e-29],
        [3.1393e-07, 3.1746e-10, 7.5803e-13,  ..., 1.1431e-17, 4.8456e-18,
         7.9018e-19],
        ...,
        [1.0000e+00, 2.7866e-27, 2.4395e-12,  ..., 1.3713e-34, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 3.5138e-16, 2.2258e-28,  ..., 3.8513e-20, 4.0606e-14,
         1.0000e+00],
        [3.1713e-26, 1.5485e-28, 1.0000e+00,  ..., 0.0000e+00, 2.2421e-44,
         1.2492e-37]])


Testing:  75%|███████▌  | 15/20 [00:43<00:14,  2.88s/it]

tensor([[2.9316e-21, 1.9151e-08, 1.2447e-28,  ..., 1.0000e+00, 3.8459e-25,
         1.4257e-08],
        [9.3428e-16, 1.0000e+00, 2.5501e-22,  ..., 3.5289e-23, 0.0000e+00,
         2.2248e-27],
        [3.4830e-18, 1.3831e-16, 7.1036e-25,  ..., 4.1425e-10, 2.1740e-24,
         1.2835e-16],
        ...,
        [9.9894e-01, 8.2446e-12, 1.0640e-03,  ..., 4.0245e-35, 0.0000e+00,
         9.6997e-40],
        [5.2033e-28, 9.9627e-38, 6.7059e-12,  ..., 7.5810e-43, 9.2265e-15,
         3.4420e-30],
        [3.2427e-23, 9.0569e-08, 2.0912e-29,  ..., 9.9997e-01, 1.3073e-24,
         3.0035e-05]])


Testing:  80%|████████  | 16/20 [00:46<00:11,  2.88s/it]

tensor([[2.8026e-45, 1.3456e-11, 1.2420e-25,  ..., 2.0552e-17, 1.4723e-18,
         1.0000e+00],
        [7.5000e-16, 6.2574e-02, 3.5673e-15,  ..., 1.6480e-02, 3.8023e-15,
         9.2094e-01],
        [2.6594e-38, 0.0000e+00, 1.7057e-12,  ..., 0.0000e+00, 5.7648e-19,
         9.7568e-39],
        ...,
        [1.1667e-13, 1.4013e-45, 9.5178e-16,  ..., 3.9294e-34, 9.0486e-17,
         4.7526e-37],
        [2.3738e-14, 2.3867e-22, 1.2073e-18,  ..., 3.5701e-16, 9.9820e-01,
         1.2577e-16],
        [1.8390e-03, 3.0716e-01, 2.0733e-10,  ..., 2.5227e-20, 4.4201e-32,
         2.3280e-22]])


Testing:  85%|████████▌ | 17/20 [00:49<00:08,  2.87s/it]

tensor([[1.0000e+00, 5.3499e-29, 2.0173e-19,  ..., 1.7587e-34, 0.0000e+00,
         0.0000e+00],
        [5.7795e-34, 6.0130e-42, 2.0601e-10,  ..., 0.0000e+00, 3.3544e-19,
         1.0757e-34],
        [0.0000e+00, 4.3673e-13, 6.9454e-27,  ..., 1.3845e-25, 4.2375e-24,
         1.0000e+00],
        ...,
        [3.9233e-36, 3.5129e-29, 4.5983e-03,  ..., 0.0000e+00, 6.3748e-19,
         7.2042e-25],
        [1.6661e-07, 2.5059e-13, 5.4567e-16,  ..., 1.2331e-19, 4.1820e-16,
         3.9231e-21],
        [3.1975e-05, 8.2273e-02, 7.2761e-11,  ..., 9.5746e-16, 5.3338e-23,
         2.6403e-17]])


Testing:  90%|█████████ | 18/20 [00:52<00:05,  2.92s/it]

tensor([[3.1995e-22, 2.0208e-18, 7.2063e-23,  ..., 1.3242e-14, 4.1829e-20,
         2.0883e-15],
        [4.1739e-34, 0.0000e+00, 5.0112e-15,  ..., 2.1019e-44, 2.6008e-08,
         1.5310e-36],
        [1.4455e-24, 1.9883e-26, 7.7545e-02,  ..., 1.4014e-31, 1.6595e-15,
         4.2745e-20],
        ...,
        [6.0612e-10, 5.0015e-24, 7.3266e-05,  ..., 3.8306e-24, 4.5550e-08,
         3.9050e-20],
        [4.1859e-23, 1.3386e-08, 3.1172e-30,  ..., 1.0000e+00, 6.0842e-26,
         3.5808e-08],
        [1.3277e-19, 1.4534e-08, 6.6612e-27,  ..., 1.0000e+00, 1.3615e-22,
         4.7098e-08]])


Testing:  95%|█████████▌| 19/20 [00:55<00:02,  2.94s/it]

tensor([[6.2147e-08, 1.0214e-11, 9.0860e-15,  ..., 1.8026e-20, 2.8269e-18,
         5.8135e-21],
        [1.8656e-08, 4.9570e-14, 1.6210e-15,  ..., 2.1437e-21, 1.8076e-16,
         9.1676e-22],
        [0.0000e+00, 2.8026e-45, 5.9373e-09,  ..., 0.0000e+00, 4.3290e-15,
         2.0861e-30],
        ...,
        [8.4043e-16, 6.3266e-20, 4.7588e-24,  ..., 1.4945e-12, 1.0292e-16,
         3.5427e-18],
        [0.0000e+00, 1.3250e-10, 4.7199e-22,  ..., 3.0691e-22, 5.2090e-24,
         1.0000e+00],
        [9.8756e-01, 2.3234e-06, 7.5138e-10,  ..., 1.5779e-27, 1.4013e-45,
         3.0489e-31]])


Testing: 100%|██████████| 20/20 [00:56<00:00,  2.83s/it]

tensor([[6.0768e-20, 3.2085e-09, 3.4560e-28,  ..., 1.0000e+00, 1.1961e-23,
         2.6276e-09],
        [1.4013e-45, 1.3018e-42, 4.2509e-28,  ..., 5.0652e-37, 1.0000e+00,
         5.7321e-24],
        [4.2019e-23, 9.5292e-22, 1.5577e-27,  ..., 3.7775e-14, 6.8352e-21,
         1.5021e-17],
        ...,
        [5.1487e-22, 1.6893e-14, 1.7729e-29,  ..., 9.5003e-04, 4.3625e-15,
         1.1604e-09],
        [9.9906e-01, 2.1967e-11, 9.3635e-04,  ..., 1.4067e-27, 2.5713e-37,
         5.7506e-29],
        [5.0045e-06, 3.4353e-09, 5.8475e-15,  ..., 6.4304e-19, 9.9453e-22,
         1.8569e-21]])
Test loss: 1.66
Test accuracy: 79.89%





### Changing lr (0.005 -> 0.001) 

In [12]:
from pathlib import Path

chw = (1, 64, 64) # image dimensions
hidden_dim = 16 # number of features in each patch's representation
n_encodelayers = 2
n_heads = 2 # no of attention heads
output_dim = 10 # Fashion MNIST has 10 classes

n_patches = 16

learning_rate = 0.001
num_epochs = 10

# instantiate model
if Path('./model_checkpoint_lr_001.pt').exists():
    model = torch.load('./model_checkpoint_lr_001.pt')
else:
    model = VisionTransformer(chw, n_patches, hidden_dim, n_encodelayers, n_heads, output_dim)

# instantiate loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# train model with training data
for epoch in range(num_epochs):
    train_loop(train_dataloader, model, loss_fn, optimizer, epoch)
    torch.save(model, './model_checkpoint_lr_001.pt')
    
# test model to get test loss and accuracy
test_accuracy, test_loss = test_loop(test_dataloader, model, loss_fn)

  self.pos_embed = nn.Parameter(torch.tensor(get_positional_embedding(self.n_patches ** 2 + 1, self.hidden_dim)))
                                                                                                                                    

Epoch 1 loss: 2.24
Validation accuracy: 19.81%


                                                                                                                                    

Epoch 2 loss: 2.03
Validation accuracy: 44.20%


                                                                                                                                    

Epoch 3 loss: 1.88
Validation accuracy: 59.28%


                                                                                                                                    

Epoch 4 loss: 1.84
Validation accuracy: 62.37%


                                                                                                                                    

Epoch 5 loss: 1.82
Validation accuracy: 63.98%


                                                                                                                                    

Epoch 6 loss: 1.81
Validation accuracy: 64.98%


                                                                                                                                    

Epoch 7 loss: 1.80
Validation accuracy: 66.24%


                                                                                                                                    

Epoch 8 loss: 1.79
Validation accuracy: 67.38%


                                                                                                                                    

Epoch 9 loss: 1.78
Validation accuracy: 68.09%


                                                                                                                                    

Epoch 10 loss: 1.77
Validation accuracy: 68.73%


                                                                                                                                    

Epoch 11 loss: 1.77
Validation accuracy: 68.81%


                                                                                                                                    

Epoch 12 loss: 1.77
Validation accuracy: 69.09%


                                                                                                                                    

Epoch 13 loss: 1.77
Validation accuracy: 69.41%


                                                                                                                                    

Epoch 14 loss: 1.76
Validation accuracy: 69.94%


                                                                                                                                    

Epoch 15 loss: 1.75
Validation accuracy: 70.46%


                                                                                                                                    

Epoch 16 loss: 1.71
Validation accuracy: 75.56%


                                                                                                                                    

Epoch 17 loss: 1.68
Validation accuracy: 78.40%


                                                                                                                                    

Epoch 18 loss: 1.67
Validation accuracy: 79.23%


                                                                                                                                    

Epoch 19 loss: 1.67
Validation accuracy: 79.46%


                                                                                                                                    

Epoch 20 loss: 1.67
Validation accuracy: 79.58%


Testing:   5%|████▎                                                                                  | 1/20 [00:04<01:17,  4.11s/it]

tensor([[1.8328e-07, 1.3017e-15, 1.0663e-06,  ..., 1.0519e-20, 1.6116e-04,
         6.9934e-13],
        [3.3800e-03, 6.0608e-13, 2.3499e-01,  ..., 1.2788e-19, 5.0476e-07,
         5.7604e-13],
        [2.5177e-32, 9.9893e-24, 1.2027e-24,  ..., 1.0941e-06, 2.0226e-04,
         9.9980e-01],
        ...,
        [9.9708e-01, 1.3237e-09, 1.3264e-08,  ..., 1.3947e-18, 1.0488e-11,
         7.6956e-15],
        [1.2819e-22, 4.9318e-22, 2.0338e-13,  ..., 9.3657e-16, 9.9997e-01,
         8.1409e-06],
        [3.8626e-11, 3.4894e-07, 9.6596e-24,  ..., 5.1590e-09, 3.6340e-07,
         6.2066e-10]])


Testing:  10%|████████▋                                                                              | 2/20 [00:07<01:11,  3.97s/it]

tensor([[2.5385e-28, 6.4850e-20, 7.0434e-22,  ..., 7.9046e-09, 1.4579e-08,
         1.5239e-06],
        [6.9783e-07, 5.2626e-14, 7.5238e-01,  ..., 3.1585e-18, 4.1933e-04,
         1.2850e-10],
        [1.6064e-24, 7.8555e-13, 2.4119e-22,  ..., 1.4772e-05, 1.3199e-09,
         8.4784e-14],
        ...,
        [4.9240e-26, 1.8746e-18, 3.6808e-22,  ..., 5.5372e-04, 5.6170e-05,
         9.9219e-01],
        [8.7775e-01, 1.4347e-06, 7.4780e-08,  ..., 1.3041e-15, 3.1106e-08,
         6.4334e-13],
        [2.0059e-23, 6.4365e-13, 3.2603e-23,  ..., 9.9999e-01, 3.8736e-06,
         4.1531e-06]])


Testing:  15%|█████████████                                                                          | 3/20 [00:11<01:07,  3.96s/it]

tensor([[4.4089e-22, 2.0697e-18, 1.1172e-15,  ..., 3.2693e-13, 4.2724e-09,
         4.1143e-12],
        [3.1283e-11, 1.0000e+00, 2.2956e-19,  ..., 2.2236e-14, 1.1393e-14,
         9.2883e-24],
        [3.1580e-07, 2.0297e-04, 9.9200e-25,  ..., 1.0403e-14, 3.4935e-13,
         2.1589e-15],
        ...,
        [1.2585e-09, 3.2970e-09, 1.6193e-03,  ..., 1.4654e-13, 6.0894e-05,
         5.1904e-10],
        [9.0437e-06, 1.4223e-10, 9.5952e-02,  ..., 6.0443e-16, 6.1473e-04,
         1.3995e-09],
        [5.0317e-10, 6.5446e-15, 9.8947e-01,  ..., 2.0650e-15, 9.3383e-03,
         1.9916e-07]])


Testing:  20%|█████████████████▍                                                                     | 4/20 [00:15<01:03,  3.99s/it]

tensor([[9.9491e-01, 4.7196e-06, 2.9262e-12,  ..., 1.9484e-18, 7.3497e-13,
         1.3693e-16],
        [2.0507e-28, 4.5041e-17, 5.8091e-26,  ..., 9.8702e-01, 7.2039e-05,
         1.2912e-02],
        [3.4440e-10, 9.9922e-01, 1.7872e-21,  ..., 4.4141e-08, 6.2953e-08,
         6.2707e-14],
        ...,
        [2.6237e-08, 3.5710e-08, 1.3637e-02,  ..., 1.7834e-12, 6.1565e-05,
         1.9650e-09],
        [3.6549e-12, 1.0000e+00, 1.1951e-21,  ..., 2.6887e-13, 1.1624e-15,
         7.2370e-23],
        [8.5346e-08, 1.1369e-12, 9.9999e-01,  ..., 4.3323e-17, 6.7130e-09,
         2.2044e-13]])


Testing:  25%|█████████████████████▊                                                                 | 5/20 [00:19<00:59,  4.00s/it]

tensor([[2.2220e-09, 9.0399e-07, 6.5257e-23,  ..., 2.8301e-11, 2.1021e-08,
         1.4361e-11],
        [6.5677e-13, 1.1560e-06, 7.5455e-12,  ..., 9.7543e-10, 3.9805e-06,
         1.4200e-16],
        [6.3152e-06, 1.0374e-13, 9.4305e-01,  ..., 1.5059e-13, 2.5410e-02,
         3.2564e-05],
        ...,
        [9.9991e-01, 5.1471e-08, 4.6321e-10,  ..., 7.1209e-20, 1.8645e-15,
         2.0126e-17],
        [1.9662e-16, 3.1676e-11, 4.4884e-15,  ..., 1.3204e-09, 1.0000e+00,
         2.7835e-10],
        [2.0176e-16, 3.5164e-09, 8.7587e-15,  ..., 1.8445e-10, 2.0637e-07,
         1.5921e-17]])


Testing:  30%|██████████████████████████                                                             | 6/20 [00:23<00:55,  3.96s/it]

tensor([[2.2551e-07, 3.3808e-10, 5.8637e-06,  ..., 1.8319e-18, 9.8281e-06,
         8.7648e-13],
        [2.5433e-33, 1.3943e-24, 9.6941e-28,  ..., 1.4086e-06, 4.8837e-06,
         9.9999e-01],
        [8.0640e-10, 4.6081e-07, 1.2888e-07,  ..., 6.9916e-12, 6.1554e-06,
         2.5406e-10],
        ...,
        [3.3575e-11, 2.3936e-11, 5.4573e-22,  ..., 5.0729e-15, 1.7971e-03,
         6.5461e-11],
        [1.5196e-17, 2.2661e-06, 1.2964e-17,  ..., 3.8587e-05, 7.3154e-08,
         3.4858e-15],
        [9.5042e-12, 1.0102e-10, 2.4380e-05,  ..., 4.6074e-15, 1.7988e-05,
         6.0243e-11]])


Testing:  35%|██████████████████████████████▍                                                        | 7/20 [00:27<00:52,  4.00s/it]

tensor([[1.0170e-05, 1.6229e-08, 1.1804e-12,  ..., 1.8090e-19, 1.0565e-07,
         1.3037e-15],
        [5.1889e-05, 5.3932e-04, 1.3850e-02,  ..., 4.4907e-12, 1.6628e-05,
         1.4045e-10],
        [4.7751e-11, 1.0000e+00, 7.5809e-24,  ..., 4.0581e-11, 4.3310e-12,
         2.7096e-18],
        ...,
        [8.0008e-22, 7.6338e-12, 1.7615e-22,  ..., 9.9999e-01, 1.0443e-05,
         7.9112e-07],
        [8.3562e-30, 3.8437e-24, 3.7687e-23,  ..., 1.1298e-08, 9.3154e-04,
         9.9907e-01],
        [9.5988e-01, 2.3442e-10, 5.8912e-04,  ..., 1.9596e-15, 1.0060e-07,
         1.0151e-09]])


Testing:  40%|██████████████████████████████████▊                                                    | 8/20 [00:31<00:48,  4.02s/it]

tensor([[2.1574e-09, 9.6172e-06, 1.3757e-23,  ..., 3.7039e-14, 1.3616e-07,
         2.8551e-13],
        [1.8185e-32, 1.2103e-22, 4.0595e-25,  ..., 4.2085e-05, 3.4145e-04,
         9.9962e-01],
        [1.3334e-06, 4.9298e-08, 1.9864e-07,  ..., 3.5519e-18, 4.2859e-06,
         9.1984e-14],
        ...,
        [2.0767e-07, 8.5589e-13, 2.7270e-07,  ..., 5.6255e-20, 2.1572e-05,
         4.3304e-13],
        [1.4867e-06, 1.5489e-05, 2.1182e-03,  ..., 1.4548e-12, 2.3429e-05,
         7.6203e-11],
        [8.4015e-10, 7.7391e-09, 1.2398e-23,  ..., 2.7391e-13, 3.2559e-08,
         3.1532e-12]])


Testing:  45%|███████████████████████████████████████▏                                               | 9/20 [00:36<00:44,  4.03s/it]

tensor([[8.0665e-12, 1.0000e+00, 2.6486e-21,  ..., 1.3947e-13, 3.5274e-15,
         5.6888e-23],
        [1.1749e-08, 7.4072e-14, 1.0000e+00,  ..., 6.5174e-18, 1.9850e-09,
         3.0015e-13],
        [9.3601e-23, 1.3325e-22, 2.1909e-13,  ..., 5.6960e-15, 9.9998e-01,
         2.2731e-05],
        ...,
        [1.4605e-11, 1.0000e+00, 1.5283e-21,  ..., 6.1970e-14, 3.1411e-15,
         5.3242e-23],
        [4.1907e-18, 4.9180e-12, 7.2600e-26,  ..., 1.0765e-14, 6.0275e-05,
         1.2075e-13],
        [7.3097e-21, 1.3263e-18, 7.7023e-14,  ..., 2.4965e-13, 1.4948e-09,
         7.3106e-10]])


Testing:  50%|███████████████████████████████████████████                                           | 10/20 [00:39<00:40,  4.00s/it]

tensor([[5.0833e-04, 3.5415e-14, 1.7337e-02,  ..., 1.9780e-20, 1.4462e-06,
         6.4778e-13],
        [6.0951e-11, 1.6073e-12, 9.8446e-24,  ..., 7.2552e-15, 2.0371e-07,
         7.2356e-11],
        [1.3799e-11, 1.0000e+00, 6.6214e-23,  ..., 4.0392e-14, 1.9696e-15,
         6.9083e-23],
        ...,
        [3.7675e-28, 3.9690e-18, 3.2697e-24,  ..., 3.4908e-01, 3.3599e-02,
         6.1732e-01],
        [1.7045e-17, 2.8594e-06, 8.3947e-18,  ..., 4.9067e-04, 2.3125e-07,
         5.9739e-14],
        [2.5832e-12, 5.3613e-10, 6.4880e-27,  ..., 9.1372e-15, 5.7231e-08,
         1.4503e-12]])


Testing:  55%|███████████████████████████████████████████████▎                                      | 11/20 [00:43<00:35,  3.99s/it]

tensor([[8.8195e-12, 2.4885e-11, 9.6577e-24,  ..., 1.7075e-10, 3.3956e-07,
         1.4745e-08],
        [4.4579e-07, 7.0709e-09, 3.0721e-22,  ..., 1.1077e-16, 9.1816e-11,
         7.6225e-14],
        [6.7246e-13, 5.2803e-13, 2.2423e-08,  ..., 1.0615e-17, 3.6530e-05,
         4.2861e-12],
        ...,
        [2.0316e-32, 5.5785e-25, 2.2226e-27,  ..., 2.1457e-07, 3.0604e-05,
         9.9997e-01],
        [1.5221e-32, 5.8226e-26, 1.0268e-25,  ..., 9.9036e-09, 1.8649e-05,
         9.9998e-01],
        [8.4885e-10, 1.0572e-06, 3.9836e-22,  ..., 5.2196e-10, 1.8953e-06,
         6.0947e-11]])


Testing:  60%|███████████████████████████████████████████████████▌                                  | 12/20 [00:47<00:32,  4.00s/it]

tensor([[8.4465e-11, 3.7935e-07, 1.7615e-21,  ..., 1.0783e-07, 5.7774e-05,
         1.8872e-08],
        [5.5802e-11, 3.1882e-13, 5.6198e-12,  ..., 4.4307e-20, 1.5247e-04,
         2.0327e-13],
        [1.2627e-10, 9.2722e-10, 1.4034e-20,  ..., 1.0405e-08, 6.9646e-05,
         9.8255e-08],
        ...,
        [2.3245e-02, 8.2958e-06, 3.4175e-14,  ..., 2.5705e-17, 2.3422e-08,
         9.4668e-15],
        [9.5708e-22, 2.1709e-11, 3.6287e-20,  ..., 1.2419e-08, 1.6994e-10,
         1.7800e-17],
        [9.5219e-11, 5.1088e-14, 2.1674e-08,  ..., 2.4003e-19, 1.0447e-02,
         1.0060e-11]])


Testing:  65%|███████████████████████████████████████████████████████▉                              | 13/20 [00:51<00:27,  3.96s/it]

tensor([[9.9976e-01, 3.8680e-09, 2.3226e-08,  ..., 7.0924e-18, 7.3194e-13,
         4.3376e-14],
        [3.8535e-12, 1.0000e+00, 1.6434e-23,  ..., 4.1030e-13, 2.5410e-15,
         3.7819e-22],
        [2.0060e-01, 2.4209e-09, 7.1315e-04,  ..., 2.5689e-19, 2.4412e-09,
         2.3458e-15],
        ...,
        [1.5854e-06, 2.4173e-15, 1.3241e-05,  ..., 1.5283e-15, 2.8960e-01,
         2.7625e-07],
        [4.3008e-13, 1.3145e-12, 1.1855e-04,  ..., 1.1437e-12, 4.9292e-01,
         1.0098e-06],
        [2.3874e-11, 1.6000e-08, 1.9461e-08,  ..., 2.6313e-14, 4.4988e-06,
         4.0557e-12]])


Testing:  70%|████████████████████████████████████████████████████████████▏                         | 14/20 [00:55<00:24,  4.01s/it]

tensor([[5.1532e-09, 4.7161e-05, 9.8685e-22,  ..., 8.6793e-14, 6.3724e-07,
         4.1117e-13],
        [1.8946e-07, 3.0203e-14, 9.9990e-01,  ..., 7.8081e-16, 6.6487e-06,
         8.0524e-10],
        [1.9271e-24, 5.5518e-23, 1.3202e-17,  ..., 2.1563e-12, 9.9968e-01,
         3.1981e-04],
        ...,
        [6.3507e-10, 1.0000e+00, 4.8226e-21,  ..., 1.6049e-10, 5.9949e-11,
         3.0907e-17],
        [6.4817e-05, 3.7631e-11, 2.9273e-06,  ..., 6.2359e-20, 1.9616e-07,
         1.1221e-13],
        [3.7040e-11, 1.2007e-07, 2.7846e-23,  ..., 4.7789e-09, 1.3080e-06,
         2.1662e-09]])


Testing:  75%|████████████████████████████████████████████████████████████████▌                     | 15/20 [00:59<00:20,  4.01s/it]

tensor([[1.9347e-06, 9.5216e-06, 5.6729e-04,  ..., 8.6626e-13, 1.5319e-06,
         5.6773e-11],
        [6.1539e-11, 1.0000e+00, 1.9985e-22,  ..., 2.6762e-12, 8.7519e-12,
         6.8698e-20],
        [1.5651e-23, 2.0457e-10, 9.4554e-24,  ..., 5.1814e-01, 4.4138e-08,
         8.0430e-10],
        ...,
        [1.2344e-10, 1.0000e+00, 3.1403e-23,  ..., 2.1929e-11, 1.0879e-11,
         3.8373e-18],
        [4.7595e-15, 1.8862e-12, 4.0440e-14,  ..., 1.8656e-05, 9.9969e-01,
         7.3697e-05],
        [3.4240e-03, 2.7018e-10, 2.6668e-05,  ..., 1.7684e-20, 4.7819e-09,
         3.5314e-15]])


Testing:  80%|████████████████████████████████████████████████████████████████████▊                 | 16/20 [01:03<00:15,  3.99s/it]

tensor([[1.1755e-21, 1.4208e-09, 1.4478e-21,  ..., 9.8625e-01, 1.4544e-06,
         4.3457e-09],
        [3.8810e-10, 3.9130e-06, 8.1779e-24,  ..., 1.2217e-10, 6.0818e-08,
         1.1303e-11],
        [6.1307e-08, 1.0266e-07, 4.5516e-22,  ..., 3.3430e-14, 4.6604e-09,
         7.4604e-13],
        ...,
        [7.3168e-08, 1.1486e-12, 9.9942e-01,  ..., 2.9901e-16, 8.0662e-06,
         8.2301e-11],
        [1.0290e-19, 1.3499e-13, 1.0435e-13,  ..., 1.0297e-12, 2.8065e-07,
         3.4373e-13],
        [4.5198e-09, 1.6203e-08, 1.0974e-22,  ..., 2.4281e-15, 1.2592e-07,
         5.1719e-13]])


Testing:  85%|█████████████████████████████████████████████████████████████████████████             | 17/20 [01:07<00:11,  3.99s/it]

tensor([[3.1610e-06, 6.8806e-15, 4.0891e-05,  ..., 2.3890e-21, 9.9315e-06,
         4.3002e-14],
        [7.4526e-25, 1.7024e-16, 1.2651e-22,  ..., 1.6885e-07, 1.0000e+00,
         1.6662e-06],
        [2.0732e-10, 1.0000e+00, 4.0116e-18,  ..., 2.8386e-15, 4.1705e-14,
         4.7774e-24],
        ...,
        [2.1330e-34, 1.5709e-24, 2.5087e-28,  ..., 5.7077e-06, 3.6666e-06,
         9.9999e-01],
        [5.5334e-01, 8.7448e-10, 5.0274e-08,  ..., 9.5472e-19, 5.7114e-09,
         7.6871e-14],
        [4.5892e-04, 3.3300e-12, 3.2273e-11,  ..., 2.7730e-18, 1.4712e-05,
         3.7715e-12]])


Testing:  90%|█████████████████████████████████████████████████████████████████████████████▍        | 18/20 [01:11<00:07,  3.98s/it]

tensor([[3.8951e-17, 2.0758e-09, 1.6645e-15,  ..., 1.8770e-10, 6.0904e-09,
         1.3383e-17],
        [1.8877e-09, 4.7029e-03, 7.4102e-21,  ..., 1.0572e-06, 2.3500e-05,
         9.5350e-10],
        [2.2787e-01, 6.3131e-01, 3.8949e-15,  ..., 1.9494e-13, 2.5208e-12,
         1.0177e-14],
        ...,
        [4.4046e-07, 7.7782e-11, 7.2447e-11,  ..., 3.2908e-18, 8.3925e-05,
         1.6324e-12],
        [4.8589e-26, 9.0374e-16, 9.7435e-25,  ..., 9.9920e-01, 9.0216e-06,
         7.9500e-04],
        [1.2159e-23, 5.6290e-13, 3.7588e-24,  ..., 1.0000e+00, 6.1440e-07,
         1.3053e-06]])


Testing:  95%|█████████████████████████████████████████████████████████████████████████████████▋    | 19/20 [01:15<00:03,  3.98s/it]

tensor([[9.9782e-01, 7.7541e-10, 4.7675e-06,  ..., 3.9169e-18, 1.6997e-11,
         1.5494e-13],
        [5.1115e-24, 1.2668e-12, 7.4223e-24,  ..., 9.9999e-01, 2.2842e-06,
         2.8212e-07],
        [3.3715e-12, 1.5133e-10, 2.5583e-07,  ..., 2.5291e-14, 3.1895e-05,
         5.6633e-11],
        ...,
        [9.9862e-01, 3.5178e-08, 6.8913e-11,  ..., 2.0278e-19, 2.1201e-13,
         1.5163e-16],
        [4.2150e-17, 4.2723e-06, 3.4939e-17,  ..., 3.1666e-05, 1.0256e-07,
         2.9284e-15],
        [6.2456e-01, 5.4346e-05, 1.6254e-13,  ..., 9.8031e-17, 2.1236e-10,
         2.5699e-14]])


Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████| 20/20 [01:18<00:00,  3.90s/it]

tensor([[9.0283e-06, 2.4838e-03, 4.0479e-22,  ..., 2.1293e-14, 3.9811e-12,
         5.1946e-15],
        [4.3431e-07, 2.2653e-14, 2.6869e-02,  ..., 1.5217e-13, 8.5482e-01,
         1.2065e-04],
        [3.3978e-21, 7.3561e-10, 1.6277e-21,  ..., 9.9971e-01, 1.1219e-06,
         1.5730e-08],
        ...,
        [7.9400e-08, 2.4613e-13, 9.9982e-01,  ..., 1.2174e-18, 2.7867e-08,
         1.5214e-13],
        [2.7211e-22, 2.0459e-11, 3.7749e-23,  ..., 1.0000e+00, 5.4508e-07,
         1.6266e-07],
        [7.4700e-03, 4.8557e-05, 3.8046e-13,  ..., 2.9796e-18, 3.9417e-08,
         4.5646e-16]])
Test loss: 1.68
Test accuracy: 78.55%





### Changing lr (0.005 -> 0.01)

In [None]:
from pathlib import Path

chw = (1, 64, 64) # image dimensions
hidden_dim = 16 # number of features in each patch's representation
n_encodelayers = 2
n_heads = 2 # no of attention heads
output_dim = 10 # Fashion MNIST has 10 classes

n_patches = 16

learning_rate = 0.01
num_epochs = 20

# instantiate model
model = VisionTransformer(chw, n_patches, hidden_dim, n_encodelayers, n_heads, output_dim)

# instantiate loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# train model with training data
for epoch in range(num_epochs):
    train_loop(train_dataloader, model, loss_fn, optimizer, epoch)
    torch.save(model, './model_checkpoint_lr_01.pt')
    
# test model to get test loss and accuracy
test_accuracy, test_loss = test_loop(test_dataloader, model, loss_fn)

### Base model but on smaller train set (to save on computation time) 

In [18]:
chw = (1, 64, 64) # image dimensions
hidden_dim = 16 # number of features in each patch's representation
n_encodelayers = 2
n_heads = 2 # no of attention heads
output_dim = 10 # Fashion MNIST has 10 classes

n_patches = 16

learning_rate = 0.01
num_epochs = 10

# instantiate model
model = VisionTransformer(chw, n_patches, hidden_dim, n_encodelayers, n_heads, output_dim)

# instantiate loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# train model with training data
for epoch in range(num_epochs):
    train_loop(train_dataloader, model, loss_fn, optimizer, epoch)

# test model to get test loss and accuracy
test_accuracy, test_loss = test_loop(test_dataloader, model, loss_fn)

  self.pos_embed = nn.Parameter(torch.tensor(get_positional_embedding(self.n_patches ** 2 + 1, self.hidden_dim)))
                                                                                                                                    

Epoch 1 loss: 2.11
Validation accuracy: 33.79%


                                                                                                                                    

Epoch 2 loss: 1.88
Validation accuracy: 57.66%


                                                                                                                                    

Epoch 3 loss: 1.83
Validation accuracy: 62.44%


                                                                                                                                    

Epoch 4 loss: 1.81
Validation accuracy: 64.91%


                                                                                                                                    

Epoch 5 loss: 1.79
Validation accuracy: 66.73%


                                                                                                                                    

Epoch 6 loss: 1.79
Validation accuracy: 66.84%


                                                                                                                                    

Epoch 7 loss: 1.79
Validation accuracy: 66.87%


                                                                                                                                    

Epoch 8 loss: 1.79
Validation accuracy: 67.39%


                                                                                                                                    

Epoch 9 loss: 1.79
Validation accuracy: 67.25%


                                                                                                                                    

Epoch 10 loss: 1.78
Validation accuracy: 68.08%


Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:55<00:00,  4.60s/it]

Test loss: 1.78
Test accuracy: 67.82%





### Changing hidden_dim (8, 32 vs 16 base)

In [16]:
chw = (1, 64, 64) # image dimensions
n_encodelayers = 2
n_heads = 2 # no of attention heads
output_dim = 10 # Fashion MNIST has 10 classes
n_patches = 16

# parameter to test
hidden_dims = [8, 32] # number of features in each patch's representation

batch_size = 512
learning_rate = 0.01
num_epochs = 10

test_acc_list = []
for hidden_dim in hidden_dims:
    # instantiate model
    model = VisionTransformer(chw, n_patches, hidden_dim, n_encodelayers, n_heads, output_dim)
    
    # instantiate loss function and optimizer
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # train model with training data
    for epoch in range(num_epochs):
        train_loop(train_dataloader, model, loss_fn, optimizer, epoch)
        torch.save(model, f'./saved_model__hidden_dim{hidden_dim}.pt')
        
#     # test model to get test loss and accuracy
#     test_accuracy, test_loss = test_loop(test_dataloader, model, loss_fn)
    
# print(f'Test accuracy for hidden_dim 8: {test_acc_list[0]}')
# print(f'Test accuracy for hidden_dim 32: {test_acc_list[1]}')

  self.pos_embed = nn.Parameter(torch.tensor(get_positional_embedding(self.n_patches ** 2 + 1, self.hidden_dim)))
                                                                                                                                    

Epoch 1 loss: 2.19
Validation accuracy: 25.73%


                                                                                                                                    

Epoch 2 loss: 1.92
Validation accuracy: 54.47%


                                                                                                                                    

Epoch 3 loss: 1.83
Validation accuracy: 62.87%


                                                                                                                                    

Epoch 4 loss: 1.81
Validation accuracy: 65.21%


                                                                                                                                    

Epoch 5 loss: 1.80
Validation accuracy: 65.67%


                                                                                                                                    

Epoch 6 loss: 1.80
Validation accuracy: 66.38%


                                                                                                                                    

Epoch 7 loss: 1.73
Validation accuracy: 72.74%


                                                                                                                                    

Epoch 8 loss: 1.72
Validation accuracy: 74.14%


                                                                                                                                    

Epoch 9 loss: 1.71
Validation accuracy: 74.53%


                                                                                                                                    

Epoch 10 loss: 1.71
Validation accuracy: 74.82%


Testing:   8%|███████▎                                                                               | 1/12 [00:05<01:01,  5.55s/it]

tensor([[3.1431e-16, 2.0023e-22, 2.8654e-02,  ..., 8.1616e-26, 2.1531e-07,
         3.4395e-19],
        [3.2234e-26, 1.0000e+00, 2.1695e-37,  ..., 2.6936e-13, 1.3107e-20,
         1.3186e-16],
        [1.2706e-14, 9.4221e-20, 2.8115e-14,  ..., 2.0945e-24, 1.0000e+00,
         3.0377e-08],
        ...,
        [1.8207e-09, 4.1775e-17, 7.6421e-16,  ..., 7.0921e-06, 6.7235e-13,
         3.2991e-03],
        [2.6498e-17, 2.1698e-21, 1.1087e-03,  ..., 7.7021e-27, 2.9766e-05,
         6.0054e-19],
        [4.9798e-04, 8.3161e-21, 9.8855e-01,  ..., 8.4371e-16, 1.3646e-13,
         4.6492e-11]])


Testing:  17%|██████████████▌                                                                        | 2/12 [00:10<00:50,  5.04s/it]

tensor([[6.4852e-20, 1.5315e-21, 4.9149e-06,  ..., 4.5611e-29, 1.9097e-05,
         1.8175e-20],
        [1.0000e+00, 1.5918e-27, 2.5262e-15,  ..., 1.9688e-20, 1.6498e-17,
         3.1791e-06],
        [1.1556e-11, 2.1589e-15, 1.1811e-18,  ..., 9.8952e-01, 9.8418e-22,
         3.2769e-08],
        ...,
        [4.9883e-01, 2.0017e-16, 1.0187e-09,  ..., 1.0020e-18, 4.8075e-02,
         4.5309e-01],
        [4.8979e-20, 3.3117e-23, 9.8074e-13,  ..., 2.6486e-17, 1.9701e-12,
         2.6812e-17],
        [4.5121e-11, 1.8455e-19, 4.3852e-28,  ..., 5.8929e-20, 6.9201e-09,
         1.0000e+00]])


Testing:  25%|█████████████████████▊                                                                 | 3/12 [00:14<00:43,  4.84s/it]

tensor([[7.9962e-13, 3.4923e-15, 2.2428e-18,  ..., 9.8904e-01, 7.4421e-23,
         1.7567e-10],
        [7.4031e-18, 3.2648e-20, 7.4342e-15,  ..., 8.9000e-31, 1.0000e+00,
         1.9673e-13],
        [5.9484e-18, 7.9913e-24, 1.3233e-13,  ..., 4.8718e-17, 2.0736e-11,
         8.0920e-14],
        ...,
        [9.9957e-01, 1.7999e-23, 6.9299e-13,  ..., 9.2012e-17, 9.2709e-18,
         1.1190e-06],
        [4.6574e-12, 6.0619e-14, 2.1266e-17,  ..., 1.3155e-22, 1.0000e+00,
         4.2096e-06],
        [2.7454e-15, 8.3244e-22, 3.1392e-02,  ..., 6.9296e-25, 1.6349e-07,
         1.7889e-18]])


Testing:  33%|█████████████████████████████                                                          | 4/12 [00:19<00:39,  4.90s/it]

tensor([[2.0368e-19, 3.5422e-22, 3.2620e-10,  ..., 3.2202e-22, 5.9205e-05,
         4.6196e-16],
        [2.7651e-12, 1.1798e-12, 3.0295e-24,  ..., 1.0000e+00, 3.6707e-23,
         3.9269e-07],
        [1.8880e-15, 5.4710e-19, 8.9492e-17,  ..., 1.1970e-28, 1.0000e+00,
         1.8726e-10],
        ...,
        [9.9493e-01, 9.8969e-25, 5.2796e-15,  ..., 2.8423e-17, 3.5824e-21,
         2.4776e-07],
        [5.2907e-19, 2.8405e-20, 9.2320e-14,  ..., 8.8803e-30, 1.0000e+00,
         3.0264e-14],
        [4.7398e-19, 4.3420e-22, 5.0775e-11,  ..., 1.7812e-20, 9.9233e-07,
         1.0861e-15]])


Testing:  42%|████████████████████████████████████▎                                                  | 5/12 [00:24<00:33,  4.85s/it]

tensor([[2.1874e-27, 1.0000e+00, 2.6399e-39,  ..., 2.7674e-17, 8.1187e-19,
         5.9915e-17],
        [5.3418e-11, 7.3377e-20, 7.4856e-27,  ..., 5.2012e-15, 1.1143e-12,
         1.0000e+00],
        [1.0000e+00, 5.1239e-33, 1.5665e-12,  ..., 1.0853e-19, 1.1839e-25,
         7.8223e-11],
        ...,
        [2.0308e-10, 7.6422e-23, 9.9988e-01,  ..., 3.4512e-22, 2.2160e-08,
         2.8434e-15],
        [5.1397e-15, 3.3714e-21, 2.5561e-02,  ..., 6.5244e-25, 2.8456e-06,
         9.8715e-18],
        [1.1476e-21, 3.1452e-21, 3.5302e-10,  ..., 7.7144e-30, 1.0000e+00,
         2.5414e-18]])


Testing:  50%|███████████████████████████████████████████▌                                           | 6/12 [00:29<00:28,  4.82s/it]

tensor([[6.9722e-15, 8.5850e-16, 1.4125e-29,  ..., 2.8910e-21, 4.2062e-25,
         6.7689e-14],
        [5.9084e-10, 9.0584e-18, 5.7551e-03,  ..., 8.7423e-22, 9.9395e-01,
         3.8590e-11],
        [4.6683e-19, 4.7403e-19, 9.6634e-11,  ..., 6.9602e-16, 2.3402e-11,
         3.1809e-18],
        ...,
        [4.0242e-17, 7.7938e-23, 5.3400e-03,  ..., 1.2337e-26, 5.2829e-08,
         4.7622e-20],
        [1.9780e-12, 6.6738e-16, 1.3612e-18,  ..., 9.8710e-01, 6.8065e-23,
         1.1284e-09],
        [4.0499e-24, 1.0000e+00, 4.3189e-41,  ..., 5.2833e-16, 2.2426e-17,
         5.7571e-13]])


Testing:  58%|██████████████████████████████████████████████████▊                                    | 7/12 [00:34<00:24,  4.82s/it]

tensor([[2.6886e-12, 1.0344e-17, 2.7758e-28,  ..., 1.1538e-20, 1.6156e-25,
         1.3325e-12],
        [2.6020e-20, 2.0804e-24, 7.9295e-14,  ..., 2.8719e-16, 2.8102e-15,
         3.4591e-18],
        [2.2080e-27, 1.0000e+00, 2.1112e-36,  ..., 5.2658e-15, 1.7009e-20,
         8.9664e-19],
        ...,
        [7.0385e-11, 8.9306e-15, 1.4496e-21,  ..., 9.9996e-01, 5.8225e-22,
         4.6867e-06],
        [1.1464e-12, 7.7718e-23, 9.7869e-01,  ..., 2.4602e-23, 1.2407e-08,
         3.4077e-17],
        [2.0329e-09, 9.9966e-24, 1.0000e+00,  ..., 3.4560e-21, 1.0098e-10,
         2.6678e-15]])


Testing:  67%|██████████████████████████████████████████████████████████                             | 8/12 [00:39<00:19,  4.86s/it]

tensor([[1.0466e-07, 2.4540e-24, 1.0000e+00,  ..., 2.9482e-19, 5.7821e-14,
         6.4944e-15],
        [1.0000e+00, 2.0106e-31, 2.4095e-14,  ..., 4.0030e-20, 1.2878e-23,
         2.5623e-09],
        [2.4192e-18, 5.8665e-20, 1.9242e-13,  ..., 4.6535e-30, 1.0000e+00,
         2.3215e-14],
        ...,
        [1.5503e-09, 1.5677e-18, 3.8174e-14,  ..., 1.7676e-10, 2.8843e-08,
         6.7146e-03],
        [5.3878e-06, 3.0223e-20, 9.9937e-01,  ..., 2.3684e-20, 5.8235e-04,
         7.9748e-10],
        [2.6487e-16, 7.8331e-20, 1.6902e-03,  ..., 3.4988e-26, 1.1433e-02,
         5.4807e-17]])


Testing:  75%|█████████████████████████████████████████████████████████████████▎                     | 9/12 [00:43<00:14,  4.83s/it]

tensor([[1.0678e-10, 4.8013e-24, 9.9990e-01,  ..., 3.6557e-22, 2.6750e-11,
         1.6571e-16],
        [5.3843e-13, 8.6444e-15, 2.1855e-19,  ..., 9.9837e-01, 4.3704e-23,
         3.3837e-10],
        [3.8347e-17, 3.0823e-22, 6.0565e-11,  ..., 2.5370e-20, 6.0834e-05,
         3.7905e-12],
        ...,
        [2.2022e-12, 1.4843e-18, 6.0754e-27,  ..., 1.0868e-20, 3.1107e-26,
         2.4913e-13],
        [6.9111e-15, 3.6409e-19, 3.6307e-15,  ..., 9.7885e-06, 8.1503e-21,
         2.7584e-14],
        [9.4859e-13, 6.2671e-13, 1.3367e-07,  ..., 2.2118e-19, 3.9876e-07,
         4.0712e-13]])


Testing:  83%|███████████████████████████████████████████████████████████████████████▋              | 10/12 [00:48<00:09,  4.78s/it]

tensor([[9.9942e-01, 2.1535e-28, 2.3975e-10,  ..., 2.5002e-15, 7.3368e-26,
         1.1500e-10],
        [1.3232e-11, 2.1431e-23, 9.9952e-01,  ..., 3.7989e-23, 1.2645e-08,
         2.1870e-16],
        [3.6077e-17, 1.3948e-20, 2.3281e-16,  ..., 1.5382e-29, 1.0000e+00,
         2.3395e-11],
        ...,
        [3.8056e-09, 2.2042e-05, 1.1134e-08,  ..., 2.8973e-13, 3.9055e-02,
         6.0978e-08],
        [2.9013e-14, 7.9600e-12, 8.7565e-22,  ..., 9.9999e-01, 3.6568e-23,
         1.3341e-10],
        [1.5884e-24, 1.0000e+00, 3.9899e-41,  ..., 7.9047e-15, 2.4185e-18,
         4.8870e-13]])


Testing:  92%|██████████████████████████████████████████████████████████████████████████████▊       | 11/12 [00:53<00:04,  4.74s/it]

tensor([[1.0445e-12, 8.6628e-22, 8.8046e-01,  ..., 1.4032e-23, 9.1159e-07,
         1.8884e-16],
        [1.7702e-12, 3.7516e-18, 1.0560e-19,  ..., 1.1864e-25, 9.9997e-01,
         2.8974e-05],
        [4.9389e-10, 2.6126e-14, 1.3735e-22,  ..., 9.9973e-01, 2.6501e-21,
         2.6364e-04],
        ...,
        [1.0000e+00, 1.2688e-32, 3.6525e-15,  ..., 1.1560e-18, 4.4719e-28,
         8.6143e-11],
        [2.0575e-15, 1.1631e-19, 4.9698e-12,  ..., 1.2003e-28, 1.0000e+00,
         1.2100e-12],
        [2.9249e-15, 2.3680e-14, 1.0464e-27,  ..., 3.4496e-21, 2.0720e-22,
         1.4370e-13]])


Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:56<00:00,  4.72s/it]


tensor([[2.6864e-07, 1.0583e-14, 2.5442e-20,  ..., 6.4644e-03, 8.4381e-16,
         9.9341e-01],
        [5.3368e-16, 1.0469e-12, 2.4557e-29,  ..., 1.5142e-20, 9.3359e-23,
         4.8260e-14],
        [4.3448e-12, 4.0182e-22, 9.8614e-01,  ..., 2.8778e-23, 4.1751e-07,
         4.6743e-16],
        ...,
        [1.4588e-13, 5.1249e-13, 2.0842e-20,  ..., 9.9988e-01, 7.6951e-23,
         2.1498e-10],
        [2.0358e-11, 4.8099e-19, 8.3110e-33,  ..., 3.2544e-20, 1.3567e-11,
         1.0000e+00],
        [1.9210e-11, 9.2163e-20, 3.6051e-30,  ..., 1.8885e-15, 2.2740e-14,
         1.0000e+00]])
Test loss: 1.71
Test accuracy: 74.70%


                                                                                                                                    

Epoch 1 loss: 2.04
Validation accuracy: 41.87%


                                                                                                                                    

Epoch 2 loss: 1.89
Validation accuracy: 57.12%


                                                                                                                                    

Epoch 3 loss: 1.84
Validation accuracy: 62.34%


                                                                                                                                    

Epoch 4 loss: 1.79
Validation accuracy: 66.68%


                                                                                                                                    

Epoch 5 loss: 1.77
Validation accuracy: 68.51%


                                                                                                                                    

Epoch 6 loss: 1.78
Validation accuracy: 68.43%


                                                                                                                                    

Epoch 7 loss: 1.78
Validation accuracy: 67.67%


                                                                                                                                    

Epoch 8 loss: 1.79
Validation accuracy: 66.54%


                                                                                                                                    

Epoch 9 loss: 1.77
Validation accuracy: 69.53%


                                                                                                                                    

Epoch 10 loss: 1.78
Validation accuracy: 68.37%


Testing:   8%|███████▎                                                                               | 1/12 [00:04<00:48,  4.42s/it]

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.5042e-11, 8.0362e-01,
         1.9638e-01],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.5224e-20, 9.9994e-01,
         5.6713e-05],
        ...,
        [2.5085e-21, 2.3629e-36, 3.4971e-18,  ..., 3.3920e-31, 3.5058e-12,
         1.7719e-35],
        [5.8705e-12, 6.8193e-39, 4.0818e-19,  ..., 7.4023e-38, 2.1632e-22,
         0.0000e+00],
        [7.8218e-12, 6.5196e-32, 1.3313e-33,  ..., 4.8625e-43, 0.0000e+00,
         0.0000e+00]])


Testing:  17%|██████████████▌                                                                        | 2/12 [00:08<00:43,  4.39s/it]

tensor([[1.0000e+00, 8.1561e-34, 4.7910e-29,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 9.1839e-17, 3.1413e-37,
         1.0000e+00],
        [3.8102e-13, 4.9396e-39, 2.4619e-18,  ..., 1.1041e-34, 1.0710e-19,
         0.0000e+00],
        ...,
        [1.6635e-21, 3.9578e-07, 6.6512e-31,  ..., 0.0000e+00, 0.0000e+00,
         4.6106e-36],
        [1.9654e-34, 1.9618e-44, 1.2552e-22,  ..., 2.0626e-29, 1.7262e-09,
         2.3520e-30],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 3.1261e-22, 1.0567e-35,
         1.0000e+00]])


Testing:  25%|█████████████████████▊                                                                 | 3/12 [00:13<00:40,  4.51s/it]

tensor([[2.0861e-39, 2.7734e-31, 7.6201e-41,  ..., 0.0000e+00, 0.0000e+00,
         1.8531e-41],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 5.2040e-02, 3.0494e-37,
         9.4796e-01],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.9139e-37, 1.5820e-16,
         2.8067e-20],
        ...,
        [6.0850e-35, 1.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 3.3586e-40, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


Testing:  33%|█████████████████████████████                                                          | 4/12 [00:17<00:35,  4.44s/it]

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [4.2060e-32, 1.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 5.1234e-37,  ..., 2.2014e-42, 1.4352e-19,
         3.6936e-32],
        ...,
        [2.7948e-13, 4.9670e-22, 6.0285e-18,  ..., 2.7929e-39, 7.6617e-18,
         5.2347e-37],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.0000e+00, 1.1969e-34,
         1.9855e-21],
        [7.4339e-01, 1.6765e-11, 5.5003e-23,  ..., 0.0000e+00, 1.8217e-44,
         0.0000e+00]])


Testing:  42%|████████████████████████████████████▎                                                  | 5/12 [00:22<00:30,  4.40s/it]

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.6413e-22, 1.0000e+00,
         1.5451e-11],
        [1.1353e-37, 1.6956e-43, 1.5676e-26,  ..., 1.3545e-39, 4.6164e-14,
         1.7423e-34],
        [1.3318e-34, 1.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.3733e-23, 2.2879e-22, 1.5121e-33,  ..., 0.0000e+00, 0.0000e+00,
         7.0065e-45],
        [1.0000e+00, 8.3253e-32, 1.4447e-34,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


Testing:  50%|███████████████████████████████████████████▌                                           | 6/12 [00:26<00:26,  4.34s/it]

tensor([[9.2205e-43, 0.0000e+00, 1.1228e-27,  ..., 3.5279e-33, 1.5262e-14,
         3.6891e-32],
        [1.2784e-30, 1.8373e-21, 2.4493e-35,  ..., 0.0000e+00, 0.0000e+00,
         5.5850e-35],
        [1.2400e-09, 5.4227e-29, 1.5250e-18,  ..., 4.5528e-42, 6.0720e-19,
         0.0000e+00],
        ...,
        [2.2411e-32, 1.4013e-45, 1.4256e-23,  ..., 3.7739e-38, 1.2291e-14,
         1.6331e-39],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.0000e+00, 7.3330e-42,
         6.4818e-26],
        [4.9443e-36, 1.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


Testing:  58%|██████████████████████████████████████████████████▊                                    | 7/12 [00:30<00:21,  4.33s/it]

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.5271e-26, 3.5548e-01,
         5.6491e-17],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 6.1452e-39, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 1.2911e-35,  ..., 4.6300e-26, 9.9999e-01,
         1.8870e-23],
        ...,
        [1.0627e-32, 2.9472e-31, 5.0715e-31,  ..., 0.0000e+00, 2.3004e-35,
         4.1640e-35],
        [0.0000e+00, 0.0000e+00, 8.3540e-38,  ..., 2.1019e-44, 2.0533e-25,
         5.9211e-32],
        [3.5347e-29, 1.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


Testing:  67%|██████████████████████████████████████████████████████████                             | 8/12 [00:34<00:17,  4.33s/it]

tensor([[3.5465e-27, 3.2521e-17, 6.1106e-35,  ..., 0.0000e+00, 0.0000e+00,
         2.7635e-35],
        [0.0000e+00, 0.0000e+00, 1.1046e-36,  ..., 2.1019e-44, 3.6065e-25,
         9.8193e-31],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.8089e-22, 7.7228e-36,
         1.0000e+00],
        ...,
        [2.1133e-37, 1.7741e-36, 6.2353e-30,  ..., 0.0000e+00, 5.3207e-29,
         1.5334e-32],
        [1.8485e-34, 1.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [4.9674e-35, 1.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


Testing:  75%|█████████████████████████████████████████████████████████████████▎                     | 9/12 [00:39<00:12,  4.32s/it]

tensor([[1.3220e-28, 1.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 2.9704e-39,  ..., 3.0089e-39, 1.0721e-18,
         2.8569e-32],
        [0.0000e+00, 0.0000e+00, 2.1342e-32,  ..., 5.7313e-43, 2.9971e-22,
         4.3591e-36],
        ...,
        [3.8707e-23, 8.6390e-15, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [3.1339e-18, 2.6420e-39, 9.0596e-20,  ..., 3.5801e-37, 6.4014e-20,
         5.8855e-44],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.0000e+00, 9.7479e-35,
         3.2430e-17]])


Testing:  83%|███████████████████████████████████████████████████████████████████████▋              | 10/12 [00:43<00:08,  4.31s/it]

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.0000e+00, 0.0000e+00,
         1.4968e-25],
        [2.3448e-15, 3.1885e-41, 2.4306e-20,  ..., 6.2238e-39, 1.9740e-21,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 4.9059e-28, 1.4432e-27,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 4.1895e-05, 1.4388e-08,
         9.9996e-01],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.0000e+00, 2.7503e-24,
         2.8793e-13]])


Testing:  92%|██████████████████████████████████████████████████████████████████████████████▊       | 11/12 [00:47<00:04,  4.34s/it]

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.6735e-40, 6.0852e-18,
         2.3239e-26],
        [0.0000e+00, 1.6804e-31, 2.0287e-39,  ..., 9.9764e-01, 1.4628e-41,
         5.4505e-09],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 3.7174e-32, 1.3627e-09,
         2.4162e-14],
        ...,
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.8419e-32, 1.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:51<00:00,  4.27s/it]

tensor([[1.9716e-36, 1.6368e-27, 9.0908e-41,  ..., 0.0000e+00, 0.0000e+00,
         2.4447e-41],
        [2.7606e-35, 1.1342e-30, 3.0168e-30,  ..., 2.7003e-42, 3.2661e-28,
         5.2408e-29],
        [9.5002e-05, 1.6139e-25, 1.6090e-16,  ..., 7.6634e-41, 3.4091e-27,
         1.2612e-44],
        ...,
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 9.8060e-19, 2.2400e-36,
         1.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 5.5351e-24, 1.0000e+00,
         9.2594e-13],
        [7.5068e-36, 1.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
Test loss: 1.77
Test accuracy: 69.32%





IndexError: list index out of range

In [22]:
test_acc_list = []
for hidden_dim in hidden_dims:
    model = torch.load(f'./saved_model__hidden_dim{hidden_dim}.pt')
    # test model to get test loss and accuracy
    test_accuracy, test_loss = test_loop(test_dataloader, model, loss_fn)
    test_acc_list.append(test_accuracy)
    
print(f'Test accuracy for hidden_dim 8: {test_acc_list[0]:.2f}%')
print(f'Test accuracy for hidden_dim 32: {test_acc_list[1]:.2f}%')

Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:49<00:00,  4.16s/it]


Test loss: 1.71
Test accuracy: 74.70%


Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:49<00:00,  4.16s/it]

Test loss: 1.77
Test accuracy: 69.32%
Test accuracy for hidden_dim 8: 74.70%
Test accuracy for hidden_dim 32: 69.32%





### Changing n_encodelayer (4, 8 vs 2 base)

In [12]:
chw = (1, 64, 64) # image dimensions
n_heads = 2 # no of attention heads
output_dim = 10 # Fashion MNIST has 10 classes
n_patches = 16
hidden_dim = 16 # number of features in each patch's representation

# parameter to test
n_encodelayer = [4, 8]

batch_size = 512
learning_rate = 0.01
num_epochs = 10

test_acc_list = []
for n_encodelayers in n_encodelayer:
    # instantiate model
    model = VisionTransformer(chw, n_patches, hidden_dim, n_encodelayers, n_heads, output_dim)
    
    # instantiate loss function and optimizer
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # train model with training data
    for epoch in range(num_epochs):
        train_loop(train_dataloader, model, loss_fn, optimizer, epoch)
        torch.save(model, f'./saved_model__enclayer{n_encodelayers}.pt')

    # test model to get test loss and accuracy
    test_accuracy, test_loss = test_loop(test_dataloader, model, loss_fn)
    test_acc_list.append(test_accuracy)
    
print(f'Test accuracy for n_encodelayer 4: {test_acc_list[0]:.2f}%')
print(f'Test accuracy for n_encodelayer 8: {test_acc_list[1]:.2f}%')

  self.pos_embed = nn.Parameter(torch.tensor(get_positional_embedding(self.n_patches ** 2 + 1, self.hidden_dim)))
                                                                                                                                    

Epoch 1 loss: 2.11
Validation accuracy: 33.75%


                                                                                                                                    

Epoch 2 loss: 1.89
Validation accuracy: 56.81%


                                                                                                                                    

Epoch 3 loss: 1.83
Validation accuracy: 63.61%


                                                                                                                                    

Epoch 4 loss: 1.81
Validation accuracy: 64.75%


                                                                                                                                    

Epoch 5 loss: 1.79
Validation accuracy: 67.00%


                                                                                                                                    

Epoch 6 loss: 1.77
Validation accuracy: 68.93%


                                                                                                                                    

Epoch 7 loss: 1.76
Validation accuracy: 69.87%


                                                                                                                                    

Epoch 8 loss: 1.75
Validation accuracy: 71.48%


                                                                                                                                    

Epoch 9 loss: 1.74
Validation accuracy: 72.27%


                                                                                                                                    

Epoch 10 loss: 1.74
Validation accuracy: 72.22%


Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████| 12/12 [01:10<00:00,  5.85s/it]


Test loss: 1.76
Test accuracy: 70.18%


                                                                                                                                    

Epoch 1 loss: 2.11
Validation accuracy: 34.64%


                                                                                                                                    

Epoch 2 loss: 1.86
Validation accuracy: 60.96%


                                                                                                                                    

Epoch 3 loss: 1.80
Validation accuracy: 66.72%


                                                                                                                                    

Epoch 4 loss: 1.78
Validation accuracy: 68.49%


                                                                                                                                    

Epoch 5 loss: 1.77
Validation accuracy: 69.39%


                                                                                                                                    

Epoch 6 loss: 1.75
Validation accuracy: 71.09%


                                                                                                                                    

Epoch 7 loss: 1.74
Validation accuracy: 72.29%


                                                                                                                                    

Epoch 8 loss: 1.73
Validation accuracy: 72.68%


                                                                                                                                    

Epoch 9 loss: 1.73
Validation accuracy: 72.57%


                                                                                                                                    

Epoch 10 loss: 1.73
Validation accuracy: 73.06%


Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████| 12/12 [01:25<00:00,  7.16s/it]

Test loss: 1.72
Test accuracy: 73.75%
Test accuracy for n_encodelayer 4: 70.18%
Test accuracy for n_encodelayer 8: 73.75%





### Changing n_heads (4, 8 vs 2 base)

In [17]:
chw = (1, 64, 64) # image dimensions
output_dim = 10 # Fashion MNIST has 10 classes
n_patches = 16
hidden_dim = 16 # number of features in each patch's representation
n_encodelayers = 2

# parameter to test
n_head = [4, 8] # no of attention heads

batch_size = 512
learning_rate = 0.01
num_epochs = 10

test_acc_list = []
for n_heads in n_head:
    # instantiate model
    model = VisionTransformer(chw, n_patches, hidden_dim, n_encodelayers, n_heads, output_dim)
    
    # instantiate loss function and optimizer
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # train model with training data
    for epoch in range(num_epochs):
        train_loop(train_dataloader, model, loss_fn, optimizer, epoch)
        torch.save(model, f'./saved_model__n_head{n_heads}.pt')

    # test model to get test loss and accuracy
    test_accuracy, test_loss = test_loop(test_dataloader, model, loss_fn)
    test_acc_list.append(test_accuracy)
    
print(f'Test accuracy for n_heads 4: {test_acc_list[0]:.2f}%')
print(f'Test accuracy for n_heads 8: {test_acc_list[1]:.2f}%')

  self.pos_embed = nn.Parameter(torch.tensor(get_positional_embedding(self.n_patches ** 2 + 1, self.hidden_dim)))
                                                                                                                                    

Epoch 1 loss: 2.09
Validation accuracy: 36.03%


                                                                                                                                    

Epoch 2 loss: 1.86
Validation accuracy: 59.81%


                                                                                                                                    

Epoch 3 loss: 1.82
Validation accuracy: 64.09%


                                                                                                                                    

Epoch 4 loss: 1.80
Validation accuracy: 65.90%


                                                                                                                                    

Epoch 5 loss: 1.80
Validation accuracy: 66.49%


                                                                                                                                    

Epoch 6 loss: 1.79
Validation accuracy: 66.79%


                                                                                                                                    

Epoch 7 loss: 1.79
Validation accuracy: 67.15%


                                                                                                                                    

Epoch 8 loss: 1.79
Validation accuracy: 67.01%


                                                                                                                                    

Epoch 9 loss: 1.75
Validation accuracy: 70.51%


                                                                                                                                    

Epoch 10 loss: 1.74
Validation accuracy: 72.54%


Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████| 12/12 [01:06<00:00,  5.54s/it]


Test loss: 1.75
Test accuracy: 71.03%


                                                                                                                                    

Epoch 1 loss: 2.13
Validation accuracy: 32.96%


                                                                                                                                    

Epoch 2 loss: 1.92
Validation accuracy: 53.83%


                                                                                                                                    

Epoch 3 loss: 1.87
Validation accuracy: 59.02%


                                                                                                                                    

Epoch 4 loss: 1.83
Validation accuracy: 62.69%


                                                                                                                                    

Epoch 5 loss: 1.82
Validation accuracy: 63.74%


                                                                                                                                    

Epoch 6 loss: 1.82
Validation accuracy: 64.37%


                                                                                                                                    

Epoch 7 loss: 1.81
Validation accuracy: 64.69%


                                                                                                                                    

Epoch 8 loss: 1.81
Validation accuracy: 64.98%


                                                                                                                                    

Epoch 9 loss: 1.81
Validation accuracy: 65.50%


                                                                                                                                    

Epoch 10 loss: 1.81
Validation accuracy: 65.35%


Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████| 12/12 [01:26<00:00,  7.21s/it]

Test loss: 1.82
Test accuracy: 63.57%
Test accuracy for n_heads 4: 71.03%
Test accuracy for n_heads 8: 63.57%





### Final Training (On whole dataset)

In [51]:
# load data
train_dataloader = DataLoader(training_data, shuffle=True, batch_size = batch_size)
test_dataloader = DataLoader(test_data, shuffle=True, batch_size = batch_size)

In [15]:
import time

chw = (1, 64, 64) # image dimensions
output_dim = 10 # Fashion MNIST has 10 classes
n_patches = 16

# optimal parameters
hidden_dim = 8 # number of features in each patch's representation
n_encodelayers = 8
n_heads = 4 # no of attention heads

learning_rate = 0.005
batch_size = 512
num_epochs = 20

# instantiate model
model = VisionTransformer(chw, n_patches, hidden_dim, n_encodelayers, n_heads, output_dim)

# instantiate loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# train model with training data
time_taken = []
for epoch in range(num_epochs):
    start_time = time.time()
    train_loop(train_dataloader, model, loss_fn, optimizer, epoch)
    end_time = time.time()
    print(f'Time taken for Epoch {epoch+1}: {end_time - start_time}')
    time_taken.append(end_time - start_time)
    torch.save(model, f'./final_model.pt')

# obtain training time
total_time = sum(time_taken)
print(f'Total training time = {total_time}')

# test model to get test loss and accuracy
test_accuracy, test_loss = test_loop(test_dataloader, model, loss_fn)

  self.pos_embed = nn.Parameter(torch.tensor(get_positional_embedding(self.n_patches ** 2 + 1, self.hidden_dim)))
                                                                                                                                    

Epoch 1 loss: 2.19
Validation accuracy: 26.56%
Time taken for Epoch 1: 2010.0962772369385


                                                                                                                                    

Epoch 2 loss: 1.90
Validation accuracy: 56.31%
Time taken for Epoch 2: 1997.5132689476013


                                                                                                                                    

Epoch 3 loss: 1.80
Validation accuracy: 66.81%
Time taken for Epoch 3: 1809.9591801166534


                                                                                                                                    

Epoch 4 loss: 1.76
Validation accuracy: 70.00%
Time taken for Epoch 4: 1903.035394668579


                                                                                                                                    

Epoch 5 loss: 1.74
Validation accuracy: 72.36%
Time taken for Epoch 5: 2082.5147771835327


                                                                                                                                    

Epoch 6 loss: 1.72
Validation accuracy: 73.83%
Time taken for Epoch 6: 2008.290519952774


                                                                                                                                    

Epoch 7 loss: 1.71
Validation accuracy: 75.32%
Time taken for Epoch 7: 2123.9680531024933


                                                                                                                                    

Epoch 8 loss: 1.69
Validation accuracy: 76.89%
Time taken for Epoch 8: 1902.0833613872528


                                                                                                                                    

Epoch 9 loss: 1.69
Validation accuracy: 77.45%
Time taken for Epoch 9: 1742.3469936847687


                                                                                                                                    

Epoch 10 loss: 1.68
Validation accuracy: 78.67%
Time taken for Epoch 10: 1992.782776594162


                                                                                                                                    

Epoch 11 loss: 1.67
Validation accuracy: 79.24%
Time taken for Epoch 11: 2179.439190387726


                                                                                                                                    

Epoch 12 loss: 1.66
Validation accuracy: 79.66%
Time taken for Epoch 12: 1685.1639959812164


                                                                                                                                    

Epoch 13 loss: 1.66
Validation accuracy: 79.82%
Time taken for Epoch 13: 1672.242731809616


                                                                                                                                    

Epoch 14 loss: 1.66
Validation accuracy: 80.02%
Time taken for Epoch 14: 21104.47133755684


                                                                                                                                    

Epoch 15 loss: 1.66
Validation accuracy: 80.52%
Time taken for Epoch 15: 1638.712197303772


                                                                                                                                    

Epoch 16 loss: 1.65
Validation accuracy: 80.65%
Time taken for Epoch 16: 1650.6607539653778


                                                                                                                                    

Epoch 17 loss: 1.65
Validation accuracy: 80.89%
Time taken for Epoch 17: 1646.6539885997772


                                                                                                                                    

Epoch 18 loss: 1.65
Validation accuracy: 81.05%
Time taken for Epoch 18: 1651.6969330310822


                                                                                                                                    

Epoch 19 loss: 1.65
Validation accuracy: 80.82%
Time taken for Epoch 19: 1651.7701678276062


                                                                                                                                    

Epoch 20 loss: 1.65
Validation accuracy: 81.41%
Time taken for Epoch 20: 1653.6248803138733
Total training time = 56107.02677965164


Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████| 20/20 [03:54<00:00, 11.71s/it]

Test loss: 1.66
Test accuracy: 80.44%





In [4]:
# obtain training time of final model

# remove 14th epoch because anomalous time taken (error due to laptop issues)
time_taken_10epochs = time_taken[:13] + time_taken[14:]

total_time_mins = sum(time_taken_10epochs) / 60
print(f'Total Training Time: {total_time_mins} minutes')

avg_time = total_time_mins/19
print(f'Average Time Per Epoch: {avg_time} minutes')

Total Training Time: 583.3759240349134 minutes
Average Time Per Epoch: 30.70399600183755 minutes
