In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as f

batch_size = 4
sequence_length = 64
embed_dim = 128

x = torch.randn(batch_size, sequence_length, embed_dim)
print("Shape of input is : ", x.shape)

Shape of input is :  torch.Size([4, 64, 128])


In [3]:
x.shape

torch.Size([4, 64, 128])

In [4]:
x.transpose(1, 2).shape

torch.Size([4, 128, 64])

In [5]:
similarity = (x @ x.transpose(1, 2))

In [6]:
similarity.shape

torch.Size([4, 64, 64])

In [7]:
similarity.var()

tensor(379.4189)

In [8]:
similarity = (x @ x.transpose(1, 2)) / (embed_dim ** 0.5)

similarity

tensor([[[ 1.2351e+01, -7.1957e-02,  7.1907e-01,  ...,  4.7201e-01,
           2.4169e+00,  4.6512e-01],
         [-7.1957e-02,  9.4296e+00, -1.9366e-01,  ...,  5.9690e-01,
          -1.0521e-02, -2.3480e-01],
         [ 7.1907e-01, -1.9366e-01,  1.2710e+01,  ..., -1.8016e+00,
           3.6835e-01,  3.8830e-01],
         ...,
         [ 4.7201e-01,  5.9690e-01, -1.8016e+00,  ...,  1.1194e+01,
           1.9364e+00,  4.6739e-01],
         [ 2.4169e+00, -1.0521e-02,  3.6835e-01,  ...,  1.9364e+00,
           9.5943e+00,  1.6906e+00],
         [ 4.6512e-01, -2.3480e-01,  3.8830e-01,  ...,  4.6739e-01,
           1.6906e+00,  9.2651e+00]],

        [[ 1.1608e+01,  2.2425e+00,  3.4092e-01,  ..., -1.4978e+00,
           1.7040e+00, -1.3147e+00],
         [ 2.2425e+00,  1.3480e+01,  3.3930e-01,  ...,  7.6517e-01,
          -5.5108e-01, -1.8567e-01],
         [ 3.4092e-01,  3.3930e-01,  1.0917e+01,  ...,  2.5809e-01,
          -7.4044e-01, -1.4849e-01],
         ...,
         [-1.4978e+00,  7

In [9]:
attn_mat = similarity.softmax(dim = 2)

attn_mat

tensor([[[9.9956e-01, 4.0245e-06, 8.8767e-06,  ..., 6.9335e-06,
          4.8483e-05, 6.8859e-06],
         [7.4093e-05, 9.9144e-01, 6.5603e-05,  ..., 1.4463e-04,
          7.8787e-05, 6.2958e-05],
         [6.2011e-06, 2.4893e-06, 9.9964e-01,  ..., 4.9862e-07,
          4.3666e-06, 4.4546e-06],
         ...,
         [2.2018e-05, 2.4947e-05, 2.2666e-06,  ..., 9.9849e-01,
          9.5228e-05, 2.1916e-05],
         [7.5734e-04, 6.6849e-05, 9.7641e-05,  ..., 4.6843e-04,
          9.9173e-01, 3.6634e-04],
         [1.4941e-04, 7.4198e-05, 1.3836e-04,  ..., 1.4975e-04,
          5.0885e-04, 9.9114e-01]],

        [[9.9882e-01, 8.5511e-05, 1.2769e-05,  ..., 2.0306e-06,
          4.9904e-05, 2.4387e-06],
         [1.3167e-05, 9.9986e-01, 1.9630e-06,  ..., 3.0052e-06,
          8.0581e-07, 1.1613e-06],
         [2.5486e-05, 2.5445e-05, 9.9831e-01,  ..., 2.3460e-05,
          8.6433e-06, 1.5623e-05],
         ...,
         [6.0843e-06, 5.8479e-05, 3.5219e-05,  ..., 9.9741e-01,
          2.620

In [10]:
attn_mat.shape, x.shape

(torch.Size([4, 64, 64]), torch.Size([4, 64, 128]))

In [11]:
output = (attn_mat @ x)

output.shape

torch.Size([4, 64, 128])

In [12]:
linear = nn.Linear(10, 20)

rand = torch.randn(4, 6, 10)

linear(rand).shape

torch.Size([4, 6, 20])

In [13]:
#Single head attention

class Attention(nn.Module):
    def __init__(self, embedding_dimension):

        super().__init__()

        self.embed_dim = embedding_dimension

        self.query = nn.Linear(self.embed_dim, self.embed_dim)
        self.key = nn.Linear(self.embed_dim, self.embed_dim)
        self.value = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self, x):
        
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        similarity = (q @ k.transpose(1, 2)) / self.embed_dim ** 0.5
        attention = similarity.softmax(axis=2)
        output = attention @ v

rand = torch.randn(4, 64, 128)

attn = Attention(embedding_dimension=128)
attn(rand)

In [14]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dimension, num_heads):
        super().__init__()

        self.embed_dim = embedding_dimension
        self.num_heads = num_heads

        self.head_dim = self.embed_dim // self.num_heads

        self.multihead_qkv = nn.ModuleList()

        for head in range(num_heads):
            qkv_proj = nn.ModuleDict(
                [
                  ["Q", nn.Linear(self.embed_dim, self.head_dim)],
                  ["K", nn.Linear(self.embed_dim, self.head_dim)],
                  ["V", nn.Linear(self.embed_dim, self.head_dim)]  
                ]
            )

            self.multihead_qkv.append(qkv_proj)

        self.proj = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self, x):
        head_outs = []

        for head in self.multihead_qkv:

            q = head["Q"](x)
            k = head["K"](x)
            v = head["V"](x)

            similarity = (q @ k.transpose(1, 2)) / self.head_dim ** 0.5
            attention = similarity.softmax(axis=-1)
            output = attention @ v

            head_outs.append(output)

        head_outs = torch.cat(head_outs, dim=-1)

        out = self.proj(head_outs)

        return out
            

rand = torch.randn(4, 64, 128)
attn = MultiHeadAttention(128, 4)
attn(rand).shape

torch.Size([4, 64, 128])

In [15]:
fc = nn.Linear(10, 30)

tensor_1 = torch.randn(5, 10)
tensor_1_out = fc(tensor_1)

print("Inpur shape : ", tensor_1.shape, "Output shape : ", tensor_1_out.shape)

tensor_2 = torch.randn(5, 1, 2, 4, 10)
tensor_2_out = fc(tensor_2)

print("Inpur shape : ", tensor_2.shape, "Output shape : ", tensor_2_out.shape)

Inpur shape :  torch.Size([5, 10]) Output shape :  torch.Size([5, 30])
Inpur shape :  torch.Size([5, 1, 2, 4, 10]) Output shape :  torch.Size([5, 1, 2, 4, 30])


In [16]:
tensor = torch.rand(1, 8, 9)
fc = nn.Linear(9, 9)

q = fc(tensor)

chunk1, chun2, chunk3 = torch.chunk(q, 3, axis=-1)

In [17]:
a = torch.rand(1, 2, 6, 4)
b = torch.rand(1, 2, 4, 3)

a @ b

tensor([[[[1.3403, 0.5602, 0.7023],
          [2.3723, 0.7682, 1.0908],
          [2.1227, 0.6838, 1.0203],
          [1.6837, 0.5906, 0.9572],
          [1.6483, 0.5486, 1.1604],
          [0.9457, 0.1740, 0.3532]],

         [[1.1002, 1.5632, 0.9837],
          [0.9615, 1.4630, 1.0461],
          [1.0916, 1.3000, 0.6108],
          [0.4304, 0.7281, 0.4745],
          [0.6627, 0.9203, 0.4878],
          [0.5577, 0.9763, 0.7935]]]])

In [18]:
class SelfAttentionEncoder(nn.Module):
    def __init__(self, embed_dim, num_heads, attn_p = 0.0, proj_p = 0.0, bias = True):

        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.query = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.key = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.value = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.attn_drop = nn.Dropout(attn_p)

        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_p)


    def forward(self, x):
        batch_size, sequence_length, embed_dim = x.shape

        print(x.shape)

        q = self.query(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
        
        attn = q @ k.transpose(-2, -1) * self.head_dim ** 0.5
        attn = attn.softmax(axis=-1)
        attn = self.attn_drop(attn)
        x = attn @ v

        x = x.transpose(1, 2).reshape(batch_size, sequence_length, embed_dim)
        print(x.shape)

        x = self.proj(x)
        x = self.proj_drop(x)


rand = torch.randn(4, 16, 128)
attn = SelfAttentionEncoder(128, 4)
attn(rand)

torch.Size([4, 16, 128])
torch.Size([4, 16, 128])


In [26]:
rand_attn = torch.rand(1, 6, 6)

attention_mask = torch.tensor([1, 1, 1, 1, 0, 0]).unsqueeze(0).bool()
print(attention_mask.shape)
# attention_mask = attention_mask.unsqueeze(1)
print(attention_mask.shape, rand_attn.shape)

# torch.softmax(rand_attn.masked_fill(~attention_mask, -float("inf")), axis=-1)

attention_mask.shape, rand_attn.shape

torch.Size([1, 6])
torch.Size([1, 6]) torch.Size([1, 6, 6])


(torch.Size([1, 6]), torch.Size([1, 6, 6]))