# Data

In [46]:
import pandas as pd

train_path = "./../data/train.csv"
test_path = "./../data/test.csv"

train_df = pd.read_csv(train_path)
test_df = pd.read_csv(test_path)

print("Train shape:", train_df.shape)
print("Test shape:", test_df.shape)

print("\nTrain head:")
print(train_df.head())

print("\nTest head:")
print(test_df.head())

Train shape: (10000, 2)
Test shape: (2000, 2)

Train head:
                                               input  \
0               reconciliation trolls realized scene   
1                        scratched kemp blah devices   
2  delusional engineered perfect prey englishman ...   
3  boomers nfl reacts parallels everything 6 redu...   
4  patience put christmas superhero luc rake fulf...   

                                              target  
0               enecs dezilaer sllort noitailicnocer  
1                        secived halb pmek dehctarcs  
2  hctarcs detsub namhsilgne yerp tcefrep dereeni...  
3  stcudnoc ysereh redlof secuder 6 gnihtyreve sl...  
4  ylesned ylno elbats latnenitnoc dellifluf ekar...  

Test head:
                                               input  \
0  intimidated campaigns emerging marines spin be...   
1            salary lebanese wifi fury fab sta polly   
2  financing ahmed sexual cinematic puff malibu p...   
3  n00 nickel disparity funded tutoria

# BLT Transformer

In [47]:
import torch

if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("✅ Using MPS (Apple Silicon GPU)")
else:
    device = torch.device("cpu")
    print("⚠️ MPS not available, falling back to CPU")

✅ Using MPS (Apple Silicon GPU)


## Patcher (entropy-based segmentation)

**Shannon Entropy Function**

Helper to compute entropy of a sequence of characters.

H = - \sum p(x) \cdot \log_2(p(x))

In [48]:
import math
from collections import Counter

def shannon_entropy(text: str) -> float:
    """Compute Shannon entropy of a string."""
    if not text:
        return 0.0
    counts = Counter(text)
    probs = [count / len(text) for count in counts.values()]
    return -sum(p * math.log2(p) for p in probs)

**Patcher Function**

- Use sliding window of size W=10.
- Keep adding characters to current patch until either:
    - Entropy > threshold(2.0)
    - Patch length > 15
- Then start a new patch.

In [49]:
def patchify(text: str, window_size=10, entropy_threshold=2.0, max_patch_len=15):
    patches = []
    current_patch = ""

    for ch in text:
        current_patch += ch

        # Compute entropy only when window_size reached
        entropy = (
            shannon_entropy(current_patch[-window_size:])
            if len(current_patch) >= window_size else 0
        )

        # Split condition: high entropy OR too long
        if entropy > entropy_threshold or len(current_patch) >= max_patch_len:
            if current_patch.strip():
                patches.append(current_patch.strip())
            current_patch = ""

    # Add leftover patch
    if current_patch.strip():
        patches.append(current_patch.strip())

    return patches

**Test Patcher**

In [50]:
sample_texts = [
    "reconciliation trolls realized scene", # High entropy, more splits
    "LMA is fun!", # Based on threshold, may not split
    "aaaaabbbbbcccccddddd"  # low entropy predictable
]

for txt in sample_texts:
    patches = patchify(txt)
    print(f"\nText: {txt}")
    print("Patches:", patches)


Text: reconciliation trolls realized scene
Patches: ['reconcilia', 'tion troll', 's realized', 'scene']

Text: LMA is fun!
Patches: ['LMA is fun', '!']

Text: aaaaabbbbbcccccddddd
Patches: ['aaaaabbbbbccccc', 'ddddd']


## Hash N-Gram Embeddings

1. Extract all n-grams (n=1,2,3) from each patch.
2. Map each n-gram into a bucket in [0, 4095].
3. Use an embedding lookup table (nn.Embedding) to get a 64-d vector.
4. Sum all vectors → final patch embedding (shape = [64]).

**Hash Function**

In [51]:
import hashlib

def hash_ngram(ngram: str, num_buckets=4096) -> int:
    """Hash an n-gram string into a bucket [0, num_buckets-1]."""
    return int(hashlib.md5(ngram.encode("utf-8")).hexdigest(), 16) % num_buckets

**N-Gram Extraction**

In [52]:
def extract_ngrams(text: str, n: int):
    return [text[i:i+n] for i in range(len(text)-n+1)]

**Patch Embedding Module**

In [53]:
import torch
import torch.nn as nn

class PatchEmbedder(nn.Module):
    def __init__(self, num_buckets=4096, embed_dim=64):
        super().__init__()
        self.embeddings = nn.ModuleDict({
            "1": nn.Embedding(num_buckets, embed_dim),
            "2": nn.Embedding(num_buckets, embed_dim),
            "3": nn.Embedding(num_buckets, embed_dim),
        })
        # Xavier init
        for emb in self.embeddings.values():
            nn.init.xavier_uniform_(emb.weight)

        self.num_buckets = num_buckets
        self.embed_dim = embed_dim

    def forward(self, patch: str):
        """Convert one patch string into a [embed_dim] vector."""
        vectors = []
        # get device from embedding params
        device = next(self.embeddings["1"].parameters()).device  

        for n in [1, 2, 3]:
            ngrams = extract_ngrams(patch, n)
            for ng in ngrams:
                bucket = hash_ngram(ng, self.num_buckets)
                idx = torch.tensor(bucket, dtype=torch.long, device=device)
                vectors.append(self.embeddings[str(n)](idx))

        if len(vectors) == 0:
            return torch.zeros(self.embed_dim, device=device)

        return torch.stack(vectors, dim=0).sum(dim=0)  # sum across n-grams

**Test It on Sample Patches**

In [54]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
embedder = PatchEmbedder().to(device)

sample_patches = ["reconcilia", "tion troll", "s realized", " scene"]

print(f"Device: {device}")

for patch in sample_patches:
    vec = embedder(patch)
    print(f"Patch: {patch} | Embedding shape: {vec.shape}")

Device: mps
Patch: reconcilia | Embedding shape: torch.Size([64])
Patch: tion troll | Embedding shape: torch.Size([64])
Patch: s realized | Embedding shape: torch.Size([64])
Patch:  scene | Embedding shape: torch.Size([64])


## BLT Dataset Class

This dataset will:
1. Read train.csv / test.csv.
2. For each row:
    - Take input string → apply patchify → embed patches into [seq_len, 64].
    - Take target string → here we’ll keep it character-level for decoder supervision (simpler than patching the output).
3. Return tensors for (src_seq, tgt_seq).

In [55]:
import torch
from torch.utils.data import Dataset
import pandas as pd

class BLTDataset(Dataset):
    def __init__(self, csv_path, patch_embedder, 
                    window_size=10, entropy_threshold=2.0, max_patch_len=15):
        self.data = pd.read_csv(csv_path)
        self.patch_embedder = patch_embedder
        self.window_size = window_size
        self.entropy_threshold = entropy_threshold
        self.max_patch_len = max_patch_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        inp, tgt = row["input"], row["target"]

        # --- Input sequence: patch embeddings ---
        patches = patchify(inp, 
                            window_size=self.window_size,
                            entropy_threshold=self.entropy_threshold,
                            max_patch_len=self.max_patch_len)
        patch_vecs = [self.patch_embedder(p) for p in patches]
        if len(patch_vecs) == 0:
            src_seq = torch.zeros((1, self.patch_embedder.embed_dim))
        else:
            src_seq = torch.stack(patch_vecs, dim=0)  # [num_patches, 64]

        # --- Target sequence: character-level (convert chars to IDs) ---
        tgt_ids = torch.tensor(list(tgt.encode("utf-8")), dtype=torch.long)
        # NOTE: ord(c) = ASCII code (works since only printable ASCII)

        return src_seq, tgt_ids

In [56]:
train_ds = BLTDataset("./../data/train.csv", patch_embedder=embedder)
test_ds = BLTDataset("./../data/test.csv", patch_embedder=embedder)

# Get first sample
src_seq, tgt_ids = train_ds[0]

print("Input shape (patch embeddings):", src_seq.shape)  # [num_patches, 64]
print("Target shape (char ids):", tgt_ids.shape)          # [target_len]
print("Target IDs:", tgt_ids[:10])
print("Target string (reconstructed):", "".join([chr(x) for x in tgt_ids]))

Input shape (patch embeddings): torch.Size([4, 64])
Target shape (char ids): torch.Size([36])
Target IDs: tensor([101, 110, 101,  99, 115,  32, 100, 101, 122, 105])
Target string (reconstructed): enecs dezilaer sllort noitailicnocer


## Collate Function for BLT

We’ll:
1. Take a batch of (src_seq, tgt_seq).
2. Pad src_seq to [batch, max_src_len, 64].
3. Pad tgt_seq to [batch, max_tgt_len].
4. Return padded tensors + lengths (useful for masking in the model).

In [57]:
from torch.nn.utils.rnn import pad_sequence

def blt_collate_fn(batch, device="cpu"):
    """
    batch: list of (src_seq, tgt_seq) from BLTDataset
    """
    src_seqs, tgt_seqs = zip(*batch)  # unpack

    # --- Pad source (patch embeddings) ---
    # src_seqs is list of [num_patches, 64]
    # pad_sequence requires same shape in last dim
    src_padded = pad_sequence(src_seqs, batch_first=True, padding_value=0.0)
    # Shape: [B, max_src_len, 64]

    src_lengths = torch.tensor([s.size(0) for s in src_seqs], dtype=torch.long)

    # --- Pad target (char IDs) ---
    tgt_padded = pad_sequence(tgt_seqs, batch_first=True, padding_value=0)
    # Shape: [B, max_tgt_len]

    tgt_lengths = torch.tensor([t.size(0) for t in tgt_seqs], dtype=torch.long)

    return (src_padded.to(device),
            src_lengths.to(device),
            tgt_padded.to(device),
            tgt_lengths.to(device))

In [58]:
from torch.utils.data import DataLoader
from torch.utils.data import random_split

train_size = int(0.9 * len(train_ds))
val_size   = len(train_ds) - train_size
train_subset, val_subset = random_split(train_ds, [train_size, val_size])

train_loader = DataLoader(
    train_subset,
    batch_size=8,
    shuffle=True,
    collate_fn=lambda b: blt_collate_fn(b, device=device)
)

val_loader = DataLoader(
    val_subset,
    batch_size=4,
    shuffle=False,
    collate_fn=lambda b: blt_collate_fn(b, device=device)
)

# Fetch a batch
src_padded, src_lengths, tgt_padded, tgt_lengths = next(iter(train_loader))

print("Source batch shape:", src_padded.shape)   # [B, max_src_len, 64]
print("Source lengths:", src_lengths)
print("Target batch shape:", tgt_padded.shape)   # [B, max_tgt_len]
print("Target lengths:", tgt_lengths)

Source batch shape: torch.Size([8, 14, 64])
Source lengths: tensor([10,  7,  4,  6, 14, 14,  9,  7], device='mps:0')
Target batch shape: torch.Size([8, 135])
Target lengths: tensor([ 93,  62,  31,  51, 135, 134,  87,  65], device='mps:0')


## BLT Model Architecture

Core Idea:
1. Encoder: Take patch embeddings [B, L, 64], map into hidden dimension with a transformer encoder.
2. Decoder: Generate characters (IDs) step by step using a transformer decoder.
3. Output layer: Linear projection → vocab size.

In [59]:
import torch
import torch.nn as nn

class BLTModel(nn.Module):
    def __init__(self, embed_dim=64, hidden_dim=128, vocab_size=256, num_layers=2, nhead=4):
        super().__init__()

        # project patch embeddings to hidden size
        self.input_proj = nn.Linear(embed_dim, hidden_dim)

        # Positional encoding for encoder & decoder
        self.pos_encoder = nn.Embedding(512, hidden_dim)  # max 512 patches
        self.pos_decoder = nn.Embedding(512, hidden_dim)  # max 512 chars

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, enable_nested_tensor=False)

        # Transformer decoder
        decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=nhead, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        # Char embedding for targets
        self.tgt_embed = nn.Embedding(vocab_size, hidden_dim)

        # Output projection
        self.output_proj = nn.Linear(hidden_dim, vocab_size)

    def forward(self, src, src_lengths, tgt_inp):
        """
        src: [B, Ls, 64]    (patch embeddings)
        src_lengths: [B]    (not used yet, but can mask)
        tgt_inp: [B, Lt]    (char ids, input shifted right)
        """

        B, Ls, _ = src.shape
        B, Lt = tgt_inp.shape

        # --- Encoder ---
        src_emb = self.input_proj(src)  # [B, Ls, H]
        pos_src = self.pos_encoder(torch.arange(Ls, device=src.device)).unsqueeze(0)  # [1, Ls, H]
        src_key_padding_mask = (src.sum(dim=-1) == 0)  # [B, Ls], True for padding
        memory = self.encoder(src_emb, src_key_padding_mask=src_key_padding_mask)

        # --- Decoder ---
        tgt_emb = self.tgt_embed(tgt_inp)  # [B, Lt, H]
        pos_tgt = self.pos_decoder(torch.arange(Lt, device=tgt_inp.device)).unsqueeze(0)  # [1, Lt, H]
        tgt_emb = tgt_emb + pos_tgt

        # causal mask for autoregressive decoding
        causal_mask = nn.Transformer.generate_square_subsequent_mask(Lt).to(tgt_inp.device)

        tgt_key_padding_mask = (tgt_inp == 0)  # [B, Lt]
        out = self.decoder(
            tgt=tgt_emb,
            memory=memory,
            tgt_mask=causal_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )

        logits = self.output_proj(out)  # [B, Lt, vocab_size]
        return logits

In [60]:
# hyperparams
VOCAB_SIZE = 256  # byte-level
model = BLTModel(embed_dim=64, hidden_dim=128, vocab_size=VOCAB_SIZE).to(device)

# Dummy batch
src_padded, src_lengths, tgt_padded, tgt_lengths = next(iter(train_loader))

# Teacher forcing: shift target for decoder input
tgt_inp = tgt_padded[:, :-1]   # input
tgt_out = tgt_padded[:, 1:]    # expected output

logits = model(src_padded, src_lengths, tgt_inp)

print("Logits shape:", logits.shape)  # [B, Lt-1, vocab_size]

Logits shape: torch.Size([8, 105, 256])




## Training Loop + Checkpoints

We’ll set up:
1. Loss = CrossEntropyLoss(ignore_index=0) (ignores PAD tokens).
2. Optimizer = Adam.
3. Training loop with logging.
4. Checkpoint saving (state_dict, optimizer, epoch, loss).

In [84]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
from torch.nn.utils import clip_grad_norm_

# ----------------------------
# Training loop for BLT model
# ----------------------------
def train_blt(
    model, 
    train_loader, 
    val_loader=None, 
    num_epochs=5, 
    lr=1e-3, 
    device=None, 
    save_every=10, 
    resume_path=None
):
    if device is None:
        if torch.backends.mps.is_available():
            device = torch.device("mps")
        elif torch.cuda.is_available():
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")

    print(f"Training on: {device}")

    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="min",      # minimize validation loss
        factor=0.5,      # reduce LR by half
        patience=5       # wait 5 epochs without improvement
    )

    os.makedirs("checkpoints", exist_ok=True)

    # Resume
    start_epoch = 1
    if resume_path and os.path.exists(resume_path):
        print(f"🔄 Resuming from checkpoint: {resume_path}")
        checkpoint = torch.load(resume_path, map_location=device)
        model.load_state_dict(checkpoint["model_state"])
        optimizer.load_state_dict(checkpoint["optimizer_state"])
        start_epoch = checkpoint["epoch"] + 1
        if "scheduler_state" in checkpoint:
            scheduler.load_state_dict(checkpoint["scheduler_state"])
        print(f"✅ Resumed from epoch {checkpoint['epoch']} (loss {checkpoint['loss']:.4f})")

    for epoch in range(start_epoch, num_epochs + 1):
        model.train()
        total_loss = 0.0
        total_tokens = 0

        for batch in train_loader:
            src, src_lengths, tgt, tgt_lengths = batch
            src, src_lengths = src.to(device), src_lengths.to(device)
            tgt, tgt_lengths = tgt.to(device), tgt_lengths.to(device)

            tgt_inp = tgt[:, :-1]
            tgt_out = tgt[:, 1:]

            logits = model(src, src_lengths, tgt_inp)

            loss = criterion(
                logits.reshape(-1, logits.size(-1)),
                tgt_out.reshape(-1)
            )

            optimizer.zero_grad()
            loss.backward()
            clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item() * tgt_out.numel()
            total_tokens += tgt_out.numel()

        avg_loss = total_loss / total_tokens
        print(f"\n📘 Epoch {epoch}/{num_epochs} - Train Loss: {avg_loss:.4f}")

        if val_loader is not None:
            val_loss, val_acc = evaluate(model, val_loader, criterion, device)
            perplexity = math.exp(val_loss)
            print(f"📗 Validation - Loss: {val_loss:.4f}, Perplexity: {perplexity:.2f}, Token Acc: {val_acc:.2f}%")

            old_lr = optimizer.param_groups[0]["lr"]
            scheduler.step(val_loss)
            new_lr = optimizer.param_groups[0]["lr"]
            if new_lr != old_lr:
                print(f"🔽 LR reduced: {old_lr:.6f} → {new_lr:.6f}")

        if epoch % save_every == 0 or epoch == num_epochs:
            ckpt_path = f"checkpoints/blt_epoch{epoch}.pt"
            torch.save({
                "epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "loss": avg_loss
            }, ckpt_path)
            print(f"✅ Saved checkpoint: {ckpt_path}")

@torch.no_grad()
def evaluate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    correct_tokens = 0

    for batch in val_loader:
        src, src_lengths, tgt, tgt_lengths = batch
        src, src_lengths = src.to(device), src_lengths.to(device)
        tgt, tgt_lengths = tgt.to(device), tgt_lengths.to(device)

        tgt_inp = tgt[:, :-1]
        tgt_out = tgt[:, 1:]

        logits = model(src, src_lengths, tgt_inp)

        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            tgt_out.reshape(-1)
        )
        total_loss += loss.item() * tgt_out.numel()
        total_tokens += tgt_out.numel()

        preds = logits.argmax(dim=-1)
        mask = tgt_out != 0
        correct_tokens += ((preds == tgt_out) & mask).sum().item()

    avg_loss = total_loss / total_tokens
    accuracy = 100.0 * correct_tokens / total_tokens
    return avg_loss, accuracy

In [None]:
train_blt(
    model, 
    train_loader, 
    val_loader=val_loader,
    num_epochs=1000, 
    save_every=2, 
    resume_path="checkpoints/blt_epoch80.pt", 
    device=device
)

Training on: mps
🔄 Resuming from checkpoint: checkpoints/blt_epoch80.pt
✅ Resumed from epoch 80 (loss 1.5632)

📘 Epoch 81/1000 - Train Loss: 1.5370
📗 Validation - Loss: 1.3799, Perplexity: 3.97, Token Acc: 38.56%

📘 Epoch 82/1000 - Train Loss: 1.5328
📗 Validation - Loss: 1.3862, Perplexity: 4.00, Token Acc: 38.45%
✅ Saved checkpoint: checkpoints/blt_epoch82.pt

📘 Epoch 83/1000 - Train Loss: 1.5338
📗 Validation - Loss: 1.3877, Perplexity: 4.01, Token Acc: 38.39%

📘 Epoch 84/1000 - Train Loss: 1.5309
📗 Validation - Loss: 1.3899, Perplexity: 4.01, Token Acc: 38.29%
✅ Saved checkpoint: checkpoints/blt_epoch84.pt

📘 Epoch 85/1000 - Train Loss: 1.5299
📗 Validation - Loss: 1.3910, Perplexity: 4.02, Token Acc: 38.29%

📘 Epoch 86/1000 - Train Loss: 1.5282
📗 Validation - Loss: 1.3953, Perplexity: 4.04, Token Acc: 38.18%
✅ Saved checkpoint: checkpoints/blt_epoch86.pt

📘 Epoch 87/1000 - Train Loss: 1.5250
📗 Validation - Loss: 1.3927, Perplexity: 4.03, Token Acc: 38.18%
🔽 LR reduced: 0.001000 → 0.0

**Testing**

In [None]:
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import pandas as pd
import os

# ----------------------------
# --- Setup device ---
# ----------------------------
device = torch.device("cpu")
print(f"Testing on: {device}")

# ----------------------------
# --- Load patch embedder ---
# ----------------------------
# Assuming you already have PatchEmbedder class
embedder = PatchEmbedder(num_buckets=4096, embed_dim=64)

# ----------------------------
# --- Load test dataset ---
# ----------------------------
test_ds = BLTDataset("./../data/test.csv", patch_embedder=embedder)
test_loader = DataLoader(
    test_ds,
    batch_size=4,
    shuffle=False,
    collate_fn=lambda b: blt_collate_fn(b, device=device)
)

# ----------------------------
# --- Load model checkpoint ---
# ----------------------------
VOCAB_SIZE = 256
model = BLTModel(embed_dim=64, hidden_dim=128, vocab_size=VOCAB_SIZE).to(device)

ckpt_path = "checkpoints/blt_epoch80.pt"
if not os.path.exists(ckpt_path):
    raise FileNotFoundError(f"{ckpt_path} not found!")

checkpoint = torch.load(ckpt_path, map_location=device)
model.load_state_dict(checkpoint["model_state"])
model.eval()
print(f"✅ Loaded checkpoint: {ckpt_path} (epoch {checkpoint['epoch']}, loss {checkpoint['loss']:.4f})")

# ----------------------------
# --- Testing loop ---
# ----------------------------
criterion = nn.CrossEntropyLoss(ignore_index=0)
total_loss = 0.0
total_tokens = 0
correct_tokens = 0

with torch.no_grad():
    for batch in test_loader:
        src, src_lengths, tgt, tgt_lengths = batch
        src, tgt = src.to(device), tgt.to(device)

        tgt_inp = tgt[:, :-1]
        tgt_out = tgt[:, 1:]

        logits = model(src, src_lengths, tgt_inp)

        # Loss
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
        total_loss += loss.item()

        # Accuracy
        preds = logits.argmax(dim=-1)
        mask = tgt_out != 0
        correct_tokens += ((preds == tgt_out) & mask).sum().item()
        total_tokens += mask.sum().item()

avg_loss = total_loss / len(test_loader)
accuracy = 100.0 * correct_tokens / total_tokens

print(f"\n🧪 Test Results - Loss: {avg_loss:.4f}, Token Accuracy: {accuracy:.2f}%")

Testing on: cpu
✅ Loaded checkpoint: checkpoints/blt_epoch80.pt (epoch 80, loss 1.5632)





🧪 Test Results - Loss: 1.4813, Token Accuracy: 53.33%


## Validate BLT

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

def validate_blt(model, test_loader, device=None):
    if device is None:
        device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model.eval()
    model.to(device)

    criterion = nn.CrossEntropyLoss(ignore_index=0)
    total_loss = 0.0
    total_tokens = 0

    with torch.no_grad():
        for batch in test_loader:
            src, src_lengths, tgt, tgt_lengths = batch
            src, tgt = src.to(device), tgt.to(device)

            tgt_inp = tgt[:, :-1]
            tgt_out = tgt[:, 1:]

            logits = model(src, src_lengths, tgt_inp)  # [B, Lt-1, vocab]
            loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))

            total_loss += loss.item() * tgt_out.numel()  # sum over tokens
            total_tokens += tgt_out.numel()

    avg_loss = total_loss / total_tokens
    perplexity = torch.exp(torch.tensor(avg_loss))
    print(f"✅ Test Loss: {avg_loss:.4f}, Perplexity: {perplexity:.4f}")
    return avg_loss, perplexity

In [None]:
test_loader = DataLoader(
    test_ds,
    batch_size=4,
    shuffle=False,
    collate_fn=lambda b: blt_collate_fn(b, device=device)
)

validate_blt(model, test_loader, device=device)

✅ Test Loss: 1.4787, Perplexity: 4.3871


(1.4786781151334818, tensor(4.3871))

## Inference Function (Text-to-Text Generation)

In [None]:
def generate_blt(model, input_text, patch_embedder, max_len=256, device=None):
    if device is None:
        device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

    model.eval()
    model.to(device)

    # --- Prepare encoder input ---
    patches = patchify(input_text)
    src_seq = torch.stack([patch_embedder(p) for p in patches], dim=0).unsqueeze(0).to(device)  # [1, Ls, H]
    src_lengths = torch.tensor([src_seq.size(1)], dtype=torch.long).to(device)

    # --- Encoder memory ---
    memory = model.encoder(model.input_proj(src_seq) + model.pos_encoder(torch.arange(src_seq.size(1), device=device)).unsqueeze(0))

    # --- Autoregressive decoding ---
    output_ids = [0]  # start with PAD / or custom start token if you have one
    for t in range(max_len):
        tgt_inp = torch.tensor(output_ids, dtype=torch.long, device=device).unsqueeze(0)  # [1, t+1]
        tgt_emb = model.tgt_embed(tgt_inp) + model.pos_decoder(torch.arange(t+1, device=device)).unsqueeze(0)
        causal_mask = nn.Transformer.generate_square_subsequent_mask(t+1).to(device)

        out = model.decoder(tgt_emb, memory, tgt_mask=causal_mask)
        logits = model.output_proj(out)  # [1, t+1, vocab]
        next_id = logits[0, -1].argmax().item()  # greedy decoding
        output_ids.append(next_id)

        if next_id == 0:  # stop at PAD (or EOS token)
            break

    # Convert IDs to string
    output_text = "".join([chr(i) for i in output_ids[1:]])  # skip first PAD/start token
    return output_text

In [None]:
input_text = "reconciliation trolls realized scene"
output_text = generate_blt(model, input_text, patch_embedder=embedder, max_len=100, device=device)
print("Input :", input_text)
print("Output:", output_text)

Input : reconciliation trolls realized scene
Output: airetsyh seitilibapac seiranidro seiraniretsym seitilibapac seitilibapac seitilibapac seitilibapac s


## Prediction on test data and saving CSV

In [None]:
import os
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# --- Make sure device is set ---
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"🔍 Using device: {device}")

# --- Create prediction folder ---
os.makedirs("prediction", exist_ok=True)

# --- Load test dataset ---
test_df = pd.read_csv("./../data/test.csv")

# --- Simple Dataset wrapper ---
class SimpleTestDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.inputs = df["input"].tolist()
    
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, idx):
        return self.inputs[idx]

test_ds = SimpleTestDataset(test_df)

test_loader = DataLoader(
    test_ds,
    batch_size=1,  # autoregressive, keep 1
    shuffle=False
)

# --- Function to generate single prediction ---
def generate_single(input_text, model, patch_embedder, max_len=256):
    model.eval()
    with torch.no_grad():
        patches = patchify(input_text)
        if len(patches) == 0:
            return ""  # handle empty input

        # Encode source sequence
        src_seq = torch.stack([patch_embedder(p) for p in patches], dim=0).unsqueeze(0).to(device)
        src_lengths = torch.tensor([src_seq.size(1)], dtype=torch.long).to(device)

        src_emb = model.input_proj(src_seq)
        pos_src = model.pos_encoder(torch.arange(src_seq.size(1), device=device)).unsqueeze(0)
        memory = model.encoder(src_emb + pos_src)

        # Autoregressive decoding (greedy)
        output_ids = [ord(" ")]  # crude <SOS>: space token
        for t in range(max_len):
            tgt_inp = torch.tensor(output_ids, dtype=torch.long, device=device).unsqueeze(0)
            pos_tgt = model.pos_decoder(torch.arange(t+1, device=device)).unsqueeze(0)
            tgt_emb = model.tgt_embed(tgt_inp) + pos_tgt
            causal_mask = nn.Transformer.generate_square_subsequent_mask(t+1).to(device)

            out = model.decoder(tgt=tgt_emb, memory=memory, tgt_mask=causal_mask)
            logits = model.output_proj(out)  # [1, t+1, vocab]

            next_id = logits[0, -1].argmax().item()
            output_ids.append(next_id)

            # stop if <pad> (0) or newline (10) appears
            if next_id in [0, 10]:
                break

        pred_text = "".join([chr(i) for i in output_ids[1:]])  # skip fake SOS
        return pred_text

# --- Generate predictions ---
model.to(device)
model.eval()
predictions = []
for input_text in test_loader:
    input_str = input_text[0]  # batch size 1
    pred = generate_single(input_str, model, patch_embedder=embedder)
    predictions.append(pred)

# --- Save predictions ---
test_df["prediction"] = predictions
output_path = "./prediction/blt_predictions.csv"
test_df.to_csv(output_path, index=False)
print(f"✅ Predictions saved to {output_path}")

# Character level Model & Training

## Data & Tokenizer

1. Loads your train.csv and test.csv.
2. Creates a character vocabulary:
    - [PAD], [SOS], [EOS] + printable ASCII chars.
    - Maps chars ↔ IDs (stoi, itos).
3. Provides encode/decode functions.
4. Tests it on a sample string.

In [None]:
import pandas as pd
import torch
from torch.utils.data import Dataset
import string

# --- Step 1: Load Data ---
train_df = pd.read_csv("./../data/train.csv")
test_df  = pd.read_csv("./../data/test.csv")

print("Train shape:", train_df.shape)
print("Test shape:", test_df.shape)
print(train_df.head())

# --- Step 2: Character Tokenizer ---
# Printable ASCII characters 32–126
ascii_chars = [chr(i) for i in range(32, 127)]
special_tokens = ["[PAD]", "[SOS]", "[EOS]"]

itos = special_tokens + ascii_chars   # id → char
stoi = {ch: i for i, ch in enumerate(itos)}  # char → id

PAD_IDX = stoi["[PAD]"]
SOS_IDX = stoi["[SOS]"]
EOS_IDX = stoi["[EOS]"]

vocab_size = len(itos)

print("\nVocab size:", vocab_size)
print("PAD idx:", PAD_IDX, "SOS idx:", SOS_IDX, "EOS idx:", EOS_IDX)

# --- Encode / Decode functions ---
def encode_text(text, add_special=True):
    ids = [stoi[ch] for ch in text if ch in stoi]
    if add_special:
        ids = [SOS_IDX] + ids + [EOS_IDX]
    return torch.tensor(ids, dtype=torch.long)

def decode_ids(ids):
    chars = []
    for i in ids:
        if i == PAD_IDX or i == SOS_IDX or i == EOS_IDX:
            continue
        chars.append(itos[i])
    return "".join(chars)

# --- Quick test ---
sample = "LMA is fun!"
encoded = encode_text(sample)
decoded = decode_ids(encoded.tolist())

print("Sample:", sample)
print("Encoded:", encoded.tolist())
print("Decoded:", decoded)

Train shape: (10000, 2)
Test shape: (2000, 2)
                                               input  \
0               reconciliation trolls realized scene   
1                        scratched kemp blah devices   
2  delusional engineered perfect prey englishman ...   
3  boomers nfl reacts parallels everything 6 redu...   
4  patience put christmas superhero luc rake fulf...   

                                              target  
0               enecs dezilaer sllort noitailicnocer  
1                        secived halb pmek dehctarcs  
2  hctarcs detsub namhsilgne yerp tcefrep dereeni...  
3  stcudnoc ysereh redlof secuder 6 gnihtyreve sl...  
4  ylesned ylno elbats latnenitnoc dellifluf ekar...  

Vocab size: 98
PAD idx: 0 SOS idx: 1 EOS idx: 2
Sample: LMA is fun!
Encoded: [1, 47, 48, 36, 3, 76, 86, 3, 73, 88, 81, 4, 2]
Decoded: LMA is fun!


## Dataset + Collate

- CharDataset → loads strings, encodes them into token IDs.
- collate_fn → pads sequences per batch + stores lengths.
- DataLoader → provides batches for training & testing.
- Prints shapes and a decoded sample to check correctness.

In [None]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

# --- Dataset Class ---
class CharDataset(Dataset):
    def __init__(self, df):
        self.inputs = df["input"].tolist()
        self.targets = df["target"].tolist()

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        src_text = self.inputs[idx]
        tgt_text = self.targets[idx]

        src_ids = encode_text(src_text, add_special=True)  # [SOS] ... [EOS]
        tgt_ids = encode_text(tgt_text, add_special=True)

        return src_ids, tgt_ids


# --- Collate Function ---
def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)

    # Pad sequences
    src_padded = pad_sequence(src_batch, batch_first=True, padding_value=PAD_IDX)
    tgt_padded = pad_sequence(tgt_batch, batch_first=True, padding_value=PAD_IDX)

    # Lengths (before padding)
    src_lengths = torch.tensor([len(x) for x in src_batch], dtype=torch.long)
    tgt_lengths = torch.tensor([len(x) for x in tgt_batch], dtype=torch.long)

    return src_padded, src_lengths, tgt_padded, tgt_lengths


# --- Create Dataset + DataLoader ---
train_ds = CharDataset(train_df)
test_ds = CharDataset(test_df)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False, collate_fn=collate_fn)

# --- Quick Check ---
src_padded, src_lengths, tgt_padded, tgt_lengths = next(iter(train_loader))

print("src_padded shape:", src_padded.shape)
print("tgt_padded shape:", tgt_padded.shape)
print("\nsrc_lengths:", src_lengths[:5])
print("tgt_lengths:", tgt_lengths[:5])
print("\nExample decoded input:", decode_ids(src_padded[0].tolist()))
print("Example decoded target:", decode_ids(tgt_padded[0].tolist()))

src_padded shape: torch.Size([32, 125])
tgt_padded shape: torch.Size([32, 125])

src_lengths: tensor([ 61, 117, 111,  71,  39])
tgt_lengths: tensor([ 61, 117, 111,  71,  39])

Example decoded input: cover kc we're extremists category stealth voter regulators
Example decoded target: srotaluger retov htlaets yrogetac stsimertxe er'ew ck revoc


In [None]:
# --- Character Vocabulary (printable ASCII + special tokens) ---
import string

PAD_IDX = 0
SOS_IDX = 1
EOS_IDX = 2

# Printable ASCII characters (32-126)
chars = [chr(i) for i in range(32, 127)]
VOCAB = {c: i+3 for i, c in enumerate(chars)}  # reserve 0,1,2 for PAD, SOS, EOS
INV_VOCAB = {i: c for c, i in VOCAB.items()}

VOCAB_SIZE = len(VOCAB) + 3  # include PAD, SOS, EOS
print(f"Vocabulary size: {VOCAB_SIZE}")

Vocabulary size: 98


## Baseline Model

- Embedding layer for characters
- Positional encoding
- Transformer encoder–decoder (2 layers each)
- Linear projection to vocab size

In [None]:
import math
import torch
import torch.nn as nn

# --- Positional Encoding (sinusoidal) ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer("pe", pe)

    def forward(self, x):
        # x: [B, L, D]
        x = x + self.pe[:, :x.size(1), :]
        return x


# --- Baseline Transformer Model ---
class CharTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, nhead=4, num_layers=2, dim_ff=128, dropout=0.1):
        super().__init__()
        self.d_model = d_model

        # Embeddings
        self.src_embed = nn.Embedding(vocab_size, d_model, padding_idx=PAD_IDX)
        self.tgt_embed = nn.Embedding(vocab_size, d_model, padding_idx=PAD_IDX)

        self.pos_encoder = PositionalEncoding(d_model)

        # Encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
                                                    dim_feedforward=dim_ff, dropout=dropout, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Decoder
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead,
                                                    dim_feedforward=dim_ff, dropout=dropout, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        # Output projection
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt_inp):
        """
        src: [B, Ls]
        tgt_inp: [B, Lt]
        """
        # Embedding + positional encoding
        src_emb = self.pos_encoder(self.src_embed(src))  # [B, Ls, D]
        tgt_emb = self.pos_encoder(self.tgt_embed(tgt_inp))  # [B, Lt, D]

        # Masks
        src_key_padding_mask = (src == PAD_IDX)  # [B, Ls]
        tgt_key_padding_mask = (tgt_inp == PAD_IDX)  # [B, Lt]
        causal_mask = nn.Transformer.generate_square_subsequent_mask(tgt_inp.size(1)).to(src.device)

        # Encode
        memory = self.encoder(src_emb, src_key_padding_mask=src_key_padding_mask)

        # Decode
        out = self.decoder(tgt=tgt_emb, memory=memory,
                            tgt_mask=causal_mask,
                            tgt_key_padding_mask=tgt_key_padding_mask,
                            memory_key_padding_mask=src_key_padding_mask)

        # Project to vocab
        logits = self.fc_out(out)  # [B, Lt, vocab_size]
        return logits

In [None]:
# Create model
baseline_model = CharTransformer(vocab_size=len(VOCAB), d_model=64, nhead=4, num_layers=2).to(device)

# Dummy batch
src_padded, src_lengths, tgt_padded, tgt_lengths = next(iter(train_loader))

# Teacher forcing: shift target
tgt_inp = tgt_padded[:, :-1]
tgt_out = tgt_padded[:, 1:]

# Forward pass
logits = baseline_model(src_padded.to(device), tgt_inp.to(device))

print("Logits shape:", logits.shape)  # [B, Lt-1, vocab_size]

Logits shape: torch.Size([32, 125, 95])




## Training

- AdamW optimizer
- CrossEntropyLoss with ignore_index=PAD_IDX
- Gradient clipping
- Checkpoint saving every few epochs
- Optional resume from checkpoint

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
from torch.nn.utils import clip_grad_norm_

# ----------------------------
# Training loop for baseline Transformer
# ----------------------------
def train_baseline(
    model,
    train_loader,
    val_loader=None,   # optional validation loader
    num_epochs=1000,
    lr=1e-3,
    device=None,
    save_every=5,
    val_every=50,
    resume_path=None
):
    # --- Pick device automatically if not provided ---
    if device is None:
        if torch.backends.mps.is_available():
            device = torch.device("mps")
        elif torch.cuda.is_available():
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")
    print(f"Training on: {device}")

    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)  # ignore <PAD> tokens

    os.makedirs("checkpoints_baseline", exist_ok=True)

    # --- Resume from checkpoint ---
    start_epoch = 1
    if resume_path and os.path.exists(resume_path):
        print(f"🔄 Resuming from checkpoint: {resume_path}")
        checkpoint = torch.load(resume_path, map_location=device)
        model.load_state_dict(checkpoint["model_state"])
        optimizer.load_state_dict(checkpoint["optimizer_state"])
        start_epoch = checkpoint["epoch"] + 1
        print(f"✅ Resumed from epoch {checkpoint['epoch']} (loss {checkpoint['loss']:.4f})")

    # --- Training loop ---
    for epoch in range(start_epoch, num_epochs + 1):
        model.train()
        total_loss = 0.0

        for batch in train_loader:
            src, src_lengths, tgt, tgt_lengths = batch
            src, tgt = src.to(device), tgt.to(device)

            # Shift target for teacher forcing
            tgt_inp = tgt[:, :-1]   # decoder input
            tgt_out = tgt[:, 1:]    # expected output

            logits = model(src, tgt_inp)  # [B, L, vocab_size]

            # Compute loss
            loss = criterion(
                logits.reshape(-1, logits.size(-1)),
                tgt_out.reshape(-1)
            )

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"📘 Epoch {epoch}/{num_epochs} - Train Loss: {avg_loss:.4f}")

        # --- Validation ---
        if val_loader is not None and epoch % val_every == 0:
            val_loss, val_acc = evaluate_baseline(model, val_loader, criterion, device)
            print(f"📗 Validation (epoch {epoch}) - Loss: {val_loss:.4f}, Token Acc: {val_acc:.2f}%")

        # --- Save checkpoint ---
        if epoch % save_every == 0 or epoch == num_epochs:
            ckpt_path = f"checkpoints_baseline/char_transformer_epoch{epoch}.pt"
            torch.save({
                "epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "loss": avg_loss
            }, ckpt_path)
            print(f"✅ Saved checkpoint: {ckpt_path}")


# ----------------------------
# Evaluation / Validation
# ----------------------------
@torch.no_grad()
def evaluate_baseline(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    correct_tokens = 0

    for batch in val_loader:
        src, src_lengths, tgt, tgt_lengths = batch
        src, tgt = src.to(device), tgt.to(device)

        tgt_inp = tgt[:, :-1]
        tgt_out = tgt[:, 1:]

        logits = model(src, tgt_inp)

        # Compute loss
        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            tgt_out.reshape(-1)
        )
        total_loss += loss.item()

        # Accuracy (ignore PAD)
        preds = logits.argmax(dim=-1)
        mask = tgt_out != PAD_IDX
        correct_tokens += ((preds == tgt_out) & mask).sum().item()
        total_tokens += mask.sum().item()

    avg_loss = total_loss / len(val_loader)
    accuracy = 100.0 * correct_tokens / total_tokens
    return avg_loss, accuracy

In [None]:
train_baseline(
    model=baseline_model,
    train_loader=train_loader,
    val_loader=None,
    num_epochs=1000,
    lr=1e-3,
    device=device,
    save_every=5,
    resume_path="checkpoints_baseline/char_transformer_epoch45.pt"
)

Training on: mps
🔄 Resuming from checkpoint: checkpoints_baseline/char_transformer_epoch45.pt
✅ Resumed from epoch 45 (loss 0.2808)
📘 Epoch 46/1000 - Train Loss: 0.2778
📘 Epoch 47/1000 - Train Loss: 0.2735
📘 Epoch 48/1000 - Train Loss: 0.2846
📘 Epoch 49/1000 - Train Loss: 0.2716
📘 Epoch 50/1000 - Train Loss: 0.2597
✅ Saved checkpoint: checkpoints_baseline/char_transformer_epoch50.pt
📘 Epoch 51/1000 - Train Loss: 0.2416


KeyboardInterrupt: 

## Validation

In [None]:
import pandas as pd
from collections import Counter

# --- Load data ---
train_df = pd.read_csv("./../data/train.csv")
test_df = pd.read_csv("./../data/test.csv")

# --- Build character set from train + test ---
all_text = "".join(train_df["input"].tolist() + train_df["target"].tolist() +
                   test_df["input"].tolist() + test_df["target"].tolist())

chars = sorted(list(set(all_text)))

# --- Add special tokens ---
PAD_IDX = 0
SOS_IDX = 1
EOS_IDX = 2

CHAR2IDX = {c: i+3 for i, c in enumerate(chars)}
IDX2CHAR = {i+3: c for i, c in enumerate(chars)}

CHAR2IDX["<PAD>"] = PAD_IDX
CHAR2IDX["<SOS>"] = SOS_IDX
CHAR2IDX["<EOS>"] = EOS_IDX
IDX2CHAR[PAD_IDX] = "<PAD>"
IDX2CHAR[SOS_IDX] = "<SOS>"
IDX2CHAR[EOS_IDX] = "<EOS>"

VOCAB_SIZE = len(CHAR2IDX)
print(f"✅ Vocab size: {VOCAB_SIZE}")

✅ Vocab size: 91


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

class CharDataset(torch.utils.data.Dataset):
    def __init__(self, df, char2idx):
        self.inputs = df["input"].tolist()
        self.targets = df["target"].tolist()
        self.char2idx = char2idx

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        # Convert characters to indices
        src = [self.char2idx[c] for c in self.inputs[idx]]
        tgt = [self.char2idx[c] for c in self.targets[idx]]
        return torch.tensor(src, dtype=torch.long), torch.tensor(tgt, dtype=torch.long)

def char_collate_fn(batch, device=None):
    src_batch, tgt_batch = zip(*batch)
    src_lengths = [len(s) for s in src_batch]
    tgt_lengths = [len(t) for t in tgt_batch]

    max_src_len = max(src_lengths)
    max_tgt_len = max(tgt_lengths)

    src_padded = torch.zeros(len(batch), max_src_len, dtype=torch.long)
    tgt_padded = torch.zeros(len(batch), max_tgt_len, dtype=torch.long)

    for i, (s, t) in enumerate(zip(src_batch, tgt_batch)):
        src_padded[i, :len(s)] = s
        tgt_padded[i, :len(t)] = t

    if device is not None:
        src_padded = src_padded.to(device)
        tgt_padded = tgt_padded.to(device)

    return src_padded, torch.tensor(src_lengths), tgt_padded, torch.tensor(tgt_lengths)

@torch.no_grad()
def validate_baseline(model, test_loader, device=None):
    """
    Evaluate baseline character-level transformer on test set.
    Computes avg loss, perplexity, and token-level accuracy.
    """
    if device is None:
        device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model.eval()
    model.to(device)

    criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)  # ignore padding
    total_loss = 0.0
    total_tokens = 0
    correct_tokens = 0

    for batch in test_loader:
        src, src_lengths, tgt, tgt_lengths = batch
        src, tgt = src.to(device), tgt.to(device)

        tgt_inp = tgt[:, :-1]
        tgt_out = tgt[:, 1:]

        logits = model(src, tgt_inp)  # [B, Lt-1, vocab_size]

        # Compute loss
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
        total_loss += loss.item() * tgt_out.numel()

        # Compute accuracy
        preds = logits.argmax(dim=-1)
        mask = tgt_out != PAD_IDX
        correct_tokens += ((preds == tgt_out) & mask).sum().item()
        total_tokens += mask.sum().item()

    avg_loss = total_loss / total_tokens
    perplexity = torch.exp(torch.tensor(avg_loss))
    accuracy = 100.0 * correct_tokens / total_tokens

    print(f"✅ Test Loss: {avg_loss:.4f}, Perplexity: {perplexity:.4f}, Token Accuracy: {accuracy:.2f}%")
    return avg_loss, perplexity, accuracy

# Example usage:
test_ds = CharDataset(pd.read_csv("./../data/test.csv"), char2idx=CHAR2IDX)
test_loader = DataLoader(
    test_ds,
    batch_size=8,
    shuffle=False,
    collate_fn=lambda b: char_collate_fn(b, device=device)
)

validate_baseline(baseline_model, test_loader, device=device)

  output = torch._nested_tensor_from_mask(


✅ Test Loss: 29.6431, Perplexity: 7478683107328.0000, Token Accuracy: 5.16%


(29.643077177616124, tensor(7.4787e+12), 5.159886322149284)

## Test set prediction

1. encode_text() and decode_text() are the same functions you used for your baseline dataset.
2. SOS_IDX and EOS_IDX should be your special token indices (start/end).
3. This uses greedy decoding only. No beam search or sampling needed.
4. You can now submit predictions_normal.csv for evaluation.

In [None]:
import torch

# Force CPU
device = torch.device("cpu")

# Load checkpoint on CPU
checkpoint_path = "checkpoints_baseline/char_transformer_epoch10.pt"
baseline_model.load_state_dict(torch.load(checkpoint_path, map_location=device)["model_state"])
baseline_model.to(device)
baseline_model.eval()

# Vocabulary
vocab = {c:i+3 for i,c in enumerate([chr(i) for i in range(32,127)])}
inv_vocab = {i:c for c,i in vocab.items()}
PAD_IDX, SOS_IDX, EOS_IDX = 0, 1, 2

def encode_text(text):
    return [SOS_IDX] + [vocab[c] for c in text] + [EOS_IDX]

def decode_text(ids):
    return "".join([inv_vocab[i] for i in ids if i > 2])

# Autoregressive prediction
def generate_single_char_model(text, model, max_len=128):
    model.eval()
    with torch.no_grad():
        input_ids = torch.tensor([encode_text(text)], dtype=torch.long, device=device)
        output_ids = [SOS_IDX]

        for t in range(max_len):
            tgt_ids = torch.tensor([output_ids], dtype=torch.long, device=device)
            logits = model(input_ids, tgt_ids)
            next_id = logits[0, -1].argmax().item()
            output_ids.append(next_id)
            if next_id == EOS_IDX:
                break

        return decode_text(output_ids[1:])

# Example
sample_text = "LMA is fun!"
prediction = generate_single_char_model(sample_text, baseline_model)
print("Input: ", sample_text)
print("Prediction:", prediction)

Input:  LMA is fun!
Prediction: hsi gnik $


  output = torch._nested_tensor_from_mask(


In [None]:
import os
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset

# ----------------------------
# --- Load test data ---
# ----------------------------
test_df = pd.read_csv("./../data/test.csv")

class CharTestDataset(Dataset):
    def __init__(self, df):
        self.inputs = df["input"].tolist()
    
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, idx):
        return self.inputs[idx]

test_ds = CharTestDataset(test_df)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False)

# ----------------------------
# --- Character encoding / decoding ---
# ----------------------------
# Make sure these match your training vocab
VOCAB = [PAD, SOS, EOS] + [chr(i) for i in range(32, 127)]
PAD_IDX, SOS_IDX, EOS_IDX = 0, 1, 2

char2idx = {c: i for i, c in enumerate(VOCAB)}
idx2char = {i: c for i, c in enumerate(VOCAB)}

def encode_text(text):
    return [SOS_IDX] + [char2idx.get(c, 3) for c in text] + [EOS_IDX]  # unknown=3

def decode_text(indices):
    # remove SOS/EOS/PAD
    chars = [idx2char.get(i, "") for i in indices if i not in (PAD_IDX, SOS_IDX, EOS_IDX)]
    return "".join(chars)

# ----------------------------
# --- Single prediction ---
# ----------------------------
@torch.no_grad()
def generate_single_char_model(text, model, max_len=256):
    device = torch.device("cpu")  # force CPU for safety
    model = model.to(device)
    model.eval()

    input_ids = torch.tensor([encode_text(text)], dtype=torch.long, device=device)

    # Start output with SOS
    output_ids = [SOS_IDX]

    for t in range(max_len):
        tgt_ids = torch.tensor([output_ids], dtype=torch.long, device=device)
        logits = model(input_ids, tgt_ids)  # returns [1, seq_len, vocab_size]
        next_id = logits[0, -1].argmax().item()
        output_ids.append(next_id)

        if next_id == EOS_IDX:
            break

    return decode_text(output_ids)

# ----------------------------
# --- Generate predictions for all test samples ---
# ----------------------------
predictions = []
for input_text in test_loader:
    input_str = input_text[0]  # batch_size=1
    pred = generate_single_char_model(input_str, baseline_model)
    predictions.append(pred)

# ----------------------------
# --- Save CSV ---
# ----------------------------
test_df["prediction"] = predictions
os.makedirs("prediction", exist_ok=True)
output_path = "./prediction/predictions_baseline.csv"
test_df.to_csv(output_path, index=False)
print(f"✅ Predictions saved to {output_path}")

# Evaluation & Predictions

##  Metrics evaluation script for both BLT and baseline models

1.	Token-level accuracy
2.	Average sequence length (input tokens vs predicted tokens)

In [None]:
import pandas as pd

def evaluate_predictions(pred_csv_path, original_csv_path):
    # Load predictions
    df_pred = pd.read_csv(pred_csv_path)
    df_orig = pd.read_csv(original_csv_path)

    total_tokens = 0
    correct_tokens = 0
    total_input_len = 0
    total_output_len = 0

    for idx in range(len(df_pred)):
        pred = str(df_pred.loc[idx, "prediction"])
        target = str(df_orig.loc[idx, "target"])
        input_str = str(df_orig.loc[idx, "input"])

        # Token-level comparison
        min_len = min(len(pred), len(target))
        correct_tokens += sum([pred[i] == target[i] for i in range(min_len)])
        total_tokens += len(target)

        # Sequence length stats
        total_input_len += len(input_str)
        total_output_len += len(pred)

    token_accuracy = 100.0 * correct_tokens / total_tokens
    avg_input_len = total_input_len / len(df_pred)
    avg_output_len = total_output_len / len(df_pred)

    print(f"✅ Evaluation for {pred_csv_path}")
    print(f"Token-level Accuracy: {token_accuracy:.2f}%")
    print(f"Average input length: {avg_input_len:.2f} chars")
    print(f"Average predicted length: {avg_output_len:.2f} chars")
    print("-" * 50)
    return token_accuracy, avg_input_len, avg_output_len


# --- Evaluate BLT predictions ---
blt_acc, blt_in_len, blt_out_len = evaluate_predictions(
    "./prediction/blt_predictions.csv", "./../data/test.csv"
)

# --- Evaluate Baseline predictions ---
baseline_acc, base_in_len, base_out_len = evaluate_predictions(
    "./prediction/predictions_baseline.csv", "./../data/test.csv"
)

# --- Comparison Summary ---
print("📊 Comparison Summary:")
print(f"BLT      | Acc: {blt_acc:.2f}%, Avg Input: {blt_in_len:.2f}, Avg Pred: {blt_out_len:.2f}")
print(f"Baseline | Acc: {baseline_acc:.2f}%, Avg Input: {base_in_len:.2f}, Avg Pred: {base_out_len:.2f}")