In [134]:
from tinygrad import Tensor, nn

inputs = Tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

inputs.realize()

print(inputs.numpy())

[[0.43 0.15 0.89]
 [0.55 0.87 0.66]
 [0.57 0.85 0.64]
 [0.22 0.58 0.33]
 [0.77 0.25 0.1 ]
 [0.05 0.8  0.55]]


In [142]:
import math

class MultiHeadAttention:

    def __init__(self, d_in: int, d_out: int, context_length: int, n_heads: int, dropout: float = 0.0, bias=False) -> None:
        self.d_in = d_in
        self.d_out = d_out
        self.context_length = context_length
        self.n_heads = n_heads
        self.head_dim = self.d_out // n_heads

        assert self.d_out % n_heads == 0

        # Linear Projections
        self.W_q = nn.Linear(d_in, d_out, bias=bias)
        self.W_k = nn.Linear(d_in, d_out, bias=bias)
        self.W_v = nn.Linear(d_in, d_out, bias=bias)

        self.dropout = dropout
        self.causal_mask = Tensor.triu(Tensor.ones(context_length, context_length), diagonal=1).bool()
        self.out_proj = nn.Linear(d_out, d_out)

    def forward(self, x: Tensor) -> Tensor:
        b, num_tokens, _ = x.shape

        # linear projections (B, num_tokens, d_out)
        keys = self.W_k(x)
        queries = self.W_q(x)
        values = self.W_v(x)

        # split heads (B, num_tokens, d_out) -> (B, n_heads, num_tokens, head_dim)
        keys = keys.reshape(b, num_tokens, self.n_heads, self.head_dim).transpose(1,2)
        queries = queries.reshape(b, num_tokens, self.n_heads, self.head_dim).transpose(1,2)
        values = values.reshape(b, num_tokens, self.n_heads, self.head_dim).transpose(1,2)
        
        attention_scores = queries @ keys.transpose(2,3)
        attention_scores = attention_scores.masked_fill(self.causal_mask[:num_tokens, :num_tokens], -math.inf)

        attention_weights = (attention_scores / self.head_dim ** 0.5).softmax(axis=-1)
        attention_weights = attention_weights.dropout(self.dropout)   

        context_vec = attention_weights @ values

        # (B, n_heads, num_tokens, head_dim) -> (B, num_tokens, d_out)
        context_vec = context_vec.view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec

Tensor.manual_seed(42)
attn = MultiHeadAttention(3, 2, context_length=10, n_heads=2, dropout=0.0)

batch = Tensor.stack([inputs for _ in range(5)])
Tensor.training = True

print(attn.forward(batch).realize().numpy())

[[[-0.539744   -0.29897985]
  [-0.5340578  -0.32340965]
  [-0.5515155  -0.3562595 ]
  [-0.5329706  -0.5371098 ]
  [-0.6656338  -0.71417296]
  [-0.6713593  -0.7271717 ]]

 [[-0.539744   -0.29897985]
  [-0.5340578  -0.32340965]
  [-0.5515155  -0.3562595 ]
  [-0.5329706  -0.5371098 ]
  [-0.6656338  -0.71417296]
  [-0.6713593  -0.7271717 ]]

 [[-0.539744   -0.29897985]
  [-0.5340578  -0.32340965]
  [-0.5515155  -0.3562595 ]
  [-0.5329706  -0.5371098 ]
  [-0.6656338  -0.71417296]
  [-0.6713593  -0.7271717 ]]

 [[-0.539744   -0.29897985]
  [-0.5340578  -0.32340965]
  [-0.5515155  -0.3562595 ]
  [-0.5329706  -0.5371098 ]
  [-0.6656338  -0.71417296]
  [-0.6713593  -0.7271717 ]]

 [[-0.539744   -0.29897985]
  [-0.5340578  -0.32340965]
  [-0.5515155  -0.3562595 ]
  [-0.5329706  -0.5371098 ]
  [-0.6656338  -0.71417296]
  [-0.6713593  -0.7271717 ]]]
