In [56]:
# CODE
import torch
from torch import nn
import math


In [57]:
class DotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, queries, keys, values):
        # queries: (B, N, D)
        # keys: (B, M, D)
        # values: (B, M, V)

        assert queries.shape[0] == keys.shape[0]
        assert queries.shape[0] == values.shape[0]
        assert queries.shape[-1] == keys.shape[-1]
        assert keys.shape[1] == values.shape[1]

        d = keys.shape[-1]

        # scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        scores = queries @ keys.transpose(1, 2) / math.sqrt(d)
        attention_weights = nn.functional.softmax(scores, dim=-1)

        return attention_weights @ values

In [58]:
queries = torch.zeros((2, 1, 2))
keys = torch.zeros((2, 10, 2))
values = torch.zeros((2, 10, 4))

In [59]:
att = DotProductAttention()
att(queries, keys, values).shape

torch.Size([2, 1, 4])

In [60]:
#CODE
class MultiHeadAttention(nn.Module):
    def __init__(self, num_hiddens, num_heads):
        super().__init__()

        self.num_heads = num_heads
        self.num_hiddens = num_hiddens

        self.attention = DotProductAttention()

        self.W_q = nn.LazyLinear(num_hiddens)
        self.W_k = nn.LazyLinear(num_hiddens)
        self.W_v = nn.LazyLinear(num_hiddens)

        self.W_o = nn.LazyLinear(num_hiddens)

    def transpose_qkv(self, X):
        """Transposition for parallel computation of multiple attention heads."""
        # Shape of input X: (batch_size, no. of queries or key-value pairs,
        # num_hiddens). Shape of output X: (batch_size, no. of queries or
        # key-value pairs, num_heads, num_hiddens / num_heads)
        X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
        # Shape of output X: (batch_size, num_heads, no. of queries or key-value
        # pairs, num_hiddens / num_heads)
        X = X.permute(0, 2, 1, 3)
        # Shape of output: (batch_size * num_heads, no. of queries or key-value
        # pairs, num_hiddens / num_heads)
        return X.reshape(-1, X.shape[2], X.shape[3])
    
    def transpose_output(self, X):
        """Reverse the operation of transpose_qkv."""
        X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
        X = X.permute(0, 2, 1, 3)
        return X.reshape(X.shape[0], X.shape[1], -1)

    def forward(self, queries, keys, values):
        # qkv: (B, N_qkv, num_hiddens)

        q = self.transpose_qkv(self.W_q(queries))
        k = self.transpose_qkv(self.W_k(keys))
        v = self.transpose_qkv(self.W_v(values))

        output = self.transpose_output(self.attention(q, k, v))

        return self.W_o(output)

        

In [61]:
num_hiddens = 100
num_heads = 5

mha = MultiHeadAttention(num_hiddens, num_heads)

X = torch.ones((16, 5, 5))
mha(X, X, X).shape



torch.Size([16, 5, 100])

In [62]:
# CODE
class PositionalEncoding(nn.Module):
    def __init__(self, num_hiddens, max_tokens=1000):
        super().__init__()
        self.P = torch.zeros((1, max_tokens, num_hiddens))
        X = torch.arange(max_tokens, dtype=torch.float32).reshape(-1, 1) / torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32)/num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :]
        return X

In [63]:
encoding_dim, num_steps = 32, 60
pe = PositionalEncoding(encoding_dim)
X = torch.zeros((1, num_steps, encoding_dim))
XPE = pe(X)

print(X.shape, XPE.shape)

torch.Size([1, 60, 32]) torch.Size([1, 60, 32])


In [64]:
XPE

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
           0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  5.3317e-01,  ...,  1.0000e+00,
           1.7783e-04,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  9.0213e-01,  ...,  1.0000e+00,
           3.5566e-04,  1.0000e+00],
         ...,
         [ 4.3616e-01,  8.9987e-01,  5.9521e-01,  ...,  9.9984e-01,
           1.0136e-02,  9.9995e-01],
         [ 9.9287e-01,  1.1918e-01,  9.3199e-01,  ...,  9.9983e-01,
           1.0314e-02,  9.9995e-01],
         [ 6.3674e-01, -7.7108e-01,  9.8174e-01,  ...,  9.9983e-01,
           1.0492e-02,  9.9994e-01]]])

In [65]:
class FFN(nn.Module):
    def __init__(self, ffn_dim, num_hiddens):
        super().__init__()
        self.w1 = nn.LazyLinear(ffn_dim)
        self.w2 = nn.LazyLinear(num_hiddens)
        self.relu = nn.ReLU()
        
    def forward(self, X):
        return self.w2(self.relu(self.w1(X)))

class AddNorm(nn.Module):
    def __init__(self, num_hiddens):
        super().__init__()
        self.ln = nn.LayerNorm(num_hiddens)

    def forward(self, res, X):
        return self.ln(X + res)

In [66]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, num_hiddens, ffn_dim, num_heads):

        super().__init__()

        self.num_hiddens = num_hiddens
        self.ffn_dim = ffn_dim
        self.num_heads = num_heads

        self.attention = MultiHeadAttention(num_hiddens, num_heads)
        self.ffn = FFN(ffn_dim, num_hiddens)
        self.addnorm1 = AddNorm(num_hiddens)
        self.addnorm2 = AddNorm(num_hiddens)

    def forward(self, X):
        Y = self.addnorm1(X, self.attention(X, X, X))
        return self.addnorm2(Y, self.ffn(Y))

In [67]:
# CODE
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, num_hiddens, num_blocks, ffn_dim, num_heads):
        super().__init__()

        self.vocab_size = vocab_size
        self.num_hiddens = num_hiddens
        
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.positional_encoding  = PositionalEncoding(num_hiddens)

        self.blocks = nn.ModuleList()

        for i in range(num_blocks):
            self.blocks.append(TransformerEncoderBlock(num_hiddens, ffn_dim, num_heads))

    def forward(self, X):
        X = self.embedding(X)
        X = self.positional_encoding(X)

        for i, block in enumerate(self.blocks):
            X = block(X)

        return X

In [68]:
encoder = TransformerEncoder(1000, 300, 6, 100, 5)

In [69]:
inp = torch.zeros((2, 100)).long()
encoder(inp).shape

torch.Size([2, 100, 300])

<h2>Your supertask!</h2>

Using this encoder implementation try to implement ViT. Have fun! :)

In [70]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size, patch_size, embedding_dim, channels=3):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.embedding_dim = embedding_dim

        self.num_patches = (img_size // patch_size) ** 2
        self.flatten_dim = patch_size * patch_size * channels

        self.linear_projection = nn.Linear(self.flatten_dim, embedding_dim)
    def forward(self, X):
        #X shape (B, C, H, W)
        X = X.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size) #creating patches
        X = X.contiguous().view(X.size(0), X.size(1), X.size(2) * X.size(3), -1)
        X = X.permute(0, 2, 1, 3).contiguous().view(X.size(0), X.size(2), -1)
        return self.linear_projection(X)




In [71]:
class TransformerEncoderVit(nn.Module):
    def __init__(self, num_hiddens, num_blocks, ffn_dim, num_heads):
        super().__init__()
        self.positional_encoding = PositionalEncoding(num_hiddens)

        self.blocks = nn.ModuleList()
        for i in range(num_blocks):
            self.blocks.append(TransformerEncoderBlock(num_hiddens, ffn_dim, num_heads))

    def forward(self, X):
        X = self.positional_encoding(X)
        for block in self.blocks:
            X = block(X)
        return X

In [72]:
class ViT(nn.Module):
    def __init__(self, img_size, patch_size, embedding_dim, num_blocks, ffn_dim, num_heads, channels=3):
        super().__init__()
        self.patch_embedding = PatchEmbedding(img_size, patch_size, embedding_dim, channels)
        self.transformer_encoder = TransformerEncoderVit(embedding_dim, num_blocks, ffn_dim, num_heads)

    def forward(self, X):
        X = self.patch_embedding(X)
        X = self.transformer_encoder(X)
        return X

In [73]:
img_size = 224  
patch_size = 16
channels = 3   
embedding_dim = 768
num_blocks = 12
ffn_dim = 2048
num_heads = 12

vit = ViT(img_size, patch_size, embedding_dim, num_blocks=num_blocks, ffn_dim=ffn_dim, num_heads=num_heads, channels=channels)
dummy_image = torch.randn(1, channels, img_size, img_size)
output = vit(dummy_image)

print("Output shape:", output.shape)


Output shape: torch.Size([1, 196, 768])
