## Robust Filter Attention

### Testing trained language models

### Setup

In [None]:
import numpy as np
import math
import scipy

from matplotlib import pyplot as plt
import matplotlib.cm as cm
from matplotlib import patches
plt.rcParams['figure.figsize'] = [10, 10]
plt.rc('font', size=20)

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.utils.data import Subset

import transformers
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2TokenizerFast
from transformers import DataCollatorForLanguageModeling

from pathlib import Path
import pandas as pd
import os
import csv
import glob
import argparse
import datetime
import time
from tqdm import tqdm # Loading bar
print('Done.')

In [None]:
from utils import complex_matmul, apply_interleaved_rope
from utils import count_parameters, get_layers, seed_everything
print('Done.')

In [None]:
from isotropic_rfa import get_safe_exp_tot, compute_covariance_matrix, compute_covariance_matrix_LHopital
from isotropic_rfa import compute_covariance_matrix_spectral_full, compute_covariance_matrix_residual_diffusion
from isotropic_rfa import compute_exp_kernel_isotropic, compute_residual_norm_isotropic
print('Done.')

In [None]:
from model import resolve_multihead_dims, autoregressive_sample
from model import init_complexlinear, init_complex_matrix, initialize_linear_layers
from model import init_rope, init_decay_per_head, init_linear_bias_slopes
from model import apply_weight_masks
from model import ComplexLinearLayer, ComplexLinearHermitianLayer, ComplextoRealLinearLayer
from model import ComplexRMSNorm
from model import MultiHeadAttentionLayer, MultiheadIsotropicRFA
from model import TransformerBlock, TransformerNetwork
from model import SelfAttentionBlock, RFA_Block
from model import RFATransformerBlock, RFATransformerNetwork
from model import LanguageModel
print('Done.')

In [None]:
from visualization import plot_trajectory, compute_state_matrix, plot_state_matrix, visualize_results
from visualization import visualize_results_attn, _get_visual_modules, visualize_rfa_lm
from visualization import plot_training_progress_lm
print('Done.')

In [None]:
from training import single_epoch_rfa_lm, single_epoch_standard_lm
from training import hook_fn
print('Done.')

In [None]:
parser = argparse.ArgumentParser('DA')
parser.add_argument('--gpu', type=int, default=0) # (Default: 0)
args = parser.parse_args(args=[])
args.device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
print(args.device)
    
seed_everything(seed=2025) # Set random seed

### Plot validation losses

In [None]:
# Load the training data histories

try:
    root_path = os.path.dirname(os.path.abspath(__file__))
except NameError:
    root_path = os.getcwd()

# search_base = os.path.join(root_path, 'saved_models', 'wikitext_103')
search_base = Path(os.path.join(root_path, 'saved_models', 'main_models'))
results_data = {}

for file_path in search_base.rglob("*_epoch_15.pt"):
    folder_name = file_path.parent.name
    parts = folder_name.split("__")
    aid = parts[1] if len(parts) > 1 else folder_name

    print(f"Loading {aid} from: {file_path.name}")
    
    try:
        checkpoint = torch.load(file_path, map_location='cpu', weights_only=False)
        
        # 1. Capture the full history dict here
        full_history = checkpoint.get('history', {})
        
        if full_history:
            # 2. Add 'full_history' to the stored data
            results_data[aid] = {
                'history': full_history.get('val_ppl', []),
                'full_history': full_history, 
                'final_ppl': full_history.get('val_ppl', [0])[-1],
                'state_dict': checkpoint.get('model_state_dict', {})
            }
        
        del checkpoint
    except Exception as e:
        print(f"Could not load {aid}: {e}")

In [None]:
# --- Configuration ---
start_epoch = 5  # Set this to > 1 to zoom into the settle phase
save_pdf = True

if not results_data:
    print("Error: No data loaded.")
else:
    plt.figure(figsize=(10, 6))
    
    # Sort IDs (M0, M1, M2, M7...)
    sorted_keys = sorted(results_data.keys(), key=lambda x: int(x[1:]) if x[1:].isdigit() else 99)
    
    for aid in sorted_keys:
        y_full = results_data[aid]['history']
        
        # Slice the data based on start_epoch (Python is 0-indexed)
        # We use max(0, start_epoch - 1) to handle the 1-based input
        idx = max(0, start_epoch - 1)
        y_sliced = y_full[idx:]
        x_axis = range(idx + 1, len(y_full) + 1)
        
        if len(y_sliced) > 0:
            plt.plot(x_axis, y_sliced, 
                     label=f"{aid} (Final: {results_data[aid]['final_ppl']:.2f})", 
                     linewidth=2, marker='o', markersize=4)

    # Use standard linear scale for the zoom to maximize visual separation
    plt.yscale('linear') 
    plt.xlabel('Epoch')
    plt.ylabel('Validation Perplexity')
#     plt.title(f'RFA Convergence Detail (Epoch {start_epoch} onwards)')
    
    # Move legend to the side to keep the plot area clean
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    plt.grid(True, which="both", ls="-", alpha=0.15)
    
    plt.tight_layout()
    if save_pdf:
        plt.savefig(f'rfa_ablation_zoom_ep{start_epoch}.pdf', bbox_inches='tight')
    plt.show()

In [None]:
if not results_data:
    print("Error: No data loaded.")
else:
    # --- Colorblind-Friendly Palette (Okabe-Ito) ---
    cb_colors = ['#0072B2', '#E69F00', '#CC79A7', '#56B4E9', '#F0E442', '#000000', '#D55E00', '#999999']
    
    plt.figure(figsize=(10, 6))
    
    # 1. Sort by final loss (Highest PPL first)
    sorted_aids = sorted(results_data.keys(), 
                         key=lambda aid: results_data[aid]['final_ppl'], 
                         reverse=True)
    
    # 2. Define the exact names in the order of the sorted data
    manual_names = ['ALiBi (B2)', 'RoPE (B1)', 'RFA (M1)', 'SC-RFA (M2)']

    for i, aid in enumerate(sorted_aids):
        y_full = results_data[aid]['history']
        
        # Slicing logic
        idx = max(0, start_epoch - 1)
        y_sliced = y_full[idx:]
        x_axis = range(idx + 1, len(y_full) + 1)
        
        label_name = manual_names[i] if i < len(manual_names) else aid
        color = cb_colors[i % len(cb_colors)]
        
        if len(y_sliced) > 0:
            plt.plot(x_axis, y_sliced, 
                     label=f"{label_name} ({results_data[aid]['final_ppl']:.2f})", 
                     color=color,
                     linewidth=2.5, # Slightly thicker lines for visibility
                     marker='o', 
                     markersize=5)

    # --- Formatting for High Readability ---
    plt.yscale('linear') 
    
    # Increased font sizes for the axes labels
    plt.xlabel('Epoch', fontsize=16, fontweight='medium')
    plt.ylabel('Validation Perplexity', fontsize=16, fontweight='medium')
    
    # Increased font sizes for the tick numbers (1, 2, 3... and 10, 20, 30...)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    
    # --- Balanced Large Legend ---
    plt.legend(
        loc='upper right', 
        framealpha=0.9, 
        borderaxespad=1.2, 
        fontsize=14,        # Matches tick sizes
        labelspacing=0.4,   # Tight but readable
        handletextpad=0.5
    )
    
    plt.grid(True, which="both", ls="-", alpha=0.15)
    plt.tight_layout()
    
    if save_pdf:
        plt.savefig(f'rfa_ablation_zoom_ep{start_epoch}.pdf', bbox_inches='tight')
    plt.show()

In [None]:
# rope: yellow circle
# alibi: dark blue diamond
# rfa: pink square
# sc_rfa: light blue triangle

In [None]:
plt.figure(figsize=(10, 6))



if not results_data:
    print("Error: No data loaded.")
else:
    # --- Colorblind-Friendly Palette (Okabe-Ito) ---
    cb_colors = ['#0072B2', '#E69F00', '#CC79A7', '#56B4E9', '#F0E442', '#000000', '#D55E00', '#999999']
    
    # --- Professional Style Assets ---
    line_styles = ['--', '-', '-.', ':']
    markers = ['D', 'o', 's', '^']
    
    plt.figure(figsize=(10, 6))
    
    # 1. Sort by final loss (Highest PPL first)
    sorted_aids = sorted(results_data.keys(), 
                         key=lambda aid: results_data[aid]['final_ppl'], 
                         reverse=True)
    
    # 2. Define the exact names in the order of the sorted data
    manual_names = ['ALiBi (B2)', 'RoPE (B1)', 'RFA (M1)', 'SC-RFA (M2)']

    for i, aid in enumerate(sorted_aids):
        y_full = results_data[aid]['history']
        
        # Slicing logic
        idx = max(0, start_epoch - 1)
        y_sliced = y_full[idx:]
        x_axis = range(idx + 1, len(y_full) + 1)
        
        label_name = manual_names[i] if i < len(manual_names) else aid
        
#         Apply the unique style combination
        plt.plot(x_axis, y_sliced, 
                 label=f"{label_name} ({results_data[aid]['final_ppl']:.2f})", 
                 color=cb_colors[i % len(cb_colors)],
                 linestyle=line_styles[i % len(line_styles)],
                 marker=markers[i % len(markers)],
                 linewidth=2.0, 
                 markersize=6,
                 markeredgewidth=1.5
)
    # --- Formatting for High Readability ---
    plt.yscale('linear') 
    plt.xlabel('Epoch', fontsize=16, fontweight='medium')
    plt.ylabel('Validation Perplexity', fontsize=16, fontweight='medium')
    
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    
    # Clean up the spines (Standard for modern journals)
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    # --- Balanced Large Legend ---
    plt.legend(
        loc='upper right', 
        framealpha=0.9, 
        borderaxespad=1.2, 
        fontsize=14, 
        labelspacing=0.4, 
        handletextpad=0.5
    )
    
    plt.grid(True, which="both", ls="--", alpha=0.2)
    plt.tight_layout()
    
    if save_pdf:
        # Saving as PDF is crucial for vector graphics in LaTeX
        plt.savefig(f'rfa_convergence_professional_ep{start_epoch}.pdf', bbox_inches='tight')
    plt.show()

In [None]:
def plot_all_generalization_gaps(results_data, start_epoch=1):
    plt.figure(figsize=(10, 6))
    
    # Sort IDs (B1, M1, M2, M7...)
    sorted_keys = sorted(results_data.keys(), key=lambda x: int(x[1:]) if x[1:].isdigit() else 99)
    
    for aid in sorted_keys:
        # Access the 'full_history' dict we saved during loading
        h = results_data[aid].get('full_history', {})
        
        if not h:
            continue
            
        val_loss = np.array(h.get('val_loss', []))
        train_loss_raw = np.array(h.get('loss', [])) 
        
        if len(val_loss) == 0 or len(train_loss_raw) == 0:
            continue
        
        # 2. Average training batches into epoch-level losses
        num_epochs = len(val_loss)
        batches_per_epoch = len(train_loss_raw) // num_epochs
        
        train_loss_epoch = []
        for i in range(num_epochs):
            start_idx = i * batches_per_epoch
            # Capture remainders in the final epoch
            end_idx = (i + 1) * batches_per_epoch if i < num_epochs - 1 else len(train_loss_raw)
            train_loss_epoch.append(np.mean(train_loss_raw[start_idx:end_idx]))
        
        train_loss_epoch = np.array(train_loss_epoch)
        
        # 3. Calculate Gap
        gap = val_loss - train_loss_epoch
        
        # 4. Slice for start_epoch and Plot
        idx = max(0, start_epoch - 1)
        epochs = range(idx + 1, num_epochs + 1)
        
        plt.plot(epochs, gap[idx:], marker='s', label=f'{aid} Gap', linewidth=2)
#         print(gap)

    plt.axhline(0, color='black', linestyle='--', alpha=0.3)
    plt.xlabel('Epoch')
    plt.ylabel('Val NLL - Train NLL')
    plt.title(f'Generalization Gap Comparison (Epoch {start_epoch}+)')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True, alpha=0.2)
    plt.tight_layout()
    plt.show()

# Call this using your results_data dictionary
plot_all_generalization_gaps(results_data, start_epoch=5)

In [None]:
def plot_all_generalization_gaps(results_data, start_epoch=5):
    plt.figure(figsize=(10, 6))
    
    # 1. Sort by final gap (Highest Gap first = Worst Generalization)
    # We define gap as final val_loss - final train_loss
    sorted_aids = sorted(results_data.keys(), 
                         key=lambda aid: (results_data[aid]['full_history']['val_loss'][-1] - 
                                          np.mean(results_data[aid]['full_history']['loss'][-100:])), 
                         reverse=True)

    manual_names = ['AliBi (B2)', 'RoPe (B1)', 'RFA (M1)', 'SC-RFA (M2)']

    for i, aid in enumerate(sorted_aids):
        h = results_data[aid].get('full_history', {})
        if not h: continue
            
        val_loss = np.array(h.get('val_loss', []))
        train_loss_raw = np.array(h.get('loss', [])) 
        
        # Calculate epoch-level train loss
        num_epochs = len(val_loss)
        batches_per_epoch = len(train_loss_raw) // num_epochs
        train_loss_epoch = np.array([
            np.mean(train_loss_raw[j*batches_per_epoch : (j+1)*batches_per_epoch]) 
            for j in range(num_epochs)
        ])
        
        # --- THE FIX: Look at the Gap Trend ---
        gap = val_loss - train_loss_epoch
        
        idx = max(0, start_epoch - 1)
        epochs = range(idx + 1, num_epochs + 1)
        label_name = manual_names[i] if i < len(manual_names) else aid
        
        plt.plot(epochs, gap[idx:], marker='s', label=f'{label_name}', linewidth=2)

    plt.axhline(0, color='black', linestyle='--', alpha=0.3)
    plt.xlabel('Epoch')
    plt.ylabel('Generalization Gap (Val NLL - Train NLL)')
    plt.title(f'Generalization Gap: Structural Regularization (Epoch {start_epoch}+)')
    
    # Overlay legend top-left (usually gaps grow, so top-left is empty space)
    plt.legend(loc='upper left', framealpha=0.8)
    plt.grid(True, alpha=0.2)
    plt.tight_layout()
    plt.show()
    
plot_all_generalization_gaps(results_data, start_epoch=5)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def plot_relative_ppl_gain_vs_m1(results_data, start_epoch=1, save_pdf=False):
    if not results_data:
        print("Error: No data loaded.")
        return

    # --- Standardized Style Assets (Matching Plot 1) ---
    cb_colors = ['#0072B2', '#E69F00', '#CC79A7', '#56B4E9', '#F0E442', '#000000', '#D55E00', '#999999']
    line_styles = ['--', '-', '-.', ':']
    markers = ['D', 'o', 's', '^']
    
    # --- Exact same sorting logic as Plot 1 to keep colors consistent ---
    sorted_aids = sorted(results_data.keys(), 
                         key=lambda aid: results_data[aid]['final_ppl'], 
                         reverse=True)
    manual_names = ['ALiBi (B2)', 'RoPE (B1)', 'RFA (M1)', 'SC-RFA (M2)']

    # Identify the baseline (M1) series
    # Note: Using M1_power_law as per your snippet logic
    if 'M1_power_law' not in results_data:
        print("M1_power_law data missing.")
        return
    m1_ppl = np.array(results_data['M1_power_law']['full_history']['val_ppl'])

    plt.figure(figsize=(10, 6))

    for i, aid in enumerate(sorted_aids):
        # Determine label name based on index to match Plot 1
        label_base = manual_names[i] if i < len(manual_names) else aid
        
        # Skip plotting M1 as a line since it's the 0% axhline, 
        # but we keep the index 'i' moving to preserve color mapping for others.
        if "RFA (M1)" in label_base or aid == 'M1_power_law':
            continue 

        h = results_data[aid].get('full_history', {})
        current_ppl = np.array(h.get('val_ppl', []))
        
        if len(current_ppl) == 0:
            continue
            
        # Calculate Percent Difference
        pct_diff_ppl = ((current_ppl - m1_ppl) / m1_ppl) * 100
        
        idx = max(0, start_epoch - 1)
        epochs = range(idx + 1, len(pct_diff_ppl) + 1)
        
        plt.plot(epochs, pct_diff_ppl[idx:], 
                 label=f"{label_base} vs M1", 
                 color=cb_colors[i % len(cb_colors)],
                 linestyle=line_styles[i % len(line_styles)],
                 marker=markers[i % len(markers)],
                 linewidth=2.0, 
                 markersize=6,
                 markeredgewidth=1.5)

    # --- Add the M1 baseline ---
    plt.axhline(0, color='black', linestyle='-', linewidth=1.5, label='M1 (RFA) Baseline', alpha=0.6)
    
    # --- Formatting (Matching Plot 1) ---
    plt.xlabel('Epoch', fontsize=16, fontweight='medium')
    plt.ylabel('PPL Difference relative to M1 (%)', fontsize=16, fontweight='medium')
    plt.ylim(-2.5,5)
    
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    plt.legend(
#         loc='best',
        loc='upper right', 
        framealpha=0.9, 
        borderaxespad=1.2, 
        fontsize=13, 
        labelspacing=0.4
    )
    
    plt.grid(True, which="both", ls="--", alpha=0.2)
    plt.tight_layout()
    
    if save_pdf:
        plt.savefig(f'rfa_relative_gain_ep{start_epoch}.pdf', bbox_inches='tight')
    plt.show()

# Execute
plot_relative_ppl_gain_vs_m1(results_data, start_epoch=1, save_pdf=True)

In [None]:
start_epoch

B1 vs. M1 (Blue Line): M1 is 1.3% to 2.0% better than the RoPE baseline throughout the entire run.

M7 vs. M1 (Green Line): M1 is 1.3% better than the fixed Unitary model. This specifically proves that learning the noise parameters ($\mu, \sigma$) provides a measurable performance boost over just having a "fixed" physical structure.

M2 vs. M1 (Orange Line): M2 is slightly better in raw PPL (below the $0$ line). However, when you look back at the Gap plot, you see that M2 has the highest overfitting. This confirms that M2's lead is "brittle"—it achieves lower PPL by sacrificing generalization.

## Below, we load the data and trained models, and test them

### Load Wikitext-103 dataset

In [None]:
# wikitext data files
data_files = {
    "train": "datasets/wikitext-103/train-*.parquet",
    "validation": "datasets/wikitext-103/validation-*.parquet",
    "test": "datasets/wikitext-103/test-*.parquet"
}

# Load dataset
raw_datasets = load_dataset("parquet", data_files=data_files)

### Load BabyLM-2025 dataset

In [None]:
# ### Load Wikitext-103 dataset# Define the directory
# train_dir = "datasets/BabyLM_2025/train_100M"

# # Get all .train files in that folder
# train_files = glob.glob(os.path.join(train_dir, "*.train"))

# # Update data_files dictionary
# data_files = {
#     "train": train_files,
#     "validation": glob.glob("datasets/BabyLM_2025/dev/*.dev"),
#     "test": glob.glob("datasets/BabyLM_2025/test/*.test")
# }

# # Load as 'text' since these are raw .txt files
# raw_datasets = load_dataset("text", data_files=data_files)

# # Now, raw_datasets["train"] will behave as one single, 
# # giant dataset containing every novel and transcript.

In [None]:
# # Load a tokenizer
# tokenizer = AutoTokenizer.from_pretrained("gpt2") 
# tokenizer.pad_token = tokenizer.eos_token

tokenizer = GPT2TokenizerFast.from_pretrained(
    "./gpt2_tokenizer/",
    local_files_only=True  # This GUARANTEES it won't try to use the internet
)
tokenizer.pad_token = tokenizer.eos_token

# Tokenize and Chunk
def tokenize_function(examples, tokenizer=None):
    return tokenizer(examples["text"])

tokenized_datasets = raw_datasets.map(
    tokenize_function, 
    batched=True, 
    num_proc=4, 
    fn_kwargs={"tokenizer": tokenizer},
    remove_columns=["text"]
)

In [None]:
# Group into blocks

def group_texts(examples, block_size):
    from itertools import chain
    # 1. Flatten all the token lists into one long list (Efficiently)
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    
    # 2. Get the total number of tokens
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    
    # 3. Round down to the nearest multiple of block_size
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
        
    # 4. Chop into fixed-size chunks
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    
    # 5. Create the labels (same as inputs for Causal LM)
    result["labels"] = result["input_ids"].copy()
    
    return result

block_size = 512
lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    num_proc=4,
    fn_kwargs={"block_size": 512}
)

In [None]:
## Filter out incorrect lengths

# Run this for all splits (train, validation, test)
lm_datasets = lm_datasets.filter(lambda x: len(x["input_ids"]) == 512)

# Verify one last time
lengths = [len(x) for x in lm_datasets["validation"]["input_ids"]]
print(f"New validation size: {len(lengths)} blocks")
print(f"Non-512 sequences remaining: {sum(1 for l in lengths if l != 512)}")


# Note: Use "validation" or "test" depending on which split caused the error
check_split = "validation" 

lengths = [len(x) for x in lm_datasets[check_split]["input_ids"]]
short_sequences = [l for l in lengths if l != 512]

print(f"--- Analysis for {check_split} split ---")
print(f"Total sequences: {len(lengths)}")
print(f"Number of non-512 sequences: {len(short_sequences)}")

if short_sequences:
    print(f"Shortest sequence found: {min(short_sequences)}")
    print(f"Indices of short sequences: {[i for i, l in enumerate(lengths) if l != 512]}")
    
    
# Note: Use "validation" or "test" depending on which split caused the error
check_split = "test" 

lengths = [len(x) for x in lm_datasets[check_split]["input_ids"]]
short_sequences = [l for l in lengths if l != 512]

print(f"--- Analysis for {check_split} split ---")
print(f"Total sequences: {len(lengths)}")
print(f"Number of non-512 sequences: {len(short_sequences)}")

if short_sequences:
    print(f"Shortest sequence found: {min(short_sequences)}")
    print(f"Indices of short sequences: {[i for i, l in enumerate(lengths) if l != 512]}")

In [None]:
# Get a single example
sample = lm_datasets["train"][0]

print(f"Dataset size: {len(lm_datasets['train'])} blocks")
print(f"Sequence Length: {len(sample['input_ids'])}")
print(f"Features: {lm_datasets['train'].column_names}")

# Verify label alignment (should be identical before the model-side shift)
is_aligned = sample['input_ids'] == sample['labels']
print(f"Labels perfectly aligned with inputs: {is_aligned}")

In [None]:
# Decode the first block
decoded_text = tokenizer.decode(sample['input_ids'])

print("--- SAMPLE DATA START ---")
print(decoded_text[:500]) # Print first 500 characters
print("--- SAMPLE DATA END ---")

In [None]:
# # Sample 1000 tokens from the first few blocks
# all_tokens = []
# for i in range(5):
#     all_tokens.extend(lm_datasets["train"][i]["input_ids"])

# plt.figure(figsize=(10, 4))
# plt.hist(all_tokens, bins=100, color='skyblue', edgecolor='black')
# plt.title("Token ID Distribution (WikiText-103)")
# plt.xlabel("Token ID")
# plt.ylabel("Frequency")
# plt.grid(axis='y', alpha=0.3)
# plt.show()

In [None]:
# lm_datasets["train"]
# # Dataset({
# #     features: ['input_ids', 'attention_mask', 'labels'],
# #     num_rows: 229206
# # })

### Define Model

In [None]:
####################
## MODEL SETTINGS ##
####################

args.batch_size = 16 # Batch size
args.vocab_size = 50257
args.d_e = 256
args.seq_len = 512
args.n_heads = 8
args.d_k_total = args.d_e # Total query-key dim across all heads
args.d_v_total = args.d_e # Total value dim across all heads
# args.num_blocks = 3

# args.max_learned_decay = 1.4 # e/2
args.max_learned_decay = 5.0
args.max_fixed_decay = 5.0 # Can be more aggressive

# Limits for clamping exponent
args.max_exponent = 0
args.min_exponent = -10

args.epsilon = 1E-5 # Stability param

args.compute_metadata = False # Triggers computing various diagnostics; turned off during training
args.compute_pulled_forward_estimates = False # "Project" every past state into every future frame; very expensive.

In [None]:
##############################
## DEFAULT ABLATION OPTIONS ##
##############################
args.causal = True
args.t_equal = True # Equal time intervals?
args.sep_params = False # Use separate params for keys and values?
args.lambda_real_zero = False # Zero out real part of eigenvalues?
args.use_full_residual_norm = 1 # Use the full |R|^2 metric?
args.use_robust_weight = True # Use rational weight rather than softmax
args.additive_bias_type = 1 # (Additive bias: 0 for zero; 1 for DLE; 2 for linear)
args.multiplicative_bias_type = 1 # (Multiplicative bias: 0 for constant; 1 for DLE; 2 for linear)
args.t_shift = None # Default
args.learn_t_shift = True
if args.learn_t_shift == True:
    args.t_shift = None
args.learn_rotations = False # Learned rotations (True), or fixed as in RoPE (False)?
args.learn_decay = False # Learned decay (True), or fixed (False)?
args.rotate_values = True # Rotate/unrotate values?
args.zero_process_noise = False # Zero process noise (sigma^2)?
args.zero_key_measurement_noise = False # Zero key measurement noise (eta^2)?
args.use_total_precision_gate = 1 # Use total-precision gating? (0 = No gate, 1 = precision gate, 2 = learned gate)
args.use_inner_residual = False # Include a residual connection BEFORE output projection?
args.use_outer_residual = True # Include a residual connection AFTER output projection?
args.use_complex_input_norm = 0 # Use complex-valued RMS Norm AFTER input projection for query/key/value (1), complex-valued RMS Norm AFTER input projection only for query/key (2), or None (0)?
args.use_complex_output_norm = False # Use complex-valued RMS Norm BEFORE output projection?
args.use_real_input_norm = True # Use real-valued RMS Norm BEFORE input projection?
args.use_real_output_norm = True # Use real-valued RMS Norm AFTER output projection?
args.add_gaussian_noise = False # Add Gaussian noise to final token? (for test-time sampling)
args.use_complex_conj_constraint = True # Eigenvalues must appear in complex conjugate pairs to ensure A is real
args.use_colored_prior = False
# args.allow_BM_branch = True # Allow separate branch for Brownian motion? (only used when learning decay)
args.scale_decay_by_time_interval = True
args.zero_rotations = False
args.use_ss_process_noise = False
args.damping = 0.05

args.use_rope = True
args.use_alibi = False
# args.use_xpos = False

args.use_SC_RoPE = False
args.use_log_linear_decay = True

In [None]:
## ABLATIONS ##

# ---------------------------------------------------

# # BASELINES:

# ---------------------------------------------------

# # B1: Standard Transformer + RoPE Baseline
# #     --- Dot-Product Similarity; current SOTA positional encoding baseline.
# #     Applies d --> 2d --> d attention projections for fair comparison.
# # Uses separate backbone and training looping

# ablation_name = 'B1'
# args.use_rope = True
# args.use_alibi = False
# args.use_relative_decay_vanilla = False

# ---------------------------------------------------

# # B2: Standard Transformer + ALiBi Baseline
# #     Applies d --> 2d --> d attention projections for fair comparison.
# # Uses separate backbone and training looping

# ablation_name = 'B2'
# args.use_rope = False
# args.use_alibi = True
# args.use_relative_decay_vanilla = False

# ---------------------------------------------------

# # B3 (RoPE + decay):
# ablation_name = 'B3'
# args.use_rope = True
# args.use_alibi = False
# args.use_SC_RoPE = False
# args.scale_decay_by_time_interval = False
# args.damping = 0.05
# args.use_relative_decay_vanilla = True

# ---------------------------------------------------

# # B4 (SC-RoPE):
# ablation_name = 'B4'
# args.use_rope = False
# args.use_alibi = False
# args.use_SC_RoPE = True
# args.scale_decay_by_time_interval = False
# args.damping = 0.05
# args.use_relative_decay_vanilla = True

# ------------------------------------------------------------------------------------------------------
# ------------------------------------------------------------------------------------------------------

# # MAIN MODELS:

# ---------------------------------------------------

# # M1: Isotropic RFA

# # Default settings
# ablation_name = 'M1'
# args.use_log_linear_decay = False
# args.damping = 0.05

# ---------------------------------------------------

# # # M1.1: Isotropic RFA w/o DLE Prior

# # Default settings
# ablation_name = 'M1.1'
# args.use_log_linear_decay = False
# args.damping = 0.05
# args.zero_process_noise = True
# args.zero_key_measurement_noise = True
# args.additive_bias_type = 0
# args.multiplicative_bias_type = 0
# args.use_log_linear_decay = False

# ---------------------------------------------------

# # M1.2: Isotropic RFA w/o robust weight

# # Default settings
# ablation_name = 'M1.2'
# args.use_log_linear_decay = False
# args.damping = 0.05
# args.use_robust_weight = False

# ---------------------------------------------------

# # M2: Spectrally-Coupled RFA
# # # --- M1, but we partition the angular frequencies so that the fast frequencies
# # # are coupled with fast decays and vice versa

# ablation_name = 'M2'
# args.use_SC_RoPE = True
# args.scale_decay_by_time_interval = False
# args.damping = 0.005
# args.use_log_linear_decay = False

# ---------------------------------------------------

# # M2.1 (M2 + no robust weight)
# ablation_name = 'M2.1'
# args.use_SC_RoPE = True
# args.scale_decay_by_time_interval = False
# args.damping = 0.05
# args.use_robust_weight = False
# args.use_log_linear_decay = False

# ---------------------------------------------------

# # M2.2: Spectrally-Coupled RFA, w/o DLE Prior
# # # --- M2, but we ablate the DLE

# ablation_name = 'M2.2'
# args.use_SC_RoPE = True
# args.scale_decay_by_time_interval = False
# args.damping = 0.05
# args.zero_process_noise = True
# args.zero_key_measurement_noise = True
# args.additive_bias_type = 0
# args.multiplicative_bias_type = 0
# args.use_log_linear_decay = False

# ---------------------------------------------------

# M2.3: (M2 + ablate out only multiplicative gate but keep additive term)

# ablation_name = 'M2.3'
# args.use_SC_RoPE = True
# args.scale_decay_by_time_interval = False
# args.damping = 0.05
# args.additive_bias_type = 1 # Keep additive term
# args.multiplicative_bias_type = 0 # Set multiplicative term to constant
# args.use_log_linear_decay = False

# ---------------------------------------------------

# # M2.4: (M2 + No Value rotations)

# ablation_name = 'M2.4'
# args.use_SC_RoPE = True
# args.scale_decay_by_time_interval = False
# args.damping = 0.05
# args.use_log_linear_decay = False
# args.rotate_values = False

# ---------------------------------------------------

# # M2.5: (M2 + No Rotations)
# ablation_name = 'M2.5'
# args.use_SC_RoPE = True
# args.scale_decay_by_time_interval = False
# args.damping = 0.05
# args.use_log_linear_decay = False
# args.zero_rotations = True

# ---------------------------------------------------

# # M2.6: (Unitary, zero noise Limit)
# ablation_name = 'M2.6'
# args.use_SC_RoPE = True
# args.scale_decay_by_time_interval = False
# args.damping = 0.0
# args.use_log_linear_decay = False

# args.use_robust_weight = False
# args.use_full_residual_norm = 0
# args.lambda_real_zero = True
# args.zero_process_noise = True
# args.zero_key_measurement_noise = True
# args.additive_bias_type = 0
# args.multiplicative_bias_type = 0
# args.max_fixed_decay = 0.0 # Zero out decay

# ---------------------------------------------------

# M2.7: Spectrally-Coupled RFA + Total Confidence Gate

ablation_name = 'M2.7'
args.use_SC_RoPE = True
args.scale_decay_by_time_interval = False
args.damping = 0.05
args.use_log_linear_decay = False
args.use_inner_residual = True
args.use_total_precision_gate = 1

In [None]:
print(ablation_name)
print(args)

In [None]:
#####################
## DEFINE BACKBONE ##
#####################

###########################################################

# # Standard Transformer

# args.num_blocks = 6
# backbone = TransformerNetwork(args.d_e, args.d_e*2, args.d_e*4, args.n_heads, args, num_blocks=args.num_blocks, Norm=nn.LayerNorm)

###########################################################

# RFA Transformer

args.num_blocks = 6
backbone = RFATransformerNetwork(args=args, num_blocks=args.num_blocks, n_heads=args.n_heads, input_dim=args.d_e, query_key_dim_total=args.d_k_total, value_dim_total=args.d_v_total, hidden_dim = 4*args.d_v_total, Norm=nn.LayerNorm)

###########################################################

In [None]:
# Wrap it to create the Language Model
model = LanguageModel(
    backbone=backbone, 
    vocab_size=args.vocab_size, 
    embed_dim=args.d_e
).to(args.device)

print(model)

params_list = list(model.parameters()) # Parameters list

print('Total parameter count:', count_parameters(model))

In [None]:
# 1. Initialize the Data Collator
# mlm=False is crucial: it tells the collator we are doing Causal LM, 
# not Masked LM (like BERT). It will ensure labels are handled correctly.
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# 3. Create DataLoaders
train_dataloader = DataLoader(
    lm_datasets["train"],
    shuffle=True,
    batch_size=args.batch_size,
    collate_fn=data_collator
)

eval_dataloader = DataLoader(
    lm_datasets["validation"],
    batch_size=args.batch_size,
    collate_fn=data_collator
)

test_dataloader = DataLoader(
    lm_datasets["test"],
    batch_size=args.batch_size,
    collate_fn=data_collator
)

# 4. Quick Test of the first batch
batch = next(iter(train_dataloader))
print(f"Batch keys: {batch.keys()}")
print(f"Input IDs shape: {batch['input_ids'].shape}")
print(f"Labels shape: {batch['labels'].shape}")

In [None]:
####################
## TRAINING SETUP ##
####################

criterion = nn.CrossEntropyLoss() # Loss

args.num_epochs = 15 # Number of epochs

args.save_model = True

args.save_epochs = 1 # Intervals of epochs to save model
args.show_example_epochs = 1 # Number of epochs between displaying results

#####################

# Create folders for model weights, and loss history

try:
    root_path = os.path.dirname(os.path.abspath(__file__))
except NameError:
    root_path = os.getcwd()

# saved_models_path = os.path.join(root_path, 'saved_models\\wikitext_103\\')
saved_models_path = os.path.join(root_path, 'saved_models\\main_models\\')
# saved_models_path = os.path.join(root_path, 'saved_models\\babylm_2025\\')
model_name = str(model.backbone.__class__.__name__)
date = str(datetime.datetime.today()).split()
date_time = date[0]
# date_time = date[0] + '_' + date[1][0:5].replace(":", "_")
model_path = saved_models_path + model_name + '__' + ablation_name + '__' + date_time + '\\'
# model_path = saved_models_path + '__' + ablation_name + '__' + date_time + '\\'

try:
    os.makedirs(model_path, exist_ok=True)
except:
    pass
# try:
#     os.makedirs(model_weight_path, exist_ok=True)
# #     os.makedirs(model_tensor_path, exist_ok=True)
# except:
#     pass

saved_models_path

In [None]:
## LOAD MODEL ##
################

# checkpoint_path = os.path.join(saved_models_path, 'TransformerNetwork__B4__2026-01-21')
# checkpoint_path = os.path.join(saved_models_path, 'RFATransformerNetwork__M1.2__2026-01-24')
# checkpoint_path = os.path.join(saved_models_path, 'RFATransformerNetwork__M4__2026-01-17')
checkpoint_path = os.path.join(saved_models_path, 'M2_b0.05')
# checkpoint_path = os.path.join(saved_models_path, 'RFATransformerNetwork__M2.7__2026-01-27')
# checkpoint_path = os.path.join(saved_models_path, 'M2.3')
# checkpoint_path = os.path.join(saved_models_path, 'B1_redo')
# 
# checkpoint_path = os.path.join(checkpoint_path, "standard_transformer_epoch_15.pt")
checkpoint_path = os.path.join(checkpoint_path, "afa_lm_epoch_15.pt")

if os.path.exists(checkpoint_path):
    print(f"Loading checkpoint from {checkpoint_path}...")

    checkpoint = torch.load(checkpoint_path, map_location=args.device)
    
#     ###############
# #     Identify the keys that cause the size mismatch (buffers, not weights)
# #     We want to keep weights/biases but drop the positional/masking grids
#     weights = checkpoint['model_state_dict']

#     # Filter out the positional/masking buffers that have the wrong size
#     # This keeps the learned weights but lets the model use its new 4096-sized grids
#     keys_to_drop = [
#         k for k in weights.keys() 
#         if "causal_mask" in k or "rope.cos" in k or "rope.sin" in k or "rope.positions" in k
#     ]

#     for k in keys_to_drop:
#         del weights[k]

#     model.load_state_dict(weights, strict=False)
#     ###############

    # Restore Model Weights
    model.load_state_dict(checkpoint['model_state_dict'])

    # Extract History and Test Results
    history = checkpoint['history']
    start_epoch = checkpoint['epoch'] + 1
    
    # Print Stored Results
    print("-" * 30)
    print(f"Resuming from Epoch: {start_epoch}")
    print(f"Last Val PPL: {history['val_ppl'][-1]:.2f}")
    print('Validation history:', history['val_ppl'])
    
else:
    print("No checkpoint found. Starting from scratch.")
    start_epoch = 0

In [None]:
# 1. Access the omega (frequency) tensor
rfa_layer = model.backbone.blocks[0].attn
omega = rfa_layer.omega_v 

# 2. Compare the first and last head
head_0 = omega[0]
head_last = omega[-1]

# 3. Compute the similarity
are_identical = torch.equal(head_0, head_last)

print(f"Frequencies across heads are identical: {are_identical}")

if are_identical:
    print("CRITICAL: You are using Standard RoPE (Wrong for M8).")
else:
    print("SUCCESS: Your Spectral Coupling is preserved (Right for M8).")
    print(f"Head 0 Max Freq: {head_0.max().item():.6f}")
    print(f"Head Last Max Freq: {head_last.max().item():.6f}")

In [None]:
def check_transformer_damping(model):
    """
    Scans a Transformer model to find attention layers and 
    calculate the effective damping ratio 'b'.
    """
    print(f"{'Block':<8} | {'Type':<15} | {'Avg b':<10} | {'Zero Frac':<10}")
    print("-" * 50)

    # Iterate through all modules to find the attention layers
    for name, module in model.named_modules():
        # Check for Isotropic RFA (M8)
        if "MultiheadIsotropicRFA" in str(type(module)):
            mu = module.mu_v.detach().cpu()
            omega = module.omega_v.detach().cpu()
            
            # Max omega for each head's shard is at the end: omega[:, -1]
            max_omegas = omega[:, -1]
            mask = mu > 0
            
            avg_b = (mu[mask] / max_omegas[mask]).mean().item() if mask.any() else 0.0
            zero_frac = (mu == 0).float().mean().item()
            
            print(f"{name:<8} | RFA (M8)       | {avg_b:.4f}     | {zero_frac:.2f}")

        # Check for Heuristic Attention (M14)
        elif "SpectralCoupledHeuristicAttention" in str(type(module)):
            mu = module.mu.detach().cpu()
            # Frequencies are stored in the SCRoPE or RoPE sub-module
            omega = module.rope.theta.detach().cpu()
            
            # Handle standard RoPE (1D theta) vs SCRoPE (2D theta)
            if omega.ndim == 1:
                max_omega = omega[-1]
                mask = mu > 0
                avg_b = (mu[mask] / max_omega).mean().item() if mask.any() else 0.0
            else:
                max_omegas = omega[:, -1]
                mask = mu > 0
                avg_b = (mu[mask] / max_omegas[mask]).mean().item() if mask.any() else 0.0
                
            zero_frac = (mu == 0).float().mean().item()
            print(f"{name:<8} | Heuristic (M14)| {avg_b:.4f}     | {zero_frac:.2f}")

# Usage:
check_transformer_damping(model)

In [None]:
print("Training complete. Performing final evaluation on Test Set...")
model.eval()

# 1. Use 'sum' reduction to get total negative log-likelihood
total_loss = 0.0
total_tokens = 0
criterion_sum = torch.nn.CrossEntropyLoss(reduction='sum')

with torch.no_grad():
    for batch in tqdm(test_dataloader):
        batch = {k: v.to(args.device) for k, v in batch.items()}
        logits, _ = model(batch["input_ids"], t_measure=None, causal=True)
        
        # Shift logic
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = batch["labels"][:, 1:].contiguous()
        
        # 2. Identify non-padding tokens (ignore_index assumes -100 or your pad_id)
        # If your labels use a specific pad_id, replace -100 with that id.
        mask = (shift_labels != -100) 
        num_tokens = mask.sum().item()
        
        # 3. Calculate sum of losses for this batch
        loss_sum = criterion_sum(
            shift_logits.view(-1, shift_logits.size(-1)), 
            shift_labels.view(-1)
        )
        
        total_loss += loss_sum.item()
        total_tokens += num_tokens

# 4. PPL: exp( (sum of all losses) / (total tokens) )
avg_nll = total_loss / total_tokens
test_ppl = np.exp(avg_nll)

history['test_loss'] = avg_nll # This is the average loss per token
history['test_ppl'] = test_ppl

print(f"Final Test Results | Avg NLL: {avg_nll:.4f} | PPL: {test_ppl:.2f}")

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

def visualize_rfa_lm(model, val_dataset, history, epoch, folder, args,
                    save=False,
                    plot_log_losses_flag=True,
                    plot_last_attn_mat_flag=True,
                    plot_decay_per_iteration=True,
                    plot_noise_params=True,
                    plot_tau_and_nu_flag=True):
    """
    Plots RFA dynamics with colorblind-safe, professional styling.
    Enforces specific physiological limits for SDE parameters.
    """
    model.eval()
    
    # --- Colorblind-Friendly Palette (Okabe-Ito inspired) ---
    cb_colors = [
        '#0072B2', # Deep Blue
        '#E69F00', # Orange
        '#CC79A7', # Soft Magenta
        '#56B4E9', # Sky Blue
        '#F0E442', # Yellow
        '#000000', # Black
        '#D55E00', # Vermillion
        '#999999'  # Gray
    ]
    
    plt.rcParams.update({
        "font.family": "serif",
        "font.size": 11,
        "axes.labelsize": 12,
        "legend.fontsize": 9,
        "grid.alpha": 0.2,
        "grid.linestyle": "-"
    })

    # --- Data Preparation ---
    history_data = {k: np.array(v) for k, v in history.items()}

    def plot_heads(data, ylabel, filename, is_memory_floor=False, is_alpha=False, 
                   y_limit=None, pad=0.05, start_head=0):
        """
        Plots a subset of heads with specific damping/noise limits.
        """
        plt.figure(figsize=(7, 4))
        
        # Determine the range of heads to plot
        # For the "last 6", we start at index 2 (assuming 8 heads total)
        for h in range(start_head, args.n_heads):
            # Additional safety for specific Bayesian plots
            if (is_memory_floor or is_alpha) and h < 2: 
                continue
                
            y = data[:, h]
            plt.plot(y, color=cb_colors[h % len(cb_colors)], label=f'Head {h+1}', 
                     linewidth=1.2, alpha=0.9)

        if y_limit is not None:
            y_min, y_max = y_limit
            y_range = y_max - y_min
            plt.ylim(y_min - y_range * pad, y_max + y_range * pad)

        plt.xlabel('Iteration')
        plt.ylabel(ylabel)
        plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), frameon=False)
        plt.grid(True)
        
        if save:
            os.makedirs(folder, exist_ok=True)
            plt.savefig(os.path.join(folder, f'{filename}_epoch_{epoch}.pdf'), 
                        bbox_inches='tight', dpi=300)
        plt.show()
        plt.close()

    with torch.no_grad():
        rand_idx = np.random.choice(len(val_dataset))
        raw_item = val_dataset[rand_idx]
        
        item = raw_item['input_ids'] if isinstance(raw_item, dict) else raw_item
        if not isinstance(item, torch.Tensor):
            item = torch.tensor(item)
            
        inputs = item.unsqueeze(0).to(args.device)
        _, output_dict = model(inputs, t_measure=None, causal=True)
        attn_mat = output_dict.get('attn_mat', None)
        
        # --- 1. Log Loss Plot ---
        if plot_log_losses_flag:
            plt.figure(figsize=(7, 4))
            losses = np.log(history_data['loss'])
            plt.plot(losses, color=cb_colors[0], alpha=0.15, label='Log Iter Loss')
            if len(losses) > 100:
                smooth = np.convolve(losses, np.ones(100)/100, mode='valid')
                plt.plot(smooth, color='#000000', linewidth=1.5, label='Moving Avg')
            plt.xlabel('Iteration')
            plt.ylabel('log(Loss)')
            plt.legend(loc='upper right', frameon=False)
            plt.grid(True)
            if save:
                plt.savefig(os.path.join(folder, f'log_loss_epoch_{epoch}.pdf'), bbox_inches='tight')
            plt.show()
            plt.close()

        # --- 2. Attention Matrices ---
        if plot_last_attn_mat_flag and attn_mat is not None:
             n_heads = attn_mat.size(1)
             fig, axes = plt.subplots(1, n_heads, figsize=(n_heads * 4, 4))
             if n_heads == 1: axes = [axes]
             for h in range(n_heads):
                 A = attn_mat[0, h].cpu().numpy()
                 axes[h].imshow(A**0.25, cmap='magma') 
                 axes[h].set_axis_off()
                 axes[h].set_title(f"H{h}", fontsize=10)
             plt.tight_layout()
             plt.show() 
             plt.close()

        # --- 3. Learned Decay Tracking ---
        if plot_decay_per_iteration:
             plot_heads(history_data['mu'], r"Damping Rate ($\mu$)", 'decay_evolution')

        # --- 4. Noise Params ---
        if plot_noise_params:
            # Key Noise: eta^2 (Limit: 0.8 - 1.9)
            plot_heads(history_data['eta'], r'$\eta^2$ (Key Noise)', 'eta_sq', 
                       y_limit=(0.8, 2.0))
            
            # Query Noise: gamma^2 (Limit: 0.8 - 1.6)
            plot_heads(history_data['gamma'], r'$\gamma^2$ (Query Noise)', 'gamma_sq', 
                       y_limit=(0.7, 1.6))
            
            # Process Noise: sigma^2
            plot_heads(history_data['sigma'], r'$\sigma^2$ (Process Noise)', 'sigma_sq', y_limit=(0.0, 0.01))
            
            # Steady state process noise: Sigma^2 / 2Mu (Limit: 0.05 - 0.11)
            mu = history_data['mu']
            sigma = history_data['sigma']
            floor = sigma / (2 * mu + 1e-10)
            plot_heads(floor, r'Steady state process noise ($\sigma^2/2\mu$)', 'memory_floor', 
                       is_memory_floor=True, y_limit=(0.05, 0.11))

            # Phase Transition Parameter (Alpha) (Limit: 0.75 - 2.0)
            eta = history_data['eta']
            alpha = eta - floor
            plot_heads(alpha, r'Phase Parameter ($\alpha$)', 'alpha_phase', 
                       is_alpha=True)

#             # --- Noise Ratio: Process / (Key + Query) ---
#             floor = sigma / (2 * mu + 1e-10)
#             # Noise Ratio: Process / (Key + Query)
#             # Use np.divide to handle array-wise division safely
#             noise_ratio = np.divide(sigma, (eta + history_data['gamma'] + 1e-10))
#             plot_heads(noise_ratio, r'Noise Ratio [$\sigma^2 / (\eta^2 + \gamma^2)$]', 
#                'noise_ratio_evolution', y_limit=(0.1, 0.5), start_head=2)
            
#             print(eta.size())
        
        # --- 5. Tau and Nu Tracking ---
        if plot_tau_and_nu_flag:
            # Inverse Temp: tau (Limit: 1.3 - 2.6)
            plot_heads(history_data['tau'], r'Inv. Temp ($\tau$)', 'tau_tracking', 
                       y_limit=(1.3, 2.6))
            
            # Robustness: nu/d (Limit: 4.0 - 5.75)
            plot_heads(history_data['nu_over_d'], r'Robustness ($\nu/d$)', 'nu_tracking', 
                       y_limit=(4.0, 5.75))


In [None]:
plot_training_progress_lm(history, 15, checkpoint_path, save=False)

visualize_rfa_lm(model, eval_dataloader.dataset, history, 15, checkpoint_path, args,
                    save=False,
                    plot_log_losses_flag=True,
                    plot_last_attn_mat_flag=True,
                    plot_decay_per_iteration=True,
                    plot_noise_params = True,
                    plot_tau_and_nu_flag=True)

### Extrapolation Tests

In [None]:
(np.arange(8)+1)*512

In [None]:
# Choose your extrapolation length (e.g., 512, 1024, 2048, or 4096)
extrapolation_block_size = 1024

# Map the tokenized data to the new, larger block size
extrapolation_dataset = tokenized_datasets["test"].map(
    group_texts,
    batched=True,
    num_proc=4,
    fn_kwargs={"block_size": extrapolation_block_size}
)

print(f"Extrapolation Dataset size: {len(extrapolation_dataset)} blocks")
print(f"New Sequence Length: {len(extrapolation_dataset[0]['input_ids'])}")

In [None]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Set batch_size=1 to avoid OOM on long sequences
extrapolation_dataloader = DataLoader(
    extrapolation_dataset, 
    batch_size=1, 
    shuffle=False, 
    collate_fn=data_collator
)

# 4. Verification check
test_batch = next(iter(extrapolation_dataloader))
print(f"Extrapolation Batch Input Shape: {test_batch['input_ids'].shape}") # Should be [1, 2048]
print(f"Extrapolation Batch Labels Shape: {test_batch['labels'].shape}")

In [None]:
print(f"Performing Extrapolation Evaluation (Length: {extrapolation_block_size})...")
model.eval()

total_loss = 0.0
total_tokens = 0
criterion_sum = torch.nn.CrossEntropyLoss(reduction='sum')

with torch.no_grad():
    for batch in tqdm(extrapolation_dataloader):
        batch = {k: v.to(args.device) for k, v in batch.items()}
        logits, _ = model(batch["input_ids"], t_measure=None, causal=True)
        
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = batch["labels"][:, 1:].contiguous()
        
        mask = (shift_labels != -100) 
        num_tokens = mask.sum().item()
        
        loss_sum = criterion_sum(
            shift_logits.view(-1, shift_logits.size(-1)), 
            shift_labels.view(-1)
        )
        
        total_loss += loss_sum.item()
        total_tokens += num_tokens

avg_nll = total_loss / total_tokens
extrap_ppl = np.exp(avg_nll)

print(f"Extrapolation Results ({extrapolation_block_size}) | Avg NLL: {avg_nll:.4f} | PPL: {extrap_ppl:.2f}")

In [None]:
results_file = saved_models_path + "extrapolation_results.csv"

# Data to save
row = {
    "ablation_name": ablation_name,
    "block_size": extrapolation_block_size,
    "avg_nll": f"{avg_nll:.4f}",
    "ppl": f"{extrap_ppl:.2f}"
}

file_exists = os.path.isfile(results_file)

with open(results_file, mode='a', newline='') as f:
    writer = csv.DictWriter(f, fieldnames=row.keys())
    if not file_exists:
        writer.writeheader()  # Write header only once
    writer.writerow(row)

print(f"Results appended to {results_file}")

In [None]:
# import gc

# # Clear variables from previous loops
# del logits
# del batch
# gc.collect()

# # Empty the PyTorch cache
# torch.cuda.empty_cache()

In [None]:
def plot_extrapolation_results(csv_path):
    # 1. Load the data
    try:
        df = pd.read_csv(csv_path)
    except FileNotFoundError:
        print(f"Error: {csv_path} not found.")
        return

    # 2. Clean and Filter data
    df['block_size'] = pd.to_numeric(df['block_size'])
    df['ppl'] = pd.to_numeric(df['ppl'])
    
    # Define your specific target lengths
    target_lengths = [512, 1024, 2048, 4096]
    
    # Filter for only those lengths and valid PPLs (ignoring the 51k crash)
    df = df[df['block_size'].isin(target_lengths)]
    df = df[df['ppl'] < 500] 
    
    # Average duplicates
    df = df.groupby(['ablation_name', 'block_size'])['ppl'].mean().reset_index()
    
    # 3. Create the plot
    plt.figure(figsize=(10, 6))
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
    
    for i, name in enumerate(df['ablation_name'].unique()):
        subset = df[df['ablation_name'] == name].sort_values('block_size')
        plt.plot(subset['block_size'], subset['ppl'], 
                 marker='o', markersize=8, 
                 label=name, linewidth=2.5, 
                 color=colors[i % len(colors)])
        print(subset['ppl'])
    
    # 5. Aesthetic formatting
    plt.title("Perplexity Extrapolation: RFA vs. Baseline", fontsize=14, fontweight='bold')
    plt.xlabel("Context Length (Tokens)", fontsize=12)
    plt.ylabel("Perplexity (PPL)", fontsize=12)
    plt.xticks(target_lengths) # Force x-axis to show your exact points
    plt.grid(True, linestyle=':', alpha=0.6)
    
    # 6. Dynamic Y-Axis Framing
    # Find the max PPL in your filtered set to set a reasonable ceiling
    if not df.empty:
        ymax = df['ppl'].max()
        ymin = df['ppl'].min()
        plt.ylim(ymin - 2, ymax + 5) 

    plt.legend(frameon=True, loc='upper left')
    plt.tight_layout()
    plt.savefig("extrapolation_results.png", dpi=300)
    plt.show()
    
plot_extrapolation_results(results_file)

In [None]:
def plot_extrapolation_percentage(csv_path):
    df = pd.read_csv(csv_path)
    df['block_size'] = pd.to_numeric(df['block_size'])
    df['ppl'] = pd.to_numeric(df['ppl'])
    
    # Filter for your specific points and clean
    target_lengths = [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096]
    df = df[df['block_size'].isin(target_lengths) & (df['ppl'] < 200)]
    df = df.groupby(['ablation_name', 'block_size'])['ppl'].mean().reset_index()

    plt.figure(figsize=(10, 6))
    colors = {'M0_Baseline': '#1f77b4', 'M1_RFA': '#ff7f0e'}

    for name in df['ablation_name'].unique():
        subset = df[df['ablation_name'] == name].sort_values('block_size')
        
        # Calculate Percentage Increase relative to length 512
        base_ppl = subset[subset['block_size'] == 512]['ppl'].values[0]
        subset['pct_increase'] = ((subset['ppl'] - base_ppl) / base_ppl) * 100
        
        plt.plot(subset['block_size'], subset['pct_increase'], 
                 marker='o', markersize=8, label=f"{name} (% Increase)", 
                 linewidth=2.5, color=colors.get(name, None))

    plt.title("Extrapolation Penalty: RFA vs. RoPE", fontsize=14, fontweight='bold')
    plt.xlabel("Context Length (Tokens)", fontsize=12)
    plt.ylabel("% Increase in Perplexity (Lower is Better)", fontsize=12)
    plt.grid(True, linestyle=':', alpha=0.6)
    plt.legend()
    
    plt.tight_layout()
    plt.savefig("extrapolation_pct_increase.png", dpi=300)
    plt.show()

plot_extrapolation_percentage(results_file)

In [None]:
def plot_relative_advantage(csv_path):
    df = pd.read_csv(csv_path)
    
    # 1. Ensure numeric types to avoid calculation errors
    df['block_size'] = pd.to_numeric(df['block_size'])
    df['ppl'] = pd.to_numeric(df['ppl'])

    # 2. Pivot carefully. 
    # NOTE: Use the exact strings in your 'ablation_name' column (e.g., 'M0' and 'M1')
    pivot_df = df.pivot_table(index='block_size', columns='ablation_name', values='ppl')
    
    # 3. Drop any lengths where one model is missing (avoids NaNs/Infs)
    pivot_df = pivot_df.dropna(subset=['M0', 'M1'])

    # 4. Calculate Improvement
    pivot_df['improvement'] = ((pivot_df['M0'] - pivot_df['M1']) / pivot_df['M0']) * 100
    
    # 5. Clean up any non-finite values before plotting
    pivot_df = pivot_df[np.isfinite(pivot_df['improvement'])]

    plt.figure(figsize=(10, 5))
    
    # Use .index and .values explicitly to avoid indexing confusion
    plt.plot(pivot_df.index.values, pivot_df['improvement'].values, 
             marker='s', color='green', linewidth=2.5)
    
    print(pivot_df['improvement'].values)
    
    plt.title("RFA Advantage over RoPE as Context Grows", fontsize=14)
    plt.xlabel("Sequence Length", fontsize=12)
    plt.ylabel("% Reduction in Perplexity", fontsize=12)
    plt.grid(True, linestyle='--', alpha=0.7)
    
    # Updated labeling loop
    for x, y in zip(pivot_df.index.values, pivot_df['improvement'].values):
        # If y is positive, put label above. If negative, put label below.
        offset = 0.3 if y >= 0 else -0.5
        plt.text(x, y + offset, f"{y:.1f}%", 
                 ha='center', fontsize=15, fontweight='bold',
                 color='black' if y >= 0 else 'red')

    # Add a horizontal line at 0 to show the "Win/Loss" boundary
    plt.axhline(y=0, color='black', linestyle='-', linewidth=1, alpha=0.8)

    plt.ylim((-1, 7))
    
    plt.tight_layout()
    plt.show()
    
plot_relative_advantage(results_file)

RFA exhibits Length-Dependent Efficiency. While performing competitively within the training horizon, it achieves a 5.5x increase in relative advantage (from 1.1% to 6.2%) as the context window expands to 8x, suggesting that stochastic blurring is a more robust prior for long-range dependencies than rigid positional rotations.

In [None]:
import matplotlib.pyplot as plt

# Data
lengths = [512, 1024, 2048, 4096]
data = {
    'RoPE (B1)': [28.48, 30.94, 44.21, 72.69],
    'ALiBi (B2)': [28.59, 27.30, 26.54, 26.30],
    'RFA (M1)': [28.01, 27.58, 29.99, 38.46],
    'SC-RFA (M2)': [27.54, 26.73, 29.46, 37.19] 
}

# Colorblind-Friendly Palette (Okabe-Ito)
cb_colors = ['#E69F00', '#CC79A7', '#56B4E9', '#0072B2']
# Distinctive line styles
line_styles = ['--', '-', '-.', ':']
# Distinctive markers
markers = ['o', 's', '^', 'D']

plt.figure(figsize=(10, 6))

# Sort keys by final value (at 4096) in descending order (highest loss first)
sorted_keys = sorted(data.keys(), key=lambda k: data[k][-1], reverse=True)

for i, key in enumerate(sorted_keys):
    plt.plot(lengths, data[key], 
             label=f"{key} ({data[key][-1]:.2f})", 
             color=cb_colors[i % len(cb_colors)], 
             linestyle=line_styles[i % len(line_styles)],
             marker=markers[i % len(markers)],
             linewidth=2.5, 
             markersize=10,
             markeredgecolor='white', # Adds a "pop" to markers
             markeredgewidth=1.0)


# --- Publication Quality Formatting ---
plt.xlabel('Context Length (Tokens)', fontsize=16, labelpad=10)
plt.ylabel('Validation Perplexity', fontsize=16, labelpad=10)
plt.xticks(lengths, fontsize=14)
plt.yticks(fontsize=14)

# Use a clean, semi-transparent legend inside the empty top-left space
plt.legend(
    loc='upper left', 
    fontsize=13, 
    frameon=True, 
    fancybox=True, 
    shadow=False, 
    framealpha=0.9, 
    edgecolor='lightgray',
    labelspacing=0.6
)

# Subtle grid for reference without clutter
plt.grid(True, which="both", ls="--", alpha=0.3)

# Remove the top and right spines for a modern, clean look
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.tight_layout()

# Save as PDF for vector graphics (essential for LaTeX)
plt.savefig('rfa_extrapolation_professional.pdf', bbox_inches='tight')
plt.show()



In [None]:
import matplotlib.pyplot as plt

# Data
lengths = [512, 1024, 2048, 4096]
data = {
    r'$\mu = 5 \times 10^{-4}$': [27.60, 28.88, 37.34, 51.48],
    r'$\mu = 5 \times 10^{-3}$': [0, 0, 0, 0],
    r'$\mu = 5 \times 10^{-2}$': [27.54, 26.73, 29.46, 37.19],
    r'$\mu = 5 \times 10^{-1}$': [27.61, 26.38, 26.37, 29.72],
    r'$\mu = 5 \times 10^{0}$': [27.91, 26.68, 26.37, 28.16],
    'ALiBi (B2)': [28.64, 27.31, 26.55, 26.31]
}

# --- Explicit Legend Order (Manual override) ---
ordered_keys = [
    r'$\mu = 5 \times 10^{-4}$',
    r'$\mu = 5 \times 10^{-3}$',
    r'$\mu = 5 \times 10^{-2}$',
    r'$\mu = 5 \times 10^{-1}$',
    r'$\mu = 5 \times 10^{0}$',
    'ALiBi (B2)'
]

# Expanded Okabe-Ito Palette (6 colors)
cb_colors = ['#0072B2', '#E69F00', '#CC79A7', '#56B4E9', '#009E73', '#D55E00']
line_styles = ['-.', '-', '--', '-', (0, (3, 5, 1, 5)), (0, (5, 10))] 
markers = ['o', 's', '^', 'D', 'v', 'p']

plt.figure(figsize=(10, 6))

for i, key in enumerate(ordered_keys):
    plt.plot(lengths, data[key], 
             label=f"{key} ({data[key][-1]:.2f})", 
             color=cb_colors[i], 
             linestyle=line_styles[i],
             marker=markers[i],
             linewidth=3.0, 
             markersize=10,
             markeredgecolor='white', 
             markeredgewidth=1.5)

# --- Publication Quality Formatting ---
plt.xlabel('Context Length (Tokens)', fontsize=16, labelpad=10)
plt.ylabel('Validation Perplexity', fontsize=16, labelpad=10)

# Focusing the view on the active data range
plt.ylim(20, 55) 

plt.xticks(lengths, fontsize=14)
plt.yticks(fontsize=14)

# Legend Configuration
plt.legend(
    loc='upper left', 
    fontsize=12, 
    frameon=True, 
    fancybox=True, 
    framealpha=0.9, 
    edgecolor='lightgray',
    labelspacing=0.5
)

plt.grid(True, which="both", ls="--", alpha=0.3)

# Remove the top and right spines for a clean look
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.tight_layout()
plt.savefig('rfa_ablation_ordered_final.pdf', bbox_inches='tight')
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Data from Table: Sensitivity Analysis of the Damping Factor b in SC-RFA (M2)
b_values = [5e-4, 5e-3, 5e-2, 5e-1, 5e0]
b_labels = [r'$5 \times 10^{-4}$', r'$5 \times 10^{-3}$', r'$5 \times 10^{-2}$', r'$5 \times 10^{-1}$', r'$5 \times 10^{0}$']

# Transpose table data: each list represents a fixed context window L
data_512 = [27.60, 27.60, 27.54, 27.61, 27.91]
data_1024 = [28.88, 28.71, 26.73, 26.38, 26.68]
data_2048 = [37.34, 35.35, 29.46, 26.37, 26.37]
data_4096 = [51.48, 43.90, 37.19, 29.72, 28.16]

lengths_data = {
    'L = 512': data_512,
    'L = 1024': data_1024,
    'L = 2048': data_2048,
    'L = 4096': data_4096
}

ordered_keys = ['L = 512', 'L = 1024', 'L = 2048', 'L = 4096']

# Okabe-Ito Palette and styles
cb_colors = ['#0072B2', '#E69F00', '#CC79A7', '#009E73']
line_styles = ['-', '--', '-.', ':']
markers = ['o', 's', '^', 'D']

plt.figure(figsize=(10, 6))

for i, key in enumerate(ordered_keys):
    plt.plot(b_values, lengths_data[key], 
             label=f"${key}$", 
             color=cb_colors[i], 
             linestyle=line_styles[i],
             marker=markers[i],
             linewidth=3.0, 
             markersize=10,
             markeredgecolor='white', 
             markeredgewidth=1.5)

# Publication Quality Formatting
plt.xscale('log')
plt.xlabel('Damping Factor $b$', fontsize=16, labelpad=10)
plt.ylabel('Validation Perplexity', fontsize=16, labelpad=10)

plt.xticks(b_values, b_labels, fontsize=14)
plt.yticks(fontsize=14)

# Focusing view range
plt.ylim(25, 55)

plt.legend(
    loc='upper right', 
    fontsize=12, 
    frameon=True, 
    fancybox=True, 
    framealpha=0.9, 
    edgecolor='lightgray',
    labelspacing=0.5,
    title='Context Length'
)
plt.setp(plt.gca().get_legend().get_title(), fontsize='14')

plt.grid(True, which="both", ls="--", alpha=0.3)

# Clean spines
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.tight_layout()
plt.savefig('rfa_b_sensitivity_fixed_windows.png', bbox_inches='tight')

### Plot attention matrices

In [None]:
# Setup Input
sample = extrapolation_dataset[0]
input_ids = torch.tensor(sample['input_ids']).unsqueeze(0).to(args.device)
L_total = input_ids.shape[1]
L_total

In [None]:
import matplotlib.pyplot as plt

# This is the "nuke" option for Matplotlib settings
plt.rcdefaults()

In [None]:
# # If using standard attention:
    
# # Forward Pass: Standard Model (M0)
# model.eval()
# with torch.no_grad():
#     _, returned_data_std = model(input_ids, return_attn=True)

# attentions = returned_data_std if isinstance(returned_data_std, list) else returned_data_std.get('attn_mat')

# from matplotlib.colors import PowerNorm

# def plot_final_layer_heads_standard(attentions, layer_idx=-1, view_range=(0, 4096), use_powernorm=False):
#     # Extract sequence limits
#     start, end = view_range
    
#     # Handle both list-of-layers and single-tensor formats
#     if isinstance(attentions, list) or isinstance(attentions, tuple):
#         layer_attn = attentions[layer_idx][0]
#     else:
#         # Assuming [B, L, L, H] -> [H, L, L] for batch 0
#         layer_attn = attentions.permute(0, 3, 1, 2)[0] if attentions.ndim == 4 else attentions

#     num_heads = layer_attn.shape[0]
#     fig, axes = plt.subplots(2, 4, figsize=(20, 10))
#     axes = axes.flatten()

#     for i in range(num_heads):
#         matrix = layer_attn[i].cpu().detach().numpy()
        
#         # Scaling Logic
#         if use_powernorm:
#             # PowerNorm helps reveal long-range "ghost" signals in the baseline
#             norm = PowerNorm(gamma=0.2)
#             im = axes[i].imshow(matrix, cmap='magma', norm=norm)
#         else:
#             # Standard linear scaling clipped at 99.5th percentile
#             vmax = np.percentile(matrix, 99.5)
#             im = axes[i].imshow(matrix, cmap='magma', vmin=0, vmax=vmax)
        
#         # Apply the dynamic view range
#         axes[i].set_xlim(start, end)
#         axes[i].set_ylim(end, start)
#         axes[i].set_title(f"Head {i+1}", fontsize=28)
#         axes[i].axis('off')

#     plt.tight_layout()
#     scale_type = "Power" if use_powernorm else "Linear"
#     plt.suptitle(f"Standard Transformer (M0): {scale_type} Scale | Range {start}-{end}", fontsize=20, y=1.05)
#     plt.show()
    

# # Example usage:
# # plot_final_layer_heads_standard(attentions, view_range=(3000, 4096), use_powernorm=True)

# # Example: Global View
# plot_final_layer_heads_standard(attentions, layer_idx=-1, view_range=(0, 4096), use_powernorm=False)
# # Example: Extrapolation Zoom
# plot_final_layer_heads_standard(attentions, layer_idx=-1, view_range=(3000, 4096), use_powernorm=False)

# # Example: Global View
# plot_final_layer_heads_standard(attentions, layer_idx=-1, view_range=(0, 4096), use_powernorm=True)
# # Example: Extrapolation Zoom
# plot_final_layer_heads_standard(attentions, layer_idx=-1, view_range=(3000, 4096), use_powernorm=True)

In [None]:
# If using RFA models:

model.eval()
model.backbone.args.compute_metadata = True
model.backbone.args.compute_pulled_forward_estimates = False
with torch.no_grad():
    _, output_dict_rfa = model(input_ids, causal=True)

# Extract attention matrices
attentions = output_dict_rfa['attn_mat']

from matplotlib.colors import PowerNorm

def plot_rfa_final_layer(model, attentions, layer_idx=-1, view_range=(0, 4096), use_powernorm=False):
    """
    Standardized plotting for RFA heads with optional PowerNorm.
    """
    start, end = view_range
    
    # Extract attention [H, L, L] for the first batch
    A = attentions.permute(0, 3, 1, 2)[0]
    rfa_module = model.backbone.blocks[layer_idx].attn

    # Physical Constants
    seq_len_train = model.backbone.args.seq_len
    scale_r = model.backbone.args.max_fixed_decay / (seq_len_train - 1)

    fig, axes = plt.subplots(2, 4, figsize=(22, 12))
    axes = axes.flatten()

    for i in range(A.shape[0]):
        matrix = A[i].cpu().numpy()
        
        # Scaling Logic
        if use_powernorm:
            # Amplifies faint signals (good for seeing the diffusive 'haze')
            norm = PowerNorm(gamma=0.2)
            im = axes[i].imshow(matrix, cmap='magma', norm=norm)
        else:
            # Linear scale clipped at 99.5th percentile (good for seeing actual density)
            vmax = np.percentile(matrix, 99.5)
            im = axes[i].imshow(matrix, cmap='magma', vmin=0, vmax=vmax)
        
        # Physics Metrics
        actual_mu = scale_r * rfa_module.mu_v[i].item()
        mu_L = actual_mu * end 
        survival = np.exp(-mu_L) * 100

        axes[i].set_xlim(start, end)
        axes[i].set_ylim(end, start)
        axes[i].set_title(f"Head {i+1}\nμL: {mu_L:.2f}\nSurv: {survival:.1f}%", fontsize=22)
        axes[i].axis('off')

    plt.suptitle(f"RFA Transformer (M1): {'Power' if use_powernorm else 'Linear'} Scale", fontsize=24, y=1.05)
    plt.show()

# Run the analysis
# Global View
plot_rfa_final_layer(model, attentions, layer_idx=-1, view_range=(0, 4096), use_powernorm=False)
# Extrapolation Zoom
plot_rfa_final_layer(model, attentions, layer_idx=-1, view_range=(3000, 4096), use_powernorm=False)

# Global View
plot_rfa_final_layer(model, attentions, layer_idx=-1, view_range=(0, 4096), use_powernorm=True)
# Extrapolation Zoom
plot_rfa_final_layer(model, attentions, layer_idx=-1, view_range=(3000, 4096), use_powernorm=True)

The visual contrast between the Standard Transformer (M0) and the RFA Transformer (M1) provides a direct explanation for the model’s performance stability during long-context extrapolation. In the M0 linear plots, a pervasive "checkerboard" noise is visible across almost every head, characterized by faint, chaotic activations across the entire history. This indicates that the Standard Transformer is wasting significant attention mass on distant tokens it can no longer geometrically resolve, directly contributing to the observed perplexity explosion.

In contrast, the M1 linear plots demonstrate that the "shutting off" of certain heads is a deliberate form of geometric regularization. The RFA model establishes a clear functional hierarchy:
* Global Anchors (Heads 1 and 2): With a decay of $\mu L = 0.00$, these heads maintain $100\%$ survival. They provide the model with a stable long-range memory that remains clear because it is no longer drowned out by the noise of the other six heads.
* Ultra-Local Specialists (Heads 7 and 8): These high-decay heads focus exclusively on the most recent context, visible as a sharp diagonal. By effectively ignoring the distant past, they eliminate the "noise floor" that plagues the standard model.
* Denoising through Specialization: While the first five heads eventually shut off after a certain context length, this represents an intelligent trade-off. Even when memory remains high, the accumulation of Brownian noise makes distant signals unreliable. By refusing to attend to these areas of high uncertainty, these heads effectively denoise the model.

The Standard Transformer fails in extrapolation because every head attempts to "do everything at every distance," leading to chaotic interference. RFA succeeds by enforcing a multi-resolution hierarchy where specialized heads prioritize signal over noise, resulting in a cleaner, more periodic attention structure that drives the 6% perplexity win.

In [None]:
#########################################

In [None]:
model.backbone.args.compute_metadata = True
model.backbone.args.add_gaussian_noise = True
model.backbone.args.compute_pulled_forward_estimates = False

def autoregressive_sample(model, start_seq, max_gen_len, t_measure=None, t_shift=None, t_equal=True, causal=True):
    """
    Performs discrete autoregressive generation and collects SDE precision metadata.
    """
    precisions = []
    
    with torch.no_grad():
        model.eval()
        # current_seq shape: [B, L] (Integers/Token IDs)
        current_seq = start_seq 
        total_seq = start_seq
        window_size = start_seq.size(1)

        for i in tqdm(range(max_gen_len), desc="Generating tokens"):
            # Forward pass: out contains the logits for the whole sequence
            out, output_dict = model(current_seq, t_measure=t_measure, t_shift=t_shift, causal=causal) 

            # Capture Precision for visualization
            p_step = torch.mean(output_dict['P_tot'][:, -1, :], dim=-1) 
            precisions.append(p_step)

            # 1. Convert logits to Token ID (Argmax)
            # out[:, -1, :] is the vocabulary distribution for the latest step
            next_token_id = torch.argmax(out[:, -1, :], dim=-1).unsqueeze(1) # Shape: [B, 1]

            # 2. Append the ID to the sequence
            total_seq = torch.cat([total_seq, next_token_id], dim=1)
            
            # 3. Slide the window to maintain context length
            current_seq = total_seq[:, -window_size:]

    # Extract only the newly generated IDs
    new_seq = total_seq[:, -max_gen_len:]

    return total_seq, new_seq, precisions

def plot_text_precision_heatmap(token_ids, precisions, tokenizer):
    """
    Overlays text with background colors representing SDE precision.
    precisions: normalized precision array [gen_len]
    token_ids: list of token IDs [gen_len]
    """
    from IPython.display import display, HTML
    
    # Ensure precision is 1D and normalized
    p = precisions.flatten() 
    tokens = [tokenizer.decode([tid]) for tid in token_ids]
    
    # Generate HTML with background colors (using the Viridis scale)
    # We map 0.0 -> Purple/Dark and 1.0 -> Yellow/Bright
    html_output = '<div style="line-height: 2.0; font-family: monospace;">'
    
    for token, score in zip(tokens, p):
        # Calculate color intensity (using a simple yellow/green scale for trust)
        # alpha is the precision score
        color = f"rgba(255, 0, 0, {score:.3f})" # Yellow highlight
        
        # Alternatively, a red-to-green scale
        # color = f"rgb({int(255*(1-score))}, {int(255*score)}, 100)"
        
        html_output += f'<span style="background-color: {color}; padding: 2px; margin: 1px; border-radius: 3px;" title="P_tot: {score:.2f}">{token}</span>'
    
    html_output += '</div>'
    display(HTML(html_output))

In [None]:
# model.args.compute_metadata=True
model.backbone.args.add_gaussian_noise = True

# 1. Get a starting sequence from the test set
# [1, 512]
idx=6
sample_input = torch.tensor(extrapolation_dataset[idx]['input_ids'][:args.seq_len]).unsqueeze(0).to(args.device)

# 2. Generate 100 new tokens
total_ids, new_ids, p_list = autoregressive_sample(
    model, 
    sample_input, 
    max_gen_len=100, 
    t_equal=args.t_equal, 
    causal=args.causal
)

# 3. Process Precisions
p_tensor = torch.stack(p_list, dim=1).squeeze(0) # [100]
p_min, p_max = p_tensor.min(), p_tensor.max()
p_normalized = (p_tensor - p_min) / (p_max - p_min + 1e-8)

# 4. Show the result
print("\n--- GENERATED TEXT WITH SDE PRECISION OVERLAY ---")
# Convert new_ids to list for the plotting function
plot_text_precision_heatmap(new_ids.squeeze(0).cpu().tolist(), p_normalized.cpu().numpy(), tokenizer)

In [None]:
def scan_sequence_precision(model, input_ids, args):
    """
    Scans a ground-truth sequence and returns the precision for each token.
    """
    model.eval()
    with torch.no_grad():
        # Forward pass on the full sequence
        # input_ids shape: [1, seq_len]
        logits, output_dict = model(input_ids, t_measure=None, causal=True)
        
        # P_tot shape is [B, L, H]
        # We take the mean across heads to get a global 'confidence' per token
        p_seq = torch.mean(output_dict['P_tot'], dim=-1).squeeze(0) # [seq_len]
        
        # Normalize for visualization
        p_min, p_max = p_seq.min(), p_seq.max()
        p_norm = (p_seq - p_min) / (p_max - p_min + 1e-8)
        
    return p_norm.cpu().numpy()

# --- EXECUTION ---
# Take a 100-token slice to make the visualization readable
sample_ids = torch.tensor(extrapolation_dataset[0]['input_ids'][:100]).unsqueeze(0).to(args.device)
p_map = scan_sequence_precision(model, sample_ids, args)

# Use your red intensity function
print("\n--- GROUND TRUTH PRECISION SCAN ---")
plot_text_precision_heatmap(sample_ids.squeeze(0).cpu().tolist(), p_map, tokenizer)

In [None]:
def warm_sliding_window_scan(model, full_ids, window_size=512, scan_len=100):
    """
    Performs a sliding window scan to ensure every token is evaluated 
    with a full 'Warm' context window.
    """
    model.eval()
    warm_precisions = []
    target_tokens = []
    
    # We scan the last 'scan_len' tokens of the sequence
    start_idx = full_ids.size(1) - scan_len
    
    with torch.no_grad():
        for i in tqdm(range(start_idx, full_ids.size(1)), desc="Sliding Window Scan"):
            # Extract the 512 tokens leading up to the current token i
            # Context is [i - window_size : i]
            # Current token is at index i
            context_window = full_ids[:, i - window_size : i + 1] 
            
            # Forward pass on the window
            _, output_dict = model(context_window.to(args.device), causal=True)
            
            # Extract precision for the very last token (the one at index i)
            # P_tot shape: [B, L, H]. We take [:, -1, :]
            p_step = torch.mean(output_dict['P_tot'][:, -1, :], dim=-1) # [B]
            warm_precisions.append(p_step)
            target_tokens.append(full_ids[0, i].item())
            
    # Normalize for the heatmap
    p_tensor = torch.stack(warm_precisions).squeeze() # [scan_len]
    p_min, p_max = p_tensor.min(), p_tensor.max()
    p_norm = (p_tensor - p_min) / (p_max - p_min + 1e-8)
    
    return target_tokens, p_norm.cpu().numpy()

# --- EXECUTION ---
# Use the large extrapolation block (size 4096)
# We take a slice large enough to provide 512 context for a 100-token scan
full_block_ids = torch.tensor(extrapolation_dataset[2]['input_ids']).unsqueeze(0)
tokens, p_map = warm_sliding_window_scan(model, full_block_ids, window_size=512, scan_len=100)

print("\n--- WARM SLIDING WINDOW SCAN (512 Context) ---")
plot_text_precision_heatmap(tokens, p_map, tokenizer)