In [46]:
import torch
from torch import tensor, sin, cos
from math import sqrt
from torch.nn.functional import softmax
import spacy
import torch.nn as nn
from torch.utils.data import DataLoader


In [47]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert torch.cuda.is_available()
device = torch.device("cuda")

In [48]:
train_files = [
    "../data/_part1.txt",
    # "../data/_part2.txt",
    # "../data/_part3.txt",
    # "../data/_part4.txt",
    # "../data/_part5.txt",
    # "../data/_part6.txt",
    # "../data/_part7.txt"
]
test_files = ["../data/much_ado_about_nothing_gut.txt"]

train_texts = []
test_texts = []

for file_name in train_files:
    with open(file_name, 'r', encoding='utf-8') as file:
        train_texts.append(file.read())

for file_name in test_files:
    with open(file_name, 'r', encoding='utf-8') as file:
        test_texts.append(file.read())

Set up our tokenizer and 3rd party embedding library

In [49]:
tokenizer = spacy.load("en_core_web_sm")
all_tokens = []
all_tokens.extend(['<PAD>', '<UNK>']) # special tokens

for text in train_texts + test_texts:
    doc = tokenizer(text)
    tokens = [token.text for token in doc]
    all_tokens.extend(tokens)

unique_tokens = set(all_tokens)
vocab = {token: i for i, token in enumerate(unique_tokens)}
reverse_vocab = {i: token for i, token in enumerate(unique_tokens)}

Define key helper functions used throughout training and inference

In [50]:
def par_attention(queries: tensor, keys: tensor, values: tensor, dim: int) -> tensor:
    raw_weights = torch.bmm(queries, keys.transpose(1, 2))

    mask = torch.tril(torch.ones_like(raw_weights), diagonal=0)
    raw_weights = raw_weights.masked_fill(mask == 0, float('-inf'))
    # print(f"raw_weights.shape:{raw_weights.shape}\nraw_weights: {raw_weights}")

    scale_factor = sqrt(dim)
    scaled_weights = softmax(raw_weights / scale_factor, dim=2)
    # print(f"scaled_weights.shape:{scaled_weights.shape}\nscaled_weights: {scaled_weights}")

    # now scaled weights is a matrix where each row represents the scaled weights produced based on a given query.
    # meanwhile values just has a value vector on each row.

    reshaped_scaled_weights = scaled_weights.view(scaled_weights.shape[0], scaled_weights.shape[1], scaled_weights.shape[2], 1)
    reshaped_values = values.view(values.shape[0], values.shape[1], 1, values.shape[2])

    scaled_values = reshaped_scaled_weights * reshaped_values

    contextualized_values = torch.sum(scaled_values, 2)
    return contextualized_values

def prep_input_string(str, context_len) -> tensor:
    """Takes an input string with up to context_len tokens and returns a tensor full of integers, which can be passed into the model"""
    tokens = tokenizer(str)

    output = torch.full([context_len], vocab['<PAD>'])
    for in_pos in range(len(tokens)):
        out_pos = context_len - len(tokens) + in_pos
        output[out_pos] = vocab[tokens[in_pos].text]

    return output

def prep_tokens(tokens, length) -> tensor:
    output = torch.full([length], vocab['<PAD>'])
    for in_pos in range(len(tokens)):
        out_pos = length - len(tokens) + in_pos
        output[out_pos] = vocab[tokens[in_pos].text]

    return output

# slice_offset is the number of tokens separating the start of one slice from the start of the previous.
# slice_offset == slice_length means no overlap, slice_offset == 1 means maximum overlap.
def slice_text(text: str, slice_length, slice_offset, context_len) -> tensor:
    slices = []
    tokens = tokenizer(text)

    for i in range(0, len(tokens), slice_offset):
        slices.append(tokens[i:i+slice_length])

    output = torch.zeros([len(slices), context_len + 1]) # use context_len + 1 because we need to include the label
    for i, slice in enumerate(slices):
        output[i] = prep_tokens(slice, context_len + 1)

    assert output.shape[1] == context_len + 1
    return output.to(device)

def slice_by_line(text: str, context_len) -> tensor:
    slices = text.split("\n")
    tokens = [tokenizer(slice) for slice in slices]

    output = torch.zeros([len(tokens), context_len + 1])
    for i, token_line in enumerate(tokens):
        output[i] = prep_tokens(token_line, context_len + 1)

    return output.to(device)

In [51]:
class PositionalEncoding(nn.Module):
    def __init__(self, dims, context_len):
        super().__init__()
        self.dims = dims
        self.context_len = context_len
        self.proj = nn.Linear(1, self.dims)

        positional_matrix = torch.zeros([self.context_len, self.dims])
        for pos in range(0, self.context_len):
            for i in range(0, self.dims // 2):
                positional_matrix[pos][2 * i] = sin(torch.tensor(pos / (10000 ** (2 * i / self.dims))))
                positional_matrix[pos][2 * i + 1] = cos(torch.tensor(pos / (10000 ** (2 * i / self.dims))))
        positional_matrix = positional_matrix.to(device)
        self.register_buffer('positional_matrix', positional_matrix)
        self.positional_matrix = self.positional_matrix.to(device)


    def forward(self, x: tensor) -> tensor:
        # x is token ids. we'll say it's context_len integers packed into a tensor, where each one represents a token. it can also be batched.
        output = torch.zeros([x.shape[0], self.context_len, self.dims]).to(device)
        for batch in range(0, x.shape[0]):
            output[batch] = self.proj(x[batch].view(x.shape[1], -1))
            output[batch] += self.positional_matrix
        # print(f"self.context_len={self.context_len}")
        # print(f"Shape of x before assert: {x.shape}")
        # assert x.shape[1] == self.context_len
        # output = self.proj(x)
        # output += self.positional_matrix
        return output

Define the architecture of the model, including all subcomponents

In [52]:
class AttentionHead(nn.Module):
    # For simplicity, I assume query, key, and value vectors have the same dimensionality
    def __init__(self, model_dim, vectors_dim):
        super().__init__()
        self.model_dim = model_dim
        self.vectors_dim = vectors_dim
        self.Q_proj = nn.Linear(model_dim, vectors_dim, bias=False)
        self.K_proj = nn.Linear(model_dim, vectors_dim, bias=False)
        self.V_proj = nn.Linear(model_dim, vectors_dim, bias=False)

    def forward(self, x):
        # each row of x is a vector representing the meaning of the token at the corresponding position with whatever context we've attained so far.
        Q = self.Q_proj(x)
        K = self.K_proj(x)
        V = self.V_proj(x)
        # print("Shape of Q matrix: ", Q.shape)
        # print("Shape of K matrix: ", K.shape)
        # print("Shape of V matrix: ", V.shape)
        output = par_attention(Q, K, V, self.vectors_dim)
        return output

class MultiHeadAttention(nn.Module):
    def __init__(self, model_dim, num_heads):
        super().__init__()
        self.att_heads = nn.ModuleList([AttentionHead(model_dim, model_dim // num_heads) for _ in range(num_heads)])
        self.proj = nn.Linear(model_dim, model_dim, bias=False)

    def forward(self, x):
        head_outputs = [head(x) for head in self.att_heads]
        x = torch.concat(head_outputs, dim=2)
        x = self.proj(x)
        return x
        
class TransformerLayer(nn.Module):
    def __init__(self, model_dim, num_heads, ff_hidden_dim, context_len):
        super().__init__()
        self.attention_block = MultiHeadAttention(model_dim, num_heads)
        self.norm1 = nn.LayerNorm(normalized_shape=[context_len, model_dim])
        self.ff1 = nn.Linear(model_dim, ff_hidden_dim)
        self.ff_relu = nn.ReLU()
        self.ff2 = nn.Linear(ff_hidden_dim, model_dim)
        self.norm2 = nn.LayerNorm(normalized_shape=[context_len, model_dim])

    def forward(self, x):
        x_res = x
        x = self.attention_block(x)
        x += x_res
        x = self.norm1(x)

        x_res = x
        x = self.ff1(x)
        x = self.ff_relu(x)
        x = self.ff2(x)
        x += x_res
        x = self.norm2(x)

        return x

class TransformerNetwork(nn.Module):
    def __init__(self, num_layers, model_dim, att_heads, ff_hidden_dim, context_len, output_dict_size):
        super().__init__()
        self.encode_embed = PositionalEncoding(model_dim, context_len)
        self.trans_layers = nn.ModuleList([TransformerLayer(model_dim, att_heads, ff_hidden_dim, context_len) for _ in range(num_layers)])

        # self.test_ff = nn.Linear(model_dim * context_len, 2048)
        # self.test_relu = nn.ReLU()
        # self.test_ff2 = nn.Linear(2048, model_dim * context_len)
        # self.test_relu2 = nn.ReLU()

        self.word_predictor = nn.Linear(model_dim * context_len, output_dict_size)
        # print(f"word_predictor input dimension: {model_dim * context_len}\noutput dimension: {output_dict_size}")

    def forward(self, x):
        # print(f"Received x of shape: {x.shape}")
        x = self.encode_embed(x)
        for layer in self.trans_layers:
            x = layer.forward(x)
        x = x.view(x.shape[0], -1)
        # print(f"Reshaped x to shape: {x.shape}")

        # x = self.test_ff(x)
        # x = self.test_relu(x)
        # x = self.test_ff2(x)
        # x = self.test_relu2(x)

        x = self.word_predictor(x)
        return x

Tools to quickly build a dataset that can be fed into the model

In [53]:
from torch.utils.data import Dataset

class CompletionDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels.long()

    def __len__(self):
        return self.features.shape[0]

    def __getitem__(self, index):
        return self.features[index], self.labels[index]

# Note: slices include features + label. So if you have context length 256, you can set slice length 257 and be fine.
def build_dataset(slices: tensor) -> CompletionDataset:    
    features = slices[:, :-1]
    labels = slices[:, -1]
    
    dataset = CompletionDataset(features, labels)
    return dataset

context_len = 16
slice_length = context_len + 1
slice_offset = slice_length

# note: slice_text returns an n by slice_length tensor of ints. (from vocab)
train_slices = [] # list of tensors
for text in train_texts:
    # train_slices.append(slice_by_line(text, context_len))
    train_slices.append(slice_text(text, slice_length, slice_offset, context_len))
    train_slices.append(slice_text(text, slice_length - 2, 1, context_len))
    train_slices.append(slice_text(text, 5, 1, context_len))
train_dataset = build_dataset(torch.cat(train_slices, dim=0))

test_slices = [] # list of tensors
for text in test_texts:
    # test_slices.append(slice_by_line(text, context_len))
    test_slices.append(slice_text(text, slice_length - 3, 1, context_len))
test_dataset = build_dataset(torch.cat(test_slices, dim=0))

In [54]:
def check_input_data(input):
    features = input[0].int().tolist()
    label = input[1].int().item()
    features_str = [reverse_vocab[f] for f in features]
    label_str = reverse_vocab[label]
    print(f"Features:\n{features_str}")
    print(f"Label:\n{label_str}")

check_input_data(test_dataset[0])
check_input_data(train_dataset[0])
len(test_dataset)

Features:
['<PAD>', '<PAD>', '<PAD>', 'MUCH', 'ADO', 'ABOUT', 'NOTHING', '\n\n', 'by', 'William', 'Shakespeare', '\n\n\n\n\n', 'DRAMATIS', 'PERSONAE', '\n\n', 'DON']
Label:
PEDRO
Features:
['The', 'Complete', 'Works', 'of', 'William', 'Shakespeare', '\n\n', 'by', 'William', 'Shakespeare', '\n\n\n\n\n                    ', 'Contents', '\n\n    ', 'THE', 'SONNETS', '\n    ']
Label:
ALL


32167

Initialize model. Output dict size is the size of the final layer.

In [55]:
dictionary_len = len(vocab)
model = TransformerNetwork(num_layers=4, model_dim=256, att_heads=4, ff_hidden_dim=1024, context_len=context_len, output_dict_size=dictionary_len)
model.to(device)
print(f"dictionary_len: {dictionary_len}")

loss_func = torch.nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)

dictionary_len: 13075


In [56]:
def train_one_epoch(do_validation: bool):
    model.train(True)
    torch.set_printoptions(profile="short")
    batches = 0
    avg_loss = 0
    for step, (features, labels) in enumerate(train_loader):
        features, labels = features.to(device), labels.to(device)
        optimizer.zero_grad()
        preds = model(features)
        # print(f"preds:{preds}\nlabels:{labels}")
        loss = loss_func(preds, labels)
        loss.backward()

        # if step % 10 == 0:  # Print every 10 batches
        #     for name, param in model.named_parameters():
        #         if param.requires_grad:
        #             print(f"Gradient data for {name}:", param.grad)
        #             print(f"Checking if gradients are fully zeroed: {torch.all(param.grad == 0.0).item()}")
        #             print(f"Shape: {param.grad.shape}")
        #             print(f"Mean: {param.grad.mean()}")
        #             print(f"Std: {param.grad.std()}")
        #             print(f"Min: {param.grad.min()}")
        #             print(f"Max: {param.grad.max()}")

        optimizer.step()

        # if step % 20 == 0:
        #     print(f"Loss on batch {step}: {loss}")

        avg_loss += loss
        batches = step + 1
    
    avg_loss = avg_loss / batches
    print(f"Average loss for training batches in this epoch: {avg_loss}")

    if do_validation:
        model.train(False)
        batches = 0
        avg_loss = 0
        for step, (features, labels) in enumerate(test_loader):
            features, labels = features.to(device), labels.to(device)
            preds = model(features)
            # print(f"preds:{preds}\nlabels:{labels}")
            loss = loss_func(preds, labels)
            
            # if step % 20 == 0:
            #     print(f"Loss on batch {step}: {loss}")

            avg_loss += loss
            batches = step + 1

        avg_loss = avg_loss / batches
        print(f"Average loss for validation batches in this epoch: {avg_loss}")


In [57]:
for i in range(1):
    print("Epoch", i)
    train_one_epoch(False)
# train_one_epoch()


Epoch 0
Average loss for training batches in this epoch: 6.214040756225586


In [58]:
model.eval()  # Set the model to evaluation mode
correct = 0
total = 0

with torch.no_grad():  # Deactivates autograd, reduces memory usage and speeds up computations
    for features, labels in test_loader:
        outputs = model(features)
        _, predicted = torch.max(outputs.data, 1)  # Get the index of the max log-probability
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Accuracy on test set: {accuracy}%")


Accuracy on test set: 11.275530823514782%


In [59]:
def infer_completion(input_text: str, context_len):
    encoded_input = prep_input_string(input_text, context_len).unsqueeze(0).float()
    
    model.train(False)
    pred = model(encoded_input)
    return reverse_vocab[torch.argmax(softmax(pred, dim=1), dim=1).item()]

In [60]:
infer_completion("", context_len)

KeyError: '+'