In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as f

batch_size = 4
sequence_length = 64
embed_dim = 128

x = torch.randn(batch_size, sequence_length, embed_dim)
print("Shape of input is : ", x.shape)

Shape of input is :  torch.Size([4, 64, 128])


In [15]:
x.shape

torch.Size([4, 64, 128])

In [16]:
x.transpose(1, 2).shape

torch.Size([4, 128, 64])

In [17]:
similarity = (x @ x.transpose(1, 2))

In [18]:
similarity.shape

torch.Size([4, 64, 64])

In [19]:
similarity.var()

tensor(382.4670)

In [20]:
similarity = (x @ x.transpose(1, 2)) / (embed_dim ** 0.5)

similarity

tensor([[[ 1.0671e+01,  1.6319e+00,  1.5174e+00,  ...,  1.6468e-01,
           1.1859e+00, -3.4202e-01],
         [ 1.6319e+00,  1.3229e+01,  1.1566e+00,  ...,  1.9852e+00,
          -1.3121e+00,  1.0552e+00],
         [ 1.5174e+00,  1.1566e+00,  1.2758e+01,  ..., -1.2592e+00,
           4.8482e-01, -7.6433e-01],
         ...,
         [ 1.6468e-01,  1.9852e+00, -1.2592e+00,  ...,  1.0402e+01,
           8.9405e-01,  8.5003e-01],
         [ 1.1859e+00, -1.3121e+00,  4.8482e-01,  ...,  8.9405e-01,
           1.0861e+01, -4.6091e-01],
         [-3.4202e-01,  1.0552e+00, -7.6433e-01,  ...,  8.5003e-01,
          -4.6091e-01,  1.2508e+01]],

        [[ 1.3754e+01,  1.0715e+00,  6.5350e-01,  ...,  7.7999e-02,
           3.0643e-01,  1.6925e+00],
         [ 1.0715e+00,  1.1790e+01, -1.2433e+00,  ..., -9.6008e-01,
           5.8391e-01, -4.2873e-01],
         [ 6.5350e-01, -1.2433e+00,  9.5225e+00,  ...,  5.8457e-01,
           9.7283e-01,  1.1810e+00],
         ...,
         [ 7.7999e-02, -9

In [21]:
attn_mat = similarity.softmax(dim = 2)

attn_mat

tensor([[[9.9766e-01, 1.1842e-04, 1.0561e-04,  ..., 2.7303e-05,
          7.5811e-05, 1.6450e-05],
         [9.1873e-06, 9.9981e-01, 5.7118e-06,  ..., 1.3081e-05,
          4.8376e-07, 5.1610e-06],
         [1.3123e-05, 9.1477e-06, 9.9975e-01,  ..., 8.1687e-07,
          4.6727e-06, 1.3399e-06],
         ...,
         [3.5677e-05, 2.2032e-04, 8.5903e-06,  ..., 9.9626e-01,
          7.3985e-05, 7.0799e-05],
         [6.2693e-05, 5.1563e-06, 3.1098e-05,  ..., 4.6823e-05,
          9.9780e-01, 1.2078e-05],
         [2.6245e-06, 1.0613e-05, 1.7205e-06,  ..., 8.6448e-06,
          2.3303e-06, 9.9964e-01]],

        [[9.9987e-01, 3.1056e-06, 2.0446e-06,  ..., 1.1500e-06,
          1.4451e-06, 5.7788e-06],
         [2.2120e-05, 9.9920e-01, 2.1852e-06,  ..., 2.9006e-06,
          1.3584e-05, 4.9346e-06],
         [1.3972e-04, 2.0965e-05, 9.9319e-01,  ..., 1.3041e-04,
          1.9228e-04, 2.3678e-04],
         ...,
         [2.2457e-05, 7.9527e-06, 3.7269e-05,  ..., 9.9826e-01,
          1.462

In [22]:
attn_mat.shape, x.shape

(torch.Size([4, 64, 64]), torch.Size([4, 64, 128]))

In [23]:
output = (attn_mat @ x)

output.shape

torch.Size([4, 64, 128])

In [24]:
linear = nn.Linear(10, 20)

rand = torch.randn(4, 6, 10)

linear(rand).shape

torch.Size([4, 6, 20])

In [25]:
#Single head attention

class Attention(nn.Module):
    def __init__(self, embedding_dimension):

        super().__init__()

        self.embed_dim = embedding_dimension

        self.query = nn.Linear(self.embed_dim, self.embed_dim)
        self.key = nn.Linear(self.embed_dim, self.embed_dim)
        self.value = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self, x):
        
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        similarity = (q @ k.transpose(1, 2)) / self.embed_dim ** 0.5
        attention = similarity.softmax(axis=2)
        output = attention @ v

rand = torch.randn(4, 64, 128)

attn = Attention(embedding_dimension=128)
attn(rand)

In [26]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dimension, num_heads):
        super().__init__()

        self.embed_dim = embedding_dimension
        self.num_heads = num_heads

        self.head_dim = self.embed_dim // self.num_heads

        self.multihead_qkv = nn.ModuleList()

        for head in range(num_heads):
            qkv_proj = nn.ModuleDict(
                [
                  ["Q", nn.Linear(self.embed_dim, self.head_dim)],
                  ["K", nn.Linear(self.embed_dim, self.head_dim)],
                  ["V", nn.Linear(self.embed_dim, self.head_dim)]  
                ]
            )

            self.multihead_qkv.append(qkv_proj)

        self.proj = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self, x):
        head_outs = []

        for head in self.multihead_qkv:

            q = head["Q"](x)
            k = head["K"](x)
            v = head["V"](x)

            similarity = (q @ k.transpose(1, 2)) / self.embed_dim ** 0.5
            attention = similarity.softmax(axis=-1)
            output = attention @ v

            head_outs.append(output)

        head_outs = torch.cat(head_outs, dim=-1)

        out = self.proj(head_outs)

        return out
            

rand = torch.randn(4, 64, 128)
attn = MultiHeadAttention(128, 4)
attn(rand).shape

torch.Size([4, 64, 128])

In [27]:
fc = nn.Linear(10, 30)

tensor_1 = torch.randn(5, 10)
tensor_1_out = fc(tensor_1)

print("Inpur shape : ", tensor_1.shape, "Output shape : ", tensor_1_out.shape)

tensor_2 = torch.randn(5, 1, 2, 4, 10)
tensor_2_out = fc(tensor_2)

print("Inpur shape : ", tensor_2.shape, "Output shape : ", tensor_2_out.shape)

Inpur shape :  torch.Size([5, 10]) Output shape :  torch.Size([5, 30])
Inpur shape :  torch.Size([5, 1, 2, 4, 10]) Output shape :  torch.Size([5, 1, 2, 4, 30])


In [31]:
tensor = torch.rand(1, 8, 9)
fc = nn.Linear(9, 9)

q = fc(tensor)

chunk1, chun2, chunk3 = torch.chunk(q, 3, axis=-1)

In [36]:
a = torch.rand(1, 2, 6, 4)
b = torch.rand(1, 2, 4, 3)

a @ b

tensor([[[[0.9982, 0.6132, 0.4496],
          [0.5950, 0.3582, 0.3470],
          [1.2939, 0.3835, 1.1406],
          [0.9041, 0.3866, 0.6114],
          [1.7082, 0.9127, 0.9175],
          [0.8884, 0.6366, 0.5170]],

         [[1.7517, 1.1400, 0.7104],
          [2.2031, 1.3760, 0.8078],
          [1.6440, 1.0661, 0.4796],
          [0.7265, 0.5810, 0.3554],
          [0.6167, 0.4061, 0.1374],
          [1.5962, 1.1938, 0.7229]]]])

In [38]:
class SelfAttentionEncoder(nn.Module):
    def __init__(self, embed_dim, num_heads, attn_p = 0.0, proj_p = 0.0):

        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.attn_drop = nn.Linear(embed_dim, embed_dim)

        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_p)


    def forward(self, x):
        batch_size, sequence_length, embed_dim = x.shape 
        q = self.query(x).reshape(batch_size, sequence_length, self.num_heads, self.head_dim)
        print(q.shape)

rand = torch.randn(4, 16, 128)
attn = SelfAttentionEncoder(128, 2)
attn(rand)


torch.Size([4, 16, 2, 64])
