In [2]:

import torch.nn as nn
import torch 

In [3]:
torch.manual_seed(789)

<torch._C.Generator at 0x1ceffdd4a30>

In [None]:
# self attention mechanism using nn parameter

class SelfAttention_v1(nn.Module) :
    def __init__(self, d_in, d_out) : # d_in should be size of input embedding dimensions 
        super().__init__()
        self.W_Query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_Key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_Value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x) : # Here x is input embeddings that have dimensions as (vocab_size, embedd_dims)
        keys      =   x @ self.W_Key
        queries   =   x @ self.W_Query
        values    =   x @ self.W_Value

        attention_scores = queries @ torch.transpose(keys)
        attention_weights = torch.softmax(attention_scores / torch.sqrt(keys.shape[-1]), dim=-1)

        context_matrix = attention_weights @ values
        return context_matrix


In [7]:
# self_attention_mechanism using nn.Linear

class SelfAttention_v2(nn.Module) :
    
    def __init__(self, d_in, d_out, qkv_bias=False) :
        #super(SelfAttention_v2, self).__init__()
        super().__init__()
        # Trainable keys, queries and values matrices
        self.W_queries = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_keys = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_values = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x) : 
        # Projecting key, value and queries matrix on d_out dimensional
        Keys = self.W_keys(x)
        Queries = self.W_queries(x)
        Values = self.W_values(x)

        atten_scores = Queries @ Keys.T
        atten_weights = torch.softmax(atten_scores / (Keys.shape[-1] ** 0.5), dim=-1)
        context_vectors = atten_weights @ Values
        
        return context_vectors


In [8]:

d_in = 3
d_out = 2
x = torch.tensor([[0.43, 0.15, 0.89], # your
                  [0.55, 0.87, 0.66], # journey
                  [0.57, 0.85, 0.66], # starts
                  [0.22, 0.58, 0.64], # with
                  [0.77, 0.25, 0.10], # one
                  [0.05, 0.80, 0.55]]) # step

self_atten = SelfAttention_v2(d_in, d_out)
print(self_atten.forward(x))

tensor([[-0.0846,  0.0597],
        [-0.0856,  0.0583],
        [-0.0856,  0.0583],
        [-0.0857,  0.0581],
        [-0.0871,  0.0563],
        [-0.0862,  0.0575]], grad_fn=<MmBackward0>)
