In [None]:
import torch

# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cuda"

torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

In [None]:
import torch


class Tokenizer:
    def __init__(self, string: str, special_tokens: list):
        unique_list = sorted(set(string))
        unique_list.extend(special_tokens)
        self.letter_to_index = {letter: index for index, letter in enumerate(unique_list)}
        self.index_to_letter = {index: letter for index, letter in enumerate(unique_list)}

    def tokenize(self, string, pad_to_length) -> torch.Tensor:
        string = list(string)
        tokenized_string = [self.letter_to_index[letter] for letter in string]
        extra_pad = pad_to_length - len(tokenized_string)
        if extra_pad > 0:
            temp = [self.letter_to_index["<pad>"] for _ in range(extra_pad)]
            tokenized_string.extend(temp)
        tokens_tensor = torch.tensor(tokenized_string, dtype=torch.int)
        return tokens_tensor.to(device=device)

    def detokenize(self, tensor):
        """

        :param tensor: a tensor of size [batch_size, sequence_len]
        :return:
        """
        temp = [self.index_to_letter[tensor[0, i].item()] for i in range(tensor.size(1))]
        return temp

    def get_vocab_size(self):
        return len(self.letter_to_index)


tokenizer_text_to_sample = "qwertyuiop[]\\asdfghjkl;\'zxcvbnm,./QWERTYUIOP{}|ASDFGHJKL:\"ZXCVBNM<>?1234567890-=`!@#$%^&*()_+~ "
special_tokens = [
    "<model>", "</model>",
    "<user>", "</user>",
    "<system>", "</system>",
    "<pad>", "<unknown>"
]
tokenizer = Tokenizer(tokenizer_text_to_sample, special_tokens)

In [None]:
import torch
from torch import nn


class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.att_block = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        self.feedforward = nn.Sequential(
            nn.LazyLinear(out_features=embed_dim * 4),
            nn.LeakyReLU(),
            nn.LazyLinear(out_features=embed_dim)
        )
        self.queries = nn.LazyLinear(out_features=embed_dim)
        self.keys = nn.LazyLinear(out_features=embed_dim)
        self.values = nn.LazyLinear(out_features=embed_dim)

    def forward(self, token_sequence):
        q = self.queries(token_sequence)
        k = self.queries(token_sequence)
        v = self.queries(token_sequence)

        attended_sequence, att_output_weights = self.att_block(q, k, v)

        fed_forwards_sequence = self.feedforward(attended_sequence)

        return fed_forwards_sequence


In [None]:
import torch
from torch import nn


class Decoder(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.LazyLinear(out_features=vocab_size * 4),
            nn.LeakyReLU(),
            nn.LazyLinear(out_features=vocab_size)
        )

    def forward(self, attended_sequence):
        outputs = self.decoder(attended_sequence)
        return outputs

In [None]:
import torch
from torch import nn


class Model(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_transformer_blocks, max_sequence_length):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim).to(device=device)
        self.positional_embedding = nn.Embedding(max_sequence_length, embed_dim).to(device=device)

        # self.transformer_blocks = [TransformerBlock(embed_dim=embed_dim, num_heads=num_heads).to(device=device) for _ in
        #                            range(num_transformer_blocks)]
        self.transformer_block = TransformerBlock(embed_dim=embed_dim, num_heads=num_heads).to(device=device)
        self.decoder = Decoder(vocab_size).to(device=device)

    def forward(self, tokenized_sequence):
        """

        :param tokenized_sequence: tensor of size [batch_size, sequence_length] with int dtype for embedding the tokens.
        :return: a tensor of size [batch_size, sequence_len, vocabulary_size].
        """
        embedded_sequence = self.embedding(tokenized_sequence)  # is now size [batch_size, sequence_length, embed_dim]
        temp = torch.arange(embedded_sequence.size(1)).unsqueeze(0).to(device=device)
        positional_embedding = self.positional_embedding(temp)
        token_sequence = embedded_sequence + positional_embedding
        # for i in range(len(self.transformer_blocks)):
        #     token_sequence = self.transformer_blocks[i](token_sequence)
        token_sequence = self.transformer_block(token_sequence)
        predictions = self.decoder(token_sequence)
        return predictions

In [None]:
model = Model(
    vocab_size=tokenizer.get_vocab_size(),
    embed_dim=1024,
    num_heads=4,
    num_transformer_blocks=4,
    max_sequence_length=64,
).to(device=device)

sample_sentence = "Hello World!"
tokenized_sentence = tokenizer.tokenize(sample_sentence, pad_to_length=16)
model_out = model(tokenized_sentence.unsqueeze(0))
print(model_out.size())

In [None]:
import torch
from torch import nn

texts = [
    "A cat ate some food and then went to sleep on the bed.",
    "The quick brown fox jumps over the lazy dog.",
    "Once upon a time a cat caught a little mouse.",
]

tensor_texts = [tokenizer.tokenize(text, pad_to_length=32) for text in texts]

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model.train()
optimizer.zero_grad()

EPOCHS = 1000

for epoch in range(EPOCHS):
    print(f"E {(epoch + 1)}/{EPOCHS} - {((epoch + 1) / EPOCHS) * 100:.3f}%")
    for index, tensor_text in enumerate(tensor_texts):
        # print(texts[index])
        expected_input = tensor_text[0:-1].unsqueeze(0).to(torch.int)
        expected_output = tensor_text[1:].unsqueeze(0).to(torch.long)

        model_output = model(expected_input)
        model_letters = torch.argmax(model_output, dim=-1).to(torch.float)

        loss = loss_fn(model_output.permute(0, 2, 1), expected_output)
        loss.backward()
        optimizer.step()

        print(f"\t{tokenizer.detokenize(model_letters)}")