In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Self Attention (Encoder Only Transformer)

In [None]:
"write a poem <EOS>"

In [2]:
class SelfAttention(nn.Module):
    def __init__(self, d_model=2, row_dim=0, col_dim=1):
        super().__init__()
        self.W_q=nn.Linear(d_model, d_model,bias=False) #Query
        self.W_k=nn.Linear(d_model, d_model,bias=False) #Key
        self.W_v=nn.Linear(d_model, d_model,bias=False) #Value
        self.row_dim=row_dim
        self.col_dim=col_dim

    def forward(self, token_embeddings):
        q=self.W_q(token_embeddings)
        k=self.W_k(token_embeddings)
        v=self.W_v(token_embeddings)

        sims=torch.matmul(q,k.transpose(dim0=self.row_dim,dim1=self.col_dim))
        scaled_sims=sims/torch.tensor(k.size(self.col_dim)**0.5)
        attention_percents=F.softmax(scaled_sims,dim=self.col_dim)
        attention_output=torch.matmul(attention_percents,v)
        return attention_output

In [3]:
encodings_matrix=torch.tensor([[1.16,.23],
              [.57,1.36],
              [4.41,-2.16]])

In [4]:
torch.manual_seed(42)

<torch._C.Generator at 0x265ffaac5b0>

In [5]:
SelfAttention_model=SelfAttention(d_model=2,row_dim=0,col_dim=1)

In [6]:
SelfAttention_model(encodings_matrix)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

In [26]:
print(SelfAttention_model.W_q.weight)
print(SelfAttention_model.W_q.weight.transpose(0,1))

Parameter containing:
tensor([[ 0.5406,  0.5869],
        [-0.1657,  0.6496]], requires_grad=True)
tensor([[ 0.5406, -0.1657],
        [ 0.5869,  0.6496]], grad_fn=<TransposeBackward0>)


# Masked Self Attention (Decoder Only Transformer)

In [33]:
class MaskedSelfAttention(nn.Module):
    def __init__(self,d_model=2,row_dim=0,col_dim=1):
        super().__init__()
        self.W_q=nn.Linear(d_model, d_model,bias=False) #Query
        self.W_k=nn.Linear(d_model, d_model,bias=False) #Key
        self.W_v=nn.Linear(d_model, d_model,bias=False) #Value
        self.row_dim=row_dim
        self.col_dim=col_dim


    def forward(self,token_encodings,mask=None):
        q=self.W_q(token_encodings)
        k=self.W_k(token_encodings)
        v=self.W_v(token_encodings)
        sims=torch.matmul(q,k.transpose(dim0=self.row_dim,dim1=self.col_dim))
        scaled_sims=sims/torch.tensor(k.size(self.col_dim)**0.5)
        if mask is not None:
            scaled_sims=scaled_sims.masked_fill(mask=mask,value=float('-inf'))
        attention_percents=F.softmax(scaled_sims,dim=self.col_dim)
        attention_output=torch.matmul(attention_percents,v)
        return attention_output



In [34]:
torch.manual_seed(42)

<torch._C.Generator at 0x28ac323b190>

In [35]:
encodings_matrix=torch.tensor([[1.16,.23],
              [.57,1.36],
              [4.41,-2.16]])

In [36]:
maskAttention_model=MaskedSelfAttention(d_model=2,row_dim=0,col_dim=1)

In [37]:
mask=torch.tril(torch.ones(3,3),diagonal=0)==0

In [38]:
mask

tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])

In [39]:
maskAttention_model(encodings_matrix,mask=mask)

tensor([[ 0.6038,  0.7434],
        [-0.0062,  0.6072],
        [ 3.4989,  2.2427]], grad_fn=<MmBackward0>)

# Cross (Encoder-Decoder) Attention 

In [43]:
class Attention(nn.Module):
    def __init__(self,d_model=2,row_dim=0,col_dim=1):
        super().__init__()
        self.W_q=nn.Linear(d_model, d_model,bias=False) #Query
        self.W_k=nn.Linear(d_model, d_model,bias=False) #Key
        self.W_v=nn.Linear(d_model, d_model,bias=False) #Value
        self.row_dim=row_dim
        self.col_dim=col_dim

    def forward(self,encodings_for_q,encodings_for_k,encodings_for_v,mask=None):
        q=self.W_q(encodings_for_q)
        k=self.W_k(encodings_for_k)
        v=self.W_v(encodings_for_v)
        sims=torch.matmul(q,k.transpose(dim0=self.row_dim,dim1=self.col_dim))
        scaled_sims=sims/torch.tensor(k.size(self.col_dim)**0.5)
        if mask is not None:
            scaled_sims=scaled_sims.masked_fill(mask=mask,value=float('-inf'))
        attention_percents=F.softmax(scaled_sims,dim=self.col_dim)
        attention_output=torch.matmul(attention_percents,v)
        return attention_output

In [40]:
encodings_for_q=torch.tensor([[1.16,.23],
              [.57,1.36],
              [4.41,-2.16]])
encodings_for_k=torch.tensor([[1.16,.23],
              [.57,1.36],
              [4.41,-2.16]])
encodings_for_v=torch.tensor([[1.16,.23],
              [.57,1.36],
              [4.41,-2.16]])

In [41]:
torch.manual_seed(42)

<torch._C.Generator at 0x28ac323b190>

In [44]:
attention=Attention(d_model=2,row_dim=0,col_dim=1)

In [45]:
attention(encodings_for_q,encodings_for_k,encodings_for_v)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

# Multi-Head Attention 

In [46]:
class MultiHeadAttention(nn.Module):
    def __init__(self,d_model=2,num_heads=1,row_dim=0,col_dim=1):
        super().__init__()
        self.heads=nn.ModuleList([Attention(d_model=d_model,row_dim=row_dim,col_dim=col_dim) for _ in range(num_heads)])
        self.col_dim=col_dim

    def forward(self,encodings_for_q,encodings_for_k,encodings_for_v):
        return torch.cat([head(encodings_for_q,encodings_for_k,encodings_for_v) for head in self.heads],dim=self.col_dim)

In [51]:
torch.manual_seed(42)

<torch._C.Generator at 0x28ac323b190>

In [52]:
multiheadAttention_model=MultiHeadAttention(d_model=2,num_heads=2,row_dim=0,col_dim=1)

In [53]:
encodings_for_q=torch.tensor([[1.16,.23],
                [.57,1.36],
                [4.41,-2.16]])
encodings_for_k=torch.tensor([[1.16,.23],
                [.57,1.36],
                [4.41,-2.16]])
encodings_for_v=torch.tensor([[1.16,.23],
                [.57,1.36],
                [4.41,-2.16]])

In [54]:
multiheadAttention_model(encodings_for_q,encodings_for_k,encodings_for_v)

tensor([[ 1.0100,  1.0641, -0.7081, -0.8268],
        [ 0.2040,  0.7057, -0.7417, -0.9193],
        [ 3.4989,  2.2427, -0.7190, -0.8447]], grad_fn=<CatBackward0>)