In [3]:
import torch

In [4]:

if torch.backends.mps.is_available():
    torch_device = torch.device("mps")
    x = torch.ones(1, device=torch_device)
    print (x)
else:
    torch_device = torch.device("cpu")
    print ("MPS device not found.")

tensor([1.], device='mps:0')


## Module

In [27]:
import torch.nn as nn

EMBEDDING_DIM = 3
CONTEXT_LENGTH = 4
QKV_DIM = 2

class CausalSelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.w_q = nn.Linear(d_in, d_out, bias=qkv_bias).to(torch_device)
        self.w_k = nn.Linear(d_in, d_out, bias=qkv_bias).to(torch_device)
        self.w_v = nn.Linear(d_in, d_out, bias=qkv_bias).to(torch_device)

    def forward(self, x):
        queries = self.w_q(x)
        keys = self.w_k(x)
        attn_scores = torch.matmul(queries, keys.transpose(1, 2))
        attn_weights = torch.softmax(attn_scores*(self.d_out**0.5), dim=-1)
        causal_attn_weights = torch.tril(attn_weights)
        row_sum = causal_attn_weights.sum(dim=-1, keepdim=True)
        causal_attn_weights = causal_attn_weights / row_sum
        values = self.w_v(x)
        context = torch.matmul(causal_attn_weights, values)  # Is this correct or should I transpose attn_weights?
        return context

In [33]:
torch.manual_seed(123)
input_embeddings = torch.randn(8, CONTEXT_LENGTH, EMBEDDING_DIM).to(torch_device)/(EMBEDDING_DIM**0.5)
csa_v1 = CausalSelfAttention_v1(EMBEDDING_DIM, QKV_DIM)
context = csa_v1(input_embeddings)
context[0]

tensor([[-0.0159, -0.0788],
        [ 0.0206,  0.0283],
        [ 0.0588, -0.0017],
        [-0.0406,  0.0841]], device='mps:0', grad_fn=<SelectBackward0>)

In [36]:
import torch.nn as nn

EMBEDDING_DIM = 3
CONTEXT_LENGTH = 4
QKV_DIM = 2

class CausalSelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.w_q = nn.Linear(d_in, d_out, bias=qkv_bias).to(torch_device)
        self.w_k = nn.Linear(d_in, d_out, bias=qkv_bias).to(torch_device)
        self.w_v = nn.Linear(d_in, d_out, bias=qkv_bias).to(torch_device)

    def forward(self, x):
        queries = self.w_q(x)
        keys = self.w_k(x)
        attn_scores = torch.matmul(queries, keys.transpose(1, 2))
        causal_attn_scores = attn_scores.masked_fill(torch.triu(torch.ones_like(attn_scores), 1) == 1, float('-inf'))
        causal_attn_weights = torch.softmax(causal_attn_scores*(self.d_out**0.5), dim=-1)
        values = self.w_v(x)
        context = torch.matmul(causal_attn_weights, values)  # Is this correct or should I transpose attn_weights?
        return context

In [37]:
torch.manual_seed(123)
# input_embeddings = torch.randn(8, CONTEXT_LENGTH, EMBEDDING_DIM).to(torch_device)/(EMBEDDING_DIM**0.5)
csa_v2 = CausalSelfAttention_v2(EMBEDDING_DIM, QKV_DIM)
csa_v2.w_q.weight.data = csa_v1.w_q.weight.data
csa_v2.w_k.weight.data = csa_v1.w_k.weight.data
csa_v2.w_v.weight.data = csa_v1.w_v.weight.data
context = csa_v2(input_embeddings)
context[0]

tensor([[-0.0159, -0.0788],
        [ 0.0206,  0.0283],
        [ 0.0588, -0.0017],
        [-0.0406,  0.0841]], device='mps:0', grad_fn=<SelectBackward0>)

In [102]:
import torch.nn as nn

EMBEDDING_DIM = 3
CONTEXT_LENGTH = 4
QKV_DIM = 2

class CausalSelfAttention_v3(nn.Module):
    def __init__(self, d_in, d_out, context_length=CONTEXT_LENGTH, dropout=0.5, qkv_bias=False):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.w_q = nn.Linear(d_in, d_out, bias=qkv_bias).to(torch_device)
        self.w_k = nn.Linear(d_in, d_out, bias=qkv_bias).to(torch_device)
        self.w_v = nn.Linear(d_in, d_out, bias=qkv_bias).to(torch_device)
        self.dropout = nn.Dropout(dropout).to(torch_device)
        self.register_buffer(
            'mask', 
            torch.triu(
                torch.ones(context_length, context_length), 
                diagonal=1,
            ).to(torch_device)
        )

    def forward(self, x):
        num_tokens = x.shape[-2]
        queries = self.w_q(x)
        keys = self.w_k(x)
        attn_scores = torch.matmul(queries, keys.transpose(-2, -1))
        causal_attn_scores = attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        causal_attn_weights = torch.softmax(causal_attn_scores*(self.d_out**0.5), dim=-1)
        causal_attn_weights = self.dropout(causal_attn_weights)
        values = self.w_v(x)
        context = torch.matmul(causal_attn_weights, values)
        return context

In [103]:
csa_v3 = CausalSelfAttention_v3(EMBEDDING_DIM, QKV_DIM)
csa_v3.w_q.weight.data = csa_v1.w_q.weight.data
csa_v3.w_k.weight.data = csa_v1.w_k.weight.data
csa_v3.w_v.weight.data = csa_v1.w_v.weight.data
context = csa_v3(input_embeddings)
context[0]

tensor([[ 0.0000,  0.0000],
        [-0.0162, -0.0804],
        [ 0.0809, -0.0908],
        [ 0.0000,  0.0000]], device='mps:0', grad_fn=<SelectBackward0>)

In [104]:
import torch.nn as nn

EMBEDDING_DIM = 3
CONTEXT_LENGTH = 4
QKV_DIM = 2

class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalSelfAttention_v3(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)]
        )

    def forward(self, x):
        context = torch.cat([head(x) for head in self.heads], dim=-1)
        return context

In [106]:
mha = MultiHeadAttentionWrapper(EMBEDDING_DIM, QKV_DIM, CONTEXT_LENGTH, 0.5, 2)
context = mha(input_embeddings)
context.shape

torch.Size([8, 4, 4])