In [2]:
import numpy as np
from tqdm import tqdm, trange
from pathlib import Path

import torch
import torch.nn as nn
from torch import optim
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, Subset, default_collate
from torchvision import transforms
from torchvision.transforms import ToTensor, v2
from torchvision import datasets

np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x2aab881483b0>

In [3]:
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

Using device: cuda

Tesla V100-PCIE-32GB
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [4]:
# load fashion MNIST data and transform images
transform = transforms.Compose([transforms.Resize((64, 64)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=transform
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=transform
)

In [5]:
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}

In [6]:
def patchify(images, n_patches):
    n, c, h, w = images.shape # n=num_images, c=image dimension

    # ensure input images are squares
    assert h==w, "image must be square"

    # instantiate patches as a 3D zero tensor (** means power of)
    patches = torch.zeros(n, n_patches**2, h*w*c // n_patches**2) # (num_images, num_patches, patch dimension)
    patch_size = h //n_patches

    for index, image in enumerate(images):
        for i in range(n_patches):
          for j in range(n_patches):
              patch = image[:, i*patch_size: (i+1)*patch_size, j*patch_size: (j+1)*patch_size]
              patches[index, i*n_patches + j] = patch.flatten()
    return patches


In [7]:
def get_positional_embedding(seq_length, dim):
    pe = torch.ones(seq_length, dim)
    for i in range(seq_length):
        for j in range(dim):
            if j % 2 == 0:
                pe[i][j] = np.sin(i/(10000 ** (j/dim)))
            elif j % 2 == 1:
                pe[i][j] = np.cos(i / (10000 ** ((j-1)/dim)))
    return pe

In [8]:
class MultiHeadSA(nn.Module):
    def __init__(self, dim, n_heads=2):
        super(MultiHeadSA, self).__init__()
        self.dim = dim
        self.n_heads = n_heads

        assert dim % n_heads == 0, f"Can't divide dimension {dim} into {n_heads} heads"

        # creating weight matrix of q, k and v
        d_head = int(dim / n_heads)
        self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for x in range(self.n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for x in range(self.n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for x in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        # Sequences has shape (N, seq_length, token_dim)
        # We go into shape    (N, seq_length, n_heads, token_dim / n_heads)
        # And come back to    (N, seq_length, item_dim)  (through concatenation)
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]

                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))

        # concat attention from each head together
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

In [9]:
class ResidualConnection(nn.Module):
    def __init__(self, hidden_dim, n_heads, mlp_ratio=4):
        super(ResidualConnection, self).__init__()
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(hidden_dim)

        # self-attention layer
        self.mhsa = MultiHeadSA(hidden_dim, n_heads)

        self.norm2 = nn.LayerNorm(hidden_dim)

        # feed forward layer
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, mlp_ratio * hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_dim, hidden_dim)
        )

    def forward(self, x):
        mhsa_out = x + self.mhsa(self.norm1(x))
        ff_out = mhsa_out + self.mlp(self.norm2(mhsa_out))
        return ff_out

In [10]:
# instantiate model
class VisionTransformer(nn.Module):
  def __init__(self, chw, n_patches, hidden_dim, n_encodelayers, n_heads, output_dim):
    super(VisionTransformer, self).__init__()

    self.chw = chw
    self.n_patches = n_patches
    self.hidden_dim = hidden_dim
    self.n_encodelayers = n_encodelayers
    self.n_heads = n_heads
    self.output_dim = output_dim

    # ensure that height and width are divisble by num of patches
    assert chw[1] % n_patches == 0, "Input shape is not divisible by number of patches"
    assert chw[2] % n_patches == 0, "Input shape is not divisible by number of patches"
    self.patch_size = (chw[1]/n_patches, chw[2]/n_patches)

    # linear mapping
    self.input_dim = int(chw[0] * self.patch_size[0] * self.patch_size[1])
    self.linear_map = nn.Linear(self.input_dim, self.hidden_dim)

    # classification head
    self.class_token = nn.Parameter(torch.rand(1, self.hidden_dim))

    # positional encoding
    self.pos_embed = nn.Parameter(torch.tensor(get_positional_embedding(self.n_patches ** 2 + 1, self.hidden_dim)))
    # self.pos_embed.requires_grad = False

    # transformer encoder
    self.encoder_layers = nn.ModuleList([ResidualConnection(hidden_dim, n_heads) for x in range(n_encodelayers)])

    # extract classification token
    self.mlp = nn.Sequential(
        nn.Linear(self.hidden_dim, output_dim),
        nn.Softmax(dim=-1)
    )

  def forward(self, input_image):
    patches = patchify(input_image, self.n_patches)
    tokens = self.linear_map(patches)

    # add classification head to each token (learnable embedding)
    tokens_with_class = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])

    # add positional embedding
    pos_embed = self.pos_embed.repeat(tokens_with_class.shape[0], 1, 1)
    pos_embed_token = tokens_with_class + pos_embed

    # transformer encoding layers
    for layer in self.encoder_layers:
      output = layer(pos_embed_token)

    # extract classification token
    output = output[:, 0]
    pred = self.mlp(output)

    return pred # Map to output dimension, output category distribution

### Data Augmentation 

In [11]:
# define mixup and cutmix functions
def mixup(imgs, labels, alpha):
    lam = np.random.beta(alpha,alpha)
    index = torch.randperm(len(imgs))
    shuffled_imgs = imgs[index]
    shuffled_labels = labels[index]
    new_imgs = lam*imgs + (1-lam)*shuffled_imgs

    return new_imgs, shuffled_labels, lam 

# random bounding box used in cutmix for cut and paste operation
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2

def cutmix(data, target, alpha):
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_target = target[indices]

    lam = np.random.beta(alpha, alpha)
    bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam)
    new_data = data.clone()
    new_data[:, :, bby1:bby2, bbx1:bbx2] = data[indices, :, bby1:bby2, bbx1:bbx2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size()[-1] * data.size()[-2]))

    return new_data, shuffled_target, lam

In [16]:
# modify train loop to include data augmentation
def train_loop(train_loader, model, loss_fn, optimizer, epoch, cutmixalpha, mixupalpha):
    train_loss = 0.0
    correct, total = 0, 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
        x, y = batch
        x, y = x.to(device), y.to(device)
        
        # execute data augmentation randomly
        if np.random.rand() > 0.5:  # Adjust the probability as needed
            augmented_x, y, lam = cutmix(x, y, cutmixalpha)
        else:
            augmented_x, y, lam = mixup(x, y, mixupalpha)
        augmented_x, y = augmented_x.to(device), y.to(device)
        
        output = model(augmented_x).to(device)
        loss = loss_fn(output, y)

        train_loss += (loss.item() / len(train_loader))
        correct += torch.sum(torch.argmax(output, dim=1) == y).item()
        total += len(augmented_x)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} loss: {train_loss:.2f}")
    print(f"Validation accuracy: {correct / total * 100:.2f}%")
    
def test_loop(test_loader, model, loss_fn):
    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)
            output = model(x).to(device)
            loss = loss_fn(output, y)
            test_loss += (loss.item() / len(test_loader))

            correct += torch.sum(torch.argmax(output, dim=1) == y).item()
            total += len(x)
        test_acc = correct / total * 100
        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {test_acc:.2f}%")
    return test_acc, test_loss

In [13]:
# load data
batch_size = 512
train_dataloader = DataLoader(training_data, shuffle=True, batch_size = batch_size)
test_dataloader = DataLoader(test_data, shuffle=True, batch_size = batch_size)

In [15]:
import time

chw = (1, 64, 64) # image dimensions
output_dim = 10 # Fashion MNIST has 10 classes
n_patches = 16

# define augmentation alphas
cutmixalpha = 1.0
mixupalpha = 0.2

# optimal parameters
hidden_dim = 8 # number of features in each patch's representation
n_encodelayers = 8
n_heads = 4 # no of attention heads

learning_rate = 0.005
batch_size = 512
num_epochs = 10

# instantiate model
model = VisionTransformer(chw, n_patches, hidden_dim, n_encodelayers, n_heads, output_dim)

# instantiate loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# train model with training data
time_taken = []
for epoch in range(num_epochs):
    start_time = time.time()
    train_loop(train_dataloader, model, loss_fn, optimizer, epoch, cutmixalpha, mixupalpha)
    end_time = time.time()
    print(f'Time taken for Epoch {epoch+1}: {end_time - start_time}')
    time_taken.append(end_time - start_time)
    torch.save(model, f'./augmentation_model.pt')

# obtain training time
total_time = sum(time_taken)
print(f'Total training time = {total_time}')

# test model to get test loss and accuracy
test_accuracy, test_loss = test_loop(test_dataloader, model, loss_fn)

  self.pos_embed = nn.Parameter(torch.tensor(get_positional_embedding(self.n_patches ** 2 + 1, self.hidden_dim)))
                                                                                                                                    

Epoch 1 loss: 2.28
Validation accuracy: 13.86%
Time taken for Epoch 1: 1796.95458984375


                                                                                                                                    

Epoch 2 loss: 2.25
Validation accuracy: 18.14%
Time taken for Epoch 2: 1886.6079955101013


                                                                                                                                    

Epoch 3 loss: 2.21
Validation accuracy: 22.87%
Time taken for Epoch 3: 1841.6621141433716


                                                                                                                                    

Epoch 4 loss: 2.20
Validation accuracy: 24.82%
Time taken for Epoch 4: 1706.9671611785889


                                                                                                                                    

Epoch 5 loss: 2.18
Validation accuracy: 27.69%
Time taken for Epoch 5: 1873.5339818000793


                                                                                                                                    

Epoch 6 loss: 2.17
Validation accuracy: 27.77%
Time taken for Epoch 6: 2127.2694861888885


                                                                                                                                    

Epoch 7 loss: 2.14
Validation accuracy: 31.55%
Time taken for Epoch 7: 2075.6422147750854


                                                                                                                                    

Epoch 8 loss: 2.14
Validation accuracy: 31.35%
Time taken for Epoch 8: 1759.2664937973022


                                                                                                                                    

Epoch 9 loss: 2.17
Validation accuracy: 28.36%
Time taken for Epoch 9: 1914.825074672699


                                                                                                                                    

Epoch 10 loss: 2.15
Validation accuracy: 30.52%
Time taken for Epoch 10: 1923.1636435985565
Total training time = 18905.892755508423


Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████| 20/20 [04:33<00:00, 13.67s/it]

Test loss: 1.81
Test accuracy: 65.77%





In [17]:
# obtain training time

total_time_mins = sum(time_taken) / 60
print(f'Total Training Time: {total_time_mins} minutes')

avg_time = total_time_mins/10
print(f'Average Time Per Epoch: {avg_time} minutes')

Total Training Time: 315.09821259180706 minutes
Average Time Per Epoch: 31.509821259180704 minutes


In [16]:
# run epochs 11 to 20

import time

chw = (1, 64, 64) # image dimensions
output_dim = 10 # Fashion MNIST has 10 classes
n_patches = 16

# define augmentation alphas
cutmixalpha = 1.0
mixupalpha = 0.2

# optimal parameters
hidden_dim = 8 # number of features in each patch's representation
n_encodelayers = 8
n_heads = 4 # no of attention heads

learning_rate = 0.005
batch_size = 512
num_epochs = 10

# instantiate model continued from previous training
if Path('./augmentation_model.pt').exists():
    model = torch.load('./augmentation_model.pt')
else:
    model = VisionTransformer(chw, n_patches, hidden_dim, n_encodelayers, n_heads, output_dim)

# instantiate loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# continue training model with training data for another 10 epochs (11 to 20 epochs)
time_taken = []
for epoch in range(10, 10+num_epochs):
    start_time = time.time()
    train_loop(train_dataloader, model, loss_fn, optimizer, epoch, cutmixalpha, mixupalpha)
    end_time = time.time()
    print(f'Time taken for Epoch {epoch+1}: {end_time - start_time}')
    time_taken.append(end_time - start_time)
    torch.save(model, f'./augmentation_model.pt')

# obtain training time
total_time = sum(time_taken)
print(f'Total training time = {total_time}')

# test model to get test loss and accuracy
test_accuracy, test_loss = test_loop(test_dataloader, model, loss_fn)

                                                                                                                                    

Epoch 11 loss: 2.16
Validation accuracy: 29.36%
Time taken for Epoch 11: 2102.367213487625


                                                                                                                                    

Epoch 12 loss: 2.14
Validation accuracy: 30.91%
Time taken for Epoch 12: 2001.2474513053894


                                                                                                                                    

Epoch 13 loss: 2.15
Validation accuracy: 30.72%
Time taken for Epoch 13: 1790.8924679756165


                                                                                                                                    

Epoch 14 loss: 2.16
Validation accuracy: 28.99%
Time taken for Epoch 14: 1764.5591716766357


                                                                                                                                    

Epoch 15 loss: 2.13
Validation accuracy: 32.33%
Time taken for Epoch 15: 1766.3577327728271


                                                                                                                                    

Epoch 16 loss: 2.14
Validation accuracy: 31.17%
Time taken for Epoch 16: 1776.1402904987335


                                                                                                                                    

Epoch 17 loss: 2.11
Validation accuracy: 34.93%
Time taken for Epoch 17: 1766.2503604888916


                                                                                                                                    

Epoch 18 loss: 2.11
Validation accuracy: 34.61%
Time taken for Epoch 18: 1771.2024459838867


                                                                                                                                    

Epoch 19 loss: 2.16
Validation accuracy: 29.97%
Time taken for Epoch 19: 1779.6425502300262


                                                                                                                                    

Epoch 20 loss: 2.13
Validation accuracy: 32.35%
Time taken for Epoch 20: 1758.8219933509827
Total training time = 18277.481677770615


Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████| 20/20 [04:05<00:00, 12.30s/it]

Test loss: 1.76
Test accuracy: 70.82%





In [14]:
# run epochs 21 to 25 to see if accuracy continues to increase

import time

chw = (1, 64, 64) # image dimensions
output_dim = 10 # Fashion MNIST has 10 classes
n_patches = 16

# define augmentation alphas
cutmixalpha = 1.0
mixupalpha = 0.2

# optimal parameters
hidden_dim = 8 # number of features in each patch's representation
n_encodelayers = 8
n_heads = 4 # no of attention heads

learning_rate = 0.005
batch_size = 512
num_epochs = 5

# instantiate model continued from previous training
if Path('./augmentation_model.pt').exists():
    model = torch.load('./augmentation_model.pt')
else:
    model = VisionTransformer(chw, n_patches, hidden_dim, n_encodelayers, n_heads, output_dim)

# instantiate loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# continue training model with training data for another 10 epochs (11 to 20 epochs)
time_taken = []
for epoch in range(20, 20+num_epochs):
    start_time = time.time()
    train_loop(train_dataloader, model, loss_fn, optimizer, epoch, cutmixalpha, mixupalpha)
    end_time = time.time()
    print(f'Time taken for Epoch {epoch+1}: {end_time - start_time}')
    time_taken.append(end_time - start_time)
    torch.save(model, f'./augmentation_model.pt')

# obtain training time
total_time = sum(time_taken)
print(f'Total training time = {total_time}')

# test model to get test loss and accuracy
test_accuracy, test_loss = test_loop(test_dataloader, model, loss_fn)

                                                                       

Epoch 21 loss: 2.14
Validation accuracy: 31.62%
Time taken for Epoch 21: 1519.1007499694824


                                                                       

Epoch 22 loss: 2.13
Validation accuracy: 32.81%
Time taken for Epoch 22: 1474.4625840187073


                                                                       

Epoch 23 loss: 2.13
Validation accuracy: 32.45%
Time taken for Epoch 23: 1455.6845943927765


                                                                       

Epoch 24 loss: 2.15
Validation accuracy: 30.91%
Time taken for Epoch 24: 1443.7331368923187


                                                                       

Epoch 25 loss: 2.12
Validation accuracy: 33.97%
Time taken for Epoch 25: 1467.1929368972778
Total training time = 7360.174002170563


Testing:   0%|          | 0/20 [00:09<?, ?it/s]


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument target in method wrapper_CUDA_nll_loss_forward)

In [17]:
# test model to get test loss and accuracy
test_accuracy, test_loss = test_loop(test_dataloader, model, loss_fn)

Testing: 100%|██████████| 20/20 [02:57<00:00,  8.90s/it]

Test loss: 1.76
Test accuracy: 71.25%



