In [1]:
import torch
import torch.nn as nn 
from torch.utils.data import Dataset, DataLoader
from model.holo import HoloConfig, HoloForCausalLM
from transformers import MambaConfig, MambaForCausalLM
import torch.optim as optim
from tqdm.notebook import tqdm 
import numpy as np
import torch.nn.utils.rnn as rnn_utils

In [2]:
import torch
from torch.utils.data import Dataset

class NIAHDataset(Dataset):
    def __init__(self, size, min_len, max_len, vocab_size):
        self.size = size
        self.min_len = min_len
        self.max_len = max_len
        self.vocab_size = vocab_size
        
        # Special Tokens
        self.start_tokens = torch.tensor([3, 4])
        self.trigger_tokens = torch.tensor([5, 4, 3, 6])
        self.flag_token = torch.tensor([2]) 

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        # 1. Randomize Length for this specific sample
        # This prevents the model from overfitting to a specific horizon
        curr_len = np.random.randint(self.min_len, self.max_len + 1)
        
        # 2. Generate Key
        key = torch.randint(10, self.vocab_size, (1,))

        # 3. Calculate Noise
        # Overhead: Start(2) + Flag(1) + Key(1) + Trigger(4) + Target(1) = 9
        noise_len = curr_len - 9
        if noise_len < 0: noise_len = 0 # Safety clipping

        noise = torch.randint(10, self.vocab_size, (noise_len,))

        # 4. Insert Key Randomly
        insert_idx = torch.randint(0, noise_len + 1, (1,)).item()
        
        input_ids = torch.cat([
            self.start_tokens,
            noise[:insert_idx],
            self.flag_token,     # The Flag
            key,                 # The Needle
            noise[insert_idx:],
            self.trigger_tokens,
            key
        ])

        labels = input_ids.clone()
        labels[:-1] = -100 
        
        return {"input_ids": input_ids, "labels": labels}



In [3]:
SEQ_LEN = 20 
VOCAB_SIZE = 1000

train_dataset = NIAHDataset(size=50000, min_len = 10, max_len = 30,  vocab_size=VOCAB_SIZE)

In [4]:
examples = train_dataset[0]
input_ids, labels = examples['input_ids'], examples['labels']

print(f"Shape of the input_ids: {input_ids.shape}")
print(f"Shape of the labels: {labels.shape}")
print("\n")
print(f"Input_ids: {input_ids}")
print(f"Labels: {labels}")

Shape of the input_ids: torch.Size([17])
Shape of the labels: torch.Size([17])


Input_ids: tensor([  3,   4, 987,   2, 246, 694, 712,  73, 654,  81, 297, 429,   5,   4,
          3,   6, 246])
Labels: tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100,  246])


In [5]:
# class NIAHDataset_Version2(Dataset):
#     def __init__(self, size = 5000, seq_len = 256, vocab_size = 1000):
#         """
#         Improved Synthetic Needle-In-A-Haystack Dataset.
        
#         Structure:
#         [Noise Part 1] ... [Needle: KEY, VALUE] ... [Noise Part 2] ... [Query: KEY, VALUE]
        
#         The model sees everything up to the final 'KEY' and must predict the final 'VALUE'.
#         """

#         self.size = size 
#         self.seq_len = seq_len 

#         # We reserve the last token in the vocab as the "Needle Key"
#         # The rest (0 to vocab_size - 2) are used for Noise and Values 
#         self.key_token = vocab_size - 1
#         self.noise_range = vocab_size - 1

#     def __len__(self):
#         return self.size

#     def __getitem__(self, idx):
#         # 1. Generate the Target Value (The Needle's content)
#         # We pick a random token from the noise range
#         target_token = torch.randint(0, self.noise_range, (1, ))

#         # 2. Define the Needle: [KEY, VALUE]
#         needle = torch.tensor([self.key_token, target_token.item()])

#         # 3. Calculate Haystack (Noise) Length
#         # Sequence = [Haystack1] + [Needle(2)] + [Haystack2] + [Query(2)]
#         # Total overhead = 2 (Needle) + 2 (Query Trigger + Target) = 4 tokens
#         haystack_len = self.seq_len - 4

#         # 4. Generate Haystack
#         haystack = torch.randint(0, self.noise_range, (haystack_len, ))

#         # 5. Insert Needle at Random Depth (The "Magic")
#         # Random split point for the haystack
#         insert_idx = torch.randint(0, haystack_len + 1, (1,)).item()

#         context_part1 = haystack[: insert_idx]
#         context_part2 = haystack[insert_idx:]

#         # 6. Construct the Full Input Sequence (Ground Truth)
#         # We construct the VALID sequence ending in the target.
#         # [Part1] [KEY, VAL] [Part2] [KEY, VAL]
#         input_ids = torch.cat([
#             context_part1,
#             needle,         # The inserted needle
#             context_part2,
#             needle          # The query (Trigger + Ground Truth Target)
#         ])

#         # 7. Create Labels (Masked Loss)
#         # We want the model to only learn from the FINAL prediction.
#         labels = input_ids.clone()
        
#         # Mask everything with -100 (Ignored by CrossEntropyLoss)
#         labels[:] = -100
        
#         # Unmask ONLY the last token (The Target Value)
#         # Note: In HF Causal training, label[i] is the target for input[i-1].
#         # So providing the full sequence as labels works perfectly; 
#         # the model tries to predict the last token from the second-to-last.
#         labels[-1] = target_token.item()
        
#         return {
#             "input_ids": input_ids, 
#             "labels": labels
#         }

In [6]:
def pad_collate_fn(batch):
    # Extract inputs and labels
    inputs = [item['input_ids'] for item in batch]
    labels = [item['labels'] for item in batch]
    
    # Pad sequences to the longest in the batch (dynamic padding)
    # batch_first=True makes shape [Batch, Seq_Len]
    inputs_padded = rnn_utils.pad_sequence(inputs, batch_first=True, padding_value=0)
    labels_padded = rnn_utils.pad_sequence(labels, batch_first=True, padding_value=-100)
    
    return {
        "input_ids": inputs_padded,
        "labels": labels_padded
    }

In [7]:
def calculate_accuracy(logits, labels):
    """
    Calculates accuracy for the final token prediction.
    
    Args:
        logits: [Batch, Seq_Len, Vocab] - The raw output from the model.
        labels: [Batch, Seq_Len] - The input_ids (or targets).
    
    Logic:
    The model predicts token[t+1] using the state at token[t].
    We want to check if the model correctly predicted the LAST token (Target).
    The state responsible for predicting the LAST token is the SECOND TO LAST token (Trigger).
    
    Index -1: The actual Target token (Ground Truth).
    Index -2: The Trigger token (Model output here predicts the Target).
    """
    # Check prediction at position -2 (The token BEFORE the target)
    key_logit = logits[..., -2, :] 
    pred_token = torch.argmax(key_logit, dim=-1)
    
    # Check against the actual Target (which is at -1 in the input/labels)
    target_token = labels[..., -1]
    
    return (pred_token == target_token).float().mean().item()

In [8]:
def train_eval_pipeline(model_name, model, train_loader, eval_seq_lens, vocab_size, device):
    """
    Args:
        model_name (str): Name for logging.
        model (nn.Module): The HF-style model (returns outputs.loss).
        train_loader (DataLoader): Loader for the fixed-length training data.
        eval_seq_lens (list): List of lengths to test generalization on (e.g. [256, 512, 1024]).
        vocab_size (int): Vocab size for generating eval data on the fly.
        device (str): 'cuda' or 'cpu'.
    """
    print(f"\n[{model_name}] Starting Training...")
    optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.01)
    model.train()

    # --- TRAINING PHASE ---
    # We use the train_loader which usually has a fixed sequence length (e.g. 256)
    # This replaces the "variable length loop" from the old code, 
    # relying on the model to learn the mechanism from the fixed length dataset.
    
    progress_bar = tqdm(train_loader, desc=f"Training {model_name}", leave=True)
    
    for step, batch in enumerate(progress_bar):
        # Move batch to device
        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)

        # Forward pass (HF models compute loss automatically if labels are passed)
        outputs = model(input_ids=input_ids, labels=labels)
        loss = outputs.loss

        # Backward
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        # Update progress bar
        progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})

    # --- EVALUATION PHASE (Length Generalization) ---
    print(f"\n[{model_name}] Evaluating Length Generalization...")
    results = []
    model.eval()
    
    for length in eval_seq_lens:
        # Generate a small temporary dataset for this specific length
        # This mirrors the 'generator_fn' logic from the old code
        # eval_ds = NIAHDataset(size=100, seq_len=length, vocab_size=vocab_size)
        eval_ds = NIAHDataset(size=100,
                              min_len = length, 
                              max_len = length,  
                              vocab_size=vocab_size)
        eval_loader = DataLoader(eval_ds, batch_size=1, shuffle=False)
        
        accs = []
        with torch.no_grad():
            for batch in eval_loader:
                input_ids = batch["input_ids"].to(device)

                output = model(input_ids = input_ids)
                acc = calculate_accuracy(output.logits, input_ids)
                accs.append(acc)

        avg_acc = np.mean(accs)
        print(f"  Len {length}: {avg_acc:.2%}")
        results.append(avg_acc)
        
    return results


In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"
VOCAB_SIZE = 1000
MIN_SEQ_LEN = 32
MAX_SEQ_LEN = 128
BATCH_SIZE = 32

In [15]:
train_dataset = NIAHDataset(size=50000, min_len = 10, max_len = 30,  vocab_size=VOCAB_SIZE)

# 2. Create the DataLoader
train_loader = DataLoader(
    train_dataset, 
    batch_size=4, 
    shuffle=True, 
    num_workers=2,# Can use parallel loading now
    collate_fn=pad_collate_fn
)

iter_loader = next(iter(train_loader))
inputs_ids, labels = iter_loader['input_ids'], iter_loader['labels']

print(inputs_ids)
print(labels)

tensor([[  3,   4, 207, 941, 101, 193,  45, 283, 841, 749, 391,   2, 842, 429,
         567, 917,   5,   4,   3,   6, 842,   0,   0,   0],
        [  3,   4, 992, 742, 849, 868, 702, 898, 282, 363, 126,   2, 440, 705,
         159, 821, 706, 151, 818,   5,   4,   3,   6, 440],
        [  3,   4, 908,   2, 249,  14, 328,  61,  60, 987, 590,  25, 773, 952,
         483, 951, 646, 119,   5,   4,   3,   6, 249,   0],
        [  3,   4,  75,  87,   2, 920, 790, 693,   5,   4,   3,   6, 920,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0]])
tensor([[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100,  842, -100, -100, -100],
        [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,  440],
        [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -10

In [11]:
train_dataset = NIAHDataset(size=50000, 
                            min_len = MIN_SEQ_LEN, 
                            max_len = MAX_SEQ_LEN,  
                            vocab_size=VOCAB_SIZE)

# 2. Create the DataLoader
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=2,# Can use parallel loading now
    collate_fn=pad_collate_fn
)


holo_config =  HoloConfig(
    d_model=128, 
    num_hidden_layers = 2,
    num_heads = 4,
    vocab_size=VOCAB_SIZE, 
    resid_dropout = 0.0, 
    dropout = 0.0,
    use_version=2
)
holo_model = HoloForCausalLM(holo_config).to(device)

config_kwargs = {
     "vocab_size": VOCAB_SIZE, 
     "ssm_cfg": {"dropout": 0.0 }
}
mamba_config = MambaConfig(
    hidden_size = 128,
    num_hidden_layers = 2, 
    **config_kwargs
)

mamba_model = MambaForCausalLM(mamba_config).to(device)

In [12]:
results = train_eval_pipeline(
        "Mamba", 
        mamba_model, 
        train_loader, 
        eval_seq_lens=[128, 256, 512, 1024, 2048, 4096, 8192], 
        vocab_size=VOCAB_SIZE, 
        device=device
    )


[Mamba] Starting Training...


Training Mamba:   0%|          | 0/1563 [00:00<?, ?it/s]


[Mamba] Evaluating Length Generalization...
  Len 128: 77.00%
  Len 256: 82.00%
  Len 512: 72.00%
  Len 1024: 63.00%
  Len 2048: 60.00%
  Len 4096: 47.00%
  Len 8192: 23.00%


In [13]:
results = train_eval_pipeline(
        "Holo", 
        holo_model, 
        train_loader, 
        eval_seq_lens=[128, 256, 512, 1024, 2048, 4096, 8192],
        vocab_size=VOCAB_SIZE, 
        device=device
    )


[Holo] Starting Training...


Training Holo:   0%|          | 0/1563 [00:00<?, ?it/s]


[Holo] Evaluating Length Generalization...
  Len 128: 0.00%
  Len 256: 0.00%
  Len 512: 0.00%
  Len 1024: 0.00%
  Len 2048: 0.00%
  Len 4096: 0.00%
  Len 8192: 0.00%


### Play around

In [14]:
def generate_niah(batch_size, seq_len, vocab_size):
    key = torch.randint(10, vocab_size, (batch_size,))
    # Ensure noise doesn't contain the "trigger" tokens (3, 4, 5, 6)
    noise = torch.randint(10, vocab_size, (batch_size, seq_len-10))
    start = torch.tensor([3, 4]).expand(batch_size, -1)
    end = torch.tensor([5, 4, 3, 6]).expand(batch_size, -1)

    inputs = []
    for b in range(batch_size):
        seq = torch.cat([start[b], key[b].unsqueeze(0), noise[b], end[b]])
        inputs.append(seq)

    return torch.stack(inputs).to(device), key.to(device)


input_ids, labels = generate_niah(1, 256, 100)
print(f"Shape of the input: {input_ids.shape}")
print(f"Shape of the labels: {labels.shape}")

Shape of the input: torch.Size([1, 253])
Shape of the labels: torch.Size([1])
