In [97]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.__version__

'2.6.0+cu124'

# Data

In [None]:
# ------------------- preparation -------------------
names = open("names.txt", "r", encoding="utf-8").read().splitlines()
chars = sorted(set(''.join(names)))
ctoi = { ch:i+1 for i, ch in enumerate(chars) }
ctoi['.'] = 0
itoc = { i:ch for ch, i in ctoi.items() }

# ------------------- modules -------------------
encode = lambda string: torch.tensor([ctoi[char] for char in string])
decode = lambda tensor: ''.join(itoc[idx.item()] for idx in tensor)

def build_dataset(data: list, context_size: int):
  features, labels = [], []

  for word in data:
    padded_word = word + '.'
    indices = [ctoi[char] for char in padded_word]
    context = [0] * context_size + indices

    for i in range(len(indices)):
      sequence = context[i : i + context_size + 1]
      features.append(sequence[:-1])
      labels.append(sequence[1:])

  return torch.tensor(features, dtype=torch.long), torch.tensor(labels, dtype=torch.long)

def get_batch(features: torch.Tensor, labels: torch.Tensor, batch_size: int):
  batch_indices = torch.randint(0, features.shape[0], size=(batch_size,))

  return features[batch_indices], labels[batch_indices]

# ------------------- datasets -------------------
features, labels = build_dataset(names, context_size=8)

num_train = int(0.8 * features.shape[0])
train_features, train_labels = features[:num_train], labels[:num_train]
val_features, val_labels = features[num_train:], labels[num_train:]

train_features.shape, val_features.shape

(torch.Size([182516, 8]), torch.Size([45630, 8]))

# Custom Transformer

Modules required for the Transformer architecture:
- Head (the basic unit)
- MultiHeadAttention (multiple heads + projection + dropout)
- FeedForward (a simple neural net)
- Block (MultiHeadAttention + FeedForward using skip connections and layernorms)

In [99]:
class Head(nn.Module):
    def __init__(self, embedding_dim: int, head_size: int, dropout_p: float = 0.2):
        super().__init__()
        self.head_size = head_size
        self.query = nn.Linear(embedding_dim, head_size, bias=False)
        self.key = nn.Linear(embedding_dim, head_size, bias=False)
        self.value = nn.Linear(embedding_dim, head_size, bias=False)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        scores = q @ k.transpose(-2, -1) * (self.head_size**-0.5)

        tril = torch.tril(torch.ones(T, T))
        scores = scores.masked_fill(tril == 0, float("-inf"))

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        out = attn @ v
        return out

class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        embedding_dim: int,
        dropout_p: float = 0.2
    ):
        super().__init__()
        self.heads = nn.ModuleList([Head(embedding_dim, head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(num_heads * head_size, embedding_dim)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.cat([head(x) for head in self.heads], dim=-1)
        x = self.dropout(self.proj(x))
        return x


class FeedForward(nn.Module):
    def __init__(self, embedding_dim: int, dropout_p: float = 0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim*4),
            nn.ReLU(),
            nn.Linear(embedding_dim*4, embedding_dim),
            nn.Dropout(dropout_p)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

class Block(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        num_heads: int,
        head_size: int
    ):
      super().__init__()
      self.self_attention = MultiHeadAttention(
          num_heads=num_heads,
          head_size=head_size,
          embedding_dim=embedding_dim
          )
      self.feed_forward = FeedForward(embedding_dim)
      self.layernorm1 = nn.LayerNorm(embedding_dim)
      self.layernorm2 = nn.LayerNorm(embedding_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
      x = x + self.self_attention(self.layernorm1(x))
      x = x + self.feed_forward(self.layernorm2(x))
      return x

class Net(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        context_size: int,
        embedding_dim: int,
        num_blocks: int,
        num_heads: int,
        head_size: int
    ):
      super().__init__()
      self.context_size = context_size

      self.token_embedding_table = nn.Embedding(vocab_size, embedding_dim)
      self.position_embedding_table = nn.Embedding(context_size, embedding_dim)

      self.blocknet = nn.Sequential(*[Block(embedding_dim, num_heads, head_size) for _ in range(num_blocks)])

      self.layernorm = nn.LayerNorm(embedding_dim)
      self.decoder = nn.Linear(embedding_dim, vocab_size)

    def forward(self, index: torch.Tensor, targets: torch.Tensor = None):
      B, T = index.shape

      token_emb = self.token_embedding_table(index)
      position_emb = self.position_embedding_table(torch.arange(T))
      x = token_emb + position_emb

      x = self.layernorm(self.blocknet(x))
      logits = self.decoder(x)

      if targets is None:
        loss = None
      else:
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        targets = targets.view(B*T)
        loss = F.cross_entropy(logits, targets)

      return logits, loss

    def generate(self, num_tokens: int):
      tokens = torch.zeros((1, 1), dtype=torch.long)

      self.eval()
      with torch.inference_mode():
        for _ in range(num_tokens):
          output = []
          context = [0] * self.context_size
          while True:
            logits, _ = self(torch.tensor([context]))
            probabilities = F.softmax(logits[:, -1, :], dim=-1)
            index = torch.multinomial(probabilities, num_samples=1)

            output.append(index.item())
            context = context[1:] + [index]

            if index == 0:
              break

        print(''.join(itoc[index] for index in output))

    def fit(self,
            features: torch.Tensor,
            labels: torch.Tensor,
            optimizer: torch.optim.Optimizer,
            batch_size: int,
            epochs: int = 1,
            verbose_frequency: int | None = None):
      if verbose_frequency is None:
        verbose_frequency = epochs // 10

      for epoch in range(epochs):
        features_batch, labels_batch = get_batch(features, labels, batch_size)
        _, loss = self(features_batch, labels_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % verbose_frequency == 0 or epoch+1==epochs:
          print(f"Epoch {epoch:6d}/{epochs} - loss: {loss:.4f}")

# Hyperparameters

In [100]:
vocab_size = len(chars) + 1 # plus the beginning/ending symbol
context_size = 8
embedding_dim = 128
num_blocks = 4
num_heads = 4
head_size = embedding_dim // num_heads
learning_rate = 1e-3
batch_size = 64
epochs = 5000

model = Net(vocab_size, context_size, embedding_dim, num_blocks, num_heads, head_size)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

model

Net(
  (token_embedding_table): Embedding(27, 128)
  (position_embedding_table): Embedding(8, 128)
  (blocknet): Sequential(
    (0): Block(
      (self_attention): MultiHeadAttention(
        (heads): ModuleList(
          (0-3): 4 x Head(
            (query): Linear(in_features=128, out_features=32, bias=False)
            (key): Linear(in_features=128, out_features=32, bias=False)
            (value): Linear(in_features=128, out_features=32, bias=False)
            (dropout): Dropout(p=0.2, inplace=False)
          )
        )
        (proj): Linear(in_features=128, out_features=128, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (feed_forward): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=128, out_features=512, bias=True)
          (1): ReLU()
          (2): Linear(in_features=512, out_features=128, bias=True)
          (3): Dropout(p=0.2, inplace=False)
        )
      )
      (layernorm1): LayerNorm((128,), eps=1e-05, elemen

In [101]:
model.fit(train_features,
          train_labels,
          optimizer,
          batch_size,
          epochs,
          verbose_frequency = 100)

Epoch      0/5000 - loss: 3.5886
Epoch    100/5000 - loss: 1.5931
Epoch    200/5000 - loss: 1.5996
Epoch    300/5000 - loss: 1.3883
Epoch    400/5000 - loss: 1.4831
Epoch    500/5000 - loss: 1.3481
Epoch    600/5000 - loss: 1.4135
Epoch    700/5000 - loss: 1.5672
Epoch    800/5000 - loss: 1.4822
Epoch    900/5000 - loss: 1.4873
Epoch   1000/5000 - loss: 1.4314
Epoch   1100/5000 - loss: 1.3500
Epoch   1200/5000 - loss: 1.4007
Epoch   1300/5000 - loss: 1.4380
Epoch   1400/5000 - loss: 1.4484
Epoch   1500/5000 - loss: 1.2852
Epoch   1600/5000 - loss: 1.4804
Epoch   1700/5000 - loss: 1.2521
Epoch   1800/5000 - loss: 1.5195
Epoch   1900/5000 - loss: 1.3919
Epoch   2000/5000 - loss: 1.3911
Epoch   2100/5000 - loss: 1.4152
Epoch   2200/5000 - loss: 1.3729
Epoch   2300/5000 - loss: 1.4429
Epoch   2400/5000 - loss: 1.4225
Epoch   2500/5000 - loss: 1.3627
Epoch   2600/5000 - loss: 1.4225
Epoch   2700/5000 - loss: 1.4119
Epoch   2800/5000 - loss: 1.4000
Epoch   2900/5000 - loss: 1.4459
Epoch   30

In [153]:
model.generate(42)

effran.
wenid.
allian.
yahri.
aramir.
zabaariyah.
graisetu.
omiyah.
sh.
brazell.
kiplrin.
brie.
yasia.
shanvik.
prystine.
ahmia.
nalaina.
penavier.
amar.
allettarye.
zoyai.
bvaunten.
jahlenen.
contari.
benna.
kamie.
daycell.
woxtelin.
liy.
maayonna.
palaine.
aoura.
briarta.
surahe.
derlyn.
kyriah.
nicar.
bersyd.
ane.
olani.
jishah.
sayda.
