In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as f
from dataclasses import dataclass

In [2]:
class ByteEmbedding(nn.Module):
    def __init__(self,d_model,hash_size):
        super().__init__()
        self.byte_embed = nn.Embedding(256,d_model)
        self.hash_embed = nn.Embedding(hash_size,d_model)

    def forward(self,byte_seq,hash_seq):
        byte_embedding = self.byte_embed(byte_seq)
        hash_embedding = self.hash_embed(hash_seq)
        return byte_embedding + hash_embedding

In [3]:
class FeedForwardLayer(nn.Module):
    def __init__(self,d_model,ff_dim,dropout):
        super().__init__()
        self.layer1 = nn.Linear(d_model,ff_dim)
        self.layer2 = nn.Linear(ff_dim,d_model)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(dropout)

    def forward(self,x):
        return self.layer2(self.dropout(self.gelu(self.layer1(x))))

In [4]:
class TransformerBlock(nn.Module):
    def __init__(self,d_model,n_heads,ff_dim,dropout):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model,n_heads,dropout=dropout)
        self.ff = FeedForwardLayer(d_model,ff_dim,dropout=dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self,x,attn_mask = None):
        attn_out,_ = self.attention(x,x,x,attn_mask=attn_mask) 
        x = x + self.dropout(attn_out)
        x = self.norm1(x)
        ff_out = self.ff(x)
        x = x + self.dropout(attn_out)
        return self.norm2(x)

In [5]:
class CrossAttentionBlock(nn.Module):
    def __init__(self,query_dim,key_dim,n_heads,ff_dim,dropout):
        super().__init__()
        self.attention = nn.MultiheadAttention(query_dim,n_heads,dropout=dropout)
        self.norm = nn.LayerNorm(query_dim)
        self.ff = FeedForwardLayer(query_dim,ff_dim,dropout = dropout)
        self.dropout = nn.Dropout(dropout)
        self.query_proj = nn.Linear(query_dim,query_dim)
        self.key_proj = nn.Linear(key_dim,query_dim)
        self.value_proj = nn.Linear(key_dim,query_dim)

    def forward(self,query,key,value):
        query = self.query_proj(query).permute(1,0,2)
        key = self.key_proj(key).permute(1,0,2)
        value = self.value_proj(value).permute(1,0,2)

        attn_out , _ = self.attention(query,key,value)
        attn_out = attn_out.permute(1,0,2)
        query = query.permute(1,0,2)
        query = query + self.dropout(attn_out)
        query = self.norm(query)
        ff_out = self.ff(query)
        return query + self.dropout(ff_out)

In [6]:
class LocalEncoder(nn.Module):
    def __init__(self,byte_dim,n_heads,ff_dim,n_layers,dropout):
        super().__init__()
        self.layers = nn.ModuleList([TransformerBlock(byte_dim,n_heads,ff_dim,dropout) for _ in range(n_layers)])
        self.cross_attn = CrossAttentionBlock(query_dim=byte_dim,key_dim=byte_dim,n_heads=n_heads,ff_dim=ff_dim,dropout=dropout)

    def forward(self,byte_embeddings,patch_embeddings):
        for layer in self.layers:
            byte_embeddings = layer(byte_embeddings)
        patch_embedding = self.cross_attn(patch_embeddings,byte_embeddings,byte_embeddings)
        return patch_embedding
        

In [7]:
class LocalDecoder(nn.Module):
    def __init__(self,patch_dim,byte_dim,n_heads,ff_dim,n_layers,dropout):
        super().__init__()
        self.layers = nn.ModuleList([TransformerBlock(byte_dim,n_heads,ff_dim,dropout) for _ in range(n_layers)])
        self.cross_attn = CrossAttentionBlock(query_dim=byte_dim,key_dim=patch_dim,n_heads=n_heads,ff_dim=ff_dim,dropout=dropout)
        self.output_proj = nn.Linear(byte_dim,256)

    def forward(self,patch_embedding,byte_embedding):
        byte_embedding = self.cross_attn(byte_embedding,patch_embedding,patch_embedding)
        for layer in self.layers:
            byte_embedding = layer(byte_embedding)
        return self.output_proj(byte_embedding)

In [8]:
def l2_loss(pred,target):
    return torch.sum((pred - target) ** 2)

In [9]:
class TitanMemory(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.d_model = config.d_model
        self.register_buffer("M",torch.eye(config.d_model))
        self.register_buffer("S",torch.zeros(config.d_model,config.d_model))

        self.query = nn.Linear(config.d_model,config.d_model,bias=False)
        self.key = nn.Linear(config.d_model,config.d_model,bias=False)
        self.value = nn.Linear(config.d_model,config.d_model,bias=False)

        self.alpha = config.alpha
        self.eta = config.eta
        self.theta = config.theta

    def forward(self,x):
        q = self.query(x)
        y = torch.matmul(q,self.M)
        return y

    def update_memory(self,x):
        B = x.size(0)
        if B != 1:
            for i in range(B):
                self.update_memory(x[i:i+1])
            return

        k = self.key(x)
        v = self.value(x)

        v_pred = torch.matmul(k,self.M)
        loss = l2_loss(v_pred,v)
        error = v_pred - v

        g = 2 * torch.matmul(error.t(),k)

        self.S = self.eta * self.S - self.theta * g
        self.S = torch.clamp(self.S, -1e3, 1e3)
        self.M = (1-self.alpha) * self.M + self.S
        self.M = torch.clamp(self.M, -1e3, 1e3)
        return loss

In [10]:
class SlidingWindowAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.d_model = config.d_model
        self.n_heads = config.n_heads
        self.window_size = config.window_size
        self.attention = nn.MultiheadAttention(embed_dim = config.d_model, num_heads = config.n_heads, batch_first= True)

    def forward(self,x):
        batch_size,seq_len,_ = x.size()
        output = []

        for i in range(0,seq_len,self.window_size):
            x_chunk = x[:,i:i+self.window_size,:]
            attn_out,_ = self.attention(x_chunk,x_chunk,x_chunk)
            output.append(attn_out)
        return torch.cat(output,dim=1)
        

In [11]:
class PersistentMemory(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.persistent = nn.Parameter(torch.randn(config.N_p,config.d_model))

    def forward(self,batch_size):
        return self.persistent.unsqueeze(0).expand(batch_size,-1,-1)

In [12]:
class TitanMAG(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.d_model = config.d_model
        self.window_size = config.window_size
        self.long_memory = TitanMemory(config)
        self.attn_layers = nn.ModuleList([SlidingWindowAttention(config) for _ in range(config.n_layers)])
        self.persistent = PersistentMemory(config)

    def forward(self,x):
        batch_size,seq_len,d_model = x.size()

        x_flat = x.reshape(-1,d_model)
        with torch.no_grad():
            self.long_memory.update_memory(x_flat)
        
        persistent_tokens = self.persistent(batch_size)
        out = torch.cat([persistent_tokens,x],dim=1)

        for layer in self.attn_layers:
            out = layer(out)
        y = out
        out_flat = out.reshape(-1,self.d_model)
        long_term = self.long_memory(out_flat)
        long_term = long_term.reshape(batch_size,-1,d_model)

        output = y * long_term
        output = output[:,-seq_len:,:]
        return output

In [13]:
class LatentGlobalTransformer(nn.Module):
    def __init__(self,patch_dim,n_heads,ff_dim,n_layers,dropout):
        super().__init__()
        self.layers = nn.ModuleList([TitanMAG(patch_dim,n_heads,ff_dim,dropout) for _ in range(n_layers)])
    def forward(self,patches,attn_mask=None):
        for layer in self.layers:
            patches = layer(patches,attn_mask=attn_mask)
        return patches

In [14]:
class ByteLatentTransformer(nn.Module):
    def __init__(self,byte_dim,patch_dim,vocab_size,n_heads,ff_dim,n_encoder,n_decoder,n_global,dropout=0.1):
        super().__init__()
        self.byte_embed = ByteEmbedding(byte_dim,vocab_size)
        self.local_encoder = LocalEncoder(byte_dim,n_heads,ff_dim,n_layers=n_encoder,dropout=dropout)
        self.global_transformer = LatentGlobalTransformer(patch_dim,n_heads,ff_dim,n_layers=n_global,dropout=dropout)
        self.local_decoder = LocalDecoder(patch_dim,byte_dim,n_heads,ff_dim,n_decoder,dropout=dropout) 
        self.projection = nn.Linear(byte_dim,patch_dim)

    def forward(self,byte_seq,hash_seq,patch_seq):
        byte_embeddings = self.byte_embed(byte_seq,hash_seq)
        if patch_seq is None:
            patch_embeddings = torch.mean(byte_embeddings,dim=1,keepdim=True)
            patch_embeddings = self.local_encoder(byte_embeddings,patch_embeddings)
            patch_embeddings = self.projection(patch_embeddings)
        else:
            patch_embeddings = patch_seq
        patch_embeddings = self.global_transformer(patch_embeddings)
        byte_output = self.local_decoder(patch_embeddings,byte_embeddings)
        return byte_output

In [15]:
from types import SimpleNamespace
class LatentGlobalTransformer(nn.Module):
    def __init__(self, patch_dim, n_heads, ff_dim, n_layers, dropout):
        super().__init__()
        config = SimpleNamespace(
            d_model = patch_dim,
            n_heads = n_heads,
            ff_dim = ff_dim,
            dropout = dropout,
            window_size = 16,  
            n_layers = 2,      
            alpha = 0.1,
            eta = 0.01,
            theta = 0.01,
            N_p = 10
        )
        self.layers = nn.ModuleList([TitanMAG(config) for _ in range(n_layers)])
    def forward(self, patches, attn_mask=None):
        for layer in self.layers:
            patches = layer(patches)
        return patches

class ByteLatentTitan(nn.Module):
    def __init__(self, byte_dim, patch_dim, vocab_size, n_heads, ff_dim, n_encoder, n_decoder, n_global, dropout=0.1):
        super().__init__()
        self.byte_embed = ByteEmbedding(byte_dim, vocab_size)
        self.local_encoder = LocalEncoder(byte_dim, n_heads, ff_dim, n_layers=n_encoder, dropout=dropout)
        self.global_transformer = LatentGlobalTransformer(patch_dim, n_heads, ff_dim, n_layers=n_global, dropout=dropout)
        self.local_decoder = LocalDecoder(patch_dim, byte_dim, n_heads, ff_dim, n_layers=n_decoder, dropout=dropout)
        self.projection = nn.Linear(byte_dim, patch_dim)

    def forward(self, byte_seq, hash_seq, patch_seq):
        byte_embeddings = self.byte_embed(byte_seq, hash_seq)
        if patch_seq is None:
            patch_embeddings = torch.mean(byte_embeddings, dim=1, keepdim=True)
            patch_embeddings = self.local_encoder(byte_embeddings, patch_embeddings)
            patch_embeddings = self.projection(patch_embeddings)
        else:
            patch_embeddings = patch_seq
        patch_embeddings = self.global_transformer(patch_embeddings)
        byte_output = self.local_decoder(patch_embeddings, byte_embeddings)
        return byte_output

if __name__ == "__main__":
    # Define model hyperparameters
    byte_dim = 64
    patch_dim = 128
    vocab_size = 256  # both for byte and hash embedding
    n_heads = 8
    ff_dim = 256
    n_encoder = 2
    n_decoder = 2
    n_global = 2
    dropout = 0.1

    # Instantiate the model
    model = ByteLatentTitan(byte_dim, patch_dim, vocab_size, n_heads, ff_dim, n_encoder, n_decoder, n_global, dropout)
    
    # Create dummy inputs
    batch_size = 2
    seq_len = 128
    # byte_seq: integers in [0, 256)
    byte_seq = torch.randint(0, 256, (batch_size, seq_len))
    # hash_seq: integers in [0, 256)
    hash_seq = torch.randint(0, 256, (batch_size, seq_len))
    # We set patch_seq to None so the model computes its own patch embeddings.
    patch_seq = None

    # Run a forward pass
    output = model(byte_seq, hash_seq, patch_seq)
    
    print("Output shape:", output.shape)


Output shape: torch.Size([2, 128, 256])


In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from types import SimpleNamespace
from tqdm import tqdm

# ===================== Sampling Utilities =====================

def sample_from_logits(logits, temperature=1.0, top_k=0, top_p=0.0):
    """
    Given a 1D tensor of logits, apply temperature scaling,
    then filter using top-k and/or nucleus (top-p) sampling, and sample one token.
    """
    # Temperature scaling
    logits = logits / temperature
    
    # Top-k filtering
    if top_k > 0:
        values, _ = torch.topk(logits, top_k)
        min_value = values[-1]
        logits[logits < min_value] = -float('Inf')
    
    # Nucleus (top-p) filtering
    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        probs = torch.softmax(sorted_logits, dim=-1)
        cumulative_probs = torch.cumsum(probs, dim=-1)
        
        # Remove tokens with cumulative probability above threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the mask to keep at least one token
        sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
        sorted_indices_to_remove[0] = 0
        
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = -float('Inf')
    
    # Convert to probabilities and sample
    probs = torch.softmax(logits, dim=-1)
    next_token = torch.multinomial(probs, 1)
    return next_token

def generate_text(model, prompt, max_length, device, temperature=1.0, top_k=0, top_p=0.0):
    model.eval()
    generated = prompt.clone()
    for _ in range(max_length - prompt.size(1)):
        # In our model, we use the same sequence for both byte_seq and hash_seq.
        hash_seq = generated.clone()
        with torch.no_grad():
            output = model(generated, hash_seq, patch_seq=None)
        # Get logits for the last time step. (Assuming batch size 1)
        next_logits = output[:, -1, :]  # shape: (1, vocab_size)
        # Squeeze to get a 1D tensor of logits.
        next_token = sample_from_logits(next_logits.squeeze(0),
                                        temperature=temperature,
                                        top_k=top_k,
                                        top_p=top_p)
        # Append the sampled token.
        next_token = next_token.unsqueeze(0)
        generated = torch.cat([generated, next_token], dim=1)
    return generated

# ===================== WikiByteDataset Definition =====================

class WikiByteDataset(Dataset):
    """
    Converts Wikipedia text samples to fixed-length byte sequences.
    """
    def __init__(self, hf_dataset, seq_len=128):
        self.dataset = hf_dataset
        self.seq_len = seq_len

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['text']
        # Convert text to bytes (ignoring errors)
        text_bytes = text.encode("utf-8", errors="ignore")
        # If the text is shorter than seq_len, pad with spaces (ASCII 32)
        if len(text_bytes) < self.seq_len:
            text_bytes = text_bytes + b" " * (self.seq_len - len(text_bytes))
        else:
            # Randomly select a contiguous segment of length seq_len
            start = torch.randint(0, len(text_bytes) - self.seq_len + 1, (1,)).item()
            text_bytes = text_bytes[start:start+self.seq_len]
        byte_seq = list(text_bytes)
        # For this example, we use the same sequence as a dummy hash sequence.
        hash_seq = byte_seq.copy()
        return {
            "byte_seq": torch.tensor(byte_seq, dtype=torch.long),
            "hash_seq": torch.tensor(hash_seq, dtype=torch.long)
        }

def collate_fn(batch):
    byte_seqs = torch.stack([item['byte_seq'] for item in batch])
    hash_seqs = torch.stack([item['hash_seq'] for item in batch])
    return byte_seqs, hash_seqs


def train():
    # Model/training hyperparameters
    byte_dim = 128
    patch_dim = 256
    vocab_size = 256      # Bytes: 0-255
    n_heads = 8
    ff_dim = 2048
    n_encoder = 6
    n_decoder = 6
    n_global = 12
    dropout = 0.1
    seq_len = 128        # Fixed sequence length (in bytes)
    batch_size = 256
    epochs = 6
    lr = 1e-4

    # Sampling hyperparameters
    temperature = 1.2  # Increase temperature for more diversity.
    top_k = 40         # Consider only top 40 tokens.
    top_p = 0.9        # Nucleus sampling threshold.

    # Load a subset of Wikipedia.
    hf_dataset = load_dataset("wikipedia", "20220301.en", split="train[:5%]")
    dataset = WikiByteDataset(hf_dataset, seq_len=seq_len)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

    # Instantiate your model. (Make sure ByteLatentTitan and its dependencies are defined/imported.)
    model = ByteLatentTitan(byte_dim, patch_dim, vocab_size, n_heads, ff_dim,
                            n_encoder, n_decoder, n_global, dropout)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)

    print("Starting autoregressive training with sampling strategies...")
    for epoch in range(epochs):
        total_loss = 0.0
        model.train()
        # Print total parameter count for the model at the beginning of the epoch.
        total_params = sum(p.numel() for p in model.parameters())
        print(f"\nEpoch [{epoch+1}] Total Parameters: {total_params}")
        
        for byte_seq, hash_seq in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            byte_seq = byte_seq.to(device)
            hash_seq = hash_seq.to(device)
            optimizer.zero_grad()
            
            # Forward pass. (Our autoregressive objective: predict next token)
            output = model(byte_seq, hash_seq, patch_seq=None)
            # Shift outputs and targets by one.
            logits = output[:, :-1, :]  # Predictions for positions 1 ... end.
            target = byte_seq[:, 1:]    # Ground truth tokens (shifted by one).
            loss = F.cross_entropy(logits.reshape(-1, vocab_size), target.reshape(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch+1}] Average Loss: {avg_loss:.4f}")

        # ----- Autoregressive Generation Sample -----
        model.eval()
        with torch.no_grad():
            sample = dataset[0]
            prompt_tokens = sample['byte_seq'][:seq_len//2].unsqueeze(0).to(device)
            generated = generate_text(model, prompt_tokens, max_length=300, device=device,
                                      temperature=temperature, top_k=top_k, top_p=top_p)
            generated_list = generated.squeeze(0).cpu().tolist()
            try:
                generated_text = bytes(generated_list).decode("utf-8", errors="replace")
            except Exception as e:
                generated_text = str(generated_list)
            prompt_text = bytes(prompt_tokens.squeeze(0).cpu().tolist()).decode("utf-8", errors="replace")
            print("\n--- Sample Generation ---")
            print("Prompt:   ", prompt_text)
            print("Generated:", generated_text)
    print("Training complete.")

if __name__ == "__main__":
    train()


Starting autoregressive training with sampling strategies...

Epoch [1] Total Parameters: 17271296


Epoch 1/6: 100%|████████████████████████████| 1262/1262 [26:39<00:00,  1.27s/it]


Epoch [1] Average Loss: 2.8119

--- Sample Generation ---
Prompt:    World Economic Forum. During the protests, ad hoc leaderless ano
Generated: World Economic Forum. During the protests, ad hoc leaderless anola
 ily unga�f�daniaf"StuV70oLi

 n)mgovea wEd1�c.
.te igfUFohcs. omcDoOc�eorftrdrpo thine,kcc4%8alSlncs af es
ng wmon.dd an o)lyrcoxgayssipedz�rc..7t
nfSkyd iganbe  a di):om,x�a aknfceswg

kssft-hoOk.
.4lycL.k
 rto.8ar. s2u."T st th lm

Epoch [2] Total Parameters: 17271296


Epoch 2/6: 100%|████████████████████████████| 1262/1262 [26:42<00:00,  1.27s/it]


Epoch [2] Average Loss: 2.6164

--- Sample Generation ---
Prompt:    o play the role of facilitator to help achieve a consensus witho
Generated: o play the role of facilitator to help achieve a consensus withoh,1tmhll tmibu,'Prytitcck.o:apouo.c
nofeafetsthmmus.a
  in  oby2ks,rdinc rg d  st a rtms,beamafrdan,' ed ,xds
  ldonf",nnchlies)1u�"6�dgyf hugire wshicc ltasiperffot.
  a;�
Ss,  idi.6T uwni oggamein-ssogfsthybui)bch,q
 shis,-z
Nemoe om

Epoch [3] Total Parameters: 17271296


Epoch 3/6: 100%|████████████████████████████| 1262/1262 [26:54<00:00,  1.28s/it]


Epoch [3] Average Loss: 2.6084

--- Sample Generation ---
Prompt:    . Anarchists usually form small groups (5–20 individuals) to e
Generated: . Anarchists usually form small groups (5–20 individuals) to e-on-e ngob lrws
Ky
19 ppplwanizcamagatimady-ro usefrycelubem- t ave ifommugmbo so astrodu an a.r asyeby,7as ondcliarty-gak ind wb o adtir ad
ss an urcrats

K c
Scr a,2265chap.2 ivir
 t  whn
oery m, nod tald illad.crpo a ba,/k"�rstnc,d  

Epoch [4] Total Parameters: 17271296


Epoch 4/6: 100%|████████████████████████████| 1262/1262 [27:00<00:00,  1.28s/it]


Epoch [4] Average Loss: 2.6047

--- Sample Generation ---
Prompt:    s of their group without the need of a leader or a leading group
Generated: s of their group without the need of a leader or a leading groupwsavlyctwowgtrontaue, e adinravdlembn,'s  ebe asb ie,2,�
Mo mi l mancujuy itedmo rak"alunl ical  lw t,  ofuosugiogncub t.
thttee,bvems
 pap,kmaiove-id spis
S,:  lls rw rd.
2dsudvedvonas.atlu,;�Theces s chicubo on tuud  e d-vsopidrnnd pp

Epoch [5] Total Parameters: 17271296


Epoch 5/6:   5%|█▌                            | 64/1262 [01:22<25:48,  1.29s/it]


KeyboardInterrupt: 