In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as f

batch_size = 4
sequence_length = 64
embed_dim = 128

x = torch.randn(batch_size, sequence_length, embed_dim)
print("Shape of input is : ", x.shape)

Shape of input is :  torch.Size([4, 64, 128])


In [2]:
x.shape

torch.Size([4, 64, 128])

In [3]:
x.transpose(1, 2).shape

torch.Size([4, 128, 64])

In [4]:
similarity = (x @ x.transpose(1, 2))

In [5]:
similarity.shape

torch.Size([4, 64, 64])

In [6]:
similarity.var()

tensor(394.5553)

In [7]:
similarity = (x @ x.transpose(1, 2)) / (embed_dim ** 0.5)

similarity

tensor([[[ 1.1178e+01,  5.7155e-01,  1.4334e+00,  ..., -2.2718e+00,
           1.1709e+00, -1.4546e+00],
         [ 5.7155e-01,  1.1902e+01,  1.7159e+00,  ...,  2.0412e-02,
           1.1101e+00, -8.6087e-01],
         [ 1.4334e+00,  1.7159e+00,  1.0122e+01,  ..., -6.0288e-01,
           9.0337e-01, -4.3254e-01],
         ...,
         [-2.2718e+00,  2.0412e-02, -6.0288e-01,  ...,  1.2256e+01,
          -5.6293e-01,  1.1852e+00],
         [ 1.1709e+00,  1.1101e+00,  9.0337e-01,  ..., -5.6293e-01,
           9.9783e+00, -1.3724e+00],
         [-1.4546e+00, -8.6087e-01, -4.3254e-01,  ...,  1.1852e+00,
          -1.3724e+00,  1.1909e+01]],

        [[ 9.8203e+00, -1.0809e+00,  1.7134e+00,  ..., -1.8418e+00,
          -7.9413e-01, -2.6518e-01],
         [-1.0809e+00,  1.4634e+01, -1.1969e+00,  ...,  1.6224e+00,
          -7.3279e-01,  6.2115e-01],
         [ 1.7134e+00, -1.1969e+00,  1.0856e+01,  ...,  1.5530e+00,
           7.0712e-01,  1.8886e+00],
         ...,
         [-1.8418e+00,  1

In [8]:
attn_mat = similarity.softmax(dim = 2)

attn_mat

tensor([[[9.9830e-01, 2.4720e-05, 5.8524e-05,  ..., 1.4395e-06,
          4.5013e-05, 3.2593e-06],
         [1.1991e-05, 9.9931e-01, 3.7659e-05,  ..., 6.9105e-06,
          2.0548e-05, 2.8627e-06],
         [1.6773e-04, 2.2250e-04, 9.9512e-01,  ..., 2.1892e-05,
          9.8727e-05, 2.5957e-05],
         ...,
         [4.9053e-07, 4.8544e-06, 2.6028e-06,  ..., 9.9959e-01,
          2.7089e-06, 1.5559e-05],
         [1.4898e-04, 1.4020e-04, 1.1402e-04,  ..., 2.6312e-05,
          9.9579e-01, 1.1712e-05],
         [1.5702e-06, 2.8430e-06, 4.3632e-06,  ..., 2.1998e-05,
          1.7047e-06, 9.9935e-01]],

        [[9.9475e-01, 1.8340e-05, 2.9990e-04,  ..., 8.5696e-06,
          2.4431e-05, 4.1463e-05],
         [1.4962e-07, 9.9996e-01, 1.3323e-07,  ..., 2.2336e-06,
          2.1192e-07, 8.2069e-07],
         [1.0685e-04, 5.8183e-06, 9.9815e-01,  ..., 9.1012e-05,
          3.9059e-05, 1.2731e-04],
         ...,
         [2.2458e-06, 7.1751e-05, 6.6945e-05,  ..., 9.9838e-01,
          3.213

In [9]:
attn_mat.shape, x.shape

(torch.Size([4, 64, 64]), torch.Size([4, 64, 128]))

In [10]:
output = (attn_mat @ x)

output.shape

torch.Size([4, 64, 128])

In [11]:
linear = nn.Linear(10, 20)

rand = torch.randn(4, 6, 10)

linear(rand).shape

torch.Size([4, 6, 20])

In [12]:
#Single head attention

class Attention(nn.Module):
    def __init__(self, embedding_dimension):

        super().__init__()

        self.embed_dim = embedding_dimension

        self.query = nn.Linear(self.embed_dim, self.embed_dim)
        self.key = nn.Linear(self.embed_dim, self.embed_dim)
        self.value = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self, x):
        
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        similarity = (q @ k.transpose(1, 2)) / self.embed_dim ** 0.5
        attention = similarity.softmax(axis=2)
        output = attention @ v

rand = torch.randn(4, 64, 128)

attn = Attention(embedding_dimension=128)
attn(rand)

In [13]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dimension, num_heads):
        super().__init__()

        self.embed_dim = embedding_dimension
        self.num_heads = num_heads

        self.head_dim = self.embed_dim // self.num_heads

        self.multihead_qkv = nn.ModuleList()

        for head in range(num_heads):
            qkv_proj = nn.ModuleDict(
                [
                  ["Q", nn.Linear(self.embed_dim, self.head_dim)],
                  ["K", nn.Linear(self.embed_dim, self.head_dim)],
                  ["V", nn.Linear(self.embed_dim, self.head_dim)]  
                ]
            )

            self.multihead_qkv.append(qkv_proj)

        self.proj = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self, x):
        head_outs = []

        for head in self.multihead_qkv:

            q = head["Q"](x)
            k = head["K"](x)
            v = head["V"](x)

            similarity = (q @ k.transpose(1, 2)) / self.head_dim ** 0.5
            attention = similarity.softmax(axis=-1)
            output = attention @ v

            head_outs.append(output)

        head_outs = torch.cat(head_outs, dim=-1)

        out = self.proj(head_outs)

        return out
            

rand = torch.randn(4, 64, 128)
attn = MultiHeadAttention(128, 4)
attn(rand).shape

torch.Size([4, 64, 128])

In [14]:
fc = nn.Linear(10, 30)

tensor_1 = torch.randn(5, 10)
tensor_1_out = fc(tensor_1)

print("Inpur shape : ", tensor_1.shape, "Output shape : ", tensor_1_out.shape)

tensor_2 = torch.randn(5, 1, 2, 4, 10)
tensor_2_out = fc(tensor_2)

print("Inpur shape : ", tensor_2.shape, "Output shape : ", tensor_2_out.shape)

Inpur shape :  torch.Size([5, 10]) Output shape :  torch.Size([5, 30])
Inpur shape :  torch.Size([5, 1, 2, 4, 10]) Output shape :  torch.Size([5, 1, 2, 4, 30])


In [15]:
tensor = torch.rand(1, 8, 9)
fc = nn.Linear(9, 9)

q = fc(tensor)

chunk1, chun2, chunk3 = torch.chunk(q, 3, axis=-1)

In [16]:
a = torch.rand(1, 2, 6, 4)
b = torch.rand(1, 2, 4, 3)

a @ b

tensor([[[[1.2597, 0.7285, 0.6481],
          [1.2363, 0.8279, 0.6588],
          [0.8187, 0.6148, 0.5989],
          [0.8854, 0.4790, 0.4740],
          [1.2740, 1.0117, 0.7970],
          [1.1715, 0.8315, 0.7771]],

         [[1.5752, 1.3709, 0.8179],
          [1.2145, 0.9876, 0.6341],
          [1.9579, 1.7720, 0.8514],
          [1.6245, 1.6110, 1.0023],
          [1.5703, 1.2739, 0.5433],
          [0.9940, 1.1083, 0.7120]]]])

In [17]:
class SelfAttentionEncoder(nn.Module):
    def __init__(self, embed_dim, num_heads, attn_p = 0.0, proj_p = 0.0, bias = True):

        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.query = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.key = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.value = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.attn_drop = nn.Dropout(attn_p)

        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_p)


    def forward(self, x):
        batch_size, sequence_length, embed_dim = x.shape

        print(x.shape)

        q = self.query(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
        
        attn = q @ k.transpose(-2, -1) * self.head_dim ** 0.5
        attn = attn.softmax(axis=-1)
        attn = self.attn_drop(attn)
        x = attn @ v

        x = x.transpose(1, 2).reshape(batch_size, sequence_length, embed_dim)
        print(x.shape)

        x = self.proj(x)
        x = self.proj_drop(x)


rand = torch.randn(4, 16, 128)
attn = SelfAttentionEncoder(128, 4)
attn(rand)

torch.Size([4, 16, 128])
torch.Size([4, 16, 128])


In [18]:
rand_attn = torch.rand(1, 6, 6)

attention_mask = torch.tensor([1, 1, 1, 1, 0, 0]).unsqueeze(0).bool()
print(attention_mask.shape)
attention_mask = attention_mask.unsqueeze(1).repeat(1, 6, 1)
print(attention_mask.shape, rand_attn.shape)

# torch.softmax(rand_attn.masked_fill(~attention_mask, -float("inf")), axis=-1)

attention_mask.shape, rand_attn.shape

torch.Size([1, 6])
torch.Size([1, 6, 6]) torch.Size([1, 6, 6])


(torch.Size([1, 6, 6]), torch.Size([1, 6, 6]))

In [None]:
rand_attn = torch.rand(1, 2, 6, 6)

attention_mask = torch.tensor([1, 1, 1, 1, 0, 0]).unsqueeze(0).bool()
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
attention_mask = attention_mask.repeat(1, 2, 6, 1)
print(attention_mask)

# rand_attn.masked_fill_(~attention_mask, float("-inf"))

tensor([[[[ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False]],

         [[ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False]]]])


In [36]:
class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, attn_p = 0.0, proj_p = 0.0, bias = True):

        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.query = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.key = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.value = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.attn_drop = nn.Dropout(attn_p)

        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_p)


    def forward(self, x, attention_mask = None):
        batch_size, sequence_length, embed_dim = x.shape

        print(x.shape)

        q = self.query(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)

        attn = q @ k.transpose(-2, -1) * self.head_dim ** 0.5
        
        #############################################

        if attention_mask is not None:
            print("Attention mask")

            attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # .repeat(1, 1, sequence_length, 1) Optional
            print(attention_mask[0])
            print(attn[0])

            attn = attn.masked_fill_(~attention_mask, float("-inf"))
            print(attn)

        #############################################

        attn = attn.softmax(axis=-1)
        print(attn)
        attn = self.attn_drop(attn)
        x = attn @ v
        print(attn.shape)

        x = x.transpose(1, 2).reshape(batch_size, sequence_length, embed_dim)
        print(x.shape)

        x = self.proj(x)
        x = self.proj_drop(x)


seq_lens = [3, 5, 4]
embed_dim = 9
num_heads = 3

a = SelfAttention(embed_dim, num_heads)

rand = torch.randn(len(seq_lens), max(seq_lens), embed_dim)

masks = torch.nn.utils.rnn.pad_sequence([torch.ones(l) for l in seq_lens], batch_first=True, padding_value=0).bool()
print("Attention mask : ")
print(masks)

output = a(rand, masks)

Attention mask : 
tensor([[ True,  True,  True, False, False],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True, False]])
torch.Size([3, 5, 9])
Attention mask
tensor([[[ True,  True,  True, False, False]]])
tensor([[[ 0.6316,  1.0843,  1.2473,  0.8408,  0.1527],
         [-1.8561, -1.3194,  0.8943,  2.3284,  1.4364],
         [-0.2901, -0.4388, -0.5369, -0.3559, -0.0659],
         [-0.8728, -0.5808, -0.7983, -0.2802,  0.0417],
         [ 1.0928,  0.4620, -1.7278, -2.6655, -1.3850]],

        [[-0.4862,  2.3263,  0.9369, -0.4796, -1.5877],
         [ 0.3433, -2.0136,  1.7547,  1.0691,  0.0433],
         [ 0.4617, -1.6016, -1.8754, -0.0097,  1.6980],
         [ 1.8979, -2.5033, -1.4901, -0.0056,  1.4520],
         [-0.2775, -0.2503, -0.3614,  0.0843,  0.4318]],

        [[ 0.8664,  1.0094,  0.8077, -0.3590, -0.6604],
         [ 1.0388,  2.4786, -0.1004,  2.4877,  0.2935],
         [ 0.1927,  1.3093, -0.6028,  1.4122,  0.5153],
         [-0.8883, -0.5865, -

In [None]:
seq_len = 8

ones = torch.ones(seq_len, seq_len)
causal_mask = torch.tril(ones).bool()

padding_mask = torch.tensor([1, 1, 1, 1, 0, 0, 0, 0]).bool()
padding_mask = padding_mask.unsqueeze(0).repeat(seq_len, 1)

causal_mask = causal_mask.masked_fill_(~padding_mask, 0)
causal_mask 

tensor([[ True, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False]])

In [55]:
class CausalSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, causal= True, attn_p = 0.0, proj_p = 0.0, bias = True):

        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.causal = causal

        self.query = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.key = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.value = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.attn_drop = nn.Dropout(attn_p)

        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_p)


    def forward(self, x, attention_mask = None):
        batch_size, sequence_length, embed_dim = x.shape


        q = self.query(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)

        attn = q @ k.transpose(-2, -1) * self.head_dim ** 0.5

        if attention_mask is not None:
                attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # .repeat(1, 1, sequence_length, 1) Optional
        
        if self.causal:

            ones = torch.ones((sequence_length, sequence_length), device=attn.device)
            causal_mask = torch.tril(ones).bool()
            causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
            print(causal_mask)

            if attention_mask is not None:

                causal_mask = causal_mask.repeat(batch_size, 1, 1, 1)
                causal_mask = causal_mask.masked_fill_(~attention_mask, False)
            

            attn = attn.masked_fill_(~causal_mask, float("-inf"))

        #############################################

        attn = attn.softmax(axis=-1)
        attn = self.attn_drop(attn)
        x = attn @ v

        x = x.transpose(1, 2).reshape(batch_size, sequence_length, embed_dim)

        x = self.proj(x)
        x = self.proj_drop(x)


seq_lens = [3, 5, 4]
embed_dim = 9
num_heads = 3

a = SelfAttention(embed_dim, num_heads)

rand = torch.randn(len(seq_lens), max(seq_lens), embed_dim)

masks = torch.nn.utils.rnn.pad_sequence([torch.ones(l) for l in seq_lens], batch_first=True, padding_value=0).bool()
print("Attention mask : ")
print(masks)

output = a(rand, masks)

Attention mask : 
tensor([[ True,  True,  True, False, False],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True, False]])
tensor([[[[ True, False, False, False, False],
          [ True,  True, False, False, False],
          [ True,  True,  True, False, False],
          [ True,  True,  True,  True, False],
          [ True,  True,  True,  True,  True]]]])
Attention mask
tensor([[[[ True, False, False, False, False],
          [ True,  True, False, False, False],
          [ True,  True,  True, False, False],
          [ True,  True,  True, False, False],
          [ True,  True,  True, False, False]]],


        [[[ True, False, False, False, False],
          [ True,  True, False, False, False],
          [ True,  True,  True, False, False],
          [ True,  True,  True,  True, False],
          [ True,  True,  True,  True,  True]]],


        [[[ True, False, False, False, False],
          [ True,  True, False, False, False],
          [ True,  Tr