In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW

import lightning as L
from torch.utils.data import DataLoader, TensorDataset

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
eos = "<EOS>"
vocab = ["what", "is", "statquest", "awesome", eos]
n_vocab = len(vocab)

In [18]:
class Transformer(L.LightningModule):
    def __init__(self):
        super().__init__()

        self.ndim = 2

        embedding_mean = torch.tensor(np.zeros((n_vocab, self.ndim)))
        embedding_std = torch.tensor(np.ones((n_vocab, self.ndim)))
        attention_mean = torch.tensor(np.zeros((self.ndim, self.ndim)))
        attention_std = torch.tensor(np.ones((self.ndim, self.ndim)))

        self.word_embedding_weights = nn.Parameter(torch.normal(mean=embedding_mean, std=embedding_std), requires_grad=True)
        self.query_weights = nn.Parameter(torch.normal(mean=attention_mean, std=attention_std), requires_grad=True)
        self.key_weights = nn.Parameter(torch.normal(mean=attention_mean, std=attention_std), requires_grad=True)
        self.value_weights = nn.Parameter(torch.normal(mean=attention_mean, std=attention_std), requires_grad=True)

    def word_embedding(self, inp):
        return inp @ self.word_embedding_weights
    
    def positional_embedding(self, seq_len, n=10000):
        P = torch.tensor(np.zeros((seq_len, self.ndim)))

        for k in range(seq_len):
            for i in np.arange(int(self.ndim/2)):
                denominator = np.power(n, 2*i/self.ndim)
                P[k, 2*i] = np.sin(k/denominator)
                P[k, 2*i+1] = np.cos(k/denominator)

        return P
    
    def self_attention(self, inp):
        querys = inp @ self.query_weights
        keys = inp @ self.key_weights
        querys @ torch.transpose(keys)
        values = inp @ self.value_weights

    def forward(self, inp):
        word_embedding_res = self.word_embedding(inp)
        pos_embed_res = self.positional_embedding(len(word_embedding_res), 10) + word_embedding_res
        self_attention_res = self.self_attention(pos_embed_res)
        return self_attention_res

    def configure_optimizers(self):
        return AdamW(self.parameters())

    def training_step(self, batch, batch_idx):
        input_i, label_i = batch
        output_i = self.forward(input_i)
        loss = (output_i - label_i) ** 2

        return loss

In [19]:
model = Transformer()

In [22]:
input_sentence = "what is statquest"
words_in_sentence = input_sentence.split(" ") + [eos]
one_hot_vectors = torch.tensor([[1. if vocab[i] == word else 0. for i in range(n_vocab)] for word in words_in_sentence], dtype=torch.float64)
model(one_hot_vectors).detach()

tensor([[ 0.1035, -0.1948],
        [-0.3839, -0.9073],
        [-0.3583,  0.2881],
        [ 0.6836,  0.9758]], dtype=torch.float64)