Let's try and learn a tokenizer with RL : )

In [1]:
from transformers import AutoTokenizer
from torch import nn
import torch.nn.functional as F
from typing import List

from copy import deepcopy
import torch

from scipy.signal import lfilter
import numpy as np

import pandas as pd

from torch.utils.data import DataLoader

device = "cuda"
assert torch.cuda.is_available()

In [2]:
# Helper functions to count the number of parameters in a torch.nn.Module
def count_parameters(module):
    return sum(p.numel() for p in module.parameters())


def parameter_count_string(module):
    n_params = count_parameters(module)
    if n_params > 10**6:
        return f"{n_params/10**6:.1f}M"
    elif n_params > 10**3:
        return f"{n_params/10**3:.1f}k"
    else:
        return f"{n_params}" 

In [3]:
byte5_tokenizer = AutoTokenizer.from_pretrained("google/byt5-large")

In [4]:
import os

username = "sdauncey"
scratch_dir = f"/scratch/{username}/tokenizer_training"

if not os.path.exists(scratch_dir):
    os.makedirs(scratch_dir)

scratch_dir

'/scratch/sdauncey/tokenizer_training'

In [5]:
from datasets import load_dataset

ds = load_dataset("fka/awesome-chatgpt-prompts")
train_prompts = [s['prompt'] for s in ds['train']]
len(train_prompts)
train_prompts[0]


'Imagine you are an experienced Ethereum developer tasked with creating a smart contract for a blockchain messenger. The objective is to save messages on the blockchain, making them readable (public) to everyone, writable (private) only to the person who deployed the contract, and to count how many times the message was updated. Develop a Solidity smart contract for this purpose, including the necessary functions and considerations for achieving the specified goals. Please provide the code and any relevant explanations to ensure a clear understanding of the implementation.'

In [6]:
import datasets

# Download a portion of OpenWebText dataset
# This will download a subset of the OpenWebText corpus
print("Downloading OpenWebText dataset...")

# Load a small portion of OpenWebText (1% of the dataset)
openwebtext_8k = datasets.load_dataset(
    "openwebtext",
    split="train[:1%]",  # Using only 1% samples of the dataset for now.
    cache_dir=os.path.join(scratch_dir, "openwebtext_8k_cache"),
    trust_remote_code=True
)

print(f"Downloaded {len(openwebtext_8k)} examples from OpenWebText")

# Display a sample
print("\nSample text from OpenWebText:")
print(openwebtext_8k[0]['text'][:500] + "...")

Downloading OpenWebText dataset...
Downloaded 80138 examples from OpenWebText

Sample text from OpenWebText:
Port-au-Prince, Haiti (CNN) -- Earthquake victims, writhing in pain and grasping at life, watched doctors and nurses walk away from a field hospital Friday night after a Belgian medical team evacuated the area, saying it was concerned about security.

The decision left CNN Chief Medical Correspondent Sanjay Gupta as the only doctor at the hospital to get the patients through the night.

CNN initially reported, based on conversations with some of the doctors, that the United Nations ordered the B...


In [7]:
len(openwebtext_8k[1]['text'])

10286

In [8]:
DataLoader(openwebtext_8k, batch_size=8).__iter__().__next__()["text"]

['Port-au-Prince, Haiti (CNN) -- Earthquake victims, writhing in pain and grasping at life, watched doctors and nurses walk away from a field hospital Friday night after a Belgian medical team evacuated the area, saying it was concerned about security.\n\nThe decision left CNN Chief Medical Correspondent Sanjay Gupta as the only doctor at the hospital to get the patients through the night.\n\nCNN initially reported, based on conversations with some of the doctors, that the United Nations ordered the Belgian First Aid and Support Team to evacuate. However, Belgian Chief Coordinator Geert Gijs, a doctor who was at the hospital with 60 Belgian medical personnel, said it was his decision to pull the team out for the night. Gijs said he requested U.N. security personnel to staff the hospital overnight, but was told that peacekeepers would only be able to evacuate the team.\n\nHe said it was a "tough decision" but that he accepted the U.N. offer to evacuate after a Canadian medical team, als

In [9]:


class BitterLLM(nn.Module):
    # Maps bytes to bytes, slowly merging and then unmerging tokens layer-by-layer
    # with a context window inversely proportional to the number of tokens in the sequence.#
    def __init__(self, vocab_size: int, embedding_dim: int, num_layers: int, num_heads: int, dropout: float, downsampling_rates: List[float]):
        
        # Initialize a standard transformer decoder architecture.
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        self.down_layers = nn.ModuleList([
            nn.TransformerDecoderLayer(embedding_dim, num_heads, dropout=dropout)
            for _ in range(num_layers)
        ])

        self.mid_layer = nn.TransformerDecoderLayer(embedding_dim, num_heads, dropout=dropout)

        self.up_layers = nn.ModuleList([
            nn.TransformerDecoderLayer(embedding_dim, num_heads, dropout=dropout)
            for _ in range(num_layers)
        ])

        self.output_layer = nn.Linear(embedding_dim, vocab_size)

        # Initialize a gate for each layer.
        layer_gate_init = nn.Linear(embedding_dim, 1)

        # Copy the gate for each layer. 
        # Initializing by copying inductively biases the model to tokenize in a later layer if the gate is high but the model chose not to.
        self.layer_gates = nn.ModuleList([
            deepcopy(layer_gate_init) for _ in range(num_layers*2)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Input is a tensor of shape (batch_size, sequence_length), 
        gets transformed into a tensor of shape (batch_size, n_tokens, embedding_dim).
        """
        x = self.embedding(x)

        hidden_states = [x]
        merge_destinations = []

        for layer, gate in zip(self.down_layers, self.layer_gates):
            # TODO: make the context window inversely proportional to the number of tokens in the sequence.
            x = layer(x)

            # Gate the layer.
            gate_logits = gate(x)
            gate_probs = F.sigmoid(gate_logits)

            if self.training:
                # Re-scale the gate probabilities to control the downsampling rate.
                ...

            gate_samples = torch.bernoulli(gate_probs)

            # Merge the tokens where the gate is 1.
            # Create a merge destination tensor
            batch_size, seq_len, _ = x.shape
            merge_dst, n_dst = get_merge_dst(gate_samples)
            
            y = torch.zeros(batch_size, n_dst, self.embedding_dim, dtype=torch.float32)
            x = torch.scatter_reduce(y, dim=1, index=merge_dst, src=x, reduce="mean", include_self=False)

            hidden_states.append((x, gate_samples, gate_probs))

        y = x
        for up_layer, down_hidden_state, merge_dst in zip(self.up_layers, reversed(hidden_states), reversed(merge_destinations)):
            pass

        # TODO: finish implementation.

        return self.output_layer(x), hidden_states
        

In [10]:
test_batch = byte5_tokenizer(train_prompts[:5], return_tensors="pt", padding=True)["input_ids"]
test_batch.shape

torch.Size([5, 797])

In [11]:
test_batch.min(), test_batch.max()

(tensor(0), tensor(229))

In [None]:
byte5_tokenizer.get_vocab()

{'<pad>': 0,
 '</s>': 1,
 '<unk>': 2,
 '\x00': 3,
 '\x01': 4,
 '\x02': 5,
 '\x03': 6,
 '\x04': 7,
 '\x05': 8,
 '\x06': 9,
 '\x07': 10,
 '\x08': 11,
 '\t': 12,
 '\n': 13,
 '\x0b': 14,
 '\x0c': 15,
 '\r': 16,
 '\x0e': 17,
 '\x0f': 18,
 '\x10': 19,
 '\x11': 20,
 '\x12': 21,
 '\x13': 22,
 '\x14': 23,
 '\x15': 24,
 '\x16': 25,
 '\x17': 26,
 '\x18': 27,
 '\x19': 28,
 '\x1a': 29,
 '\x1b': 30,
 '\x1c': 31,
 '\x1d': 32,
 '\x1e': 33,
 '\x1f': 34,
 ' ': 35,
 '!': 36,
 '"': 37,
 '#': 38,
 '$': 39,
 '%': 40,
 '&': 41,
 "'": 42,
 '(': 43,
 ')': 44,
 '*': 45,
 '+': 46,
 ',': 47,
 '-': 48,
 '.': 49,
 '/': 50,
 '0': 51,
 '1': 52,
 '2': 53,
 '3': 54,
 '4': 55,
 '5': 56,
 '6': 57,
 '7': 58,
 '8': 59,
 '9': 60,
 ':': 61,
 ';': 62,
 '<': 63,
 '=': 64,
 '>': 65,
 '?': 66,
 '@': 67,
 'A': 68,
 'B': 69,
 'C': 70,
 'D': 71,
 'E': 72,
 'F': 73,
 'G': 74,
 'H': 75,
 'I': 76,
 'J': 77,
 'K': 78,
 'L': 79,
 'M': 80,
 'N': 81,
 'O': 82,
 'P': 83,
 'Q': 84,
 'R': 85,
 'S': 86,
 'T': 87,
 'U': 88,
 'V': 89,
 'W': 90,

In [22]:
def get_merge_dst(gate_samples: torch.Tensor) -> torch.Tensor:
    """
    Returns (merge_dst, dst_idx) the merge destination for each token in the sequence and the number of unique merge destinations.
    For now, has a janky python for-loop implementation.
    Input is a tensor of shape (batch_size, sequence_length) with 0 tokens are merged into the next 1 token.
    """
    batch_size, seq_len = gate_samples.shape
    merge_dst = torch.zeros_like(gate_samples, dtype=torch.long)
    n_dst = torch.zeros(batch_size, dtype=torch.long)

    # Process each batch separately
    for b in range(batch_size):
        dst_idx = 0
        for i in range(seq_len):
            merge_dst[b, i] = dst_idx
            if gate_samples[b, i] == 1 and i < seq_len - 1:
                # If previous position had gate=1, keep the same destination
                dst_idx += 1

        n_dst[b] = dst_idx + 1

    return merge_dst, n_dst


class MiniBitterLLM(nn.Module):
    # A mini BitterLLM with 2 down, 4 mid, and 2 up layers. As a vibe check on the idea.
    def __init__(self, vocab_size: int, embedding_dim: int, num_heads: int, dropout: float=0.01, downsample_rate: float = 0.25):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        self.down_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(embedding_dim, num_heads, dropout=dropout, batch_first=True, dim_feedforward=embedding_dim) for _ in range(2)
        ])

        # dim_feedforward should scale inversely with the number of tokens in the sequence.
        self.mid_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(embedding_dim, num_heads, dropout=dropout, batch_first=True, dim_feedforward=embedding_dim*4) for _ in range(4) 
        ])

        self.up_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(embedding_dim, num_heads, dropout=dropout, batch_first=True, dim_feedforward=embedding_dim) for _ in range(2)
        ])

        self.output_layer = nn.Linear(embedding_dim, vocab_size)
        # Initialize a gate for each layer.
        layer_gate_init = nn.Linear(embedding_dim, 1)

        # Copy the gate for each layer. 
        # Initializing by copying inductively biases the model to tokenize in a later layer if the gate is high but the model chose not to.
        self.down_layer_gate = deepcopy(layer_gate_init)
        self.up_layer_gate = deepcopy(layer_gate_init)
        self.downsample_rate = downsample_rate

    def apply_local_layers(self, layers, x: torch.Tensor, context_window_length) -> torch.Tensor:
        """Again a janky python for-loop implementation that re-constructs the causal mask for each layer."""
        _, seq_len, _ = x.shape

        # Create causal mask for context length of 64
        mask = torch.ones(seq_len, seq_len) * float('-inf')

        for i in range(seq_len):
            # Allow attention to self and previous window_size tokens
            start_idx = max(0, i - context_window_length + 1)
            mask[i, start_idx:i+1] = 0.0
        
        # Process through down layers with the specified context length
        for layer in layers:
            x = layer(x, src_mask=mask, is_causal=True)

        return x


    def forward(self, x: torch.Tensor) -> torch.Tensor:

        batch_size, _ = x.shape

        x = self.embedding(x)

        # Apply down layers  byte tokens        
        x = self.apply_local_layers(self.down_layers, x, 64)

        # Sample gating binary variables for each token.
        down_gate_logits = self.down_layer_gate(x)
        down_gate_probs = F.sigmoid(down_gate_logits)
        down_gate_samples = torch.bernoulli(down_gate_probs)

        # Hack: ensure for now that we always gate on the first token:
        down_gate_samples[:, 0] = 1.

        # Merge the tokens into the next token where the gate is 1.
        down_gate_samples = down_gate_samples.squeeze(-1)
        down_merge_dst, n_dst = get_merge_dst(down_gate_samples)
        down_merge_dst = down_merge_dst.unsqueeze(-1).expand(-1, -1, self.embedding_dim)

        x_downsampled = torch.zeros(batch_size, n_dst.max(), self.embedding_dim, dtype=torch.float32).to(x.device)
        x_downsampled = torch.scatter_reduce(x_downsampled, dim=1, index=down_merge_dst, src=x, reduce="mean", include_self=False)

        # Apply mid layers to merged tokens and compute the deviation
        y_downsampled = self.apply_local_layers(self.mid_layers, x_downsampled, 64*4)
        deviation = y_downsampled - x_downsampled        

        # Upsample by removing the first token merge group, shifting all token groups down and adding another one token group at the end.
        up_gate_samples = down_gate_samples[:, 1:]
        up_gate_samples = torch.cat([up_gate_samples, torch.ones(batch_size, 1, dtype=up_gate_samples.dtype).to(up_gate_samples.device)], dim=1)
        up_merge_dst, _ = get_merge_dst(up_gate_samples)
        up_merge_dst = up_merge_dst.unsqueeze(-1).expand(-1, -1, self.embedding_dim)

        # Add the upsampled deviation to the input to the middle layers
        upsampled_deviation = torch.gather(deviation, dim=1, index=up_merge_dst)
        y = x + upsampled_deviation

        # Apply up layers to byte tokens
        y = self.apply_local_layers(self.up_layers, y, 64)

        # Apply second gating to the downsampled output, for use in inference and a consistency loss in training.
        up_gate_logits = self.up_layer_gate(y)
        up_gate_probs = F.sigmoid(up_gate_logits)

        # Map residual stream to logits
        logits = self.output_layer(y)
        logits = F.log_softmax(logits, dim=-1)

        out = {
            "logits": logits,
            "down_gate_probs": down_gate_probs.squeeze(-1),
            "up_gate_probs": up_gate_probs.squeeze(-1),
            "down_gate_samples": down_gate_samples.to(dtype=torch.long),
            "up_gate_samples": up_gate_samples.to(dtype=torch.long),
            "down_merge_dst": down_merge_dst[:, :, 0], # This dimension is repeated.
            "up_merge_dst": up_merge_dst[:, :, 0],
            "n_dst": n_dst,
        }

        return out


my_model = MiniBitterLLM(vocab_size=byte5_tokenizer.vocab_size, embedding_dim=128, num_heads=4, downsample_rate=0.25)
print(f"my_model has {parameter_count_string(my_model)} parameters")

my_model has 1.3M parameters


In [45]:
test_string = openwebtext_8k[-1]["text"][:200]
test_batch = byte5_tokenizer.encode(test_string, return_tensors="pt", padding=True)
test_string

'Cameron could have risen to the occasion, like Obama, and tried to change the entire political tone of the debate over immigration\n\nIt is barely a week since Barack Obama broke the political mould on '

In [None]:
# test_batch = byte5_tokenizer(train_prompts[:5], return_tensors="pt", padding=True)["input_ids"]

In [46]:
out = my_model(test_batch)

for k, v in out.items():
    print(f"{k=} {v.shape=}, {v.dtype=}")


k='logits' v.shape=torch.Size([1, 201, 256]), v.dtype=torch.float32
k='down_gate_probs' v.shape=torch.Size([1, 201]), v.dtype=torch.float32
k='up_gate_probs' v.shape=torch.Size([1, 201]), v.dtype=torch.float32
k='down_gate_samples' v.shape=torch.Size([1, 201]), v.dtype=torch.int64
k='up_gate_samples' v.shape=torch.Size([1, 201]), v.dtype=torch.int64
k='down_merge_dst' v.shape=torch.Size([1, 201]), v.dtype=torch.int64
k='up_merge_dst' v.shape=torch.Size([1, 201]), v.dtype=torch.int64
k='n_dst' v.shape=torch.Size([1]), v.dtype=torch.int64


In [47]:
out = my_model(test_batch)
print(out["down_merge_dst"][0][:20])
print(out["up_merge_dst"][0][:20])

tensor([0, 1, 1, 1, 2, 3, 3, 3, 3, 3, 4, 5, 5, 6, 6, 6, 6, 6, 6, 6])
tensor([0, 0, 0, 1, 2, 2, 2, 2, 2, 3, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5])


In [48]:
def display_gating(tokens_ids, merge_dst):
    """Display how a SmallBitterLLM merges a sequence. token_ids and merge_dst are tensors of shape (sequence_length,)."""
    previous_merge_dst = 0
    for t_id, merge_destinantion in zip(tokens_ids, merge_dst):
        merge_destinantion = merge_destinantion.item()
        
        t_txt = byte5_tokenizer.decode(t_id)
        print(f"{t_txt}", end="")
        
        if merge_destinantion != previous_merge_dst:
            print(f"|", end="")
            previous_merge_dst = merge_destinantion

display_gating(test_batch[0], out["down_merge_dst"][0])

C|ame|r|on co|u|ld| have risen t|o the occasion|, li|ke O|bama, and tri|e|d to ch|an|g|e |th|e| e|n|tire |po|lit|ica|l t|one of| the| |d|e|bate over |i|mmigr|a|ti|on|
|
I|t i|s |bare|l|y a we|e|k s|ince Ba|rack |O|bam|a bro|ke t|he po|litic|al| mo|u|ld| o|n| </s>

In [26]:
my_optimizer = torch.optim.Adam(my_model.parameters(), lr=0.001)

In [29]:
def compute_discounted_rewards(rewards, discount):
    """
    Assumes that rewards is a numpy array of shape (n_episodes, n_timesteps). Returns tensor of same shape.
    credit to: https://stackoverflow.com/questions/47970683/vectorize-a-numpy-discount-calculation/47971187#47971187,
    minor modifications made to vectorise computation.
    C[i] = R[i] + discount * C[i+1]
    signal.lfilter(b, a, x, axis=-1, zi=None)
    a[0]*y[n] = b[0]*x[n] + b[1]*x[n-1] + ... + b[M]*x[n-M]
                          - a[1]*y[n-1] - ... - a[N]*y[n-N]
    """
    r = rewards[:, ::-1]
    a = [1, -discount]
    b = [1]
    y = lfilter(b, a, x=r)
    return y[:, ::-1]


def discounted_rewards_torch(rewards, discount):
    """torch wrapper for compute_discounted_rewards. Warning: does _not_ allow for backprop through the rewards, which is fine for policy gradients."""
    rewards_device = rewards.device
    rewards = rewards.detach().cpu().numpy()
    discounted_rewards = compute_discounted_rewards(rewards, discount)
    discounted_rewards = torch.tensor(discounted_rewards.copy(), device=rewards_device) # Copy as torch doesn't like converting negatively strided arrays
    return discounted_rewards


def bitter_tokenizer_training_step(model, batch, optimizer):
    """
    Assume that batch is torch.tensor of token ids of shape (batch, sequence_length). returns a dict of floats of the training losses for the batch.
    """
    batch_size, _ = batch.shape

    optimizer.zero_grad()

    out = model(batch)
    logits = out["logits"]
    down_gate_samples = out["down_gate_samples"]
    down_gate_probs = out["down_gate_probs"]
    up_gate_samples = out["up_gate_samples"]
    up_gate_probs = out["up_gate_probs"]

    # Compute autoregressive loss: log probability of next token.
    next_token_ids = batch[:, 1:]
    current_token_logits = logits[:, :-1]
    next_token_logits = F.cross_entropy(current_token_logits.transpose(1, 2), next_token_ids, reduction="none") # Transpose as F.cross_entropy wants shape [batch, classes, ...]
    ar_loss = next_token_logits.mean()

    # Compute gating loss: discounted log probabilities of following token(s).
    discount_rate = 0.95
    next_token_logits_padded = torch.cat([next_token_logits, torch.zeros(batch_size, 1, device=next_token_logits.device)], dim=-1) # Pad the last reward as zero
    discounted_rewards = discounted_rewards_torch(next_token_logits_padded, discount_rate)
    discounted_rewards = (discounted_rewards - discounted_rewards.mean()) # Simple estimate of the advantage

    # action 0 = continue, action 1 = gate
    action_log_probs = torch.stack([(1 - down_gate_probs).log() , down_gate_probs.log()], dim=1)
    selected_action_log_probs = F.cross_entropy(action_log_probs, down_gate_samples, reduction="none")
    gating_loss =  - (discounted_rewards * selected_action_log_probs).mean() # Negative as we want to maximise the reward.

    # Compute consistency loss: minimize difference between training gating and inference gating
    up_gating_log_probs = torch.stack([(1 - up_gate_probs).log() , up_gate_probs.log()], dim=1)
    consistency_loss = F.cross_entropy(up_gating_log_probs, up_gate_samples, reduction="mean")

    # Hacky additional consistency loss: make the downsampling rate match the training gating.
    down_gate_rate_loss = (model.downsample_rate - down_gate_probs.mean()) **2
    up_gate_rate_loss = (model.downsample_rate - up_gate_probs.mean()) **2
    rate_consistency_loss = 5.*(down_gate_rate_loss + up_gate_rate_loss)

    # Optimizer step
    total_loss = ar_loss + gating_loss + consistency_loss + rate_consistency_loss

    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()

    out = {
        "ar_loss": ar_loss.item(),
        "gating_loss": gating_loss.item(),
        "consistency_loss": consistency_loss.item(),
        "total_loss": total_loss.item(),
        "selected_action_ce": selected_action_log_probs.mean().item(),
        "down_gate_rate_loss": down_gate_rate_loss.item(),
        "up_gate_rate_loss": up_gate_rate_loss.item(),
        "rate_consistency_loss": rate_consistency_loss.item(),
    }

    return out

bitter_tokenizer_training_step(my_model, test_batch, my_optimizer)

down_gate_probs.mean()=tensor(0.3914, grad_fn=<MeanBackward0>) up_gate_probs.mean()=tensor(0.2047, grad_fn=<MeanBackward0>)


{'ar_loss': 4.190525531768799,
 'gating_loss': 1.1250131632802887,
 'consistency_loss': 0.7627230882644653,
 'total_loss': 6.188528588008682,
 'selected_action_ce': 0.6385809183120728,
 'down_gate_rate_loss': 0.020004423335194588,
 'up_gate_rate_loss': 0.0020489380694925785,
 'rate_consistency_loss': 0.1102668046951294}

In [30]:
from tqdm import tqdm


In [33]:
def bitter_tokenizer_training_loop(model, train_dataset):
    # TODO: validation dataset
    # Create data loaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=1,
        pin_memory=True
    )

    # Initialize model and optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    # Training loop
    num_epochs = 1
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        model = model.to(device)
        train_losses = []
        for step_no, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)):

            batch = batch["text"]
            # Tokenize and cutoff at 4096 tokens
            batch = byte5_tokenizer(batch, return_tensors="pt", padding=True)["input_ids"]
            batch = batch[:, :4096]  # Truncate to maximum length of 4096
            batch = batch.to(device)
            print(batch.shape)

            with profile(activities=activities, record_shapes=True) as prof:
                with record_function("model_inference"):
                    loss_dict = bitter_tokenizer_training_step(model, batch, optimizer)
            
            print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=10))

            train_losses.append(loss_dict)

            print(f"Batch ar train loss: {loss_dict["ar_loss"]}")

            if step_no % 10 == 0:
                # Print metrics
                print(f"Epoch {epoch+1}/{num_epochs} step:{step_no}/{len(train_loader)}")
                print(f"Train loss: {np.mean([l['total_loss'] for l in train_losses]):.4f}")
                with torch.no_grad():
                    print(f"")
        
    train_losses = pd.DataFrame(train_losses)

    return model, train_losses

In [34]:
from torch.profiler import profile, record_function, ProfilerActivity


activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA, ProfilerActivity.XPU]
sort_by_keyword = device + "_time_total"



In [35]:

# Run the function
result = bitter_tokenizer_training_loop(my_model, openwebtext_8k)[-1]

result

Epoch 1/1:   0%|                                                                                                                                | 0/2505 [00:00<?, ?it/s]

torch.Size([32, 4096])


Epoch 1/1:   0%|                                                                                                                                | 0/2505 [03:51<?, ?it/s]


KeyboardInterrupt: 

In [None]:
my_model

In [21]:
import gc

In [22]:
# Print the tensors using the most CUDA memory
def print_cuda_memory_usage():
    # Check if CUDA is available
    if not torch.cuda.is_available():
        print("CUDA is not available")
        return
    
    # Get total GPU memory
    total_memory = torch.cuda.get_device_properties(0).total_memory
    
    # Get memory usage statistics
    reserved_memory = torch.cuda.memory_reserved(0)
    allocated_memory = torch.cuda.memory_allocated(0)
    free_memory = total_memory - reserved_memory
    
    print(f"GPU Memory Usage:")
    print(f"  Total: {total_memory / 1e9:.2f} GB")
    print(f"  Reserved: {reserved_memory / 1e9:.2f} GB")
    print(f"  Allocated: {allocated_memory / 1e9:.2f} GB")
    print(f"  Free: {free_memory / 1e9:.2f} GB")
    
    # Find the largest tensors in memory
    print("\nLargest CUDA Tensors:")
    
    # Get all tensors in memory
    tensors = []
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) and obj.is_cuda:
                tensors.append((obj.size(), obj.element_size() * obj.nelement(), obj.dtype))
        except:
            pass
    
    # Sort tensors by memory usage (largest first)
    tensors.sort(key=lambda x: x[1], reverse=True)
    
    # Print the top 10 tensors
    for i, (size, memory, dtype) in enumerate(tensors[:10]):
        print(f"  {i+1}. Size: {size}, Memory: {memory / 1e6:.2f} MB, Type: {dtype}")

# Run the function to see memory usage
print_cuda_memory_usage()


GPU Memory Usage:
  Total: 25.44 GB
  Reserved: 24.95 GB
  Allocated: 23.61 GB
  Free: 0.49 GB

Largest CUDA Tensors:
  1. Size: torch.Size([3, 60469, 10, 128]), Memory: 928.80 MB, Type: torch.float32
  2. Size: torch.Size([3, 60469, 10, 128]), Memory: 928.80 MB, Type: torch.float32
  3. Size: torch.Size([3, 60469, 10, 128]), Memory: 928.80 MB, Type: torch.float32
  4. Size: torch.Size([10, 31738, 512]), Memory: 649.99 MB, Type: torch.float32
  5. Size: torch.Size([10, 31738, 512]), Memory: 649.99 MB, Type: torch.float32
  6. Size: torch.Size([10, 31738, 512]), Memory: 649.99 MB, Type: torch.float32
  7. Size: torch.Size([10, 31738, 512]), Memory: 649.99 MB, Type: torch.float32
  8. Size: torch.Size([10, 60469, 128]), Memory: 619.20 MB, Type: torch.int64
  9. Size: torch.Size([10, 60469, 128]), Memory: 619.20 MB, Type: torch.int64
  10. Size: torch.Size([3, 31738, 10, 128]), Memory: 487.50 MB, Type: torch.float32


  return isinstance(obj, torch.Tensor)


In [None]:
model_file_name = "bitter-llm-v1.pt"
net_scratch_dir = os.path.join("/itet-stor/sdauncey/net_scratch/VScodeProjects/bitter-lesson-tokenization")


# Save the model to the specified directory
os.makedirs(net_scratch_dir, exist_ok=True)
model_save_file = os.path.join(net_scratch_dir, model_file_name)
torch.save(my_model.state_dict(), model_save_file)
print(f"Model saved to {model_save_file}")

In [None]:
type(ds).__mro__

In [None]:
ds["train"].__class__.__mro__

In [None]:
for batch in DataLoader(ds["train"], batch_size=32):
    print(batch["prompt"][0])

In [None]:
ds["train"][0]