## Test Multihead Functionality

In [1]:
import numpy as np
import math
import scipy

from matplotlib import pyplot as plt
import matplotlib.cm as cm
plt.rcParams['figure.figsize'] = [10, 10]
plt.rc('font', size=20)

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.utils.data import Dataset, DataLoader

import os
import argparse
import time
from tqdm import tqdm # Loading bar
print('Done.')

Done.


In [2]:
from utils import complex_conj_transpose, batched_complex_conj_transpose, complex_exp, complex_exp_v2, complex_hadamard, complex_matmul, complex_division
from utils import batched_complex_conj_transpose, batched_complex_hadamard, batched_complex_matmul, batched_complex_division
from utils import batched_complex_exp, batched_complex_hadamard_full, batched_complex_matmul_full
print('Done.')

Done.


In [3]:
from dynamics import stochastic_LTI, DynamicSim
from dynamics import construct_mapping
from dynamics import get_nth_measurement, get_random_measurements
from dynamics import linear_spiral, linear_spiral_3D, Lorenz, rand_coupling_matrix, Van_der_Pol_osc
print('Done.')

Done.


In [4]:
from precision_attention import compute_residuals, compute_kernel_v1, compute_estimates_and_residuals_vectorized, get_time_diffs, compute_neg_kernel, clamp_exponent_arg
from precision_attention import compute_kernel, batched_compute_estimates_and_residuals_vectorized, compute_estimates_and_residuals_irregular_times, compute_nu
from precision_attention import compute_precision_v1
# from precision_attention import precise_attn, precise_attn_with_correction, precise_attn_full
from precision_attention import compute_precision, compute_precision_tanh
print('Done.')

Done.


In [5]:
from model import compute_lambda_h
from model import init_complex_matrix, build_nearly_identity, initialize_to_correct_model
from model import init_weight_masks, apply_weight_masks
from model import Complex_MSE_Loss, Batched_Complex_MSE_Loss, inverse_penalty
from model import BatchedPrecisionAttentionBlock
from model import HadamardLayer, TemporalNorm, TemporalWhiteningLayer
from model import PrecisionNet_1layer, PrecisionNet
print('Done.')

Done.


In [6]:
from data_utils import construct_random_mapping, construct_data, TrainDataset, create_train_loader
print('Done.')

Done.


In [None]:
from visualization import plot_trajectory, compute_state_transition_matrix, plot_state_transition_matrix, plot_eigenvals, visualize_results
print('Done.')

In [7]:
from training import single_iter, single_epoch, hook_fn
print('Done.')

Done.


In [8]:
parser = argparse.ArgumentParser('DA')
parser.add_argument('--gpu', type=int, default=0) # (Default: 0)
args = parser.parse_args(args=[])
args.device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
print(args.device)

torch.manual_seed(2025)
np.random.seed(2025)

cuda:0


### Build dataset

In [9]:
# Set dynamical system

# # Normal system
# D1 = torch.zeros(2,2,2).to(args.device) # Diagonal matrix
# D1[0] = torch.tensor([[-0.1, 0.0], [0.0, -0.1]]).to(args.device)
# D1[1] = torch.tensor([[-1.0, 0.0], [0.0, 1.0]]).to(args.device)
# # S1 = torch.zeros(2,2,2).to(args.device)
# # S1[0] = torch.tensor(([1.0,1.0],[0.0,0.0]))
# # S1[1] = torch.tensor(([0.0,0.0],[1.0,-1.0]))
# # S1 = U/np.sqrt(2)
# alpha = np.random.uniform(low=0.0, high=1.0)*2*np.pi
# beta = np.random.uniform(low=0.0, high=1.0)*2*np.pi
# S1 = construct_special_2D_unitary(alpha=alpha, beta=beta)
# Si1 = complex_conj_transpose(S1)

# Stable 2D linear system (in diagonalized form):
D1 = torch.zeros(2,2,2).to(args.device) # Diagonal matrix
S1 = torch.zeros(2,2,2).to(args.device) # RHS matrix
Si1 = torch.zeros(2,2,2).to(args.device) # Inverse of RHS matrix
D1[0] = torch.tensor([[-0.1, 0.0], [0.0, -0.1]]).to(args.device)
D1[1] = torch.tensor([[-1.0, 0.0], [0.0, 1.0]]).to(args.device)
# D1[0] = torch.tensor([[-0.1, 0.0], [0.0, -0.5]]).to(args.device)
# D1[1] = torch.tensor([[-0.0, 0.0], [0.0, 0.0]]).to(args.device)
S1[0] = torch.tensor([[1.0, 1.0], [1.0, 1.0]]).to(args.device)
S1[1] = torch.tensor([[-1.0, 1.0], [0.0, 0.0]]).to(args.device)
Si1[0] = 0.5*torch.tensor([[0.0, 1.0], [0.0, 1.0]]).to(args.device)
Si1[1] = 0.5*torch.tensor([[1.0, -1.0], [-1.0, 1.0]]).to(args.device)
A = complex_matmul(S1,complex_matmul(D1,Si1))[0].unsqueeze(0)
params = [D1, S1, Si1]

In [10]:
# DEFINE MODEL

args.cr_max = 2
args.cr_min = 0

args.t_equal = 1 # Equal time intervals? (0 or 1)

args.m = 2 # Dimension of simulated system
# args.embed_dim = 2 # Embedding dimension
args.embed_dim = 256 # Embedding dimension

args.num_heads = 1
args.head_dim = int(args.embed_dim/args.num_heads)

#######################################

# Key, query, value embedding dimensions are same as input embedding dimension
args.d_k = args.head_dim
args.d_v = args.head_dim

# # Key, query, value embedding dimensions are half of input embedding dimension
# args.d_k = int(args.head_dim/2) # Key and query embedding dimension
# args.d_v = int(args.head_dim/2) # Value embedding dimension

# # Key, query, value embedding dimensions are 2
# args.d_k = 2 # Key and query embedding dimension
# args.d_v = 2 # Value embedding dimension

#######################################

args.nu = 1.0 # Measurement weighting
args.tf = 10.0 # Final time
args.dt = 0.01 # Time step size
args.n = 10 # nth measurement (ie use every nth point as a measurement)

args.N_t = int(args.tf/args.dt) # Number of time steps
args.seq_len = int(args.N_t/args.n) # Number of measurements
# t_v = (torch.arange(args.N_t + args.n)*args.dt).to(args.device) # Array of time steps
t_v = (torch.arange(args.N_t + args.n)*args.dt).to(args.device) # Array of time steps

# Some scalar weights in model
args.alpha = 1.0
args.beta = 0.0
args.delta = 1.0
args.eta = 0.0

# model = PrecisionAttentionBlock(args).to(args.device)
model = PrecisionNet_1layer(args).to(args.device) # Define model
params_list = list(model.parameters()) # Parameters list

Pu, Pd, R1, R1i = construct_random_mapping(S1, Si1, args) # Get random matrices

loss = Batched_Complex_MSE_Loss() # Loss
loss_p = Complex_MSE_Loss() # Frobenius Norm Penalty
# loss_p = Complex_Trace_Loss() # Trace Penalty
lr = 1E-2 # Learning rate
optimizer = torch.optim.Adam(params_list, lr=lr, betas=(0.9, 0.999)) # Optimizer

In [11]:
args.num_epochs = 1000 # Number of epochs
args.num_samp = 32 # Number of samples in train loader
args.batch_size = 32 # Batch size
args.num_its = int(args.num_samp/args.batch_size) # Number of iterations in an epoch
args.save_epochs = 1 # Intervals of epochs to save model
args.show_example_epochs = 4 # Number of epochs between displaying results so far
args.n_example = 5 # Plot state estimates at n_example data points

# args.penalty_weight = 0.1 # Penalty weight
args.penalty_weight = 1.0 # Penalty weight
args.weight_mask = 1 # Mask the weights? (0 or 1)
args.tanh = 0 # Use tanh precision? (0 or 1)
args.nu_adaptive = 0 # Use adaptive calculation of nu? (0 or 1)

mean_epoch_losses = np.zeros(args.num_epochs)
log_mean_epoch_losses = np.zeros(args.num_epochs)
all_losses = np.zeros(args.num_epochs * args.num_samp)
all_lambdas = np.zeros((args.num_epochs, 2, args.d_v))

sigma_process = 0.0 # Process noise
# sigma_process = 0.2 # Process noise
sigma_process_0 = sigma_process # Initial process noise
# sigma_measure = 0.1 # Measurement noise
sigma_measure = 1.0 # Measurement noise

# Build training dataset
train_loader, train_dataset, X_true_all, X_measure_all, t_measure_all = create_train_loader(A, S1, Si1, Pu, Pd, R1, R1i, t_v, sigma_process, sigma_process_0, sigma_measure,args)

## Experiments

In [38]:
def _split_heads(self, x: torch.Tensor, batch_size: int) -> torch.Tensor:
    """
    Splits the input tensor into multiple heads and prepares for attention.
    (batch_size, 2, seq_len, embed_dim) -> (batch_size * num_heads, 2, seq_len, head_dim)
    """
    seq_len = x.size(2)
    # Reshape to (batch_size, 2, seq_len, num_heads, head_dim)
    x = x.view(batch_size, 2, seq_len, self.num_heads, self.head_dim)
    # Permute to (batch_size, num_heads, seq_len, head_dim)
    x = x.permute(0, 3, 1, 2, 4)
    # Reshape to (batch_size * num_heads, 2, seq_len, head_dim) for batched attention
    return x.reshape(batch_size * self.num_heads, 2, seq_len, self.head_dim)

def _combine_heads(self, x: torch.Tensor, batch_size: int) -> torch.Tensor:
    """
    Combines multiple attention heads back into a single tensor.
    (batch_size * num_heads, 2, seq_len, head_dim) -> (batch_size, 2, seq_len, embed_dim)
    """
    seq_len = x.size(2)
    # Reshape to (batch_size, num_heads, 2, seq_len, head_dim)
    x = x.view(batch_size, self.num_heads, 2, seq_len, self.head_dim)
    # Permute to (batch_size, 2, seq_len, num_heads, head_dim)
    x = x.permute(0, 2, 3, 1, 4)
    # # Reshape to (batch_size, seq_len, embed_dim)
    return x.reshape(batch_size, 2, seq_len, self.d_e)
    
args.num_heads = 4
args.head_dim = int(args.embed_dim/args.num_heads)
self=args
x = X
batch_size = args.batch_size

X_q = X_k = X_v = X

# x = _split_heads(self, X, args.batch_size)
# # Xo = _combine_heads(self, x, args.batch_size)
# # torch.sum(X-Xo)

q_proj = nn.Linear(args.embed_dim, args.embed_dim)
k_proj = nn.Linear(args.embed_dim, args.embed_dim)
v_proj = nn.Linear(args.embed_dim, args.embed_dim)

## DELETE ##################
q_proj = q_proj.to(args.device)
k_proj = k_proj.to(args.device)
v_proj = v_proj.to(args.device)
## DELETE ##################

q = q_proj(X_q) # (B, S_q, E)
k = k_proj(X_k)   # (B, S_k, E)
v = v_proj(X_v) # (B, S_k, E)
 
q = _split_heads(self, q, args.batch_size)
k = _split_heads(self, k, args.batch_size)
v = _split_heads(self, v, args.batch_size)

attention_core = BatchedPrecisionAttentionBlock(args.head_dim, args).to(args.device)

args.head_dim
q.size()
attention_core.W_q.size()
_, out, _, _, _ = attention_core(q, k, v, t_measure_all)

out.size()
# Xo = _combine_heads(self, out, args.batch_size)

# Xo.size()

torch.Size([128, 2, 100, 64])

In [12]:
# # Batched precision attention for experiments (original)

# Pu, Pd, R1, R1i = construct_random_mapping(S1, Si1, args)
# # train_loader, train_dataset, X_true_all, X_measure_all, t_measure = create_train_loader(A, S1, Si1, Pu, Pd, R1, R1i, t_v, sigma_process, sigma_process_0, sigma_measure,args)

# for it, (train_data, X_true, X_measure, t_v_all) in enumerate(train_loader):

#     inputs  = train_data[:, :, :-1]
#     outputs = train_data[:, :, 1:]
    
#     break

# X = inputs    
# X_q = X
# X_k = X
# X_v = X

# self = model.a1

# ######################################################################################################
# ######################################################################################################

# self.lambda_h = compute_lambda_h(self.lambda1,self.args) # Get nonpositive complex conjugate eigenvalues

# # Take absolute value of noise parameters to ensure positive definiteness / non-negativeness
# self.lambda_Omega = self.lambda_Omega_sqrt**2 # Process noise matrix
# self.lambda_Omega0 = self.lambda_Omega0_sqrt**2 # Initial process noise uncertainty matrix
# self.lambda_Gamma = self.lambda_Gamma_sqrt**2 # Measurement noise matrix

# ############ (Masking; used for testing) ###########
# lambda_h, lambda_Omega, lambda_Omega0, lambda_Gamma, W_q, W_k, W_v, W_p, W_r, W_e, W_q_b, W_k_b, W_v_b, W_p_b, W_r_b, W_e_b = apply_weight_masks(self, self.args)
# ####################################################

# X_q = X_q.unsqueeze(-1)
# X_k = X_k.unsqueeze(-1)
# X_v = X_v.unsqueeze(-1)

# # Project input into Q, K, V
# #     Q = batched_complex_matmul(W_q, X_q)
# #     K = batched_complex_matmul(W_k, X_k)
# #     V = batched_complex_matmul(W_v, X_v)
# Q = batched_complex_matmul(W_q, X_q) + W_q_b
# K = batched_complex_matmul(W_k, X_k) + W_k_b
# V = batched_complex_matmul(W_v, X_v) + W_v_b

# #     R = batched_complex_matmul(W_r, X_v)

# # G1 = torch.sigmoid(self.G1)
# # G = torch.stack((G1,torch.zeros(self.args.seq_len,self.args.embed_dim,1)))
# # IG = torch.stack((1 - G1,torch.zeros(self.args.seq_len,self.args.embed_dim,1)))

# if len(t_measure_all.size()) > 1:
#     t_measure = t_measure_all[0,:-1]
# else:
#     t_measure = t_measure_all[:,:-1]

# # Functionality for possibly unequal time intervals
# if self.args.t_equal == 1: # If equal time intervals
#     K_exp, K_exp2 = compute_kernel(lambda_h, t_measure)
#     X_ij_hat_all, R_qk_ij = batched_compute_estimates_and_residuals_vectorized(Q, K, V, K_exp, self.args)
#     mat_exp = K_exp[:, -(self.args.seq_len+1), :, :] # Get matrix exponential for next-state prediction
# else: # If unequal time intervals
#     X_ij_hat_all, R_qk_ij = compute_estimates_and_residuals_irregular_times(lambda_h, t_measure_all[:,:-1], Q, K, V, self.args)
#     mat_exp = batched_complex_exp(lambda_h.squeeze(1).unsqueeze(0) * (t_measure_all[:,-1] - t_measure_all[:,-2]).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1))
#     K_exp2 = None

# if self.args.tanh == 0:
#     P_ij, nu = compute_precision(lambda_h, lambda_Omega, lambda_Omega0, lambda_Gamma, K_exp2, t_measure_all[:,:-1], self.args, R_qk_ij=R_qk_ij, alpha_nu=self.alpha_nu, beta_nu=self.beta_nu, lambda_C=self.lambda_C)
# else:
#     P_ij, nu = compute_precision_tanh(lambda_h, lambda_Omega, lambda_Omega0, lambda_Gamma, K_exp2, t_measure_all[:,:-1], self.args, R_qk_ij=R_qk_ij, alpha_nu=self.alpha_nu, beta_nu=self.beta_nu, lambda_C=self.lambda_C)

# # Compute unnormalized attention matrix
# mahalanobis_distance = P_ij * (R_qk_ij[:,0]**2 + R_qk_ij[:,1]**2)
# denom = (1 + nu*torch.sum(mahalanobis_distance, axis=3, keepdims = True))
# A_ij = P_ij / denom

# A_ij = A_ij * self.causal_mask # Apply causal mask to attention matrix
# X_ij_hat_all = X_ij_hat_all * self.causal_mask # Mask out estimates backward in time (not strictly necessary but useful larter for visualization)

# # Normalize attention
# S_ij = torch.sum(A_ij, axis=2, keepdims = True)
# Q_ij = A_ij / S_ij

# # Compute Hadamard product and sum to get estimate in diagonalized space
# est_v = torch.sum(Q_ij.unsqueeze(1) * X_ij_hat_all,axis=3)

# # Add residual connection
# est_eigenbasis = est_v # No residual connection
# #     est_e = self.args.alpha*est_v + self.args.beta*V # JUST FOR TESTING
# #     est_e = est_v + self.alpha*(est_v - V) # JUST FOR TESTING
# #     est_e = est_v + R
# #     est_e = est_v + batched_complex_matmul(W_r, est_v - V)

# # Multiply by output matrix to get estimate
# #     est = batched_complex_matmul(W_e,est_eigenbasis)
# est = batched_complex_matmul(W_e,est_eigenbasis) + W_e_b
# #     est = batched_complex_matmul(W_p,est_eigenbasis)

# # Get prediction in diagonalized space
# #     pred_p = batched_complex_hadamard(mat_exp, est_e)
# #     pred_p = batched_complex_hadamard(lambda_h, est_e)*(self.args.n * self.args.dt) + est_e # JUST FOR TESTING
# #     pred_p = batched_complex_hadamard(mat_exp, V) # JUST FOR TESTING
# #     pred_p = batched_complex_hadamard(lambda_h, V)*(self.args.n * self.args.dt) + V # JUST FOR TESTING
# if self.args.t_equal == 1: # If equal time intervals
#     pred_p = batched_complex_hadamard(mat_exp, est_eigenbasis)
# else:
#     pred_p = batched_complex_hadamard_full(mat_exp.unsqueeze(2), est_eigenbasis)

# # Multiply by output matrix to get output prediction
# #     pred = batched_complex_matmul(self.W_p, pred_p)
# #     pred = batched_complex_matmul(W_p, pred_p)
# pred = batched_complex_matmul(W_p, pred_p) + W_p_b
# #     pred = batched_complex_matmul(self.W_p, batched_complex_hadamard(lambda_h, X_v))*self.args.dt + X_v # JUST FOR TESTING

# # Output is a linear combination of estimate and prediction
# out = self.args.delta*pred + self.args.eta*est
# #     out = self.delta*pred + self.eta*est
# #     out = pred + est

# est = est.squeeze(-1)
# out = out.squeeze(-1)
# X_ij_hat_all = X_ij_hat_all.squeeze(-1)

In [37]:
lambda_h = compute_lambda_h(lambda1, args).view(2,num_heads,1,head_dim,1)
lambda_h.size()

torch.Size([2, 4, 1, 64, 1])