In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as f

batch_size = 4
sequence_length = 64
embed_dim = 128

x = torch.randn(batch_size, sequence_length, embed_dim)
print("Shape of input is : ", x.shape)

Shape of input is :  torch.Size([4, 64, 128])


In [30]:
x.shape

torch.Size([4, 64, 128])

In [31]:
x.transpose(1, 2).shape

torch.Size([4, 128, 64])

In [32]:
similarity = (x @ x.transpose(1, 2))

In [33]:
similarity.shape

torch.Size([4, 64, 64])

In [34]:
similarity.var()

tensor(379.6125)

In [35]:
similarity = (x @ x.transpose(1, 2)) / (embed_dim ** 0.5)

similarity

tensor([[[ 1.2232e+01, -1.6930e-01, -3.2264e-01,  ..., -5.7459e-01,
           1.8301e+00,  1.3866e+00],
         [-1.6930e-01,  1.1490e+01, -8.9926e-01,  ...,  9.0709e-01,
          -4.6978e-01,  1.2035e+00],
         [-3.2264e-01, -8.9926e-01,  1.1293e+01,  ...,  2.7990e-01,
          -1.3360e+00, -2.1876e+00],
         ...,
         [-5.7459e-01,  9.0709e-01,  2.7990e-01,  ...,  1.3370e+01,
           1.9783e+00, -3.9287e-01],
         [ 1.8301e+00, -4.6978e-01, -1.3360e+00,  ...,  1.9783e+00,
           1.1668e+01,  7.7543e-01],
         [ 1.3866e+00,  1.2035e+00, -2.1876e+00,  ..., -3.9287e-01,
           7.7543e-01,  1.2696e+01]],

        [[ 1.2441e+01,  6.4279e-01, -8.4803e-01,  ...,  2.0921e+00,
          -1.3674e+00, -1.8692e-02],
         [ 6.4279e-01,  1.3137e+01, -7.0683e-01,  ...,  6.0690e-01,
           6.4403e-01, -1.6587e+00],
         [-8.4803e-01, -7.0683e-01,  1.3899e+01,  ..., -5.2183e-01,
           5.7990e-01,  4.4243e-01],
         ...,
         [ 2.0921e+00,  6

In [36]:
attn_mat = similarity.softmax(dim = 2)

attn_mat

tensor([[[9.9924e-01, 4.1118e-06, 3.5273e-06,  ..., 2.7417e-06,
          3.0363e-05, 1.9487e-05],
         [8.6280e-06, 9.9888e-01, 4.1581e-06,  ..., 2.5315e-05,
          6.3887e-06, 3.4050e-05],
         [9.0126e-06, 5.0632e-06, 9.9903e-01,  ..., 1.6464e-05,
          3.2714e-06, 1.3961e-06],
         ...,
         [8.7881e-07, 3.8671e-06, 2.0654e-06,  ..., 9.9981e-01,
          1.1287e-05, 1.0539e-06],
         [5.3332e-05, 5.3478e-06, 2.2489e-06,  ..., 6.1852e-05,
          9.9915e-01, 1.8577e-05],
         [1.2247e-05, 1.0198e-05, 3.4341e-07,  ..., 2.0665e-06,
          6.6469e-06, 9.9964e-01]],

        [[9.9952e-01, 7.5147e-06, 1.6922e-06,  ..., 3.2014e-05,
          1.0066e-06, 3.8782e-06],
         [3.7468e-06, 9.9986e-01, 9.7168e-07,  ..., 3.6147e-06,
          3.7514e-06, 3.7509e-07],
         [3.9375e-07, 4.5346e-07, 9.9989e-01,  ..., 5.4562e-07,
          1.6420e-06, 1.4311e-06],
         ...,
         [2.3052e-05, 5.2203e-06, 1.6885e-06,  ..., 9.9964e-01,
          2.484

In [37]:
attn_mat.shape, x.shape

(torch.Size([4, 64, 64]), torch.Size([4, 64, 128]))

In [38]:
output = (attn_mat @ x)

output.shape

torch.Size([4, 64, 128])

In [40]:
linear = nn.Linear(10, 20)

rand = torch.randn(4, 6, 10)

linear(rand).shape

torch.Size([4, 6, 20])

In [57]:
#Single head attention

class Attention(nn.Module):
    def __init__(self, embedding_dimension):

        super().__init__()

        self.embed_dim = embedding_dimension

        self.query = nn.Linear(self.embed_dim, self.embed_dim)
        self.key = nn.Linear(self.embed_dim, self.embed_dim)
        self.value = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self, x):
        
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        similarity = (q @ k.transpose(1, 2)) / self.embed_dim ** 0.5
        attention = similarity.softmax(axis=2)
        output = attention @ v

rand = torch.randn(4, 64, 128)

attn = Attention(embedding_dimension=128)
attn(rand)

In [63]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dimension, num_heads):
        super().__init__()

        self.embed_dim = embedding_dimension
        self.num_heads = num_heads

        self.head_dim = self.embed_dim // self.num_heads

        self.multihead_qkv = nn.ModuleList()

        for head in range(num_heads):
            qkv_proj = nn.ModuleDict(
                [
                  ["Q", nn.Linear(self.embed_dim, self.head_dim)],
                  ["K", nn.Linear(self.embed_dim, self.head_dim)],
                  ["V", nn.Linear(self.embed_dim, self.head_dim)]  
                ]
            )

            self.multihead_qkv.append(qkv_proj)

        self.proj = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self, x):
        head_outs = []

        for head in self.multihead_qkv:

            q = head["Q"](x)
            k = head["K"](x)
            v = head["V"](x)

            similarity = (q @ k.transpose(1, 2)) / self.embed_dim ** 0.5
            attention = similarity.softmax(axis=-1)
            output = attention @ v

            head_outs.append(output)

        head_outs = torch.cat(head_outs, dim=-1)

        out = self.proj(head_outs)

        return out
            

rand = torch.randn(4, 64, 128)
attn = MultiHeadAttention(128, 4)
attn(rand).shape

torch.Size([4, 64, 128])

In [65]:
fc = nn.Linear(10, 30)

tensor_1 = torch.randn(5, 10)
tensor_1_out = fc(tensor_1)

print("Inpur shape : ", tensor_1.shape, "Output shape : ", tensor_1_out.shape)

tensor_2 = torch.randn(5, 1, 2, 4, 10)
tensor_2_out = fc(tensor_2)

print("Inpur shape : ", tensor_2.shape, "Output shape : ", tensor_2_out.shape)

Inpur shape :  torch.Size([5, 10]) Output shape :  torch.Size([5, 30])
Inpur shape :  torch.Size([5, 1, 2, 4, 10]) Output shape :  torch.Size([5, 1, 2, 4, 30])
