# Marble 1: Does Size Matter?

**Series:** Marble (hyperparameter exploration for minimal viable LM)

**Goal:** Find the smallest model that learns *beyond unigram distribution* while training fast enough to iterate on a MacBook.

**This run:**
- Model: 4L, 128D, 4H (~3-4M params)
- Training: 5000 steps, lr=3e-4, wd=0.01
- Target: ~100 steps/minute

**Success criteria:**
1. Loss drops below 5.0 (better than Flannel's 5.25)
2. Top predictions include context-dependent tokens (not just comma/period)
3. h_mean shows direction changes (not converging to static point)
4. Dead tokens show motion → slowing pattern (approaching freeze)

**Data storage:**
- Live metrics logged to CSV every epoch
- Checkpoints saved every 500 steps (W as bfloat16, h_mean as float32)

## Parameters

In [1]:
# Training
TOTAL_STEPS = 5000
BATCH_SIZE = 128
SEQ_LEN = 128
LEARNING_RATE = 3e-4      # Standard LM lr
WEIGHT_DECAY = 0.01       # Gentle regularization
BETA1 = 0.9
BETA2 = 0.999
EPSILON = 1e-8

# Model architecture
VOCAB_SIZE = 10000
HIDDEN_DIM = 128          # 2× Flannel
NUM_LAYERS = 4            # 2× Flannel
NUM_HEADS = 4

# Checkpointing
CHECKPOINT_INTERVAL = 500  # Save W and h_mean every N steps

# Reproducibility
RANDOM_SEED = 42

# Paths
CORPUS_PATH = '../../data/flannel_model_corpus.txt'
TOKENIZER_PATH = '../../data/flannel_tokenizer_chars.json'
DEAD_MASK_PATH = '../../tensors/Flannel/live_dead_tokens.safetensors'
OUTPUT_DIR = '../../tensors/Marble-1'

print(f"Marble 1: {NUM_LAYERS}L, {HIDDEN_DIM}D, {NUM_HEADS}H")
print(f"Training: {TOTAL_STEPS} steps, lr={LEARNING_RATE}, wd={WEIGHT_DECAY}")

Marble 1: 4L, 128D, 4H
Training: 5000 steps, lr=0.0003, wd=0.01


## Imports

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tokenizers import Tokenizer
from safetensors.torch import save_file, load_file
from pathlib import Path
from tqdm import tqdm
import json
import random
import numpy as np
import pandas as pd
from datetime import datetime

## Device Detection

In [3]:
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Set Random Seeds

In [4]:
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

if device == 'cuda':
    torch.cuda.manual_seed_all(RANDOM_SEED)

print(f"Random seed set to {RANDOM_SEED}")

Random seed set to 42


## Load Tokenizer and Dead Token Mask

In [5]:
tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
print(f"Loaded tokenizer with vocab size {tokenizer.get_vocab_size()}")

masks = load_file(DEAD_MASK_PATH)
dead_mask = masks['dead_mask'].bool()
dead_indices = masks['dead_indices'].long()
live_mask = ~dead_mask
n_dead = dead_mask.sum().item()
n_live = live_mask.sum().item()

print(f"Dead tokens: {n_dead}")
print(f"Live tokens: {n_live}")

Loaded tokenizer with vocab size 10000
Dead tokens: 3699
Live tokens: 6301


## Dataset

In [6]:
class TextDataset(Dataset):
    def __init__(self, corpus_path, tokenizer, seq_len):
        with open(corpus_path, 'r', encoding='utf-8') as f:
            text = f.read()
        encoding = tokenizer.encode(text)
        self.tokens = encoding.ids
        self.seq_len = seq_len
        self.n_sequences = len(self.tokens) // seq_len
        
    def __len__(self):
        return self.n_sequences
    
    def __getitem__(self, idx):
        start = idx * self.seq_len
        chunk = self.tokens[start:start + self.seq_len]
        return torch.tensor(chunk, dtype=torch.long)

dataset = TextDataset(CORPUS_PATH, tokenizer, SEQ_LEN)
EPOCH_LENGTH = len(dataset) // BATCH_SIZE

print(f"Dataset: {len(dataset)} sequences of length {SEQ_LEN}")
print(f"Epoch length: ~{EPOCH_LENGTH} steps")

Dataset: 10713 sequences of length 128
Epoch length: ~83 steps


## Model Definition

**Architecture:** 4L, 128D, 4H  
**Compute dtype:** bfloat16 (like Qwen)  
**Storage dtype:** bfloat16 for W checkpoints, float32 for metrics

In [7]:
class TinyLM(nn.Module):
    def __init__(self, vocab_size, hidden_dim, num_layers, num_heads, seq_len):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.pos_embedding = nn.Embedding(seq_len, hidden_dim)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=0.0,
            activation='gelu',
            batch_first=True,
            norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.ln_f = nn.LayerNorm(hidden_dim)
        
        # Xavier init for embeddings
        nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.pos_embedding.weight, mean=0.0, std=0.02)
        
    def forward(self, input_ids, return_hidden=False):
        """Forward pass. Returns logits, or (logits, h) if return_hidden=True."""
        batch_size, seq_len = input_ids.shape
        
        tok_emb = self.embedding(input_ids)
        pos_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        pos_emb = self.pos_embedding(pos_ids)
        
        hidden = tok_emb + pos_emb
        causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=input_ids.device)
        hidden = self.transformer(hidden, mask=causal_mask, is_causal=True)
        hidden = self.ln_f(hidden)  # This is h
        
        logits = hidden @ self.embedding.weight.T
        
        if return_hidden:
            return logits, hidden
        return logits

model = TinyLM(VOCAB_SIZE, HIDDEN_DIM, NUM_LAYERS, NUM_HEADS, SEQ_LEN)
model = model.to(device).to(torch.bfloat16)  # Train in bfloat16 like Qwen

n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,} ({n_params/1e6:.2f}M)")

Model parameters: 2,089,728 (2.09M)




## Optimizer

AdamW with float32 states (optimizer requires full precision).

In [8]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    betas=(BETA1, BETA2),
    eps=EPSILON,
    weight_decay=WEIGHT_DECAY
)

print(f"Optimizer: AdamW (lr={LEARNING_RATE}, wd={WEIGHT_DECAY})")

Optimizer: AdamW (lr=0.0003, wd=0.01)


## Metrics and Checkpointing Setup

In [9]:
output_path = Path(OUTPUT_DIR)
output_path.mkdir(parents=True, exist_ok=True)

# Metrics log (lightweight, written every epoch)
metrics_log = []

# Checkpoint storage (every CHECKPOINT_INTERVAL steps)
checkpoints = {
    'steps': [],      # Which steps we saved
    'W': [],          # bfloat16
    'h_mean': [],     # float32
}

print(f"Output directory: {output_path}")
print(f"Checkpoints every {CHECKPOINT_INTERVAL} steps")

Output directory: ../../tensors/Marble-1
Checkpoints every 500 steps


## Helper Functions

### Lattice Displacement (for dead token motion detection)

In [10]:
def compute_ulp_bf16(tensor_bf16):
    """Compute ULP (Unit in Last Place) for bfloat16 tensor.
    
    Formula: ULP = 2^(E - 134) where E is the exponent.
    Returns float32 tensor of ULP values.
    """
    bits = tensor_bf16.view(torch.uint16).to(torch.int32)
    exponent = ((bits >> 7) & 0xFF).to(torch.int32)
    
    # Subnormals: treat E=0 as E=1
    effective_exp = torch.where(exponent == 0, torch.ones_like(exponent), exponent)
    
    ulp = torch.pow(2.0, (effective_exp - 134).float())
    return ulp

def compute_lattice_displacement(W_before, W_after):
    """Compute lattice displacement ΔW′ = ΔW / ULP.
    
    Args:
        W_before: bfloat16 weights at time t-1
        W_after: bfloat16 weights at time t
    
    Returns:
        float32 tensor of lattice displacements (per dimension)
    """
    delta_W = W_after.float() - W_before.float()
    ulp = compute_ulp_bf16(W_before)
    return delta_W / ulp

def fraction_moving(W_before, W_after, mask):
    """Fraction of tokens (selected by mask) with L1 > 0 (any dimension moved).
    
    Args:
        W_before, W_after: bfloat16 weight matrices
        mask: boolean mask selecting tokens to measure
    
    Returns:
        float: fraction of masked tokens that moved
    """
    delta_W_prime = compute_lattice_displacement(W_before[mask], W_after[mask])
    L1 = delta_W_prime.abs().sum(dim=1)  # Sum across dimensions
    moving = (L1 > 0).float().mean().item()
    return moving

print("✓ Lattice displacement functions defined")

✓ Lattice displacement functions defined


## Training Loop

**Live metrics computed every epoch:**
1. Loss
2. Top-3 predicted tokens (from h_mean)
3. h_mean autocorrelation (vs previous step)
4. Fraction of dead tokens moving (L1 > 0)

**Checkpoints saved every 500 steps:**
- W (bfloat16)
- h_mean (float32)

In [11]:
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
loss_fn = nn.CrossEntropyLoss()

model.train()
step = 0
epoch = 0

# Previous step's data for metrics
h_mean_prev = None
W_prev = None

# Initial checkpoint (t=0)
W_init = model.embedding.weight.detach().cpu().to(torch.bfloat16)
checkpoints['steps'].append(0)
checkpoints['W'].append(W_init)
checkpoints['h_mean'].append(torch.zeros(HIDDEN_DIM))  # Placeholder for t=0

print(f"Starting training for {TOTAL_STEPS} steps...")
pbar = tqdm(total=TOTAL_STEPS, desc="Training")

while step < TOTAL_STEPS:
    epoch += 1
    
    for batch in dataloader:
        if step >= TOTAL_STEPS:
            break
        
        input_ids = batch.to(device)
        
        # Forward pass
        with torch.autocast(device_type=device if device != 'mps' else 'cpu', dtype=torch.bfloat16):
            logits, h = model(input_ids, return_hidden=True)
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = input_ids[:, 1:].contiguous()
            loss = loss_fn(shift_logits.view(-1, VOCAB_SIZE), shift_labels.view(-1))
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        step += 1
        
        # === Compute metrics ===
        
        # Get current state (float32 for metrics, bfloat16 for storage)
        W_current_bf16 = model.embedding.weight.detach().cpu().to(torch.bfloat16)
        W_current_f32 = W_current_bf16.float()
        h_mean_current = h.mean(dim=(0, 1)).cpu().float()  # (D,)
        
        # Top-3 tokens (unembedding logits from h_mean)
        logits_from_h = h_mean_current @ W_current_f32.T
        top3_tokens = torch.topk(logits_from_h, k=3).indices.tolist()
        
        # h_mean autocorrelation
        if h_mean_prev is not None:
            h_autocorr = torch.dot(h_mean_current, h_mean_prev) / \
                        (torch.norm(h_mean_current) * torch.norm(h_mean_prev))
            h_autocorr = h_autocorr.item()
        else:
            h_autocorr = float('nan')
        
        # Dead token motion
        if W_prev is not None:
            dead_moving = fraction_moving(W_prev, W_current_bf16, dead_mask)
        else:
            dead_moving = float('nan')
        
        # Log metrics
        metrics_log.append({
            'step': step,
            'epoch': epoch,
            'loss': loss.item(),
            'top1': top3_tokens[0],
            'top2': top3_tokens[1],
            'top3': top3_tokens[2],
            'h_autocorr': h_autocorr,
            'dead_moving': dead_moving,
        })
        
        # Checkpoint if needed
        if step % CHECKPOINT_INTERVAL == 0:
            checkpoints['steps'].append(step)
            checkpoints['W'].append(W_current_bf16.clone())
            checkpoints['h_mean'].append(h_mean_current.clone())
        
        # Update previous state
        h_mean_prev = h_mean_current
        W_prev = W_current_bf16
        
        # Progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'dead_moving': f'{dead_moving:.3f}' if not np.isnan(dead_moving) else 'N/A',
            'epoch': epoch
        })
        pbar.update(1)

pbar.close()
print(f"\nTraining complete. Final loss: {loss.item():.4f}")
print(f"Completed {epoch} epochs")

Starting training for 5000 steps...


Training: 100%|██████████| 5000/5000 [09:15<00:00,  8.99it/s, loss=5.4688, dead_moving=0.075, epoch=61]


Training complete. Final loss: 5.4688
Completed 61 epochs





## Save Metrics CSV

In [12]:
df = pd.DataFrame(metrics_log)
csv_path = output_path / 'metrics.csv'
df.to_csv(csv_path, index=False)

print(f"Saved metrics to {csv_path}")
print(f"  Rows: {len(df)}")
print(f"  Columns: {list(df.columns)}")

Saved metrics to ../../tensors/Marble-1/metrics.csv
  Rows: 5000
  Columns: ['step', 'epoch', 'loss', 'top1', 'top2', 'top3', 'h_autocorr', 'dead_moving']


## Save Checkpoints

In [13]:
# Stack checkpoints into tensors
checkpoint_steps = torch.tensor(checkpoints['steps'], dtype=torch.int64)
checkpoint_W = torch.stack(checkpoints['W'])  # (n_checkpoints, vocab, dim) bfloat16
checkpoint_h_mean = torch.stack(checkpoints['h_mean'])  # (n_checkpoints, dim) float32

checkpoint_data = {
    'steps': checkpoint_steps,
    'W': checkpoint_W.view(torch.uint16),  # Store bfloat16 as uint16
    'h_mean': checkpoint_h_mean,
    'dead_mask': dead_mask,
    'dead_indices': dead_indices,
}

checkpoint_path = output_path / 'checkpoints.safetensors'
save_file(checkpoint_data, str(checkpoint_path))

print(f"Saved checkpoints to {checkpoint_path}")
print(f"  {len(checkpoints['steps'])} checkpoints")
print(f"  File size: {checkpoint_path.stat().st_size / 1e6:.2f} MB")

Saved checkpoints to ../../tensors/Marble-1/checkpoints.safetensors
  11 checkpoints
  File size: 28.21 MB


## Save Metadata

In [14]:
metadata = {
    'experiment': 'Marble 1',
    'series': 'Marble',
    'date': datetime.now().strftime('%Y-%m-%d'),
    'total_steps': TOTAL_STEPS,
    'batch_size': BATCH_SIZE,
    'seq_len': SEQ_LEN,
    'learning_rate': LEARNING_RATE,
    'weight_decay': WEIGHT_DECAY,
    'beta1': BETA1,
    'beta2': BETA2,
    'epsilon': EPSILON,
    'vocab_size': VOCAB_SIZE,
    'hidden_dim': HIDDEN_DIM,
    'num_layers': NUM_LAYERS,
    'num_heads': NUM_HEADS,
    'n_params': n_params,
    'random_seed': RANDOM_SEED,
    'n_dead_tokens': n_dead,
    'n_live_tokens': n_live,
    'final_loss': metrics_log[-1]['loss'],
    'total_epochs': epoch,
    'epoch_length': EPOCH_LENGTH,
    'device': device,
    'checkpoint_interval': CHECKPOINT_INTERVAL,
    'n_checkpoints': len(checkpoints['steps']),
    'notes': '4L, 128D, 4H model. Testing if larger architecture learns beyond unigram.'
}

with open(output_path / 'metadata.json', 'w') as f:
    json.dump(metadata, f, indent=2)

print("Saved metadata.json")
print(f"\n{'='*60}")
print("MARBLE 1 COMPLETE")
print(f"{'='*60}")

Saved metadata.json

MARBLE 1 COMPLETE
