## Robust Filter Attention

### Training on synthetic data from simulated dynamical systems

### Setup

In [None]:
import numpy as np
import math
import scipy

from matplotlib import pyplot as plt
import matplotlib.cm as cm
from matplotlib import patches
plt.rcParams['figure.figsize'] = [10, 10]
plt.rc('font', size=20)

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

import transformers
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader

import os
import argparse
import datetime
import time
from tqdm import tqdm # Loading bar
print('Done.')

In [None]:
from utils import complex_matmul, apply_interleaved_rope
from utils import count_parameters, get_layers, seed_everything
print('Done.')

In [None]:
from dynamics import simulate_stochastic_LTI, simulate_diagonalized_stochastic_LTI, DynamicSim
from dynamics import construct_mapping
from dynamics import get_nth_measurement, get_random_measurements
from dynamics import linear_spiral, linear_spiral_3D, Lorenz, rand_coupling_matrix, Van_der_Pol_osc
from dynamics import construct_random_mapping, construct_data_dynamics, TrainDataset_dynamics, create_train_loader_dynamics
print('Done.')

In [None]:
from isotropic_rfa import get_safe_exp_tot, compute_covariance_matrix, compute_covariance_matrix_LHopital
from isotropic_rfa import compute_covariance_matrix_spectral_full, compute_covariance_matrix_residual_diffusion
from isotropic_rfa import compute_exp_kernel_isotropic, compute_residual_norm_isotropic
print('Done.')

In [None]:
from model import resolve_multihead_dims, autoregressive_sample
from model import init_complexlinear, init_complex_matrix, initialize_linear_layers
from model import init_rope, init_decay_per_head, init_linear_bias_slopes
from model import apply_weight_masks
from model import ComplexLinearLayer, ComplexLinearHermitianLayer, ComplextoRealLinearLayer
from model import ComplexRMSNorm
from model import MultiHeadAttentionLayer, MultiheadIsotropicRFA
from model import TransformerBlock, TransformerNetwork
from model import SelfAttentionBlock, RFA_Block
from model import RFATransformerBlock, RFATransformerNetwork
print('Done.')

In [None]:
from visualization import plot_trajectory, compute_state_matrix, plot_state_matrix, visualize_results
from visualization import visualize_results_attn, _get_visual_modules
print('Done.')

In [None]:
from training import single_epoch_rfa_dynamics, single_epoch_attn_dynamics
from training import hook_fn
print('Done.')

In [None]:
parser = argparse.ArgumentParser('DA')
parser.add_argument('--gpu', type=int, default=0) # (Default: 0)
args = parser.parse_args(args=[])
args.device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
print(args.device)

seed_everything(2025)

### Load \& Visualize Data

In [None]:
###############################
## DEFINE THE LINERAR SYSTEM ##
###############################

# # Normal system
# D1 = torch.zeros(2,2,2).to(args.device) # Diagonal matrix
# D1[0] = torch.tensor([[-0.1, 0.0], [0.0, -0.1]]).to(args.device)
# D1[1] = torch.tensor([[-1.0, 0.0], [0.0, 1.0]]).to(args.device)
# # S1 = torch.zeros(2,2,2).to(args.device)
# # S1[0] = torch.tensor(([1.0,1.0],[0.0,0.0]))
# # S1[1] = torch.tensor(([0.0,0.0],[1.0,-1.0]))
# # S1 = U/np.sqrt(2)
# alpha = np.random.uniform(low=0.0, high=1.0)*2*np.pi
# beta = np.random.uniform(low=0.0, high=1.0)*2*np.pi
# S1 = construct_special_2D_unitary(alpha=alpha, beta=beta)
# Si1 = complex_conj_transpose(S1)

# --------------------------------------

# Stable 2D linear system (in diagonalized form):
D1 = torch.zeros(2,2,2).to(args.device) # Diagonal matrix
S1 = torch.zeros(2,2,2).to(args.device) # RHS matrix
Si1 = torch.zeros(2,2,2).to(args.device) # Inverse of RHS matrix
D1[0] = torch.tensor([[-0.1, 0.0], [0.0, -0.1]]).to(args.device) # Negative real part eigenvals
# D1[0] = torch.tensor([[0.0, 0.0], [0.0, 0.0]]).to(args.device) # Zero real part eigenvals
# D1[0] = torch.tensor([[0.1, 0.0], [0.0, 0.1]]).to(args.device) # Positive real part eigenvals
D1[1] = torch.tensor([[-1.0, 0.0], [0.0, 1.0]]).to(args.device)
S1[0] = torch.tensor([[1.0, 1.0], [1.0, 1.0]]).to(args.device)
S1[1] = torch.tensor([[-1.0, 1.0], [0.0, 0.0]]).to(args.device)
Si1[0] = 0.5*torch.tensor([[0.0, 1.0], [0.0, 1.0]]).to(args.device)
Si1[1] = 0.5*torch.tensor([[1.0, -1.0], [-1.0, 1.0]]).to(args.device)
A = complex_matmul(S1,complex_matmul(D1,Si1))[0].unsqueeze(0)
params = [D1, S1, Si1]

In [None]:
#################################
## DYNAMICS SIMULATION OPTIONS ##
#################################

args.rand_embed = 0 # Use random orthogonal matrix when projecting data to higher dimension?
args.weight_mask = False # Mask the weights? (0 or 1) (for visualizing in 2D)

args.m = 2 # Dimension of simulated system
args.d_e = 256 # Embedding dimension

args.tf = 10.0 # Final time
args.dt = 0.01 # Time step size
args.n = 10 # nth measurement (ie use every nth point as a measurement)
args.delta_t = args.dt * args.n

# Process and measurement noise:
sigma_process = 0.5 # Process noise
sigma_measure = 1.0 # Measurement noise

args.N_t = int(args.tf/args.dt) # Number of time steps
args.k_step = 1
args.N_t_tot = args.N_t + args.n * args.k_step
# t_vector = (torch.arange(args.N_t_tot + 1)).to(args.device) # Array of time steps
t_vector = (torch.arange(args.N_t_tot + 1)*args.dt).to(args.device) # Array of time steps
args.seq_len = int(args.N_t/args.n) + 1 # Number of measurements
args.total_L = args.seq_len + args.k_step

args.batch_size = 32 # Batch size
args.num_batch = 1 # Number of batches
args.num_samp = args.batch_size * args.num_batch # Number of samples in train loader

args.t_equal = True # Equal time intervals?

# if args.t_equal == 1:
#     args.total_time = args.seq_len - 1
# else:
#     args.total_time = args.tf

In [None]:
# # Test dynamic sim

# # Initial condition
# theta0 = torch.rand(1) * 2 * np.pi
# r0 = 20.0 + (torch.rand(1) * 2*3.0) - 3.0
# x0 = torch.tensor([r0 * torch.cos(theta0), r0 * torch.sin(theta0)])
# x0 = x0.to(args.device).unsqueeze(0).unsqueeze(-1)

# X, X_measure = simulate_stochastic_LTI(A, x0, args.N_t_tot, args, sigma_process, sigma_measure)
# # X, X_measure = simulate_diagonalized_stochastic_LTI(D1, S1, Si1, x0, t_vector, sigma_process, sigma_measure, args.device)

# # Actual trajectory
# X_plt = X.squeeze().detach().cpu().numpy()
# plt.plot(X_plt.T[0],X_plt.T[1],'black')

# # Noisy trajectory
# traj = X_measure.detach().cpu().squeeze().numpy()
# plt.plot(traj.T[0], traj.T[1], 'b--')

# plt.axis('equal')
# plt.grid()
# plt.show()

In [None]:
##########################
## CREATE TRAINING DATA ##
##########################

Pu, Pd, R1, R1i = construct_random_mapping(S1, Si1, args) # Get random matrices

# Build training dataset
train_loader, train_dataset, X_true_all, X_measure_all, t_measure_all = create_train_loader_dynamics(A, S1, Si1, Pu, Pd, R1, R1i, t_vector, sigma_process, sigma_measure, args, args.k_step)

In [None]:
###########################
## PLOT ALL TRAJECTORIES ##
###########################

fig, ax = plt.subplots()

# for it, (train_data, X_true, X_measure, t_measure_full) in enumerate(train_loader):
for it, (inputs, target, X_true, X_measure, t_measure) in enumerate(train_loader):
    # Actual trajectory
    X_true_plt = X_true.squeeze().detach().cpu().numpy()
    plt.plot(X_true_plt.T[0],X_true_plt.T[1],'black')

    # Noisy trajectory
    traj = X_measure.detach().cpu().squeeze().numpy()
    plt.plot(traj.T[0], traj.T[1], 'b--')
    
    x0 = X_true_plt[:,0].T
    plt.scatter(x0[0],x0[1], color='red')

    plt.axis('equal')
    plt.grid()
    plt.show()

### Define Model

In [None]:
####################
## MODEL SETTINGS ##
####################

args.max_learned_decay = 1.4 # e/2
# args.max_learned_decay = 5.0
args.max_fixed_decay = 5.0 # Can be more aggressive

# Limits for clamping exponent
args.max_exponent = 0
args.min_exponent = -30

# Key, query, value embedding dimensions are same as input embedding dimension
args.d_k = args.d_e
args.d_v = args.d_e

args.epsilon = 1E-5 # Stability param

args.compute_metadata = True # Triggers computing various diagnostics; turned off during training
args.compute_pulled_forward_estimates = True # "Project" every past state into every future frame; very expensive.

if args.d_k != args.d_v and args.sep_params == 0:
    print('ERROR: Key and value embedding dimensions must be the same if using shared parameters.')

In [None]:
######################
## Ablation options ##
######################
args.causal = True
args.sep_params = False # Use separate params for keys and values?
args.lambda_real_zero = 0 # Zero out real part of eigenvalues?
args.use_full_residual_norm = 1 # Use the full |R|^2 metric and rational robust weight (1) or dot product / exponential weight (0)?
args.use_robust_weight = True # Use rational weight rather than softmax
args.additive_bias_type = 1 # (Additive bias: 0 for zero; 1 for DLE; 2 for linear)
args.multiplicative_bias_type = 1 # (Multiplicative bias: 0 for constant; 1 for DLE; 2 for linear)
# args.t_shift = None # Default
# args.t_shift = 1.0
# args.t_shift = args.k_step
args.learn_t_shift = True
if args.learn_t_shift == True:
    args.t_shift = None
args.learn_rotations = True # Learned rotations (True), or fixed as in RoPE (False)?
args.learn_decay = True # Learned decay (True), or fixed (False)?
args.rotate_values = True # Rotate/unrotate values?
args.zero_process_noise = False # Zero process noise (sigma^2)?
args.zero_key_measurement_noise = False # Zero key measurement noise (eta^2)?
args.use_total_precision_gate = 1 # Use total-precision gating? (0 = No gate, 1 = precision gate, 2 = learned gate)
args.use_inner_residual = True # Include a residual connection BEFORE output projection?
args.use_outer_residual = True # Include a residual connection AFTER output projection?
args.use_complex_input_norm = 0 # Use complex-valued RMS Norm AFTER input projection for query/key/value (1), complex-valued RMS Norm AFTER input projection only for query/key (2), or None (0)?
args.use_complex_output_norm = False # Use complex-valued RMS Norm BEFORE output projection?
args.use_real_input_norm = True # Use real-valued RMS Norm BEFORE input projection?
args.use_real_output_norm = True # Use real-valued RMS Norm AFTER output projection?
args.add_gaussian_noise = False # Add Gaussian noise to final token? (for test-time sampling)
args.use_complex_conj_constraint = True # Eigenvalues must appear in complex conjugate pairs to ensure A is real
args.use_colored_prior = False
# args.allow_BM_branch = True # Allow separate branch for Brownian motion? (only used when learning decay)
args.damping = 0.005
args.use_ss_process_noise = True
args.scale_decay_by_time_interval = False
args.zero_rotations = False
args.use_SC_RoPE = False
args.use_log_linear_decay = False
# -----------------------------
args.use_rope = True
args.use_alibi = False

# args.learnable_rope = True

args.use_relative_decay_vanilla = False

In [None]:
if args.t_equal == False:
    print('Note: t_equal == False path is broken.')

In [None]:
##################
## DEFINE MODEL ##
##################

# Settings
args.n_heads = 1 # Number of heads
args.d_k_total = args.d_e # Total query-key dim across all heads
args.d_v_total = args.d_e # Total value dim across all heads
# args.num_blocks = 3

###########################################################

# # Standard Attention

# model = SelfAttentionBlock(input_dim=args.d_e, qkv_dim=args.d_e, num_heads=args.n_heads, args=args)
    
###########################################################

# Multihead Isotropic RFA

model = RFA_Block(args, args.n_heads, input_dim=args.d_e, query_key_dim_total=args.d_k_total, value_dim_total=args.d_v_total)

###########################################################

# # Standard Transformer

# args.num_blocks = 6
# model = TransformerNetwork(args.d_e, args.d_v*2, args.d_v*4, args.n_heads, args, num_blocks=args.num_blocks)

###########################################################

# # RFA Transformer

# args.num_blocks = 6
# model = RFATransformerNetwork(args=args, num_blocks=args.num_blocks, n_heads=args.n_heads, input_dim=args.d_e, query_key_dim_total=args.d_k, value_dim_total=args.d_v, hidden_dim = 4*args.d_v, Norm=nn.LayerNorm)

###########################################################

model.to(args.device)

print(model)

params_list = list(model.parameters()) # Parameters list

print('Total parameter count:', count_parameters(model))

### Training

In [None]:
####################
## TRAINING SETUP ##
####################

criterion = nn.MSELoss() # Loss

args.batch_size = 32 # Batch size
args.num_epochs = int(4000 / args.num_batch) # Number of epochs
args.num_its = int(args.num_samp/args.batch_size) # Number of iterations in an epoch

args.save_model = False

args.save_epochs = int(args.num_epochs/5) # Intervals of epochs to save model
# args.show_example_epochs = int(args.num_epochs/10) # Number of epochs between displaying results
args.show_example_epochs = 5
args.n_example = 5 # Plot state estimates at n_example data points

#####################

# Create folders for model weights, and loss history

try:
    root_path = os.path.dirname(os.path.abspath(__file__))
except NameError:
    root_path = os.getcwd()

saved_models_path = os.path.join(root_path, 'saved_models\\dynamics\\')
model_name = str(model.__class__.__name__)
date = str(datetime.datetime.today()).split()[0]
# model_path = saved_models_path + model_name + '__' + date + '\\'
model_path = os.path.join(saved_models_path, f"{model_name}__{date}")

model_weight_path = model_path + 'model_weights\\'
model_tensor_path = model_path + 'info_tensors\\'
model_images_path = model_path + 'train_imgs\\'

try:
    os.makedirs(model_path, exist_ok=True)
except:
    pass
try:
    os.makedirs(model_weight_path, exist_ok=True)
    os.makedirs(model_tensor_path, exist_ok=True)
except:
    pass

In [None]:
####################

# Optimizer

# args.lr = 1E-2 # Learning rate
# optimizer = torch.optim.Adam(params_list, lr=args.lr, betas=(0.9, 0.999)) # Optimizer

# Separate the "Physics" from the "Features"
sde_params = [p for n, p in model.named_parameters() if any(k in n for k in ['mu_', 'sigma_', 'eta_', 'gamma_'])]
feature_params = [p for n, p in model.named_parameters() if not any(k in n for k in ['mu_', 'sigma_', 'eta_', 'gamma_'])]

feature_lr = 1e-2
# feature_lr = 1e-3
sde_lr = feature_lr/2

# optimizer = torch.optim.Adam([
#     {'params': feature_params, 'lr': feature_lr},
#     {'params': sde_params, 'lr': sde_lr}  # slower to prevent spikes
# ])

optimizer = torch.optim.Adam([
    # Standard Weights (with momentum)
    {'params': feature_params, 'lr': feature_lr, 'betas': (0.9, 0.999)},
    
    # Decay Params (Lower learning rate, NO momentum, higher epsilon)
    {
        'params': sde_params, 
        'lr': sde_lr,          # Lower learning rate
        'betas': (0.0, 0.999), # First beta=0 kills momentum
        'eps': 1e-7            # Higher eps prevents division-by-zero spikes
    }
])

########################3

In [None]:
# Learning rate scheduler

# scheduler = None

################################

# Cosine annealing

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, 
#     max_lr=args.lr,        # Target learning rate
    max_lr = [feature_lr, sde_lr],
    epochs=args.num_epochs, 
    steps_per_epoch=len(train_loader),
    pct_start=0.05,          # Spend 5% of training time warming up
    anneal_strategy='cos',   # Use cosine decay
    div_factor=25.0,         # Start LR is args.lr / 25
    final_div_factor=1000.0  # Final LR is args.lr / 1000
)

################################

# # Multi-stage learning rate schedule (for SDE param training)
# 1. Warmup: Stabilizes initial gradients. Projections (W_q, W_k, W_v) must first align 
#    embeddings into a coherent latent space.
# 2. Simmer: The "Co-adaptation" phase. SDE parameters (especially decay mu) rarely 
#    converge until the feature coordinate system is stable. Once the loss curve 
#    levels off, the SDE params begin to capture the structural temporal dependencies.
#    While SDE params contribute marginally to training loss, they are helpful for model
#    robustness and generalization (out-of-distribution performance).
# 3. Freeze: High-precision refinement.

# def get_schedules(optimizer, warmup_steps, milestone, total_steps, feature_lr, sde_lr):
    
#     # Define the multipliers relative to the initial LR of each group
#     # Group 0: Features, Group 1: SDE
    
#     def simmer_lambda(step):
#         # Cosine from 1.0 down to 0.1
#         if step < warmup_steps: return 1.0 # Should be handled by SequentialLR
#         t = (step - warmup_steps) / (milestone - warmup_steps)
#         return 0.1 + 0.9 * (1 + math.cos(math.pi * t)) / 2

#     def freeze_lambda(step):
#         # Cosine from 0.1 down to 0.001
#         t = (step - milestone) / (total_steps - milestone)
#         return 0.001 + 0.099 * (1 + math.cos(math.pi * t)) / 2

#     # 1. Warmup
#     s1 = torch.optim.lr_scheduler.LinearLR(
#         optimizer, start_factor=1/25.0, end_factor=1.0, total_iters=warmup_steps
#     )

#     # 2. Simmer
#     s2 = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=simmer_lambda)

#     # 3. Freeze
#     s3 = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=freeze_lambda)

#     return torch.optim.lr_scheduler.SequentialLR(
#         optimizer, 
#         schedulers=[s1, s2, s3], 
#         milestones=[warmup_steps, milestone]
#     )

# # Define steps
# steps_per_epoch = len(train_loader)
# total_steps = steps_per_epoch * args.num_epochs
# warmup_steps = int(0.10 * total_steps)
# milestone = int(0.70 * total_steps)
# final_steps = total_steps - milestone

# scheduler = get_schedules(optimizer, warmup_steps, milestone, total_steps, feature_lr, sde_lr)

In [None]:
# ###### Initialize model to correct values (for testing) ######
# ##############################################################

# if isinstance(model, MultiheadIsotropicRFA_1layer):
#     module = model.layers[0]
#     initialize_to_correct_model(module, D1, S1, Si1, sigma_process, sigma_measure, args)

In [None]:
######  Load in model weights ######
####################################

# epoch = 410
# model_weight_path_full = model_weight_path + 'weights_epoch_' + str(epoch)

# try:
#     model.load_state_dict(torch.load(model_weight_path_full, weights_only=True))
#     print('Model loaded.')
# except:
#     print('Model loading failed.')
#     pass

In [None]:
########################### RFA TRAINING LOOP ###########################
#########################################################################

print(f"Starting training on {args.device}...")

history = {
    'loss': [],
    'mu': [],     # Each entry will be [n_heads]
    'sigma': [],
    'sigma_tilde': [],
    'eta': [],
    'gamma': [],
    'tau': [],
    'nu_over_d': [],
    'gs_mean': [] # Mean gate value per iteration
}

# Train for num_epochs:
for epoch in tqdm(np.arange(args.num_epochs), desc="Training progress..."):

    # Train for single epoch
    output_dict, history = single_epoch_rfa_dynamics(model, train_loader, history, optimizer, criterion, params_list, args, t_equal=args.t_equal, t_shift=args.t_shift, causal=args.causal, scheduler=scheduler)

    # Visualize results so far:
    if np.mod(epoch+1,args.show_example_epochs) == 0:
        model.eval()
        with torch.no_grad():
            visualize_results(
            model, train_dataset, history, 
            R1, R1i, Pu, Pd, A, epoch, model_images_path, args, t_equal=args.t_equal, t_shift=args.t_shift, causal=args.causal,
            plot_losses_flag=False, plot_log_losses_flag=True, plot_traj_flag=True,
            plot_pulled_forward_estimates_flag=True, plot_last_attn_mat_flag=False,
            plot_total_precision_flag=False, plot_attn_prior_flag=False, plot_eigenvals_flag=True,
            plot_decay_per_epoch=False, plot_noise_params_by_type=False, plot_noise_params_by_head=False,
            plot_tau_and_nu_flag=False, plot_gates_per_epoch=False)
        model.train()
    
    # Save model and training history
    if args.save_model == True and np.mod(epoch + 1, args.save_epochs) == 0:
            # Save Model Weights
            model_weight_path_full = os.path.join(model_weight_path, f'weights_epoch_{epoch}.pth')
            torch.save(model.state_dict(), model_weight_path_full)

            # Prepare Tensors/History for saving
            model_tensor_path_full = os.path.join(model_tensor_path, f'tensors_epoch_{epoch}.pth')

            tensors_to_save = {
                'history': history,                    # All scalar tracking (mu, sigma, loss, etc.)
                'attn_mat': output_dict['attn_mat'],   # Latest attention matrix
                'eigenvals': output_dict['eigenvals'], # Latest eigenvalues
                'args': args                           # Save args so we know the config later
            }

            torch.save(tensors_to_save, model_tensor_path_full)
            print(f"Checkpoint saved: Epoch {epoch+1}")

In [None]:
## JUMP

# Final vizualization
model.eval()
with torch.no_grad():
    visualize_results(
            model, train_dataset, history, 
            R1, R1i, Pu, Pd, A, epoch, model_images_path, args, t_equal=args.t_equal, t_shift=args.t_shift, causal=args.causal,
            plot_losses_flag=False, plot_log_losses_flag=True, plot_traj_flag=True,
            plot_pulled_forward_estimates_flag=True, plot_last_attn_mat_flag=True,
            plot_total_precision_flag=False, plot_attn_prior_flag=True, plot_eigenvals_flag=True,
            plot_decay_per_epoch=True, plot_noise_params_by_type=True, plot_noise_params_by_head=False,
            plot_tau_and_nu_flag=True, plot_gates_per_epoch=True)

In [None]:
# # Test model
# with torch.no_grad():

#     # Get prediction for random choice of input
#     rand_idx = np.random.choice(args.num_samp)
#     train_data, X_true, X_measure, t_measure = train_dataset.__getitem__(rand_idx)

#     inputs = train_data.unsqueeze(0)[:, :-1]
    
#     print(inputs.size())

#     out, output_dict = model.forward(inputs)
    
#     est = output_dict['est_latent']
#     attn_mat = output_dict['attn_mat']
#     A_prior = output_dict['A_prior']
#     x_hat = output_dict['x_hat']
#     lambda_h = output_dict['epoch_lambdas']
#     P_tot = output_dict['P_tot']
#     gate = output_dict['gate']

In [None]:
# # Save Model
# model_weight_path_full = model_weight_path + 'weights_epoch_' + str(epoch)
# model_tensor_path_full = model_tensor_path + 'tensors_epoch_' + str(epoch)
# torch.save(model.state_dict(), model_weight_path_full)

# tensors_to_save = {
# 'epoch_losses': epoch_losses,
# 'A_hat_complex': A_hat_complex,
# 'epoch_lambdas': epoch_lambdas
# }
# torch.save(tensors_to_save, model_tensor_path_full)

In [None]:
# #########################################
# ##### LOOP FOR STANDARD ATTENTION #######

# all_losses = [] # Global list to store every iteration

# for epoch in tqdm(np.arange(args.num_epochs), desc="Training progress..."):

#     epoch_losses = single_epoch_attn_dynamics(model, train_loader, optimizer, criterion, params_list, args)

#     # Append the epoch's iterations to our global history
#     all_losses.extend(epoch_losses)
    
#     # Visualize results so far:
#     if np.mod(epoch+1,args.show_example_epochs) == 0:
#         visualize_results_attn(model, train_dataset, all_losses, R1, R1i, Pu, Pd, A, epoch, args)
# #########################################

In [None]:
# if np.mod(epoch+1,args.show_example_epochs) == 0:
#     visualize_results_attn(model, train_dataset, all_losses, R1, R1i, Pu, Pd, A, epoch, args)

In [None]:
##################################
## Test autoregressive sampling ##
##################################

args.add_gaussian_noise = 1
args.max_gen_len = 100

# for it, (train_data, _, _, _) in enumerate(train_loader):
#     start_seq = train_data[:, :-1]
for it, (inputs, _, _, _, t_measure) in enumerate(train_loader):
    start_seq = inputs
    break

total_seq, new_seq, precisions = autoregressive_sample(model, start_seq, args.max_gen_len, t_measure, t_shift=args.t_shift, t_equal=args.t_equal, causal=True)

traj_tot = torch.matmul(R1i,total_seq.unsqueeze(-1)) # Reverse random mapping
X_tot = torch.matmul(Pd,traj_tot) # Map back to lower dim

traj_new = torch.matmul(R1i,new_seq.unsqueeze(-1)) # Reverse random mapping
X_new = torch.matmul(Pd,traj_new) # Map back to lower dim

# Predicted trajectory
X_tot_plt = X_tot.squeeze(0)[0].detach().cpu().squeeze().numpy()
plt.plot(X_tot_plt.T[0], X_tot_plt.T[1], 'r--', label='Predicted') # Added label

X_new_plt = X_new.squeeze(0)[0].detach().cpu().squeeze().numpy()
plt.plot(X_new_plt.T[0], X_new_plt.T[1], 'b', label='Predicted') # Added label

plt.grid()

In [None]:
# Stack and normalize collected precisions
p_tensor = torch.stack(precisions, dim=1) # [B, max_gen_len]
# Normalize relative to the sequence's maximum confidence for the heatmap
p_min, p_max = p_tensor.min(), p_tensor.max()
p_normalized = (p_tensor - p_min) / (p_max - p_min + 1e-8)
plt.plot(p_normalized[:,-1].detach().cpu().numpy())

In [None]:
self = model.layers[0]
tau = F.softplus(self.tau_param) + self.args.epsilon # Softmax temperature
print(tau)
d = self.d_k_head
nu = d * F.softplus(self.nu_param) + 2.0 + self.args.epsilon
print(nu/d)

In [None]:
print(F.softplus(model.layers[0].t_shift_param))

In [None]:
# def recursive_state_update(model, start_seq, num_steps, t_measure=None, t_shift=None):
#     """
#     Iteratively updates a fixed-size trajectory.
#     Input 'start_seq' is transformed into 'out' n times.
#     """
#     model.eval()
#     current_state = start_seq
    
#     traj_start = torch.matmul(R1i,start_seq.unsqueeze(-1)) # Reverse random mapping
#     X_start = torch.matmul(Pd,traj_start) # Map back to lower dim
#     X_start_plt = X_start.squeeze(0)[0].detach().cpu().squeeze().numpy()
#     plt.plot(X_start_plt.T[0], X_start_plt.T[1], 'r--', label='Input') 

#     # We collect the 'state' at each iteration to see how the trajectory evolves
#     # but the sequence length (L) never changes.
#     evolution_history = [] 

#     with torch.no_grad():
#         for _ in range(num_steps):
#             # The model treats the current window as a set of noisy observations
#             # and filters them through the SDE prior.
#             out, output_dict = model(current_state, t_measure=t_measure, t_shift=t_shift)
            
#             # RECURSION: The entire output block is the next input block
#             current_state = out 
            
#             evolution_history.append(current_state)
            
#             traj_filt = torch.matmul(R1i,current_state.unsqueeze(-1)) # Reverse random mapping
#             X_filt = torch.matmul(Pd,traj_filt) # Map back to lower dim
#             X_filt_plt = X_filt.squeeze(0)[0].detach().cpu().squeeze().numpy()
#             plt.plot(X_filt_plt.T[0], X_filt_plt.T[1]) 

#         plt.grid()
#         plt.show()

#     return current_state, evolution_history


# args.add_gaussian_noise = 1
# num_steps = 30

# for it, (inputs, _, _, _, t_measure) in enumerate(train_loader):
#     start_seq = inputs
#     break

# seq_filt, evolution_history = recursive_state_update(model, start_seq, num_steps)