### Test forward pass of network

In [1]:
import numpy as np
import math
import scipy

from matplotlib import pyplot as plt
import matplotlib.cm as cm
plt.rcParams['figure.figsize'] = [10, 10]
plt.rc('font', size=20)

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.utils.data import Dataset, DataLoader

import os
import argparse
import time
from tqdm import tqdm # Loading bar
print('Done.')

Done.


In [2]:
from utils import complex_conj_transpose, batched_complex_conj_transpose, complex_exp, complex_exp_v2, complex_hadamard, complex_matmul, complex_division
from utils import batched_complex_conj_transpose, batched_complex_hadamard, batched_complex_matmul, batched_complex_division
from utils import batched_complex_exp, batched_complex_hadamard_full, batched_complex_matmul_full
print('Done.')

Done.


In [3]:
from dynamics import stochastic_LTI, construct_mapping
from dynamics import get_nth_measurement, get_random_measurements
print('Done.')

Done.


In [4]:
from precision_attention import compute_residuals, compute_kernel_v1, compute_estimates_and_residuals_vectorized, get_time_diffs, compute_neg_kernel, clamp_exponent_arg
from precision_attention import compute_kernel, batched_compute_estimates_and_residuals_vectorized, compute_estimates_and_residuals_irregular_times, compute_nu
from precision_attention import compute_precision_v1
from precision_attention import precise_attn, precise_attn_with_correction, precise_attn_full
print('Done.')

Done.


In [5]:
from model import compute_lambda_h
from model import init_complex_matrix, build_nearly_identity, initialize_to_correct_model
from model import init_weight_masks, apply_weight_masks
from model import Complex_MSE_Loss, Batched_Complex_MSE_Loss, inverse_penalty
from model import BatchedPrecisionAttentionBlock
from model import HadamardLayer, TemporalNorm, TemporalWhiteningLayer
from model import PrecisionNet_1layer, PrecisionNet
print('Done.')

Done.


In [12]:
from data_utils import construct_random_mapping, construct_data, TrainDataset, create_train_loader
print('Done.')

Done.


In [6]:
from training import plot_trajectory, compute_state_transition_matrix, plot_state_transition_matrix, plot_eigenvals, visualize_results
from training import single_iter, single_epoch, hook_fn
print('Done.')

Done.


In [7]:
parser = argparse.ArgumentParser('DA')
parser.add_argument('--gpu', type=int, default=0) # (Default: 0)
args = parser.parse_args(args=[])
args.device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
print(args.device)

torch.manual_seed(2025)
np.random.seed(2025)

cuda:0


In [8]:
# Defines a stable 2D linear system:
D1 = torch.zeros(2,2,2).to(args.device) # Diagonal matrix
S1 = torch.zeros(2,2,2).to(args.device) # Unitary matrix
Si1 = torch.zeros(2,2,2).to(args.device) # Inverse of unitary matrix
D1[0] = torch.tensor([[-0.1, 0.0], [0.0, -0.1]]).to(args.device)
D1[1] = torch.tensor([[-1.0, 0.0], [0.0, 1.0]]).to(args.device)
S1[0] = torch.tensor([[1.0, 1.0], [1.0, 1.0]]).to(args.device)
S1[1] = torch.tensor([[-1.0, 1.0], [0.0, 0.0]]).to(args.device)
Si1[0] = 0.5*torch.tensor([[0.0, 1.0], [0.0, 1.0]]).to(args.device)
Si1[1] = 0.5*torch.tensor([[1.0, -1.0], [-1.0, 1.0]]).to(args.device)

D1 = D1.unsqueeze(1)
S1 = S1.unsqueeze(1)
Si1 = Si1.unsqueeze(1)

A = complex_matmul(S1,complex_matmul(D1,Si1))[0] # State transition matrix

sigma_process = 0.0
sigma_process_0 = 0.0
sigma_measure = 0.1

args.m = 2 # Dimension of system
args.tf = 10 # Final time
args.dt = 0.01 # Time step size
args.N_t = int(args.tf/args.dt) + 1 # Number of time steps
args.seq_len = args.N_t

t_v = torch.linspace(0, args.tf, args.N_t).to(args.device) # Vector of times

x0 = (torch.randn(args.m)*10).unsqueeze(0).unsqueeze(-1).to(args.device) # Initial condition

In [9]:
X_true, X_measure_full = stochastic_LTI(A, x0, args.N_t, args, sigma_process=sigma_process, sigma_process_0=sigma_process_0, sigma_measure = sigma_measure) # Simulate system

idxs, t_measure, X_measure = get_nth_measurement(X_measure_full, t_v, args.N_t, n=10)

X = torch.zeros((2,X_measure.size()[0], X_measure.size()[1], X_measure.size()[2])).to(args.device)
X[0] = X_measure

In [10]:
# Test forward pass of batched attention block

inputs = X[:,:,:-1]
outputs = X[:,:,1:]

args.embed_dim = 256
args.num_heads = 1
args.head_dim = int(args.embed_dim/args.num_heads)
args.d_k = args.head_dim # Key and query embedding dimension
args.d_v = args.head_dim # Value embedding dimension

args.nu = 1.0

args.seq_len = int(args.tf/args.dt) # Number of time steps

t_v = torch.linspace(0, args.tf, args.seq_len + 1).to(args.device)[:-1] # Vector of times

model0 = PrecisionNet_1layer(args).to(args.device)

# model0.forward(inputs,t_v)

Test forward pass of full network

In [13]:
args.t_equal = 1
args.tanh = 0
args.weight_mask = 1
args.nu_adaptive = 0

args.embed_dim = 256
args.num_heads = 1
args.head_dim = int(args.embed_dim/args.num_heads)
args.d_k = args.head_dim # Key and query embedding dimension
args.d_v = args.head_dim # Value embedding dimension

args.nu = 1.0
args.tf = 10.0 # Final time
args.dt = 0.01 # Time step size
# args.seq_len = int(tf/args.dt) # Number of time steps
args.N_t = int(args.tf/args.dt) # Number of time steps
args.seq_len = 100
args.n = 10

args.alpha = 1.0
args.beta = 0.0
args.delta = 1.0
args.eta = 0.0

# t_v = torch.linspace(0, args.tf, args.seq_len + 1).to(args.device)[:-1] # Vector of times
t_v = torch.linspace(0, args.tf, args.N_t+1).to(args.device) # Vector of times

model0 = PrecisionNet_1layer(args).to(args.device)

# loss = nn.L1Loss()
# loss = nn.L2Loss()
loss = Batched_Complex_MSE_Loss()
lr = 1E-4
params_list = list(model0.parameters())
optimizer = torch.optim.Adam(params_list, lr=lr, betas=(0.9, 0.999))

Pu, Pd, R1, R1i = construct_random_mapping(S1, Si1, args)

x0 = (torch.randn(args.m)*10).unsqueeze(0).unsqueeze(-1).to(args.device) # Initial condition
#     X_true, X_measure = stochastic_LTI(A, x0, Npts, args, sigma_process=sigma_process, sigma_process_0=sigma_process_0, sigma_measure=sigma_measure) # Simulate system
X_true, X_measure_full = stochastic_LTI(A, x0, args.N_t+1, args, sigma_process=sigma_process, sigma_process_0=sigma_process_0, sigma_measure = sigma_measure) # Simulate system
idxs, t_measure, X_measure = get_nth_measurement(X_measure_full, t_v, args.N_t+1, n=10)

X_measure_c = torch.zeros((2,X_measure.size()[0], X_measure.size()[1], X_measure.size()[2])).to(args.device)
X_measure_c[0].size()
X_measure_c[0] = X_measure
X_high = complex_matmul(Pu,X_measure_c) # Map to higher dim

#     X_random = torch.matmul(R1,X_high) # Map to random basis
X_random = complex_matmul(R1,X_high) # Map to random basis
X_random, X_high, X_true, X_measure, t_measure = construct_data(A, Pu, Pd, R1, R1i, args.N_t+1, t_v, sigma_process, sigma_process_0, sigma_measure, args)
X = torch.matmul(R1i,X_random.unsqueeze(-1)).unsqueeze(0).squeeze(-1)

inputs = X[:, :, :-1]
outputs = X[:, :, 1:]

optimizer.zero_grad() # Zero out gradients

# est, out, Q_ij, X_ij_hat_all = model0(inputs, t_v)
est, out, Q_ij, X_ij_hat_all, lambda_h = model0(inputs, t_measure.unsqueeze(0))

loss_i = loss(out, outputs)

loss_i.backward()

torch.nn.utils.clip_grad_norm_(params_list, 1)

optimizer.step()

In [14]:
args.reconstruct = 0
args.latent_loss = 0
args.nu_adaptive = 0

# Test single_iter

args.penalty_weight = 0

args.embed_dim = 256
args.d_v = args.embed_dim
args.nu = 1.0
args.tf = 10.0 # Final time
args.dt = 0.01 # Time step size
args.N_t = int(args.tf/args.dt) # Number of time steps
n = 10 # nth measurement
args.seq_len = int(args.N_t/n)

t_v = torch.linspace(0, args.tf, args.N_t+1).to(args.device) # Vector of times

model0 = PrecisionNet_1layer(args).to(args.device)

# loss = nn.L1Loss()
# loss = nn.L2Loss()
# loss = ComplexL2Loss()
loss = Batched_Complex_MSE_Loss()
loss_p = Complex_MSE_Loss()
lr = 1E-4
params_list = list(model0.parameters())
optimizer = torch.optim.Adam(params_list, lr=lr, betas=(0.9, 0.999))

Pu, Pd, R1, R1i = construct_random_mapping(S1, Si1, args)

# X_true, X_measure_full = stochastic_LTI(A, x0, args.N_t+1, args, sigma_process=sigma_process, sigma_process_0=sigma_process_0, sigma_measure = sigma_measure) # Simulate system
# idxs, t_measure, X_measure = get_nth_measurement(X_measure_full, t_v, args.N_t+1, n=n)

# X_measure_c = torch.zeros((2,X_measure.size()[0], X_measure.size()[1], X_measure.size()[2])).to(args.device)
# X_measure_c[0].size()
# X_measure_c[0] = X_measure
# X_high = complex_matmul(Pu,X_measure_c) # Map to higher dim

#     X_random = torch.matmul(R1,X_high) # Map to random basis
# X_random = complex_matmul(R1,X_high) # Map to random basis
X_random, X_high, X_true, X_measure, t_measure = construct_data(A, Pu, Pd, R1, R1i, args.N_t+1, t_v, sigma_process, sigma_process_0, sigma_measure, args)
X = torch.matmul(R1i,X_random.unsqueeze(-1)).unsqueeze(0).squeeze(-1)

inputs = X[:, :, :-1]
outputs = X[:, :, 1:]

start = time.time()
# single_iter(model0, optimizer, loss, inputs, outputs, t_v, args)
single_iter(model0, optimizer, loss, loss_p, inputs, outputs, t_measure.unsqueeze(0), args)
end = time.time()
print(end - start)

0.02706766128540039
