In [8]:
import sys

sys.path.append("..")

In [9]:
import torch
import torch.nn as nn 
from torch.utils.data import Dataset, DataLoader
from model.holo import HoloConfig, HoloForCausalLM
from model.long import LongConfig, LongForCausalLM
from model.long_new import LongConfig, LongHFModel
from transformers import MambaConfig, MambaForCausalLM
from transformers import GPT2Config, GPT2LMHeadModel

import torch.nn.functional as F
import torch.optim as optim
from tqdm.notebook import tqdm 
import numpy as np
import torch.nn.utils.rnn as rnn_utils
import matplotlib.pyplot as plt

In [10]:
class NeedleHaystackDataset(Dataset):
    def __init__(self, size=2000, min_len=32, max_len=64, vocab_size=128, depth=None):
        """
        depth: Float between 0.0 and 1.0. 
               If None, depth is randomized for every sample (0% to 100%).
               If set (e.g., 0.5), the needle is always placed at 50% context.
        """
        self.size = size
        self.min_len = min_len
        self.max_len = max_len
        self.vocab_size = vocab_size
        self.depth = depth
        
        # We reserve the last token as the specific "Prompt" trigger
        self.prompt_token = torch.tensor([vocab_size - 1]) 
        
    def __len__(self): 
        return self.size

    def __getitem__(self, idx):
        # 1. Determine total length of the sequence
        curr_len = np.random.randint(self.min_len, self.max_len + 1)
        
        # 2. Generate the Key (Needle)
        # Range: [1, vocab-2] to avoid padding (0) and prompt token (vocab-1)
        key = torch.randint(1, self.vocab_size - 1, (1,))
        
        # 3. Generate Noise (Haystack)
        # We need space for 2 prompt tokens and 2 key tokens (4 tokens total overhead)
        noise_len = max(0, curr_len - 4)
        noise = torch.randint(1, self.vocab_size - 1, (noise_len,))
        
        # 4. Determine Insertion Point (Depth)
        if self.depth is not None:
            # Fixed depth (e.g., 0.9 for 90% deep)
            insert_idx = int(noise_len * self.depth)
        else:
            # Fully Random depth (0% to 100%)
            # FIX: Previously this was noise_len // 2 (biased to start)
            insert_idx = torch.randint(0, noise_len + 1, (1,)).item()
        
        # 5. Construct Sequence
        # [Noise Part A] -> [Prompt] -> [Key] -> [Noise Part B] -> [Prompt] -> [Key (Target)]
        input_ids = torch.cat([
            noise[:insert_idx], 
            self.prompt_token, key,      
            noise[insert_idx:], 
            self.prompt_token, key       
        ])
        
        # 6. Create Labels (Mask everything except the final Key)
        labels = input_ids.clone()
        # Mask everything up to the final token
        labels[:-1] = -100 
        
        return {"input_ids": input_ids, "labels": labels}

In [11]:
SEQ_LEN = 20 
VOCAB_SIZE = 1000

train_dataset = NeedleHaystackDataset(size=50000, min_len = 10, max_len = 30,  vocab_size=VOCAB_SIZE)

In [12]:
examples = train_dataset[0]
input_ids, labels = examples['input_ids'], examples['labels']

print(f"Shape of the input_ids: {input_ids.shape}")
print(f"Shape of the labels: {labels.shape}")
print("\n")
print(f"Input_ids: {input_ids}")
print(f"Labels: {labels}")

Shape of the input_ids: torch.Size([21])
Shape of the labels: torch.Size([21])


Input_ids: tensor([798, 618, 524,   2, 479,  52, 854, 920, 999, 268, 259, 783, 286, 141,
        936, 641, 388,  44, 337, 999, 268])
Labels: tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100,  268])


In [13]:
def pad_collate_fn(batch):
    inputs = rnn_utils.pad_sequence([x['input_ids'] for x in batch], batch_first=True, padding_value=0)
    labels = rnn_utils.pad_sequence([x['labels'] for x in batch], batch_first=True, padding_value=-100)
    return {"input_ids": inputs, "labels": labels}


In [14]:
def calculate_accuracy(logits, labels, top_k=5):
    """
    Calculates Top-1 and Top-K accuracy for the final token prediction.
    Assumes NO padding (all sequences in batch end at the last index).
    """
    # 1. Get the actual target token (the very last token in the sequence)
    # Shape: [Batch]
    targets = labels[:, -1]

    # 2. Get the logits used to predict that target (from the second-to-last position)
    # Shape: [Batch, Vocab_Size]
    # Note: Logic is "token at T-2 predicts token at T-1" (0-indexed)
    predictions = logits[:, -2, :]

    # --- Top-1 Accuracy ---
    # Check if the highest probability token matches the target
    top1 = (predictions.argmax(dim=-1) == targets).float().mean().item()

    # --- Top-K Accuracy ---
    # Check if the target is within the top K probability tokens
    _, top_indices = predictions.topk(top_k, dim=-1) # [Batch, K]
    topk = (top_indices == targets.unsqueeze(1)).any(dim=1).float().mean().item()

    return top1, topk

In [15]:
def train_eval_pipeline(model_name, model, train_loader, eval_seq_lens, vocab_size, device):
    """
    Updated to run for 10 Epochs.
    """
    print(f"\n[{model_name}] Starting Training...")
    optimizer = optim.AdamW(model.parameters(), lr=0.003, weight_decay=0.01)
    
    # --- NEW: Define Epochs ---
    NUM_EPOCHS = 10
    model.train()

    # --- TRAINING PHASE ---
    for epoch in range(NUM_EPOCHS):
        epoch_loss = []
        # Wrap the loader in a progress bar for each epoch
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [{model_name}]", leave=True)
        
        for batch in progress_bar:
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)

            # Forward pass
            outputs = model(input_ids=input_ids)
            logits = outputs.logits
            
            # Shift for Causal LM loss
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            
            loss = F.cross_entropy(
                shift_logits.view(-1, vocab_size), 
                shift_labels.view(-1)
            )
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update progress bar
            epoch_loss.append(loss.item())
            progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
        
        print(f"Epoch {epoch+1} Average Loss: {np.mean(epoch_loss):.4f}")

    # --- EVALUATION PHASE (Length Generalization) ---
    # (Remains 1 pass per length as discussed, to test the model's final state)
    print(f"\n[{model_name}] Evaluating Length Generalization...")
    results = {"lengths": [], "acc": [], "top5": []}
    
    model.eval()
    for length in eval_seq_lens:
        eval_ds = NeedleHaystackDataset(size=100, min_len=length, max_len=length, vocab_size=vocab_size)
        eval_loader = DataLoader(eval_ds, batch_size=1, shuffle=False)
        
        batch_accs, batch_top5s = [], []
        inner_pbar = tqdm(eval_loader, desc=f"Eval L={length}", leave=False)
        
        with torch.no_grad():
            for batch in inner_pbar:
                input_ids = batch["input_ids"].to(device)
                try:
                    output = model(input_ids=input_ids)
                    t1, t5 = calculate_accuracy(output.logits, input_ids)
                    batch_accs.append(t1)
                    batch_top5s.append(t5)
                    del output, input_ids
                except RuntimeError as e:
                    if "out of memory" in str(e).lower():
                        print(f"OOM at length {length}!")
                        torch.cuda.empty_cache()
                        batch_accs.append(0.0)
                        batch_top5s.append(0.0)
                        break
                    else: raise e

        avg_acc = np.mean(batch_accs) if batch_accs else 0.0
        avg_top5 = np.mean(batch_top5s) if batch_top5s else 0.0
        
        print(f"  Length {length}: Top-1={avg_acc:.1%} | Top-5={avg_top5:.1%}")
        results["lengths"].append(length)
        results["acc"].append(avg_acc)
        results["top5"].append(avg_top5)
        torch.cuda.empty_cache()
        
    return results

In [16]:
device = "cuda" if torch.cuda.is_available() else "cpu"
VOCAB_SIZE = 128
MIN_SEQ_LEN = 32
MAX_SEQ_LEN = 128
BATCH_SIZE = 32

In [17]:
train_dataset = NeedleHaystackDataset(size=10000, 
                            min_len = MIN_SEQ_LEN, 
                            max_len = MAX_SEQ_LEN,  
                            vocab_size=VOCAB_SIZE)

# 2. Create the DataLoader
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    collate_fn=pad_collate_fn
)


config_kwargs = {
     "vocab_size": VOCAB_SIZE, 
     "ssm_cfg": {"dropout": 0.0 }
}
mamba_config = MambaConfig(
    hidden_size = 64,
    num_hidden_layers = 2, 
    **config_kwargs
)
mamba_model = MambaForCausalLM(mamba_config).to(device)

In [18]:
long_config = LongConfig(
    vocab_size = VOCAB_SIZE, 
    hidden_size = 64, 
    expansion_ratio = 4, 
    num_hidden_layers = 2, 
    num_heads = 4
)
long_model = LongHFModel(long_config).to(device)



results_long = train_eval_pipeline(
        "Long-Attention", 
        long_model, 
        train_loader, 
        eval_seq_lens=[2**i for i in range(8, 18)],
        vocab_size=VOCAB_SIZE, 
        device=device
    )

LongHFModel has generative capabilities, as `prepare_inputs_for_generation` is explicitly defined. However, it doesn't directly inherit from `GenerationMixin`. From ðŸ‘‰v4.50ðŸ‘ˆ onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.



[Long-Attention] Starting Training...


Epoch 1/10 [Long-Attention]:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 1 Average Loss: 4.9013


Epoch 2/10 [Long-Attention]:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 2 Average Loss: 3.7384


Epoch 3/10 [Long-Attention]:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 3 Average Loss: 0.0216


Epoch 4/10 [Long-Attention]:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 4 Average Loss: 0.0038


Epoch 5/10 [Long-Attention]:   0%|          | 0/313 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
holo_config =  HoloConfig(
    d_model=128, 
    num_hidden_layers = 2,
    num_heads = 4,
    vocab_size=VOCAB_SIZE, 
    resid_dropout = 0.0, 
    dropout = 0.0,
    use_version=2
)
holo_model = HoloForCausalLM(holo_config).to(device)

results_holo = train_eval_pipeline(
        "Holo-Attention", 
        holo_model, 
        train_loader, 
        eval_seq_lens=[2**i for i in range(8, 18)],
        vocab_size=VOCAB_SIZE, 
        device=device
    )

In [None]:
results_mamba = train_eval_pipeline(
        "Mamba", 
        mamba_model, 
        train_loader, 
        eval_seq_lens=[2**i for i in range(8, 18)], 
        vocab_size=VOCAB_SIZE, 
        device=device
    )

In [None]:
eval_seq_lens=[2**i for i in range(8, 21)]


plt.figure(figsize=(10, 6))

# Plotting Mamba
plt.plot(eval_seq_lens, results_mamba, marker='o', linestyle='-', linewidth=2, label='Mamba')

# Plotting Holo
# plt.plot(eval_seq_lens, results_holo1, marker='s', linestyle='--', linewidth=2, label='Holo')
plt.plot(eval_seq_lens, results_holo, marker='s', linestyle='-', linewidth=2, label='Holo')

# Formatting the X-axis to be Logarithmic (essential for exponential lengths)
plt.xscale('log')
plt.xticks(eval_seq_lens, labels=eval_seq_lens) # Set specific ticks for our lengths

# Labels and Title
plt.xlabel('Sequence Length', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.title('Needle In A Haystack: Length Generalization', fontsize=14)
plt.ylim(-0.05, 1.05) # Keep y-axis range clean
plt.grid(True, which="both", ls="-", alpha=0.2)
plt.legend(fontsize=12)

# Save the plot
plt.tight_layout()
plt.savefig('niah_results_plot.png')

### Play around