Byte Latent Transformer - Part 1: Model Architecture
Introduction
The Byte Latent Transformer is an advanced neural network architecture that processes data at the byte level. Unlike traditional transformers that work with word or subword tokens, this model operates directly on bytes, which offers several advantages:

Language-agnostic: Works with any language without special tokenization
Universal data handling: Can process text, code, and even binary data
No out-of-vocabulary issues: Every possible byte is in the vocabulary

In this implementation, we'll build a Byte Latent Transformer from scratch using PyTorch.
Implementation Overview
We'll implement the model in these parts:

1.Model Architecture

2.Training Pipeline

3.Inference and Evaluation

# **1.Model** **Architecture**

**Required** **Libraries**

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

**Core Components**
**1. Byte Embedding Layer**

This layer converts input bytes (0-255) into embeddings:

In [7]:
class ByteEmbedding(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        # 256 possible byte values (0-255)
        self.embedding = nn.Embedding(256, hidden_dim)

    def forward(self, x):
        return self.embedding(x)

**2. Positional Encoding**

For transformers to understand sequence order:bold text

In [8]:
class PositionalEncoding(nn.Module):
    def __init__(self, hidden_dim, max_seq_length=2048):
        super().__init__()

        # Create positional encoding matrix
        pe = torch.zeros(max_seq_length, hidden_dim)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, hidden_dim, 2).float() * (-math.log(10000.0) / hidden_dim)
        )

        # Apply sine to even indices and cosine to odd indices
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Register buffer (persistent state)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        # Add positional encoding to input embeddings
        return x + self.pe[:, :x.size(1)]

**3. Multi-Head Self-Attention**

The core attention mechanism:

In [9]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads, dropout=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads

        assert self.head_dim * num_heads == hidden_dim, "hidden_dim must be divisible by num_heads"

        # Linear projections for Q, K, V
        self.q_linear = nn.Linear(hidden_dim, hidden_dim)
        self.k_linear = nn.Linear(hidden_dim, hidden_dim)
        self.v_linear = nn.Linear(hidden_dim, hidden_dim)

        # Output projection
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # Linear projections and reshape for multi-head attention
        q = self.q_linear(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_linear(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_linear(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        # Apply mask if provided (for causal attention)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # Apply softmax and dropout
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Apply attention weights to values
        context = torch.matmul(attn_weights, v)

        # Reshape and apply output projection
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_dim)
        output = self.out_proj(context)

        return output

**4. Feed-Forward Network**

The position-wise feed-forward network:

In [10]:
class FeedForward(nn.Module):
    def __init__(self, hidden_dim, ff_dim, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(hidden_dim, ff_dim)
        self.linear2 = nn.Linear(ff_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Apply first linear layer with GELU activation
        x = F.gelu(self.linear1(x))
        # Apply dropout and second linear layer
        x = self.dropout(x)
        x = self.linear2(x)
        return x

**5. Encoder Layer**

Combines attention and feed-forward networks:

In [11]:
class EncoderLayer(nn.Module):
    def __init__(self, hidden_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(hidden_dim, num_heads, dropout)
        self.feed_forward = FeedForward(hidden_dim, ff_dim, dropout)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self-attention with residual connection and layer norm
        attn_output = self.self_attn(x, x, x, mask)
        x = x + self.dropout(attn_output)
        x = self.norm1(x)

        # Feed-forward with residual connection and layer norm
        ff_output = self.feed_forward(x)
        x = x + self.dropout(ff_output)
        x = self.norm2(x)

        return x

**6. Latent Projection Layer**

This is what makes it a "Latent" transformer - projecting bytes to a latent space:

In [12]:
class LatentProjection(nn.Module):
    def __init__(self, hidden_dim, latent_dim):
        super().__init__()
        self.down_proj = nn.Linear(hidden_dim, latent_dim)
        self.up_proj = nn.Linear(latent_dim, hidden_dim)
        self.norm = nn.LayerNorm(latent_dim)

    def forward(self, x):
        # Project to latent space
        latent = self.down_proj(x)
        latent = self.norm(latent)

        # Project back to hidden space
        return self.up_proj(latent)

**7. Complete Byte Latent Transformer Model**

Now we'll put everything together:

In [13]:
class ByteLatentTransformer(nn.Module):
    def __init__(self,
                 hidden_dim=512,
                 latent_dim=256,
                 num_layers=6,
                 num_heads=8,
                 ff_dim=2048,
                 dropout=0.1,
                 max_seq_length=2048):
        super().__init__()

        # Byte embedding layer
        self.byte_embedding = ByteEmbedding(hidden_dim)
        self.positional_encoding = PositionalEncoding(hidden_dim, max_seq_length)

        # Latent projection
        self.latent_projection = LatentProjection(hidden_dim, latent_dim)

        # Encoder layers
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(hidden_dim, num_heads, ff_dim, dropout)
            for _ in range(num_layers)
        ])

        # Output projection
        self.output_projection = nn.Linear(hidden_dim, 256)  # 256 possible byte values

        # Dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Convert input bytes to embeddings and add positional encoding
        x = self.byte_embedding(x)
        x = self.positional_encoding(x)
        x = self.dropout(x)

        # Apply latent projection
        x = self.latent_projection(x)

        # Pass through encoder layers
        for layer in self.encoder_layers:
            x = layer(x, mask)

        # Project to output vocabulary
        output = self.output_projection(x)

        return output

**8. Creating a Causal Mask**

For autoregressive generation, we need a causal mask:

In [14]:
def create_causal_mask(size):
    """Create a causal mask for autoregressive generation."""
    mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
    return ~mask  # Flip so 1s indicate allowed positions

**Testing the Model**

Let's test our model with a small example:

In [15]:
def test_model():
    # Create model with small dimensions for testing
    model = ByteLatentTransformer(
        hidden_dim=64,
        latent_dim=32,
        num_layers=2,
        num_heads=4,
        ff_dim=128
    )

    # Create a sample input (batch_size=2, seq_length=10)
    x = torch.randint(0, 256, (2, 10))

    # Create causal mask
    mask = create_causal_mask(10).unsqueeze(0).unsqueeze(0)

    # Forward pass
    output = model(x, mask)

    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print("Model test successful!")

# Run the test
test_model()

Input shape: torch.Size([2, 10])
Output shape: torch.Size([2, 10, 256])
Model test successful!


In the next part, we'll implement the training pipeline including data loading, optimization, and loss functions.

# Byte Latent Transformer - Part 2: Training Pipeline

Introduction

In this second part, we'll implement the training pipeline for our Byte Latent Transformer. This includes:

1.Data processing at the byte level

2.Creating datasets and data loaders

3.Setting up the training **loop** **bold text** **bold text**

4.Implementing optimization and learning rate scheduling

**Required Libraries**

In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F

import numpy as np
import os
import time
import math
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

**1. Data Processing**

First, let's create utilities to convert text to byte sequences and vice versa:

In [17]:
def text_to_bytes(text):
    """Convert text to a list of byte values."""
    return list(text.encode('utf-8'))

def bytes_to_text(byte_list):
    """Convert a list of byte values back to text."""
    return bytes(byte_list).decode('utf-8', errors='replace')

**2. Dataset Creation**

Now, let's create a dataset for byte-level language modeling:

In [18]:
class ByteDataset(Dataset):
    def __init__(self, data, seq_length):
        self.data = data  # Raw bytes
        self.seq_length = seq_length

    def __len__(self):
        # Minus 1 because each example needs a target (next byte)
        return max(0, len(self.data) - self.seq_length)

    def __getitem__(self, idx):
        # Get input sequence
        input_seq = self.data[idx:idx+self.seq_length]

        # Get target sequence (shifted by 1)
        target_seq = self.data[idx+1:idx+self.seq_length+1]

        # Convert to tensors
        input_tensor = torch.tensor(input_seq, dtype=torch.long)
        target_tensor = torch.tensor(target_seq, dtype=torch.long)

        return input_tensor, target_tensor

Let's create a function to load and preprocess text data:

In [19]:
def load_text_data(file_path, seq_length):
    """Load text file and create a ByteDataset."""
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()
    except UnicodeDecodeError:
        # Try with a different encoding if UTF-8 fails
        with open(file_path, 'r', encoding='latin-1') as f:
            text = f.read()

    # Convert text to bytes
    byte_data = text_to_bytes(text)

    # Create dataset
    dataset = ByteDataset(byte_data, seq_length)

    return dataset

**3. Training Utils**

Let's implement some utility functions for tracking training progress:

In [20]:
class TrainingMonitor():
    def __init__(self):
        self.epochs = []
        self.losses = []
        self.val_losses = []
        self.best_val_loss = float('inf')

    def update(self, epoch, loss, val_loss=None):
        self.epochs.append(epoch)
        self.losses.append(loss)

        if val_loss is not None:
            self.val_losses.append(val_loss)
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                return True  # Signal to save checkpoint
        return False

    def plot(self):
        plt.figure(figsize=(10, 5))
        plt.plot(self.epochs, self.losses, label='Training Loss')

        if self.val_losses:
            plt.plot(self.epochs, self.val_losses, label='Validation Loss')

        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.title('Training Progress')
        plt.grid(True)
        plt.show()

**4. Learning Rate Scheduler**

Let's implement a learning rate scheduler with warmup and cosine decay:

In [21]:
class WarmupCosineScheduler():
    def __init__(self, optimizer, warmup_steps, total_steps, min_lr_ratio=0.1):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.min_lr_ratio = min_lr_ratio

        # Get initial learning rate
        self.base_lr = optimizer.param_groups[0]['lr']
        self.min_lr = self.base_lr * min_lr_ratio

        self.step_count = 0

    def step(self):
        self.step_count += 1
        lr = self.get_lr()

        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def get_lr(self):
        # Warmup phase
        if self.step_count < self.warmup_steps:
            return self.base_lr * (self.step_count / self.warmup_steps)

        # Cosine decay phase
        progress = (self.step_count - self.warmup_steps) / (self.total_steps - self.warmup_steps)
        cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))

        # Scale between base_lr and min_lr
        return self.min_lr + (self.base_lr - self.min_lr) * cosine_decay

**5. Training Function**

Now, let's implement the full training loop:

In [22]:
def train_model(model, train_dataloader, val_dataloader=None,
                epochs=10, lr=3e-4, warmup_steps=1000, device='cuda',
                save_dir='checkpoints', save_every=1):

    # Create directories if they don't exist
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # Move model to device
    model = model.to(device)

    # Setup optimizer
    optimizer = optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0.01)

    # Calculate total steps for scheduler
    total_steps = epochs * len(train_dataloader)

    # Setup scheduler
    scheduler = WarmupCosineScheduler(optimizer, warmup_steps, total_steps)

    # Setup training monitor
    monitor = TrainingMonitor()

    # Create a causal mask once for the maximum sequence length
    seq_length = next(iter(train_dataloader))[0].size(1)
    causal_mask = create_causal_mask(seq_length).to(device)

    # Training loop
    for epoch in range(1, epochs + 1):
        model.train()
        epoch_loss = 0
        start_time = time.time()

        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch}/{epochs}")

        for batch_idx, (inputs, targets) in enumerate(progress_bar):
            # Move data to device
            inputs = inputs.to(device)
            targets = targets.to(device)

            # Forward pass
            outputs = model(inputs, causal_mask)

            # Reshape outputs for loss calculation
            # [batch_size, seq_len, vocab_size] -> [batch_size * seq_len, vocab_size]
            outputs = outputs.view(-1, outputs.size(-1))
            targets = targets.view(-1)

            # Calculate loss
            loss = F.cross_entropy(outputs, targets)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            # Update parameters
            optimizer.step()

            # Update learning rate
            scheduler.step()

            # Update progress bar
            epoch_loss += loss.item()
            avg_loss = epoch_loss / (batch_idx + 1)
            progress_bar.set_postfix(loss=f"{avg_loss:.4f}",
                                    lr=f"{optimizer.param_groups[0]['lr']:.6f}")

        # Calculate average epoch loss
        avg_epoch_loss = epoch_loss / len(train_dataloader)

        # Validation
        val_loss = None
        if val_dataloader:
            val_loss = validate_model(model, val_dataloader, causal_mask, device)
            print(f"Epoch {epoch}: train_loss={avg_epoch_loss:.4f}, val_loss={val_loss:.4f}, "
                  f"time={time.time() - start_time:.2f}s")
        else:
            print(f"Epoch {epoch}: train_loss={avg_epoch_loss:.4f}, "
                  f"time={time.time() - start_time:.2f}s")

        # Update monitor and save checkpoint if it's the best model
        save_checkpoint = monitor.update(epoch, avg_epoch_loss, val_loss)

        # Save model checkpoint
        if save_checkpoint or (epoch % save_every == 0):
            checkpoint_path = os.path.join(save_dir, f"model_epoch_{epoch}.pt")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_epoch_loss,
                'val_loss': val_loss,
            }, checkpoint_path)
            print(f"Model checkpoint saved to {checkpoint_path}")

            # Save best model separately
            if save_checkpoint:
                best_path = os.path.join(save_dir, "best_model.pt")
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'train_loss': avg_epoch_loss,
                    'val_loss': val_loss,
                }, best_path)
                print(f"Best model saved with validation loss: {val_loss:.4f}")

    # Plot training progress
    monitor.plot()

    return model, monitor

**6. Validation Function**

Let's implement the validation function:

In [23]:
def validate_model(model, val_dataloader, mask, device):
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for inputs, targets in val_dataloader:
            # Move data to device
            inputs = inputs.to(device)
            targets = targets.to(device)

            # Forward pass
            outputs = model(inputs, mask)

            # Reshape outputs for loss calculation
            outputs = outputs.view(-1, outputs.size(-1))
            targets = targets.view(-1)

            # Calculate loss
            loss = F.cross_entropy(outputs, targets)
            val_loss += loss.item()

    return val_loss / len(val_dataloader)

**7. Data Preparation**

Now, let's write a function to prepare our data and create data loaders:

In [24]:
def prepare_data(file_path, seq_length=128, batch_size=32, val_split=0.1, num_workers=2):
    """Prepare data for training and validation."""

    # Load dataset
    full_dataset = load_text_data(file_path, seq_length)

    # Split into train and validation
    val_size = int(len(full_dataset) * val_split)
    train_size = len(full_dataset) - val_size

    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size]
    )

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return train_loader, val_loader

**8. Putting it All Together**

Let's create a function to initialize and train our model:

In [25]:
def create_causal_mask(size):
    """Create a causal mask for autoregressive generation."""
    mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
    return ~mask  # Flip so 1s indicate allowed positions

def train_byte_latent_transformer(file_path,
                                 hidden_dim=512,
                                 latent_dim=256,
                                 num_layers=6,
                                 num_heads=8,
                                 ff_dim=2048,
                                 dropout=0.1,
                                 max_seq_length=512,
                                 batch_size=16,
                                 epochs=10,
                                 learning_rate=3e-4,
                                 warmup_steps=1000,
                                 device='cuda' if torch.cuda.is_available() else 'cpu'):
    """Initialize and train a ByteLatentTransformer model."""

    # Create model
    model = ByteLatentTransformer(
        hidden_dim=hidden_dim,
        latent_dim=latent_dim,
        num_layers=num_layers,
        num_heads=num_heads,
        ff_dim=ff_dim,
        dropout=dropout,
        max_seq_length=max_seq_length
    )

    # Prepare data
    train_loader, val_loader = prepare_data(
        file_path,
        seq_length=max_seq_length,
        batch_size=batch_size
    )

    # Train model
    trained_model, monitor = train_model(
        model=model,
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        epochs=epochs,
        lr=learning_rate,
        warmup_steps=warmup_steps,
        device=device
    )

    return trained_model, monitor

**Example Usage**

Here's how you can use the training pipeline:

In [26]:
import requests

url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"

# Download the file
response = requests.get(url)
response.raise_for_status()  # Raise an exception for bad status codes

# Save the file
with open("shakespeare.txt", "w", encoding="utf-8") as f:
    f.write(response.text[:50000])

print("File downloaded and saved as shakespeare.txt")

File downloaded and saved as shakespeare.txt


In [None]:
# Example usage with a text file
def main():
    # Check if GPU is available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load and train the model
    # For a quick test, use smaller dimensions
    model, monitor = train_byte_latent_transformer(
        file_path='/content/shakespeare.txt',  # Replace with your file
        hidden_dim=256,  # Smaller for faster training
        latent_dim=128,
        num_layers=4,
        num_heads=4,
        ff_dim=1024,
        max_seq_length=128,
        batch_size=8,
        epochs=5
    )

    # Plot training progress
    monitor.plot()

    print("Training complete!")

if __name__ == "__main__":
    main()

Because of the speed of cpu I can't run it currently now

In the next part, we'll implement the inference and evaluation functionality to generate text with our trained model.

# Byte Latent Transformer - Part 3: Inference and Evaluation
In this final part, we'll implement the inference and evaluation functionality for our Byte Latent Transformer. This includes:

1.Text generation with the trained model

2.Model evaluation metrics

3.Sample applications

4.Saving and loading models

5.Complete example with Google Colab integration


**Required** **Libraries**

In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import random

**1. Text Generation Functions**
First, let's implement functions for generating text with our trained model:

In [29]:
def create_causal_mask(size):
    """Create a causal mask for autoregressive generation."""
    mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
    return ~mask  # Flip so 1s indicate allowed positions

def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None):
    """Sample from logits with optional temperature, top-k, and nucleus sampling."""
    # Apply temperature
    logits = logits / temperature

    # Apply top-k filtering if specified
    if top_k is not None:
        # Keep only the top-k values, set the rest to -inf
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        threshold = v[:, -1].unsqueeze(1)
        logits = torch.where(logits < threshold,
                           torch.ones_like(logits) * float('-inf'),
                           logits)

    # Apply nucleus (top-p) sampling if specified
    if top_p is not None:
        # Sort logits in descending order
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)

        # Calculate cumulative probabilities
        probs = F.softmax(sorted_logits, dim=-1)
        cum_probs = torch.cumsum(probs, dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cum_probs > top_p

        # Shift the indices to the right to keep the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # Create a mask for indices to remove
        indices_to_remove = sorted_indices_to_remove.scatter(
            1, sorted_indices, sorted_indices_to_remove
        )

        # Set logits to -inf where needed
        logits = logits.masked_fill(indices_to_remove, float('-inf'))

    # Apply softmax to get probabilities
    probs = F.softmax(logits, dim=-1)

    # Sample from the probability distribution
    next_token = torch.multinomial(probs, 1)

    return next_token

def generate_text(model, prompt_bytes, max_length=100, temperature=0.8,
                  top_k=50, top_p=0.9, device='cuda'):
    """Generate text using the trained model."""
    model.eval()
    model = model.to(device)

    # Convert prompt to tensor
    if isinstance(prompt_bytes, str):
        # If prompt is a string, convert to bytes
        prompt_bytes = list(prompt_bytes.encode('utf-8'))

    input_tensor = torch.tensor([prompt_bytes], dtype=torch.long).to(device)
    generated_bytes = list(prompt_bytes)

    # Generate text auto-regressively
    with torch.no_grad():
        for _ in range(max_length):
            # Use only the last 'max_seq_length' tokens if input is too long
            if hasattr(model, 'positional_encoding') and hasattr(model.positional_encoding, 'pe'):
                max_seq_length = model.positional_encoding.pe.size(1)
                if input_tensor.size(1) > max_seq_length:
                    input_tensor = input_tensor[:, -max_seq_length:]

            # Create causal mask for the sequence
            causal_mask = create_causal_mask(input_tensor.size(1)).to(device)

            # Forward pass through the model
            logits = model(input_tensor, causal_mask)

            # Get predictions for the next token
            next_token_logits = logits[0, -1, :]

            # Sample next token
            next_token = sample_from_logits(
                next_token_logits.unsqueeze(0),
                temperature=temperature,
                top_k=top_k,
                top_p=top_p
            )

            # Add the generated token to the input for the next iteration
            input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0)], dim=1)

            # Store the generated byte
            generated_bytes.append(next_token.item())

            # Stop if we generate the end of text token (if you have one)
            # For now, we'll just generate up to max_length

    # Convert bytes back to text
    try:
        generated_text = bytes(generated_bytes).decode('utf-8', errors='replace')
    except Exception as e:
        print(f"Error decoding bytes: {e}")
        generated_text = "Error decoding generated bytes"

    return generated_text

**2. Model Evaluation Metrics**
Let's implement functions to evaluate our model:

In [30]:
def evaluate_perplexity(model, dataloader, device='cuda'):
    """Evaluate model perplexity on a dataset."""
    model.eval()
    model = model.to(device)

    total_loss = 0
    total_tokens = 0

    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Evaluating perplexity"):
            inputs = inputs.to(device)
            targets = targets.to(device)

            # Create causal mask
            causal_mask = create_causal_mask(inputs.size(1)).to(device)

            # Forward pass
            outputs = model(inputs, causal_mask)

            # Reshape outputs for loss calculation
            outputs = outputs.view(-1, outputs.size(-1))
            targets = targets.view(-1)

            # Calculate loss
            loss = F.cross_entropy(outputs, targets, reduction='sum')

            # Update counters
            total_loss += loss.item()
            total_tokens += targets.numel()

    # Calculate perplexity
    avg_loss = total_loss / total_tokens
    perplexity = torch.exp(torch.tensor(avg_loss))

    return perplexity.item()

def evaluate_model_speed(model, seq_length=128, batch_size=1, num_iterations=10, device='cuda'):
    """Evaluate model inference speed."""
    model.eval()
    model = model.to(device)

    # Create random input tensors
    inputs = torch.randint(0, 256, (batch_size, seq_length), device=device)

    # Create causal mask
    causal_mask = create_causal_mask(seq_length).to(device)

    # Warm up
    with torch.no_grad():
        for _ in range(5):
            _ = model(inputs, causal_mask)

    # Measure inference time
    torch.cuda.synchronize() if device == 'cuda' else None
    start_time = time.time()

    with torch.no_grad():
        for _ in range(num_iterations):
            _ = model(inputs, causal_mask)

    torch.cuda.synchronize() if device == 'cuda' else None
    end_time = time.time()

    # Calculate statistics
    total_time = end_time - start_time
    avg_time = total_time / num_iterations
    tokens_per_second = (batch_size * seq_length) / avg_time

    return {
        'total_time': total_time,
        'avg_time_per_batch': avg_time,
        'tokens_per_second': tokens_per_second
    }

**3. Save and Load Model** **Functions**

In [32]:
def save_model(model, optimizer=None, epoch=None, loss=None, path='model.pt'):
    """Save model and training state."""
    checkpoint = {
        'model_state_dict': model.state_dict(),
    }

    if optimizer is not None:
        checkpoint['optimizer_state_dict'] = optimizer.state_dict()

    if epoch is not None:
        checkpoint['epoch'] = epoch

    if loss is not None:
        checkpoint['loss'] = loss

    # Add model hyperparameters for easy loading
    checkpoint['model_config'] = {
        'hidden_dim': model.byte_embedding.embedding.weight.size(1),
        'latent_dim': model.latent_projection.down_proj.out_features,
        'num_layers': len(model.encoder_layers),
        'num_heads': model.encoder_layers[0].self_attn.num_heads,
        'ff_dim': model.encoder_layers[0].feed_forward.linear1.out_features,
    }

    torch.save(checkpoint, path)
    print(f"Model saved to {path}")

def load_model(path, device='cuda'):
    """Load model from checkpoint."""

    checkpoint = torch.load(path, map_location=device)

    # Get model configuration
    config = checkpoint.get('model_config', {})

    # Create model with the same configuration
    model = ByteLatentTransformer(
        hidden_dim=config.get('hidden_dim', 512),
        latent_dim=config.get('latent_dim', 256),
        num_layers=config.get('num_layers', 6),
        num_heads=config.get('num_heads', 8),
        ff_dim=config.get('ff_dim', 2048)
    )

    # Load the state dict
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)

    print(f"Model loaded from {path}")
    return model, checkpoint

**4. Sample Applications**

**4.1 Text Completion **

In [33]:
def complete_text(model, prompt, max_length=100, temperature=0.8, top_k=50, top_p=0.9, device='cuda'):
    """Complete text given a prompt."""
    return generate_text(
        model=model,
        prompt_bytes=prompt,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        device=device
    )

4.2 Interactive Text Generation

In [34]:
def interactive_generation(model, device='cuda'):
    """Interactive text generation with the model."""
    print("=== Interactive Text Generation ===")
    print("Type your prompt and press Enter. Type 'exit' to quit.")

    while True:
        prompt = input("\nPrompt: ")
        if prompt.lower() == 'exit':
            break

        print("\nGenerating...")

        try:
            temperature = float(input("Temperature (0.1-1.5, default 0.8): ") or 0.8)
            max_length = int(input("Max length to generate (default 100): ") or 100)

            generated_text = complete_text(
                model=model,
                prompt=prompt,
                max_length=max_length,
                temperature=temperature,
                device=device
            )

            print("\n=== Generated Text ===")
            print(generated_text)

        except Exception as e:
            print(f"Error: {e}")

4.3 Byte-Level Analysis

In [35]:
def analyze_byte_distribution(model, text, device='cuda'):
    """Analyze byte distribution predictions for a given text."""
    model.eval()
    model = model.to(device)

    # Convert text to bytes
    bytes_data = list(text.encode('utf-8'))
    input_tensor = torch.tensor([bytes_data], dtype=torch.long).to(device)

    # Create causal mask
    causal_mask = create_causal_mask(input_tensor.size(1)).to(device)

    # Get model predictions
    with torch.no_grad():
        logits = model(input_tensor, causal_mask)

    # Convert logits to probabilities
    probs = F.softmax(logits, dim=-1)

    # Get top predicted bytes for each position
    top_k = 5
    top_probs, top_indices = torch.topk(probs[0], k=top_k, dim=-1)

    # Print analysis
    print(f"Byte-level analysis for: '{text}'")
    print("Format: position -> actual byte -> top predictions")

    for i in range(len(bytes_data)):
        actual_byte = bytes_data[i]
        byte_char = chr(actual_byte) if 32 <= actual_byte <= 126 else f"\\x{actual_byte:02x}"

        print(f"{i}: '{byte_char}' (byte {actual_byte}) -> ", end="")

        for j in range(top_k):
            pred_byte = top_indices[i, j].item()
            pred_prob = top_probs[i, j].item()
            pred_char = chr(pred_byte) if 32 <= pred_byte <= 126 else f"\\x{pred_byte:02x}"

            print(f"{pred_char} ({pred_byte}): {pred_prob:.4f}", end="")
            if j < top_k - 1:
                print(", ", end="")

        print()

**5. Complete Usage Example with Google Colab**

Here's a complete example of how to use our Byte Latent Transformer in Google Colab:

In [37]:
def colab_main():
    """Main function for Google Colab usage."""
    from google.colab import files
    import os
    import torch

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Menu for user options
    while True:
        print("\n=== Byte Latent Transformer ===")
        print("1. Train a new model")
        print("2. Load a trained model")
        print("3. Generate text")
        print("4. Interactive text generation")
        print("5. Evaluate model")
        print("6. Upload/download model")
        print("7. Exit")

        choice = input("\nEnter your choice (1-7): ")

        if choice == '1':
            # Train a new model
            print("\n=== Training a New Model ===")

            # Ask user to upload a text file
            print("Please upload a text file for training:")
            uploaded = files.upload()

            if not uploaded:
                print("No file uploaded. Returning to menu.")
                continue

            file_name = list(uploaded.keys())[0]
            file_path = file_name

            # Get training parameters
            try:
                hidden_dim = int(input("Hidden dimension (default 256): ") or 256)
                latent_dim = int(input("Latent dimension (default 128): ") or 128)
                num_layers = int(input("Number of layers (default 4): ") or 4)
                num_heads = int(input("Number of attention heads (default 4): ") or 4)
                epochs = int(input("Number of epochs (default 5): ") or 5)
                batch_size = int(input("Batch size (default 8): ") or 8)
                max_seq_length = int(input("Maximum sequence length (default 128): ") or 128)

                # Import functions from Part 2
                from Part2_ByteLatentTransformer_Training import train_byte_latent_transformer

                # Train the model
                model, monitor = train_byte_latent_transformer(
                    file_path=file_path,
                    hidden_dim=hidden_dim,
                    latent_dim=latent_dim,
                    num_layers=num_layers,
                    num_heads=num_heads,
                    ff_dim=hidden_dim * 4,
                    max_seq_length=max_seq_length,
                    batch_size=batch_size,
                    epochs=epochs,
                    device=device
                )

                # Save the trained model
                save_path = f"byte_latent_transformer_{hidden_dim}_{latent_dim}_{num_layers}.pt"
                save_model(model, path=save_path)

                # Download the trained model
                files.download(save_path)

            except Exception as e:
                print(f"Error during training: {e}")

        elif choice == '2':
            # Load a trained model
            print("\n=== Loading a Trained Model ===")
            print("Please upload a trained model file (.pt):")

            uploaded = files.upload()

            if not uploaded:
                print("No file uploaded. Returning to menu.")
                continue

            model_file = list(uploaded.keys())[0]

            try:
                model, checkpoint = load_model(model_file, device=device)
                print("Model loaded successfully!")

                # Display model information
                config = checkpoint.get('model_config', {})
                print(f"Model configuration:")
                print(f"- Hidden dimension: {config.get('hidden_dim', 'Unknown')}")
                print(f"- Latent dimension: {config.get('latent_dim', 'Unknown')}")
                print(f"- Number of layers: {config.get('num_layers', 'Unknown')}")
                print(f"- Number of attention heads: {config.get('num_heads', 'Unknown')}")

            except Exception as e:
                print(f"Error loading model: {e}")
                model = None

        elif choice == '3':
            # Generate text
            if not locals().get('model'):
                print("No model loaded. Please load a model first.")
                continue

            print("\n=== Generate Text ===")
            prompt = input("Enter a prompt: ")

            try:
                temperature = float(input("Temperature (0.1-1.5, default 0.8): ") or 0.8)
                max_length = int(input("Max length to generate (default 100): ") or 100)
                top_k = int(input("Top-k value (default 50, 0 to disable): ") or 50)
                top_p = float(input("Top-p value (default 0.9, 0 to disable): ") or 0.9)

                # Generate text
                generated_text = generate_text(
                    model=model,
                    prompt_bytes=prompt,
                    max_length=max_length,
                    temperature=temperature,
                    top_k=top_k if top_k > 0 else None,
                    top_p=top_p if top_p > 0 else None,
                    device=device
                )

                print("\n=== Generated Text ===")
                print(generated_text)

            except Exception as e:
                print(f"Error during text generation: {e}")

        elif choice == '4':
            # Interactive text generation
            if not locals().get('model'):
                print("No model loaded. Please load a model first.")
                continue

            try:
                interactive_generation(model, device=device)
            except Exception as e:
                print(f"Error during interactive generation: {e}")

        elif choice == '5':
            # Evaluate model
            if not locals().get('model'):
                print("No model loaded. Please load a model first.")
                continue

            print("\n=== Evaluate Model ===")
            print("1. Measure inference speed")
            print("2. Analyze byte distribution")
            print("3. Return to main menu")

            eval_choice = input("Enter your choice (1-3): ")

            if eval_choice == '1':
                # Measure inference speed
                try:
                    batch_size = int(input("Batch size (default 1): ") or 1)
                    seq_length = int(input("Sequence length (default 128): ") or 128)
                    num_iterations = int(input("Number of iterations (default 10): ") or 10)

                    results = evaluate_model_speed(
                        model=model,
                        seq_length=seq_length,
                        batch_size=batch_size,
                        num_iterations=num_iterations,
                        device=device
                    )

                    print("\n=== Performance Results ===")
                    print(f"Total time: {results['total_time']:.4f} seconds")
                    print(f"Average time per batch: {results['avg_time_per_batch'] * 1000:.4f} ms")
                    print(f"Tokens per second: {results['tokens_per_second']:.2f}")

                except Exception as e:
                    print(f"Error during performance evaluation: {e}")

            elif eval_choice == '2':
                # Analyze byte distribution
                try:
                    text = input("Enter text to analyze: ")
                    analyze_byte_distribution(model, text, device=device)
                except Exception as e:
                    print(f"Error during byte distribution analysis: {e}")

        elif choice == '6':
            # Upload/download model
            print("\n=== Upload/Download Model ===")
            print("1. Upload model")
            print("2. Download current model")
            print("3. Return to main menu")

            file_choice = input("Enter your choice (1-3): ")

            if file_choice == '1':
                # Upload model
                print("Please upload a trained model file (.pt):")
                uploaded = files.upload()

                if not uploaded:
                    print("No file uploaded. Returning to menu.")
                    continue

                model_file = list(uploaded.keys())[0]

                try:
                    model, checkpoint = load_model(model_file, device=device)
                    print("Model loaded successfully!")
                except Exception as e:
                    print(f"Error loading model: {e}")

            elif file_choice == '2':
                # Download current model
                if not locals().get('model'):
                    print("No model loaded. Please load a model first.")
                    continue

                try:
                    save_path = input("Enter filename to save as (default: model.pt): ") or "model.pt"
                    save_model(model, path=save_path)
                    files.download(save_path)
                except Exception as e:
                    print(f"Error downloading model: {e}")

        elif choice == '7':
            # Exit
            print("Exiting...")
            break

        else:
            print("Invalid choice. Please try again.")

if __name__ == "__main__":
    # Check if running in Google Colab
    try:
        import google.colab
        is_colab = True
    except ImportError:
        is_colab = False

    if is_colab:
        colab_main()
    else:
        # When running locally, use a simpler interface
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {device}")

        model = None
        prompt = "Hello, world!"

        while True:
            print("\n=== Byte Latent Transformer ===")
            print("1. Train a new model")
            print("2. Load a trained model")
            print("3. Generate text")
            print("4. Exit")

            choice = input("\nEnter your choice (1-4): ")

            # Implement local options similar to the Colab interface
            if choice == '4':
                break

Using device: cpu

=== Byte Latent Transformer ===
1. Train a new model
2. Load a trained model
3. Generate text
4. Interactive text generation
5. Evaluate model
6. Upload/download model
7. Exit

Enter your choice (1-7): 7
Exiting...


This completes the implementation of our Byte Latent Transformer. You now have a fully functional model that can:

1.Process text data at the byte level

3.Train efficiently with a latent space projection

4.Generate text with various sampling strategies

5.Evaluate model performance and analyze byte distributions

The implementation is organized into three parts for better understanding and modularity, and comes with a Google Colab interface for easy usage.