In [None]:
import torch
import random

torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
random.seed(42)

In [None]:
import torch


class Tokenizer:
    def __init__(self, vocabulary: str, special_tokens: list):
        self.special_tokens = special_tokens
        self.vocabulary_dictionary = {}
        self.reverse_vocabulary_dictionary = {}
        vocabulary_list = list(set(vocabulary))
        vocabulary_list = vocabulary_list + special_tokens
        for i, t in enumerate(vocabulary_list):
            self.vocabulary_dictionary[t] = i
            self.reverse_vocabulary_dictionary[i] = t

    def tokenize(self, text: str, pad_before: list = None, pad_after: list = None) -> torch.Tensor:
        split_string = list(text)
        pad_before = [] if pad_before is None else pad_before
        pad_after = [] if pad_after is None else pad_after
        full = pad_before + split_string + pad_after
        out = torch.Tensor([self.vocabulary_dictionary[letter] for letter in full])
        out = out.long()
        return out

    def untokenize(self, tensor: torch.Tensor) -> list:
        out_list = [self.reverse_vocabulary_dictionary[token.item()] for token in tensor]
        return out_list

In [None]:
import torch
from torch import nn


class Model(nn.Module):
    def __init__(self,
                 vocabulary_length: int,
                 latent_size: int,
                 attention_depth: int,
                 num_heads: int):
        super().__init__()

        self.encode_dimension = latent_size

        self.embedding = nn.Embedding(vocabulary_length, latent_size)

        self.positional_encoding = nn.Embedding(2, latent_size)

        self.attention_block = nn.MultiheadAttention(embed_dim=latent_size, num_heads=num_heads, batch_first=True)
        self.projection_layer = nn.Linear(in_features=latent_size, out_features=latent_size)
        self.mlp_layer = nn.Sequential(
            nn.Linear(in_features=latent_size, out_features=latent_size * 4),
            nn.LeakyReLU(),
            nn.Linear(in_features=latent_size * 4, out_features=latent_size),
        )

        self.layer_norm = nn.LayerNorm(latent_size)

        self.attention_blocks = nn.ModuleList([self.attention_block for _ in range(attention_depth)])
        self.projection_layers = nn.ModuleList([self.projection_layer for _ in range(attention_depth)])
        self.mlp_layers = nn.ModuleList([self.mlp_layer for _ in range(attention_depth)])

        self.decoder = nn.Sequential(
            nn.Linear(in_features=latent_size, out_features=vocabulary_length),
        )

    def encode_token(self, latent: torch.Tensor, token: torch.Tensor) -> tuple:
        sequence = torch.stack([latent, token], dim=0)
        for i, block in enumerate(self.attention_blocks):
            seq, _ = self.attention_blocks[i](sequence, sequence, sequence)
            seq = self.projection_layers[i](sequence)
            sequence = self.layer_norm(sequence + seq)
            seq = self.mlp_layers[i](sequence)
            sequence = self.layer_norm(sequence + seq)
        attended_latent, attended_token = sequence[0], sequence[1]
        return attended_latent, attended_token

    def inference(
            self,
            latent: torch.Tensor,
            tokens: torch.Tensor,
            max_len: int,
            end_at: list) -> list:
        tokens = self.embedding(tokens)
        for token in tokens:
            att_latent, att_token = self.encode_token(latent, token)
            latent = self.layer_norm(latent + att_token)
        output = []
        while len(output) < max_len:
            prediction_distribution = torch.softmax(self.decoder(latent), dim=-0)
            prediction = prediction_distribution.argmax()
            letter = tokenizer.reverse_vocabulary_dictionary[prediction.item()]
            print(f"{letter}", end="")
            output.append(letter)
            if letter in end_at:
                break
        return output

    def forward(self, latent: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
        tokens = self.embedding(tokens)
        for token in tokens:
            att_latent, att_token = self.encode_token(latent, token)
            latent = self.layer_norm(latent + att_token)
        return self.decoder(latent)

In [None]:
vocab = "abcdefghijklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ,.<>/?;:'[]{}1234567890!@#$%^&*()`~-_=+"
special_tokens = ["<model>", "</model>", "<system>", "</system>", "<user>", "</user>", '"', "\\", "\n"]

LATENT_SIZE = 64
ATTENTION_DEPTH = 4
NUM_HEADS = 8

tokenizer = Tokenizer(
    vocabulary=vocab,
    special_tokens=special_tokens
)

model = Model(
    vocabulary_length=len(tokenizer.vocabulary_dictionary),
    latent_size=LATENT_SIZE,
    attention_depth=ATTENTION_DEPTH,
    num_heads=NUM_HEADS
)

In [None]:
strings = [
    "The quick brown fox jumps over the lazy dog.",
    "She sells seashells by the seashore.",
    "I love eating pizza on a rainy day.",
    "Artificial intelligence is transforming the world.",
    "The cat sat on the mat and purred softly.",
    "Learning to code can be both fun and challenging.",
    "He runs faster than anyone in his class.",
    "A beautiful sunrise can brighten any mood.",
    "The library is a quiet place for reading and studying.",
    "Music has the power to connect people from all walks of life."
]
training_texts = []
for index, string in enumerate(strings):
    tensor_text = tokenizer.tokenize(string, pad_before=["<model>"], pad_after=["</model>"])
    training_texts.append(tensor_text)

In [None]:
EPOCHS = 10
BATCH_SIZE = 64
LEARNING_RATE = 0.001

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.CrossEntropyLoss()

epoch_losses = []
text_losses = []

for epoch in range(EPOCHS):
    model.train()
    optimizer.zero_grad()
    epoch_loss = 0
    for text in training_texts:
        latent_space = torch.zeros(LATENT_SIZE)
        raw_logits = model(latent=latent_space, tokens=text[:-1])
        loss = loss_fn(raw_logits, text[-1])
        loss.backward()
        text_losses.append(loss.item())
        epoch_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad()
    epoch_loss /= len(training_texts)
    epoch_losses.append(epoch_loss)
    print(f"E {epoch + 1:,} - {((epoch + 1) / EPOCHS) * 100}% | Loss: {epoch_loss:,}")

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

axes[0].plot(epoch_losses, label="Epoch Loss")
axes[0].set_xlabel("Epochs")
axes[0].set_ylabel("Loss")
axes[0].set_title("Epoch Losses")
axes[0].legend()
axes[0].grid(True)

axes[1].plot(text_losses, label="Text Loss", color='orange')
axes[1].set_xlabel("Text Number")
axes[1].set_ylabel("Loss")
axes[1].set_title("Text Losses")
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.show()

In [None]:
test_text = tokenizer.tokenize("The")
test_latent = torch.zeros(64)
test_inference = model.inference(
    latent=test_latent,
    tokens=test_text,
    max_len=100,
    end_at=["</model>"]
)