In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self,d_in,
                 d_out,
                 context_length,
                 dropout,
                 num_heads,
                 bias=False):
        super().__init__()

        self.W_Query  = nn.Linear(d_in, d_out, bias=bias)
        self.W_Key    = nn.Linear(d_in, d_out, bias=bias)
        self.W_Value  = nn.Linear(d_in, d_out, bias=bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.d_out    = d_out
        self.num_head = num_heads
        self.head_dim = d_out//num_heads
        self.dropout  = nn.Dropout(dropout)
        
        self.register_buffer('mask',
                             torch.triu(torch.ones(context_length,context_length),
                                        diagonal=1))
        
    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys    = self.W_Key(x)
        queries = self.W_Query(x)
        values  = self.W_Value(x)

        keys    = keys.view(b, num_tokens, self.num_head, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_head, self.head_dim)
        values  = values.view(b, num_tokens, self.num_head, self.head_dim)

        keys    = keys.transpose(1,2)
        queries = queries.transpose(1,2)
        values  = values.transpose(1,2)

        att_score = queries @ keys.transpose(2,3)

        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        att_score.masked_fill_(mask_bool, -torch.inf)

        att_weights = torch.softmax(att_score / keys.shape[-1]**0.5, dim=-1)
        att_weights = self.dropout(att_weights)

        context_vec = (att_weights @ values).transpose(1, 2)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec

In [None]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55], # step     (x^6)
   [0.4419, 0.6515, 0.5683]]
)
batch = torch.stack((inputs, inputs), dim=0)
batch.shape

In [None]:
d_in = 3
d_out = 6
context_length = 7
torch.manual_seed(42)
MUA = MultiHeadAttention(d_in, d_out, context_length, 0, 2, False)