In [1]:
import torch

In [2]:

if torch.backends.mps.is_available():
    torch_device = torch.device("mps")
    x = torch.ones(1, device=torch_device)
    print (x)
else:
    torch_device = torch.device("cpu")
    print ("MPS device not found.")

tensor([1.], device='mps:0')


In [9]:
EMBEDDING_DIM = 256
CONTEXT_LENGTH = 4

In [21]:
input_embeddings = torch.rand(8, CONTEXT_LENGTH, EMBEDDING_DIM).to(torch_device)

In [22]:
input_embeddings.shape

torch.Size([8, 4, 256])

In [23]:
attn_scores = torch.matmul(input_embeddings, input_embeddings.transpose(1, 2))
attn_scores.shape

torch.Size([8, 4, 4])

In [24]:
attn_scores[0]

tensor([[89.6776, 62.3817, 55.2851, 66.0383],
        [62.3817, 80.6503, 55.0130, 61.5921],
        [55.2851, 55.0130, 67.5199, 54.2769],
        [66.0383, 61.5921, 54.2769, 85.4145]], device='mps:0')

In [25]:
attn_weights = attn_scores/torch.sum(attn_scores, dim=2, keepdim=True)

In [41]:
context = torch.matmul(attn_weights, input_embeddings)

In [42]:
context.shape.sum(dim=2)

tensor([[130.9528, 124.2068, 110.9094, 128.1101],
        [128.6789, 124.3847, 114.9565, 131.2061],
        [128.7147, 134.5517, 131.0540, 134.1371],
        [124.9101, 131.6775, 126.2380, 126.1799],
        [123.9312, 124.4568, 116.6398, 123.6487],
        [126.8249, 129.9419, 124.3546, 133.1391],
        [125.9766, 128.0879, 128.4334, 129.7215],
        [134.1459, 124.1152, 127.3237, 129.6205]], device='mps:0')

### Softmax version

In [54]:
temperature = 0.1
attn_scores = torch.matmul(input_embeddings, input_embeddings.transpose(1, 2))
# attn_weights = torch.exp(temperature*attn_scores)/torch.sum(torch.exp(temperature*attn_scores), dim=2, keepdim=True)
attn_weights = torch.softmax(temperature*attn_scores, dim=-1)
context = torch.matmul(attn_weights, input_embeddings)  # Is this correct or should I transpose attn_weights?

In [57]:
attn_weights[0].sum(dim=-1, keepdim=True)

tensor([[1.0000],
        [1.0000],
        [1.0000],
        [1.0000]], device='mps:0')