In [1]:
import torch
import torch.nn as nn
import pandas as pd
import math
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
import torch.nn.functional as F
import torch.nn.init as init

device = 'cuda'

In [2]:
with open('tiny_ss.txt', 'r') as f:
    text = f.readlines()

full_text = "".join(line.strip() for line in text)
char_list = list(full_text)

In [3]:
chars = set(char_list)
tokens_enc = {char : i for i, char in enumerate(chars)}
tokens_dec = {i : char for char, i in tokens_enc.items()}

encode = lambda s: [tokens_enc[c] for c in s]
decode = lambda l: "".join([tokens_dec[i] for i in l])

train_data = torch.tensor(encode(full_text), dtype=torch.long)

In [4]:
context_length = 256
batch_size = 16

class TextDataset(Dataset):
    def __init__(self, data, context_length):
        self.data = data
        self.context_length = context_length

    def __len__(self):
        return len(self.data) - self.context_length

    def __getitem__(self, idx):
        start_idx = torch.randint(len(self.data) - self.context_length, (1,)).item()
        input_seq = self.data[start_idx : start_idx + self.context_length]
        target_labels = self.data[start_idx + 1 : start_idx + self.context_length + 1]

        return input_seq, target_labels

dataset = TextDataset(train_data, context_length)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [5]:
class Head(nn.Module):
    def __init__(self, d_model, d_k, d_v, device):
        super().__init__()
        self.w_k = nn.Parameter(init.xavier_normal_(torch.empty(d_model, d_k).to(device)))
        self.w_q = nn.Parameter(init.xavier_normal_(torch.empty(d_model, d_k).to(device)))
        self.w_v = nn.Parameter(init.xavier_normal_(torch.empty(d_model, d_v).to(device)))

        self.d_k = d_k
        self.softmax = nn.Softmax(dim=-1)
        self.device = device

    def forward(self, x):
        self.keys = x @ self.w_k
        self.querys = x @ self.w_q
        self.values = x @ self.w_v

        seq_len = x.shape[1]
        mask = 1 - torch.tril(torch.ones(seq_len, seq_len, device=self.device))
        mask = mask.masked_fill(mask == 1, float('-inf'))

        self.att_linear = (self.querys @ self.keys.mT) / self.d_k**0.5
        self.att_mask = self.att_linear + mask
        self.att = self.softmax(self.att_mask)

        self.att_v = self.att @ self.values

        return self.att_v

In [6]:
class MultiHead(nn.Module):
    def __init__(self, d_model, d_k, d_v, num_heads, device):
        super().__init__()
        self.w_o = nn.Parameter(init.xavier_normal_(torch.empty(num_heads * d_v, d_model).to(device)))
        self.heads = nn.ModuleList([Head(d_model, d_k, d_v, device) for _ in range(num_heads)])

    def forward(self, x):
        heads_out = torch.cat([h(x) for h in self.heads], dim=-1)
        m_att_out = heads_out @ self.w_o

        return m_att_out

In [7]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, device):
        super().__init__()
        self.w_ff_1 = nn.Parameter(init.xavier_normal_(torch.empty(d_model, d_ff).to(device)))
        self.b_ff_1 = nn.Parameter(torch.zeros(d_ff).to(device))
        self.w_ff_2 = nn.Parameter(init.xavier_normal_(torch.empty(d_ff, d_model).to(device)))
        self.b_ff_2 = nn.Parameter(torch.zeros(d_model).to(device))
        self.relu = nn.ReLU()

    def forward(self, x):
        self.L_1 = (x @ self.w_ff_1) + self.b_ff_1
        self.L_1_dash = self.relu(self.L_1)
        self.L_2 = (self.L_1_dash @ self.w_ff_2) + self.b_ff_2
        return self.L_2

In [8]:
class Layer(nn.Module):
    def __init__(self, d_model, d_k, d_v, d_ff, num_heads, device):
        super().__init__()
        self.MultiHead = MultiHead(d_model, d_k, d_v, num_heads, device)
        self.FeedForward = FeedForward(d_model, d_ff, device)
        self.ln1 = nn.RMSNorm(d_model)
        self.ln2 = nn.RMSNorm(d_model)

    def forward(self, x):
        x = x + self.MultiHead(self.ln1(x))
        out = x + self.FeedForward(self.ln2(x))
        return out

In [9]:
class Transformer(nn.Module):
    def __init__(self, d_model, d_k, d_v, d_ff, num_heads, num_layers, d_vocab, device):
        super().__init__()
        self.layers = nn.ModuleList([Layer(d_model, d_k, d_v, d_ff, num_heads, device) for _ in range(num_layers)])
        self.w_emb = nn.Parameter(init.xavier_normal_(torch.empty(d_vocab, d_model).to(device)))
        self.w_unemb = nn.Parameter(init.xavier_normal_(torch.empty(d_model, d_vocab).to(device)))

        self.d_vocab = d_vocab
        self.softmax = nn.Softmax(dim=-1)
        self.ln_o = nn.RMSNorm(d_model)
        self.device = device

    def one_hot(self, x):
        batch_size, seq_len = x.shape
        one_hot_matrix = torch.zeros(batch_size, seq_len, self.d_vocab, device=self.device)
        indices = x.unsqueeze(-1)
        one_hot_matrix.scatter_(2, indices, 1)
        return one_hot_matrix

    def add_pos_enc(self, x):
        batch_size, seq_len, d_model = x.shape
        position = torch.arange(seq_len, device=self.device).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, device=self.device) * -(math.log(10000.0) / d_model))

        pe = torch.zeros(seq_len, d_model, device=self.device)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        x = x + pe
        return x

    def forward(self, x):
        x = self.one_hot(x)
        x = x @ self.w_emb
        x = self.add_pos_enc(x)
        for layer in self.layers:
            x = layer(x)
        x = self.ln_o(x)
        out = x @ self.w_unemb
        return out

In [10]:
d_model = 64
d_k = 16
d_v = 16
d_ff = 512
num_heads = 12
num_layers = 16
d_vocab = len(chars)

model = Transformer(d_model, d_k, d_v, d_ff, num_heads, num_layers, d_vocab, device).to(device)
optimiser = AdamW(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params}")

Total number of parameters: 1854656


In [None]:
for epoch in range(1):
    epoch_loss = 0
    num_batches = 0
    for batch in dataloader: #tqdm(dataloader):
        text, labels = batch
        text = text.to(device)
        labels = labels.to(device)
        out = model(text)

        loss = criterion(out.view(-1, d_vocab), labels.view(-1))

        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

        epoch_loss += loss.item()

        num_batches += 1

        if num_batches % 10 == 0:
            print(f"Batch {num_batches}, Loss: {epoch_loss/10}")
            epoch_loss = 0

        if num_batches == 1000:
            break

In [None]:
def generate_text(model, start_string, num_generate=256, context_length=256):
    model.eval()
    input_eval = torch.tensor(encode(start_string), dtype=torch.long, device=device).unsqueeze(0)
    text_generated = []

    with torch.no_grad():
        for _ in range(num_generate):
            input_eval_context = input_eval[:, -context_length:]
            predictions = model(input_eval_context)
            predictions = predictions[:, -1, :]

            predicted_id = torch.multinomial(F.softmax(predictions, dim=-1), num_samples=1)
            # predicted_id = torch.argmax(predictions, dim=-1, keepdim=True)

            text_generated.append(tokens_dec[predicted_id.item()])
            input_eval = torch.cat([input_eval, predicted_id], dim=-1)

    return start_string + "".join(text_generated)


generate_text(model, 'ROMEO: ')