In [124]:
import cupy as np
# import numpy as np
import pandas as pd
import re

device = 'cuda'

In [125]:
with open('tiny_ss.txt', 'r') as f:
    text = f.readlines()

full_text = "".join(line.strip() for line in text)

chars = sorted(set(full_text))
tokens_enc = {char: i for i, char in enumerate(chars)}
tokens_dec = {i: char for char, i in tokens_enc.items()}

encode = lambda s: [tokens_enc[c] for c in s]
decode = lambda l: "".join([tokens_dec[i] for i in l])

train_data = np.array(encode(full_text), dtype=np.int32)

In [126]:
context_length = 128
batch_size = 16

class TextDataset:
    def __init__(self, data, context_length):
        self.data = data
        self.context_length = context_length

    def __len__(self):
        return len(self.data) - self.context_length

    def __getitem__(self, idx=None):
        if idx is None:
            idx = np.random.randint(0, len(self))
        input_seq = self.data[idx : idx + self.context_length]
        target_labels = self.data[idx + 1 : idx + self.context_length + 1]
        return input_seq, target_labels

class DataLoader:
    def __init__(self, dataset, batch_size, shuffle=True):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self):
        while True:
            batch_inputs = []
            batch_targets = []
            for _ in range(self.batch_size):
                input_seq, target_labels = self.dataset.__getitem__()
                batch_inputs.append(input_seq)
                batch_targets.append(target_labels)
            yield np.stack(batch_inputs), np.stack(batch_targets)

dataset = TextDataset(train_data, context_length)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [127]:
## HELPER FUNCTIONS

def initialise_weights(dim1, dim2):
    return np.random.randn(dim1, dim2) * np.sqrt(2.0 / (dim1 + dim2))

def initialise_bias(dim):
    return np.zeros((dim,))

def one_hot(x, d_vocab=len(tokens_enc)):
    one_hot = np.zeros(x.shape + (d_vocab,), dtype=np.float32)
    idx = np.indices(x.shape)
    one_hot[(*idx, x)] = 1
    return one_hot

def get_look_forward_mask(seq_len):
    tril = np.tril(np.ones((seq_len, seq_len), dtype=np.float32))
    mask = np.where(tril == 1, 0.0, -1e9)
    return mask[np.newaxis, :, :]

def add_pos_enc(x):
    batch_size, seq_len, d_model = x.shape
    position = np.arange(seq_len).reshape(seq_len, 1)
    div_term = np.power(10000.0, (2 * np.arange(d_model // 2)) / d_model)

    pe = np.zeros((seq_len, d_model), dtype=np.float32)
    pe[:, 0::2] = np.sin(position / div_term)
    pe[:, 1::2] = np.cos(position / div_term)

    x = x + pe[np.newaxis, :, :]
    return x

def collect_over_b(x):
    return np.sum(x, axis=0) / x.shape[0]

In [128]:
class Adam:
    def __init__(self, value, lr=5e-4, beta1=0.9, beta2=0.999, eps=1e-8):
        self.value = value
        self.m = np.zeros_like(self.value)
        self.v = np.zeros_like(self.value)
        self.t = 0

        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps

    def update(self, grad):
        self.t += 1
        self.m = self.beta1 * self.m + (1 - self.beta1) * grad
        self.v = self.beta2 * self.v + (1 - self.beta2) * (grad ** 2)

        m_hat = self.m / (1 - self.beta1 ** self.t)
        v_hat = self.v / (1 - self.beta2 ** self.t)

        self.value -= self.lr * m_hat / (np.sqrt(v_hat) + self.eps)
        return self.value

In [129]:
class Softmax:
    def __call__(self, x, dim=-1):
        self.dim = dim
        x_max = np.max(x, axis=dim, keepdims=True)
        exp_x = np.exp(x - x_max)
        sum_exp_x = np.sum(exp_x, axis=dim, keepdims=True)
        self.y = exp_x / sum_exp_x
        return self.y

    def backward(self, prev_grad):
        B, L, dim = prev_grad.shape

        # ## explicit implementation
        # I = np.eye(dim)
        # I = np.tile(I, (B, L, 1, 1))
        # diag_y = I * self.y[..., None]
        # outer = np.einsum('blx, bly -> blxy', self.y, self.y)
        # J = diag_y - outer
        # d_x = np.einsum('blxy, bly -> blx', J, prev_grad)


        ## efficient implementation (ChatGPT generated)
        dot = np.sum(prev_grad * self.y, axis=self.dim, keepdims=True)
        d_x = self.y * (prev_grad - dot)

        return d_x

In [130]:
class RMSNorm:
    def __init__(self, d_rms):
        self.g = np.random.rand(d_rms)
        self.g_adam = Adam(self.g)

    def __call__(self, x, dim=-1):
        self.x = x
        g = self.g.reshape(1, -1)
        self.r = np.sqrt(np.mean(x**2, axis=dim, keepdims=True))
        return (x / self.r) * g

    def backward(self, prev_grad):
        B, L, d_rms = prev_grad.shape
        I = np.eye(d_rms)
        I = np.tile(I, (B, L, 1, 1))
        r = self.r[..., np.newaxis]

        outer = np.einsum('blx, bly -> blxy', self.x, self.x)
        dy_dx = self.g * ((I / r) - (outer / (r**3 * d_rms)))
        dy_dg = self.x / self.r

        d_x = np.einsum('blx, blxy -> bly', prev_grad, dy_dx)
        d_g = collect_over_b(collect_over_b(dy_dg * prev_grad))
        self.g = self.g_adam.update(d_g)
        return d_x

In [131]:
class ReLU:
    def __call__(self, x):
        self.x = np.maximum(0, x)
        return self.x

    def backward(self, prev_grad):
        mask = (self.x > 0).astype(int)
        return prev_grad * mask

In [132]:
class Head:
    def __init__(self, d_model, d_k, d_v):
        self.w_k = initialise_weights(d_model, d_k)
        self.w_q = initialise_weights(d_model, d_k)
        self.w_v = initialise_weights(d_model, d_v)
        self.d_k = d_k
        self.softmax = Softmax()

        self.w_k_adam = Adam(self.w_k)
        self.w_q_adam = Adam(self.w_q)
        self.w_v_adam = Adam(self.w_v)

    def forward(self, x):
        _, seq_len, _ = x.shape
        mask = get_look_forward_mask(seq_len)
        self.x = x

        self.K = x @ self.w_k
        self.Q = x @ self.w_q
        self.V = x @ self.w_v

        self.att_linear = (self.Q @ self.K.transpose(0,2,1)) / np.sqrt(self.d_k)
        self.att_mask = self.att_linear + mask
        self.att = self.softmax(self.att_mask)
        self.att_v = self.att @ self.V

        return self.att_v

    def backward(self, prev_grad):
        d_V = self.att.transpose(0,2,1) @ prev_grad
        d_att = prev_grad @ self.V.transpose(0,2,1)
        d_att_mask = self.softmax.backward(d_att)
        d_att_mask[self.att <= 1e-8] = 0

        d_Q = (1 / np.sqrt(self.d_k)) * (d_att_mask @ self.K)
        d_K = (1 / np.sqrt(self.d_k)) * (d_att_mask.transpose(0,2,1) @ self.Q)

        d_w_v = collect_over_b(self.x.transpose(0,2,1) @ d_V)
        d_w_q = collect_over_b(self.x.transpose(0,2,1) @ d_Q)
        d_w_k = collect_over_b(self.x.transpose(0,2,1) @ d_K)

        d_X_1 = ((d_K @ self.w_k.T) + (d_Q @ self.w_q.T) + (d_V @ self.w_v.T)) / 3

        self.w_k = self.w_k_adam.update(d_w_k)
        self.w_q = self.w_q_adam.update(d_w_q)
        self.w_v = self.w_v_adam.update(d_w_v)

        return d_X_1

In [133]:
class MultiHead:
    def __init__(self, d_model, d_k, d_v, num_heads):
        self.w_o = initialise_weights(num_heads * d_v, d_model)
        self.heads = [Head(d_model, d_k, d_v) for _ in range(num_heads)]
        self.num_heads = num_heads
        self.d_v = d_v

        self.w_o_adam = Adam(self.w_o)

    def forward(self, x):
        B, L, _ = x.shape
        self.m_att = np.zeros((B, L, self.num_heads*self.d_v))

        for idx, head in enumerate(self.heads):
            self.m_att[..., idx*self.d_v : (idx+1)*self.d_v] = head.forward(x)
        self.m_att_output = self.m_att @ self.w_o

        return self.m_att_output

    def backward(self, prev_grad):
        B, L, _ = self.m_att.shape
        _, _, d_model = prev_grad.shape

        d_w_o = collect_over_b(self.m_att.transpose(0,2,1) @ prev_grad)
        d_m_att = prev_grad @ self.w_o.T

        head_gradients = np.zeros((B, L, d_model))
        for idx, head in enumerate(self.heads):
            head_gradients += head.backward(d_m_att[..., idx*self.d_v : (idx+1)*self.d_v])
        head_gradients /= self.num_heads

        self.w_o = self.w_o_adam.update(d_w_o)

        return head_gradients

In [134]:
class FeedForward:
    def __init__(self, d_model, d_ff):
        self.w_ff_1 = initialise_weights(d_model, d_ff)
        self.b_ff_1 = initialise_bias(d_ff)
        self.w_ff_2 = initialise_weights(d_ff, d_model)
        self.b_ff_2 = initialise_bias(d_model)
        self.relu = ReLU()

        self.w_ff_1_adam = Adam(self.w_ff_1)
        self.b_ff_1_adam = Adam(self.b_ff_1)
        self.w_ff_2_adam = Adam(self.w_ff_2)
        self.b_ff_2_adam = Adam(self.b_ff_2)

    def forward(self, x):
        self.x_double_dash = x
        self.ff_1 = (x @ self.w_ff_1) + self.b_ff_1
        self.ff_1_dash = self.relu(self.ff_1)
        self.ff_2 = (self.ff_1_dash @ self.w_ff_2) + self.b_ff_2
        return self.ff_2

    def backward(self, prev_grad):
        d_w_ff_2 = collect_over_b(self.ff_1_dash.transpose(0,2,1) @ prev_grad)
        d_b_ff_2 = collect_over_b(collect_over_b(prev_grad))
        d_ff_1_dash = prev_grad @ self.w_ff_2.T

        d_ff_1 = self.relu.backward(d_ff_1_dash)

        d_w_ff_1 = collect_over_b(self.x_double_dash.transpose(0,2,1) @ d_ff_1)
        d_b_ff_1 = collect_over_b(collect_over_b(d_ff_1))
        d_x_double_dash = d_ff_1 @ self.w_ff_1.T

        self.w_ff_1 = self.w_ff_1_adam.update(d_w_ff_1)
        self.b_ff_1 = self.b_ff_1_adam.update(d_b_ff_1)
        self.w_ff_2 = self.w_ff_2_adam.update(d_w_ff_2)
        self.b_ff_2 = self.b_ff_2_adam.update(d_b_ff_2)

        return d_x_double_dash

In [135]:
class Layer:
    def __init__(self, d_model, d_k, d_v, d_ff, num_heads):
        self.multihead = MultiHead(d_model, d_k, d_v, num_heads)
        self.feedforward = FeedForward(d_model, d_ff)
        self.rms1 = RMSNorm(d_model)
        self.rms2 = RMSNorm(d_model)

    def forward(self, x):
        x_norm = self.rms1(x)
        m_att_res = x + self.multihead.forward(x_norm)
        m_att_res_norm = self.rms2(m_att_res)
        out = m_att_res + self.feedforward.forward(m_att_res_norm)
        return out

    def backward(self, prev_grad):
        d_m_att_res_norm = self.feedforward.backward(prev_grad)
        d_m_att_res = self.rms2.backward(d_m_att_res_norm) + prev_grad

        d_x_norm = self.multihead.backward(d_m_att_res)
        d_x = self.rms1.backward(d_x_norm) + d_m_att_res

        return d_x

In [136]:
class Transformer:
    def __init__(self, d_model, d_k, d_v, d_ff, d_vocab, num_heads, num_layers):
        self.layers = [Layer(d_model, d_k, d_v, d_ff, num_heads) for _ in range(num_layers)]
        self.w_emb = initialise_weights(d_vocab, d_model)
        self.w_unemb = initialise_weights(d_model, d_vocab)
        self.rms_out = RMSNorm(d_model)
        self.softmax = Softmax()

        self.w_unemb_adam = Adam(self.w_unemb)
        self.w_emb_adam = Adam(self.w_emb)

    def forward(self, x):
        self.x_one_hot = x
        x_emb = self.x_one_hot @ self.w_emb
        x = add_pos_enc(x_emb)
        for layer in self.layers:
            x = layer.forward(x)
        self.out = self.rms_out(x)
        self.pred = self.out @ self.w_unemb
        self.softmax_pred = self.softmax(self.pred)
        return self.softmax_pred

    def backward(self, outputs, labels):
        d_pred = outputs - labels
        d_out = d_pred @ self.w_unemb.T
        d_w_unemb = collect_over_b(self.out.transpose(0,2,1) @ d_pred)

        prev_grad = self.rms_out.backward(d_out)
        for layer in reversed(self.layers):
            prev_grad = layer.backward(prev_grad)
        d_w_emb = collect_over_b(self.x_one_hot.transpose(0,2,1) @ prev_grad)

        self.w_unemb = self.w_unemb_adam.update(d_w_unemb)
        self.w_emb = self.w_emb_adam.update(d_w_emb)

    def generate(self, prompt, det=False, max_len=100):
        prompt_tok = np.array(encode(prompt))
        generated_text = list(prompt_tok)
        output_text = list([int(i) for i in prompt_tok])

        for _ in range(max_len):
            prompt_one_hot = one_hot(np.array(generated_text)[-context_length:])[np.newaxis, ...]
            pred = self.forward(prompt_one_hot)[-1, -1, :]

            if det:
                pred_id = np.argmax(pred)
            else:
                pred_id = np.random.choice(len(pred), size=1, p=pred)[0]

            generated_text.append(pred_id)
            output_text.append(int(pred_id))

        return decode(output_text)

In [137]:
d_model = 64
d_k = 16
d_v = 16
d_ff = 128
num_heads = 4
num_layers = 6
d_vocab = len(tokens_enc)

model = Transformer(d_model, d_k, d_v, d_ff, d_vocab, num_heads, num_layers)

In [None]:
batch_loss = 0
for idx, batch in enumerate(train_loader):
    text, labels = batch
    text = one_hot(text)
    labels = one_hot(labels)

    outputs = model.forward(text)

    correct_probs = np.sum(outputs.reshape(-1, d_vocab) * labels.reshape(-1, d_vocab), axis=-1)
    batch_loss += -np.mean(np.log(correct_probs))

    if idx % 10 == 0:
        print(f'batch {idx}: ', batch_loss/10)
        batch_loss = 0
    model.backward(outputs, labels)

    if idx == 10000:
        break

In [None]:
print(model.generate('ROMEO: '))