In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class SelfAttention(nn.Module):
    def __init__(self, d_model=2, row_dim=0, col_dim=1):
        super().__init__()


        self.w_q = nn.Linear(in_features=d_model,out_features=d_model, bias=False)
        self.w_k = nn.Linear(in_features=d_model,out_features=d_model, bias=False)
        self.w_v = nn.Linear(in_features=d_model,out_features=d_model, bias=False)

        self.row_dim = row_dim
        self.col_dim = col_dim

    def forward(self,token_encoding):
        k = self.w_k(token_encoding)
        q = self.w_q(token_encoding)
        v = self.w_v(token_encoding)

        sims = torch.matmul(q,k.transpose(dim0=self.row_dim,dim1=self.col_dim))
        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**.5)

        attention_percents = F.softmax(scaled_sims,dim= self.col_dim)

        attention_scores = torch.matmul(attention_percents,v )


        return attention_scores



In [4]:
encodings_matrix = torch.tensor([[1.16, 0.23],
                                 [0.57, 1.36],
                                 [4.41, -2.16]])


torch.manual_seed(42)

self_attention = SelfAttention()

self_attention(encodings_matrix)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

In [5]:
class MaskedSelfAttention(nn.Module):

    def __init__(self, d_model=2, row_dim=0, col_dim=1):

        super().__init__()

        self.w_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.w_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.w_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

        self.row_dim = row_dim
        self.col_dim = col_dim

    def forward(self,token_encoding,mask=None):
        k = self.w_k(token_encoding)
        q = self.w_q(token_encoding)
        v = self.w_v(token_encoding)

        sims = torch.matmul(q,k.transpose(dim0=self.row_dim,dim1=self.col_dim))
        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**.5)

        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)
        
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores



In [6]:
encodings_matrix = torch.tensor([[1.16, 0.23],
                                 [0.57, 1.36],
                                 [4.41, -2.16]])

torch.manual_seed(42)

maskedSelfAttention = MaskedSelfAttention(d_model=2,
                               row_dim=0,
                               col_dim=1)

mask = torch.tril(torch.ones(3,3))
mask = mask==0
mask

tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])

In [7]:
maskedSelfAttention(encodings_matrix, mask)

tensor([[ 0.6038,  0.7434],
        [-0.0062,  0.6072],
        [ 3.4989,  2.2427]], grad_fn=<MmBackward0>)

In [14]:
class Attention(nn.Module):

    def __init__(self, d_model=2, row_dim=0,col_dim=1):
        super().__init__()

        self.w_q = nn.Linear(in_features=d_model,out_features=d_model, bias=False)
        self.w_k = nn.Linear(in_features=d_model,out_features=d_model, bias=False)
        self.w_v = nn.Linear(in_features=d_model,out_features=d_model, bias=False)

        self.row_dim = row_dim
        self.col_dim = col_dim

    def forward(self,encodings_for_q, encodings_for_k, encodings_for_v, mask=None):
        q = self.w_q(encodings_for_q)
        k = self.w_k(encodings_for_k)
        v = self.w_v(encodings_for_v)

        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)
            
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores


In [15]:
encodings_for_q = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_k = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_v = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

torch.manual_seed(42)

attention = Attention(d_model=2,
                      row_dim=0,
                      col_dim=1)

attention(encodings_for_q, encodings_for_k, encodings_for_v)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

In [18]:
class MultiHeadAttention(nn.Module):
    def __init__(self, 
                 d_model=2,  
                 row_dim=0, 
                 col_dim=1, 
                 num_heads=1):
        
        super().__init__()

        self.heads = nn.ModuleList([ Attention(d_model=2,
                      row_dim=0,
                      col_dim=1) for _ in range(num_heads)])
        
        self.col_dim =col_dim

    def forward(self,encodings_for_q, encodings_for_k, encodings_for_v):

        return [head(encodings_for_q, 
                               encodings_for_k,
                               encodings_for_v)  for head in self.heads]

In [19]:
torch.manual_seed(42)

## create an attention object
multiHeadAttention = MultiHeadAttention(d_model=2,
                                        row_dim=0,
                                        col_dim=1,
                                        num_heads=1)

## calculate encoder-decoder attention
multiHeadAttention(encodings_for_q, encodings_for_k, encodings_for_v)

[tensor([[1.0100, 1.0641],
         [0.2040, 0.7057],
         [3.4989, 2.2427]], grad_fn=<MmBackward0>)]

In [20]:
torch.manual_seed(42)

## create an attention object
multiHeadAttention = MultiHeadAttention(d_model=2,
                                        row_dim=0,
                                        col_dim=1,
                                        num_heads=2)

## calculate encoder-decoder attention
multiHeadAttention(encodings_for_q, encodings_for_k, encodings_for_v)

[tensor([[1.0100, 1.0641],
         [0.2040, 0.7057],
         [3.4989, 2.2427]], grad_fn=<MmBackward0>),
 tensor([[-0.7081, -0.8268],
         [-0.7417, -0.9193],
         [-0.7190, -0.8447]], grad_fn=<MmBackward0>)]