In [1]:
import numpy as np
import torch
import torch.nn as nn

## Vocabulary


In [2]:
class Vocabulary:
    def __init__(self, text: str):
        self.char_to_idx = {}

        self.idx_to_char = {}
        self.vocab_size = 0
        self.build_vocab(text)

    def build_vocab(self, text):
        # Create sorted vocabulary from unique characters
        unique_chars = sorted(list(set(text)))
        self.char_to_idx = {char: idx for idx, char in enumerate(unique_chars)}

        self.idx_to_char = {idx: char for char, idx in self.char_to_idx.items()}
        self.vocab_size = len(unique_chars)

    def encode(self, text):
        """Convert string to list of indices"""
        return [self.char_to_idx[char] for char in text]

    def decode(self, indices):
        """Convert list of indices to string"""
        return "".join([self.idx_to_char[idx] for idx in indices])

    def encode_tensor(self, text):
        """Convert string to PyTorch tensor"""
        return torch.tensor([self.encode(text)])

    def decode_tensor(self, tensor):
        """Convert PyTorch tensor to string"""
        return self.decode(tensor.flatten().tolist())

In [7]:
text = """
I grew up on the crime side, the New York Times side
Stayin' alive was no jive
At second hands, moms bounced on old men
So then we moved to Shaolin land
""".strip()

vocab = Vocabulary(text)
vocab.vocab_size

32

In [9]:
new_text = "Stayin' alive was no jive"
print(vocab.encode(new_text))

[7, 27, 10, 31, 17, 22, 2, 1, 10, 20, 17, 29, 14, 1, 30, 10, 26, 1, 22, 23, 1, 18, 17, 29, 14]


## Transformer


In [10]:
EMBEDDING_SIZE = 32
ATTENTION_HEADS = 4
FEED_FORWARD_SIZE = 128
DROPOOUT = 0.1
CONTEXT_WINDOW = 128


class MultiHeadAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.d_k = EMBEDDING_SIZE // ATTENTION_HEADS

        self.q_linear = nn.Linear(EMBEDDING_SIZE, EMBEDDING_SIZE)
        self.k_linear = nn.Linear(EMBEDDING_SIZE, EMBEDDING_SIZE)
        self.v_linear = nn.Linear(EMBEDDING_SIZE, EMBEDDING_SIZE)
        self.out = nn.Linear(EMBEDDING_SIZE, EMBEDDING_SIZE)

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)

        q = (
            self.q_linear(q)
            .view(batch_size, -1, ATTENTION_HEADS, self.d_k)
            .transpose(1, 2)
        )
        k = (
            self.k_linear(k)
            .view(batch_size, -1, ATTENTION_HEADS, self.d_k)
            .transpose(1, 2)
        )
        v = (
            self.v_linear(v)
            .view(batch_size, -1, ATTENTION_HEADS, self.d_k)
            .transpose(1, 2)
        )

        scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn = torch.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)

        out = out.transpose(1, 2).contiguous().view(batch_size, -1, EMBEDDING_SIZE)
        return self.out(out)


class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(EMBEDDING_SIZE, FEED_FORWARD_SIZE),
            nn.ReLU(),
            nn.Dropout(DROPOOUT),
            nn.Linear(FEED_FORWARD_SIZE, EMBEDDING_SIZE),
        )

    def forward(self, x):
        return self.net(x)


class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.attention = MultiHeadAttention()
        self.feed_forward = FeedForward()
        self.norm1 = nn.LayerNorm(EMBEDDING_SIZE)
        self.norm2 = nn.LayerNorm(EMBEDDING_SIZE)
        self.dropout = nn.Dropout(DROPOOUT)

    def forward(self, x, mask=None):
        attended = self.attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attended))

        fed_forward = self.feed_forward(x)
        x = self.norm2(x + self.dropout(fed_forward))
        return x


class Transformer(nn.Module):
    def __init__(self, vocab_size: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, EMBEDDING_SIZE)
        self.pos_embedding = nn.Parameter(
            torch.randn(1, CONTEXT_WINDOW, EMBEDDING_SIZE)
        )
        self.transformer = TransformerBlock()
        self.fc = nn.Linear(EMBEDDING_SIZE, vocab_size)

    def forward(self, x, mask=None):
        x = self.embedding(x) + self.pos_embedding[:, : x.size(1)]
        x = self.transformer(x, mask)
        return self.fc(x)

## Training


In [14]:
def train(sentence):
    vocab = Vocabulary(sentence)

    # Prepare input and target sequences
    x = vocab.encode_tensor(sentence[:-1])  # Input sequence
    y = vocab.encode_tensor(sentence[1:])  # Target sequence

    # Create model and optimizer
    model = Transformer(vocab.vocab_size)
    optimizer = torch.optim.Adam(model.parameters())
    criterion = nn.CrossEntropyLoss()

    # Training loop
    for epoch in range(1000):
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output.view(-1, vocab.vocab_size), y.view(-1))
        loss.backward()
        optimizer.step()

        if (epoch + 1) % 100 == 0:
            print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")

    return model, vocab


def generate(model, prefix, vocab, max_new_chars=32):
    model.eval()
    current_sequence = vocab.encode(prefix)
    result = prefix

    for _ in range(max_new_chars):
        # Predict next character
        x = torch.tensor([current_sequence])
        with torch.no_grad():
            output = model(x)
            next_char_idx = torch.argmax(output[0, -1]).item()

        # Add predicted character to sequence
        current_sequence.append(next_char_idx)
        result += vocab.idx_to_char[next_char_idx]

        # Stop if we predict a period
        if vocab.idx_to_char[next_char_idx] == ".":
            break

    return result

In [16]:
sentence = "The quick brown fox jumps over the lazy dog."
sentence[:-1]

'The quick brown fox jumps over the lazy dog'

In [17]:
sentence[1:]

'he quick brown fox jumps over the lazy dog.'

In [15]:
model, vocab = train(sentence)

Epoch 100, Loss: 0.4024
Epoch 200, Loss: 0.0857
Epoch 300, Loss: 0.0362
Epoch 400, Loss: 0.0203
Epoch 500, Loss: 0.0143
Epoch 600, Loss: 0.0096
Epoch 700, Loss: 0.0072
Epoch 800, Loss: 0.0057
Epoch 900, Loss: 0.0046
Epoch 1000, Loss: 0.0038


In [19]:
text = "The quick brown"
generated = generate(model, text, vocab, max_new_chars=4)
print(generated)

The quick brown fox
