Let's try and learn a tokenizer with RL : )

In [27]:
from transformers import AutoTokenizer
from torch import nn
import torch.nn.functional as F
from typing import List

from copy import deepcopy
import torch

from scipy.signal import lfilter
import numpy as np

In [97]:
# Helper functions to count the number of parameters in a torch.nn.Module
def count_parameters(module):
    return sum(p.numel() for p in module.parameters())


def parameter_count_string(module):
    n_params = count_parameters(module)
    if n_params > 10**6:
        return f"{n_params/10**6:.1f}M"
    elif n_params > 10**3:
        return f"{n_params/10**3:.1f}k"
    else:
        return f"{n_params}" 

In [28]:
byte5_tokenizer = AutoTokenizer.from_pretrained("google/byt5-large")

In [29]:
import os

username = "sdauncey"
scratch_dir = f"/scratch/{username}/tokenizer_training"

if not os.path.exists(scratch_dir):
    os.makedirs(scratch_dir)

scratch_dir

'/scratch/sdauncey/tokenizer_training'

In [33]:
from datasets import load_dataset

ds = load_dataset("fka/awesome-chatgpt-prompts")
train_prompts = [s['prompt'] for s in ds['train']]
len(train_prompts)
train_prompts[0]


'Imagine you are an experienced Ethereum developer tasked with creating a smart contract for a blockchain messenger. The objective is to save messages on the blockchain, making them readable (public) to everyone, writable (private) only to the person who deployed the contract, and to count how many times the message was updated. Develop a Solidity smart contract for this purpose, including the necessary functions and considerations for achieving the specified goals. Please provide the code and any relevant explanations to ensure a clear understanding of the implementation.'

In [None]:
# import datasets

# # Download a portion of OpenWebText dataset
# # This will download a subset of the OpenWebText corpus
# print("Downloading OpenWebText dataset...")

# # Load a small portion of OpenWebText (1% of the dataset)
# openwebtext_8k = datasets.load_dataset(
#     "openwebtext",
#     split="train[:1%]",  # Using only 1% samples of the dataset for now.
#     cache_dir=os.path.join(scratch_dir, "openwebtext_8k_cache"),
#     trust_remote_code=True
# )

# print(f"Downloaded {len(openwebtext_8k)} examples from OpenWebText")

# # Display a sample
# print("\nSample text from OpenWebText:")
# print(openwebtext_8k[0]['text'][:500] + "...")

In [52]:


class BitterLLM(nn.Module):
    # Maps bytes to bytes, slowly merging and then unmerging tokens layer-by-layer
    # with a context window inversely proportional to the number of tokens in the sequence.#
    def __init__(self, vocab_size: int, embedding_dim: int, num_layers: int, num_heads: int, dropout: float, downsampling_rates: List[float]):
        
        # Initialize a standard transformer decoder architecture.
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        self.down_layers = nn.ModuleList([
            nn.TransformerDecoderLayer(embedding_dim, num_heads, dropout=dropout)
            for _ in range(num_layers)
        ])

        self.mid_layer = nn.TransformerDecoderLayer(embedding_dim, num_heads, dropout=dropout)

        self.up_layers = nn.ModuleList([
            nn.TransformerDecoderLayer(embedding_dim, num_heads, dropout=dropout)
            for _ in range(num_layers)
        ])

        self.output_layer = nn.Linear(embedding_dim, vocab_size)

        # Initialize a gate for each layer.
        layer_gate_init = nn.Linear(embedding_dim, 1)

        # Copy the gate for each layer. 
        # Initializing by copying inductively biases the model to tokenize in a later layer if the gate is high but the model chose not to.
        self.layer_gates = nn.ModuleList([
            deepcopy(layer_gate_init) for _ in range(num_layers*2)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Input is a tensor of shape (batch_size, sequence_length), 
        gets transformed into a tensor of shape (batch_size, n_tokens, embedding_dim).
        """
        x = self.embedding(x)

        hidden_states = [x]
        merge_destinations = []

        for layer, gate in zip(self.down_layers, self.layer_gates):
            # TODO: make the context window inversely proportional to the number of tokens in the sequence.
            x = layer(x)

            # Gate the layer.
            gate_logits = gate(x)
            gate_probs = F.sigmoid(gate_logits)

            if self.training:
                # Re-scale the gate probabilities to control the downsampling rate.
                ...

            gate_samples = torch.bernoulli(gate_probs)

            # Merge the tokens where the gate is 1.
            # Create a merge destination tensor
            batch_size, seq_len, _ = x.shape
            merge_dst, n_dst = get_merge_dst(gate_samples)
            
            y = torch.zeros(batch_size, n_dst, self.embedding_dim, dtype=torch.float32)
            x = torch.scatter_reduce(y, dim=1, index=merge_dst, src=x, reduce="mean", include_self=False)

            hidden_states.append((x, gate_samples, gate_probs))

        y = x
        for up_layer, down_hidden_state, merge_dst in zip(self.up_layers, reversed(hidden_states), reversed(merge_destinations)):
            pass

        # TODO: finish implementation.

        return self.output_layer(x), hidden_states
        

In [53]:
test_batch = byte5_tokenizer(train_prompts[:5], return_tensors="pt", padding=True)["input_ids"]
test_batch.shape

torch.Size([5, 797])

In [42]:
test_batch.min(), test_batch.max()

(tensor(0), tensor(229))

In [43]:
byte5_tokenizer.get_vocab()

{'<pad>': 0,
 '</s>': 1,
 '<unk>': 2,
 '\x00': 3,
 '\x01': 4,
 '\x02': 5,
 '\x03': 6,
 '\x04': 7,
 '\x05': 8,
 '\x06': 9,
 '\x07': 10,
 '\x08': 11,
 '\t': 12,
 '\n': 13,
 '\x0b': 14,
 '\x0c': 15,
 '\r': 16,
 '\x0e': 17,
 '\x0f': 18,
 '\x10': 19,
 '\x11': 20,
 '\x12': 21,
 '\x13': 22,
 '\x14': 23,
 '\x15': 24,
 '\x16': 25,
 '\x17': 26,
 '\x18': 27,
 '\x19': 28,
 '\x1a': 29,
 '\x1b': 30,
 '\x1c': 31,
 '\x1d': 32,
 '\x1e': 33,
 '\x1f': 34,
 ' ': 35,
 '!': 36,
 '"': 37,
 '#': 38,
 '$': 39,
 '%': 40,
 '&': 41,
 "'": 42,
 '(': 43,
 ')': 44,
 '*': 45,
 '+': 46,
 ',': 47,
 '-': 48,
 '.': 49,
 '/': 50,
 '0': 51,
 '1': 52,
 '2': 53,
 '3': 54,
 '4': 55,
 '5': 56,
 '6': 57,
 '7': 58,
 '8': 59,
 '9': 60,
 ':': 61,
 ';': 62,
 '<': 63,
 '=': 64,
 '>': 65,
 '?': 66,
 '@': 67,
 'A': 68,
 'B': 69,
 'C': 70,
 'D': 71,
 'E': 72,
 'F': 73,
 'G': 74,
 'H': 75,
 'I': 76,
 'J': 77,
 'K': 78,
 'L': 79,
 'M': 80,
 'N': 81,
 'O': 82,
 'P': 83,
 'Q': 84,
 'R': 85,
 'S': 86,
 'T': 87,
 'U': 88,
 'V': 89,
 'W': 90,

In [162]:
def get_merge_dst(gate_samples: torch.Tensor) -> torch.Tensor:
    """
    Returns (merge_dst, dst_idx) the merge destination for each token in the sequence and the number of unique merge destinations.
    For now, has a janky python for-loop implementation.
    Input is a tensor of shape (batch_size, sequence_length) with 0 tokens are merged into the next 1 token.
    """
    batch_size, seq_len = gate_samples.shape
    merge_dst = torch.zeros_like(gate_samples, dtype=torch.long)
    n_dst = torch.zeros(batch_size, dtype=torch.long)

    # Process each batch separately
    for b in range(batch_size):
        dst_idx = 0
        for i in range(seq_len):
            merge_dst[b, i] = dst_idx
            if gate_samples[b, i] == 1 and i < seq_len - 1:
                # If previous position had gate=1, keep the same destination
                dst_idx += 1

        n_dst[b] = dst_idx + 1

    return merge_dst, n_dst


class MiniBitterLLM(nn.Module):
    # A mini BitterLLM with 2 down, 4 mid, and 2 up layers. As a vibe check on the idea.
    def __init__(self, vocab_size: int, embedding_dim: int, num_heads: int, dropout: float=0.01, downsample_rate: float = 0.25):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        self.down_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(embedding_dim, num_heads, dropout=dropout, batch_first=True) for _ in range(2)
        ])

        self.mid_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(embedding_dim, num_heads, dropout=dropout, batch_first=True) for _ in range(4)
        ])

        self.up_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(embedding_dim, num_heads, dropout=dropout, batch_first=True) for _ in range(2)
        ])

        self.output_layer = nn.Linear(embedding_dim, vocab_size)
        # Initialize a gate for each layer.
        layer_gate_init = nn.Linear(embedding_dim, 1)

        # Copy the gate for each layer. 
        # Initializing by copying inductively biases the model to tokenize in a later layer if the gate is high but the model chose not to.
        self.down_layer_gate = deepcopy(layer_gate_init)
        self.up_layer_gate = deepcopy(layer_gate_init)
        self.downsample_rate = downsample_rate

    def apply_local_layers(self, layers, x: torch.Tensor, context_window_length) -> torch.Tensor:
        """Again a janky python for-loop implementation that re-constructs the causal mask for each layer."""
        _, seq_len, _ = x.shape

        # Create causal mask for context length of 64
        mask = torch.ones(seq_len, seq_len) * float('-inf')

        for i in range(seq_len):
            # Allow attention to self and previous window_size tokens
            start_idx = max(0, i - context_window_length + 1)
            mask[i, start_idx:i+1] = 0.0
        
        # Process through down layers with the specified context length
        for layer in layers:
            x = layer(x, src_mask=mask, is_causal=True)

        return x


    def forward(self, x: torch.Tensor) -> torch.Tensor:

        batch_size, _ = x.shape

        x = self.embedding(x)

        # Apply down layers  byte tokens        
        x = self.apply_local_layers(self.down_layers, x, 64)

        # Sample gating binary variables for each token.
        gate_logits = self.down_layer_gate(x)
        gate_probs = F.sigmoid(gate_logits)

        if self.training:
            # Re-scale the gate probabilities to control the downsampling rate.
            true_gate_probs = gate_probs * (self.downsample_rate / gate_probs.mean())

        gate_samples = torch.bernoulli(true_gate_probs)

        # Merge the tokens into the next token where the gate is 1.
        gate_samples = gate_samples.squeeze(-1)
        merge_dst, n_dst = get_merge_dst(gate_samples)
        merge_dst = merge_dst.unsqueeze(-1).expand(-1, -1, self.embedding_dim)

        x_downsampled = torch.zeros(batch_size, n_dst.max(), self.embedding_dim, dtype=torch.float32)
        x_downsampled = torch.scatter_reduce(x_downsampled, dim=1, index=merge_dst, src=x, reduce="mean", include_self=False)

        # Apply mid layers to merged tokens and compute the deviation
        y_downsampled = self.apply_local_layers(self.mid_layers, x_downsampled, 64*4)
        deviation = y_downsampled - x_downsampled

        # Add the upsampled deviation to the input to the middle layers
        upsampled_deviation = torch.gather(deviation, dim=1, index=merge_dst)
        y = x + upsampled_deviation

        # Apply up layers to byte tokens
        y = self.apply_local_layers(self.up_layers, y, 64)

        # Map residual stream to logits
        logits = self.output_layer(y)
        logits = F.log_softmax(logits, dim=-1)

        out = {
            "logits": logits,
            "gate_probs": true_gate_probs.squeeze(-1),
            "gate_samples": gate_samples.to(dtype=torch.long),
            "merge_dst": merge_dst[:, :, 0], # This dimension is repeated.
            "n_dst": n_dst,
        }

        return out


my_model = MiniBitterLLM(vocab_size=byte5_tokenizer.vocab_size, embedding_dim=128, num_heads=4, downsample_rate=0.25)
print(f"my_model has {parameter_count_string(my_model)} parameters")

my_model has 4.8M parameters


In [164]:
out = my_model(test_batch)

for k, v in out.items():
    print(f"{k=} {v.shape=}, {v.dtype=}")


k='logits' v.shape=torch.Size([5, 797, 256]), v.dtype=torch.float32
k='gate_probs' v.shape=torch.Size([5, 797]), v.dtype=torch.float32
k='gate_samples' v.shape=torch.Size([5, 797]), v.dtype=torch.int64
k='merge_dst' v.shape=torch.Size([5, 797]), v.dtype=torch.int64
k='n_dst' v.shape=torch.Size([5]), v.dtype=torch.int64


In [105]:
out = my_model(test_batch)

In [165]:
def display_gating(tokens_ids, merge_dst):
    """Display how a SmallBitterLLM merges a sequence. token_ids and merge_dst are tensors of shape (sequence_length,)."""
    previous_merge_dst = 0
    for t_id, merge_destinantion in zip(tokens_ids, merge_dst):
        merge_destinantion = merge_destinantion.item()
        
        if merge_destinantion != previous_merge_dst:
            print(f"|", end="")
            previous_merge_dst = merge_destinantion
        
        t_txt = byte5_tokenizer.decode(t_id)
        print(f"{t_txt}", end="")

display_gating(test_batch[0], out["merge_dst"][0])

Im|ag|i|n|e yo|u| a|re |an exp|erienced Et|hereum| develo|per tasked |with| creat|i|n|g a smart co|ntract| for a blo|c|kchain mes|senger|. |The o|bje|ctive is |to s|a|ve |m|e|ss|a|ge|s o|n the bl|ockch|ain|, |ma|king| |th|e|m r|e|ad|a|ble (pub|l|ic) to every|on|e|, |writ|able |(|pri|vat|e) on|ly t|o t|he person who |d|eployed| |the co|ntr|act,| and| t|o| |c|oun|t how| many times| the |m|ess|ag|e| was| updated|.| Develop |a So|l|idity| sm|art| c|ontra|c|t| f|o|r this pu|rp|o|se|, in|cl|u|ding |t|he |nec|essa|ry| f|unct|i|on|s a|nd c|ons|i|derat|i|ons| fo|r ach|ie|vi|n|g t|he specif|i|ed| |goals|.| |Ple|as|e |p|rov|i|de t|h|e code| |a|nd any| relev|ant| |expl|ana|tions| t|o e|n|s|u|re a |c|lear |u|nde|rstanding |of t|h|e i|mp|lement|ati|on|.</s><pad><pad><pad><pad>|<pad><pad><pad><pad><pad><pad><pad><pad>|<pad>|<pad><pad><pad><pad><pad><pad><pad>|<pad><pad><pad>|<pad><pad><pad>|<pad><pad><pad><pad><pad><pad>|<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>

In [166]:
optimizer = torch.optim.Adam(my_model.parameters(), lr=0.001)


In [176]:
def compute_discounted_rewards(rewards, discount):
    """
    Assumes that rewards is a numpy array of shape (n_episodes, n_timesteps). Returns tensor of same shape.
    credit to: https://stackoverflow.com/questions/47970683/vectorize-a-numpy-discount-calculation/47971187#47971187,
    minor modifications made to vectorise computation.
    C[i] = R[i] + discount * C[i+1]
    signal.lfilter(b, a, x, axis=-1, zi=None)
    a[0]*y[n] = b[0]*x[n] + b[1]*x[n-1] + ... + b[M]*x[n-M]
                          - a[1]*y[n-1] - ... - a[N]*y[n-N]
    """
    r = rewards[:, ::-1]
    a = [1, -discount]
    b = [1]
    y = lfilter(b, a, x=r)
    return y[:, ::-1]


def discounted_rewards_torch(rewards, discount):
    """torch wrapper for compute_discounted_rewards. Warning: does _not_ allow for backprop through the rewards, which is fine for policy gradients."""
    rewards = rewards.detach().numpy()
    discounted_rewards = compute_discounted_rewards(rewards, discount)
    discounted_rewards = torch.tensor(discounted_rewards.copy()) # Copy as torch doesn't like converting negatively strided arrays
    return discounted_rewards


def bitter_tokenizer_training_step(model, batch):
    """
    Assume that batch is torch.tensor of token ids of shape (batch, sequence_length). returns a dict of floats of the training losses for the batch.
    """
    batch_size, _ = batch.shape

    optimizer.zero_grad()

    out = model(batch)
    logits = out["logits"]
    gate_samples = out["gate_samples"]
    gate_probs = out["gate_probs"]


    # Compute autoregressive loss: log probability of next token.
    next_token_ids = batch[:, 1:]
    current_token_logits = logits[:, :-1]
    next_token_logits = F.cross_entropy(current_token_logits.transpose(1, 2), next_token_ids, reduction="none") # Transpose as F.cross_entropy wants shape [batch, classes, ...]
    ar_loss = next_token_logits.mean()

    # Compute gating loss: discounted log probabilities of following token(s).
    discount_rate = 0.95
    next_token_logits_padded = torch.cat([next_token_logits, torch.zeros(batch_size, 1)], dim=-1) # Pad the last reward as zero
    discounted_rewards = discounted_rewards_torch(next_token_logits_padded, discount_rate)
    discounted_rewards = (discounted_rewards - discounted_rewards.mean())

    # action 0 = continue, action 1 = gate
    action_log_probs = torch.stack([(1 - gate_probs).log() , gate_probs.log()], dim=1)
    selected_action_log_probs = F.cross_entropy(action_log_probs, gate_samples, reduction="none")
    gating_loss = (discounted_rewards * selected_action_log_probs).mean()

    # Optimizer step
    total_loss = ar_loss + gating_loss

    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()

    out = {
        "ar_loss": ar_loss.item(),
        "gating_loss": gating_loss.item(),
        "total_loss": total_loss.item()
    }

    return out

bitter_tokenizer_training_step(my_model, test_batch)

{'ar_loss': 2.5473217964172363,
 'gating_loss': 0.30918932459530757,
 'total_loss': 2.856511121012544}