In [25]:
from importlib.metadata import version
import torch
torch.manual_seed(123)
print("TORCH VERSION :", version("torch"))
device = "cuda" if torch.cuda.is_available() else 'mps' if torch.backend.mps.is_available() else 'cpu'
print('GPU  : ', device)

TORCH VERSION : 2.2.1
GPU  :  cuda


In [26]:
import torch.nn as nn
import torch.nn.functional as F

In [27]:
# Naive Self Attention Implementation
inputs = torch.rand(3,5)
print(inputs)
attention_scores = inputs @ inputs.T
attention_weights = F.softmax(attention_scores,dim=-1)
attention_weights

tensor([[0.2961, 0.5166, 0.2517, 0.6886, 0.0740],
        [0.8665, 0.1366, 0.1025, 0.1841, 0.7264],
        [0.3153, 0.6871, 0.0756, 0.1966, 0.3164]])


tensor([[0.4070, 0.2828, 0.3103],
        [0.2295, 0.5150, 0.2555],
        [0.3217, 0.3264, 0.3519]])

This self-attention mechanism is also called "scaled dot-product attention".

In [100]:
class SingleHeadAttention(nn.Module):
    def __init__(self,query,key,value,blocksize,dropout, mask=False,qkv_bias=False):
        super().__init__()
        d_model = query.shape[-1]
        

        self.w_query = nn.Linear(d_model,d_model,bias=qkv_bias)
        self.w_key = nn.Linear(d_model,d_model,bias=qkv_bias)
        self.w_value = nn.Linear(d_model,d_model,bias=qkv_bias)

        self.dropout = nn.Dropout(dropout)

        self.register_buffer("mask",torch.triu(torch.ones(blocksize,blocksize),diagonal=1))

    def forward(self,x):
        batch, token_count, embeddings = x.shape
        queries = self.w_query(x)
        keys = self.w_key(x)
        values = self.w_value(x)
        attention_scores = queries @ keys.transpose(1,2)
        attention_scores.masked_fill_(self.mask.bool()[:token_count, :token_count],-torch.inf)

        attention_weights = F.softmax(attention_scores/ keys.shape[-1]**0.5,dim=-1)
        context_vector = attention_weights @ values
        return context_vector

In [101]:
inp = torch.rand(1,3,5)

In [102]:
sha = SingleHeadAttention(inp,inp,inp,3,0.1)

In [104]:
sha(inp).shape

torch.Size([1, 3, 5])