In [2]:
import torch

In [3]:

if torch.backends.mps.is_available():
    torch_device = torch.device("mps")
    x = torch.ones(1, device=torch_device)
    print (x)
else:
    torch_device = torch.device("cpu")
    print ("MPS device not found.")

tensor([1.], device='mps:0')


## Module

In [8]:
import torch.nn as nn

EMBEDDING_DIM = 3
CONTEXT_LENGTH = 4
QKV_DIM = 2

class CausalSelfAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length=CONTEXT_LENGTH, dropout=0.5, qkv_bias=False):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.w_q = nn.Linear(d_in, d_out, bias=qkv_bias).to(torch_device)
        self.w_k = nn.Linear(d_in, d_out, bias=qkv_bias).to(torch_device)
        self.w_v = nn.Linear(d_in, d_out, bias=qkv_bias).to(torch_device)
        self.dropout = nn.Dropout(dropout).to(torch_device)
        self.register_buffer(
            'mask', 
            torch.triu(
                torch.ones(context_length, context_length), 
                diagonal=1,
            ).to(torch_device)
        )

    def forward(self, x):
        num_tokens = x.shape[-2]
        queries = self.w_q(x)
        keys = self.w_k(x)
        attn_scores = queries @ keys.transpose(-2, -1)
        causal_attn_scores = attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        causal_attn_weights = torch.softmax(causal_attn_scores*(self.d_out**0.5), dim=-1)
        causal_attn_weights = self.dropout(causal_attn_weights)
        values = self.w_v(x)
        context = causal_attn_weights @ values
        return context

In [9]:
csa = CausalSelfAttention(EMBEDDING_DIM, QKV_DIM)
input_embeddings = torch.randn(8, CONTEXT_LENGTH, EMBEDDING_DIM).to(torch_device)/(EMBEDDING_DIM**0.5)
context = csa(input_embeddings)
context[0]

tensor([[-0.1064, -0.0565],
        [ 0.0000,  0.0000],
        [ 0.0390, -0.0119],
        [-0.0236, -0.0125]], device='mps:0', grad_fn=<SelectBackward0>)

In [10]:
import torch.nn as nn

EMBEDDING_DIM = 3
CONTEXT_LENGTH = 4
QKV_DIM = 2

class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalSelfAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)]
        )

    def forward(self, x):
        context = torch.cat([head(x) for head in self.heads], dim=-1)
        return context

In [11]:
mha = MultiHeadAttentionWrapper(EMBEDDING_DIM, QKV_DIM, CONTEXT_LENGTH, 0.5, 2)
context = mha(input_embeddings)
context.shape

torch.Size([8, 4, 4])

In [24]:
import torch.nn as nn

EMBEDDING_DIM = 3
CONTEXT_LENGTH = 4
QKV_DIM = 2

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = d_out//num_heads
        self.w_q = nn.Linear(d_in, d_out, bias=qkv_bias).to(torch_device)
        self.w_k = nn.Linear(d_in, d_out, bias=qkv_bias).to(torch_device)
        self.w_v = nn.Linear(d_in, d_out, bias=qkv_bias).to(torch_device)
        self.w_o = nn.Linear(d_out, d_out).to(torch_device)
        self.dropout = nn.Dropout(dropout).to(torch_device)
        self.register_buffer(
            'mask', 
            torch.triu(
                torch.ones(context_length, context_length), 
                diagonal=1,
            ).to(torch_device)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.w_k(x)
        queries = self.w_q(x)
        values = self.w_v(x)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim).transpose(-3, -2)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(-3, -2)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim).transpose(-3, -2)
        attn_scores = queries @ keys.transpose(-2, -1)
        causal_attn_scores = attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        causal_attn_weights = torch.softmax(causal_attn_scores*(keys.shape[-1]**0.5), dim=-1)
        causal_attn_weights = self.dropout(causal_attn_weights)
        context = (causal_attn_weights @ values).transpose(-3, -2)
        context = context.contiguous().view(b, num_tokens, self.num_heads*self.head_dim)
        context = self.w_o(context)
        return context


In [26]:
EMBEDDING_DIM = 3
CONTEXT_LENGTH = 4
QKV_DIM = 4

mha = MultiHeadAttention(EMBEDDING_DIM, QKV_DIM, CONTEXT_LENGTH, 0.5, 2)
context = mha(input_embeddings)
context.shape

torch.Size([8, 4, 2, 2])
torch.Size([8, 4, 4])
torch.Size([4, 4])


torch.Size([8, 4, 4])