In [1]:
import torch
import torch.nn as nn
from dataclasses import dataclass
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from types import SimpleNamespace
from tqdm import tqdm

class ByteEmbedding(nn.Module):
    def __init__(self,d_model,hash_size):
        super().__init__()
        self.byte_embed = nn.Embedding(256,d_model)
        self.hash_embed = nn.Embedding(hash_size,d_model)

    def forward(self,byte_seq,hash_seq):
        byte_embedding = self.byte_embed(byte_seq)
        hash_embedding = self.hash_embed(hash_seq)
        return byte_embedding + hash_embedding

class FeedForwardLayer(nn.Module):
    def __init__(self,d_model,ff_dim,dropout):
        super().__init__()
        self.layer1 = nn.Linear(d_model,ff_dim)
        self.layer2 = nn.Linear(ff_dim,d_model)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(dropout)

    def forward(self,x):
        return self.layer2(self.dropout(self.gelu(self.layer1(x))))

class TransformerBlock(nn.Module):
    def __init__(self,d_model,n_heads,ff_dim,dropout):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model,n_heads,dropout=dropout)
        self.ff = FeedForwardLayer(d_model,ff_dim,dropout=dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self,x,attn_mask = None):
        attn_out,_ = self.attention(x,x,x,attn_mask=attn_mask) 
        x = x + self.dropout(attn_out)
        x = self.norm1(x)
        ff_out = self.ff(x)
        x = x + self.dropout(attn_out)
        return self.norm2(x)

class CrossAttentionBlock(nn.Module):
    def __init__(self,query_dim,key_dim,n_heads,ff_dim,dropout):
        super().__init__()
        self.attention = nn.MultiheadAttention(query_dim,n_heads,dropout=dropout)
        self.norm = nn.LayerNorm(query_dim)
        self.ff = FeedForwardLayer(query_dim,ff_dim,dropout = dropout)
        self.dropout = nn.Dropout(dropout)
        self.query_proj = nn.Linear(query_dim,query_dim)
        self.key_proj = nn.Linear(key_dim,query_dim)
        self.value_proj = nn.Linear(key_dim,query_dim)

    def forward(self,query,key,value):
        query = self.query_proj(query).permute(1,0,2)
        key = self.key_proj(key).permute(1,0,2)
        value = self.value_proj(value).permute(1,0,2)

        attn_out , _ = self.attention(query,key,value)
        attn_out = attn_out.permute(1,0,2)
        query = query.permute(1,0,2)
        query = query + self.dropout(attn_out)
        query = self.norm(query)
        ff_out = self.ff(query)
        return query + self.dropout(ff_out)

class LocalEncoder(nn.Module):
    def __init__(self,byte_dim,n_heads,ff_dim,n_layers,dropout):
        super().__init__()
        self.layers = nn.ModuleList([TransformerBlock(byte_dim,n_heads,ff_dim,dropout) for _ in range(n_layers)])
        self.cross_attn = CrossAttentionBlock(query_dim=byte_dim,key_dim=byte_dim,n_heads=n_heads,ff_dim=ff_dim,dropout=dropout)

    def forward(self,byte_embeddings,patch_embeddings):
        for layer in self.layers:
            byte_embeddings = layer(byte_embeddings)
        patch_embedding = self.cross_attn(patch_embeddings,byte_embeddings,byte_embeddings)
        return patch_embedding

class LocalDecoder(nn.Module):
    def __init__(self,patch_dim,byte_dim,n_heads,ff_dim,n_layers,dropout):
        super().__init__()
        self.layers = nn.ModuleList([TransformerBlock(byte_dim,n_heads,ff_dim,dropout) for _ in range(n_layers)])
        self.cross_attn = CrossAttentionBlock(query_dim=byte_dim,key_dim=patch_dim,n_heads=n_heads,ff_dim=ff_dim,dropout=dropout)
        self.output_proj = nn.Linear(byte_dim,256)

    def forward(self,patch_embedding,byte_embedding):
        byte_embedding = self.cross_attn(byte_embedding,patch_embedding,patch_embedding)
        for layer in self.layers:
            byte_embedding = layer(byte_embedding)
        return self.output_proj(byte_embedding)

def l2_loss(pred,target):
    return torch.sum((pred - target) ** 2)

class TitanMemory(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.d_model = config.d_model
        self.register_buffer("M",torch.eye(config.d_model))
        self.register_buffer("S",torch.zeros(config.d_model,config.d_model))

        self.query = nn.Linear(config.d_model,config.d_model,bias=False)
        self.key = nn.Linear(config.d_model,config.d_model,bias=False)
        self.value = nn.Linear(config.d_model,config.d_model,bias=False)

        self.alpha = config.alpha
        self.eta = config.eta
        self.theta = config.theta

    def forward(self,x):
        q = self.query(x)
        y = torch.matmul(q,self.M)
        return y

    def update_memory(self,x):
        B = x.size(0)
        if B != 1:
            for i in range(B):
                self.update_memory(x[i:i+1])
            return

        k = self.key(x)
        v = self.value(x)

        v_pred = torch.matmul(k,self.M)
        loss = l2_loss(v_pred,v)
        error = v_pred - v

        g = 2 * torch.matmul(error.t(),k)

        self.S = self.eta * self.S - self.theta * g
        self.S = torch.clamp(self.S, -1e3, 1e3)
        self.M = (1-self.alpha) * self.M + self.S
        self.M = torch.clamp(self.M, -1e3, 1e3)
        return loss

class SlidingWindowAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.d_model = config.d_model
        self.n_heads = config.n_heads
        self.window_size = config.window_size
        self.attention = nn.MultiheadAttention(embed_dim = config.d_model, num_heads = config.n_heads, batch_first= True)

    def forward(self,x):
        batch_size,seq_len,_ = x.size()
        output = []

        for i in range(0,seq_len,self.window_size):
            x_chunk = x[:,i:i+self.window_size,:]
            attn_out,_ = self.attention(x_chunk,x_chunk,x_chunk)
            output.append(attn_out)
        return torch.cat(output,dim=1)

class PersistentMemory(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.persistent = nn.Parameter(torch.randn(config.N_p,config.d_model))

    def forward(self,batch_size):
        return self.persistent.unsqueeze(0).expand(batch_size,-1,-1)

class TitanMAG(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.d_model = config.d_model
        self.window_size = config.window_size
        self.long_memory = TitanMemory(config)
        self.attn_layers = nn.ModuleList([SlidingWindowAttention(config) for _ in range(config.n_layers)])
        self.persistent = PersistentMemory(config)

    def forward(self,x):
        batch_size,seq_len,d_model = x.size()

        x_flat = x.reshape(-1,d_model)
        with torch.no_grad():
            self.long_memory.update_memory(x_flat)
        
        persistent_tokens = self.persistent(batch_size)
        out = torch.cat([persistent_tokens,x],dim=1)

        for layer in self.attn_layers:
            out = layer(out)
        y = out
        out_flat = out.reshape(-1,self.d_model)
        long_term = self.long_memory(out_flat)
        long_term = long_term.reshape(batch_size,-1,d_model)

        output = y * long_term
        output = output[:,-seq_len:,:]
        return output

class LatentGlobalTransformer(nn.Module):
    def __init__(self,patch_dim,n_heads,ff_dim,n_layers,dropout):
        super().__init__()
        self.layers = nn.ModuleList([TitanMAG(patch_dim,n_heads,ff_dim,dropout) for _ in range(n_layers)])
    def forward(self,patches,attn_mask=None):
        for layer in self.layers:
            patches = layer(patches,attn_mask=attn_mask)
        return patches

class ByteLatentTransformer(nn.Module):
    def __init__(self,byte_dim,patch_dim,vocab_size,n_heads,ff_dim,n_encoder,n_decoder,n_global,dropout=0.1):
        super().__init__()
        self.byte_embed = ByteEmbedding(byte_dim,vocab_size)
        self.local_encoder = LocalEncoder(byte_dim,n_heads,ff_dim,n_layers=n_encoder,dropout=dropout)
        self.global_transformer = LatentGlobalTransformer(patch_dim,n_heads,ff_dim,n_layers=n_global,dropout=dropout)
        self.local_decoder = LocalDecoder(patch_dim,byte_dim,n_heads,ff_dim,n_decoder,dropout=dropout) 
        self.projection = nn.Linear(byte_dim,patch_dim)

    def forward(self,byte_seq,hash_seq,patch_seq):
        byte_embeddings = self.byte_embed(byte_seq,hash_seq)
        if patch_seq is None:
            patch_embeddings = torch.mean(byte_embeddings,dim=1,keepdim=True)
            patch_embeddings = self.local_encoder(byte_embeddings,patch_embeddings)
            patch_embeddings = self.projection(patch_embeddings)
        else:
            patch_embeddings = patch_seq
        patch_embeddings = self.global_transformer(patch_embeddings)
        byte_output = self.local_decoder(patch_embeddings,byte_embeddings)
        return byte_output

from types import SimpleNamespace
class LatentGlobalTransformer(nn.Module):
    def __init__(self, patch_dim, n_heads, ff_dim, n_layers, dropout):
        super().__init__()
        config = SimpleNamespace(
            d_model = patch_dim,
            n_heads = n_heads,
            ff_dim = ff_dim,
            dropout = dropout,
            window_size = 16,  
            n_layers = 2,      
            alpha = 0.1,
            eta = 0.01,
            theta = 0.01,
            N_p = 10
        )
        self.layers = nn.ModuleList([TitanMAG(config) for _ in range(n_layers)])
    def forward(self, patches, attn_mask=None):
        for layer in self.layers:
            patches = layer(patches)
        return patches

class ByteLatentTitan(nn.Module):
    def __init__(self, byte_dim, patch_dim, vocab_size, n_heads, ff_dim, n_encoder, n_decoder, n_global, dropout=0.1):
        super().__init__()
        self.byte_embed = ByteEmbedding(byte_dim, vocab_size)
        self.local_encoder = LocalEncoder(byte_dim, n_heads, ff_dim, n_layers=n_encoder, dropout=dropout)
        self.global_transformer = LatentGlobalTransformer(patch_dim, n_heads, ff_dim, n_layers=n_global, dropout=dropout)
        self.local_decoder = LocalDecoder(patch_dim, byte_dim, n_heads, ff_dim, n_layers=n_decoder, dropout=dropout)
        self.projection = nn.Linear(byte_dim, patch_dim)

    def forward(self, byte_seq, hash_seq, patch_seq):
        byte_embeddings = self.byte_embed(byte_seq, hash_seq)
        if patch_seq is None:
            patch_embeddings = torch.mean(byte_embeddings, dim=1, keepdim=True)
            patch_embeddings = self.local_encoder(byte_embeddings, patch_embeddings)
            patch_embeddings = self.projection(patch_embeddings)
        else:
            patch_embeddings = patch_seq
        patch_embeddings = self.global_transformer(patch_embeddings)
        byte_output = self.local_decoder(patch_embeddings, byte_embeddings)
        return byte_output
        
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from types import SimpleNamespace
from tqdm import tqdm
import numpy as np

# ===== Sampling Utilities =====

def sample_from_logits(logits, temperature=1.0, top_k=0, top_p=0.0):
    # Apply temperature scaling
    logits = logits / temperature

    # Top-k filtering
    if top_k > 0:
        values, _ = torch.topk(logits, top_k)
        min_value = values[-1]
        logits[logits < min_value] = -float('Inf')

    # Nucleus (top-p) filtering
    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        probs = torch.softmax(sorted_logits, dim=-1)
        cumulative_probs = torch.cumsum(probs, dim=-1)
        
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
        sorted_indices_to_remove[0] = 0
        
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = -float('Inf')
    
    probs = torch.softmax(logits, dim=-1)
    next_token = torch.multinomial(probs, 1)
    return next_token

def generate_text(model, prompt, max_length, device, temperature=1.0, top_k=0, top_p=0.0):
    model.eval()
    generated = prompt.clone()
    for _ in range(max_length - prompt.size(1)):
        hash_seq = generated.clone()
        with torch.no_grad():
            output = model(generated, hash_seq, patch_seq=None)
        # Get logits for the last time step (assumes batch size 1)
        next_logits = output[:, -1, :]  # shape: (1, vocab_size)
        next_token = sample_from_logits(next_logits.squeeze(0),
                                        temperature=temperature,
                                        top_k=top_k,
                                        top_p=top_p)
        next_token = next_token.unsqueeze(0)
        generated = torch.cat([generated, next_token], dim=1)
    return generated


class WikiByteDataset(Dataset):

    def __init__(self, hf_dataset, seq_len=128):
        self.dataset = hf_dataset
        self.seq_len = seq_len

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['text']
        text_bytes = text.encode("utf-8", errors="ignore")
        if len(text_bytes) < self.seq_len:
            text_bytes = text_bytes + b" " * (self.seq_len - len(text_bytes))
        else:
            start = torch.randint(0, len(text_bytes) - self.seq_len + 1, (1,)).item()
            text_bytes = text_bytes[start:start+self.seq_len]
        byte_seq = list(text_bytes)
        hash_seq = byte_seq.copy()
        return {
            "byte_seq": torch.tensor(byte_seq, dtype=torch.long),
            "hash_seq": torch.tensor(hash_seq, dtype=torch.long)
        }

def collate_fn(batch):
    byte_seqs = torch.stack([item['byte_seq'] for item in batch])
    hash_seqs = torch.stack([item['hash_seq'] for item in batch])
    return byte_seqs, hash_seqs

def compute_baseline_loss(dataset, vocab_size=256, num_samples=100):
    """
    Computes the empirical cross-entropy (i.e. the entropy) of the token distribution
    over a subset of the dataset. This serves as a baseline loss for a next-token
    prediction task (if one predicted tokens according to the empirical distribution).
    """
    counts = np.zeros(vocab_size, dtype=np.float32)
    total_tokens = 0
    indices = np.random.choice(len(dataset), min(num_samples, len(dataset)), replace=False)
    for idx in indices:
        sample = dataset[int(idx)]   # Cast idx to int
        tokens = sample['byte_seq'].numpy()
        for t in tokens:
            counts[t] += 1
        total_tokens += len(tokens)
    probs = counts / total_tokens
    # Compute entropy: H = - sum(p * log(p)) for nonzero p
    baseline_loss = -np.sum(probs[probs > 0] * np.log(probs[probs > 0]))
    return baseline_loss



def train():
    byte_dim = 256         # Reduced from 128
    patch_dim = 512       # Reduced from 256
    vocab_size = 256      
    n_heads = 8           # Reduced from 8
    ff_dim = 2048          # Reduced from 2048
    n_encoder = 8         # Reduced from 6
    n_decoder = 8         # Reduced from 6
    n_global = 12          # Reduced from 12
    dropout = 0.1
    seq_len = 128        
    batch_size = 64       # Reduced batch size for a smaller dataset
    epochs = 100           # More epochs may help since the dataset is small
    lr = 1e-3             # A slightly higher LR for quicker convergence on a small dataset
    
    # Sampling hyperparameters
    temperature = 1.0  
    top_k = 40         
    top_p = 0.9        

    # Load Tiny Shakespeare dataset
    hf_dataset = load_dataset("tiny_shakespeare", split="train")
    dataset = WikiByteDataset(hf_dataset, seq_len=seq_len)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    
    # Compute the baseline loss from the token distribution (lower is better)
    baseline_loss = compute_baseline_loss(dataset, vocab_size=vocab_size, num_samples=200)
    print(f"Baseline (empirical token entropy) Loss: {baseline_loss:.4f}\n")
    
    # Instantiate the model (make sure ByteLatentTitan is defined)
    model = ByteLatentTitan(byte_dim, patch_dim, vocab_size, n_heads, ff_dim,
                            n_encoder, n_decoder, n_global, dropout)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    
    print("Starting autoregressive training with sampling strategies on Tiny Shakespeare...\n")
    for epoch in range(epochs):
        total_loss = 0.0
        model.train()
        total_params = sum(p.numel() for p in model.parameters())
        print(f"\nEpoch [{epoch+1}] Total Parameters: {total_params}")
        
        for byte_seq, hash_seq in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            byte_seq = byte_seq.to(device)
            hash_seq = hash_seq.to(device)
            optimizer.zero_grad()
            
            output = model(byte_seq, hash_seq, patch_seq=None)
            logits = output[:, :-1, :]  # Predict tokens 1...end
            target = byte_seq[:, 1:]    # Ground truth shifted by one
            loss = F.cross_entropy(logits.reshape(-1, vocab_size), target.reshape(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch+1}] Average Loss: {avg_loss:.4f}")
        print(f"Baseline Loss: {baseline_loss:.4f}")
        
        # ----- Autoregressive Generation Sample -----
        model.eval()
        with torch.no_grad():
            sample = dataset[0]
            prompt_tokens = sample['byte_seq'][:seq_len//2].unsqueeze(0).to(device)
            generated = generate_text(model, prompt_tokens, max_length=300, device=device,
                                      temperature=temperature, top_k=top_k, top_p=top_p)
            generated_list = generated.squeeze(0).cpu().tolist()
            try:
                generated_text = bytes(generated_list).decode("utf-8", errors="replace")
            except Exception as e:
                generated_text = str(generated_list)
            prompt_text = bytes(prompt_tokens.squeeze(0).cpu().tolist()).decode("utf-8", errors="replace")
            print("\n--- Sample Generation ---")
            print("Prompt:   ", prompt_text)
            print("Generated:", generated_text)
    print("\nTraining complete.")

if __name__ == "__main__":
    train()


Baseline (empirical token entropy) Loss: 3.0743



RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from types import SimpleNamespace
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

# -------------------- Basic Modules --------------------

class ByteEmbedding(nn.Module):
    def __init__(self, d_model, hash_size):
        super().__init__()
        self.byte_embed = nn.Embedding(256, d_model)
        self.hash_embed = nn.Embedding(hash_size, d_model)

    def forward(self, byte_seq, hash_seq):
        byte_embedding = self.byte_embed(byte_seq)
        hash_embedding = self.hash_embed(hash_seq)
        return byte_embedding + hash_embedding

class FeedForwardLayer(nn.Module):
    def __init__(self, d_model, ff_dim, dropout):
        super().__init__()
        self.layer1 = nn.Linear(d_model, ff_dim)
        self.layer2 = nn.Linear(ff_dim, d_model)
        self.silu = nn.SiLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.layer2(self.dropout(self.silu(self.layer1(x))))

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, ff_dim, dropout):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.ff = FeedForwardLayer(d_model, ff_dim, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None):
        attn_out, _ = self.attention(x, x, x, attn_mask=attn_mask)
        x = x + self.dropout(attn_out)
        x = self.norm1(x)
        ff_out = self.ff(x)
        x = x + self.dropout(ff_out)
        return self.norm2(x)

class CrossAttentionBlock(nn.Module):
    def __init__(self, query_dim, key_dim, n_heads, ff_dim, dropout):
        super().__init__()
        self.attention = nn.MultiheadAttention(query_dim, n_heads, dropout=dropout)
        self.norm = nn.LayerNorm(query_dim)
        self.ff = FeedForwardLayer(query_dim, ff_dim, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.query_proj = nn.Linear(query_dim, query_dim)
        self.key_proj = nn.Linear(key_dim, query_dim)
        self.value_proj = nn.Linear(key_dim, query_dim)

    def forward(self, query, key, value):
        query = self.query_proj(query).permute(1, 0, 2)
        key = self.key_proj(key).permute(1, 0, 2)
        value = self.value_proj(value).permute(1, 0, 2)
        attn_out, _ = self.attention(query, key, value)
        attn_out = attn_out.permute(1, 0, 2)
        query = query.permute(1, 0, 2)
        query = query + self.dropout(attn_out)
        query = self.norm(query)
        ff_out = self.ff(query)
        return query + self.dropout(ff_out)

class LocalEncoder(nn.Module):
    def __init__(self, byte_dim, n_heads, ff_dim, n_layers, dropout):
        super().__init__()
        self.layers = nn.ModuleList([TransformerBlock(byte_dim, n_heads, ff_dim, dropout) for _ in range(n_layers)])
        self.cross_attn = CrossAttentionBlock(query_dim=byte_dim, key_dim=byte_dim, n_heads=n_heads, ff_dim=ff_dim, dropout=dropout)

    def forward(self, byte_embeddings, patch_embeddings):
        for layer in self.layers:
            byte_embeddings = layer(byte_embeddings)
        patch_embedding = self.cross_attn(patch_embeddings, byte_embeddings, byte_embeddings)
        return patch_embedding

class LocalDecoder(nn.Module):
    def __init__(self, patch_dim, byte_dim, n_heads, ff_dim, n_layers, dropout):
        super().__init__()
        self.layers = nn.ModuleList([TransformerBlock(byte_dim, n_heads, ff_dim, dropout) for _ in range(n_layers)])
        self.cross_attn = CrossAttentionBlock(query_dim=byte_dim, key_dim=patch_dim, n_heads=n_heads, ff_dim=ff_dim, dropout=dropout)
        self.output_proj = nn.Linear(byte_dim, 256)

    def forward(self, patch_embedding, byte_embedding):
        byte_embedding = self.cross_attn(byte_embedding, patch_embedding, patch_embedding)
        for layer in self.layers:
            byte_embedding = layer(byte_embedding)
        return self.output_proj(byte_embedding)

def l2_loss(pred, target):
    return torch.sum((pred - target) ** 2)

# -------------------- Memory Modules --------------------

class TitanMemory(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.d_model = config.d_model
        self.register_buffer("M", torch.eye(config.d_model))
        self.register_buffer("S", torch.zeros(config.d_model, config.d_model))
        self.query = nn.Linear(config.d_model, config.d_model, bias=False)
        self.key = nn.Linear(config.d_model, config.d_model, bias=False)
        self.value = nn.Linear(config.d_model, config.d_model, bias=False)
        self.alpha = config.alpha
        self.eta = config.eta
        self.theta = config.theta

    def forward(self, x):
        q = self.query(x)
        y = torch.matmul(q, self.M)
        return y

    def update_memory(self, x):
        B = x.size(0)
        if B != 1:
            for i in range(B):
                self.update_memory(x[i:i+1])
            return
        k = self.key(x)
        v = self.value(x)
        v_pred = torch.matmul(k, self.M)
        loss = l2_loss(v_pred, v)
        error = v_pred - v
        g = 2 * torch.matmul(error.t(), k)
        self.S = self.eta * self.S - self.theta * g
        self.S = torch.clamp(self.S, -1e3, 1e3)
        self.M = (1 - self.alpha) * self.M + self.S
        self.M = torch.clamp(self.M, -1e3, 1e3)
        return loss

class SlidingWindowAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.d_model = config.d_model
        self.n_heads = config.n_heads
        self.window_size = config.window_size
        self.attention = nn.MultiheadAttention(embed_dim=config.d_model, num_heads=config.n_heads, batch_first=True)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        output = []
        for i in range(0, seq_len, self.window_size):
            x_chunk = x[:, i:i+self.window_size, :]
            attn_out, _ = self.attention(x_chunk, x_chunk, x_chunk)
            output.append(attn_out)
        return torch.cat(output, dim=1)

class PersistentMemory(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.persistent = nn.Parameter(torch.randn(config.N_p, config.d_model))

    def forward(self, batch_size):
        return self.persistent.unsqueeze(0).expand(batch_size, -1, -1)

# -------------------- MoE Layer --------------------

class MoELayer(nn.Module):
    def __init__(self, d_model, num_experts=4, hidden_dim=None, dropout=0.1):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = 4 * d_model
        self.experts = nn.ModuleList([FeedForwardLayer(d_model, hidden_dim, dropout) for _ in range(num_experts)])
        self.gate = nn.Linear(d_model, num_experts)

    def forward(self, x):
        # x: (batch, seq, d_model)
        gate_logits = self.gate(x)  # (batch, seq, num_experts)
        gate_probs = F.softmax(gate_logits, dim=-1)  # (batch, seq, num_experts)
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=2)  # (batch, seq, num_experts, d_model)
        gate_probs = gate_probs.unsqueeze(-1)  # (batch, seq, num_experts, 1)
        moe_output = torch.sum(gate_probs * expert_outputs, dim=2)  # (batch, seq, d_model)
        return moe_output

# -------------------- TitanMAG with MoE Integration --------------------

class TitanMAG(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.d_model = config.d_model
        self.window_size = config.window_size
        self.long_memory = TitanMemory(config)
        self.attn_layers = nn.ModuleList([SlidingWindowAttention(config) for _ in range(config.n_layers)])
        self.persistent = PersistentMemory(config)
    def forward(self, x):
        batch_size, seq_len, d_model = x.size()
        x_flat = x.reshape(-1, d_model)
        with torch.no_grad():
            self.long_memory.update_memory(x_flat)
        persistent_tokens = self.persistent(batch_size)
        out = torch.cat([persistent_tokens, x], dim=1)
        for layer in self.attn_layers:
            out = layer(out)
        y = out
        out_flat = out.reshape(-1, self.d_model)
        long_term = self.long_memory(out_flat)
        long_term = long_term.reshape(batch_size, -1, d_model)
        output = y * long_term
        output = output[:, -seq_len:, :]
        return output

# -------------------- RMSNorm --------------------

class RMSNorm(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.eps = config.eps 
        self.weights = nn.Parameter(torch.ones(config.d_model))

    def forward(self, x):
        mean = torch.mean(x ** 2, dim=-1, keepdim=True)
        rms = torch.sqrt(mean + self.eps)
        return (x / rms) * self.weights

# -------------------- EncoderBlock with MoE --------------------

class EncoderBlock(nn.Module):
    def __init__(self, patch_dim, n_heads, ff_dim, n_layers, dropout):
        super().__init__()
        config = SimpleNamespace(
            d_model = patch_dim,
            n_heads = n_heads,
            ff_dim = ff_dim,
            dropout = dropout,
            window_size = 32,
            n_layers = 8,
            alpha = 0.1,
            eta = 0.01,
            theta = 0.01,
            N_p = 64,
            eps = 1e-5
        )
        self.Titan = TitanMAG(config)
        self.MoE = MoELayer(config.d_model, num_experts=8, hidden_dim=config.ff_dim, dropout=config.dropout)
        self.norm1 = RMSNorm(config)
        self.norm2 = RMSNorm(config)
        self.dropout = nn.Dropout(config.dropout)
        self.dropout2 = nn.Dropout(config.dropout)

    def forward(self, x, attn_mask=None):
        attn_out = x + self.Titan(self.norm1(self.dropout(x)))
        out = attn_out + self.MoE(self.norm2(self.dropout2(attn_out)))
        return out
                                  
        
# -------------------- LatentGlobalTransformer with MoE Integration --------------------

class LatentGlobalTransformer(nn.Module):
    def __init__(self, patch_dim, n_heads, ff_dim, n_layers, dropout):
        super().__init__()
        self.layers = nn.ModuleList([EncoderBlock(patch_dim, n_heads, ff_dim, n_layers=2, dropout=dropout) for _ in range(n_layers)])
    def forward(self, patches, attn_mask=None):
        for layer in self.layers:
            patches = layer(patches)
        return patches
        


# -------------------- ByteLatentTitan --------------------

class ByteLatentTitan(nn.Module):
    def __init__(self, byte_dim, patch_dim, vocab_size, n_heads, ff_dim, n_encoder, n_decoder, n_global, dropout=0.1):
        super().__init__()
        self.byte_embed = ByteEmbedding(byte_dim, vocab_size)
        self.local_encoder = LocalEncoder(byte_dim, n_heads, ff_dim, n_layers=n_encoder, dropout=dropout)
        self.global_transformer = LatentGlobalTransformer(patch_dim, n_heads, ff_dim, n_layers=n_global, dropout=dropout)
        self.local_decoder = LocalDecoder(patch_dim, byte_dim, n_heads, ff_dim, n_layers=n_decoder, dropout=dropout)
        self.projection = nn.Linear(byte_dim, patch_dim)

    def forward(self, byte_seq, hash_seq, patch_seq):
        byte_embeddings = self.byte_embed(byte_seq, hash_seq)
        if patch_seq is None:
            patch_embeddings = torch.mean(byte_embeddings, dim=1, keepdim=True)
            patch_embeddings = self.local_encoder(byte_embeddings, patch_embeddings)
            patch_embeddings = self.projection(patch_embeddings)
        else:
            patch_embeddings = patch_seq
        patch_embeddings = self.global_transformer(patch_embeddings)
        byte_output = self.local_decoder(patch_embeddings, byte_embeddings)
        return byte_output

# -------------------- Sampling Utilities --------------------

def sample_from_logits(logits, temperature=1.0, top_k=0, top_p=0.0):
    logits = logits / temperature
    if top_k > 0:
        values, _ = torch.topk(logits, top_k)
        min_value = values[-1]
        logits[logits < min_value] = -float('Inf')
    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        probs = torch.softmax(sorted_logits, dim=-1)
        cumulative_probs = torch.cumsum(probs, dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
        sorted_indices_to_remove[0] = 0
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = -float('Inf')
    probs = torch.softmax(logits, dim=-1)
    next_token = torch.multinomial(probs, 1)
    return next_token

def generate_text(model, prompt, max_length, device, temperature=1.0, top_k=0, top_p=0.0):
    model.eval()
    generated = prompt.clone()
    for _ in range(max_length - prompt.size(1)):
        hash_seq = generated.clone()
        with torch.no_grad():
            output = model(generated, hash_seq, patch_seq=None)
        next_logits = output[:, -1, :]  # (1, vocab_size)
        next_token = sample_from_logits(next_logits.squeeze(0),
                                        temperature=temperature,
                                        top_k=top_k,
                                        top_p=top_p)
        next_token = next_token.unsqueeze(0)
        generated = torch.cat([generated, next_token], dim=1)
    return generated

# -------------------- Dataset & Collate --------------------

class WikiByteDataset(Dataset):
    def __init__(self, hf_dataset, seq_len=128):
        self.dataset = hf_dataset
        self.seq_len = seq_len

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['text']
        text_bytes = text.encode("utf-8", errors="ignore")
        if len(text_bytes) < self.seq_len:
            text_bytes = text_bytes + b" " * (self.seq_len - len(text_bytes))
        else:
            start = torch.randint(0, len(text_bytes) - self.seq_len + 1, (1,)).item()
            text_bytes = text_bytes[start:start+self.seq_len]
        byte_seq = list(text_bytes)
        hash_seq = byte_seq.copy()
        return {"byte_seq": torch.tensor(byte_seq, dtype=torch.long),
                "hash_seq": torch.tensor(hash_seq, dtype=torch.long)}

def collate_fn(batch):
    byte_seqs = torch.stack([item['byte_seq'] for item in batch])
    hash_seqs = torch.stack([item['hash_seq'] for item in batch])
    return byte_seqs, hash_seqs

# -------------------- Baseline Evaluation --------------------

def compute_baseline_loss(dataset, vocab_size=256, num_samples=100):
    counts = np.zeros(vocab_size, dtype=np.float32)
    total_tokens = 0
    indices = np.random.choice(len(dataset), min(num_samples, len(dataset)), replace=False)
    for idx in indices:
        sample = dataset[int(idx)]
        tokens = sample['byte_seq'].numpy()
        for t in tokens:
            counts[t] += 1
        total_tokens += len(tokens)
    probs = counts / total_tokens
    baseline_loss = -np.sum(probs[probs > 0] * np.log(probs[probs > 0]))
    return baseline_loss

def evaluate_generation_loss(model, sample, device):
    # Evaluate a held-out sample using teacher forcing:
    model.eval()
    seq = sample['byte_seq'].unsqueeze(0).to(device)  # shape (1, L)
    L = seq.size(1)
    prompt_length = L // 2
    with torch.no_grad():
        output = model(seq, seq, patch_seq=None)  # shape (1, L, vocab_size)
    logits = output[:, prompt_length:-1, :]  # predictions for positions prompt_length+1 to L
    target = seq[:, prompt_length+1:]
    loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
    return loss.item()

# -------------------- Checkpoint Functions --------------------

def save_checkpoint(model, optimizer, epoch, checkpoint_path="checkpoint.pt"):
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict()
    }
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved at epoch {epoch}.")

def load_checkpoint(model, optimizer, checkpoint_path="checkpoint.pt"):
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_epoch = checkpoint["epoch"] + 1
        print(f"Checkpoint found. Resuming training from epoch {start_epoch}.")
        return start_epoch
    else:
        print("No checkpoint found. Training from scratch.")
        return 0

# -------------------- Training Loop --------------------

def train():
    import os
    # Hyperparameters adjusted for a larger dataset
    byte_dim = 64         
    patch_dim = 128       
    vocab_size = 256      
    n_heads = 8           
    ff_dim = 512         
    n_encoder = 4         
    n_decoder = 4         
    n_global = 8          
    dropout = 0.1
    seq_len = 256        
    batch_size = 128       
    epochs = 10           
    lr = 1e-5            
    
    # Sampling hyperparameters
    temperature = 1.0  
    top_k = 40         
    top_p = 0.9        

    # Use a larger dataset: 10% slice of English Wikipedia (20220301)
    hf_dataset = load_dataset("wikipedia", "20220301.en", split="train[:1%]")
    dataset = WikiByteDataset(hf_dataset, seq_len=seq_len)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    
    baseline_loss = compute_baseline_loss(dataset, vocab_size=vocab_size, num_samples=200)
    print(f"Baseline (empirical token entropy) Loss: {baseline_loss:.4f}\n")
    
    model = ByteLatentTitan(byte_dim, patch_dim, vocab_size, n_heads, ff_dim,
                              n_encoder, n_decoder, n_global, dropout)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    
    # Load checkpoint if available
    checkpoint_path = "checkpoint.pt"
    start_epoch = 0 #load_checkpoint(model, optimizer, checkpoint_path)
    
    train_losses = []
    gen_losses = []
    
    print("Starting autoregressive training with MoE and sampling strategies on Wikipedia...\n")
    for epoch in range(start_epoch, epochs):
        total_loss = 0.0
        model.train()
        total_params = sum(p.numel() for p in model.parameters())
        print(f"\nEpoch [{epoch+1}] Total Parameters: {total_params}")
        
        for byte_seq, hash_seq in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            byte_seq = byte_seq.to(device)
            hash_seq = hash_seq.to(device)
            optimizer.zero_grad()
            
            output = model(byte_seq, hash_seq, patch_seq=None)
            logits = output[:, :-1, :]  # Predict tokens 1...end
            target = byte_seq[:, 1:]    # Ground truth shifted by one
            loss = F.cross_entropy(logits.reshape(-1, vocab_size), target.reshape(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        avg_loss = total_loss / len(dataloader)
        train_losses.append(avg_loss)
        print(f"Epoch [{epoch+1}] Average Training Loss: {avg_loss:.4f}")
        print(f"Baseline Loss: {baseline_loss:.4f}")
        
        # Evaluate generation loss on one held-out sample
        sample = dataset[0]
        gen_loss = evaluate_generation_loss(model, sample, device)
        gen_losses.append(gen_loss)
        print(f"Epoch [{epoch+1}] Generation Loss: {gen_loss:.4f}")
        
        # Generate sample text
        model.eval()
        with torch.no_grad():
            prompt_tokens = sample['byte_seq'][:seq_len//2].unsqueeze(0).to(device)
            generated = generate_text(model, prompt_tokens, max_length=seq_len, device=device,
                                      temperature=temperature, top_k=top_k, top_p=top_p)
            generated_list = generated.squeeze(0).cpu().tolist()
            try:
                generated_text = bytes(generated_list).decode("utf-8", errors="replace")
            except Exception as e:
                generated_text = str(generated_list)
            prompt_text = bytes(prompt_tokens.squeeze(0).cpu().tolist()).decode("utf-8", errors="replace")
            print("\n--- Sample Generation ---")
            print("Prompt:   ", prompt_text)
            print("Generated:", generated_text)
        
        # Save checkpoint after each epoch
        #save_checkpoint(model, optimizer, epoch, checkpoint_path)
    
    print("\nTraining complete.")
    epochs_range = range(start_epoch+1, epochs+1)
    plt.figure(figsize=(10, 6))
    plt.plot(epochs_range, train_losses, label="Training Loss")
    plt.plot(epochs_range, gen_losses, label="Generation Loss", linestyle="--")
    plt.axhline(baseline_loss, color='r', linestyle=':', label="Baseline Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training and Generation Loss Curves")
    plt.legend()
    plt.grid(True)
    plt.show()

if __name__ == "__main__":
    train()


Baseline (empirical token entropy) Loss: 3.3112

Starting autoregressive training with MoE and sampling strategies on Wikipedia...


Epoch [1] Total Parameters: 14046400


Epoch 1/10: 100%|█████████████████████████████| 505/505 [06:21<00:00,  1.32it/s]


Epoch [1] Average Training Loss: 4.6763
Baseline Loss: 3.3112
Epoch [1] Generation Loss: 4.2287

--- Sample Generation ---
Prompt:    on organisational and economic aspects of their ideal society.

Mutualism is an 18th-century economic theory that was developed 
Generated: on organisational and economic aspects of their ideal society.

Mutualism is an 18th-century economic theory that was developed tr �  �.� at-��e�ee o� an�i�rr.&� V d �n �ne te � Iio pa�̑uer or��r�ter �er iu�n��n �&aeuo.d� ir� draorucat�h1D0e ̷rd r-�h- 

Epoch [2] Total Parameters: 14046400


Epoch 2/10: 100%|█████████████████████████████| 505/505 [06:20<00:00,  1.33it/s]


Epoch [2] Average Training Loss: 4.0377
Baseline Loss: 3.3112
Epoch [2] Generation Loss: 3.7122

--- Sample Generation ---
Prompt:     of all forms of domination and hierarchy.

Tactics 
Anarchists' tactics take various forms but in general serve two major goals
Generated:  of all forms of domination and hierarchy.

Tactics 
Anarchists' tactics take various forms but in general serve two major goalse �a�o�n 2 amh, aoat. ur.iui
 ith.nn�haa.or�ma rdh,u iunes pundei Te i.sns th� onn., te n�
Sal htht�red h   e ou
 ore iwnn
 ar

Epoch [3] Total Parameters: 14046400


Epoch 3/10:  46%|█████████████▍               | 234/505 [02:56<03:24,  1.32it/s]