In [1]:
import sys

sys.path.append("..")

In [2]:
import torch
import torch.nn as nn 
from torch.utils.data import Dataset, DataLoader
from model.holo import HoloConfig, HoloForCausalLM
from model.long import LongConfig, LongForCausalLM
from transformers import MambaConfig, MambaForCausalLM
from transformers import GPT2Config, GPT2LMHeadModel

import torch.optim as optim
from tqdm.notebook import tqdm 
import numpy as np
import torch.nn.utils.rnn as rnn_utils
import matplotlib.pyplot as plt

In [3]:
import torch
from torch.utils.data import Dataset

class NIAHDataset(Dataset):
    def __init__(self, size, min_len, max_len, vocab_size):
        self.size = size
        self.min_len = min_len
        self.max_len = max_len
        self.vocab_size = vocab_size
        
        # Special Tokens
        self.start_tokens = torch.tensor([3, 4])
        self.trigger_tokens = torch.tensor([5, 4, 3, 6]) # e.g. "The answer is"
        self.flag_token = torch.tensor([2]) 

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        # 1. Randomize Length for this specific sample
        # This prevents the model from overfitting to a specific horizon
        curr_len = np.random.randint(self.min_len, self.max_len + 1)
        
        # 2. Generate Key
        key = torch.randint(10, self.vocab_size, (1,))

        # 3. Calculate Noise
        # Overhead: Start(2) + Flag(1) + Key(1) + Trigger(4) + Target(1) = 9
        noise_len = curr_len - 9
        if noise_len < 0: noise_len = 0 # Safety clipping

        noise = torch.randint(10, self.vocab_size, (noise_len,))

        # 4. Insert Key Randomly
        insert_idx = torch.randint(0, noise_len + 1, (1,)).item()
        
        input_ids = torch.cat([
            self.start_tokens,
            noise[:insert_idx],
            self.flag_token,     # The Flag
            key,                 # The Needle
            noise[insert_idx:],
            self.trigger_tokens,
            key
        ])

        labels = input_ids.clone()

        # Mask everything but unmask the last token (label)
        labels[:-1] = -100 
        labels[-1] = input_ids[-1]

        # UNMASK THE TRIGGER (Optional but recommended for stability)
        # This teaches the model to recognize "The question is coming"
        trigger_len = len(self.trigger_tokens)
        labels[-(trigger_len + 1): -1] = self.trigger_tokens
        
        return {"input_ids": input_ids, "labels": labels}



In [4]:
SEQ_LEN = 20 
VOCAB_SIZE = 1000

train_dataset = NIAHDataset(size=50000, min_len = 10, max_len = 30,  vocab_size=VOCAB_SIZE)

In [5]:
examples = train_dataset[0]
input_ids, labels = examples['input_ids'], examples['labels']

print(f"Shape of the input_ids: {input_ids.shape}")
print(f"Shape of the labels: {labels.shape}")
print("\n")
print(f"Input_ids: {input_ids}")
print(f"Labels: {labels}")

Shape of the input_ids: torch.Size([15])
Shape of the labels: torch.Size([15])


Input_ids: tensor([  3,   4, 127, 159, 903, 832, 895,   2, 168, 461,   5,   4,   3,   6,
        168])
Labels: tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,    5,    4,
           3,    6,  168])


In [6]:
def pad_collate_fn(batch):
    # Extract inputs and labels
    inputs = [item['input_ids'] for item in batch]
    labels = [item['labels'] for item in batch]
    
    # Pad sequences to the longest in the batch (dynamic padding)
    # batch_first=True makes shape [Batch, Seq_Len]
    inputs_padded = rnn_utils.pad_sequence(inputs, batch_first=True, padding_value=0)
    labels_padded = rnn_utils.pad_sequence(labels, batch_first=True, padding_value=-100)
    
    return {
        "input_ids": inputs_padded,
        "labels": labels_padded
    }

In [7]:
def calculate_accuracy(target_logits, labels, top_k=5):
    """
    Calculates accuracy for the final token prediction, ignoring padding.
    Assumes padding in labels is 0 (or invalid tokens), and computes per-sequence lengths.
    """
    # Compute sequence lengths (assuming padding_value=0 in input_ids/labels)
    # This finds the last non-padding position
    mask = labels != 0  # [Batch, Seq_len]
    lengths = mask.sum(dim=1)  # [Batch] effective lengths

    # Ensure lengths are at least 2 (for -2 and -1 positions)
    valid_mask = lengths >= 2  # [Batch] boolean mask for valid sequences

    if not valid_mask.any():
        return 0.0, 0.0  # No valid sequences

    # Filter to only valid batches
    valid_indices = valid_mask.nonzero(as_tuple=False).squeeze(1)  # Indices of valid batches
    valid_lengths = lengths[valid_indices]
    valid_labels = labels[valid_indices]
    valid_logits = target_logits[valid_indices]

    # Compute target positions (last token) and pred positions (token before last)
    target_pos = valid_lengths - 1  # [Valid_Batch]
    pred_pos = valid_lengths - 2    # [Valid_Batch]

    # Gather target tokens
    target_token = valid_labels.gather(dim=1, index=target_pos.unsqueeze(1)).squeeze(1)  # [Valid_Batch]

    # Gather key logits (predictions for the target)
    batch_indices = torch.arange(len(valid_indices), device=labels.device)  # [Valid_Batch]
    key_logit = valid_logits[batch_indices, pred_pos, :]  # [Valid_Batch, Vocab]

    # --- Top-1 Accuracy ---
    pred_token = torch.argmax(key_logit, dim=-1)  # [Valid_Batch]
    top1 = (pred_token == target_token).float().mean().item()

    # --- Top-K Accuracy ---
    _, top_k_indices = torch.topk(key_logit, k=top_k, dim=-1)  # [Valid_Batch, k]
    target_expanded = target_token.unsqueeze(1)  # [Valid_Batch, 1]
    topk_acc = (top_k_indices == target_expanded).any(dim=1).float().mean().item()

    return top1, topk_acc

In [15]:
def train_eval_pipeline(model_name, model, train_loader, eval_seq_lens, vocab_size, device):
    """
    Args:
        model_name (str): Name for logging.
        model (nn.Module): The HF-style model (returns outputs.loss).
        train_loader (DataLoader): Loader for the fixed-length training data.
        eval_seq_lens (list): List of lengths to test generalization on (e.g. [256, 512, 1024]).
        vocab_size (int): Vocab size for generating eval data on the fly.
        device (str): 'cuda' or 'cpu'.
    """
    print(f"\n[{model_name}] Starting Training...")
    optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.01)
    model.train()

    # --- TRAINING PHASE ---
    # We use the train_loader which usually has a fixed sequence length (e.g. 256)
    # This replaces the "variable length loop" from the old code, 
    # relying on the model to learn the mechanism from the fixed length dataset.
    
    progress_bar = tqdm(train_loader, desc=f"Training {model_name}", leave=True)
    
    for step, batch in enumerate(progress_bar):
        # Move batch to device
        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)

        # Forward pass (HF models compute loss automatically if labels are passed)
        outputs = model(input_ids=input_ids, labels = labels)
        loss = outputs.loss

        # Backward
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        # Update progress bar
        progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})

    # --- EVALUATION PHASE (Length Generalization) ---
    print(f"\n[{model_name}] Evaluating Length Generalization...")
    results = {"lengths": [], "acc": [], "top5": []}
    
    model.eval()
    
    for length in eval_seq_lens:
        # Generate a small temporary dataset for this specific length
        # This mirrors the 'generator_fn' logic from the old code
        # eval_ds = NIAHDataset(size=100, seq_len=length, vocab_size=vocab_size)
        eval_ds = NIAHDataset(size=100,
                              min_len = length, 
                              max_len = length,  
                              vocab_size=vocab_size)
        eval_loader = DataLoader(eval_ds, batch_size=1, shuffle=False)
        
        batch_accs = []
        batch_top5s = []

        # Inner loop progress bar (Eval can be slow on long seqs)
        inner_pbar = tqdm(eval_loader, desc=f"Eval L={length}", leave=False)
        
        with torch.no_grad():
            for batch in inner_pbar:
                input_ids = batch["input_ids"].to(device)

                try:
                    # Forward (No labels, we just want logits)
                    output = model(input_ids=input_ids)
                    
                    # Calculate Metrics
                    t1, t5 = calculate_accuracy(output.logits, input_ids)
                    batch_accs.append(t1)
                    batch_top5s.append(t5)
                    
                    # CRITICAL: Free memory immediately
                    del output, input_ids
                    
                except RuntimeError as e:
                    if "out of memory" in str(e).lower():
                        print(f"OOM at length {length}!")
                        torch.cuda.empty_cache()
                        batch_accs.append(0.0) # Penalty for OOM
                        batch_top5s.append(0.0)
                        break # Skip rest of this length
                    else:
                        raise e

        # Aggregate results
        avg_acc = np.mean(batch_accs) if batch_accs else 0.0
        avg_top5 = np.mean(batch_top5s) if batch_top5s else 0.0
        
        print(f"  Length {length}: Top-1={avg_acc:.1%} | Top-5={avg_top5:.1%}")
        
        results["lengths"].append(length)
        results["acc"].append(avg_acc)
        results["top5"].append(avg_top5)
        
        # Clean cache between lengths
        torch.cuda.empty_cache()
        
    return results


In [16]:
device = "cuda" if torch.cuda.is_available() else "cpu"
VOCAB_SIZE = 1000
MIN_SEQ_LEN = 32
MAX_SEQ_LEN = 256
BATCH_SIZE = 32

In [17]:
train_dataset = NIAHDataset(size=50000, min_len = 10, max_len = 30,  vocab_size=VOCAB_SIZE)

# 2. Create the DataLoader
train_loader = DataLoader(
    train_dataset, 
    batch_size=4, 
    shuffle=True, 
    num_workers=2,# Can use parallel loading now
    collate_fn=pad_collate_fn
)

iter_loader = next(iter(train_loader))
inputs_ids, labels = iter_loader['input_ids'], iter_loader['labels']

print(inputs_ids)
print(labels)

tensor([[  3,   4, 662, 767, 675, 214, 117, 862, 618, 832,   2, 168, 860, 122,
         873, 303, 889,  84, 450, 540, 673, 281, 576, 691, 986,   5,   4,   3,
           6, 168],
        [  3,   4, 987, 907, 288, 819, 282, 275, 253, 976,  80, 574, 384, 839,
         126,   2, 617, 845, 606, 701, 510, 774, 180,   5,   4,   3,   6, 617,
           0,   0],
        [  3,   4, 946, 661,  50, 392, 603, 346,  69, 379, 354, 704, 527, 850,
         631,   2, 402, 897, 101,   5,   4,   3,   6, 402,   0,   0,   0,   0,
           0,   0],
        [  3,   4, 333,   2, 586,   5,   4,   3,   6, 586,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0]])
tensor([[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100,    5,    4,    3,    6,  168],
        [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100

In [18]:
train_dataset = NIAHDataset(size=200000, 
                            min_len = MIN_SEQ_LEN, 
                            max_len = MAX_SEQ_LEN,  
                            vocab_size=VOCAB_SIZE)

# 2. Create the DataLoader
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=2, # Can use parallel loading now
    collate_fn=pad_collate_fn
)

# Use version 2 meaning that using the newest LongAttention inside the HOLOBlock
# holo_config =  HoloConfig(
#     d_model=128, 
#     num_hidden_layers = 2,
#     num_heads = 4,
#     vocab_size=VOCAB_SIZE, 
#     resid_dropout = 0.0, 
#     dropout = 0.0,
#     use_version=2
# )
# holo_model = HoloForCausalLM(holo_config).to(device)
long_config = LongConfig(
    vocab_size = VOCAB_SIZE, 
    hidden_size = 128, 
    expansion_ratio = 4, 
    num_hidden_layers = 2, 
    num_heads = 4
)
long_model = LongForCausalLM(long_config).to(device)


config_kwargs = {
     "vocab_size": VOCAB_SIZE, 
     "ssm_cfg": {"dropout": 0.0 }
}
mamba_config = MambaConfig(
    hidden_size = 128,
    num_hidden_layers = 2, 
    **config_kwargs
)
print(long_model)
mamba_model = MambaForCausalLM(mamba_config).to(device)

LongForCausalLM(
  (long_model): LongModel(
    (wte): Embedding(1000, 128)
    (layers): ModuleList(
      (0-1): 2 x LongBlock(
        (attn): LongAttention(
          (q_proj): Linear(in_features=128, out_features=128, bias=False)
          (k_proj): Linear(in_features=128, out_features=128, bias=False)
          (v_proj): Linear(in_features=128, out_features=128, bias=False)
          (conv): Conv1d(128, 128, kernel_size=(4,), stride=(1,), padding=(3,), groups=128)
          (input_gate_proj): Linear(in_features=128, out_features=128, bias=True)
          (output_gate_proj): Linear(in_features=128, out_features=128, bias=True)
          (gamma_proj): Linear(in_features=128, out_features=4, bias=True)
          (o_proj): Linear(in_features=128, out_features=128, bias=False)
          (grp_norm): GroupNorm(4, 128, eps=1e-05, affine=True)
          (mem_norm): GroupNorm(4, 128, eps=1e-05, affine=True)
        )
        (mlp): LongMLP(
          (fc1): Linear(in_features=128, out_feat

In [12]:
# MAX_LENGTH = 8192  # CRITICAL for Needle-in-Haystack

# gpt2_config = GPT2Config(
#     vocab_size=VOCAB_SIZE,
#     # 1. Context Window
#     # GPT-2 has a HARD limit. You must set this >= your max haystack length.
#     n_positions=MAX_LENGTH, 
    
#     # 2. Dimensions (Matching your request)
#     n_embd=128,       # This is "hidden_size"
#     n_layer=2,        # This is "num_hidden_layers"
    
#     # 3. Heads
#     # n_embd (128) must be divisible by n_head. 
#     # 4 heads gives 32 dimension per head (standard).
#     n_head=4, 
    
#     # 4. Cleanup
#     bos_token_id=0,
#     eos_token_id=0,
    
#     # Optional: Disable dropout for pure algorithmic testing (like your Mamba config)
#     resid_pdrop=0.0,
#     embd_pdrop=0.0,
#     attn_pdrop=0.0,
#     use_cache=False # False for training with Gradient Checkpointing, True for generation
# )

# model_gpt2 = GPT2LMHeadModel(gpt2_config).to(device)

# results_gpt2 = train_eval_pipeline(
#         "GPT2-Attention", 
#         model_gpt2, 
#         train_loader, 
#         eval_seq_lens=[2**i for i in range(8, 14)],
#         vocab_size=VOCAB_SIZE, 
#         device=device
#     )

In [13]:
# from transformers import LlamaConfig, LlamaForCausalLM

# # 1. Use Llama Configuration (Uses RoPE by default)
# config_llama = LlamaConfig(
#     vocab_size=VOCAB_SIZE,
#     hidden_size=128,        # Match your dims
#     intermediate_size=512,  # MLP size (usually 4x hidden)
#     num_hidden_layers=2,    # Match your layers
#     num_attention_heads=4,  # Match your heads
#     max_position_embeddings=16384, # RoPE can handle this easily
    
#     # 2. Critical Modern Settings
#     hidden_act="silu",      # Better than GeLU
#     attention_bias=False,   # Flash Attention friendly
#     rms_norm_eps=1e-5,
    
#     # 3. Disable Dropout for pure algorithmic test
#     attention_dropout=0.0,
#     hidden_dropout=0.0,
# )

# # 4. Initialize
# model_llama = LlamaForCausalLM(config_llama).to(device)

# # 5. Run your pipeline
# results_llama = train_eval_pipeline(
#     "Llama-RoPE", 
#     model_llama, 
#     train_loader, 
#     eval_seq_lens=[2**i for i in range(8, 14)], # Up to 8k
#     vocab_size=VOCAB_SIZE, 
#     device=device
# )

In [None]:
results_long = train_eval_pipeline(
        "Long-Attention", 
        long_model, 
        train_loader, 
        eval_seq_lens=[2**i for i in range(8, 18)],
        vocab_size=VOCAB_SIZE, 
        device=device
    )


[Long-Attention] Starting Training...


Training Long-Attention:   0%|          | 0/6250 [00:00<?, ?it/s]

In [None]:
holo_config =  HoloConfig(
    d_model=128, 
    num_hidden_layers = 2,
    num_heads = 4,
    vocab_size=VOCAB_SIZE, 
    resid_dropout = 0.0, 
    dropout = 0.0,
    use_version=2
)
holo_model = HoloForCausalLM(holo_config).to(device)

results_holo = train_eval_pipeline(
        "Holo-Attention", 
        holo_model, 
        train_loader, 
        eval_seq_lens=[2**i for i in range(8, 18)],
        vocab_size=VOCAB_SIZE, 
        device=device
    )

In [None]:
results_mamba = train_eval_pipeline(
        "Mamba", 
        mamba_model, 
        train_loader, 
        eval_seq_lens=[2**i for i in range(8, 18)], 
        vocab_size=VOCAB_SIZE, 
        device=device
    )

In [None]:
eval_seq_lens=[2**i for i in range(8, 21)]


plt.figure(figsize=(10, 6))

# Plotting Mamba
plt.plot(eval_seq_lens, results_mamba, marker='o', linestyle='-', linewidth=2, label='Mamba')

# Plotting Holo
# plt.plot(eval_seq_lens, results_holo1, marker='s', linestyle='--', linewidth=2, label='Holo')
plt.plot(eval_seq_lens, results_holo, marker='s', linestyle='-', linewidth=2, label='Holo')

# Formatting the X-axis to be Logarithmic (essential for exponential lengths)
plt.xscale('log')
plt.xticks(eval_seq_lens, labels=eval_seq_lens) # Set specific ticks for our lengths

# Labels and Title
plt.xlabel('Sequence Length', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.title('Needle In A Haystack: Length Generalization', fontsize=14)
plt.ylim(-0.05, 1.05) # Keep y-axis range clean
plt.grid(True, which="both", ls="-", alpha=0.2)
plt.legend(fontsize=12)

# Save the plot
plt.tight_layout()
plt.savefig('niah_results_plot.png')

### Play around