Let's try and learn a tokenizer with RL : )

In [1]:
from transformers import AutoTokenizer
from torch import nn
import torch.nn.functional as F
from typing import List

from copy import deepcopy
import torch

In [2]:
byte5_tokenizer = AutoTokenizer.from_pretrained("google/byt5-large")

In [3]:
import os

username = "sdauncey"
scratch_dir = f"/scratch/{username}/tokenizer_training"

if not os.path.exists(scratch_dir):
    os.makedirs(scratch_dir)

In [5]:
import datasets

# Download a portion of OpenWebText dataset
# This will download a subset of the OpenWebText corpus
print("Downloading OpenWebText dataset...")

# Load a small portion of OpenWebText (1% of the dataset)
openwebtext = datasets.load_dataset(
    "openwebtext",
    split="train[:1%]",  # Using only 1% of the data to keep it manageable
    cache_dir=os.path.join(scratch_dir, "openwebtext_cache"),
    trust_remote_code=True
)

print(f"Downloaded {len(openwebtext)} examples from OpenWebText")

# Display a sample
print("\nSample text from OpenWebText:")
print(openwebtext[0]['text'][:500] + "...")

Downloading OpenWebText dataset...


Downloading data:   0%|          | 0/21 [00:00<?, ?files/s]

Generating train split:   0%|          | 0/8013769 [00:00<?, ? examples/s]

Downloaded 80138 examples from OpenWebText

Sample text from OpenWebText:
Port-au-Prince, Haiti (CNN) -- Earthquake victims, writhing in pain and grasping at life, watched doctors and nurses walk away from a field hospital Friday night after a Belgian medical team evacuated the area, saying it was concerned about security.

The decision left CNN Chief Medical Correspondent Sanjay Gupta as the only doctor at the hospital to get the patients through the night.

CNN initially reported, based on conversations with some of the doctors, that the United Nations ordered the B...


In [6]:
openwebtext

Dataset({
    features: ['text'],
    num_rows: 80138
})

In [42]:
len(openwebtext[0]["text"]), len(openwebtext[1]["text"]), len(openwebtext[2]["text"])

(5516, 10286, 6021)

In [10]:
byte5_tokenizer.encode(openwebtext[0]["text"])

[83,
 114,
 117,
 119,
 48,
 100,
 120,
 48,
 83,
 117,
 108,
 113,
 102,
 104,
 47,
 35,
 75,
 100,
 108,
 119,
 108,
 35,
 43,
 70,
 81,
 81,
 44,
 35,
 48,
 48,
 35,
 72,
 100,
 117,
 119,
 107,
 116,
 120,
 100,
 110,
 104,
 35,
 121,
 108,
 102,
 119,
 108,
 112,
 118,
 47,
 35,
 122,
 117,
 108,
 119,
 107,
 108,
 113,
 106,
 35,
 108,
 113,
 35,
 115,
 100,
 108,
 113,
 35,
 100,
 113,
 103,
 35,
 106,
 117,
 100,
 118,
 115,
 108,
 113,
 106,
 35,
 100,
 119,
 35,
 111,
 108,
 105,
 104,
 47,
 35,
 122,
 100,
 119,
 102,
 107,
 104,
 103,
 35,
 103,
 114,
 102,
 119,
 114,
 117,
 118,
 35,
 100,
 113,
 103,
 35,
 113,
 120,
 117,
 118,
 104,
 118,
 35,
 122,
 100,
 111,
 110,
 35,
 100,
 122,
 100,
 124,
 35,
 105,
 117,
 114,
 112,
 35,
 100,
 35,
 105,
 108,
 104,
 111,
 103,
 35,
 107,
 114,
 118,
 115,
 108,
 119,
 100,
 111,
 35,
 73,
 117,
 108,
 103,
 100,
 124,
 35,
 113,
 108,
 106,
 107,
 119,
 35,
 100,
 105,
 119,
 104,
 117,
 35,
 100,
 35,
 69,
 104,
 111,
 106,
 

In [34]:
def get_merge_dst(gate_samples: torch.Tensor) -> torch.Tensor:
    """
    Returns (merge_dst, dst_idx) the merge destination for each token in the sequence and the number of unique merge destinations.
    For now, has a janky python for-loop implementation.
    Input is a tensor of shape (batch_size, sequence_length) with 0 tokens are merged into the next 1 token.
    """
    batch_size, seq_len = gate_samples.shape
    merge_dst = torch.zeros_like(gate_samples, dtype=torch.long)
            
    # Process each batch separately
    for b in range(batch_size):
        dst_idx = 0
        for i in range(seq_len):
            merge_dst[b, i] = dst_idx
            if gate_samples[b, i] == 1 and i < seq_len - 1:
                # If previous position had gate=1, keep the same destination
                dst_idx += 1

    n_dst = dst_idx + 1
    return merge_dst, n_dst

class BitterLLM(nn.Module):
    # Maps bytes to bytes, slowly merging and then unmerging tokens layer-by-layer
    # with a context window inversely proportional to the number of tokens in the sequence.#
    def __init__(self, vocab_size: int, embedding_dim: int, num_layers: int, num_heads: int, dropout: float, downsampling_rates: List[float]):
        
        # Initialize a standard transformer decoder architecture.
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        self.down_layers = nn.ModuleList([
            nn.TransformerDecoderLayer(embedding_dim, num_heads, dropout=dropout)
            for _ in range(num_layers)
        ])

        self.mid_layer = nn.TransformerDecoderLayer(embedding_dim, num_heads, dropout=dropout)

        self.up_layers = nn.ModuleList([
            nn.TransformerDecoderLayer(embedding_dim, num_heads, dropout=dropout)
            for _ in range(num_layers)
        ])

        self.output_layer = nn.Linear(embedding_dim, vocab_size)

        # Initialize a gate for each layer.
        layer_gate_init = nn.Linear(embedding_dim, 1)

        # Copy the gate for each layer. 
        # Initializing by copying inductively biases the model to tokenize in a later layer if the gate is high but the model chose not to.
        self.layer_gates = nn.ModuleList([
            deepcopy(layer_gate_init) for _ in range(num_layers*2)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Input is a tensor of shape (batch_size, sequence_length), 
        gets transformed into a tensor of shape (batch_size, n_tokens, embedding_dim).
        """
        x = self.embedding(x)

        hidden_states = [x]
        merge_destinations = []

        for layer, gate in zip(self.down_layers, self.layer_gates):
            # TODO: make the context window inversely proportional to the number of tokens in the sequence.
            x = layer(x)

            # Gate the layer.
            gate_logits = gate(x)
            gate_probs = F.sigmoid(gate_logits)

            if self.training:
                # Re-scale the gate probabilities to control the downsampling rate.
                ...

            gate_samples = torch.bernoulli(gate_probs)

            # Merge the tokens where the gate is 1.
            # Create a merge destination tensor
            batch_size, seq_len, _ = x.shape
            merge_dst, n_dst = get_merge_dst(gate_samples)
            
            y = torch.zeros(batch_size, n_dst, self.embedding_dim, dtype=torch.float32)
            x = torch.scatter_reduce(y, dim=1, index=merge_dst, src=x, reduce="mean", include_self=False)

            hidden_states.append((x, gate_samples, gate_probs))

        y = x
        for up_layer, down_hidden_state, merge_dst in zip(self.up_layers, reversed(hidden_states), reversed(merge_destinations)):
            pass

        return self.output_layer(x), hidden_states
        

In [70]:
class MiniBitterLLM(nn.Module):
    # A mini BitterLLM with 2 down, 4 mid, and 2 up layers. As a vibe check on the idea.
    def __init__(self, vocab_size: int, embedding_dim: int, num_heads: int, dropout: float=0.01, downsample_rate: float = 0.25):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        self.down_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(embedding_dim, num_heads, dropout=dropout, batch_first=True) for _ in range(2)
        ])

        self.mid_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(embedding_dim, num_heads, dropout=dropout, batch_first=True) for _ in range(4)
        ])

        self.up_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(embedding_dim, num_heads, dropout=dropout, batch_first=True) for _ in range(2)
        ])

        self.output_layer = nn.Linear(embedding_dim, vocab_size)
        # Initialize a gate for each layer.
        layer_gate_init = nn.Linear(embedding_dim, 1)

        # Copy the gate for each layer. 
        # Initializing by copying inductively biases the model to tokenize in a later layer if the gate is high but the model chose not to.
        self.down_layer_gate = deepcopy(layer_gate_init)
        self.up_layer_gate = deepcopy(layer_gate_init)
        self.downsample_rate = downsample_rate

    def apply_local_layers(self, layers, x: torch.Tensor, context_window_length) -> torch.Tensor:
        """Again a janky python for-loop implementation that re-constructs the causal mask for each layer."""
        _, seq_len, _ = x.shape

        # Create causal mask for context length of 64
        mask = torch.ones(seq_len, seq_len) * float('-inf')

        for i in range(seq_len):
            # Allow attention to self and previous window_size tokens
            start_idx = max(0, i - context_window_length + 1)
            mask[i, start_idx:i+1] = 0.0
        
        # Process through down layers with the specified context length
        for layer in layers:
            x = layer(x, tgt_mask=mask, tgt_is_causal=True)


    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x = self.embedding(x)
        # Apply transformer layers with context length of 64
        
        self.apply_local_layers(self.down_layers, x, 64)

        # Gate the layer.
        gate_logits = self.down_layer_gate(x)
        gate_probs = F.sigmoid(gate_logits)

        if self.training:
            # Re-scale the gate probabilities to control the downsampling rate.
            true_gate_probs = gate_probs * (self.downsample_rate / gate_probs.mean())

        gate_samples = torch.bernoulli(true_gate_probs)

        # Merge the tokens where the gate is 1.
        # Create a merge destination tensor
        batch_size, seq_len, _ = x.shape
        merge_dst, n_dst = get_merge_dst(gate_samples)
        
        x_downsampled = torch.zeros(batch_size, n_dst, self.embedding_dim, dtype=torch.float32)
        x_downsampled = torch.scatter_reduce(x_downsampled, dim=1, index=merge_dst, src=x, reduce="mean", include_self=False)

        y_downsampled = self.apply_local_layers(self.mid_layers, x_downsampled, 64*4)

        deviation = y_downsampled - x_downsampled
        upsampled_deviation = torch.gather(deviation, dim=1, index=merge_dst)

        y = x + upsampled_deviation
        self.apply_local_layers(self.up_layers, y, 64)

        print(f"{x.shape=} {x_downsampled.shape=} {y_downsampled.shape=} {y.shape=}")

        return self.output_layer(y)


my_model = MiniBitterLLM(vocab_size=byte5_tokenizer.vocab_size, embedding_dim=128, num_heads=4, downsample_rate=0.25)

In [72]:
test_batch = byte5_tokenizer(openwebtext["text"][:5], return_tensors="pt", padding=True)["input_ids"]
test_batch.shape

torch.Size([5, 10397])

In [66]:
test_batch.min(), test_batch.max()

(tensor(0), tensor(229))

In [71]:
byte5_tokenizer.get_vocab()

{'<pad>': 0,
 '</s>': 1,
 '<unk>': 2,
 '\x00': 3,
 '\x01': 4,
 '\x02': 5,
 '\x03': 6,
 '\x04': 7,
 '\x05': 8,
 '\x06': 9,
 '\x07': 10,
 '\x08': 11,
 '\t': 12,
 '\n': 13,
 '\x0b': 14,
 '\x0c': 15,
 '\r': 16,
 '\x0e': 17,
 '\x0f': 18,
 '\x10': 19,
 '\x11': 20,
 '\x12': 21,
 '\x13': 22,
 '\x14': 23,
 '\x15': 24,
 '\x16': 25,
 '\x17': 26,
 '\x18': 27,
 '\x19': 28,
 '\x1a': 29,
 '\x1b': 30,
 '\x1c': 31,
 '\x1d': 32,
 '\x1e': 33,
 '\x1f': 34,
 ' ': 35,
 '!': 36,
 '"': 37,
 '#': 38,
 '$': 39,
 '%': 40,
 '&': 41,
 "'": 42,
 '(': 43,
 ')': 44,
 '*': 45,
 '+': 46,
 ',': 47,
 '-': 48,
 '.': 49,
 '/': 50,
 '0': 51,
 '1': 52,
 '2': 53,
 '3': 54,
 '4': 55,
 '5': 56,
 '6': 57,
 '7': 58,
 '8': 59,
 '9': 60,
 ':': 61,
 ';': 62,
 '<': 63,
 '=': 64,
 '>': 65,
 '?': 66,
 '@': 67,
 'A': 68,
 'B': 69,
 'C': 70,
 'D': 71,
 'E': 72,
 'F': 73,
 'G': 74,
 'H': 75,
 'I': 76,
 'J': 77,
 'K': 78,
 'L': 79,
 'M': 80,
 'N': 81,
 'O': 82,
 'P': 83,
 'Q': 84,
 'R': 85,
 'S': 86,
 'T': 87,
 'U': 88,
 'V': 89,
 'W': 90,

In [73]:
my_model(test_batch)

TypeError: TransformerEncoderLayer.forward() got an unexpected keyword argument 'tgt_mask'

In [39]:
test_samples = torch.tensor([[0, 1, 1, 0, 0, 1, 0, 0, 0, 1]], dtype=torch.float32)

merge_dst, n_dst = get_merge_dst(test_samples)
merge_dst, n_dst

(tensor([[0, 0, 1, 2, 2, 2, 3, 3, 3, 3]]), 4)

In [54]:
# Create causal mask for context length of 64
# Create a sliding window attention mask with window size 2
window_size = 4
seq_len = 10
mask = torch.ones(seq_len, seq_len) * float('-inf')

for i in range(seq_len):
    # Allow attention to self and previous window_size tokens
    start_idx = max(0, i - window_size + 1)
    mask[i, start_idx:i+1] = 0.0
mask

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, 0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, 0., 0., 0., 0., -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, 0., 0., 0., 0., -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, 0., 0., 0., 0., -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0., 0.]])

In [40]:
y = torch.zeros(1, n_dst, dtype=torch.float32)
y = y.scatter_reduce(dim=1, index=merge_dst, src=test_samples, reduce="mean", include_self=False)
y

tensor([[0.5000, 1.0000, 0.3333, 0.2500]])

In [41]:
y_upsampled = torch.gather(y, dim=1, index=merge_dst)
y_upsampled

tensor([[0.5000, 0.5000, 1.0000, 0.3333, 0.3333, 0.3333, 0.2500, 0.2500, 0.2500,
         0.2500]])