In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as f

batch_size = 4
sequence_length = 64
embed_dim = 128

x = torch.randn(batch_size, sequence_length, embed_dim)
print("Shape of input is : ", x.shape)

Shape of input is :  torch.Size([4, 64, 128])


In [2]:
x.shape

torch.Size([4, 64, 128])

In [3]:
x.transpose(1, 2).shape

torch.Size([4, 128, 64])

In [4]:
similarity = (x @ x.transpose(1, 2))

In [5]:
similarity.shape

torch.Size([4, 64, 64])

In [6]:
similarity.var()

tensor(383.9728)

In [7]:
similarity = (x @ x.transpose(1, 2)) / (embed_dim ** 0.5)

similarity

tensor([[[10.4182, -0.3275, -1.2412,  ..., -1.0767, -0.7180, -0.2053],
         [-0.3275, 10.6836,  0.8792,  ..., -0.6690, -0.9858, -0.5184],
         [-1.2412,  0.8792,  9.5799,  ...,  0.3521, -0.6487, -0.7115],
         ...,
         [-1.0767, -0.6690,  0.3521,  ...,  9.4699, -0.2579,  0.5086],
         [-0.7180, -0.9858, -0.6487,  ..., -0.2579, 11.2757, -1.2471],
         [-0.2053, -0.5184, -0.7115,  ...,  0.5086, -1.2471, 12.5451]],

        [[12.4161,  0.9529,  0.4513,  ...,  0.0627, -0.3439,  1.1411],
         [ 0.9529, 11.4158, -1.2582,  ...,  0.7327, -0.7804,  0.6689],
         [ 0.4513, -1.2582, 10.3306,  ...,  0.2178, -2.4212,  0.8338],
         ...,
         [ 0.0627,  0.7327,  0.2178,  ..., 12.3123, -0.5745,  1.4630],
         [-0.3439, -0.7804, -2.4212,  ..., -0.5745, 12.5936, -0.2378],
         [ 1.1411,  0.6689,  0.8338,  ...,  1.4630, -0.2378, 10.8373]],

        [[10.8862,  0.1236, -1.9391,  ..., -0.6725, -0.1161,  0.2866],
         [ 0.1236, 13.6696,  0.5223,  ...,  0

In [8]:
attn_mat = similarity.softmax(dim = 2)

attn_mat

tensor([[[9.9681e-01, 2.1470e-05, 8.6094e-06,  ..., 1.0149e-05,
          1.4528e-05, 2.4260e-05],
         [1.6480e-05, 9.9764e-01, 5.5082e-05,  ..., 1.1712e-05,
          8.5317e-06, 1.3616e-05],
         [1.9831e-05, 1.6529e-04, 9.9285e-01,  ..., 9.7569e-05,
          3.5865e-05, 3.3681e-05],
         ...,
         [2.6140e-05, 3.9298e-05, 1.0910e-04,  ..., 9.9455e-01,
          5.9280e-05, 1.2759e-04],
         [6.1759e-06, 4.7251e-06, 6.6192e-06,  ..., 9.7842e-06,
          9.9890e-01, 3.6386e-06],
         [2.9000e-06, 2.1204e-06, 1.7479e-06,  ..., 5.9215e-06,
          1.0231e-06, 9.9958e-01]],

        [[9.9964e-01, 1.0506e-05, 6.3626e-06,  ..., 4.3139e-06,
          2.8726e-06, 1.2683e-05],
         [2.8537e-05, 9.9865e-01, 3.1271e-06,  ..., 2.2898e-05,
          5.0429e-06, 2.1482e-05],
         [5.1061e-05, 9.2392e-06, 9.9678e-01,  ..., 4.0428e-05,
          2.8879e-06, 7.4852e-05],
         ...,
         [4.7843e-06, 9.3494e-06, 5.5869e-06,  ..., 9.9939e-01,
          2.529

In [9]:
attn_mat.shape, x.shape

(torch.Size([4, 64, 64]), torch.Size([4, 64, 128]))

In [10]:
output = (attn_mat @ x)

output.shape

torch.Size([4, 64, 128])

In [11]:
linear = nn.Linear(10, 20)

rand = torch.randn(4, 6, 10)

linear(rand).shape

torch.Size([4, 6, 20])

In [12]:
#Single head attention

class Attention(nn.Module):
    def __init__(self, embedding_dimension):

        super().__init__()

        self.embed_dim = embedding_dimension

        self.query = nn.Linear(self.embed_dim, self.embed_dim)
        self.key = nn.Linear(self.embed_dim, self.embed_dim)
        self.value = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self, x):
        
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        similarity = (q @ k.transpose(1, 2)) / self.embed_dim ** 0.5
        attention = similarity.softmax(axis=2)
        output = attention @ v

rand = torch.randn(4, 64, 128)

attn = Attention(embedding_dimension=128)
attn(rand)

In [13]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dimension, num_heads):
        super().__init__()

        self.embed_dim = embedding_dimension
        self.num_heads = num_heads

        self.head_dim = self.embed_dim // self.num_heads

        self.multihead_qkv = nn.ModuleList()

        for head in range(num_heads):
            qkv_proj = nn.ModuleDict(
                [
                  ["Q", nn.Linear(self.embed_dim, self.head_dim)],
                  ["K", nn.Linear(self.embed_dim, self.head_dim)],
                  ["V", nn.Linear(self.embed_dim, self.head_dim)]  
                ]
            )

            self.multihead_qkv.append(qkv_proj)

        self.proj = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self, x):
        head_outs = []

        for head in self.multihead_qkv:

            q = head["Q"](x)
            k = head["K"](x)
            v = head["V"](x)

            similarity = (q @ k.transpose(1, 2)) / self.head_dim ** 0.5
            attention = similarity.softmax(axis=-1)
            output = attention @ v

            head_outs.append(output)

        head_outs = torch.cat(head_outs, dim=-1)

        out = self.proj(head_outs)

        return out
            

rand = torch.randn(4, 64, 128)
attn = MultiHeadAttention(128, 4)
attn(rand).shape

torch.Size([4, 64, 128])

In [14]:
fc = nn.Linear(10, 30)

tensor_1 = torch.randn(5, 10)
tensor_1_out = fc(tensor_1)

print("Inpur shape : ", tensor_1.shape, "Output shape : ", tensor_1_out.shape)

tensor_2 = torch.randn(5, 1, 2, 4, 10)
tensor_2_out = fc(tensor_2)

print("Inpur shape : ", tensor_2.shape, "Output shape : ", tensor_2_out.shape)

Inpur shape :  torch.Size([5, 10]) Output shape :  torch.Size([5, 30])
Inpur shape :  torch.Size([5, 1, 2, 4, 10]) Output shape :  torch.Size([5, 1, 2, 4, 30])


In [15]:
tensor = torch.rand(1, 8, 9)
fc = nn.Linear(9, 9)

q = fc(tensor)

chunk1, chun2, chunk3 = torch.chunk(q, 3, axis=-1)

In [16]:
a = torch.rand(1, 2, 6, 4)
b = torch.rand(1, 2, 4, 3)

a @ b

tensor([[[[0.6284, 0.7227, 0.9596],
          [0.4593, 0.6037, 0.8552],
          [0.4661, 1.0673, 1.0934],
          [0.8888, 0.7978, 0.9769],
          [0.1879, 0.7750, 0.9653],
          [0.6388, 0.9627, 1.3669]],

         [[0.4987, 0.4611, 1.1365],
          [0.7592, 0.4052, 1.3393],
          [0.5425, 0.4535, 1.1568],
          [0.8584, 0.8129, 2.0019],
          [1.1036, 0.8007, 2.1859],
          [0.7237, 0.8247, 1.8851]]]])

In [17]:
class SelfAttentionEncoder(nn.Module):
    def __init__(self, embed_dim, num_heads, attn_p = 0.0, proj_p = 0.0, bias = True):

        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.query = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.key = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.value = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.attn_drop = nn.Dropout(attn_p)

        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_p)


    def forward(self, x):
        batch_size, sequence_length, embed_dim = x.shape

        print(x.shape)

        q = self.query(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
        
        attn = q @ k.transpose(-2, -1) * self.head_dim ** 0.5
        attn = attn.softmax(axis=-1)
        attn = self.attn_drop(attn)
        x = attn @ v

        x = x.transpose(1, 2).reshape(batch_size, sequence_length, embed_dim)
        print(x.shape)

        x = self.proj(x)
        x = self.proj_drop(x)


rand = torch.randn(4, 16, 128)
attn = SelfAttentionEncoder(128, 4)
attn(rand)

torch.Size([4, 16, 128])
torch.Size([4, 16, 128])


In [18]:
rand_attn = torch.rand(1, 6, 6)

attention_mask = torch.tensor([1, 1, 1, 1, 0, 0]).unsqueeze(0).bool()
print(attention_mask.shape)
attention_mask = attention_mask.unsqueeze(1).repeat(1, 6, 1)
print(attention_mask.shape, rand_attn.shape)

# torch.softmax(rand_attn.masked_fill(~attention_mask, -float("inf")), axis=-1)

attention_mask.shape, rand_attn.shape

torch.Size([1, 6])
torch.Size([1, 6, 6]) torch.Size([1, 6, 6])


(torch.Size([1, 6, 6]), torch.Size([1, 6, 6]))

In [19]:
rand_attn = torch.rand(1, 2, 6, 6)

attention_mask = torch.tensor([1, 1, 1, 1, 0, 0]).unsqueeze(0).bool()
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
attention_mask = attention_mask.repeat(1, 2, 6, 1)
print(attention_mask)

# rand_attn.masked_fill_(~attention_mask, float("-inf"))

tensor([[[[ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False]],

         [[ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True, False, False]]]])


In [20]:
class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, attn_p = 0.0, proj_p = 0.0, bias = True):

        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.query = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.key = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.value = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.attn_drop = nn.Dropout(attn_p)

        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_p)


    def forward(self, x, attention_mask = None):
        batch_size, sequence_length, embed_dim = x.shape

        print(x.shape)

        q = self.query(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)

        attn = q @ k.transpose(-2, -1) * self.head_dim ** 0.5
        
        #############################################

        if attention_mask is not None:
            print("Attention mask")

            attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # .repeat(1, 1, sequence_length, 1) Optional
            print(attention_mask[0])
            print(attn[0])

            attn = attn.masked_fill_(~attention_mask, float("-inf"))
            print(attn)

        #############################################

        attn = attn.softmax(axis=-1)
        print(attn)
        attn = self.attn_drop(attn)
        x = attn @ v
        print(attn.shape)

        x = x.transpose(1, 2).reshape(batch_size, sequence_length, embed_dim)
        print(x.shape)

        x = self.proj(x)
        x = self.proj_drop(x)


seq_lens = [3, 5, 4]
embed_dim = 9
num_heads = 3

a = SelfAttention(embed_dim, num_heads)

rand = torch.randn(len(seq_lens), max(seq_lens), embed_dim)

masks = torch.nn.utils.rnn.pad_sequence([torch.ones(l) for l in seq_lens], batch_first=True, padding_value=0).bool()
print("Attention mask : ")
print(masks)

output = a(rand, masks)

Attention mask : 
tensor([[ True,  True,  True, False, False],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True, False]])
torch.Size([3, 5, 9])
Attention mask
tensor([[[ True,  True,  True, False, False]]])
tensor([[[-2.5752,  1.2256,  0.3576, -0.7280,  1.3640],
         [-1.0284,  1.5719,  0.2485, -0.1938,  0.9666],
         [-2.8719,  2.1797,  0.0179, -2.5782,  0.8866],
         [-0.2487,  0.0902,  0.2461,  0.7834,  0.5637],
         [-0.6897,  0.8424,  0.1392, -0.1759,  0.5518]],

        [[ 0.0471,  0.3792, -0.4356,  0.5821,  0.4571],
         [ 0.3958, -0.7316,  0.4995, -0.0352, -0.1791],
         [ 1.4831, -1.3442,  0.6859,  1.0190,  0.4052],
         [-1.0623,  0.0193, -0.3651, -0.1329, -0.2070],
         [ 0.0291,  0.0138,  0.2136, -0.4241, -0.2424]],

        [[-0.3754,  0.0525,  0.8761,  0.5526, -0.2392],
         [-1.5586,  2.0417,  0.4571, -2.6434,  1.8729],
         [-0.9346,  1.1810, -0.4002, -1.8324,  1.6955],
         [ 0.8070, -0.7053, -

In [21]:
seq_len = 8

ones = torch.ones(seq_len, seq_len)
causal_mask = torch.tril(ones).bool()

padding_mask = torch.tensor([1, 1, 1, 1, 0, 0, 0, 0]).bool()
padding_mask = padding_mask.unsqueeze(0).repeat(seq_len, 1)

causal_mask = causal_mask.masked_fill_(~padding_mask, 0)
causal_mask 

tensor([[ True, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False]])

In [22]:
class CausalSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, causal= True, attn_p = 0.0, proj_p = 0.0, bias = True):

        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.causal = causal

        self.query = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.key = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.value = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.attn_drop = nn.Dropout(attn_p)

        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_p)


    def forward(self, x, attention_mask = None):
        batch_size, sequence_length, embed_dim = x.shape


        q = self.query(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)

        attn = q @ k.transpose(-2, -1) * self.head_dim ** 0.5

        if attention_mask is not None:
                attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # .repeat(1, 1, sequence_length, 1) Optional
        
        if self.causal:

            ones = torch.ones((sequence_length, sequence_length), device=attn.device)
            causal_mask = torch.tril(ones).bool()
            causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
            print(causal_mask)

            if attention_mask is not None:

                causal_mask = causal_mask.repeat(batch_size, 1, 1, 1)
                causal_mask = causal_mask.masked_fill_(~attention_mask, False)
            

            attn = attn.masked_fill_(~causal_mask, float("-inf"))

        #############################################

        attn = attn.softmax(axis=-1)
        attn = self.attn_drop(attn)
        x = attn @ v

        x = x.transpose(1, 2).reshape(batch_size, sequence_length, embed_dim)

        x = self.proj(x)
        x = self.proj_drop(x)


seq_lens = [3, 5, 4]
embed_dim = 9
num_heads = 3

a = SelfAttention(embed_dim, num_heads)

rand = torch.randn(len(seq_lens), max(seq_lens), embed_dim)

masks = torch.nn.utils.rnn.pad_sequence([torch.ones(l) for l in seq_lens], batch_first=True, padding_value=0).bool()
print("Attention mask : ")
print(masks)

output = a(rand, masks)

Attention mask : 
tensor([[ True,  True,  True, False, False],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True, False]])
torch.Size([3, 5, 9])
Attention mask
tensor([[[ True,  True,  True, False, False]]])
tensor([[[ 5.6995e-01,  3.3796e-01,  6.3030e-01,  8.9603e-01, -2.3478e-01],
         [ 2.5791e-01,  2.6351e-01,  7.8042e-01,  6.7147e-01, -5.4165e-01],
         [-7.5067e-01, -1.0092e-01,  9.2417e-01,  2.5597e-01, -1.1230e+00],
         [ 2.6883e-01,  3.1667e-01,  1.1072e+00,  1.1027e+00, -7.6815e-01],
         [ 3.4245e-01, -5.6203e-04, -2.8669e-01,  7.5262e-01,  5.7168e-01]],

        [[ 7.3013e-01,  3.6927e-01, -2.1133e-01,  6.1256e-01, -2.9070e-01],
         [ 5.5573e-01,  2.7058e-01, -2.7659e-02,  4.7476e-01, -1.1177e-01],
         [-1.3794e+00, -1.5250e-01, -7.3867e-01, -1.8175e+00, -3.7134e-01],
         [-3.2434e-01, -1.4839e-01,  7.7238e-02, -2.9166e-01,  1.1593e-01],
         [-1.0262e+00, -9.5097e-01, -7.4966e-01, -2.6475e-01, -4.6870e-01]]