In [1]:
import numpy as np
from tqdm import tqdm, trange

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

from torchvision.transforms import ToTensor, Compose, Normalize, RandomHorizontalFlip, RandomCrop
from torchvision.datasets import CIFAR10

np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7bc775bbc950>

In [2]:
# Step 1 : Patchifying and the mienar mapping to form feature vectors of each sub image

# Method 1 to patchify : Manual

def patchify(images, n_patches):
  n, c, h, w = images.shape

  assert h == w,    " Patchify method is implemented only for square images "

  patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2)
  patch_size = h // n_patches

  for idx, image in enumerate(images):
    for i in range(n_patches) :
      for j in range(n_patches) :
        patch = image[:, i * patch_size:(i + 1) * patch_size, j * patch_size : (j + 1) * patch_size ]
        patches[idx, i*n_patches + j] = patch.flatten()
  return patches



In [3]:
class MyMSA(nn.Module):
  def __init__(self, d, n_heads = 4):
    super(MyMSA, self).__init__()
    self.d = d
    self.n_heads = n_heads

    assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads "

    d_head = int(d/ n_heads)
    self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
    self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
    self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
    self.d_head = d_head
    self.softmax = nn.Softmax(dim=-1)

  def forward(self, sequences):
      # Sequences has shape (N, seq_length, token_dim)
      # We go into shape    (N, seq_length, n_heads, token_dim / n_heads)
      # And come back to    (N, seq_length, item_dim)  (through concatenation)
      result = []
      for sequence in sequences:
          seq_result = []
          for head in range(self.n_heads):
              q_mapping = self.q_mappings[head]
              k_mapping = self.k_mappings[head]
              v_mapping = self.v_mappings[head]

              seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
              q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

              attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
              seq_result.append(attention @ v)
          result.append(torch.hstack(seq_result))
      return torch.cat([torch.unsqueeze(r, dim=0) for r in result])


In [4]:
class MyViTBlock(nn.Module):
  def __init__(self, hidden_d, n_heads, mlp_ratio = 4):
    super(MyViTBlock, self).__init__()
    self.hidden_d = hidden_d
    self.n_heads = n_heads

    self.norm1 = nn.LayerNorm(hidden_d)    # layer normlisations , along last dimension
    self.mhsa = MyMSA(hidden_d, n_heads)
    self.norm2 = nn.LayerNorm(hidden_d)
    self.mlp = nn.Sequential(
        nn.Linear(hidden_d, mlp_ratio*hidden_d),
        nn.GELU(),
        nn.Linear(mlp_ratio*hidden_d, hidden_d)
    )


  def forward(self, x):
    out = x + self.mhsa(self.norm1(x))   # residual connections
    out = out + self.mlp(self.norm2(out))
    return out

In [9]:
#model ViT tht will classify images with shape (N*1*28*28)

class MyViT(nn.Module):
  def __init__(self, chw = (3, 32, 32), n_patches = 4, n_blocks =4, hidden_d = 128, n_heads = 4, out_d =100):
    # super constructor
    super(MyViT, self).__init__()

    # Attributes
    self.chw = chw   #(C, H, W)
    self.n_patches = n_patches
    self.n_blocks = n_blocks
    self.n_heads = n_heads
    self.hidden_d = hidden_d

    assert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
    assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
    self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)

    #1. Linear mapper
    self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
    self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)

    #2. Learnable classification token
    self.class_token = nn.Parameter(torch.randn(1, self.hidden_d))

    #3 positional embedding
    self.register_buffer('positional_embeddings', get_positional_embeddings(n_patches ** 2 + 1, hidden_d), persistent=False)

    # 4) Transformer encoder blocks
    self.blocks = nn.ModuleList([MyViTBlock(hidden_d, n_heads) for _ in range(n_blocks)])

    # 5) Classification MLPk
    self.mlp = nn.Sequential(
        nn.Linear(self.hidden_d, out_d),
        nn.Softmax(dim=-1)   # along last dimension to get probailities
    )
  def forward(self, images):
    #dividing images into patches
    n, c, h, w = images.shape
    patches = patchify(images, self.n_patches).to(self.positional_embeddings.device)

    #running linear layer tokensization
    #map the vector corresponding to each patch to the hidden size dimension
    tokens = self.linear_mapper(patches)

    #adding classification to the tokens
    tokens = torch.cat((self.class_token.expand(n, -1, -1),tokens), dim = 1)

    #add positional embedding
    out = tokens + self.positional_embeddings.repeat(n, 1, 1)

    #transformer blocks
    for block in self.blocks:
      out = block(out)

    # getting the classification token only
    out = out[:,0]

    return self.mlp(out) # map to output dimension, output category distribution

In [6]:
# got to MyViT
def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result

In [10]:
def main() :
  #define transformations for the CIFAR - 100 dataset
  transform = Compose([
      RandomCrop(32, padding = 4),
      RandomHorizontalFlip(),
      ToTensor(),
      Normalize((0.5071, 0.04865, 0.4409), (0.2673, 0.2564, 0.2762)),
  ])

  # Load CIFAR 100 dataset
  train_set = CIFAR10(root ='./../datasets', train = True, download = True, transform = transform)
  test_set = CIFAR10(root = './../datasets', train = False, download = True, transform = transform)

  # Create data loaders
  train_loader = DataLoader(train_set, shuffle = True, batch_size = 128, num_workers = 2)
  test_loader = DataLoader(test_set, shuffle = False, batch_size = 128, num_workers = 2)

  #defining model and training options
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  print("Using Device : ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")
  model = MyViT((3, 32, 32), n_patches = 4, n_blocks = 4, hidden_d = 128, n_heads = 4, out_d = 100).to(device)
  N_EPOCHS = 15
  LR = 0.001

  # training loop
  optimizer = Adam(model.parameters(), lr = LR)
  criterion = CrossEntropyLoss()
  for epoch in trange(N_EPOCHS, desc = 'Training'):
    train_loss = 0.0
    for batch in tqdm(train_loader, desc = f"Epoch {epoch + 1} in training", leave = False):
      x, y = batch
      x, y = x.to(device), y.to(device)
      y_hat = model(x)
      loss = criterion(y_hat, y)

      train_loss += loss.detach().cpu().item() / len(train_loader)

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

    print(f"Epoch {epoch + 1} / {N_EPOCHS} loss : {train_loss: .2f}")

  #TEST loop
  with torch.no_grad():
    correct, total = 0, 0
    test_loss = 0.0
    for batch in tqdm(test_loader, desc = 'Testing'):
      x, y = batch
      x, y = x.to(device), y.to(device)
      y_hat = model(x)
      loss = criterion(y_hat, y)
      test_loss += loss.detach().cpu().item() / len(test_loader)

      correct += torch.sum(torch.argmax(y_hat, dim = 1) == y).detach().cpu().item()
      total +=len(x)
    print(f"Test Loss: {test_loss:.2f}")
    print(f"Test Accuracy: {correct / total * 100:.2f}%")

In [8]:
main()

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./../datasets/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 48921389.64it/s]


Extracting ./../datasets/cifar-10-python.tar.gz to ./../datasets
Files already downloaded and verified
Using Device :  cpu 


Training:   0%|          | 0/15 [00:00<?, ?it/s]
Epoch 1 in training:   0%|          | 0/391 [00:00<?, ?it/s][A
Epoch 1 in training:   0%|          | 1/391 [00:02<15:33,  2.39s/it][A
Epoch 1 in training:   1%|          | 2/391 [00:04<13:36,  2.10s/it][A
Epoch 1 in training:   1%|          | 3/391 [00:06<13:30,  2.09s/it][A
Epoch 1 in training:   1%|          | 4/391 [00:08<13:56,  2.16s/it][A
Epoch 1 in training:   1%|▏         | 5/391 [00:11<14:29,  2.25s/it][A
Epoch 1 in training:   2%|▏         | 6/391 [00:12<13:22,  2.09s/it][A
Epoch 1 in training:   2%|▏         | 7/391 [00:14<12:38,  1.98s/it][A
Epoch 1 in training:   2%|▏         | 8/391 [00:16<12:15,  1.92s/it][A
Epoch 1 in training:   2%|▏         | 9/391 [00:18<11:51,  1.86s/it][A
Epoch 1 in training:   3%|▎         | 10/391 [00:19<11:40,  1.84s/it][A
Epoch 1 in training:   3%|▎         | 11/391 [00:22<13:11,  2.08s/it][A
Epoch 1 in training:   3%|▎         | 12/391 [00:24<13:37,  2.16s/it][A
Epoch 1 in training:

Epoch 1 / 15 loss :  4.52



Epoch 2 in training:   0%|          | 0/391 [00:00<?, ?it/s][A
Epoch 2 in training:   0%|          | 1/391 [00:03<23:34,  3.63s/it][A
Epoch 2 in training:   1%|          | 2/391 [00:06<19:34,  3.02s/it][A
Epoch 2 in training:   1%|          | 3/391 [00:09<19:00,  2.94s/it][A
Epoch 2 in training:   1%|          | 4/391 [00:10<16:18,  2.53s/it][A
Epoch 2 in training:   1%|▏         | 5/391 [00:12<14:48,  2.30s/it][A
Epoch 2 in training:   2%|▏         | 6/391 [00:14<13:49,  2.15s/it][A
Epoch 2 in training:   2%|▏         | 7/391 [00:16<13:19,  2.08s/it][A
Epoch 2 in training:   2%|▏         | 8/391 [00:18<13:02,  2.04s/it][A
Epoch 2 in training:   2%|▏         | 9/391 [00:21<15:07,  2.37s/it][A
Epoch 2 in training:   3%|▎         | 10/391 [00:23<14:13,  2.24s/it][A
Epoch 2 in training:   3%|▎         | 11/391 [00:25<13:36,  2.15s/it][A
Epoch 2 in training:   3%|▎         | 12/391 [00:27<13:09,  2.08s/it][A
Epoch 2 in training:   3%|▎         | 13/391 [00:29<12:58,  2.06s/it

Epoch 2 / 15 loss :  4.52



Epoch 3 in training:   0%|          | 0/391 [00:00<?, ?it/s][A
Epoch 3 in training:   0%|          | 1/391 [00:03<21:10,  3.26s/it][A
Epoch 3 in training:   1%|          | 2/391 [00:06<20:51,  3.22s/it][A
Epoch 3 in training:   1%|          | 3/391 [00:08<17:58,  2.78s/it][A
Epoch 3 in training:   1%|          | 4/391 [00:10<16:07,  2.50s/it][A
Epoch 3 in training:   1%|▏         | 5/391 [00:12<15:06,  2.35s/it][A
Epoch 3 in training:   2%|▏         | 6/391 [00:14<14:18,  2.23s/it][A
Epoch 3 in training:   2%|▏         | 7/391 [00:16<13:57,  2.18s/it][A
Epoch 3 in training:   2%|▏         | 8/391 [00:20<16:24,  2.57s/it][A
Epoch 3 in training:   2%|▏         | 9/391 [00:22<15:28,  2.43s/it][A
Epoch 3 in training:   3%|▎         | 10/391 [00:24<15:16,  2.41s/it][A
Epoch 3 in training:   3%|▎         | 11/391 [00:27<15:24,  2.43s/it][A
Epoch 3 in training:   3%|▎         | 12/391 [00:29<15:48,  2.50s/it][A
Epoch 3 in training:   3%|▎         | 13/391 [00:34<19:44,  3.13s/it

Epoch 3 / 15 loss :  4.50



Epoch 4 in training:   0%|          | 0/391 [00:00<?, ?it/s][A
Epoch 4 in training:   0%|          | 1/391 [00:03<20:03,  3.08s/it][A
Epoch 4 in training:   1%|          | 2/391 [00:05<17:28,  2.70s/it][A
Epoch 4 in training:   1%|          | 3/391 [00:07<16:29,  2.55s/it][A
Epoch 4 in training:   1%|          | 4/391 [00:10<16:25,  2.55s/it][A
Epoch 4 in training:   1%|▏         | 5/391 [00:13<16:30,  2.57s/it][A
Epoch 4 in training:   2%|▏         | 6/391 [00:15<15:32,  2.42s/it][A
Epoch 4 in training:   2%|▏         | 7/391 [00:17<15:36,  2.44s/it][A
Epoch 4 in training:   2%|▏         | 8/391 [00:19<15:06,  2.37s/it][A
Epoch 4 in training:   2%|▏         | 9/391 [00:21<14:31,  2.28s/it][A
Epoch 4 in training:   3%|▎         | 10/391 [00:24<15:51,  2.50s/it][A
Epoch 4 in training:   3%|▎         | 11/391 [00:27<15:03,  2.38s/it][A
Epoch 4 in training:   3%|▎         | 12/391 [00:29<14:35,  2.31s/it][A
Epoch 4 in training:   3%|▎         | 13/391 [00:31<14:30,  2.30s/it

Epoch 4 / 15 loss :  4.48



Epoch 5 in training:   0%|          | 0/391 [00:00<?, ?it/s][A
Epoch 5 in training:   0%|          | 1/391 [00:03<21:24,  3.29s/it][A
Epoch 5 in training:   1%|          | 2/391 [00:05<18:33,  2.86s/it][A
Epoch 5 in training:   1%|          | 3/391 [00:07<15:55,  2.46s/it][A
Epoch 5 in training:   1%|          | 4/391 [00:09<14:35,  2.26s/it][A
Epoch 5 in training:   1%|▏         | 5/391 [00:11<13:49,  2.15s/it][A
Epoch 5 in training:   2%|▏         | 6/391 [00:13<13:57,  2.18s/it][A
Epoch 5 in training:   2%|▏         | 7/391 [00:16<14:30,  2.27s/it][A
Epoch 5 in training:   2%|▏         | 8/391 [00:18<14:48,  2.32s/it][A
Epoch 5 in training:   2%|▏         | 9/391 [00:20<14:13,  2.23s/it][A
Epoch 5 in training:   3%|▎         | 10/391 [00:22<13:37,  2.15s/it][A
Epoch 5 in training:   3%|▎         | 11/391 [00:24<13:09,  2.08s/it][A
Epoch 5 in training:   3%|▎         | 12/391 [00:26<12:52,  2.04s/it][A
Epoch 5 in training:   3%|▎         | 13/391 [00:29<13:21,  2.12s/it

Epoch 5 / 15 loss :  4.49



Epoch 6 in training:   0%|          | 0/391 [00:00<?, ?it/s][A
Epoch 6 in training:   0%|          | 1/391 [00:02<18:17,  2.81s/it][A
Epoch 6 in training:   1%|          | 2/391 [00:04<15:34,  2.40s/it][A
Epoch 6 in training:   1%|          | 3/391 [00:06<14:18,  2.21s/it][A
Epoch 6 in training:   1%|          | 4/391 [00:08<13:33,  2.10s/it][A
Epoch 6 in training:   1%|▏         | 5/391 [00:11<15:48,  2.46s/it][A
Epoch 6 in training:   2%|▏         | 6/391 [00:13<14:35,  2.27s/it][A
Epoch 6 in training:   2%|▏         | 7/391 [00:15<13:45,  2.15s/it][A
Epoch 6 in training:   2%|▏         | 8/391 [00:17<13:17,  2.08s/it][A
Epoch 6 in training:   2%|▏         | 9/391 [00:19<12:54,  2.03s/it][A
Epoch 6 in training:   3%|▎         | 10/391 [00:21<12:39,  1.99s/it][A
Epoch 6 in training:   3%|▎         | 11/391 [00:24<14:00,  2.21s/it][A
Epoch 6 in training:   3%|▎         | 12/391 [00:26<13:50,  2.19s/it][A
Epoch 6 in training:   3%|▎         | 13/391 [00:28<13:18,  2.11s/it

Epoch 6 / 15 loss :  4.49



Epoch 7 in training:   0%|          | 0/391 [00:00<?, ?it/s][A
Epoch 7 in training:   0%|          | 1/391 [00:02<19:08,  2.94s/it][A
Epoch 7 in training:   1%|          | 2/391 [00:05<16:29,  2.54s/it][A
Epoch 7 in training:   1%|          | 3/391 [00:07<14:52,  2.30s/it][A
Epoch 7 in training:   1%|          | 4/391 [00:09<15:57,  2.47s/it][A
Epoch 7 in training:   1%|▏         | 5/391 [00:12<15:42,  2.44s/it][A
Epoch 7 in training:   2%|▏         | 6/391 [00:14<14:30,  2.26s/it][A
Epoch 7 in training:   2%|▏         | 7/391 [00:16<13:55,  2.18s/it][A
Epoch 7 in training:   2%|▏         | 8/391 [00:18<13:31,  2.12s/it][A
Epoch 7 in training:   2%|▏         | 9/391 [00:20<13:45,  2.16s/it][A
Epoch 7 in training:   3%|▎         | 10/391 [00:22<14:19,  2.26s/it][A
Epoch 7 in training:   3%|▎         | 11/391 [00:25<14:34,  2.30s/it][A
Epoch 7 in training:   3%|▎         | 12/391 [00:27<13:46,  2.18s/it][A
Epoch 7 in training:   3%|▎         | 13/391 [00:29<13:41,  2.17s/it

Epoch 7 / 15 loss :  4.49



Epoch 8 in training:   0%|          | 0/391 [00:00<?, ?it/s][A
Epoch 8 in training:   0%|          | 1/391 [00:02<18:56,  2.91s/it][A
Epoch 8 in training:   1%|          | 2/391 [00:05<17:18,  2.67s/it][A
Epoch 8 in training:   1%|          | 3/391 [00:08<17:49,  2.76s/it][A
Epoch 8 in training:   1%|          | 4/391 [00:10<15:46,  2.45s/it][A
Epoch 8 in training:   1%|▏         | 5/391 [00:12<14:41,  2.28s/it][A
Epoch 8 in training:   2%|▏         | 6/391 [00:14<13:53,  2.17s/it][A
Epoch 8 in training:   2%|▏         | 7/391 [00:16<13:25,  2.10s/it][A
Epoch 8 in training:   2%|▏         | 8/391 [00:18<14:04,  2.21s/it][A
Epoch 8 in training:   2%|▏         | 9/391 [00:21<14:49,  2.33s/it][A
Epoch 8 in training:   3%|▎         | 10/391 [00:23<14:09,  2.23s/it][A
Epoch 8 in training:   3%|▎         | 11/391 [00:25<13:33,  2.14s/it][A
Epoch 8 in training:   3%|▎         | 12/391 [00:27<13:35,  2.15s/it][A
Epoch 8 in training:   3%|▎         | 13/391 [00:29<13:14,  2.10s/it

Epoch 8 / 15 loss :  4.48



Epoch 9 in training:   0%|          | 0/391 [00:00<?, ?it/s][A
Epoch 9 in training:   0%|          | 1/391 [00:02<18:37,  2.86s/it][A
Epoch 9 in training:   1%|          | 2/391 [00:05<16:40,  2.57s/it][A
Epoch 9 in training:   1%|          | 3/391 [00:07<14:45,  2.28s/it][A
Epoch 9 in training:   1%|          | 4/391 [00:09<13:40,  2.12s/it][A
Epoch 9 in training:   1%|▏         | 5/391 [00:11<14:54,  2.32s/it][A
Epoch 9 in training:   2%|▏         | 6/391 [00:13<14:38,  2.28s/it][A
Epoch 9 in training:   2%|▏         | 7/391 [00:15<13:52,  2.17s/it][A
Epoch 9 in training:   2%|▏         | 8/391 [00:17<13:27,  2.11s/it][A
Epoch 9 in training:   2%|▏         | 9/391 [00:19<13:05,  2.06s/it][A
Epoch 9 in training:   3%|▎         | 10/391 [00:21<12:42,  2.00s/it][A
Epoch 9 in training:   3%|▎         | 11/391 [00:24<13:58,  2.21s/it][A
Epoch 9 in training:   3%|▎         | 12/391 [00:26<14:14,  2.25s/it][A
Epoch 9 in training:   3%|▎         | 13/391 [00:28<13:39,  2.17s/it

Epoch 9 / 15 loss :  4.49



Epoch 10 in training:   0%|          | 0/391 [00:00<?, ?it/s][A
Epoch 10 in training:   0%|          | 1/391 [00:03<21:49,  3.36s/it][A
Epoch 10 in training:   1%|          | 2/391 [00:06<19:24,  2.99s/it][A
Epoch 10 in training:   1%|          | 3/391 [00:08<16:26,  2.54s/it][A
Epoch 10 in training:   1%|          | 4/391 [00:10<14:50,  2.30s/it][A
Epoch 10 in training:   1%|▏         | 5/391 [00:11<13:52,  2.16s/it][A
Epoch 10 in training:   2%|▏         | 6/391 [00:13<13:21,  2.08s/it][A
Epoch 10 in training:   2%|▏         | 7/391 [00:16<14:06,  2.20s/it][A
Epoch 10 in training:   2%|▏         | 8/391 [00:18<14:21,  2.25s/it][A
Epoch 10 in training:   2%|▏         | 9/391 [00:20<13:38,  2.14s/it][A
Epoch 10 in training:   3%|▎         | 10/391 [00:22<13:37,  2.15s/it][A
Epoch 10 in training:   3%|▎         | 11/391 [00:24<13:08,  2.07s/it][A
Epoch 10 in training:   3%|▎         | 12/391 [00:26<12:46,  2.02s/it][A
Epoch 10 in training:   3%|▎         | 13/391 [00:28<13

Epoch 10 / 15 loss :  4.49



Epoch 11 in training:   0%|          | 0/391 [00:00<?, ?it/s][A
Epoch 11 in training:   0%|          | 1/391 [00:03<23:51,  3.67s/it][A
Epoch 11 in training:   1%|          | 2/391 [00:05<18:13,  2.81s/it][A
Epoch 11 in training:   1%|          | 3/391 [00:08<16:11,  2.50s/it][A
Epoch 11 in training:   1%|          | 4/391 [00:10<15:03,  2.33s/it][A
Epoch 11 in training:   1%|▏         | 5/391 [00:12<14:27,  2.25s/it][A
Epoch 11 in training:   2%|▏         | 6/391 [00:14<14:41,  2.29s/it][A
Epoch 11 in training:   2%|▏         | 7/391 [00:17<15:01,  2.35s/it][A
Epoch 11 in training:   2%|▏         | 8/391 [00:18<14:08,  2.22s/it][A
Epoch 11 in training:   2%|▏         | 9/391 [00:21<14:08,  2.22s/it][A
Epoch 11 in training:   3%|▎         | 10/391 [00:23<13:33,  2.14s/it][A
Epoch 11 in training:   3%|▎         | 11/391 [00:25<12:59,  2.05s/it][A
Epoch 11 in training:   3%|▎         | 12/391 [00:27<13:32,  2.14s/it][A
Epoch 11 in training:   3%|▎         | 13/391 [00:29<14

Epoch 11 / 15 loss :  4.49



Epoch 12 in training:   0%|          | 0/391 [00:00<?, ?it/s][A
Epoch 12 in training:   0%|          | 1/391 [00:02<18:14,  2.81s/it][A
Epoch 12 in training:   1%|          | 2/391 [00:04<15:48,  2.44s/it][A
Epoch 12 in training:   1%|          | 3/391 [00:07<17:04,  2.64s/it][A
Epoch 12 in training:   1%|          | 4/391 [00:09<15:13,  2.36s/it][A
Epoch 12 in training:   1%|▏         | 5/391 [00:11<14:17,  2.22s/it][A
Epoch 12 in training:   2%|▏         | 6/391 [00:13<13:30,  2.10s/it][A
Epoch 12 in training:   2%|▏         | 7/391 [00:15<13:01,  2.04s/it][A
Epoch 12 in training:   2%|▏         | 8/391 [00:17<13:37,  2.13s/it][A
Epoch 12 in training:   2%|▏         | 9/391 [00:20<14:49,  2.33s/it][A
Epoch 12 in training:   3%|▎         | 10/391 [00:22<14:12,  2.24s/it][A
Epoch 12 in training:   3%|▎         | 11/391 [00:24<13:46,  2.17s/it][A
Epoch 12 in training:   3%|▎         | 12/391 [00:26<13:17,  2.10s/it][A
Epoch 12 in training:   3%|▎         | 13/391 [00:28<12

Epoch 12 / 15 loss :  4.50



Epoch 13 in training:   0%|          | 0/391 [00:00<?, ?it/s][A
Epoch 13 in training:   0%|          | 1/391 [00:03<23:58,  3.69s/it][A
Epoch 13 in training:   1%|          | 2/391 [00:05<18:14,  2.81s/it][A
Epoch 13 in training:   1%|          | 3/391 [00:07<15:47,  2.44s/it][A
Epoch 13 in training:   1%|          | 4/391 [00:09<14:19,  2.22s/it][A
Epoch 13 in training:   1%|▏         | 5/391 [00:11<13:34,  2.11s/it][A
Epoch 13 in training:   2%|▏         | 6/391 [00:13<13:02,  2.03s/it][A
Epoch 13 in training:   2%|▏         | 7/391 [00:16<14:59,  2.34s/it][A
Epoch 13 in training:   2%|▏         | 8/391 [00:18<14:37,  2.29s/it][A
Epoch 13 in training:   2%|▏         | 9/391 [00:20<13:52,  2.18s/it][A
Epoch 13 in training:   3%|▎         | 10/391 [00:22<13:17,  2.09s/it][A
Epoch 13 in training:   3%|▎         | 11/391 [00:24<12:55,  2.04s/it][A
Epoch 13 in training:   3%|▎         | 12/391 [00:26<12:37,  2.00s/it][A
Epoch 13 in training:   3%|▎         | 13/391 [00:28<13

Epoch 13 / 15 loss :  4.51



Epoch 14 in training:   0%|          | 0/391 [00:00<?, ?it/s][A
Epoch 14 in training:   0%|          | 1/391 [00:03<19:52,  3.06s/it][A
Epoch 14 in training:   1%|          | 2/391 [00:05<16:58,  2.62s/it][A
Epoch 14 in training:   1%|          | 3/391 [00:07<16:25,  2.54s/it][A
Epoch 14 in training:   1%|          | 4/391 [00:10<17:05,  2.65s/it][A
Epoch 14 in training:   1%|▏         | 5/391 [00:12<15:49,  2.46s/it][A
Epoch 14 in training:   2%|▏         | 6/391 [00:15<15:40,  2.44s/it][A
Epoch 14 in training:   2%|▏         | 7/391 [00:17<14:49,  2.32s/it][A
Epoch 14 in training:   2%|▏         | 8/391 [00:19<14:23,  2.25s/it][A
Epoch 14 in training:   2%|▏         | 9/391 [00:21<15:03,  2.36s/it][A
Epoch 14 in training:   3%|▎         | 10/391 [00:24<15:32,  2.45s/it][A
Epoch 14 in training:   3%|▎         | 11/391 [00:26<14:54,  2.35s/it][A
Epoch 14 in training:   3%|▎         | 12/391 [00:28<14:14,  2.25s/it][A
Epoch 14 in training:   3%|▎         | 13/391 [00:30<13

Epoch 14 / 15 loss :  4.48



Epoch 15 in training:   0%|          | 0/391 [00:00<?, ?it/s][A
Epoch 15 in training:   0%|          | 1/391 [00:03<22:46,  3.50s/it][A
Epoch 15 in training:   1%|          | 2/391 [00:05<18:05,  2.79s/it][A
Epoch 15 in training:   1%|          | 3/391 [00:07<16:03,  2.48s/it][A
Epoch 15 in training:   1%|          | 4/391 [00:09<14:49,  2.30s/it][A
Epoch 15 in training:   1%|▏         | 5/391 [00:12<15:41,  2.44s/it][A
Epoch 15 in training:   2%|▏         | 6/391 [00:15<16:09,  2.52s/it][A
Epoch 15 in training:   2%|▏         | 7/391 [00:17<15:04,  2.36s/it][A
Epoch 15 in training:   2%|▏         | 8/391 [00:19<14:26,  2.26s/it][A
Epoch 15 in training:   2%|▏         | 9/391 [00:21<14:04,  2.21s/it][A
Epoch 15 in training:   3%|▎         | 10/391 [00:23<13:49,  2.18s/it][A
Epoch 15 in training:   3%|▎         | 11/391 [00:26<14:33,  2.30s/it][A
Epoch 15 in training:   3%|▎         | 12/391 [00:28<14:59,  2.37s/it][A
Epoch 15 in training:   3%|▎         | 13/391 [00:30<14

Epoch 15 / 15 loss :  4.47


Testing: 100%|██████████| 79/79 [01:02<00:00,  1.26it/s]

Test Loss: 4.48
Test Accuracy: 14.10%





In [11]:
main() #hidden_dim = 32->128

Files already downloaded and verified
Files already downloaded and verified
Using Device :  cpu 


Training:   0%|          | 0/15 [00:00<?, ?it/s]
Epoch 1 in training:   0%|          | 0/391 [00:00<?, ?it/s][A
Epoch 1 in training:   0%|          | 1/391 [00:04<30:51,  4.75s/it][A
Epoch 1 in training:   1%|          | 2/391 [00:07<24:11,  3.73s/it][A
Epoch 1 in training:   1%|          | 3/391 [00:10<22:04,  3.41s/it][A
Epoch 1 in training:   1%|          | 4/391 [00:13<20:15,  3.14s/it][A
Epoch 1 in training:   1%|▏         | 5/391 [00:17<22:27,  3.49s/it][A
Epoch 1 in training:   2%|▏         | 6/391 [00:20<20:51,  3.25s/it][A
Epoch 1 in training:   2%|▏         | 7/391 [00:23<19:36,  3.06s/it][A
Epoch 1 in training:   2%|▏         | 8/391 [00:25<18:45,  2.94s/it][A
Epoch 1 in training:   2%|▏         | 9/391 [00:28<18:45,  2.95s/it][A
Epoch 1 in training:   3%|▎         | 10/391 [00:32<19:43,  3.11s/it][A
Epoch 1 in training:   3%|▎         | 11/391 [00:34<18:49,  2.97s/it][A
Epoch 1 in training:   3%|▎         | 12/391 [00:37<18:15,  2.89s/it][A
Epoch 1 in training:

Epoch 1 / 15 loss :  4.52



Epoch 2 in training:   0%|          | 0/391 [00:00<?, ?it/s][A
Epoch 2 in training:   0%|          | 1/391 [00:05<33:29,  5.15s/it][A
Epoch 2 in training:   1%|          | 2/391 [00:10<32:26,  5.00s/it][A
Epoch 2 in training:   1%|          | 3/391 [00:13<26:26,  4.09s/it][A
Epoch 2 in training:   1%|          | 4/391 [00:16<24:06,  3.74s/it][A
Epoch 2 in training:   1%|▏         | 5/391 [00:19<23:02,  3.58s/it][A
Epoch 2 in training:   2%|▏         | 6/391 [00:22<20:59,  3.27s/it][A
Epoch 2 in training:   2%|▏         | 7/391 [00:24<19:46,  3.09s/it][A
Epoch 2 in training:   2%|▏         | 8/391 [00:27<18:50,  2.95s/it][A
Epoch 2 in training:   2%|▏         | 9/391 [00:31<20:10,  3.17s/it][A
Epoch 2 in training:   3%|▎         | 10/391 [00:34<19:30,  3.07s/it][A
Epoch 2 in training:   3%|▎         | 11/391 [00:36<18:42,  2.96s/it][A
Epoch 2 in training:   3%|▎         | 12/391 [00:39<18:02,  2.86s/it][A
Epoch 2 in training:   3%|▎         | 13/391 [00:42<18:40,  2.96s/it

KeyboardInterrupt: 

In [None]:
main() #n_patches = 8->4

Files already downloaded and verified
Files already downloaded and verified
Using Device :  cpu 


Training:   0%|          | 0/15 [00:00<?, ?it/s]
Epoch 1 in training:   0%|          | 0/391 [00:00<?, ?it/s][A
Epoch 1 in training:   0%|          | 1/391 [00:04<31:56,  4.91s/it][A
Epoch 1 in training:   1%|          | 2/391 [00:07<24:18,  3.75s/it][A
Epoch 1 in training:   1%|          | 3/391 [00:10<21:27,  3.32s/it][A
Epoch 1 in training:   1%|          | 4/391 [00:13<20:04,  3.11s/it][A
Epoch 1 in training:   1%|▏         | 5/391 [00:17<22:36,  3.51s/it][A
Epoch 1 in training:   2%|▏         | 6/391 [00:20<20:36,  3.21s/it][A
Epoch 1 in training:   2%|▏         | 7/391 [00:22<19:28,  3.04s/it][A
Epoch 1 in training:   2%|▏         | 8/391 [00:25<19:07,  2.99s/it][A
Epoch 1 in training:   2%|▏         | 9/391 [00:29<20:22,  3.20s/it][A
Epoch 1 in training:   3%|▎         | 10/391 [00:32<20:07,  3.17s/it][A
Epoch 1 in training:   3%|▎         | 11/391 [00:35<19:10,  3.03s/it][A
Epoch 1 in training:   3%|▎         | 12/391 [00:37<18:23,  2.91s/it][A
Epoch 1 in training:

Epoch 1 / 15 loss :  4.48



Epoch 2 in training:   0%|          | 0/391 [00:00<?, ?it/s][A
Epoch 2 in training:   0%|          | 1/391 [00:04<26:38,  4.10s/it][A
Epoch 2 in training:   1%|          | 2/391 [00:07<22:42,  3.50s/it][A
Epoch 2 in training:   1%|          | 3/391 [00:10<21:50,  3.38s/it][A
Epoch 2 in training:   1%|          | 4/391 [00:14<22:31,  3.49s/it][A
Epoch 2 in training:   1%|▏         | 5/391 [00:16<20:39,  3.21s/it][A
Epoch 2 in training:   2%|▏         | 6/391 [00:19<19:33,  3.05s/it][A
Epoch 2 in training:   2%|▏         | 7/391 [00:22<18:42,  2.92s/it][A
Epoch 2 in training:   2%|▏         | 8/391 [00:25<19:32,  3.06s/it][A
Epoch 2 in training:   2%|▏         | 9/391 [00:28<19:39,  3.09s/it][A
Epoch 2 in training:   3%|▎         | 10/391 [00:31<18:57,  2.99s/it][A
Epoch 2 in training:   3%|▎         | 11/391 [00:34<18:38,  2.94s/it][A
Epoch 2 in training:   3%|▎         | 12/391 [00:37<18:18,  2.90s/it][A
Epoch 2 in training:   3%|▎         | 13/391 [00:40<19:47,  3.14s/it