In [1]:
import torch
from torch import tensor, sin, cos
from math import sqrt
from torch.nn.functional import softmax
import spacy
from torchtext.vocab import GloVe

In [2]:
glove = GloVe(dim=300)

In [3]:
def par_attention(queries: tensor, keys: tensor, values: tensor, dim: int) -> tensor:
    raw_weights = torch.bmm(queries, keys.transpose(1, 2))

    mask = torch.tril(torch.ones_like(raw_weights), diagonal=0)
    raw_weights = raw_weights.masked_fill(mask == 0, float('-inf'))
    # print(f"raw_weights.shape:{raw_weights.shape}\nraw_weights: {raw_weights}")

    scale_factor = sqrt(dim)
    scaled_weights = softmax(raw_weights / scale_factor, dim=2)
    # print(f"scaled_weights.shape:{scaled_weights.shape}\nscaled_weights: {scaled_weights}")

    # now scaled weights is a matrix where each row represents the scaled weights produced based on a given query.
    # meanwhile values just has a value vector on each row.

    reshaped_scaled_weights = scaled_weights.view(scaled_weights.shape[0], scaled_weights.shape[1], scaled_weights.shape[2], 1)
    reshaped_values = values.view(1, values.shape[0], values.shape[1], values.shape[2])

    scaled_values = reshaped_scaled_weights * reshaped_values

    contextualized_values = torch.sum(scaled_values, 2)
    return contextualized_values

def build_dictionary(file_path) -> (dict, dict):
    with open(file_path, 'r', encoding='utf-8') as file:
        content = file.read()
    
    tokenizer = spacy.load("en_core_web_sm")
    tokens = tokenizer(content)
    unique_words = set()
    for token in tokens:
        unique_words.add(str(token))
    word_to_id = {str(word): i for i, word in enumerate(unique_words)}
    id_to_word = {i: str(word) for i, word in enumerate(unique_words)}

    return word_to_id, id_to_word

def positional_embedding(word, pos) -> tensor:
    model_dims = 300

    positional_encoding = torch.tensor([0.0] * model_dims)
    for i in range(0, model_dims // 2):
        positional_encoding[2 * i] = sin(torch.tensor(pos / (10000 ** (2 * i / model_dims))))
        positional_encoding[2 * i + 1] = cos(torch.tensor(pos / (10000 ** (2 * i / model_dims))))

    embedding = glove[word]
    embedding += positional_encoding
    return embedding


def encode_input_string(str, context_len) -> tensor:
    tokenizer = spacy.load("en_core_web_sm")
    tokens = tokenizer(str)

    output = torch.zeros(size=[context_len, 300])
    for i, token in enumerate(tokens):
        output[i] = positional_embedding(token.text, i)

    return output

In [4]:
import torch.nn as nn

class AttentionHead(nn.Module):
    # For simplicity, I assume query, key, and value vectors have the same dimensionality
    def __init__(self, model_dim, vectors_dim):
        super().__init__()
        self.model_dim = model_dim
        self.vectors_dim = vectors_dim
        self.Q_proj = nn.Linear(model_dim, vectors_dim)
        self.K_proj = nn.Linear(model_dim, vectors_dim)
        self.V_proj = nn.Linear(model_dim, vectors_dim)

    def forward(self, x):
        # each row of x is a vector representing the meaning of the token at the corresponding position with whatever context we've attained so far.
        Q = self.Q_proj(x)
        K = self.K_proj(x)
        V = self.V_proj(x)
        # print("Shape of Q matrix: ", Q.shape)
        # print("Shape of K matrix: ", K.shape)
        # print("Shape of V matrix: ", V.shape)
        output = par_attention(Q, K, V, self.vectors_dim)
        return output

class MultiHeadAttention(nn.Module):
    def __init__(self, model_dim, num_heads):
        super().__init__()
        self.att_heads = nn.ModuleList([AttentionHead(model_dim, model_dim // num_heads) for _ in range(num_heads)])
        self.proj = nn.Linear(model_dim, model_dim)

    def forward(self, x):
        head_outputs = [head(x) for head in self.att_heads]
        x = torch.concat(head_outputs, dim=2)
        x = self.proj(x)
        return x
        
class TransformerLayer(nn.Module):
    def __init__(self, model_dim, num_heads, ff_hidden_dim, context_len):
        super().__init__()
        self.attention_block = MultiHeadAttention(model_dim, num_heads)
        self.norm1 = nn.LayerNorm(normalized_shape=[context_len, model_dim])
        self.ff1 = nn.Linear(model_dim, ff_hidden_dim)
        self.ff_relu = nn.ReLU()
        self.ff2 = nn.Linear(ff_hidden_dim, model_dim)
        self.norm2 = nn.LayerNorm(normalized_shape=[context_len, model_dim])

    def forward(self, x):
        x_res = x
        x = self.attention_block(x)
        x += x_res
        x = self.norm1(x)

        x_res = x
        x = self.ff1(x)
        x = self.ff_relu(x)
        x = self.ff2(x)
        x += x_res
        x = self.norm2(x)

        return x


In [5]:

class TransformerNetwork(nn.Module):
    def __init__(self, num_layers, model_dim, att_heads, ff_hidden_dim, context_len, output_dict_size):
        super().__init__()
        self.trans_layers = nn.ModuleList([TransformerLayer(model_dim, att_heads, ff_hidden_dim, context_len) for _ in range(num_layers)])
        self.word_predictor = nn.Linear(model_dim * context_len, output_dict_size)
        print("model_dim * context_len = ", model_dim * context_len)

    def forward(self, x):
        for layer in self.trans_layers:
            x = layer.forward(x)
        # print("Shape of x before view: ", x.shape)
        x = x.view(x.shape[0], -1)
        # print("Shape of x after view: ", x.shape)
        x = self.word_predictor(x)
        return x



In [6]:
# These parameters match what's described in "attention is all you need". The exception is that they probably use a different tokenizer and have the ability to output any token.
# Also not sure how they handle context length...
# paper_model = TransformerNetwork(num_layers=6, model_dim=512, att_heads=8, ff_hidden_dim=2048, context_len=256, output_dict_size=1)

word_to_id, id_to_word = build_dictionary('../data/much_ado_about_nothing_gut.txt')

In [7]:

dictionary_len = len(id_to_word)
context_len = 32
model = TransformerNetwork(num_layers=2, model_dim=300, att_heads=6, ff_hidden_dim=1200, context_len=context_len, output_dict_size=dictionary_len)

test_input = "The next word is"
encoded_input = encode_input_string(test_input, context_len)



model_dim * context_len =  9600


In [8]:
# This approach kinda sucks ay? Cause it assumes I already know what the last token is... Whatever, will iterate on this lol
def encode_inputs(input_list, context_len) -> tensor:
    output = torch.zeros(size=[len(input_list), context_len, 300])
    for i, input in enumerate(input_list):
        output[i] = encode_input_string(input, context_len)
    return output

def encode_outputs(output_tokens: [str]) -> tensor:
    output_cats = torch.zeros(size=[len(output_tokens)]).long()
    for i, token in enumerate(output_tokens):
        output_cats[i] = word_to_id[token]
    return output_cats

In [66]:
train_features = encode_inputs(["Did I not tell you she was innocent"], context_len)
train_labels = encode_outputs(["?"])
print(train_labels)

val_features = encode_inputs(["Well, I am glad that all things sort so well"], context_len)
val_labels = encode_outputs(["."])
print(val_labels)

tensor([9])
tensor([1037])


In [67]:
from torch.utils.data import Dataset

class CompletionDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels

    def __len__(self):
        return self.features.shape[0]

    def __getitem__(self, index):
        return self.features[index], self.labels[index]

train_dataset = CompletionDataset(train_features, train_labels)
val_dataset = CompletionDataset(val_features, val_labels)

In [68]:
from torch.utils.data import DataLoader

loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.2)
train_loader = DataLoader(dataset=train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=1, shuffle=False)

In [70]:
# model.train(False)
# batches = 0
# avg_loss = 0
# for step, (features, labels) in enumerate(train_loader):
#     # optimizer.zero_grad()
#     print("step=", step)
#     preds = model(features)
#     print(f"preds:{preds}\nlabels:{labels}")
#     loss = loss_func(preds, labels)
#     # loss.backward()
#     # optimizer.step()
#     print("loss=", loss)

# model.train(True)
# batches = 0
# avg_loss = 0
# for step, (features, labels) in enumerate(train_loader):
#     optimizer.zero_grad()
#     preds = model(features)
#     print(f"preds:{preds}\nlabels:{labels}")
#     loss = loss_func(preds, labels)
#     loss.backward()
#     optimizer.step()
    
#     print(f"Loss on step {step}: {loss}")


step= 0
preds:tensor([[ 0.3106,  0.7455,  0.0710,  ..., -0.4766, -0.6115, -0.2710]],
       grad_fn=<AddmmBackward0>)
labels:tensor([9])
loss= tensor(961.1148, grad_fn=<NllLossBackward0>)


In [71]:


def train_one_epoch():
    model.train(True)
    batches = 0
    avg_loss = 0
    for step, (features, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        preds = model(features)
        print(f"preds:{preds}\nlabels:{labels}")
        loss = loss_func(preds, labels)
        loss.backward()
        optimizer.step()
        
        print(f"Loss on step {step}: {loss}")

        avg_loss += loss
        batches = step + 1
    
    avg_loss = avg_loss / batches
    print(f"Average loss for training batches in this epoch: {avg_loss}")

    model.train(False)
    batches = 0
    avg_loss = 0
    for step, (features, labels) in enumerate(val_loader):
        preds = model(features)
        print(f"preds:{preds}\nlabels:{labels}")
        loss = loss_func(preds, labels)
        
        print(f"Loss on step {step}: {loss}")


        avg_loss += loss
        batches = step + 1

    avg_loss = avg_loss / batches
    print(f"Average loss for validation batches in this epoch: {avg_loss}")


In [72]:
train_one_epoch()
# train_one_epoch()


preds:tensor([[ 0.3106,  0.7455,  0.0710,  ..., -0.4766, -0.6115, -0.2710]],
       grad_fn=<AddmmBackward0>)
labels:tensor([9])
Loss on step 0: 961.1148071289062
Average loss for training batches in this epoch: 961.1148071289062
preds:tensor([[ 0.4711,  0.7358,  0.1544,  ..., -0.4126, -0.5657, -0.2135]],
       grad_fn=<AddmmBackward0>)
labels:tensor([1037])
Loss on step 0: 2874.47607421875
Average loss for validation batches in this epoch: 2874.47607421875
