In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, TensorDataset, random_split
import torch.nn as nn
import torch.nn.functional as F
import functorch
import os
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [8]:
# Import simulation data
u_array = np.load('unp_concat_100.npy')
m_array = np.load('mnp_concat_100.npy')
Jac_array = np.load('Jacnp_concat_100.npy')

# Import reduced-order subspaces
AS = np.load('AS_fs_wom.npy')
PCA = np.load('PCA_U_f_100_allstep.npy')

# Import finite element mass matrix
M = np.load('M.npy')

# Import Bayesian prior information
prior_m_precision = np.load('prior_prec.npy')
prior_m_covariance = np.load('prior.npy')
prior_mean = np.load('prior_mean.npy')

# Import reference solution data
obs_mean = np.load('obs_mean.npy')

In [15]:
# Prepare parameter data
m_array = torch.tensor(m_array)
m_array_ = m_array - prior_mean  

# Reduce parameter dimension using active subspace
m_red = m_array_ @ (prior_m_precision @ AS)

# Prepare observation data
u_array_ = u_array - obs_mean 

# Truncate PCA basis to 129 dimensions
PCA = PCA[:,:129]

# Set number of time steps for training
num_steps = 100

# Create training datasets
train_m = torch.tensor(m_red.numpy()[:-100], dtype=torch.float32)  
train_s = torch.tensor(u_array_[:-100,0,:]@PCA, dtype=torch.float32)  
output_s = torch.tensor(u_array_[:-100,1:1+num_steps,:]@PCA, dtype=torch.float32) 

# Normalize parameter data
train_m_100 = train_m

mean = torch.mean(train_m_100,dim=0)  
std = torch.std(train_m_100,dim=0)    

sdata_m = (train_m_100 - mean)/std   

# Normalize observation data
mean_o = output_s.mean(dim = (0,1))   
# std_o = output_s.std(dim=(0,1))     

# Set observation std to 1 for normalization
std_o = torch.ones_like(std_o)

sdata_y = (train_s - mean_o)/(std_o)      
sdata_y_t = (output_s - mean_o)/(std_o)  

# Normalize Jacobian data using same scaling
Jac_r_ = torch.tensor(Jac_array)
Jac_r = Jac_r_.transpose(2,3)         
Jac_r = Jac_r[:,1:]                   
Jac_r_n_1 = Jac_r / std_o.unsqueeze(1).unsqueeze(0).unsqueeze(0)  
Jac_r_n_2 = Jac_r_n_1 * std.unsqueeze(0).unsqueeze(0).unsqueeze(0)  
sdata_jac_ = Jac_r_n_2[:,:num_steps]  
sdata_jac = sdata_jac_[:-100]         

# Limit dataset size to 1280 samples
sdata_m_ = sdata_m[:1024+256]
sdata_y_ = sdata_y[:1024+256]
sdata_y_t_ = sdata_y_t[:1024+256]
sdata_jac_ = sdata_jac[:1024+256]

In [62]:
# Create dataset from preprocessed tensors
dataset = TensorDataset(sdata_m_, sdata_y_, sdata_y_t_, sdata_jac_)
total_samples = len(dataset)

# Define 80-20 train-validation split
train_size = int(total_samples * 0.8)
val_size = total_samples - train_size

# Randomly split dataset
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create data loaders for batch processing
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)   
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)      

In [65]:
# Define attention layer for simplified version

class DualPurposeAttentionLayer(nn.Module):
    def __init__(self, hidden_dim, attention_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.attention_dim = attention_dim
        
        # Shared projections for Q, K, V
        self.query_transform = nn.Linear(hidden_dim, attention_dim)
        self.key_transform = nn.Linear(hidden_dim, attention_dim)
        self.value_transform = nn.Linear(hidden_dim, attention_dim)
        
        # Output projections for forward and backward attention
        self.output_forward = nn.Linear(attention_dim, hidden_dim)
        self.output_backward = nn.Linear(attention_dim, hidden_dim)
        
        self.attention_scale = attention_dim ** -0.5

        self.ff = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim ),
            nn.ELU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.ff_j = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim ),
            nn.ELU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.norm3 = nn.LayerNorm(hidden_dim)
        self.norm4 = nn.LayerNorm(hidden_dim)

    def forward(self, x):
        seq_length = x.size(1)

        query = self.query_transform(x)
        key = self.key_transform(x)
        value = self.value_transform(x)

        attention_scores = torch.matmul(query, key.transpose(-2, -1)) * self.attention_scale
        
        # Create forward causal mask (lower triangular)
        forward_mask = torch.tril(torch.ones(seq_length, seq_length)).bool().to(x.device)
        
        # Apply forward mask
        forward_scores = attention_scores.masked_fill(~forward_mask, float('-10e7'))
        forward_weights = F.softmax(forward_scores, dim=-1)
        forward_attended = torch.matmul(forward_weights, value)
        forward_output = self.output_forward(forward_attended)

        # Process forward output for state prediction
        x_forward = self.norm1(x + forward_output)
        ff_output = self.ff(x_forward)
        x_forward = self.norm2(x_forward + ff_output)
               
        return x_forward

class StatePredictor(nn.Module):
    def __init__(self, state_dim, hidden_dim, param_dim, num_steps=10):
        super().__init__()
        self.layers = num_steps
        self.state_dim = state_dim
        self.param_dim = param_dim
        self.hidden_dim = hidden_dim

        self.state_transforms = nn.ModuleList([nn.Linear(state_dim, hidden_dim) for _ in range(num_steps)])
        self.param_transform = nn.Linear(param_dim, hidden_dim ) 
        self.combine_transforms = nn.ModuleList([nn.Linear(hidden_dim * 2, hidden_dim) for _ in range(num_steps)])
        
        self.attention_layer = DualPurposeAttentionLayer(hidden_dim, attention_dim=hidden_dim // 2)
    
        self.fc_layers = nn.ModuleList([nn.Linear(hidden_dim, 128) for _ in range(num_steps)])
        self.out_layers = nn.ModuleList([nn.Linear(128, state_dim) for _ in range(num_steps)])

        self.fc_layers_j = nn.ModuleList([nn.Linear(hidden_dim, 128) for _ in range(num_steps)])
        self.out_layers_j = nn.ModuleList([nn.Linear(128, state_dim) for _ in range(num_steps)])

        self.GELU = nn.GELU()
        self.ELU = nn.ELU()
        self.Tanh = nn.Tanh()

    def get_positional_encoding(self, seq_len, d_model):
        position = torch.arange(seq_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        pe = torch.zeros(seq_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)  
        return pe.unsqueeze(0)

    def forward(self, inputs, inputm):
        batch_size = inputs.size(0)
        latent_variables = []
        current_state = inputs
        current_backward_state = inputs

        for i in range(self.layers):
            param_key = self.param_transform(inputm)
            state_query = self.state_transforms[i](current_state)

            combined = self.Tanh(self.combine_transforms[i](torch.cat([param_key, state_query], dim=-1)))
            latent_variables.append(combined)

        latent_stack = torch.stack(latent_variables, dim=1)
        pe = self.get_positional_encoding(self.layers, self.hidden_dim).to(latent_stack.device)
        latent_stack = latent_stack + pe

        forward_attended = self.attention_layer(latent_stack)

        forward_outputs = []
        for i in range(self.layers):
            s = self.ELU(self.fc_layers[i](forward_attended[:, i, :]))
            new_state = self.out_layers[i](s) + current_state
            forward_outputs.append(new_state)
            current_state = new_state

        jac_outputs = []
        
        # for i in reversed(range(self.layers)):
        for i in range(self.layers):
            s = self.ELU(self.fc_layers_j[i](forward_attended[:, i, :]))
            j_output = self.out_layers_j[i](s) + current_backward_state
            jac_outputs.append(j_output)  # Insert at the beginning to maintain original order
            current_backward_state = j_output

        forward_outputs_stacked = torch.stack(forward_outputs, dim=1)
        jac_outputs_stacked = torch.stack(jac_outputs, dim=1)

        return forward_outputs_stacked, jac_outputs_stacked

In [None]:
# Define attention layer for original version

# import torch
# import torch.nn as nn
# import math

# class DualPurposeAttentionLayer(nn.Module):
#     def __init__(self, hidden_dim, attention_dim):
#         super().__init__()
#         self.hidden_dim = hidden_dim
#         self.attention_dim = attention_dim
        
#         # Shared projections for Q, K, V
#         self.query_transform = nn.Linear(hidden_dim, attention_dim,bias=False)
#         self.key_transform = nn.Linear(hidden_dim, attention_dim,bias=False)
#         self.value_transform = nn.Linear(hidden_dim, attention_dim,bias=False)
        
#         # Output projections for forward and backward attention
#         self.output_forward = nn.Linear(attention_dim, hidden_dim)
#         self.output_backward = nn.Linear(attention_dim, hidden_dim)
        
#         self.attention_scale = attention_dim ** -0.5

#         self.ff = nn.Sequential(
#             nn.Linear(hidden_dim, hidden_dim ),
#             nn.ELU(),
#             nn.Linear(hidden_dim, hidden_dim)
#         )

#         self.ff_j = nn.Sequential(
#             nn.Linear(hidden_dim, hidden_dim ),
#             nn.ELU(),
#             nn.Linear(hidden_dim, hidden_dim)
#         )

#         self.norm1 = nn.LayerNorm(hidden_dim)
#         self.norm2 = nn.LayerNorm(hidden_dim)
#         self.norm3 = nn.LayerNorm(hidden_dim)
#         self.norm4 = nn.LayerNorm(hidden_dim)

#     def forward(self, x):
#         seq_length = x.size(1)

#         query = self.query_transform(x)
#         key = self.key_transform(x)
#         value = self.value_transform(x)

#         attention_scores = torch.matmul(query, key.transpose(-2, -1)) * self.attention_scale
        
#         # Create forward causal mask (lower triangular)
#         forward_mask = torch.tril(torch.ones(seq_length, seq_length)).bool().to(x.device)
        
#         # Apply forward mask
#         forward_scores = attention_scores.masked_fill(~forward_mask, float('-10e7'))
#         forward_weights = F.softmax(forward_scores, dim=-1)
#         forward_attended = torch.matmul(forward_weights, value)
#         forward_output = self.output_forward(forward_attended)

#         # Process forward output for state prediction
#         x_forward = self.norm1(x + forward_output)
#         ff_output = self.ff(x_forward)
#         x_forward = self.norm2(x_forward + ff_output)
               
#         return x_forward

# class StatePredictor(nn.Module):
#     def __init__(self, state_dim, hidden_dim, param_dim, num_steps=10):
#         super().__init__()
#         self.layers = num_steps
#         self.state_dim = state_dim
#         self.param_dim = param_dim
#         self.hidden_dim = hidden_dim

#         self.state_transforms = nn.ModuleList([nn.Linear(state_dim, hidden_dim) for _ in range(num_steps)])
#         self.param_transform = nn.Linear(param_dim, hidden_dim ) 
#         self.combine_transforms = nn.ModuleList([nn.Linear(hidden_dim * 2, hidden_dim) for _ in range(num_steps)])
        
#         self.attention_layer = DualPurposeAttentionLayer(hidden_dim, attention_dim=hidden_dim // 2)
    
#         self.fc_layers = nn.ModuleList([nn.Linear(hidden_dim, 128) for _ in range(num_steps)])
#         self.out_layers = nn.ModuleList([nn.Linear(128, state_dim) for _ in range(num_steps)])

#         self.fc_layers_j = nn.ModuleList([nn.Linear(hidden_dim, 128) for _ in range(num_steps)])
#         self.out_layers_j = nn.ModuleList([nn.Linear(128, state_dim) for _ in range(num_steps)])

#         self.GELU = nn.GELU()
#         self.ELU = nn.ELU()
#         self.Tanh = nn.Tanh()

#     def get_positional_encoding(self, seq_len, d_model):
#         position = torch.arange(seq_len).unsqueeze(1).float()
#         div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
#         pe = torch.zeros(seq_len, d_model)
#         pe[:, 0::2] = torch.sin(position * div_term)
#         pe[:, 1::2] = torch.cos(position * div_term)  
#         return pe.unsqueeze(0)

#     def forward(self, inputs, inputm):
#         batch_size = inputs.size(0)
#         latent_variables = []
#         current_state = inputs
#         current_backward_state = inputs

#         for i in range(self.layers):
#             param_key = self.param_transform(inputm)
#             state_query = self.state_transforms[i](current_state)

#             combined = self.Tanh(self.combine_transforms[i](torch.cat([param_key, state_query], dim=-1)))
#             latent_variables.append(combined)

#         latent_stack = torch.stack(latent_variables, dim=1)
#         pe = self.get_positional_encoding(self.layers, self.hidden_dim).to(latent_stack.device)
#         latent_stack = latent_stack + pe

#         forward_attended = self.attention_layer(latent_stack)

#         forward_outputs = []
#         for i in range(self.layers):
#             s = self.ELU(self.fc_layers[i](forward_attended[:, i, :]))
#             new_state = self.out_layers[i](s) + current_state
#             forward_outputs.append(new_state)
#             current_state = new_state

#         jac_outputs = []
        
#         # for i in reversed(range(self.layers)):
#         for i in range(self.layers):
#             s = self.ELU(self.fc_layers_j[i](forward_attended[:, i, :]))
#             j_output = self.out_layers_j[i](s) + current_backward_state
#             jac_outputs.append(j_output)  # Insert at the beginning to maintain original order
#             current_backward_state = j_output

#         forward_outputs_stacked = torch.stack(forward_outputs, dim=1)
#         jac_outputs_stacked = torch.stack(jac_outputs, dim=1)

#         return forward_outputs_stacked, jac_outputs_stacked

In [66]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [69]:
def compute_jacobian_batched(model, inputs, params):
    def single_jacobian(inp, param):
        def get_backward_output(p):
            _, backward_output = model(inp.unsqueeze(0), p.unsqueeze(0))
            return backward_output.squeeze(0)
        return functorch.jacfwd(get_backward_output)(param)

    return functorch.vmap(single_jacobian)(inputs, params)

In [71]:
# Define checkpoint directory
checkpoint_dir = 'checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

best_val_loss = float('inf')
patience = 20
patience_counter = 0

In [None]:
# Initialize model and training components
model = StatePredictor(129, 100, 128, num_steps=num_steps)  # state_dim=129, hidden_dim=100, param_dim=128
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.5)  # Reduce LR when loss plateaus
loss_function = nn.MSELoss()

# Early stopping parameters
best_val_loss = float('inf')
patience = 20
patience_counter = 0

# Training loop
for epoch in range(1000):
    # Training phase
    model.train()
    train_losses = []
    train_jac_losses = []
    
    for inputm, inputs, targets, Jacs in train_loader:
        # Move data to device
        inputm, inputs, targets, Jacs = inputm.to(device), inputs.to(device), targets.to(device), Jacs.to(device)
        
        optimizer.zero_grad()
        forward_outputs, backward_outputs = model(inputs, inputm)

        # Compute forward prediction loss for each time step
        step_losses = [loss_function(forward_outputs[:, i], targets[:, i]) for i in range(num_steps)]
        main_loss = sum(step_losses)
        
        # Jacobian loss weight
        lambda_jac = 1.
        
        # Compute Jacobian using automatic differentiation
        jacobian_func = compute_jacobian_batched(model, inputs, inputm)
        jac_losses = [torch.norm(jacobian_func[:, i] - Jacs[:, i], dim=(1, 2)).mean() for i in range(num_steps)]
        jac_loss_total = sum(jac_losses)
        
        # Total loss combines forward prediction and Jacobian accuracy
        loss = main_loss + jac_loss_total

        loss.backward()
        optimizer.step()
        train_jac_losses.append(jac_loss_total.item())

    # Validation phase
    model.eval()
    val_losses = [0] * num_steps
    val_jac_losses = [0] * num_steps
    
    with torch.no_grad():
        for inputm, inputs, targets, Jacs in val_loader:
            inputm, inputs, targets, Jacs = inputm.to(device), inputs.to(device), targets.to(device), Jacs.to(device)
            forward_outputs, backward_outputs = model(inputs, inputm)

            # Accumulate validation losses for each time step
            for i in range(num_steps):
                val_losses[i] += loss_function(forward_outputs[:, i], targets[:, i]).item()

            # Compute validation Jacobian losses
            jacobian_func = compute_jacobian_batched(model, inputs, inputm)
            jac_losses = [torch.norm(jacobian_func[:, i] - Jacs[:, i], dim=(1, 2)).mean() for i in range(num_steps)]
            for i in range(num_steps):
                val_jac_losses[i] += jac_losses[i].item()

    # Average validation losses across batches
    avg_val_losses = [loss / len(val_loader) for loss in val_losses]
    avg_val_jac_losses = [loss / len(val_loader) for loss in val_jac_losses]
    avg_total_loss = sum(avg_val_losses) / num_steps
    avg_total_jac_loss = sum(avg_val_jac_losses) / num_steps

    # Update learning rate scheduler
    scheduler.step(avg_total_loss + avg_total_jac_loss)

    # Print detailed loss information
    loss_string = ', '.join(f'Step {i+1}: {loss:.4f}' for i, loss in enumerate(avg_val_losses))
    jac_loss_string = ', '.join(f'Jac Step {i+1}: {loss:.4f}' for i, loss in enumerate(avg_val_jac_losses))
    print(f'Epoch {epoch+1}:')
    print(f'  Main Loss: {loss_string}, Avg Main Loss: {avg_total_loss:.4f}')
    print(f'  Jac Loss: {jac_loss_string}, Avg Jac Loss: {avg_total_jac_loss:.4f}')

    total_val_loss = avg_total_loss + avg_total_jac_loss
    
    # Early stopping and model checkpointing
    if total_val_loss < best_val_loss:
        best_val_loss = total_val_loss
        patience_counter = 0
        
        # Save best model checkpoint
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_loss': best_val_loss,
            'num_steps': num_steps,
            'patience_counter': patience_counter,
        }
        
        checkpoint_path = os.path.join(checkpoint_dir, f'best_model_checkpoint_100_casual_jac.pth')
        torch.save(checkpoint, checkpoint_path)
        print(f"Saved best model checkpoint to {checkpoint_path}")
    else:
        patience_counter += 1

print("Training completed.")

In [74]:
# Prepare test data
m_test = m_red[-100:]
m_test_r = ((m_test - mean)/std).to(dtype=torch.float32)

test_y_in = torch.tensor(u_array_[-100:,0,:]@PCA, dtype=torch.float32)
test_y_out = torch.tensor(u_array[-100:,1:1+num_steps,:], dtype=torch.float32)

test_t_in_r = (test_y_in-mean_o)/std_o

In [86]:
# Evaluate model
model = model.to('cpu')
m_test_r = m_test_r.to('cpu')
PCA_ = torch.tensor(PCA, dtype=torch.float32)
states = model(test_t_in_r, m_test_r)
obs_mean_t = torch.tensor(obs_mean,dtype=torch.float32)


In [132]:
def compute_errors(model, test_t_in_r, m_test_r, test_y_out, std_o, mean_o, PCA_, M, obs_mean_t):
    # Generate states from the model
    with torch.no_grad():
        states = model(test_t_in_r, m_test_r)[0]
    
    # Adjust and transform states in a vectorized manner
    adjusted_states = (states * std_o + mean_o) @ PCA_.float().T
    
    # Reshape test_y_out and adjusted_states for broadcasting
    test_y_out = test_y_out.view(100, 100, -1)
    adjusted_states = adjusted_states.view(100, 100, -1)
    
    # Compute differences
    diff = test_y_out - adjusted_states - obs_mean_t
    
    # Compute norms using einsum for efficiency
    norm_adjusted = torch.einsum('ijk,kl,ijl->ij', diff, M, diff)
    norm_test = torch.einsum('ijk,kl,ijl->ij', test_y_out, M, test_y_out)
    
    # Compute errors
    errors = torch.sqrt(norm_adjusted) / torch.sqrt(norm_test)
    
    return errors

In [133]:
errors = compute_errors(model.to('cpu'), test_t_in_r.to('cpu'), m_test_r.to('cpu'), test_y_out.to('cpu'), std_o.to('cpu'), mean_o.to('cpu'), PCA_.to('cpu'), M.to('cpu'), obs_mean_t.to('cpu'))


In [121]:
errors

tensor([[0.0213, 0.0416, 0.0554,  ..., 0.0197, 0.0198, 0.0200],
        [0.0222, 0.0315, 0.0448,  ..., 0.0164, 0.0172, 0.0163],
        [0.0320, 0.0506, 0.0747,  ..., 0.0236, 0.0246, 0.0232],
        ...,
        [0.0600, 0.0954, 0.1054,  ..., 0.0418, 0.0425, 0.0391],
        [0.0218, 0.0355, 0.0451,  ..., 0.0196, 0.0188, 0.0191],
        [0.0367, 0.0556, 0.0586,  ..., 0.0176, 0.0164, 0.0163]])

In [99]:
jacobian_func = compute_jacobian_batched(model,test_t_in_r,m_test_r)

In [100]:
Jac_test = Jac_r[-100:]
mean_normalized_differences = []
Jac_comp = jacobian_func.to('cpu') / std.unsqueeze(0).unsqueeze(0) * std_o.unsqueeze(1).unsqueeze(0)
for i in range(100):
    normalized_diff = (Jac_comp[:,i] - Jac_test[:, i]).norm(dim=(1, 2)) / Jac_test[:, i].norm(dim=(1, 2))
    mean_normalized_differences.append(normalized_diff.mean())

In [101]:
mean_normalized_differences

[tensor(0.1116, dtype=torch.float64, grad_fn=<MeanBackward0>),
 tensor(0.0732, dtype=torch.float64, grad_fn=<MeanBackward0>),
 tensor(0.0567, dtype=torch.float64, grad_fn=<MeanBackward0>),
 tensor(0.0484, dtype=torch.float64, grad_fn=<MeanBackward0>),
 tensor(0.0426, dtype=torch.float64, grad_fn=<MeanBackward0>),
 tensor(0.0431, dtype=torch.float64, grad_fn=<MeanBackward0>),
 tensor(0.0394, dtype=torch.float64, grad_fn=<MeanBackward0>),
 tensor(0.0361, dtype=torch.float64, grad_fn=<MeanBackward0>),
 tensor(0.0338, dtype=torch.float64, grad_fn=<MeanBackward0>),
 tensor(0.0341, dtype=torch.float64, grad_fn=<MeanBackward0>),
 tensor(0.0318, dtype=torch.float64, grad_fn=<MeanBackward0>),
 tensor(0.0303, dtype=torch.float64, grad_fn=<MeanBackward0>),
 tensor(0.0287, dtype=torch.float64, grad_fn=<MeanBackward0>),
 tensor(0.0295, dtype=torch.float64, grad_fn=<MeanBackward0>),
 tensor(0.0297, dtype=torch.float64, grad_fn=<MeanBackward0>),
 tensor(0.0283, dtype=torch.float64, grad_fn=<MeanBackw

In [None]:
# Load data for MAP point estimation
map_m = torch.tensor(np.load('combined_map_m_100_10.npy'),dtype=torch.float64)  
init_y = torch.tensor(sdata_y[0].clone().detach().unsqueeze(0).repeat(200, 1),dtype=torch.float32) 
map_obs = torch.tensor(np.load('combined_map_obs_100_10.npy'),dtype=torch.float64) 
map_map = torch.tensor(np.load('combined_map_map_100_10.npy'),dtype=torch.float64)

# Prepare MAP parameter data
map_m_ = map_m - prior_mean  
map_m_red = map_m_ @ (prior_m_precision @ AS)  
map_prior = ((map_m_red - mean)/std).float()  

# Prepare MAP observation data
map_obs_ = map_obs[:,1:]  
map_obs_2 = map_obs_ - obs_mean  

# Noise variance for uncertainty quantification
noise_var=3.9e-3

In [None]:
# Precompute matrices for efficient L-BFGS optimization
UNoise_precU = PCA_.double().T @ Noise_prec @ PCA_.double()  
yMU = (obs_w_noise[:, 9::10] @ Noise_prec @ PCA_.double())   

# Set up L-BFGS optimizer for quasi-Newton optimization
model.eval()
optimizer = optim.LBFGS([m], 
                       lr=learning_rate, 
                       max_iter=150,           
                       max_eval=None,          
                       tolerance_grad=1e-7,    
                       tolerance_change=1e-9,  
                       history_size=150,       
                       line_search_fn="strong_wolfe") 

# Storage for optimization history
L2_lbfgs = []

def closure():
    """Closure function for L-BFGS optimizer - computes loss and gradients"""
    optimizer.zero_grad()
    
    # Forward pass through neural network
    states = model(init_y, m)[0]
    
    # Transform predictions to observation space
    nn_val = (states*std_o + mean_o).double()[:,9::10]  
    
    # Efficient likelihood computation using precomputed matrices
    like1 = torch.einsum('bij,bij->b',nn_val,yMU)                  
    like2 = torch.einsum('bij,jk,bik->b', nn_val, UNoise_precU, nn_val)  
    like = (-2*like1 + like2)/2.                                 
    
    # Prior regularization term
    prior = torch.einsum('ij,ij->i', m, m) /2.
    
    # Total loss (negative log posterior)
    loss = (prior + like).mean()
    
    # Compute gradients
    loss.backward()
    return loss

# Optimization loop
for epoch in range(1):
    loss = optimizer.step(closure)  
    
    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
    
    # Store optimization metrics
    L2_lbfgs.append(loss.item())