In [2]:
import torch

In [3]:

if torch.backends.mps.is_available():
    torch_device = torch.device("mps")
    x = torch.ones(1, device=torch_device)
    print (x)
else:
    torch_device = torch.device("cpu")
    print ("MPS device not found.")

tensor([1.], device='mps:0')


In [59]:
EMBEDDING_DIM = 3
CONTEXT_LENGTH = 4
QKV_DIM = 2

In [60]:
import math

In [61]:
input_embeddings = torch.randn(8, CONTEXT_LENGTH, EMBEDDING_DIM).to(torch_device)/math.sqrt(EMBEDDING_DIM)

### Trainable

In [62]:
w_q = torch.randn(EMBEDDING_DIM, QKV_DIM).to(torch_device)
w_k = torch.randn(EMBEDDING_DIM, QKV_DIM).to(torch_device)
w_v = torch.randn(EMBEDDING_DIM, QKV_DIM).to(torch_device)

In [63]:
query = torch.matmul(input_embeddings, w_q)
keys = torch.matmul(input_embeddings, w_k)
attn_scores = torch.matmul(query, keys.transpose(1, 2))
attn_weights = torch.softmax(attn_scores*(QKV_DIM**0.5), dim=-1)
values = torch.matmul(input_embeddings, w_v)
context = torch.matmul(attn_weights, values)  # Is this correct or should I transpose attn_weights?

In [64]:
attn_scores[0]

tensor([[-0.4733,  2.4863,  1.9936, -1.4273],
        [ 1.4557, -0.1330, -0.3731,  1.0159],
        [ 1.4137, -1.4526, -1.3766,  1.5809],
        [-0.5814,  0.0731,  0.1643, -0.4147]], device='mps:0')

In [65]:
attn_weights[0]

tensor([[0.0100, 0.6590, 0.3283, 0.0026],
        [0.5821, 0.0616, 0.0438, 0.3125],
        [0.4341, 0.0075, 0.0084, 0.5499],
        [0.1305, 0.3294, 0.3748, 0.1652]], device='mps:0')

In [66]:
attn_weights[0].sum(dim=-1, keepdim=True)

tensor([[1.0000],
        [1.0000],
        [1.0000],
        [1.0000]], device='mps:0')

In [67]:
context[0]

tensor([[ 0.1628, -0.2580],
        [-0.2273,  0.0354],
        [-0.1950,  0.1418],
        [ 0.0869, -0.1119]], device='mps:0')

## Module

In [77]:
import torch.nn as nn

EMBEDDING_DIM = 3
CONTEXT_LENGTH = 4
QKV_DIM = 2

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.w_q = torch.randn(d_in, d_out).to(torch_device)
        self.w_k = torch.randn(d_in, d_out).to(torch_device)
        self.w_v = torch.randn(d_in, d_out).to(torch_device)

    def forward(self, x):
        query = torch.matmul(input_embeddings, self.w_q)
        keys = torch.matmul(input_embeddings, self.w_k)
        attn_scores = torch.matmul(query, keys.transpose(1, 2))
        attn_weights = torch.softmax(attn_scores*(self.d_out**0.5), dim=-1)
        values = torch.matmul(input_embeddings, self.w_v)
        context = torch.matmul(attn_weights, values)  # Is this correct or should I transpose attn_weights?
        return context

In [78]:
torch.manual_seed(123)
input_embeddings = torch.randn(8, CONTEXT_LENGTH, EMBEDDING_DIM).to(torch_device)/math.sqrt(EMBEDDING_DIM)
sa_v1 = SelfAttention_v1(EMBEDDING_DIM, QKV_DIM)
context = sa_v1(input_embeddings)
context[0]

tensor([[-0.0253,  0.7391],
        [ 0.0095,  0.6203],
        [-0.0135,  0.6879],
        [-0.2222,  1.0093]], device='mps:0')

In [85]:
EMBEDDING_DIM = 3
CONTEXT_LENGTH = 4
QKV_DIM = 2

class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.w_q = nn.Linear(d_in, d_out, bias=qkv_bias).to(torch_device)
        self.w_k = nn.Linear(d_in, d_out, bias=qkv_bias).to(torch_device)
        self.w_v = nn.Linear(d_in, d_out, bias=qkv_bias).to(torch_device)

    def forward(self, x):
        query = self.w_q(input_embeddings)
        keys = self.w_k(input_embeddings)
        attn_scores = torch.matmul(query, keys.transpose(1, 2))
        attn_weights = torch.softmax(attn_scores*(self.d_out**0.5), dim=-1)
        values = self.w_v(input_embeddings)
        context = torch.matmul(attn_weights, values)  # Is this correct or should I transpose attn_weights?
        return context

In [100]:
sa_v2 = SelfAttention_v2(EMBEDDING_DIM, QKV_DIM)
context = sa_v2(input_embeddings)
context[0]

tensor([[-0.0099, -0.0742],
        [-0.0119, -0.0634],
        [-0.0090, -0.0737],
        [-0.0149, -0.0619]], device='mps:0', grad_fn=<SelectBackward0>)

In [101]:
sa_v1.w_k = sa_v2.w_k.weight.T
sa_v1.w_q = sa_v2.w_q.weight.T
sa_v1.w_v = sa_v2.w_v.weight.T
context = sa_v1(input_embeddings)
context[0]

tensor([[-0.0099, -0.0742],
        [-0.0119, -0.0634],
        [-0.0090, -0.0737],
        [-0.0149, -0.0619]], device='mps:0', grad_fn=<SelectBackward0>)