In [None]:
# import npz
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
# from godunov_vis_tools import * 
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"using {device}")
from tools_torch import *
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [3]:
# Load the data
data = np.load('../datasets/controlled_boundary_wavelet_20000\data_wavelet_boundary_20000_combined.npz', allow_pickle=True)

In [None]:
# Extract the data
branch_coords = data['branch_coords']
branch_values = data['branch_values']
output_sensor_coords = data['output_sensor_coords']
output_sensor_values = data['output_sensor_values']
rho = data['rho']
# v = data['v']
x = data['x']
tt = data['t']
Nx = data['Nx']
Nt = data['Nt']
Xmax = data['Xmax']
Tmax = data['Tmax']
P = data['P']
N = data['N']


print(f"branch_coords.shape = {branch_coords.shape}, branch_values.shape = {branch_values.shape}, output_sensor_coords.shape = {output_sensor_coords.shape}, ")
print(f"output_sensor_values.shape = {output_sensor_values.shape}, rho.shape = {rho.shape}, x.shape = {x.shape}, t.shape = {tt.shape}")
print(f"Nx = {Nx}, Nt = {Nt}, Xmax = {Xmax}, Tmax = {Tmax}, P = {P}, N = {N}")

In [None]:
from vidon_model import VIDON, FDLearner
p = 400
model = VIDON(p=p, num_heads=4, d_branch_input=2, d_v=1, use_linear_decoder=False, UQ=True).to(device)
model.to(device)

# Define the loss function 
criterion = nn.MSELoss()

model.load_state_dict(torch.load("model_irregular_boundary_UQTrue_random_sampling_11_nov_p400_extUQ3.pth",map_location=torch.device('cpu')))

In [5]:
def pad_to_shape_branch(coords, values, target_shape_boundary, target_shape_probe):
    # Separate boundary and probe data based on ID in coords
    boundary_data_coords = coords[coords[:, 2] == -1]
    probe_data_coords = coords[coords[:, 2] != -1]
    
    boundary_data_values = values[coords[:, 2] == -1]
    probe_data_values = values[coords[:, 2] != -1]
    
    # Truncate or pad boundary data to target_shape_boundary
    filtered_boundary_coords = boundary_data_coords[:target_shape_boundary]
    filtered_boundary_values = boundary_data_values[:target_shape_boundary]
    
    # Truncate or pad probe data to target_shape_probe
    if probe_data_coords.shape[0] < target_shape_probe:
        # Pad if probe data is smaller than target_shape_probe
        pad_size = target_shape_probe - probe_data_coords.shape[0]
        
        # Generate random values from existing probe data
        random_indices = np.random.choice(probe_data_coords.shape[0], size=pad_size)
        random_coords = probe_data_coords[random_indices]
        random_values = probe_data_values[random_indices]
        
        # Concatenate original and random padded data
        filtered_probe_coords = np.concatenate([probe_data_coords, random_coords], axis=0)
        filtered_probe_values = np.concatenate([probe_data_values, random_values], axis=0)
    else:
        # Truncate if probe data is larger than target_shape_probe
        filtered_probe_coords = probe_data_coords[:target_shape_probe]
        filtered_probe_values = probe_data_values[:target_shape_probe]

    # Combine boundary and probe data back together
    filtered_coords = np.vstack([filtered_boundary_coords, filtered_probe_coords])
    filtered_values = np.vstack([filtered_boundary_values, filtered_probe_values])

    return filtered_coords, filtered_values

In [6]:
tt_sub = tt[:,::5]

In [7]:
idx = 4
n_probes = 8  # Replace with the desired number of unique IDs to sample
max_id = 0
min_id = 1e6

T_PRED = 8
T_PAST = 2

def create_val_loader_for_shifted_t(t_start, branch_coords, branch_values, output_sensor_coords, output_sensor_values):
    branch_coords_filtered_list, branch_values_filtered_list = [], []
    output_sensor_coords_filtered_list, output_sensor_values_filtered_list = [], []
    t_starts = []

    tt_sub = tt[:,::5]
    tt_starts = tt_sub[tt_sub <= Tmax - T_PRED - T_PAST]


    # for every scenario
    for idx in range(len(branch_coords)):

        t = t_start + T_PAST
        t_starts.append(t_start)


        # Filter out coordinates where ID is -2
        coords_probes_boundary = branch_coords[idx][branch_coords[idx, :, 2] != -2]

        # Further filter based on t_max for sampling purposes
        coords_in_horizon = coords_probes_boundary[(coords_probes_boundary[:, 1] <= t + T_PRED) & (coords_probes_boundary[:, 1] >= t - T_PAST)]

        # # Get all unique IDs except -1 and -2
        unique_ids = np.unique(coords_in_horizon[:, 2])
        unique_ids = unique_ids[(unique_ids != -1) & (unique_ids != -2)]

        # Sample IDs with even spacing
        sampled_ids = unique_ids[::max(1, len(unique_ids) // n_probes)]  # Evenly spaced selection of IDs

        # Create a mask for branch_coords to keep points with the sampled IDs and -1
        mask_sampled_ids = np.isin(branch_coords[idx][:, 2], sampled_ids) & (branch_coords[idx][:, 1] <= t) & (branch_coords[idx][:, 1] >= t - T_PAST)
        
        # Mask for ID == -1 entries
        mask_boundary = branch_coords[idx][:, 2] == -1
        
        # # Remove ID == -1 entries if x location is below 5 or above t_max_boundary
        mask_id_neg1_x_above_5 = mask_boundary & (branch_coords[idx][:, 0] >= 4) & (branch_coords[idx][:, 1] <= t + T_PRED) & (branch_coords[idx][:, 1] >= t - T_PAST)
        
        # Get indices of ID == -1 and evenly sample half the points
        neg1_indices = np.where(mask_id_neg1_x_above_5)[0]
        # half_neg1_indices = np.random.choice(neg1_indices, size=len(neg1_indices) // 2, replace=False) if len(neg1_indices) > 0 else []
        half_neg1_indices = np.random.choice(neg1_indices, size=len(neg1_indices), replace=False) if len(neg1_indices) > 0 else []

        # Create a mask for the sampled half of ID == -1
        mask_half_neg1 = np.zeros(mask_id_neg1_x_above_5.shape, dtype=bool)
        mask_half_neg1[half_neg1_indices] = True
        
        # Combine the two masks to filter both branch_coords and branch_values
        final_mask = mask_sampled_ids | mask_half_neg1

        # Apply the final mask to both branch_coords and branch_values
        filtered_coords = branch_coords[idx][final_mask]
        filtered_values = branch_values[idx][final_mask]

        # shift to t = 0
        filtered_coords[:, 1] -= t_start

        # # Append the filtered coordinates and values to the list
        branch_coords_filtered_list.append(filtered_coords)
        branch_values_filtered_list.append(filtered_values)

        # keep trunk_coords where t is in horizon
        output_sensor_coords_filtered = output_sensor_coords[idx][(output_sensor_coords[idx][:, 1] <= t + T_PRED) & (output_sensor_coords[idx][:, 1] >= t - T_PAST)]
        output_sensor_values_filtered = output_sensor_values[idx][(output_sensor_coords[idx][:, 1] <= t + T_PRED) & (output_sensor_coords[idx][:, 1] >= t - T_PAST)]

        # shift to t = 0
        output_sensor_coords_filtered[:, 1] -= t_start

        # Append the filtered coordinates and values to the list
        output_sensor_coords_filtered_list.append(output_sensor_coords_filtered)
        output_sensor_values_filtered_list.append(output_sensor_values_filtered)


    # pad to shape of max_id
    filtered_coords_values = [
        pad_to_shape_branch(coords, values, target_shape_boundary=188, target_shape_probe=200) 
        for coords, values in zip(branch_coords_filtered_list, branch_values_filtered_list)
    ]

    # Split into separate lists if needed
    filtered_coords_padded = np.array([item[0] for item in filtered_coords_values])
    filtered_values_padded = np.array([item[1] for item in filtered_coords_values])

    m_min = min(arr.shape[0] for arr in output_sensor_coords_filtered_list)

    # Stack the arrays, trimming each to m_min rows
    filtered_output_coords_padded = np.stack([arr[-m_min:] for arr in output_sensor_coords_filtered_list])
    filtered_output_values_padded = np.stack([arr[-m_min:] for arr in output_sensor_values_filtered_list])

    # make dataloader
    branch_coords = torch.tensor(filtered_coords_padded.astype(np.float16))
    branch_values = torch.tensor(filtered_values_padded.astype(np.float16))
    output_sensor_coords = torch.tensor(filtered_output_coords_padded.astype(np.float16))
    output_sensor_values = torch.tensor(filtered_output_values_padded.astype(np.float16))
    # rho = torch.tensor(rho.astype(np.float16))
    t_starts = torch.tensor(np.array(t_starts))

    val_dataset  = DeepONetDatasetTrain(branch_coords[:,:,:], branch_values[:,:,0:1], output_sensor_coords, output_sensor_values, device=device)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=8, pin_memory=True)

    return val_loader

In [8]:
shifts_possible = sum(sum([tt_sub[0] < 15]))

In [None]:
import time

# Add timing for val loader creation and validation
skip_factor = 1
mse_per_timestep = []
mae_per_timestep = []

for _, timestep in tqdm(
    enumerate(tt_sub[0][tt_sub[0] < 15][::skip_factor]), 
    total=int(shifts_possible / skip_factor), 
    desc="Processing timesteps"
):
    tqdm.write(f"Current Timestep: {timestep:.2f}")  # Print current timestep dynamically

    # Timer for validation loader creation
    start_val_loader = time.time()
    val_loader = create_val_loader_for_shifted_t(
        timestep, branch_coords, branch_values, output_sensor_coords, output_sensor_values
    )
    val_loader_time = time.time() - start_val_loader
    # tqdm.write(f"Validation Loader Creation Time: {val_loader_time:.4f} seconds")

    # Validate model on validation set
    model.eval()

    losses = []  # Initialize for tracking loss
    mae_errors = []  # Initialize for tracking MAE

    # Timer for torch.no_grad section
    start_no_grad = time.time()
    with torch.no_grad():
        for i, (branch_coords_, branch_values_, trunk_coords_, trunk_values_) in enumerate(val_loader):
            branch_coords_ = branch_coords_.to(device)
            branch_values_ = branch_values_.to(device)
            trunk_coords_ = trunk_coords_.to(device)
            trunk_values_ = trunk_values_.to(device)
            
            # Forward pass with mixed precision
            with torch.cuda.amp.autocast():
                rho_pred, sigma = model(branch_coords_[:, :, :2], branch_values_, trunk_coords_)
                loss = criterion(rho_pred, trunk_values_)
                mae = F.l1_loss(rho_pred, trunk_values_, reduction='mean')  # Calculate MAE
            
            losses.append(loss.item())
            mae_errors.append(mae.item())
    no_grad_time = time.time() - start_no_grad
    # tqdm.write(f"Torch.no_grad Section Time: {no_grad_time:.4f} seconds")

    # Calculate mean of the losses and MAE errors
    validation_loss = np.mean(losses)
    mean_absolute_error = np.mean(mae_errors)

    mse_per_timestep.append(validation_loss)
    mae_per_timestep.append(mean_absolute_error)


In [None]:
plt.plot(mse_per_timestep)
print(np.mean(mse_per_timestep))
# plt.plot(mse_per_timestep_receding)
plt.yscale('log')

In [10]:
def filter_time_window(t_start, branch_coords, branch_values, output_sensor_coords, output_sensor_values, idx):
    # randomly sample a shift between 0 and T_max - T_PRED - T_PAST
    t = t_start + T_PAST

    # Filter out coordinates where ID is -2
    coords_probes_boundary = branch_coords[idx][branch_coords[idx, :, 2] != -2]

    # Further filter based on t_max for sampling purposes
    coords_in_horizon = coords_probes_boundary[(coords_probes_boundary[:, 1] <= t + T_PRED) & (coords_probes_boundary[:, 1] >= t - T_PAST)]

    # # Get all unique IDs except -1 and -2
    unique_ids = np.unique(coords_in_horizon[:, 2])
    unique_ids = unique_ids[(unique_ids != -1) & (unique_ids != -2)]

    # Sample IDs with even spacing
    sampled_ids = unique_ids[::max(1, len(unique_ids) // n_probes)]  # Evenly spaced selection of IDs
    print(sampled_ids)

    # Create a mask for branch_coords to keep points with the sampled IDs and -1
    mask_sampled_ids = np.isin(branch_coords[idx][:, 2], sampled_ids) & (branch_coords[idx][:, 1] <= t) & (branch_coords[idx][:, 1] >= t - T_PAST)
    
    # Mask for ID == -1 entries
    mask_boundary = branch_coords[idx][:, 2] == -1
    
    # # Remove ID == -1 entries if x location is below 5 or above t_max_boundary
    mask_id_neg1_x_above_5 = mask_boundary & (branch_coords[idx][:, 0] >= 4) & (branch_coords[idx][:, 1] <= t + T_PRED) & (branch_coords[idx][:, 1] >= t - T_PAST)
    
    # Get indices of ID == -1 and evenly sample half the points
    neg1_indices = np.where(mask_id_neg1_x_above_5)[0]
    # half_neg1_indices = np.random.choice(neg1_indices, size=len(neg1_indices) // 2, replace=False) if len(neg1_indices) > 0 else []
    half_neg1_indices = np.random.choice(neg1_indices, size=len(neg1_indices), replace=False) if len(neg1_indices) > 0 else []

    # Create a mask for the sampled half of ID == -1
    mask_half_neg1 = np.zeros(mask_id_neg1_x_above_5.shape, dtype=bool)
    mask_half_neg1[half_neg1_indices] = True
    
    # Combine the two masks to filter both branch_coords and branch_values
    final_mask = mask_sampled_ids | mask_half_neg1

    # Apply the final mask to both branch_coords and branch_values
    filtered_coords = branch_coords[idx][final_mask]
    filtered_values = branch_values[idx][final_mask]

    output_sensor_coords_filtered = output_sensor_coords[idx][(output_sensor_coords[idx][:, 1] <= t + T_PRED) & (output_sensor_coords[idx][:, 1] >= t - T_PAST)]
    output_sensor_values_filtered = output_sensor_values[idx][(output_sensor_coords[idx][:, 1] <= t + T_PRED) & (output_sensor_coords[idx][:, 1] >= t - T_PAST)]

    return filtered_coords, filtered_values, output_sensor_coords_filtered, output_sensor_values_filtered