In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import numpy as np

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
!git clone https://github.com/nerminnuraydogan/vision-transformer

In [None]:
image = Image.open('vision-transformer/car.png')

In [None]:
image = image.resize((128, 128))

# convert to numpy array 
x = np.array(image)


# An Image Is Worth 16x16 Words
P = 16   # patch size
C = 3    # number of channels (RGB)

# split image into patches using numpy
patches = x.reshape(x.shape[0]//P, P, x.shape[1]//P, P, C).swapaxes(1, 2).reshape(-1, P, P, C)

# flatten patches
x_p = np.reshape(patches, (-1, P * P * C))

# get number of patches
N = x_p.shape[0]

print('Image shape: ', x.shape)  # width, height, channel
print('Number of patches: {} with resolution ({}, {})'.format(N, P, P))
print('Patches shape: ', patches.shape)
print('Flattened patches shape: ', x_p.shape)

In [None]:
fig = plt.figure()

gridspec = fig.add_gridspec(1, 2)
ax1 = fig.add_subplot(gridspec[0])
ax1.set(title='Image')

# display image 
ax1.imshow(x)

subgridspec = gridspec[1].subgridspec(8, 8, hspace=-0.8)

# display patches
for i in range(8):    # N = 64, 8x8 grid
    for j in range(8):
        num = i * 8 + j
        ax = fig.add_subplot(subgridspec[i, j])
        ax.set(xticks=[], yticks=[])
        ax.imshow(patches[num])

In [None]:
D = 768

# batch size
B = 1

# convert flattened patches to tensor
x_p = torch.Tensor(x_p)

# add batch dimension
x_p = x_p[None, ...]    

# weight matrix E
E = nn.Parameter(torch.randn(1, P * P * C, D))

patch_embeddings = torch.matmul(x_p , E)

assert patch_embeddings.shape == (B, N, D)
print(patch_embeddings.shape)

# Class Token

In [None]:
# init class token
class_token = nn.Parameter(torch.randn(1, 1, D))

patch_embeddings = torch.cat((class_token, patch_embeddings), 1)

print(patch_embeddings.shape)
assert patch_embeddings.shape == (B, N + 1, D)

# Position Embedding

In [None]:
# position embeddings
E_pos = nn.Parameter(torch.randn(1, N + 1, D))

z0 = patch_embeddings + E_pos

print(z0.shape)
assert z0.shape == (B, N + 1, D)

# Self Attention

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, embedding_dim=768, key_dim=64):
        super(SelfAttention, self).__init__()
        self.embedding_dim = embedding_dim
        self.key_dim = key_dim

        self.W_q = nn.Linear(embedding_dim, key_dim, bias=False)
        self.W_k = nn.Linear(embedding_dim, key_dim, bias=False)
        self.W_v = nn.Linear(embedding_dim, key_dim, bias=False)

    def forward(self, x):
        key_dim = self.key_dim

        q = self.W_q(x)   # (B, N, d_k)
        k = self.W_k(x)   # (B, N, d_k)
        v = self.W_v(x)   # (B, N, d_k)

        sims = torch.matmul(q, k.transpose(-2, -1))   # (B, N, N)
        scaled_sims = sims / np.sqrt(key_dim)
        
        attention_weights = F.softmax(scaled_sims, dim=-1)  # normalize over keys

        weighted_values = torch.matmul(attention_weights, v)  # (B, N, d_k)
        return weighted_values

In [None]:
D_h = 64

# init self-attention
self_attention = SelfAttention(D, D_h)   # embedding_dim, key_dim

attention_scores = self_attention(patch_embeddings)

print(attention_scores.shape)
assert attention_scores.shape == (B, N + 1, D_h)

# Multi-Head Self-Attention

In [None]:
num_heads = 12            # set number of heads (k)
embedding_dim = 768    # set dimensionality

assert embedding_dim % num_heads == 0   # dimensionality should be divisible by number of heads
key_dim = embedding_dim // num_heads   # set key,query and value dimensionality

        # init self-attentions
attention_list = [SelfAttention(embedding_dim, key_dim) for _ in range(num_heads)]
multi_head_attention = nn.ModuleList(attention_list)

        # init U_msa weight matrix
W = nn.Parameter(torch.randn(num_heads * key_dim, embedding_dim))

In [None]:
attention_scores = [attention(patch_embeddings) for attention in multi_head_attention]
for i in attention_scores:
    print(i.shape)

In [None]:
Z = torch.cat(attention_scores, -1)
print(Z.shape)

print(W.shape)

attention_score = torch.matmul(Z, W)
print(attention_score.shape)

In [None]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embedding_dim=768, num_heads=12):
        super(MultiHeadSelfAttention, self).__init__()

        self.num_heads = num_heads            # set number of heads (k)
        self.embedding_dim = embedding_dim    # set dimensionality

        assert embedding_dim % num_heads == 0   # dimensionality should be divisible by number of heads
        self.key_dim = embedding_dim // num_heads   # set key,query and value dimensionality

        # init self-attentions
        self.attention_list = [SelfAttention(embedding_dim, self.key_dim) for _ in range(num_heads)]
        self.multi_head_attention = nn.ModuleList(self.attention_list)

        # init U_msa weight matrix
        self.W = nn.Parameter(torch.randn(num_heads * self.key_dim, embedding_dim))

    def forward(self, x):
        # compute self-attention scores of each head
        attention_scores = [attention(x) for attention in self.multi_head_attention]

        # concat attentions
        Z = torch.cat(attention_scores, -1)

        # compute multi-head attention score
        attention_score = torch.matmul(Z, self.W)

        return attention_score

# Multi-Layer Perceptron

In [None]:
class MultiLayerPerceptron(nn.Module):
    def __init__(self, embedding_dim=768, hidden_dim=3072):
        super(MultiLayerPerceptron, self).__init__()

        self.mlp = nn.Sequential(
                            nn.Linear(embedding_dim, hidden_dim),
                            nn.GELU(),
                            nn.Linear(hidden_dim, hidden_dim),
                            nn.GELU(),
                            nn.Linear(hidden_dim, embedding_dim)
                   )

    def forward(self, x):
        # pass through multi-layer perceptron
        x = self.mlp(x)
        return x

In [None]:
hidden_dim = 3072

# init mlp
mlp = MultiLayerPerceptron(D, hidden_dim)

# compute mlp output
output = mlp(patch_embeddings)

assert output.shape == (B, N + 1, D)
output.shape

# Transformer Encoder

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, embedding_dim=768, num_heads=12, hidden_dim=3072, dropout_prob=0.1):
        super().__init__()

        self.MSA = MultiHeadSelfAttention(embedding_dim, num_heads)
        self.MLP = MultiLayerPerceptron(embedding_dim, hidden_dim)

        self.layer_norm1 = nn.LayerNorm(embedding_dim)
        self.layer_norm2 = nn.LayerNorm(embedding_dim)

        self.dropout1 = nn.Dropout(dropout_prob)
        self.dropout2 = nn.Dropout(dropout_prob)

    def forward(self, x):
        # --- Multi-head self-attention ---
        norm_x = self.layer_norm1(x)
        msa_out = self.MSA(norm_x)
        msa_out = self.dropout1(msa_out)
        x = x + msa_out   # residual connection

        # --- Feed-forward network ---
        norm_x = self.layer_norm2(x)
        mlp_out = self.MLP(norm_x)
        mlp_out = self.dropout2(mlp_out)
        x = x + mlp_out   # residual connection

        return x

In [None]:
dropout_prob = 0.1

# init transformer encoder
transformer_encoder = TransformerEncoder(D, n_head, hidden_dim, dropout_prob)

# compute transformer encoder output
output = transformer_encoder(patch_embeddings)

assert output.shape == (B, N + 1, D)
output.shape

# MLP Head

In [None]:
class MLPHead(nn.Module):
    def __init__(self, embedding_dim=768, num_classes=10, fine_tune=False):
        super(MLPHead, self).__init__()
        self.num_classes = num_classes

        if not fine_tune:
            # hidden layer with tanh activation function
            self.mlp_head = nn.Sequential(
                                    nn.Linear(embedding_dim, 3072),  # hidden layer
                                    nn.Tanh(),
                                    nn.Linear(3072, num_classes)    # output layer
                            )
        else:
            # single linear layer
            self.mlp_head = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        x = self.mlp_head(x)
        return x

In [None]:
cls_token = output[0][0]

n_classes = 10

mlp_head_pretrain = MLPHead(D, n_class)

output_1 = mlp_head_pretrain(z_L)
output_1

In [None]:
F.softmax(output_1, dim=0)

In [98]:
class VisionTransformer(nn.Module):
    def __init__(self, patch_size=16, image_size=224, channel_size=3,
                     num_layers=1, embedding_dim=768, num_heads=12, hidden_dim=3072,
                            dropout_prob=0.1, num_classes=10, pretrain=True):
        super(VisionTransformer, self).__init__()

        self.patch_size = patch_size
        self.channel_size = channel_size
        self.num_layers = num_layers
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.dropout_prob = dropout_prob
        self.num_classes = num_classes

        # get number of patches of the image
        self.num_patches = int(image_size ** 2 / patch_size ** 2)   # height * width / patch size ^ 2

        # trainable linear projection for mapping dimension of patches (weight matrix E)
        self.W = nn.Parameter(
                    torch.randn( patch_size * patch_size * channel_size, embedding_dim))

        # position embeddings (E_pos)
        self.pos_embedding = nn.Parameter(torch.randn(self.num_patches + 1, embedding_dim))

        # learnable class token embedding (x_class)
        self.class_token = nn.Parameter(torch.rand(1, D))

        # stack transformer encoder layers
        transformer_encoder_list = [
            TransformerEncoder(embedding_dim, num_heads, hidden_dim, dropout_prob)
                    for _ in range(num_layers)]
        self.transformer_encoder_layers = nn.Sequential(*transformer_encoder_list)

        # mlp head
        self.mlp_head = MLPHead(embedding_dim, num_classes)

    def forward(self, x):
        # get patch size and channel size
        P, C = self.patch_size, self.channel_size

        # split image into patches
        patches = x.unfold(1, C, C).unfold(2, P, P).unfold(3, P, P)
        patches = patches.contiguous().view(patches.size(0), -1, C * P * P).float()

        # linearly embed patches
        patch_embeddings = torch.matmul(patches , self.W)

        # add class token
        batch_size = patch_embeddings.shape[0]
        patch_embeddings = torch.cat((self.class_token.repeat(batch_size, 1, 1), patch_embeddings), 1)

        # add positional embedding
        patch_embeddings = patch_embeddings + self.pos_embedding

        # feed patch embeddings into a stack of Transformer encoders
        transformer_encoder_output = self.transformer_encoder_layers(patch_embeddings)

        # extract [class] token from encoder output
        output_class_token = transformer_encoder_output[:, 0]

        # pass token through mlp head for classification
        y = self.mlp_head(output_class_token)

        return y

In [99]:
model = VisionTransformer()

In [100]:
model.to("cpu")

VisionTransformer(
  (transformer_encoder_layers): Sequential(
    (0): TransformerEncoder(
      (MSA): MultiHeadSelfAttention(
        (multi_head_attention): ModuleList(
          (0-11): 12 x SelfAttention(
            (W_q): Linear(in_features=768, out_features=64, bias=False)
            (W_k): Linear(in_features=768, out_features=64, bias=False)
            (W_v): Linear(in_features=768, out_features=64, bias=False)
          )
        )
      )
      (MLP): MultiLayerPerceptron(
        (mlp): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=3072, out_features=3072, bias=True)
          (3): GELU(approximate='none')
          (4): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout