## Robust Filter Attention

### Training on Wikitext-103

### Setup

In [None]:
import numpy as np
import math
import scipy

from matplotlib import pyplot as plt
import matplotlib.cm as cm
from matplotlib import patches
plt.rcParams['figure.figsize'] = [10, 10]
plt.rc('font', size=20)

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.utils.data import Subset

import transformers
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2TokenizerFast
from transformers import DataCollatorForLanguageModeling

import os
import argparse
import datetime
import time
from tqdm import tqdm # Loading bar
import glob
print('Done.')

In [None]:
from utils import complex_matmul, apply_interleaved_rope
from utils import count_parameters, get_layers, seed_everything
print('Done.')

In [None]:
from isotropic_rfa import get_safe_exp_tot, compute_covariance_matrix, compute_covariance_matrix_LHopital
from isotropic_rfa import compute_covariance_matrix_spectral_full, compute_covariance_matrix_residual_diffusion
from isotropic_rfa import compute_exp_kernel_isotropic, compute_residual_norm_isotropic
print('Done.')

In [None]:
from model import resolve_multihead_dims, autoregressive_sample
from model import init_complexlinear, init_complex_matrix, initialize_linear_layers
from model import init_rope, init_decay_per_head, init_linear_bias_slopes
from model import apply_weight_masks
from model import ComplexLinearLayer, ComplexLinearHermitianLayer, ComplextoRealLinearLayer
from model import ComplexRMSNorm
from model import MultiHeadAttentionLayer, MultiheadIsotropicRFA
from model import TransformerBlock, TransformerNetwork
from model import SelfAttentionBlock, RFA_Block
from model import RFATransformerBlock, RFATransformerNetwork
from model import LanguageModel
print('Done.')

In [None]:
from visualization import plot_trajectory, compute_state_matrix, plot_state_matrix, visualize_results
from visualization import visualize_results_attn, _get_visual_modules, visualize_rfa_lm
from visualization import plot_training_progress_lm
print('Done.')

In [None]:
from training import single_epoch_rfa_lm, single_epoch_standard_lm
from training import hook_fn
print('Done.')

In [None]:
parser = argparse.ArgumentParser('DA')
parser.add_argument('--gpu', type=int, default=0) # (Default: 0)
args = parser.parse_args(args=[])
args.device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
print(args.device)
    
seed_everything(seed=2025) # Set random seed

### Load Wikitext-103 dataset

In [None]:
# wikitext data files
data_files = {
    "train": "datasets/wikitext-103/train-*.parquet",
    "validation": "datasets/wikitext-103/validation-*.parquet",
    "test": "datasets/wikitext-103/test-*.parquet"
}

# Load dataset
raw_datasets = load_dataset("parquet", data_files=data_files)

In [None]:
# # Load a tokenizer
# tokenizer = AutoTokenizer.from_pretrained("gpt2") 
# tokenizer.pad_token = tokenizer.eos_token

tokenizer = GPT2TokenizerFast.from_pretrained(
    "./gpt2_tokenizer/",
    local_files_only=True  # This GUARANTEES it won't try to use the internet
)
tokenizer.pad_token = tokenizer.eos_token

# Tokenize and Chunk
def tokenize_function(examples, tokenizer=None):
    return tokenizer(examples["text"])

tokenized_datasets = raw_datasets.map(
    tokenize_function, 
    batched=True, 
    num_proc=4, 
    fn_kwargs={"tokenizer": tokenizer},
    remove_columns=["text"]
)

In [None]:
# Group into blocks

def group_texts(examples, block_size):
    from itertools import chain
    # Flatten all the token lists into one long list (Efficiently)
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    
    # Get the total number of tokens
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    
    # Round down to the nearest multiple of block_size
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
        
    # Chop into fixed-size chunks
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    
    # Create the labels (same as inputs for Causal LM)
    result["labels"] = result["input_ids"].copy()
    
    return result

block_size = 512
lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    num_proc=4,
    fn_kwargs={"block_size": 512}
)

In [None]:
# Get a single example
sample = lm_datasets["train"][0]

print(f"Dataset size: {len(lm_datasets['train'])} blocks")
print(f"Sequence Length: {len(sample['input_ids'])}")
print(f"Features: {lm_datasets['train'].column_names}")

# Verify label alignment (should be identical before the model-side shift)
is_aligned = sample['input_ids'] == sample['labels']
print(f"Labels perfectly aligned with inputs: {is_aligned}")

In [None]:
# Decode the first block
decoded_text = tokenizer.decode(sample['input_ids'])

print("--- SAMPLE DATA START ---")
print(decoded_text[:500]) # Print first 500 characters
print("--- SAMPLE DATA END ---")

In [None]:
# # Sample 1000 tokens from the first few blocks
# all_tokens = []
# for i in range(5):
#     all_tokens.extend(lm_datasets["train"][i]["input_ids"])

# plt.figure(figsize=(10, 4))
# plt.hist(all_tokens, bins=100, color='skyblue', edgecolor='black')
# plt.title("Token ID Distribution (WikiText-103)")
# plt.xlabel("Token ID")
# plt.ylabel("Frequency")
# plt.grid(axis='y', alpha=0.3)
# plt.show()

In [None]:
lm_datasets["train"]
# Dataset({
#     features: ['input_ids', 'attention_mask', 'labels'],
#     num_rows: 229206
# })

### Define Model

In [None]:
####################
## MODEL SETTINGS ##
####################

args.batch_size = 16 # Batch size
args.vocab_size = 50257
args.d_e = 256
args.seq_len = 512
args.n_heads = 8
args.num_blocks = 6
args.d_k_total = args.d_e # Total query-key dim across all heads
args.d_v_total = args.d_e # Total value dim across all heads
# args.num_blocks = 3

# args.max_learned_decay = 1.4 # e/2
args.max_learned_decay = 5.0
args.max_fixed_decay = 5.0 # Can be more aggressive

# Limits for clamping exponent
args.max_exponent = 0
args.min_exponent = -10

args.epsilon = 1E-5 # Stability param

args.compute_metadata = False # Triggers computing various diagnostics; turned off during training
args.compute_pulled_forward_estimates = False # "Project" every past state into every future frame; very expensive.

In [None]:
##############################
## DEFAULT ABLATION OPTIONS ##
##############################
args.causal = True
args.t_equal = True # Equal time intervals?
args.sep_params = False # Use separate params for keys and values?
args.lambda_real_zero = False # Zero out real part of eigenvalues?
args.use_full_residual_norm = 1 # Use the full |R|^2 metric?
args.use_robust_weight = True # Use rational weight rather than softmax
args.additive_bias_type = 1 # (Additive bias: 0 for zero; 1 for DLE; 2 for linear)
args.multiplicative_bias_type = 1 # (Multiplicative bias: 0 for constant; 1 for DLE; 2 for linear)
args.t_shift = None # Default
args.learn_t_shift = True
if args.learn_t_shift == True:
    args.t_shift = None
args.learn_rotations = False # Learned rotations (True), or fixed as in RoPE (False)?
args.learn_decay = False # Learned decay (True), or fixed (False)?
args.rotate_values = True # Rotate/unrotate values?
args.zero_process_noise = False # Zero process noise (sigma^2)?
args.zero_key_measurement_noise = False # Zero key measurement noise (eta^2)?
args.use_total_precision_gate = 1 # Use total-precision gating? (0 = No gate, 1 = precision gate, 2 = learned gate)
args.use_inner_residual = False # Include a residual connection BEFORE output projection?
args.use_outer_residual = True # Include a residual connection AFTER output projection?
args.use_complex_input_norm = 0 # Use complex-valued RMS Norm AFTER input projection for query/key/value (1), complex-valued RMS Norm AFTER input projection only for query/key (2), or None (0)?
args.use_complex_output_norm = False # Use complex-valued RMS Norm BEFORE output projection?
args.use_real_input_norm = True # Use real-valued RMS Norm BEFORE input projection?
args.use_real_output_norm = True # Use real-valued RMS Norm AFTER output projection?
args.add_gaussian_noise = False # Add Gaussian noise to final token? (for test-time sampling)
args.use_complex_conj_constraint = True # Eigenvalues must appear in complex conjugate pairs to ensure A is real
args.use_colored_prior = False
# args.allow_BM_branch = True # Allow separate branch for Brownian motion? (only used when learning decay)
args.scale_decay_by_time_interval = True
args.zero_rotations = False
args.use_ss_process_noise = False
args.damping = 0.05

args.use_rope = True
args.use_alibi = False
# args.use_xpos = False

args.use_SC_RoPE = False
args.use_log_linear_decay = True

In [None]:
## ABLATIONS ##

# ---------------------------------------------------

# # BASELINES:

# ---------------------------------------------------

# # B1: Standard Transformer + RoPE Baseline
# #     --- Dot-Product Similarity; current SOTA positional encoding baseline.
# #     Applies d --> 2d --> d attention projections for fair comparison.
# # Uses separate backbone and training looping

# ablation_name = 'B1'
# args.use_rope = True
# args.use_alibi = False
# args.use_relative_decay_vanilla = False

# ---------------------------------------------------

# # B2: Standard Transformer + ALiBi Baseline
# #     Applies d --> 2d --> d attention projections for fair comparison.
# # Uses separate backbone and training looping

# ablation_name = 'B2'
# args.use_rope = False
# args.use_alibi = True
# args.use_relative_decay_vanilla = False

# ---------------------------------------------------

# # B3 (RoPE + decay):
# ablation_name = 'B3'
# args.use_rope = True
# args.use_alibi = False
# args.use_SC_RoPE = False
# args.scale_decay_by_time_interval = False
# args.damping = 0.05
# args.use_relative_decay_vanilla = True

# ---------------------------------------------------

# # B4 (SC-RoPE):
# ablation_name = 'B4'
# args.use_rope = False
# args.use_alibi = False
# args.use_SC_RoPE = True
# args.scale_decay_by_time_interval = False
# args.damping = 0.05
# args.use_relative_decay_vanilla = True

# ------------------------------------------------------------------------------------------------------
# ------------------------------------------------------------------------------------------------------

# # MAIN MODELS:

# ---------------------------------------------------

# # M1: Isotropic RFA

# # Default settings
# ablation_name = 'M1'
# args.use_log_linear_decay = False
# args.damping = 0.05

# ---------------------------------------------------

# # # M1.1: Isotropic RFA w/o DLE Prior

# # Default settings
# ablation_name = 'M1.1'
# args.use_log_linear_decay = False
# args.damping = 0.05
# args.zero_process_noise = True
# args.zero_key_measurement_noise = True
# args.additive_bias_type = 0
# args.multiplicative_bias_type = 0
# args.use_log_linear_decay = False

# ---------------------------------------------------

# # M1.2: Isotropic RFA w/o robust weight

# # Default settings
# ablation_name = 'M1.2'
# args.use_log_linear_decay = False
# args.damping = 0.05
# args.use_robust_weight = False

# ---------------------------------------------------

# # M2: Spectrally-Coupled RFA
# # # --- M1, but we partition the angular frequencies so that the fast frequencies
# # # are coupled with fast decays and vice versa

# ablation_name = 'M2'
# args.use_SC_RoPE = True
# args.scale_decay_by_time_interval = False
# args.damping = 0.005
# args.use_log_linear_decay = False

# ---------------------------------------------------

# # M2.1 (M2 + no robust weight)
# ablation_name = 'M2.1'
# args.use_SC_RoPE = True
# args.scale_decay_by_time_interval = False
# args.damping = 0.05
# args.use_robust_weight = False
# args.use_log_linear_decay = False

# ---------------------------------------------------

# # M2.2: Spectrally-Coupled RFA, w/o DLE Prior
# # # --- M2, but we ablate the DLE

# ablation_name = 'M2.2'
# args.use_SC_RoPE = True
# args.scale_decay_by_time_interval = False
# args.damping = 0.05
# args.zero_process_noise = True
# args.zero_key_measurement_noise = True
# args.additive_bias_type = 0
# args.multiplicative_bias_type = 0
# args.use_log_linear_decay = False

# ---------------------------------------------------

# M2.3: (M2 + ablate out only multiplicative gate but keep additive term)

# ablation_name = 'M2.3'
# args.use_SC_RoPE = True
# args.scale_decay_by_time_interval = False
# args.damping = 0.05
# args.additive_bias_type = 1 # Keep additive term
# args.multiplicative_bias_type = 0 # Set multiplicative term to constant
# args.use_log_linear_decay = False

# ---------------------------------------------------

# # M2.4: (M2 + No Value rotations)

# ablation_name = 'M2.4'
# args.use_SC_RoPE = True
# args.scale_decay_by_time_interval = False
# args.damping = 0.05
# args.use_log_linear_decay = False
# args.rotate_values = False

# ---------------------------------------------------

# # M2.5: (M2 + No Rotations)
# ablation_name = 'M2.5'
# args.use_SC_RoPE = True
# args.scale_decay_by_time_interval = False
# args.damping = 0.05
# args.use_log_linear_decay = False
# args.zero_rotations = True

# ---------------------------------------------------

# # M2.6: (Unitary, zero noise Limit)
# ablation_name = 'M2.6'
# args.use_SC_RoPE = True
# args.scale_decay_by_time_interval = False
# args.damping = 0.0
# args.use_log_linear_decay = False

# args.use_robust_weight = False
# args.use_full_residual_norm = 0
# args.lambda_real_zero = True
# args.zero_process_noise = True
# args.zero_key_measurement_noise = True
# args.additive_bias_type = 0
# args.multiplicative_bias_type = 0
# args.max_fixed_decay = 0.0 # Zero out decay

# ---------------------------------------------------

# M2.7: Spectrally-Coupled RFA + Total Confidence Gate

ablation_name = 'M2.7'
args.use_SC_RoPE = True
args.scale_decay_by_time_interval = False
args.damping = 0.05
args.use_log_linear_decay = False
args.use_inner_residual = True
args.use_total_precision_gate = 1

In [None]:
print(ablation_name)
print(args)

In [None]:
#####################
## DEFINE BACKBONE ##
#####################

###########################################################

# # Standard Attention

# backbone = SelfAttentionBlock(input_dim=args.d_e, qkv_dim=args.d_e, num_heads=args.n_heads, args=args)
    
###########################################################

# # Multihead Isotropic RFA

# backbone = RFA_Block(args, args.n_heads, input_dim=args.d_e, query_key_dim_total=args.d_k_total, value_dim_total=args.d_v_total)

###########################################################

# # # Standard Transformer

# backbone = TransformerNetwork(args.d_e, args.d_e*2, args.d_e*4, args.n_heads, args, num_blocks=args.num_blocks, Norm=nn.LayerNorm)

###########################################################

# RFA Transformer

backbone = RFATransformerNetwork(args=args, num_blocks=args.num_blocks, n_heads=args.n_heads, input_dim=args.d_e, query_key_dim_total=args.d_k_total, value_dim_total=args.d_v_total, hidden_dim = 4*args.d_v_total, Norm=nn.LayerNorm)

###########################################################

In [None]:
# Wrap it to create the Language Model
model = LanguageModel(
    backbone=backbone, 
    vocab_size=args.vocab_size, 
    embed_dim=args.d_e
).to(args.device)

print(model)

params_list = list(model.parameters()) # Parameters list

print('Total parameter count:', count_parameters(model))

### Training

In [None]:
# Initialize the Data Collator
# mlm=False is crucial: it tells the collator we are doing Causal LM, 
# not Masked LM (like BERT). It will ensure labels are handled correctly.
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Create DataLoaders
train_dataloader = DataLoader(
    lm_datasets["train"],
    shuffle=True,
    batch_size=args.batch_size,
    collate_fn=data_collator
)

eval_dataloader = DataLoader(
    lm_datasets["validation"],
    batch_size=args.batch_size,
    collate_fn=data_collator
)

test_dataloader = DataLoader(
    lm_datasets["test"],
    batch_size=args.batch_size,
    collate_fn=data_collator
)

# Quick Test of the first batch
batch = next(iter(train_dataloader))
print(f"Batch keys: {batch.keys()}")
print(f"Input IDs shape: {batch['input_ids'].shape}")
print(f"Labels shape: {batch['labels'].shape}")

In [None]:
# ## DEBUG BY TESTING ON SUBSET OF DATASET ##
# ###########################################

# # Define how many batches we want to test (e.g., 50 batches)
# # If batch_size is 16, 1000 samples = ~62 batches
# num_debug_samples = 1000 

# train_dataset = train_dataloader.dataset
# val_dataset = eval_dataloader.dataset 

# # Create subsets
# train_subset = Subset(train_dataset, range(min(num_debug_samples, len(train_dataset))))
# val_subset = Subset(val_dataset, range(min(num_debug_samples // 5, len(val_dataset))))

# # Create temporary debug loaders
# debug_train_loader = torch.utils.data.DataLoader(
#     train_subset, 
#     batch_size=args.batch_size, 
#     shuffle=True,
#     collate_fn=train_dataloader.collate_fn
# )

# debug_val_loader = torch.utils.data.DataLoader(
#     val_subset, 
#     batch_size=args.batch_size, 
#     shuffle=False,
#     collate_fn=eval_dataloader.collate_fn
# )

# train_dataloader = debug_train_loader
# eval_dataloader = debug_val_loader

# print(f"Debug Mode: Training on {len(debug_train_loader)} batches.")

In [None]:
####################
## TRAINING SETUP ##
####################

criterion = nn.CrossEntropyLoss() # Loss

args.num_epochs = 15 # Number of epochs
print('Num epochs:', args.num_epochs)

args.save_model = True

args.save_epochs = 1 # Intervals of epochs to save model
args.show_example_epochs = 1 # Number of epochs between displaying results

#####################

# Create folders for model weights, and loss history

try:
    root_path = os.path.dirname(os.path.abspath(__file__))
except NameError:
    root_path = os.getcwd()

saved_models_path = os.path.join(root_path, 'saved_models', 'wikitext_103')
model_name = str(model.backbone.__class__.__name__)
date = str(datetime.datetime.today()).split()
date_time = date[0]
model_path = os.path.join(saved_models_path, f"{model_name}__{ablation_name}__{date_time}")

try:
    os.makedirs(model_path, exist_ok=True)
except:
    pass

In [None]:
model_path

In [None]:
####################

# Optimizer

# args.lr = 1E-2 # Learning rate
# optimizer = torch.optim.Adam(params_list, lr=args.lr, betas=(0.9, 0.999)) # Optimizer

# Separate the "Physics" from the "Features"
sde_params = [p for n, p in model.named_parameters() if any(k in n for k in ['mu_', 'sigma_', 'eta_', 'gamma_'])]
feature_params = [p for n, p in model.named_parameters() if not any(k in n for k in ['mu_', 'sigma_', 'eta_', 'gamma_'])]

feature_lr = 1e-3
sde_lr = feature_lr/2

# optimizer = torch.optim.Adam([
#     {'params': feature_params, 'lr': feature_lr},
#     {'params': sde_params, 'lr': sde_lr}  # slower to prevent spikes
# ])

optimizer = torch.optim.Adam([
    # Standard Weights (with momentum)
    {'params': feature_params, 'lr': feature_lr, 'betas': (0.9, 0.999)},
    
    # Decay Params (Lower learning rate, NO momentum, higher epsilon)
    {
        'params': sde_params, 
        'lr': sde_lr,          # Lower learning rate
        'betas': (0.0, 0.999), # First beta=0 kills momentum
        'eps': 1e-7            # Higher eps prevents division-by-zero spikes
    }
])

########################

In [None]:
# Learning rate scheduler

# scheduler = None

################################

# Cosine annealing

total_steps = (args.num_epochs * len(train_dataloader)) + 10 

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, 
#     max_lr=args.lr,        # Target learning rate
    max_lr = [feature_lr, sde_lr],
    epochs=args.num_epochs, 
    steps_per_epoch=len(train_dataloader)+2, # Add +2 to ensure the scheduler never "runs out" of steps
    pct_start=0.05,          # Spend 5% of training time warming up
    anneal_strategy='cos',   # Use cosine decay
    div_factor=25.0,         # Start LR is args.lr / 25
    final_div_factor=1000.0  # Final LR is args.lr / 1000
)

In [None]:
# If starting a new training run :

start_epoch = 0

# Initialize history tracking
if model_name == 'TransformerNetwork':
    history = {
    'loss': [],
    'val_loss': [],
    'val_ppl': []
}
else:
    history = {
        'loss': [], 'mu': [], 'sigma': [], 'sigma_tilde': [],
        'eta': [], 'gamma': [], 'tau': [], 'nu_over_d': [],
        'val_loss': [], 'val_ppl': [], 'epoch_times': []
    }

In [None]:
# ## LOAD MODEL ##
# ################

# # checkpoint_path = os.path.join(saved_models_path, 'TransformerNetwork__M0__2026-01-11')
# checkpoint_path = os.path.join(saved_models_path, 'RFATransformerNetwork__M2.7__2026-01-27')

# # Path to checkpoint
# checkpoint_path = os.path.join(checkpoint_path, "rfa_lm_epoch_2.pt")
# # checkpoint_path = os.path.join(checkpoint_path, "standard_transformer_epoch_15.pt")

# if os.path.exists(checkpoint_path):
#     print(f"Loading checkpoint from {checkpoint_path}...")
    
#     # Use weights_only=False because we are loading a custom dict with history/args
#     checkpoint = torch.load(checkpoint_path, map_location=args.device)

#     # Restore Model Weights
#     model.load_state_dict(checkpoint['model_state_dict'])

#     # Restore Optimizer and Scheduler
#     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#     if scheduler and 'scheduler_state_dict' in checkpoint:
#         scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

#     # Extract History and Test Results
#     history = checkpoint['history']
#     start_epoch = checkpoint['epoch'] + 1
    
#     # Print Stored Results
#     print("-" * 30)
#     print(f"Resuming from Epoch: {start_epoch}")
#     print(f"Last Val PPL: {history['val_ppl'][-1]:.2f}")
#     print('Validation history:', history['val_ppl'])
    
# else:
#     print("No checkpoint found. Starting from scratch.")
#     start_epoch = 0

In [None]:
# ## CHECKS

# # Access the omega (frequency) tensor
# rfa_layer = model.backbone.blocks[0].attn
# omega = rfa_layer.omega_v 

# # Compare the first and last head
# head_0 = omega[0]
# head_last = omega[-1]

# # Compute the similarity
# are_identical = torch.equal(head_0, head_last)

# print(f"Frequencies across heads are identical: {are_identical}")

# if are_identical:
#     print("CRITICAL: You are using Standard RoPE (Wrong for M8).")
# else:
#     print("SUCCESS: Your Spectral Coupling is preserved (Right for M8).")
#     print(f"Head 0 Max Freq: {head_0.max().item():.6f}")
#     print(f"Head Last Max Freq: {head_last.max().item():.6f}")
    
# def check_transformer_damping(model):
#     """
#     Scans a Transformer model to find attention layers and 
#     calculate the effective damping ratio 'b'.
#     """
#     print(f"{'Block':<8} | {'Type':<15} | {'Avg b':<10} | {'Zero Frac':<10}")
#     print("-" * 50)

#     # Iterate through all modules to find the attention layers
#     for name, module in model.named_modules():
#         # Check for Isotropic RFA (M8)
#         if "MultiheadIsotropicRFA" in str(type(module)):
#             mu = module.mu_v.detach().cpu()
#             omega = module.omega_v.detach().cpu()
            
#             # Max omega for each head's shard is at the end: omega[:, -1]
#             max_omegas = omega[:, -1]
#             mask = mu > 0
            
#             avg_b = (mu[mask] / max_omegas[mask]).mean().item() if mask.any() else 0.0
#             zero_frac = (mu == 0).float().mean().item()
            
#             print(f"{name:<8} | RFA (M8)       | {avg_b:.4f}     | {zero_frac:.2f}")

#         # Check for Heuristic Attention (M14)
#         elif "SpectralCoupledHeuristicAttention" in str(type(module)):
#             mu = module.mu.detach().cpu()
#             # Frequencies are stored in the SCRoPE or RoPE sub-module
#             omega = module.rope.theta.detach().cpu()
            
#             # Handle standard RoPE (1D theta) vs SCRoPE (2D theta)
#             if omega.ndim == 1:
#                 max_omega = omega[-1]
#                 mask = mu > 0
#                 avg_b = (mu[mask] / max_omega).mean().item() if mask.any() else 0.0
#             else:
#                 max_omegas = omega[:, -1]
#                 mask = mu > 0
#                 avg_b = (mu[mask] / max_omegas[mask]).mean().item() if mask.any() else 0.0
                
#             zero_frac = (mu == 0).float().mean().item()
#             print(f"{name:<8} | Heuristic (M14)| {avg_b:.4f}     | {zero_frac:.2f}")

# # Usage:
# check_transformer_damping(model)

In [None]:
# plot_training_progress_lm(history, 15, checkpoint_path, save=False)

# visualize_rfa_lm(model, eval_dataloader.dataset, history, 7, checkpoint_path, args,
#                     save=False,
#                     plot_log_losses_flag=True,
#                     plot_last_attn_mat_flag=True,
#                     plot_decay_per_iteration=True,
#                     plot_noise_params = True,
#                     plot_tau_and_nu_flag=True)

In [None]:
#########################################################################
########################### RFA TRAINING LOOP ###########################
#########################################################################

if model_name == 'TransformerNetwork':
    print('ERROR: Wrong model type.')

print(f"Starting training on {args.device}...")

for epoch in tqdm(range(start_epoch, args.num_epochs)):
    start_time = time.time()
    
    # --- TRAINING PHASE ---
    model.train()
    output_dict, history = single_epoch_rfa_lm(
        model, 
        train_dataloader, 
        history, 
        optimizer, 
        criterion, 
        args, 
        scheduler=scheduler
    )
    
    end_time = time.time()
    epoch_duration = end_time - start_time
    history['epoch_times'].append(epoch_duration)
    
    # --- EVALUATION PHASE ---
    model.eval()
    val_losses = []
    
    with torch.no_grad():
        for batch in eval_dataloader:
            batch = {k: v.to(args.device) for k, v in batch.items()}
            inputs, labels = batch["input_ids"], batch["labels"]
            
            logits, _ = model(inputs, t_measure=None, causal=True)
            
            # Shift logic for next-token prediction
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            
            v_loss = criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            val_losses.append(v_loss.item())

    # --- LOGGING & METRICS ---
    avg_val_loss = np.mean(val_losses)
    val_ppl = np.exp(avg_val_loss)
    
    history['val_loss'].append(avg_val_loss)
    history['val_ppl'].append(val_ppl)

    # Calculate average train loss for current epoch
    train_loss_epoch = np.mean(history['loss'][-len(train_dataloader):])

    print(f"Epoch {epoch+1}/{args.num_epochs} | "
          f"Train Loss: {train_loss_epoch:.4f} | "
          f"Val Loss: {avg_val_loss:.4f} | "
          f"Val PPL: {val_ppl:.2f}")

    # --- SLIDING WINDOW CHECKPOINT SAVING ---
    if args.save_model and (epoch + 1) % args.save_epochs == 0:
        checkpoint_name = f"rfa_lm_epoch_{epoch+1}.pt"
        checkpoint_full_path = os.path.join(model_path, checkpoint_name)
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'history': history,
            'args': args
        }, checkpoint_full_path)
        print(f"Full checkpoint saved: {checkpoint_name}")

        # Keep only the 2 most recent checkpoints to save disk space
        checkpoints = sorted(glob.glob(os.path.join(model_path, "rfa_lm_epoch_*.pt")), 
                             key=os.path.getmtime)
        if len(checkpoints) > 2:
            for i in range(len(checkpoints) - 2):
                os.remove(checkpoints[i])
                print(f"Removed old checkpoint: {os.path.basename(checkpoints[i])}")

    # --- VISUALIZATION ---
    if np.mod(epoch + 1, args.show_example_epochs) == 0:
        plot_training_progress_lm(history, epoch, model_path)
        
        # Uncomment if you want the deep RFA dynamics visualizations
        # visualize_rfa_lm(model, eval_dataloader.dataset, history, epoch, model_path, args,
        #                  save=True,
        #                  plot_log_losses_flag=True,
        #                  plot_last_attn_mat_flag=True,
        #                  plot_decay_per_iteration=True,
        #                  plot_noise_params=True,
        #                  plot_tau_and_nu_flag=True)

# --- FINAL EVALUATION ON TEST SET ---
print("\nTraining complete. Performing final evaluation on Test Set...")
model.eval()

total_loss = 0.0
total_tokens = 0
criterion_sum = torch.nn.CrossEntropyLoss(reduction='sum')

with torch.no_grad():
    for batch in test_dataloader:
        batch = {k: v.to(args.device) for k, v in batch.items()}
        logits, _ = model(batch["input_ids"], t_measure=None, causal=True)
        
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = batch["labels"][:, 1:].contiguous()
        
        # Identify non-padding tokens (ignore_index is usually -100)
        mask = (shift_labels != -100) 
        num_tokens = mask.sum().item()
        
        loss_sum = criterion_sum(
            shift_logits.view(-1, shift_logits.size(-1)), 
            shift_labels.view(-1)
        )
        
        total_loss += loss_sum.item()
        total_tokens += num_tokens

# Perplexity calculation: exp(Total NLL / Total Tokens)
avg_nll = total_loss / total_tokens
test_ppl = np.exp(avg_nll)

history['test_loss'] = avg_nll
history['test_ppl'] = test_ppl

print(f"Final Test Results | Avg NLL: {avg_nll:.4f} | PPL: {test_ppl:.2f}")

In [None]:
# # Jump

# #########################################
# ##### LOOP FOR STANDARD ATTENTION #######

# if model_name == 'RFATransformerNetwork':
#     print('ERROR: Wrong model type.')

# print(f"Starting Baseline Training on {args.device}...")

# for epoch in tqdm(range(start_epoch, args.num_epochs)):
#     start_time = time.time()
    
#     # --- TRAINING PHASE ---
#     model.train()
#     # Call the simplified standard training function
#     attn_weights, history = single_epoch_standard_lm(
#         model, 
#         train_dataloader, 
#         history, 
#         optimizer, 
#         criterion, 
#         args, 
#         scheduler=scheduler
#     )
    
#     # --- EVALUATION PHASE ---
#     model.eval()
#     val_losses = []
    
#     with torch.no_grad():
#         for batch in eval_dataloader:
#             batch = {k: v.to(args.device) for k, v in batch.items()}
#             inputs, labels = batch["input_ids"], batch["labels"]
            
#             # Standard forward pass
#             logits, _ = model(inputs)
            
#             # Internal shift for validation loss
#             shift_logits = logits[:, :-1, :].contiguous()
#             shift_labels = labels[:, 1:].contiguous()
            
#             v_loss = criterion(
#                 shift_logits.view(-1, shift_logits.size(-1)), 
#                 shift_labels.view(-1)
#             )
#             val_losses.append(v_loss.item())

#     # --- LOGGING & METRICS ---
#     avg_val_loss = np.mean(val_losses)
#     val_ppl = np.exp(avg_val_loss)
    
#     history['val_loss'].append(avg_val_loss)
#     history['val_ppl'].append(val_ppl)

#     # Calculate average train loss for current epoch
#     train_loss_epoch = np.mean(history['loss'][-len(train_dataloader):])

#     # Print summary for the epoch
#     print(f"Epoch {epoch+1}/{args.num_epochs} | "
#           f"Train Loss: {train_loss_epoch:.4f} | "
#           f"Val Loss: {avg_val_loss:.4f} | "
#           f"Val PPL: {val_ppl:.2f}")

#     # --- SLIDING WINDOW CHECKPOINT SAVING ---
#     if args.save_model and (epoch + 1) % args.save_epochs == 0:
#         checkpoint_name = f"standard_transformer_epoch_{epoch+1}.pt"
#         checkpoint_full_path = os.path.join(model_path, checkpoint_name)
        
#         torch.save({
#             'epoch': epoch,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
#             'history': history,
#             'args': args
#         }, checkpoint_full_path)
#         print(f"Full checkpoint saved: {checkpoint_name}")

#         # Keep only the 2 most recent baseline checkpoints to save disk space
#         checkpoints = sorted(glob.glob(os.path.join(model_path, "standard_transformer_epoch_*.pt")), 
#                              key=os.path.getmtime)
#         if len(checkpoints) > 2:
#             for i in range(len(checkpoints) - 2):
#                 os.remove(checkpoints[i])
#                 print(f"Removed old baseline checkpoint: {os.path.basename(checkpoints[i])}")

#     # Plotting (Using the same progress plotter, but with baseline history)
#     if np.mod(epoch + 1, args.show_example_epochs) == 0:
#         plot_training_progress_lm(history, epoch, model_path)
        
# # --- AFTER THE FULL LOOP FINISHES ---
# print("\nBaseline training complete. Performing final evaluation on Test Set...")
# model.eval()

# # Use 'sum' reduction to get total negative log-likelihood
# total_loss = 0.0
# total_tokens = 0
# criterion_sum = torch.nn.CrossEntropyLoss(reduction='sum')

# with torch.no_grad():
#     for batch in test_dataloader:
#         batch = {k: v.to(args.device) for k, v in batch.items()}
#         # Standard forward pass
#         logits, _ = model(batch["input_ids"])
        
#         # Shift logic
#         shift_logits = logits[:, :-1, :].contiguous()
#         shift_labels = batch["labels"][:, 1:].contiguous()
        
#         # Identify non-padding tokens
#         mask = (shift_labels != -100) 
#         num_tokens = mask.sum().item()
        
#         # Calculate sum of losses for this batch
#         loss_sum = criterion_sum(
#             shift_logits.view(-1, shift_logits.size(-1)), 
#             shift_labels.view(-1)
#         )
        
#         total_loss += loss_sum.item()
#         total_tokens += num_tokens

# # PPL: exp( (sum of all losses) / (total tokens) )
# avg_nll = total_loss / total_tokens
# test_ppl = np.exp(avg_nll)

# history['test_loss'] = avg_nll 
# history['test_ppl'] = test_ppl

# print(f"Final Baseline Test Results | Avg NLL: {avg_nll:.4f} | PPL: {test_ppl:.2f}")

In [None]:
# --- AFTER THE FULL LOOP FINISHES ---
print("Training complete. Performing final evaluation on Test Set...")
model.eval()

# Use 'sum' reduction to get total negative log-likelihood
total_loss = 0.0
total_tokens = 0
criterion_sum = torch.nn.CrossEntropyLoss(reduction='sum')

with torch.no_grad():
    for batch in test_dataloader:
        batch = {k: v.to(args.device) for k, v in batch.items()}
        logits, _ = model(batch["input_ids"], t_measure=None, causal=True)
        
        # Shift logic
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = batch["labels"][:, 1:].contiguous()
        
        # Identify non-padding tokens (ignore_index assumes -100 or pad_id)
        # If labels use a specific pad_id, replace -100 with that id.
        mask = (shift_labels != -100) 
        num_tokens = mask.sum().item()
        
        # Calculate sum of losses for this batch
        loss_sum = criterion_sum(
            shift_logits.view(-1, shift_logits.size(-1)), 
            shift_labels.view(-1)
        )
        
        total_loss += loss_sum.item()
        total_tokens += num_tokens

# PPL: exp( (sum of all losses) / (total tokens) )
avg_nll = total_loss / total_tokens
test_ppl = np.exp(avg_nll)

history['test_loss'] = avg_nll # This is the average loss per token
history['test_ppl'] = test_ppl

print(f"Final Test Results | Avg NLL: {avg_nll:.4f} | PPL: {test_ppl:.2f}")