In [1]:
import torch 
import torch.nn as nn
from model.holo import HoloConfig, HoloForCausalLM
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from itertools import cycle

from rich.console import Console
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn
from rich.live import Live
from rich.table import Table
from rich.panel import Panel

We keep the num_layers to 2 to conduct the induction head test:
- Layer 1: Copying the tokens
- Layer 2: Try to predict tokens taking the information from layer 1

In [2]:
class InductionDataset(Dataset):
    """
    Generates data in the pattern: [A B ... A] -> Label: [B]
    """
    def __init__(self, size = 1000, seq_len = 30, vocab_size = 1000):
        self.size = size 
        self.seq_len = seq_len 
        self.vocab_size = vocab_size 

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        half_len = self.seq_len // 2
        random_segment = torch.randint(0, self.vocab_size, (half_len, ))

        full_sequence = torch.cat([random_segment, random_segment], dim = 0)

        input_ids = full_sequence
        targets = full_sequence.clone()
        
        # We set the label to -100 so the loss function ignores it.
        # We mask:
        #   - The entire first half (random guessing)
        #   - The very first token of the second half (the "boundary jump")
        #     because predicting the restart point is also impossible.       
        targets[:half_len] = -100
        
        # Note: In standard HF models (like HoloForCausalLM), the labels 
        # are automatically shifted by 1. 
        # - labels[0] is never used.
        # - labels[1] is the target for input[0].
        # - labels[half_len] is the target for input[half_len-1] (The boundary).
        # So masking up to 'half_len' covers the boundary transiti
        return input_ids, targets 

In [7]:
console = Console()
device = "cuda" if torch.cuda.is_available() else "cpu"

LR = 1e-3 
SEQ_LEN = 256
BATCH_SIZE = 32
VOCAB_SIZE = 200
STEPS = 1000


In [8]:
induction_config = HoloConfig(
    num_heads = 8, 
    d_model = 128, 
    num_hidden_layers = 2,
    vocab_size = VOCAB_SIZE,
    holo_expansion_ratio = 2, 
    expansion_factor = 2
)

model = HoloForCausalLM(induction_config)

In [9]:
ds = InductionDataset(seq_len = SEQ_LEN, vocab_size = VOCAB_SIZE)

ds_loader = DataLoader(ds, batch_size = BATCH_SIZE, num_workers = 4)
iterator = cycle(ds_loader)
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

model.train()

# Create the layout elements
progress = Progress(
    SpinnerColumn(),
    TextColumn("[bold blue]{task.description}"),
    BarColumn(),
    TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
    TimeRemainingColumn()
)
task_id = progress.add_task("[green]Training...", total=STEPS)

table = Table(title="Induction Head Metrics")
table.add_column("Step", justify="right", style="cyan")
table.add_column("Loss", justify="right", style="magenta")
table.add_column("Status", justify="center")

# Use 'Live' to render the table and progress bar together
with Live(Panel(table, title="Real-time Stats"), refresh_per_second=10) as live:
    
    # We manually advance the progress bar inside the loop
    # (Not using 'with progress:' context manager to keep layout clean)
    progress.start()
    
    for step in range(1, STEPS + 1):
        # Data
        input_ids, labels = next(iterator)
        input_ids, labels = input_ids.to(device), labels.to(device)

        # Forward / Backward
        outputs = model(input_ids=input_ids, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        # Update Progress
        progress.update(task_id, advance=1)
        
        # Log to Table every 50 steps
        if step % 50 == 0 or step == 1:
            loss_val = loss.item()
            
            # Determine Status Color
            if loss_val < 0.01:
                status = "[bold green]SOLVED[/bold green]"
            elif loss_val < 1.0:
                status = "[yellow]Converging[/yellow]"
            else:
                status = "[red]High[/red]"
            
            # Add row to table
            table.add_row(str(step), f"{loss_val:.4f}", status)
            
            # Force UI update
            live.update(Panel(table, title="Real-time Stats"))

            # Early Exit check
            if loss_val < 0.005:
                progress.stop()
                console.print(f"\n[bold green]✅ SUCCESS! Loss dropped to {loss_val:.5f} at step {step}. Induction Heads active.[/bold green]")
                break
    
    progress.stop()

if loss.item() > 0.1:
    console.print("\n[bold red]❌ Test Failed. Loss did not converge to zero.[/bold red]")