In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchsummary

batch_size = 32
block_size = 8
n_embed = 32

## Single Head Attention

In [16]:
class HEAD(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))

    def forward(self, inputs):
        batch_size, sequence_length, embedding_dim = inputs.shape
        keys = self.key(inputs)
        queries = self.query(inputs)
        values = self.value(inputs)
        weights = queries @ keys.transpose(-2, -1) * (embedding_dim ** -0.5)
        weights = weights.masked_fill(self.tril[:sequence_length, :sequence_length]==0, float("-inf"))
        weights = F.softmax(weights, dim=-1)
        values = self.value(inputs)
        output = weights @ values
        return output

In [17]:
single_head_att = HEAD(16)
single_head_att

HEAD(
  (key): Linear(in_features=32, out_features=16, bias=False)
  (query): Linear(in_features=32, out_features=16, bias=False)
  (value): Linear(in_features=32, out_features=16, bias=False)
)

## Multi Head Attention (Single Head Attention Bundles)

In [18]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([HEAD(head_size) for _ in range(num_heads)])

    def forward(self, inputs):
        return torch.cat([head(inputs) for head in self.heads], dim=-1)

In [19]:
test = MultiHeadAttention(4, 16)
test

MultiHeadAttention(
  (heads): ModuleList(
    (0): HEAD(
      (key): Linear(in_features=32, out_features=16, bias=False)
      (query): Linear(in_features=32, out_features=16, bias=False)
      (value): Linear(in_features=32, out_features=16, bias=False)
    )
    (1): HEAD(
      (key): Linear(in_features=32, out_features=16, bias=False)
      (query): Linear(in_features=32, out_features=16, bias=False)
      (value): Linear(in_features=32, out_features=16, bias=False)
    )
    (2): HEAD(
      (key): Linear(in_features=32, out_features=16, bias=False)
      (query): Linear(in_features=32, out_features=16, bias=False)
      (value): Linear(in_features=32, out_features=16, bias=False)
    )
    (3): HEAD(
      (key): Linear(in_features=32, out_features=16, bias=False)
      (query): Linear(in_features=32, out_features=16, bias=False)
      (value): Linear(in_features=32, out_features=16, bias=False)
    )
  )
)

## Feed Forward

In [21]:
class FeedForward(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout)
        )

    def forward(self, input_tesnor):
        return self.layer(input_tensor)