In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Attention

#### Self-Attention

In [3]:
class SelfAttention(nn.Module):

    def __init__(self , d_model = 2,
                row_dim = 0,
                col_dim = 1):

        super().__init__()

        # weight matrices
        
        self.W_q = nn.Linear(in_features = d_model,
                            out_features = d_model,
                            bias = False) 
        self.W_k = nn.Linear(in_features = d_model,
                            out_features = d_model,
                            bias = False)
        
        self.W_v = nn.Linear(in_features = d_model,
                            out_features = d_model,
                            bias = False)

        self.row_dim = row_dim
        self.col_dim = col_dim
        self.d_model = d_model


    
    def forward(self , token_encodings):
        q = self.W_q(token_encodings)
        k = self.W_k(token_encodings)
        v = self.W_v(token_encodings)
        
        sims = torch.matmul(q , k.transpose(dim0 = self.row_dim,
                                           dim1 = self.col_dim))
        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)
        # scaled_sims = sims / torch.tensor(self.d_model**0.5)

        attention_percents = F.softmax(scaled_sims , dim = self.col_dim)
        attention_scores = torch.matmul(attention_percents , v)

        return attention_scores

    
        

In [4]:
encodings_matrix = torch.tensor([[1.16, 0.23],[0.57, 1.36],[4.41, -2.16]])
torch.manual_seed(42)

<torch._C.Generator at 0x7eb0ce34abf0>

In [5]:
self_attention = SelfAttention(d_model = 2,
                              row_dim = 0,
                              col_dim = 1)

In [6]:
self_attention(encodings_matrix)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

In [7]:
for name, param in self_attention.named_parameters():
    print(f"Parameter name: {name}")
    print(f"Parameter values:\n{param.data}")
    print("-" * 30)

Parameter name: W_q.weight
Parameter values:
tensor([[ 0.5406,  0.5869],
        [-0.1657,  0.6496]])
------------------------------
Parameter name: W_k.weight
Parameter values:
tensor([[-0.1549,  0.1427],
        [-0.3443,  0.4153]])
------------------------------
Parameter name: W_v.weight
Parameter values:
tensor([[ 0.6233, -0.5188],
        [ 0.6146,  0.1323]])
------------------------------


In [8]:
# Get the state dictionary
state_dict = self_attention.state_dict()

# Print the state dictionary
for key, value in state_dict.items():
    print(f"Key: {key}")
    print(f"Values:\n{value}")
    print("-" * 30)

Key: W_q.weight
Values:
tensor([[ 0.5406,  0.5869],
        [-0.1657,  0.6496]])
------------------------------
Key: W_k.weight
Values:
tensor([[-0.1549,  0.1427],
        [-0.3443,  0.4153]])
------------------------------
Key: W_v.weight
Values:
tensor([[ 0.6233, -0.5188],
        [ 0.6146,  0.1323]])
------------------------------


In [9]:
state_dict

OrderedDict([('W_q.weight',
              tensor([[ 0.5406,  0.5869],
                      [-0.1657,  0.6496]])),
             ('W_k.weight',
              tensor([[-0.1549,  0.1427],
                      [-0.3443,  0.4153]])),
             ('W_v.weight',
              tensor([[ 0.6233, -0.5188],
                      [ 0.6146,  0.1323]]))])

In [10]:
self_attention.W_q.weight.transpose(0,1)

tensor([[ 0.5406, -0.1657],
        [ 0.5869,  0.6496]], grad_fn=<TransposeBackward0>)

In [11]:
self_attention.W_k.weight.transpose(0,1)

tensor([[-0.1549, -0.3443],
        [ 0.1427,  0.4153]], grad_fn=<TransposeBackward0>)

In [12]:
self_attention.W_v.weight.transpose(0,1)

tensor([[ 0.6233,  0.6146],
        [-0.5188,  0.1323]], grad_fn=<TransposeBackward0>)

#### Masked Self-Attention

In [38]:
class MaskedSelfAttention(nn.Module):

    def __init__(self , d_model = 2,
                row_dim = 0,
                col_dim = 1):

        super().__init__()

        # weight matrices
        
        self.W_q = nn.Linear(in_features = d_model,
                            out_features = d_model,
                            bias = False) 
        self.W_k = nn.Linear(in_features = d_model,
                            out_features = d_model,
                            bias = False)
        
        self.W_v = nn.Linear(in_features = d_model,
                            out_features = d_model,
                            bias = False)
        

        self.row_dim = row_dim
        self.col_dim = col_dim
        self.d_model = d_model
        

    
    def forward(self , token_encodings , mask = None):
        q = self.W_q(token_encodings)
        k = self.W_k(token_encodings)
        v = self.W_v(token_encodings)
        
        sims = torch.matmul(q , k.transpose(dim0 = self.row_dim,
                                           dim1 = self.col_dim))
        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)
        # scaled_sims = sims / torch.tensor(self.d_model**0.5)

        if mask is not None:
            # mask = torch.tril(torch.ones(k.size(self.row_dim),k.size(self.row_dim)))
            # mask = mask == 0
            scaled_sims = scaled_sims.masked_fill(mask = mask,
                                                 value = -1e-9)       

        
     

        
        attention_percents = F.softmax(scaled_sims , dim = self.col_dim)
        attention_scores = torch.matmul(attention_percents , v)
        
        return attention_scores

    
        

In [39]:
encodings_matrix = torch.tensor([[1.16, 0.23],[0.57, 1.36],[4.41, -2.16]])
torch.manual_seed(42)

<torch._C.Generator at 0x7eb0ce34abf0>

In [40]:
masked_self_attention = MaskedSelfAttention(d_model = 2,
                              row_dim = 0,
                              col_dim = 1)

In [41]:
mask = torch.tril(torch.ones(3,3))
mask

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [42]:
mask = mask == 0

In [43]:
mask

tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])

In [44]:
masked_self_attention(encodings_matrix , mask)

tensor([[1.3921, 1.2440],
        [1.2494, 1.1960],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

#### Encoder-Decoder Attention

In [57]:
class Attention(nn.Module):

    def __init__(self ,
                 d_model = 2,
                 row_dim = 0,
                 col_dim = 1):

        super().__init__()

        # weight matrices
        
        self.W_q = nn.Linear(in_features = d_model,
                            out_features = d_model,
                            bias = False) 
        self.W_k = nn.Linear(in_features = d_model,
                            out_features = d_model,
                            bias = False)
        
        self.W_v = nn.Linear(in_features = d_model,
                            out_features = d_model,
                            bias = False)

        self.row_dim = row_dim
        self.col_dim = col_dim
        self.d_model = d_model


    
    def forward(self , 
                encodings_q ,
                encodings_k ,
                encodings_v,
               mask = None):
        
        q = self.W_q(encodings_q)
        k = self.W_k(encodings_k)
        v = self.W_v(encodings_v)
        
        sims = torch.matmul(q , k.transpose(dim0 = self.row_dim,
                                           dim1 = self.col_dim))
        
        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)
        # scaled_sims = sims / torch.tensor(self.d_model**0.5)

        if mask is not None:
            # mask = torch.tril(torch.ones(k.size(self.row_dim),k.size(self.row_dim)))
            # mask = mask == 0
            scaled_sims = scaled_sims.masked_fill(mask = mask,
                                                 value = -1e-9)       
       
        attention_percents = F.softmax(scaled_sims , dim = self.col_dim)
        attention_scores = torch.matmul(attention_percents , v)

        return attention_scores

    
class MultiHeadAttention(nn.Module):
    
    def __init__(self ,
                 d_model = 2,
                 row_dim = 0,
                 col_dim = 1,
                num_heads = 1):
        super().__init__()


        self.heads = nn.ModuleList(
            [Attention(d_model = 2,row_dim = 0,col_dim = 1) 
             for _ in range(num_heads)]
        )

        self.col_dim = col_dim


    def forward(self , 
                encodings_q ,
                encodings_k ,
                encodings_v):

        return torch.cat(
            [head(encodings_q ,
                encodings_k ,
                encodings_v)
             for head in self.heads] , dim = self.col_dim
        )

In [58]:
encodings_q = torch.tensor([[1.3921, 1.2440],
        [1.2494, 1.1960],
        [3.4989, 2.2427]])

encodings_k = torch.tensor([[1.3921, 1.2440],
        [1.2494, 1.1960],
        [3.4989, 2.2427]])


encodings_v = torch.tensor([[1.3921, 1.2440],
        [1.2494, 1.1960],
        [3.4989, 2.2427]])

torch.manual_seed(42)

<torch._C.Generator at 0x7eb0ce34abf0>

In [59]:
mha = MultiHeadAttention(d_model = 2,
                 row_dim = 0,
                 col_dim = 1,
                num_heads = 3)

In [60]:
mha(encodings_q,encodings_q,encodings_q)

tensor([[ 0.4082,  1.3614, -0.9187,  0.1218,  1.3103,  0.2117],
        [ 0.4108,  1.3661, -0.9077,  0.1238,  1.3070,  0.2119],
        [ 0.3628,  1.2806, -1.0814,  0.0919,  1.3522,  0.2101]],
       grad_fn=<CatBackward0>)