# Imports

In [1]:
pip install einops fancy_einsum

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[K     |████████████████████████████████| 41 kB 509 kB/s 
[?25hCollecting fancy_einsum
  Downloading fancy_einsum-0.0.3-py3-none-any.whl (6.2 kB)
Installing collected packages: fancy-einsum, einops
Successfully installed einops-0.6.0 fancy-einsum-0.0.3


In [2]:
# %%
from ast import Mult
import matplotlib.pyplot as plt
from regex import D
import seaborn as sns
import torch as t
from torch import nn, norm, optim
from torch import Tensor
from torch.nn import functional
from einops import reduce, repeat, rearrange
import plotly.express as px
import numpy as np
from fancy_einsum import einsum
from math import sqrt
from dataclasses import dataclass
from torch.utils.data import Dataset, DataLoader
from typing import Callable, Union, Optional
from tqdm.notebook import tqdm_notebook
from torch.distributions.categorical import Categorical
from collections import OrderedDict
import re

In [3]:
@dataclass(frozen=True)
class TransformerConfig:
    '''Constants used throughout your decoder-only transformer model.'''

    num_layers: int
    num_heads: int
    vocab_size: int
    hidden_size: int
    max_seq_len: int
    dropout: float = 0.1
    layer_norm_epsilon: float = 1e-05


# %%


def multihead_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor, num_heads: int):
    '''
    Implements multihead masked attention on the matrices Q, K and V.

    Q: shape (batch, seq, nheads*headsize)
    K: shape (batch, seq, nheads*headsize)
    V: shape (batch, seq, nheads*headsize)
    '''

    Q = rearrange(Q, 'batch seq (nheads h_size) -> batch seq nheads h_size', nheads = num_heads)
    K = rearrange(K, 'batch seq (nheads h_size) -> batch seq nheads h_size', nheads = num_heads)
    V = rearrange(V, 'batch seq (nheads h_size) -> batch seq nheads h_size', nheads = num_heads)

    seq_len = Q.shape[1]
    head_size = Q.shape[-1]

    attention_scores = einsum('b seq_q nheads h_size, b seq_k nheads h_size -> b nheads seq_q seq_k', Q, K)

    mask = t.zeros(size=(seq_len, seq_len), dtype=dtype, device=device)
    for i in range(seq_len):
        mask[..., i, i+1:] = -t.inf

    attention_scores += mask
    attention_probabilities = functional.softmax(attention_scores / sqrt(head_size), dim=-1) 

    values = einsum('b seq_k nheads h_size, b nheads seq_q seq_k -> b nheads seq_q h_size', V, attention_probabilities)
    return rearrange(values, 'b nheads seq_q h_size -> b seq_q (nheads h_size)')

# If the above is wrong, the most likely culprits are the following:
# the ordering of nheads and h_size in the parentheses in the rearrange on the last line
# or of nheads and h_size in the parentheses in the rearranges in the first three lines

# If the above is wrong, the most likely culprits are the following:
# the ordering of nheads and h_size in the parentheses in the rearrange on the last line
# or of nheads and h_size in the parentheses in the rearranges in the first three lines

# %%

class MultiheadMaskedAttention(nn.Module):
    W_QKV: nn.Linear
    W_O: nn.Linear

    def __init__(self, hidden_size: int, num_heads: int):

        super().__init__()

        self.num_heads = num_heads

        self.W_QKV = nn.Linear(in_features=hidden_size, out_features=3*hidden_size, dtype=dtype, device=device)
        self.W_O = nn.Linear(in_features=hidden_size, out_features=hidden_size, dtype=dtype, device=device)


    def forward(self, x: t.Tensor) -> t.Tensor:
        '''
        x: shape (batch, seq, hidden_size)

        Return: shape (batch, seq, hidden_size)
        '''

        hidden_size = x.shape[-1]

        QKV = self.W_QKV(x)

        # The below should also work instead of the following 3 lines
        # Q, K, V = rearrange(QKV, 'b seq (three hidden_size) -> three b seq hidden_size', three=3)

        Q = QKV[..., :hidden_size]
        K = QKV[..., hidden_size:2*hidden_size]
        V = QKV[..., 2*hidden_size:]

        attention_values = multihead_masked_attention(Q, K, V, self.num_heads)

        return self.W_O(attention_values)

# If the above is wrong, the most likely culprits are the following: 
# Maybe the second line of __init__ shouldn't have out_features = 3 * hidden_size

# %%

class PositionalEncoding(nn.Module):

    def __init__(self, max_seq_len: int, embedding_dim: int):

        super().__init__()

        graph1 = t.arange(max_seq_len, dtype=dtype, device=device)
        graph2 = 1 / 1e4 ** (t.arange(0,embedding_dim,step=2, dtype=dtype, device=device) / embedding_dim)

        graph = t.outer(graph1, graph2)
        graph = rearrange(t.cat([t.sin(graph), t.cos(graph)], dim=1), 'L (d1 d2) -> L (d2 d1)', d1=2)

        self.register_buffer('PE_matrix', graph)


    def forward(self, x: Tensor) -> Tensor:
        '''
        x: shape (batch, seq_len, embedding_dim)
        '''

        seq_len = x.shape[1]
        return x + self.PE_matrix[:seq_len, :] #  type: ignore


# %%

class MLP(nn.Module):

    def __init__(self, config: TransformerConfig):

        super().__init__()
        self.config = config

        self.fc1 = nn.Linear(config.hidden_size, 4*config.hidden_size, dtype=dtype, device=device)
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(4*config.hidden_size, config.hidden_size, dtype=dtype, device=device)
        self.dropout = nn.Dropout(p=config.dropout)

    def forward(self, x: t.Tensor) -> t.Tensor:
        
        x = self.gelu(self.fc1(x))
        x = self.dropout(self.fc2(x))

        return x


class DecoderBlock(nn.Module):

    def __init__(self, config: TransformerConfig):

        super().__init__()
        self.config = config

        self.attention = MultiheadMaskedAttention(self.config.hidden_size, self.config.num_heads)  # type: ignore
        self.ln1 = nn.LayerNorm(normalized_shape=config.hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = MLP(config)
        self.ln2 = nn.LayerNorm(normalized_shape=config.hidden_size, eps=config.layer_norm_epsilon)


    def forward(self, x: t.Tensor) -> t.Tensor:

        x = self.ln1(self.attention(x)) + x
        x = self.ln2(self.mlp(x)) + x

        return x


class DecoderOnlyTransformer(nn.Module):

    def __init__(self, config: TransformerConfig):

        super().__init__()

        self.token_emb = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size)
        self.pos_emb = PositionalEncoding(config.max_seq_len, config.hidden_size)
        self.dropout = nn.Dropout(p=config.dropout)
        # self.blocks = nn.Sequential({'block '+str(i):DecoderBlock(config) for i in range(config.num_layers)}) # type: ignore
        self.blocks = nn.Sequential(OrderedDict(
            [(f'block {i}',DecoderBlock(config)) for i in range(config.num_layers)]
            )) # type: ignore
        self.layer_norm = nn.LayerNorm(normalized_shape=config.hidden_size)


    def forward(self, x: t.Tensor) -> t.Tensor:
      
        if len(x.shape)==1:
            x=x.unsqueeze(dim=0)
            
        embedding = self.token_emb(x.to(dtype=t.long))
        embedding = self.pos_emb(embedding)
        embedding = self.dropout(embedding)
        embedding = self.blocks(embedding)
        embedding = self.layer_norm(embedding)

        logits = einsum('vocab d, batch seq d -> batch seq vocab', self.token_emb.weight, embedding)

        return logits

# %%

class CustomTextDataset(Dataset):
    def __init__(self, text, labels):
        self.labels = labels
        self.text = text

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        label = self.labels[idx]
        text = self.text[idx]
        sample = (text, label)
        return sample


def train_transformer(trainloader: DataLoader, testloader: DataLoader, epochs: int, loss_fn: Callable, config: TransformerConfig) -> list:
    '''
    Returns tuple of (loss_list, accuracy_list), where accuracy_list contains the fraction of accurate classifications on the test set, at the end of each epoch.
    '''

    model = DecoderOnlyTransformer(config).to(device).train()
    optimizer = t.optim.Adam(model.parameters())
    loss_list = []
    accuracy_list = []

    for epoch in tqdm_notebook(range(epochs)):

        loss = None
        
        for (x, y) in tqdm_notebook(trainloader, leave=False):

            x = x.to(device)
            y = y.to(device)
            y = rearrange(y, 'batch seq_len -> (batch seq_len)')

            y_hat = model(x)
            y_hat = rearrange(y_hat, 'batch seq_len logit -> (batch seq_len) logit')

            loss = loss_fn(y_hat, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            loss_list.append(loss.item())
        
        for (x, y) in tqdm_notebook(testloader, leave=False):

            x = x.to(device)
            y = y.to(device)
            y = rearrange(y, 'batch seq_len -> (batch seq_len)')


            y_hat = model(x)
            y_hat = rearrange(y_hat, 'batch seq_len logit -> (batch seq_len) logit')
            preds = t.argmax(y_hat, dim=1)
            
            accuracy = (y == preds).to(float).mean().item()
            accuracy_list.append(accuracy)

        print(f"Epoch {epoch+1}/{epochs}, train loss is {loss:.6f}") 

    print(f"Saving model to: {MODEL_FILENAME}")
    t.save(model, MODEL_FILENAME)
    return [loss_list, accuracy_list]



In [5]:

def greedy_search(logits: t.Tensor) -> int:
    """
    logits: shape (vocab_size, )

    Return: the most likely token (as an integer)
    """
    out = logits.argmax().item()
    assert isinstance(out, int)
    return out

def sample_basic(logits: t.Tensor) -> int:
    """
    logits: shape (vocab_size, ) - unnormalized log-probabilities

    Return: a sampled token
    """
    distribution = t.distributions.categorical.Categorical(logits=logits)
    out = distribution.sample().item()
    assert isinstance(out, int)
    return out

def apply_temperature(logits: t.Tensor, temperature: float) -> t.Tensor:
    """
    logits: shape (vocab_size, )

    Return: shape (vocab_size, )
    """
    assert temperature > 0
    return logits / temperature

def apply_freq_penalty(input_ids: t.Tensor, logits: t.Tensor, freq_penalty: float) -> t.Tensor:
    """
    input_ids: shape (seq, )
    logits: shape (vocab_size, )
    Return: shape (vocab_size, )
    """
    (vocab_size,) = logits.shape
    id_freqs = t.bincount(input_ids, minlength=vocab_size)
    return logits - freq_penalty * id_freqs

def sample_top_k(logits: t.Tensor, top_k: int) -> int:
    """
    logits: shape (vocab_size, ) - unnormalized log-probabilities
    top_k: only consider this many of the most likely tokens for sampling

    Return: a sampled token
    """
    top_logits, top_idx = t.topk(logits, top_k)
    idx = t.distributions.categorical.Categorical(logits=top_logits).sample()
    return top_idx[idx].item()

def sample_top_p(logits: t.Tensor, top_p: float, min_tokens_to_keep: int = 1) -> int:
    """
    logits: shape (vocab_size, ) - unnormalized log-probabilities
    Return: a sampled token
    """
    logits_sorted, indices = logits.sort(descending=True, stable=True)
    cumul_probs = logits_sorted.softmax(-1).cumsum(-1)
    n_keep = t.searchsorted(cumul_probs, top_p, side="right").item() + 1
    n_keep = max(n_keep, min_tokens_to_keep)
    keep_idx = indices[:n_keep]
    keep_logits = logits[keep_idx]
    sample = t.distributions.categorical.Categorical(logits=keep_logits).sample()
    return keep_idx[sample].item()

# Shakespeare

## Definitions

In [9]:
dtype = t.float
device = 'cuda' if t.cuda.is_available() else 'cpu'

class WordsDataset(Dataset):
    def __init__(self, words, seq_len, sample_size):
        
        self.words = words
        self.seq_len = seq_len
        self.sample_size = sample_size
        self.vocab_size = len(set(self.words))
        self.max_len = len(self.words) - self.seq_len + 1
        self.word_to_tok = {word: i for (i, word) in enumerate(set(words))}
        self.tok_to_word = {self.word_to_tok[word]: word for word in self.word_to_tok}
        self.tokens = t.tensor([self.word_to_tok[word] for word in self.words], dtype=dtype, device=device)

    def __len__(self):
        return int(self.max_len * self.sample_size)

    def __getitem__(self, idx):

        current_seq = self.tokens[idx: idx + self.seq_len + 1]
        x = current_seq[:-1]
        y = current_seq[1:]

        return x, y


# %%

class WordsTokenizer():
    model_max_length: int

    def __init__(self, wordsdataset: WordsDataset):
        
        self.word_to_tok = wordsdataset.word_to_tok
        self.tok_to_word = wordsdataset.tok_to_word
        self.model_max_length = wordsdataset.seq_len

    def encode(self, initial_text: str, return_tensors: Optional[str] = None) -> Union[list, t.Tensor]:
        '''
        Tokenizes initial_text, then returns the token ids.

        Return type is list by default, but if return_tensors="pt" then it is returned as a tensor.
        '''

        words = re.split(r'\b', initial_text)
        words = [word for word in words if word]

        tokens = [self.word_to_tok[word] for word in words]
        if return_tensors == 'pt':
            tokens = t.tensor(tokens, dtype=dtype, device=device)

        return tokens
        

    def decode(self, list_of_ids: Union[t.Tensor, list]) -> str:
        '''
        Converts ids to a list of tokens, then joins them into a single string.
        '''
        
        return ''.join([self.tok_to_word[int(token)] for token in list_of_ids])


# %%


def apply_sampling_methods(
    input_ids: t.Tensor, logits: t.Tensor, temperature=1.0, freq_penalty=0.0, top_k=0, top_p=0.0
) -> int:
    '''
    Return the next token, sampled from the model's probability distribution with modifiers.
x
    input_ids: shape (seq,)
    '''
    assert input_ids.ndim == 1, "input_ids should be a 1D sequence of token ids"
    assert temperature >= 0, "Temperature should be non-negative"
    assert 0 <= top_p <= 1.0, "Top-p must be a probability"
    assert 0 <= top_k, "Top-k must be non-negative"
    assert not (top_p != 0 and top_k != 0), "At most one of top-p and top-k supported"

    if temperature == 0:
        return greedy_search(logits)
    if temperature != 1.0:
        logits = apply_temperature(logits, temperature)
    if freq_penalty != 0.0:
        logits = apply_freq_penalty(input_ids, logits, freq_penalty)
    if top_k > 0:
        return sample_top_k(logits, top_k)
    if top_p > 0:
        return sample_top_p(logits, top_p)
    return sample_basic(logits)


def sample_tokens(
    model: DecoderOnlyTransformer,
    tokenizer: WordsTokenizer,
    initial_text: str,
    max_tokens_generated=30,
    **kwargs # kwargs are for params like temperature, top_k, etc
) -> str:
    '''
    Sample tokens until the model outputs `tokenizer.eos_token_id` or the specified token limit is reached.

    Return: the prompt and continuation concatenated
    '''
    # Note - an alternative to model.eval() is to use the @t.inference_mode() decorator for this whole function.
    model.eval()
    input_ids: list = tokenizer.encode(initial_text) # type: ignore
    generated = []
    for _ in range(max_tokens_generated):
        new_input_ids = t.tensor(input_ids + generated, dtype=t.long, device=device)
        new_input_ids_window = new_input_ids[-min(max_seq_len, new_input_ids.shape[0]):]
        logits = model(new_input_ids_window)[0, -1]
        new_token = apply_sampling_methods(new_input_ids, logits, **kwargs)
        generated.append(new_token)
        if new_token == getattr(tokenizer, "eos_token_id", None):
            break
    return tokenizer.decode(input_ids + generated)


## Training

In [10]:
def train_transformer(model, loss_fn, optimizer, trainloader, epochs, plot_loss=True):

    loss_list = []

    for epoch in range(epochs):
        
        progress_bar = tqdm_notebook(trainloader)
        for (x, y) in progress_bar:

            x = x.to(device=device)
            y = y.to(dtype=t.int64, device=device)

            y_hat = rearrange(model(x), "b s d -> (b s) d")
            y = t.flatten(y)

            loss = loss_fn(y_hat, y)
            loss_list.append(loss.item())

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            progress_bar.set_description(f"epoch = {epoch+1}, loss = {loss.item():.4f}")

    # Function to plot the loss over epochs
    if plot_loss:
        fig = px.line(
            y=loss_list, 
            template="simple_white", 
            labels={
                "x": "No. batches seen", 
                "y": str(loss_fn).replace("()", "") # This gets a name like "CrossEntropyLoss" from the loss function
            }, 
            title='Training loss'
        )
        # This next bit of code plots vertical lines corresponding to the epochs
        if epochs > 1:
            for idx, epoch_start in enumerate(np.linspace(0, len(loss_list), epochs, endpoint=False)):
                fig.add_vline(x=epoch_start, line_width=3, line_dash="dash", annotation_text=f"Epoch {idx}", annotation_position="top right")
        fig.show()
    
    return model


In [12]:
max_seq_len = 48
batch_size = 32

with open("100-0.txt") as file:
    text = file.read()
    words = re.split(r"\b", text)

trainset = WordsDataset(words=words, seq_len=max_seq_len, sample_size=0.01)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
tokenizer = WordsTokenizer(trainset)

config = TransformerConfig(
    num_layers = 8,
    num_heads = 8,
    vocab_size = trainset.vocab_size,
    hidden_size = 512,
    max_seq_len = trainset.seq_len,
    dropout = 0.1,
    layer_norm_epsilon = 1e-05
)

model = DecoderOnlyTransformer(config).to(device).train()
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
epochs = 1

In [13]:
model = train_transformer(model, loss_fn, optimizer, trainloader, epochs)

  0%|          | 0/622 [00:00<?, ?it/s]

In [14]:
initial_text = "twas"
text_output = sample_tokens(model, tokenizer, initial_text, max_tokens_generated=100, temperature=1.0, top_k=10)
print(text_output)

twas of and care,
  When I was the must see, and left what all see.


                    53

Mine be have plea thee for love, shallowest is too words thyself thy it self-killed:
That is thy use not is is change is is the tillage of thy of husbandry?
Or survive who


In [15]:
initial_text = "prithee"
text_output = sample_tokens(model, tokenizer, initial_text, max_tokens_generated=100, temperature=1.0, top_k=10)
print(text_output)

prithee fair of the your year,
The mayst world some toiled:
  Then bankrupt and is praise is beauteous and thee thy dial’s use it it not their tomb
Of his stain of through Gutenberg moan.
  This all is one,
  Sweet thy love love flattery, then she loves love is as


In [16]:
initial_text = "verily"
text_output = sample_tokens(model, tokenizer, initial_text, max_tokens_generated=100, temperature=1.0, top_k=10)
print(text_output)

verily thoughts would be although die,
The children’s blood fair words which use,
Possessing fair fair the of check,
From his low death, before?
No bosom’s the conquest of KING moan,
Mine eye blood is to through the ere time distilled:
Make though provoke him on him him the
