In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import torch


class Datamaker:
    def __init__(self, vocabulary: str) -> None:
        unique_vocabulary = list(set(vocabulary))
        self.vocabulary_dict = {}
        self.reverse_vocabulary_dict = {}
        for index, letter in enumerate(unique_vocabulary):
            self.vocabulary_dict[letter] = index
            self.reverse_vocabulary_dict[index] = letter

    def untokenize(self, x: torch.Tensor) -> list:
        temp = []
        for index, element in enumerate(x):
            temp.append(self.reverse_vocabulary_dict[round(element.item())])
        string = "".join(temp)
        return string

    def split_data(self, string_for_training: str, text_length: int) -> tuple:
        split_string = list(string_for_training)
        in_sequences, out_sequences = [], []
        i = 0
        while True:
            if i + text_length + 1 <= len(split_string):
                temp_in = split_string[i:i + text_length]
                temp_out = split_string[i + 1:i + text_length + 1]
                for index, element in enumerate(temp_in):
                    temp_in[index] = self.vocabulary_dict[element]
                for index, element in enumerate(temp_out):
                    temp_out[index] = self.vocabulary_dict[element]
                in_sequences.append(temp_in)
                out_sequences.append(temp_out)
                i += 1
            else:
                break

        in_sequence_tensor = torch.zeros(len(in_sequences), text_length, dtype=torch.int)
        out_sequence_tensor = torch.zeros(len(out_sequences), text_length, dtype=torch.long)

        for sequence_index, sequence in enumerate(in_sequences):
            for element_index, element in enumerate(sequence):
                in_sequence_tensor[sequence_index, element_index] = element
                out_sequence_tensor[sequence_index, element_index] = element

        return in_sequence_tensor, out_sequence_tensor

In [None]:
import torch
from torch import nn


class Encoder(nn.Module):
    def __init__(self, encode_size: int, vocab_size: int, context_length: int) -> None:
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, encode_size)
        self.positional_encodings = nn.Parameter(torch.randn(context_length, encode_size))
        self.layer_norm = nn.LayerNorm(encode_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """ x is of size [num_batches, sequence_length] """
        base_embedding = self.embedding(x)
        pos_enc = self.positional_encodings
        full_embedding = base_embedding + pos_enc
        normalized = self.layer_norm(full_embedding)
        return normalized


In [None]:
import torch
from torch import nn


class AttentionBlock(nn.Module):
    def __init__(self, num_heads: int, encode_size: int) -> None:
        super().__init__()

        self.layer_norm = nn.LayerNorm(encode_size)
        self.num_heads = num_heads
        self.encode_size = encode_size
        self.hidden_encode_size = encode_size // num_heads

        self.q = nn.Linear(in_features=self.encode_size, out_features=self.encode_size)
        self.k = nn.Linear(in_features=self.encode_size, out_features=self.encode_size)
        self.v = nn.Linear(in_features=self.encode_size, out_features=self.encode_size)

        self.lrelu = nn.LeakyReLU()

        self.mlp = nn.Sequential(
            nn.Linear(in_features=self.encode_size, out_features=self.encode_size * 4),
            self.lrelu,
            nn.Linear(in_features=self.encode_size * 4, out_features=self.encode_size),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x

        num_batches = x.size(dim=0)
        num_tokens = x.size(dim=1)
        encode_size = x.size(dim=2)

        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        """ q k v are size [batch_size, num_tokens, encode_size] """

        qs = torch.chunk(q, self.num_heads, dim=-1)
        ks = torch.chunk(k, self.num_heads, dim=-1)
        vs = torch.chunk(v, self.num_heads, dim=-1)
        """ qs ks vs are num_heads tuples of size [batch_size, num_tokens, hidden_encode_size] """

        attention_matrices = torch.zeros(self.num_heads, num_batches, num_tokens, num_tokens)

        for head in range(self.num_heads):
            for batch in range(num_batches):
                for q_index, query in enumerate(qs[head][batch]):
                    for k_index, key in enumerate(ks[head][batch]):
                        if k_index > q_index:
                            attention_value = -torch.inf
                        else:
                            attention_value = torch.dot(query, key)
                        attention_matrices[head][batch][q_index][k_index] = attention_value

        attention_matrices = nn.Softmax(dim=-1)(attention_matrices)

        head_outputs = []
        for head in range(self.num_heads):
            head_output = torch.bmm(attention_matrices[head], vs[head])
            head_outputs.append(head_output)

        values = torch.cat(head_outputs, dim=-1)

        x = residual + values
        x = self.layer_norm(x)
        residual = x
        mlp = self.mlp(x)
        x = residual + mlp
        x = self.layer_norm(x)
        return x

In [None]:
import torch
from torch import nn


class Decoder(nn.Module):
    def __init__(self, encode_size: int, vocab_amount: int) -> None:
        super().__init__()

        self.lrelu = nn.LeakyReLU()

        self.decoder = nn.Sequential(
            nn.Linear(in_features=encode_size, out_features=encode_size * 4),
            self.lrelu,
            nn.Linear(in_features=encode_size * 4, out_features=vocab_amount)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.decoder(x)
        return x

In [None]:
import torch
from torch import nn


class Model(nn.Module):
    def __init__(self, num_heads: int, encode_size: int, vocab_size: int, context_size: int) -> None:
        super().__init__()

        self.encoder = Encoder(encode_size=encode_size, vocab_size=vocab_size, context_length=context_size)
        self.att_block_1 = AttentionBlock(num_heads=num_heads, encode_size=encode_size)
        # self.att_block_2 = AttentionBlock(num_heads=num_heads, encode_size=encode_size)
        # self.att_block_3 = AttentionBlock(num_heads=num_heads, encode_size=encode_size)
        self.decoder = Decoder(encode_size=encode_size, vocab_amount=vocab_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.encoder(x)
        x = self.att_block_1(x)
        # x = self.att_block_2(x)
        # x = self.att_block_3(x)
        x = self.decoder(x)
        return x

In [None]:
datamaker = Datamaker("abcdefghijklmnopqrstuvwxyz .")
in_data, out_data = datamaker.split_data("the quick brown fox jumps over the lazy dog.", 8)
print(f"in: {in_data.size()}, out: {out_data.size()}")

model = Model(num_heads=8, encode_size=24, vocab_size=len(datamaker.vocabulary_dict),
              context_size=8)

In [None]:
EPOCHS = 100
LEARNING_RATE = 0.01

flattened_out_data = torch.flatten(out_data, start_dim=0, end_dim=1)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(EPOCHS):
    print(f"E {epoch + 1}/{EPOCHS} - {((epoch + 1) / EPOCHS) * 100:.2f}%")
    model.train()
    optimizer.zero_grad()

    raw_logits = model(in_data)
    flattened_logits = torch.flatten(raw_logits, start_dim=0, end_dim=1)
    sample_data = flattened_logits[:8]
    sample_argmax = torch.argmax(sample_data, dim=-1)
    human_readable_sample_data = datamaker.untokenize(sample_argmax)
    print(f"Sample: {human_readable_sample_data}")

    loss = loss_fn(flattened_logits, flattened_out_data)
    loss.backward()
    optimizer.step()