In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_token, num_heads):
        super().__init__()
        assert d_token % num_heads == 0, "d_token must be divisible by num_heads"

        self.d_token = d_token
        self.num_heads = num_heads
        # Dimension of each head
        self.d_k = d_token // num_heads
        self.square_d_k = math.sqrt(self.d_k)

        self.W_q = nn.Linear(d_token, d_token)
        self.W_k = nn.Linear(d_token, d_token)
        self.W_v = nn.Linear(d_token, d_token)

        self.W_o = nn.Linear(d_token, d_token)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        # Linear projections to get Q, K, V with reshape to (Batch, Seq_Len, Num_Heads, Head_Dim)
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        # Scaled Dot-Product Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.square_d_k

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn_weights = F.softmax(scores, dim=-1)

        # Weighted sum of values (Batch, Heads, Seq_Len, Head_Dim)
        context = torch.matmul(attn_weights, V)

        # Concatenate heads and apply final linear layer + reshape back to (Batch, Seq_Len, d_token)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_token)
        output = self.W_o(context)

        return output


class TransformerBlock(nn.Module):
    def __init__(self, d_token, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_token, num_heads)

        # Layer Norms
        self.norm1 = nn.LayerNorm(d_token)
        self.norm2 = nn.LayerNorm(d_token)

        # Feed Forward Network
        self.feed_forward = nn.Sequential(
            nn.Linear(d_token, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_token)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_output = self.attention(x, mask)
        # Residual connection
        x = self.norm1(x + self.dropout(attn_output))

        # Feed Forward + Residual + Norm
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))

        return x

class Transformer(nn.Module):
    def __init__(self, vocab_size, d_token, num_heads, num_blocks, d_ff, max_seq_len):
        super().__init__()
        # Token Embeddings
        self.embedding = nn.Embedding(vocab_size, d_token)

        # Positional Embeddings
        self.position_embedding = nn.Embedding(max_seq_len, d_token)

        # Stack of Transformer Blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(d_token, num_heads, d_ff) for _ in range(num_blocks)
        ])

        # Unembedding Layer
        self.lm_head = nn.Linear(d_token, vocab_size, bias=False)
        self.lm_head.weight = self.embedding.weight


    def forward(self, x):
        batch_size, seq_len = x.shape

        # Causal mask
        mask = torch.tril(torch.ones(seq_len, seq_len)).to(x.device)

        # Create position indices (0, 1, 2, ..., seq_len-1)
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
        x = self.embedding(x) + self.position_embedding(positions)

        # Pass through all transformer blocks
        for block in self.blocks:
            x = block(x, mask)

        # Final projection to vocabulary size
        logits = self.lm_head(x)
        return logits

### Data Preprocess

In [8]:
with open('./names.txt', 'r') as f:
	names = f.read().splitlines()
raw_names = names[:10000]

# Create Vocabulary (Character-level)
chars = sorted(list(set("".join(raw_names))))
# Add special tokens:
# <PAD>: Padding (0)
# <SOS>: Start of Sequence (1)
# <EOS>: End of Sequence (2)
stoi = {ch: i+3 for i, ch in enumerate(chars)}
stoi['<PAD>'] = 0
stoi['<SOS>'] = 1
stoi['<EOS>'] = 2
itos = {i: ch for ch, i in stoi.items()}
vocab_size = len(stoi)

print(f"Vocab Size: {vocab_size}")
print(f"Sample vocabulary: {list(stoi.items())[:5]}")

max_len = max(len(name) for name in raw_names) + 2 # +2 for <SOS> and <EOS>

def encode_name(name):
    # Convert name to integers with <SOS> and <EOS>
    return [stoi['<SOS>']] + [stoi[ch] for ch in name] + [stoi['<EOS>']]

# Prepare Training Data (Pad to max_len)
X_train = []
Y_train = []

for name in raw_names:
    encoded = encode_name(name)
    # <SOS> A l i c e -> A l i c e <EOS>
    input_seq = encoded[:-1]
    target_seq = encoded[1:]

    # Pad sequences to constant length
    pad_len = max_len - len(input_seq)
    input_seq += [stoi['<PAD>']] * pad_len
    target_seq += [stoi['<PAD>']] * pad_len

    X_train.append(input_seq)
    Y_train.append(target_seq)

X_train = torch.tensor(X_train, dtype=torch.long)
Y_train = torch.tensor(Y_train, dtype=torch.long)

Vocab Size: 29
Sample vocabulary: [('a', 3), ('b', 4), ('c', 5), ('d', 6), ('e', 7)]


### Training + Testing

In [9]:
def generate_name(start_str):
    model.eval()
    context = [stoi['<SOS>']] + [stoi.get(c, 0) for c in start_str]
    context = torch.tensor(context, dtype=torch.long).unsqueeze(0)

    generated_name = start_str

    for _ in range(max_len):
        if context.shape[1] >= max_len:
            break

        with torch.no_grad():
            logits = model(context)
            # Look only at the last predicted token
            last_logits = logits[:, -1, :]

            # Sample from the distribution
            probs = F.softmax(last_logits, dim=-1)
            # Prevent generating <SOS> or <PAD>
            probs[0][stoi['<SOS>']] = 0
            probs[0][stoi['<PAD>']] = 0

            # Pick next char
            next_ix = torch.multinomial(probs, num_samples=1).item()

            if next_ix == stoi['<EOS>']:
                break

            generated_name += itos[next_ix]

            # Append to context for next step
            context = torch.cat([context, torch.tensor([[next_ix]])], dim=1)

    return generated_name

In [10]:
d_token = 32
num_heads = 4
num_blocks = 2
d_ff = 64
learning_rate = 0.005
epochs = 200

model = Transformer(vocab_size, d_token, num_heads, num_blocks, d_ff, max_len)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Ensure the loss is not calculated on <PAD> tokens (id 0)
criterion = nn.CrossEntropyLoss(ignore_index=0)

print("Starting training...")
for epoch in range(epochs):
    optimizer.zero_grad()
    logits = model(X_train)

    loss = criterion(logits.view(-1, vocab_size), Y_train.view(-1))

    loss.backward()
    optimizer.step()

    if (epoch+1) % 20 == 0:
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
        print("Input 'A':", generate_name("A"))
        print("Input 'Ma':", generate_name("Ma"))
        print("Input 'Z':", generate_name("Z"))
        print("Input 'Ca':", generate_name("Ca"))

Starting training...
Epoch 20, Loss: 3.7929
Input 'A': Aeaaliyyyaleno
Input 'Ma': Malynegke
Input 'Z': Zle
Input 'Ca': Caaag
Epoch 40, Loss: 2.6267
Input 'A': Amlza
Input 'Ma': Maraieh
Input 'Z': Zaaaihijan
Input 'Ca': Cairlyn
Epoch 60, Loss: 2.5284
Input 'A': Aglidne
Input 'Ma': Mavtmya
Input 'Z': Zatiih
Input 'Ca': Calalhi
Epoch 80, Loss: 2.4755
Input 'A': Aiaai
Input 'Ma': Maeeia
Input 'Z': Zeaa
Input 'Ca': Caa
Epoch 100, Loss: 2.4322
Input 'A': Ayliane
Input 'Ma': Maarr
Input 'Z': Zoaay
Input 'Ca': Caaka
Epoch 120, Loss: 2.3678
Input 'A': Airien
Input 'Ma': Marlah
Input 'Z': Zrmea
Input 'Ca': Caeeeninh
Epoch 140, Loss: 2.2872
Input 'A': Aialela
Input 'Ma': Mananer
Input 'Z': Zelada
Input 'Ca': Caealis
Epoch 160, Loss: 2.2232
Input 'A': Aenicryna
Input 'Ma': Mabmiuh
Input 'Z': Zaseniha
Input 'Ca': Calilin
Epoch 180, Loss: 2.1694
Input 'A': Aaivey
Input 'Ma': Maeynn
Input 'Z': Zionla
Input 'Ca': Canolyn
Epoch 200, Loss: 2.1314
Input 'A': Ayanni
Input 'Ma': Maomaha
Input 'Z': Zaluay
I