## Adaptive Filter Attention

### Setup

In [None]:
import numpy as np
import math
import scipy

from matplotlib import pyplot as plt
import matplotlib.cm as cm
plt.rcParams['figure.figsize'] = [10, 10]
plt.rc('font', size=20)

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

import transformers
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader

import os
import argparse
import datetime
import time
from tqdm import tqdm # Loading bar
print('Done.')

In [None]:
from utils import complex_matmul, apply_interleaved_rope
from utils import count_parameters, get_layers
print('Done.')

In [None]:
from dynamics import stochastic_LTI, DynamicSim
from dynamics import construct_mapping
from dynamics import get_nth_measurement, get_random_measurements
from dynamics import linear_spiral, linear_spiral_3D, Lorenz, rand_coupling_matrix, Van_der_Pol_osc
print('Done.')

In [None]:
from isotropic_afa import get_safe_exp_tot, compute_covariance_matrix, compute_covariance_matrix_safe
from isotropic_afa import compute_exp_kernel_isotropic, compute_residual_norm_isotropic
print('Done.')

In [None]:
from model import compute_lambda, resolve_multihead_dims
from model import init_complexlinear, init_complex_matrix, set_complex_weight, initialize_linear_layers
from model import initialize_to_correct_model
from model import apply_weight_masks
from model import ComplexLinearLayer, ComplexLinearHermitianLayer, ComplextoRealLinearLayer
from model import MultiHeadAttentionLayer
from model import ComplexRMSNorm
from model import Attention_1layer, AFA_1layer
from model import SimpleAttention_Net
from model import TransformerBlock, TransformerNetwork
from model import MultiheadIsotropicAFA_1layer, AFATransformerBlock, AFATransformerNetwork
# from model import compute_attention_matrix, compute_estimate
print('Done.')

In [None]:
from data_utils import construct_random_mapping, construct_data_dynamics, TrainDataset_dynamics, create_train_loader_dynamics
print('Done.')

In [None]:
from visualization import plot_trajectory, compute_state_matrix, plot_state_matrix, plot_eigenvals, visualize_results
from visualization import visualize_results_attn, _get_visual_modules
print('Done.')

In [None]:
from training import single_iter, single_epoch, hook_fn
from training import single_iter_attn, single_epoch_attn
print('Done.')

In [None]:
parser = argparse.ArgumentParser('DA')
parser.add_argument('--gpu', type=int, default=0) # (Default: 0)
args = parser.parse_args(args=[])
args.device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
print(args.device)

torch.manual_seed(2025)
np.random.seed(2025)

### Training

In [None]:
# Set dynamical system

# # Normal system
# D1 = torch.zeros(2,2,2).to(args.device) # Diagonal matrix
# D1[0] = torch.tensor([[-0.1, 0.0], [0.0, -0.1]]).to(args.device)
# D1[1] = torch.tensor([[-1.0, 0.0], [0.0, 1.0]]).to(args.device)
# # S1 = torch.zeros(2,2,2).to(args.device)
# # S1[0] = torch.tensor(([1.0,1.0],[0.0,0.0]))
# # S1[1] = torch.tensor(([0.0,0.0],[1.0,-1.0]))
# # S1 = U/np.sqrt(2)
# alpha = np.random.uniform(low=0.0, high=1.0)*2*np.pi
# beta = np.random.uniform(low=0.0, high=1.0)*2*np.pi
# S1 = construct_special_2D_unitary(alpha=alpha, beta=beta)
# Si1 = complex_conj_transpose(S1)

# Stable 2D linear system (in diagonalized form):
D1 = torch.zeros(2,2,2).to(args.device) # Diagonal matrix
S1 = torch.zeros(2,2,2).to(args.device) # RHS matrix
Si1 = torch.zeros(2,2,2).to(args.device) # Inverse of RHS matrix
D1[0] = torch.tensor([[-0.1, 0.0], [0.0, -0.1]]).to(args.device) # Negative real part eigenvals
# D1[0] = torch.tensor([[0.0, 0.0], [0.0, 0.0]]).to(args.device) # Zero real part eigenvals
# D1[0] = torch.tensor([[0.1, 0.0], [0.0, 0.1]]).to(args.device) # Positive real part eigenvals
D1[1] = torch.tensor([[-1.0, 0.0], [0.0, 1.0]]).to(args.device)
S1[0] = torch.tensor([[1.0, 1.0], [1.0, 1.0]]).to(args.device)
S1[1] = torch.tensor([[-1.0, 1.0], [0.0, 0.0]]).to(args.device)
Si1[0] = 0.5*torch.tensor([[0.0, 1.0], [0.0, 1.0]]).to(args.device)
Si1[1] = 0.5*torch.tensor([[1.0, -1.0], [-1.0, 1.0]]).to(args.device)
A = complex_matmul(S1,complex_matmul(D1,Si1))[0].unsqueeze(0)
params = [D1, S1, Si1]

In [None]:
# DEFINE MODEL

args.cr_max = 2
args.cr_min = 0

args.max_exponent = 10
args.min_exponent = -10

args.t_equal = 1 # Equal time intervals? (0 or 1)
args.sep_params = 0 # Use separate params for keys and values? (0 or 1)

##################################################
## Testing options
args.rand_embed = 0 # Use random orthogonal matrix when projecting data to higher dimension?
args.weight_mask = 0 # Mask the weights? (0 or 1) (must be set to 1 when visualizing in 2D)
args.lambda_real_zero = 0 # Force real part of lambda to be zero? (0=No, 1=Yes)
args.compute_pulled_forward_estimates = 1
##################################################

##################################################
## Ablation options ##
######################
## (Setting everything to 0 (except lambda_real_zero & outer residual norm) makes it an ordinary Transformer)
args.lambda_real_zero = 0 # Zero out real part of eigenvalues? (1 = yes, 0 = no)
args.use_full_residual_norm = 1 # Use the full |R|^2 metric and rational robust weight (1) or dot product / exponential weight (0)?
args.use_robust_weight = 1 # Use rational weight rather than softmax
args.additive_bias_type = 1 # (Additive bias: 0 for zero; 1 for DLE; 2 for linear)
args.multiplicative_bias_type = 1 # (Multiplicative bias: 0 for constant; 1 for DLE; 2 for linear)
args.compute_next_step_pred = 1 # Compute next-step prediction? (1 = yes, 0 = No)
args.learn_rotations = 1 # Learned rotations (1), or fixed as in RoPE (0)?
args.rotate_values = 1 # Rotate/unrotate values? (1 = yes, 0 = No)
args.zero_process_noise = 0 # Zero process noise (sigma^2)? (1 = yes, 0 = No)
args.zero_key_measurement_noise = 0 # Zero key measurement noise (eta^2)? (1 = yes, 0 = No)
args.use_total_precision_gate = 2 # Use total-precision gating? (0 = No gate, 1 = precision gate, 2 = learned gate)
args.use_complex_input_norm = 0 # Use complex-valued RMS Norm AFTER input projection (1) or real-valued RMS Norm BEFORE input projection (0)?
args.use_complex_output_norm = 0 # Use complex-valued RMS Norm BEFORE output projection (1) or real-valued RMS Norm AFTER output projection (0)?
args.use_inner_residual = 1 # Include a residual connection BEFORE output projection (1) or AFTER output projection (0)?
args.add_gaussian_noise = 0 # Add Gaussian noise to final token (for test-time sampling)

if args.use_inner_residual == 0 and (args.use_total_precision_gate == 1 or args.use_total_precision_gate == 2):
    print('Warning: inner residual is off.')

##################################################

args.m = 2 # Dimension of simulated system
# args.embed_dim = 2 # Embedding dimension
args.d_e = 128 # Embedding dimension

# args.num_heads = 1
# args.head_dim = int(args.embed_dim/args.num_heads)

##################################################

# Key, query, value embedding dimensions are same as input embedding dimension
args.d_k = args.d_e
args.d_v = args.d_e
# args.d_k = args.head_dim
# args.d_v = args.head_dim

# # Key, query, value embedding dimensions are half of input embedding dimension
# args.d_k = int(args.head_dim/2) # Key and query embedding dimension
# args.d_v = int(args.head_dim/2) # Value embedding dimension

# # Key, query, value embedding dimensions are 2
# args.d_k = 2 # Key and query embedding dimension
# args.d_v = 2 # Value embedding dimension

##################################################

args.nu = 1.0 # Measurement weighting
args.tf = 10.0 # Final time
args.dt = 0.01 # Time step size
args.n = 10 # nth measurement (ie use every nth point as a measurement)
args.delta_t = args.dt * args.n

args.N_t = int(args.tf/args.dt) # Number of time steps
args.seq_len = int(args.N_t/args.n)+1 # Number of measurements
# args.seq_len = args.N_t / args.n = (args.tf/args.dt) / (args.delta_t/args.dt) = args.tf/args.delta_t
t_v = (torch.arange(args.N_t + args.n + 1)*args.dt).to(args.device) # Array of time steps

# Some scalar weights in model
args.alpha = 1.0
args.beta = 0.0
args.delta = 1.0
args.eta = 0.0

In [None]:
# DEFINE MODEL:

###########################################################

# # Standard Attention/Transformer Models

# args.model_type = 'RealInputs'
# args.num_heads = 4
# args.num_blocks = 3
# # model = Attention_1layer(args.d_e, args.d_v, args.num_heads, args).to(args.device)
# # model = SimpleAttention_Net(args.d_e, args.d_v*2, args.num_heads, args).to(args.device)
# model = TransformerNetwork(args.d_e, args.d_v*2, args.d_v*4, args.num_heads, args, num_blocks=args.num_blocks).to(args.device)
# loss = nn.MSELoss()
# loss_p = nn.MSELoss()

###########################################################

# # Complex Attention with Real Input/Output Models

# args.model_type = 'RealInputs'
# args.metric = 'RealDotProduct'
# # args.metric_type = 'InverseMahalanobis'
# # model = SimpleComplexAttention_Net(args.d_e, args.d_v, args).to(args.device)
# model = ComplexTransformerNetwork(args.d_e, args.d_v*2, args.d_v*4, args, num_blocks=2).to(args.device)
# loss = nn.MSELoss()

###########################################################

# Multihead Isotropic AFA

args.num_heads = 2 # Number of heads
args.compute_next_step_pred = 1
args.use_inner_residual = 1
# args.d_k = args.d_e
# args.d_v = args.d_e
# model = MultiheadSimplifiedAFA_1layer(args, args.seq_len, args.num_heads, input_dim=args.d_e, query_key_dim=args.d_k, value_dim=args.d_v).to(args.device)
args.d_k_total = args.d_e # Total query-key dim across all heads
args.d_v_total = args.d_e # Total value dim across all heads
model = MultiheadIsotropicAFA_1layer(args, args.num_heads, input_dim=args.d_e, query_key_dim_total=args.d_k_total, value_dim_total=args.d_v_total).to(args.device)

# args.num_blocks = 1 # Number of transformer blocks
# model = MultiheadSimplifiedAFATransformerNetwork(args, args.seq_len, num_blocks=args.num_blocks, n_heads=args.num_heads, input_dim=args.d_e, query_key_dim_total=args.d_k_total, value_dim_total=args.d_v_total, Norm=Norm).to(args.device)

loss = nn.MSELoss() # Loss
loss_p = nn.MSELoss() # Penalty

###########################################################

# # AFA Transformer

# args.compute_next_step_pred = 0
# args.num_heads = 2 # Number of heads
# args.num_blocks = 2 # Number of transformer blocks
# args.d_k_total = args.d_e # Total query-key dim across all heads
# args.d_v_total = args.d_e # Total value dim across all heads
# model = AFATransformerNetwork(args=args, num_blocks=args.num_blocks, n_heads=args.num_heads, input_dim=args.d_e, query_key_dim_total=args.d_k, value_dim_total=args.d_v, hidden_dim = 4*args.d_v, Norm=nn.LayerNorm).to(args.device)

# loss = nn.MSELoss()
# loss_p = nn.MSELoss()

###########################################################

In [None]:
params_list = list(model.parameters()) # Parameters list

# loss_p = Complex_Trace_Loss() # Trace Penalty
lr = 1E-2 # Learning rate
optimizer = torch.optim.Adam(params_list, lr=lr, betas=(0.9, 0.999)) # Optimizer

if args.d_k != args.d_v and args.sep_params == 0:
    print('ERROR: Key and value embedding dimensions must be the same if using shared parameters.')

print('Total parameter count:', count_parameters(model))

#####################

# Create folders for model weights, and loss history

saved_models_path = 'C:/Users/Pracioppo/Desktop/AFA, Dec 11/saved_models/'
model_name = str(model.__class__.__name__)
date = str(datetime.datetime.today()).split()[0]
model_path = saved_models_path + model_name + '__' + date + '/'

model_weight_path = model_path + 'model_weights/'
model_tensor_path = model_path + 'info_tensors/'

try:
    os.mkdir(model_path)
    os.mkdir(model_weight_path)
    os.mkdir(model_tensor_path)
except:
    pass

print('Model loaded.')

In [None]:
## CREATE TRAINING DATA

args.concat_mag = 0

# Training params
args.num_epochs = 1000 # Number of epochs
args.num_samp = 32 # Number of samples in train loader
args.batch_size = 32 # Batch size
args.num_its = int(args.num_samp/args.batch_size) # Number of iterations in an epoch
args.save_epochs = 100 # Intervals of epochs to save model
args.show_example_epochs = 5 # Number of epochs between displaying results so far
args.n_example = 5 # Plot state estimates at n_example data points
args.epsilon = 1E-5

# Process and measurement noise:
sigma_process = 0.0 # Process noise
sigma_measure = 2.0 # Measurement noise

#####################################################################

# Initialize arrays to record losses
mean_epoch_losses = np.zeros(args.num_epochs)
log_mean_epoch_losses = np.zeros(args.num_epochs)
all_losses = np.zeros(args.num_epochs * args.num_samp)
all_lambdas = np.zeros((args.num_epochs, 2, args.d_v))

Pu, Pd, R1, R1i = construct_random_mapping(S1, Si1, args) # Get random matrices

# Build training dataset
train_loader, train_dataset, X_true_all, X_measure_all, t_measure_all = create_train_loader_dynamics(A, S1, Si1, Pu, Pd, R1, R1i, t_v, sigma_process, sigma_measure, args)

In [None]:
# PLOT ALL TRAJECTORIES

from matplotlib import patches
fig, ax = plt.subplots()

for it, (train_data, X_true, X_measure, t_measure_full) in enumerate(train_loader):
# for it, train_data in enumerate(train_loader):
#     for i in range(args.batch_size):

#     if X_true.size()[2] == 1:
#         X_true = torch.stack((X_true, torch.zeros_like(X_true)),dim=2)
#         X_measure = torch.stack((X_measure, torch.zeros_like(X_measure)),dim=2)

    # Actual trajectory
    X_true_plt = X_true.squeeze().detach().cpu().numpy()
    plt.plot(X_true_plt.T[0],X_true_plt.T[1],'black')

    # Noisy trajectory
    traj = X_measure.detach().cpu().squeeze().numpy()
    plt.plot(traj.T[0], traj.T[1], 'b--')
    
    x0 = X_true_plt[:,0].T
    plt.scatter(x0[0],x0[1], color='red')

#     circle = patches.Circle((0,0), 17.0, edgecolor='red', facecolor='none', linewidth = 2)
#     ax.add_patch(circle)

    plt.axis('equal')
    plt.grid()
    plt.show()

In [None]:
# ###### Initialize model to correct values (for testing) ######
# ##############################################################

# if isinstance(model, MultiheadIsotropicAFA_1layer):
#     module = model.layers[0]
#     initialize_to_correct_model(module, D1, S1, Si1, sigma_process, sigma_measure, args)

In [None]:
######  Load in model weights ######
####################################

# epoch = 410
# model_weight_path_full = model_weight_path + 'weights_epoch_' + str(epoch)

# try:
#     model.load_state_dict(torch.load(model_weight_path_full, weights_only=True))
#     print('Model loaded.')
# except:
#     print('Model loading failed.')
#     pass

In [None]:
########################### TRAINING LOOP ###########################
#####################################################################

###### LOOP FOR PRECISION ATTENTION LAYERS #######

# Train for num_epochs:
for epoch in tqdm(np.arange(args.num_epochs), desc="Training progress..."):

    # Train for single epoch
    epoch_losses, A_hat_complex, epoch_lambdas, unnormalized_attention = single_epoch(model, train_loader, optimizer, loss, loss_p, params_list, args)

    # Collect losses
    all_losses[epoch*args.num_its:(epoch+1)*args.num_its] = epoch_losses
    mean_epoch_losses[epoch] = np.mean(epoch_losses)
    log_mean_epoch_losses[epoch] = np.log(np.mean(epoch_losses))
    all_lambdas[epoch,:,:] = np.mean(epoch_lambdas,axis=0)
    
    # Visualize results so far:
    if np.mod(epoch+1,args.show_example_epochs) == 0:
        visualize_results(model, train_dataset, all_losses, mean_epoch_losses, log_mean_epoch_losses, all_lambdas, R1, R1i, Pu, Pd, A, epoch, args)

#     if np.mod(epoch+1,args.save_epochs) == 0:
#         model_weight_path_full = model_weight_path + 'weights_epoch_' + str(epoch)
#         model_tensor_path_full = model_tensor_path + 'tensors_epoch_' + str(epoch)
#         torch.save(model.state_dict(), model_weight_path_full)

#         tensors_to_save = {
#         'epoch_losses': epoch_losses,
#         'A_hat_complex': A_hat_complex,
#         'epoch_lambdas': epoch_lambdas
#         }
#         torch.save(tensors_to_save, model_tensor_path_full)

In [None]:
## JUMP

self = model.layers[0]
Z_q = Z_k = Z_v = torch.rand([32, 1, 101, 128]).to(args.device)

In [None]:
batch_size = Z_q.size()[0]
seq_len = Z_q.size()[2]

if t_measure_all == None:
#             t_measure = torch.arange(seq_len).to(self.args.device) * self.args.delta_t
    t_measure = torch.arange(seq_len).to(self.args.device)
else:
    if len(t_measure_all.size()) > 1:
        t_measure = t_measure_all[0,:-1]
    else:
        t_measure = t_measure_all[:,:-1]

# Normalize time vector by total time elapsed (optional)
#         t_measure = t_measure / (t_measure[-1] - t_measure[0]).unsqueeze(0)

#######################################

# Apply linear projections
Q_proj = self.W_q(Z_q)
K_proj = self.W_k(Z_k)
V_proj = self.W_v(Z_v)

# Split into real/imaginary parts
Q_proj = Q_proj.view(batch_size, seq_len, 2, self.d_k_total)
K_proj = K_proj.view(batch_size, seq_len, 2, self.d_k_total)
V_proj = V_proj.view(batch_size, seq_len, 2, self.d_v_total)

# Split into heads
Q = Q_proj.view(batch_size, seq_len, 2, self.n_heads, self.d_k_head)
K = K_proj.view(batch_size, seq_len, 2, self.n_heads, self.d_k_head)
V = V_proj.view(batch_size, seq_len, 2, self.n_heads, self.d_v_head)

# Move real/imaginary index to the end
Q = Q.permute(0,1,3,4,2).contiguous()
K = K.permute(0,1,3,4,2).contiguous()
V = V.permute(0,1,3,4,2).contiguous()

#######################################

# Optionally, apply complex-valued normalization on inputs (per head)
if self.args.use_complex_input_norm == 1: # Normalize query, key, and value
    Q_norm = self.cn_q(Q)
    K_norm = self.cn_k(K)
    V_norm = self.cn_v(V)
elif self.args.use_complex_input_norm == 0: # No normalization
    Q_norm = Q
    K_norm = K
    V_norm = V
elif self.args.use_complex_input_norm == 2: # Normalize only query and key
    Q_norm = self.cn_q(Q)
    K_norm = self.cn_k(K)
else:
    print('Eror: args.use_complex_input_norm must be 0, 1, or 2.')

#######################################

mu_v, omega_v, lambda_v = compute_lambda(self.mu_v, self.lambda_imag_v, self.args) # Compute eigenvals

# Ensure non-negativeness of noise parameters
sigma_squared_v = F.softplus(self.sigma_v) # Process noise
eta_squared_v = F.softplus(self.eta_v) + self.args.epsilon # Measurement noise
gamma_squared_v = F.softplus(self.gamma_v) + self.args.epsilon # Anchor measurement noise

if self.args.zero_process_noise == 1:
    sigma_squared_v = torch.zeros_like(sigma_squared_v).to(sigma_squared_v.device)
if self.args.zero_key_measurement_noise == 1:
    eta_squared_v = torch.zeros_like(eta_squared_v).to(eta_squared_v.device)

nu = F.softplus(self.nu_sqrt) + self.args.epsilon # Scaling factor

#         noise_params = torch.stack((sigma_squared_v, eta_v, gamma_v, nu))
#         print(noise_params)
#         print('sigma^2 = ', sigma_squared_v.detach().cpu().numpy())
#         print('eta^2   = ', eta_squared_v.detach().cpu().numpy())
#         print('gamma^2 = ', gamma_squared_v.detach().cpu().numpy())
#         print('nu      = ', nu.detach().cpu().numpy())

# Get relative time difference
t_measure_i = t_measure.squeeze().unsqueeze(1)  # [m, 1]
t_measure_j = t_measure.squeeze().unsqueeze(0)  # [1, m]
Delta_T = torch.abs(t_measure_i - t_measure_j).unsqueeze(-1)  # [m, m]

# Clamp the exponent to ensure safe values
exp_rel_v = mu_v * Delta_T
exp_rel_safe_v = torch.clamp(exp_rel_v, min=self.args.min_exponent, max=self.args.max_exponent)

####################################################

Phi_tilde_plus_v, E_rel_v = compute_exp_kernel_isotropic(omega_v, t_measure, exp_rel_safe_v)

V_ij_v = compute_covariance_matrix_safe(mu_v, Delta_T, exp_rel_safe_v, sigma_squared_v, eta_squared_v, gamma_squared_v, t_measure, self.args)

if self.args.sep_params == 1:
    mu_k, omega_k, lambda_k = compute_lambda(self.mu_k, self.lambda_imag_k, self.args)

    sigma_squared_k = F.softplus(self.sigma_k) # Process noise
    eta_squared_k = F.softplus(self.eta_k) + self.args.epsilon # Measurement noise
    gamma_squared_k = F.softplus(self.gamma_k) + self.args.epsilon # Anchor noise

    if self.args.zero_process_noise == 1:
        sigma_squared_k = torch.zeros_like(sigma_squared_k).to(sigma_squared_k.device)
    if self.args.zero_key_measurement_noise == 1:
        eta_squared_k = torch.zeros_like(eta_squared_k).to(eta_squared_k.device)

    #################

    Phi_tilde_plus_k, E_rel_k = compute_exp_kernel_isotropic(omega_k, t_measure, exp_rel_safe_k)

    exp_rel_k = mu_k * Delta_T
    exp_rel_safe_k = torch.clamp(exp_rel_k, min=self.args.min_exponent, max=self.args.max_exponent)

    V_ij_k = compute_covariance_matrix_safe(mu_k, Delta_T, exp_rel_safe_k, sigma_squared_k, eta_squared_k, gamma_squared_k, t_measure, self.args)

else:
    V_ij_k = V_ij_v
    Phi_tilde_plus_k = Phi_tilde_plus_v
    E_rel_k = E_rel_v
    sigma_squared_k = sigma_squared_v
    eta_squared_k = eta_squared_v
    gamma_squared_k = gamma_squared_v

cos_k = Phi_tilde_plus_k[...,0]
sin_k = Phi_tilde_plus_k[...,1]
cos_v = Phi_tilde_plus_v[...,0]
sin_v = Phi_tilde_plus_v[...,1]

Q_tilde = apply_interleaved_rope(Q_norm, cos_k, -sin_k)
K_tilde = apply_interleaved_rope(K_norm, cos_k, -sin_k)
V_tilde = apply_interleaved_rope(V_norm, cos_v, -sin_v)

R_qk_abs_squared = compute_residual_norm_isotropic(Q_tilde, K_tilde, E_rel_k, self.args)

A, A_prior, unnormalized_attention = compute_attention_matrix(self, R_qk_abs_squared, V_ij_k, V_ij_v, Delta_T, nu, sigma_squared_k, eta_squared_k, self.args)

A_hat = A * E_rel_v # Scale attention weights by the relative decay

est_v = compute_estimate(A_hat, V_tilde, Phi_tilde_plus_v, self.args)

#############################################

# Add Guassian noise to last step (for test-time sampling)
if self.args.add_gaussian_noise == 1:
    P_tot = torch.sum(unnormalized_attention,-2) # Total precision
    V_tot = 1/(P_tot)
    V_tot_last = V_tot[:,-1]
    V_tot_last = V_tot_last + gamma_squared_v
    std_dev = torch.sqrt(V_tot_last/2) # Divide by 2 to deal with complex numbers
    gaussian_noise = torch.randn_like(std_dev) * std_dev # Sample from normal distrb
    gaussian_noise_unsqueeze = gaussian_noise.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
    est_v[:, -1:] = est_v[:, -1:] + gaussian_noise_unsqueeze # Add Gaussian noise to last step
else:
    pass

# Add residual connection
if self.args.use_inner_residual == 1:
    if self.args.use_total_precision_gate == 0: # Simple residual
        est_latent = est_v + V # Just add them

    elif self.args.use_total_precision_gate == 1: # Precision gate
#                P_tot/ (P_tot + P_prior) = 1 / (1 + P_prior/P_tot)
#                = 1/(1 + e^{-(ln(P_tot) - ln(P_prior)}) = sigmoid[ln(P_tot) - ln(P_prior)] 
        P_tot = torch.sum(unnormalized_attention,-2).unsqueeze(-1).unsqueeze(-1)
        P_tot_log = torch.log(P_tot)
        P_prior_log = torch.log(self.P_prior_param**2 + self.args.epsilon)
        g = torch.sigmoid(P_tot_log * self.P_scale**2 - P_prior_log) # Mean precision gate               
        est_latent = g * est_v + (1-g) * V # Update using convex combination

    elif self.args.use_total_precision_gate == 2: # Learned gate

        g = torch.sigmoid(self.P_scale) # Learned gate
        est_latent = g * est_v + (1-g) * V # Update using convex combination
#                 est_latent = g * est_v + V # No gate on residual connection
    else:
        print('Error: args.use_total_precision_gate must be 0, 1, or 2.')
else:
    est_latent = est_v # No residual

# Optionally, use complex normalization on outputs
if self.args.use_complex_output_norm == 1:
    est_norm = self.cn_o(est_latent)
else:
    est_norm = est_latent

# Predict next step using same matrix exponential used for estimation
if self.args.compute_next_step_pred == 1:
    cos_v_one_step = Phi_tilde_plus_v[1,:,:,0].unsqueeze(0)
    sin_v_one_step = Phi_tilde_plus_v[1,:,:,1].unsqueeze(0)
    pred_p = torch.exp(mu_v).unsqueeze(-1).unsqueeze(-1) * apply_interleaved_rope(est_norm, cos_v_one_step, sin_v_one_step)

else:
    pred_p = est_norm

# -------------------
# Move real/imag dimension back
pred_p_permute = pred_p.permute(0,1,4,2,3).contiguous()

# Merge heads
pred_p_reshape = pred_p_permute.view(batch_size, seq_len, 2, self.d_v_total)     

# Stack complex numbers into last dimension
pred_p_stack = pred_p_reshape.view(batch_size, 1, seq_len, self.d_v_total*2)

#         # Move real/imag dimension back, merge heads, and stack complex numbers into last dimension
#         pred_p_stack = pred_p.permute(0, 1, 4, 2, 3).reshape(batch_size, 1, seq_len, -1)

# Map back to original basis and get real part
out = self.W_o(pred_p_stack)

In [None]:
# # Save Model
# model_weight_path_full = model_weight_path + 'weights_epoch_' + str(epoch)
# model_tensor_path_full = model_tensor_path + 'tensors_epoch_' + str(epoch)
# torch.save(model.state_dict(), model_weight_path_full)

# tensors_to_save = {
# 'epoch_losses': epoch_losses,
# 'A_hat_complex': A_hat_complex,
# 'epoch_lambdas': epoch_lambdas
# }
# torch.save(tensors_to_save, model_tensor_path_full)

In [None]:
# ##### LOOP FOR REAL/COMPLEX ATTENTION NETS #######

# for epoch in tqdm(np.arange(args.num_epochs), desc="Training progress..."):

#     epoch_losses = single_epoch_attn(model, train_loader, optimizer, loss, params_list, args)

#     # Collect losses
#     all_losses[epoch*args.num_its:(epoch+1)*args.num_its] = epoch_losses
#     log_mean_epoch_losses[epoch] = np.log(np.mean(epoch_losses))
    
#     # Visualize results so far:
#     if np.mod(epoch+1,args.show_example_epochs) == 0:
#         visualize_results_attn(model, train_dataset, all_losses, mean_epoch_losses, log_mean_epoch_losses, all_lambdas, R1, R1i, Pu, Pd, A, epoch, args)
# ##################################################

In [None]:
# Sample the model autoregressively
def autoregressive_sample(model, start_seq, max_gen_len):

    with torch.no_grad():
        model.eval()
        current_seq = start_seq
        total_seq = start_seq
        
        window_size = start_seq.size(2)

        for i in range(max_gen_len):
            # Forward pass through the model
            out, output_dict = model(current_seq)

            # Extract the last generated token (the prediction for the next step)
            next_token = out[:, :, -1].unsqueeze(2)

            # Append new token
            total_seq = torch.cat([total_seq, next_token], dim=2)

            # Remove first token to keep context window fixed
            current_seq = total_seq[:,:,-window_size:]

    new_seq = total_seq[:,:,-args.max_gen_len:]

    return total_seq, new_seq
    
args.add_gaussian_noise = 1
args.max_gen_len = 100

for it, (train_data, _, _, _) in enumerate(train_loader):
    start_seq = train_data[:, :-1].unsqueeze(1)
    break
    
total_seq, new_seq = autoregressive_sample(model, start_seq, args.max_gen_len)
    
traj_tot = torch.matmul(R1i,total_seq.unsqueeze(-1)) # Reverse random mapping
X_tot = torch.matmul(Pd,traj_tot) # Map back to lower dim

traj_new = torch.matmul(R1i,new_seq.unsqueeze(-1)) # Reverse random mapping
X_new = torch.matmul(Pd,traj_new) # Map back to lower dim

# Predicted trajectory
X_tot_plt = X_tot.squeeze(0)[0].detach().cpu().squeeze().numpy()
plt.plot(X_tot_plt.T[0], X_tot_plt.T[1], 'r--', label='Predicted') # Added label

X_new_plt = X_new.squeeze(0)[0].detach().cpu().squeeze().numpy()
plt.plot(X_new_plt.T[0], X_new_plt.T[1], 'b', label='Predicted') # Added label

plt.grid()