In [5]:
import torch 
import torch.nn as nn
inputs = torch.tensor( 
  [[0.43, 0.15, 0.89], # Your (x^1) 
   [0.55, 0.87, 0.66], # journey (x^2) 
   [0.57, 0.85, 0.64], # starts (x^3) 
   [0.22, 0.58, 0.33], # with (x^4) 
   [0.77, 0.25, 0.10], # one (x^5) 
   [0.05, 0.80, 0.55]] # step (x^6) 
)

In [10]:
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_q = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_k = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_v = nn.Linear(d_in, d_out, bias=qkv_bias)
    
    def forward(self, x):
        q = self.W_q(x) 
        k = self.W_k(x) 
        v = self.W_v(x)
        attn_scores = torch.matmul(q, k.T)
        attn_weights = torch.softmax(attn_scores / k.shape[-1] ** 0.5, dim=-1)
        context_vecs = torch.matmul(attn_weights, v)
        return context_vecs

In [12]:
torch.manual_seed(789)
sa_v1 = SelfAttention(3, 2)
print(sa_v1(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


In [None]:
import torch 
import torch.nn as nn

class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, drop_ratio, qkv_bias=False):
        self.wq = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.wk = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.wv = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(drop_ratio)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1),
        )
    
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        q = self.wq(x)
        k = self.wk(x)
        v = self.wv(x)
        attn_scores = torch.matmul(q, k.T)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens,:num_tokens], -torch.inf
        )
        attn_weights = torch.softmax(attn_scores / k.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = torch.matmul(attn_weights, v)
        return context_vec
        

In [14]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__)(self, num_heads, d_in, d_out, context_length, dropout, qkv_bias=False):
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias=False)) for i in range(num_heads)]
        )
        
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, print("d_out must be divisible by num_heads")
        
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.wq = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.wk = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.wv = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
        
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        
        q = self.wq(x)
        k = self.wk(x)
        v = self.wv(x)
        
        q = q.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1,2)
        k = k.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1,2)
        v = v.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1,2)
        
        attn_scores = torch.matmul(q,k.transpose(2,3))
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores/k.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = torch.matmul(attn_weights, v)

        context_vec = context_vec.transpose(1,2).contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        
        return context_vec