<a href="https://colab.research.google.com/github/MilenaOehlers/generative_models_for_deep_radar_object_detection/blob/main/vision_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Following https://medium.com/@brianpulfer/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c

In [2]:
import numpy as np

from tqdm import tqdm, trange

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import MNIST

np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7a96a0246390>

In [4]:
class MyMSA(nn.Module):
    def __init__(self, d, n_heads=2):
        super(MyMSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d%n_heads==0, f"Can't divide dimension {d} into {n_heads} heads"

        d_head = int(d / n_heads)
        """Wait... so here, as opposed to the regular transformer tutorial,
         not d_model==d is the dimension of q_mappings,
        but d_head?? In that case, applying weights to both K and Q truely
        does not make sense. """
        self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        # Sequences has shape (N, seq_length, token_dim)
        # We go into shape    (N, seq_length, n_heads, token_dim / n_heads)
        # And come back to    (N, seq_length, item_dim)  (through concatenation)
        result = []

        """In tutorial: 'Also notice that using loops is not the most efficient
        way to compute the multi-head self-attention, but it makes the code much
        clearer for learning.'"""
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                """
                So, of each picture, we have 49 patches of original size 16 but
                encoded size token_dim==8, and of those 8dims, eg. 4x (2 belong
                to one head only), hence its as if instead of 8 dims we had (2x4) dims,
                with n_channels=n_heads=4; very same, just different way of writing.

                So, P pictures of 49 patches each with n_pixels==16 each.
                each picture-patch combi is converted independently, but in the
                very same way, via learnable parameters into basically a
                (n_heads x d_head) presentation.

                Inside each head, head_specific K/Q/V matrices are applied
                and only inside-head matrix multiplication takes place.

                So yes, both Q and K weighing makes no sense.
                """
                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                # q, k, and v have dim (n_patches x d_head)
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                # attention hence has dim (n_patches x n_patches)
                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                # appended partial result has dim (n_patches x d_head) again
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result)) # -> n_patches x token_dim
        # torch.cat default dim==0, so stack pictures back onto each other:
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

In [5]:
class MyViTBlock(nn.Module):
    def __init__(self, hidden_d, n_heads, mlp_ratio=4):
        super(MyViTBlock, self).__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(hidden_d)
        self.mhsa = MyMSA(hidden_d, n_heads)
        self.norm2 = nn.LayerNorm(hidden_d)
        # Multi-Layer-Perceptron:
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d, mlp_ratio * hidden_d),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_d, hidden_d)
        )

    def forward(self, x):
        out = x + self.mhsa(self.norm1(x))
        out = out + self.mlp(self.norm2(out))
        return out

In [26]:
class MyViT(nn.Module):
  def __init__(self, chw=(1, 28, 28), n_patches=7, n_blocks=2,  hidden_d=8,
               n_heads=2, out_d=10):
    # Super constructor
    super(MyViT, self).__init__()

    # Attributes
    self.chw = chw # (C, H, W)
    self.n_patches = n_patches
    self.hidden_d = hidden_d

    assert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
    assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
    self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)

    # 1) Linear mapper
    self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
    self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)

    # 2) Learnable classifiation token
    self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))

    # 3) Positional embedding
    self.pos_embed = nn.Parameter(torch.tensor(self.get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d)))
    self.pos_embed.requires_grad = False

    # 4) Transformer encoder blocks
    self.blocks = nn.ModuleList([MyViTBlock(hidden_d, n_heads) for _ in range(n_blocks)])

    # 5) Classification MLPk
    self.mlp = nn.Sequential(
        nn.Linear(self.hidden_d, out_d),
        nn.Softmax(dim=-1)
    )
    """
    nn.LayerNorm: normalizes over last D dimensions of the input tensor,
    e.g. layer_norm = nn.LayerNorm([49,16])
    for our N pictures, each constisting of 49 patches of 4*4 pixels,
    hence forming an input tensor of (N,49,16)
    """

  def forward(self, images):
    patches = self.patchify(images, self.n_patches)
    tokens = self.linear_mapper(patches)

    # Adding classification token to the tokens
    tokens = torch.stack([torch.vstack((self.class_token, tokens[i]))
                          for i in range(len(tokens))])

    # Adding positional embedding
    pos_embed = self.pos_embed.repeat(images.shape[0], 1, 1)
    out = tokens + pos_embed

    # Transformer Blocks
    for block in self.blocks:
        out = block(out)

    # Getting the classification token only
    out = out[:, 0]

    return self.mlp(out) # Map to output dimension, output category distribution

  def patchify(self,images, n_patches):
    n, c, h, w = images.shape

    assert h == w, "Patchify method is implemented for square images only"

    patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2)
    patch_size = h // n_patches

    for idx, image in enumerate(images):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches

  def get_positional_embeddings(self, sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result


In [27]:
def main():
    # Loading data
    transform = ToTensor()

    train_set = MNIST(root='./../datasets', train=True, download=True, transform=transform)
    test_set = MNIST(root='./../datasets', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
    test_loader = DataLoader(test_set, shuffle=False, batch_size=128)

    # Defining model and training options
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")
    model = MyViT((1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10).to(device)
    N_EPOCHS = 5
    LR = 0.005

    # Training loop
    """Somehow weird to me that the training loop has to be written explicitly,
    with setting gradients to zero manually... surely, torch provides a whole
    automatic fucntionality for that? """
    optimizer = Adam(model.parameters(), lr=LR)
    criterion = CrossEntropyLoss()
    for epoch in trange(N_EPOCHS, desc="Training"):
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)

            """ not sure why in the following / len(train_loader)?
            is a loss always supposed to be an average per-batch loss?
            Or maybe the len(train_loader)==n_training, not n_batches.
            Then, would make much more sense!
            -> but here, it actually returns n_batches. bit weird to me, but ok
            """
            train_loss += loss.detach().cpu().item() / len(train_loader)

            """Why would the gradient be set to zero after each batch??
            -> accoring to https://pytorch.org/tutorials/recipes/recipes/zeroing_out_gradients.html
            gradients are accumulated in buffers whenever .backward()
            is called during calculation, and not overwritten
            If were not interested in testing for vanishing or exploding gradients
            afterwards, we can delete them, or zero them out.
            Actually, even in the explicit example provided there, they zero out
            the gradients inside each epoch>batch loop.

            So, it seems gradients have to be zeroed out before accumulating
            gradients in the derivate graph leaves again, as else the gradients
            from last calculation would be added to the ones of this calculation
            hence overshooting, (= applying the wrong gradients), correct or not?

            I am still surprised that these elements are coded this explicitly
            if I remember correctly, inside tensorflow that was taken care of?
            """

            optimizer.zero_grad() # override gradients in buffer with zero
            loss.backward() # accumulate gradients into the leaves
            optimizer.step() # parameter upgrade based on accum. gradients

        print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")

    # Test loop
    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)
        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")

In [28]:
main()

  self.pos_embed = nn.Parameter(torch.tensor(self.get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d)))


Using device:  cpu 


Training:   0%|          | 0/5 [00:00<?, ?it/s]
Epoch 1 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 1 in training:   0%|          | 1/469 [00:00<06:38,  1.17it/s][A
Epoch 1 in training:   0%|          | 2/469 [00:01<04:52,  1.60it/s][A
Epoch 1 in training:   1%|          | 3/469 [00:01<04:19,  1.79it/s][A
Epoch 1 in training:   1%|          | 4/469 [00:02<03:59,  1.94it/s][A
Epoch 1 in training:   1%|          | 5/469 [00:02<03:52,  2.00it/s][A
Epoch 1 in training:   1%|▏         | 6/469 [00:03<03:43,  2.08it/s][A
Epoch 1 in training:   1%|▏         | 7/469 [00:03<03:38,  2.12it/s][A
Epoch 1 in training:   2%|▏         | 8/469 [00:04<03:37,  2.12it/s][A
Epoch 1 in training:   2%|▏         | 9/469 [00:04<03:35,  2.13it/s][A
Epoch 1 in training:   2%|▏         | 10/469 [00:05<03:34,  2.14it/s][A
Epoch 1 in training:   2%|▏         | 11/469 [00:05<03:31,  2.17it/s][A
Epoch 1 in training:   3%|▎         | 12/469 [00:05<03:30,  2.17it/s][A
Epoch 1 in training: 

Epoch 1/5 loss: 2.00



Epoch 2 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 2 in training:   0%|          | 1/469 [00:00<03:47,  2.06it/s][A
Epoch 2 in training:   0%|          | 2/469 [00:01<04:37,  1.68it/s][A
Epoch 2 in training:   1%|          | 3/469 [00:01<04:51,  1.60it/s][A
Epoch 2 in training:   1%|          | 4/469 [00:02<05:09,  1.50it/s][A
Epoch 2 in training:   1%|          | 5/469 [00:03<04:59,  1.55it/s][A
Epoch 2 in training:   1%|▏         | 6/469 [00:03<04:32,  1.70it/s][A
Epoch 2 in training:   1%|▏         | 7/469 [00:04<04:11,  1.84it/s][A
Epoch 2 in training:   2%|▏         | 8/469 [00:04<03:57,  1.94it/s][A
Epoch 2 in training:   2%|▏         | 9/469 [00:04<03:48,  2.01it/s][A
Epoch 2 in training:   2%|▏         | 10/469 [00:05<03:41,  2.08it/s][A
Epoch 2 in training:   2%|▏         | 11/469 [00:05<03:38,  2.10it/s][A
Epoch 2 in training:   3%|▎         | 12/469 [00:06<03:35,  2.12it/s][A
Epoch 2 in training:   3%|▎         | 13/469 [00:06<03:34,  2.12it/s

Epoch 2/5 loss: 1.82



Epoch 3 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 3 in training:   0%|          | 1/469 [00:00<04:04,  1.91it/s][A
Epoch 3 in training:   0%|          | 2/469 [00:01<04:44,  1.64it/s][A
Epoch 3 in training:   1%|          | 3/469 [00:01<04:56,  1.57it/s][A
Epoch 3 in training:   1%|          | 4/469 [00:02<05:15,  1.47it/s][A
Epoch 3 in training:   1%|          | 5/469 [00:03<04:50,  1.60it/s][A
Epoch 3 in training:   1%|▏         | 6/469 [00:03<04:26,  1.74it/s][A
Epoch 3 in training:   1%|▏         | 7/469 [00:04<04:09,  1.85it/s][A
Epoch 3 in training:   2%|▏         | 8/469 [00:04<03:59,  1.93it/s][A
Epoch 3 in training:   2%|▏         | 9/469 [00:05<03:51,  1.99it/s][A
Epoch 3 in training:   2%|▏         | 10/469 [00:05<03:46,  2.03it/s][A
Epoch 3 in training:   2%|▏         | 11/469 [00:05<03:41,  2.06it/s][A
Epoch 3 in training:   3%|▎         | 12/469 [00:06<03:37,  2.10it/s][A
Epoch 3 in training:   3%|▎         | 13/469 [00:06<03:36,  2.10it/s

Epoch 3/5 loss: 1.72



Epoch 4 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 4 in training:   0%|          | 1/469 [00:00<04:43,  1.65it/s][A
Epoch 4 in training:   0%|          | 2/469 [00:01<04:02,  1.93it/s][A
Epoch 4 in training:   1%|          | 3/469 [00:01<03:52,  2.00it/s][A
Epoch 4 in training:   1%|          | 4/469 [00:01<03:44,  2.08it/s][A
Epoch 4 in training:   1%|          | 5/469 [00:02<03:43,  2.08it/s][A
Epoch 4 in training:   1%|▏         | 6/469 [00:02<03:38,  2.11it/s][A
Epoch 4 in training:   1%|▏         | 7/469 [00:03<03:38,  2.12it/s][A
Epoch 4 in training:   2%|▏         | 8/469 [00:03<03:35,  2.14it/s][A
Epoch 4 in training:   2%|▏         | 9/469 [00:04<03:35,  2.13it/s][A
Epoch 4 in training:   2%|▏         | 10/469 [00:04<03:33,  2.15it/s][A
Epoch 4 in training:   2%|▏         | 11/469 [00:05<03:32,  2.16it/s][A
Epoch 4 in training:   3%|▎         | 12/469 [00:05<03:30,  2.17it/s][A
Epoch 4 in training:   3%|▎         | 13/469 [00:06<03:28,  2.19it/s

Epoch 4/5 loss: 1.68



Epoch 5 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 5 in training:   0%|          | 1/469 [00:00<05:17,  1.48it/s][A
Epoch 5 in training:   0%|          | 2/469 [00:01<05:33,  1.40it/s][A
Epoch 5 in training:   1%|          | 3/469 [00:01<04:50,  1.60it/s][A
Epoch 5 in training:   1%|          | 4/469 [00:02<04:21,  1.78it/s][A
Epoch 5 in training:   1%|          | 5/469 [00:02<04:05,  1.89it/s][A
Epoch 5 in training:   1%|▏         | 6/469 [00:03<03:55,  1.96it/s][A
Epoch 5 in training:   1%|▏         | 7/469 [00:03<03:50,  2.01it/s][A
Epoch 5 in training:   2%|▏         | 8/469 [00:04<03:45,  2.05it/s][A
Epoch 5 in training:   2%|▏         | 9/469 [00:04<03:44,  2.05it/s][A
Epoch 5 in training:   2%|▏         | 10/469 [00:05<03:39,  2.09it/s][A
Epoch 5 in training:   2%|▏         | 11/469 [00:05<03:39,  2.09it/s][A
Epoch 5 in training:   3%|▎         | 12/469 [00:06<03:37,  2.10it/s][A
Epoch 5 in training:   3%|▎         | 13/469 [00:06<03:37,  2.09it/s

Epoch 5/5 loss: 1.66


Testing: 100%|██████████| 79/79 [00:20<00:00,  3.76it/s]

Test loss: 1.64
Test accuracy: 82.05%



