In [8]:
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import gc
import random
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader, Dataset

# Ensure the parent directory is in the path to find 'model/' folder
sys.path.append("..")

# --- Imports from your structure ---
try:
    from model.long import LongConfig, LongForCausalLM
    from transformers import MambaConfig, MambaForCausalLM
    from transformers import GPT2Config, GPT2LMHeadModel
except ImportError:
    print("Warning: Custom model modules (Holo/Long) not found. Ensure 'model' folder is in path.")

# Check for GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on device: {device}")

Running on device: cuda


In [9]:
class ListOpsGenerator:
    def __init__(self, max_depth=5):
        self.max_depth = max_depth
        # <PAD> is at index 0
        self.tokens = ["<PAD>", "[", "]", "MIN", "MAX", "MED", "SM"] + [str(i) for i in range(10)]
        self.vocab = {t: i for i, t in enumerate(self.tokens)}
        self.rev_vocab = {i: t for t, i in self.vocab.items()}
        
        # --- CRITICAL FIXES ---
        self.PAD_TOKEN_ID = 0      # Use this for INPUTS (Embedding layer)
        self.IGNORE_INDEX = -100   # Use this for TARGETS (Loss function) - Not used for single answer, but good practice

    def generate_tree(self, current_depth):
        # Base case: depth 0 or small random chance to stop early
        if current_depth == 0 or random.random() < 0.1:
            return str(random.randint(0, 9))
        
        op = random.choice(["MIN", "MAX", "MED", "SM"])
        # Reduce max children slightly to keep length manageable
        num_children = random.randint(2, 4) 
        children = [self.generate_tree(current_depth - 1) for _ in range(num_children)]
        return f"[{op} " + " ".join(children) + "]"

    def solve(self, sequence):
        # FIX: Add spaces around brackets so they split into separate tokens!
        # [MIN 8 4] -> ( MIN 8 4 ) -> ['(', 'MIN', '8', '4', ')']
        tokens = sequence.replace("[", " ( ").replace("]", " ) ").split()
        
        def parse(toks):
            token = toks.pop(0)
            if token == "(":
                op = toks.pop(0)
                vals = []
                while toks[0] != ")":
                    vals.append(parse(toks))
                toks.pop(0) # Remove )
                
                if op == "MIN": return min(vals)
                if op == "MAX": return max(vals)
                if op == "MED": return int(np.median(vals))
                if op == "SM": return sum(vals) % 10
            else:
                return int(token)
        try:
            return parse(tokens.copy())
        except Exception as e:
            # Print the actual error message for debugging next time
            print(f"Solver failed on: {sequence} | Error: {e}")
            return None

    def generate_sample(self, target_length):
        while True:
            depth = random.randint(2, self.max_depth)
            seq_str = self.generate_tree(depth)
            token_strs = seq_str.replace("[", " [ ").replace("]", " ] ").split()
            tokens = [self.vocab[t] for t in token_strs]
            
            if len(tokens) <= target_length:
                break
        
        answer = self.solve(seq_str)
        if answer is None:
            return self.generate_sample(target_length)

        # Padding for Input
        padding_needed = target_length - len(tokens)
        input_ids = tokens + [self.PAD_TOKEN_ID] * padding_needed
        
        target_ids = [self.IGNORE_INDEX] * (len(tokens) - 1)  # Ignore intermediate tokens
        target_ids.append(self.vocab[str(answer)])           # Target is the answer at the end of the sequence
        target_ids += [self.IGNORE_INDEX] * padding_needed   # Ignore padding
        
        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(target_ids, dtype=torch.long)
        
class ListOpsStaticDataset(Dataset):
    def __init__(self, generator, num_samples, length):
        self.generator = generator
        self.length = length
        self.samples = []
        
        # PRE-GENERATE data so it stays fixed
        print(f"    ...Pre-generating {num_samples} samples of length {length}...")
        for _ in range(num_samples):
            self.samples.append(self.generator.generate_sample(self.length))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # Return the saved sample (Input, Target)
        return self.samples[idx]

In [10]:
def get_model(model_name, vocab_size, max_seq_len, 
              hidden_dim=128, num_layers=6, num_heads=4, device='cuda'):
    """
    Initializes models with flexible configuration.
    Default (ListOps Baseline): DIM=128, LAYERS=6, HEADS=4
    """
    
    # Ensure GPT-2/Transformer position embeddings fit the data
    safe_max_pos = max(max_seq_len, 4096) 

    if model_name == "Long-LLM":
        config = LongConfig(
            vocab_size = vocab_size, 
            hidden_size = hidden_dim, 
            num_hidden_layers = num_layers, 
            num_heads = num_heads,
            max_position_embeddings = safe_max_pos,
            expansion_ratio = 8/3, 
            hybrid_ratio = 0,
            gate_init_bias = 0.0,
        )
        model = LongForCausalLM(config)

    elif model_name == "Mamba":
        config = MambaConfig(
            vocab_size = vocab_size,
            hidden_size = hidden_dim,
            num_hidden_layers = num_layers,
            ssm_cfg = {"dropout": 0.0},
            
            # Mamba state_size mặc định là 16, với ListOps có thể giữ nguyên
        )
        model = MambaForCausalLM(config)

    elif model_name == "GPT-2":
        config = GPT2Config(
            vocab_size = vocab_size, 
            n_positions = safe_max_pos, 
            n_embd = hidden_dim, 
            n_layer = num_layers, 
            n_head = num_heads,
            resid_pdrop = 0.1, # Nên để dropout nhẹ (0.1) khi train model lớn hơn
            embd_pdrop = 0.1, 
            attn_pdrop = 0.1, 
            use_cache = False
        )
        model = GPT2LMHeadModel(config)
    else:
        raise ValueError(f"Unknown model: {model_name}")
        
    return model.to(device)

In [11]:
def train_model(model_name, model, train_loader, epochs=20, lr=5e-4, max_grad_norm=1.0): # Lowered LR
    model.train()
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=-100)
    
    vocab = train_loader.dataset.generator.vocab 

    for epoch in range(epochs):
        total_loss = 0
        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        
        for batch_idx, (inputs, targets) in enumerate(loop):
            inputs, targets = inputs.to(device), targets.to(device)
            # print(targets.shape)
            # print(targets[0, -10:])
            
            batch_size, seq_len = inputs.shape
            # --- FIX: Create Attention Mask ---
            # 1 for valid tokens, 0 for PAD (index 0)
            attention_mask = (inputs != 0).to(device)
            # 2. Reshape to [batch_size, 1, 1, seq_len] 
            # This allows it to be broadcasted across heads and the 'query' dimension

            if model_name != 'Mamba':
                attention_mask = attention_mask.view(batch_size, 1, 1, seq_len)
        
                # 3. (Optional) Expand if your model specifically requires this exact shape
                # Otherwise, PyTorch broadcasting handles this automatically in the model
                attention_mask = attention_mask.expand(batch_size, 4, seq_len, seq_len)
                # print(attention_mask)
                
            optimizer.zero_grad()
            
            # Pass the mask to the model
            outputs = model(inputs, attention_mask=attention_mask)
            logits = outputs.logits # [B, Seq, Vocab]
            
            # (Rest of your loss calculation code remains the same...)
            # target_token_ids = []
            # for t in targets:
            #     tid = vocab[str(t.item())]
            #     target_token_ids.append(tid)
            # target_token_ids = torch.tensor(target_token_ids).to(device)
            
            # last_logits = logits[:, -1, :] 
            loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
            # loss = criterion(last_logits, target_token_ids)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            
            total_loss += loss.item()
            loop.set_postfix(loss=f"{total_loss/(batch_idx+1):.4f}")

            
def evaluate_model(model_name, model, test_loader, generator):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            # 1. Re-create the same mask used in training
            batch_size, seq_len = inputs.shape
            attention_mask = (inputs != 0).to(device)

            if model_name != "Mamba":
                attention_mask = attention_mask.view(batch_size, 1, 1, seq_len)
                attention_mask = attention_mask.expand(batch_size, 4, seq_len, seq_len)
            
            # 2. Forward pass
            outputs = model(inputs, attention_mask=attention_mask)
            logits = outputs.logits
            
            # 3. Get predictions (B, Seq_Len)
            preds = torch.argmax(logits, dim=-1)
            
            # 4. Mask out the -100 positions in targets to calculate accuracy
            # This ensures we only compare valid tokens
            valid_mask = (targets != -100)
            
            # Compare only valid positions
            correct_preds = (preds == targets) & valid_mask
            
            correct += correct_preds.sum().item()
            total += valid_mask.sum().item()
    
    return correct / total if total > 0 else 0

In [12]:
def print_model_stats(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    # print(f"Model Structure:\n{model}\n") # Optional: prints layers
    print("-" * 30)
    print(f"Total Parameters:     {total_params:,}")
    print(f"Trainable Parameters: {trainable_params:,}")
    print("-" * 30)

vocab_size = 1000
MODELS_TO_TEST = ["Mamba", "Long-LLM", "GPT-2"] 
for model_name in MODELS_TO_TEST:
        # Clean Memory
        torch.cuda.empty_cache()
        gc.collect()
        
        # 2. Initialize Model
        if (model_name == "Long-LLM"):
            model = get_model(model_name, vocab_size, num_layers=6, max_seq_len = 512)
        elif (model_name == "GPT-2"):
            model = get_model(model_name, vocab_size, num_layers=6, max_seq_len = 512)
        else:
            model = get_model(model_name, vocab_size, num_layers=12, max_seq_len = 512)
        print(f"Model name: {model_name}")
        print_model_stats(model)

The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and install the kernels library using `pip install kernels` or https://github.com/Dao-AILab/causal-conv1d for causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py.


Model name: Mamba
------------------------------
Total Parameters:     1,527,424
Trainable Parameters: 1,527,424
------------------------------
Model name: Long-LLM
------------------------------
Total Parameters:     2,039,576
Trainable Parameters: 2,039,576
------------------------------
Model name: GPT-2
------------------------------
Total Parameters:     1,842,176
Trainable Parameters: 1,842,176
------------------------------


In [None]:
# 1. PARAMETERS
vocab_size = 1000
# INCREASE THIS! 200 is too small.
TRAIN_SAMPLES = 10000 
TEST_SAMPLES = 1000
CONTEXT_LENGTHS = [512, 1024, 2048] # Example lengths

results = {}

for seq_len in CONTEXT_LENGTHS:
    print(f"\n=== Benchmarking Sequence Length: {seq_len} ===")
    
    # 2. GENERATE DATA ONCE PER LENGTH (Fairness)
    generator = ListOpsGenerator(max_depth=5) # Adjust depth based on length if needed
    
    train_dataset = ListOpsStaticDataset(generator, TRAIN_SAMPLES, seq_len)
    test_dataset = ListOpsStaticDataset(generator, TEST_SAMPLES, seq_len)
    
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
    
    # 3. TEST MODELS ON SAME DATA
    for model_name in ["Long-LLM", "Mamba", "GPT-2"]:
        print(f"  > Training {model_name}...")
        
        # 2. Initialize Model
        if (model_name == "Long-LLM"):
            model = get_model(model_name, vocab_size, num_layers=6, max_seq_len = 512).to(device)
            print(model)
        elif (model_name == "GPT-2"):
            model = get_model(model_name, vocab_size, num_layers=6, max_seq_len = 512).to(device)
        else:
            model = get_model(model_name, vocab_size, num_layers=12, max_seq_len = 512).to(device)
        
        # Train
        train_model(model_name, model, train_loader, epochs=10, lr=5e-4) # 10 epochs is enough if data is 10k
        
        # Evaluate
        acc = evaluate_model(model_name, model, test_loader, generator)
        print(f"  >> {model_name} Accuracy: {acc:.2%}")
        
        # Save result
        if model_name not in results: results[model_name] = []
        results[model_name].append(acc)
        
        # Cleanup
        del model
        torch.cuda.empty_cache()
        gc.collect()


=== Benchmarking Sequence Length: 512 ===
    ...Pre-generating 10000 samples of length 512...
    ...Pre-generating 1000 samples of length 512...
  > Training Long-LLM...
LongForCausalLM(
  (long_model): LongModel(
    (wte): Embedding(1000, 128)
    (layers): ModuleList(
      (0-5): 6 x LongBlock(
        (attn): LongAttention(
          (q_proj): Linear(in_features=128, out_features=128, bias=False)
          (k_proj): Linear(in_features=128, out_features=128, bias=False)
          (v_proj): Linear(in_features=128, out_features=128, bias=False)
          (conv): Conv1d(128, 128, kernel_size=(4,), stride=(1,), padding=(3,), groups=128)
          (input_gate_proj): Linear(in_features=128, out_features=128, bias=True)
          (output_gate_proj): Linear(in_features=128, out_features=128, bias=True)
          (gamma_proj): Linear(in_features=128, out_features=4, bias=True)
          (o_proj): Linear(in_features=128, out_features=128, bias=False)
          (v_norm): LayerNorm((32,), e

Epoch 1/10:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 2/10:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 3/10:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 4/10:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 5/10:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 6/10:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 7/10:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 8/10:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 9/10:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 10/10:   0%|          | 0/625 [00:00<?, ?it/s]

In [None]:
# --- Plotting Results ---
plt.figure(figsize=(10, 6))
colors = ['b', 'g', 'r', 'c', 'm']
markers = ['o', 's', '^', 'D', 'v']

for idx, model_name in enumerate(results):
    # Ensure we only plot if we have results (in case of OOM stops)
    if len(results[model_name]) == len(CONTEXT_LENGTHS):
        plt.plot(CONTEXT_LENGTHS, results[model_name], 
                 label=model_name,
                 color=colors[idx % len(colors)],
                 marker=markers[idx % len(markers)],
                 linewidth=2,
                 markersize=8)

plt.xlabel("Sequence Length", fontsize=12, fontweight='bold')
plt.ylabel("Accuracy", fontsize=12, fontweight='bold')
plt.title(f"ListOps Benchmark (Fixed Data)", fontsize=14, pad=20)
plt.ylim(-0.05, 1.05)
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.savefig("listops.jpg")
plt.show()