In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from godunov_vis_tools import * 
from tqdm import tqdm

In [None]:
# Load the data
data = np.load('../datasets/godunov_100_combined.npz', allow_pickle=True)

# Extract the data
branch_coords = data['branch_coords']
branch_values = data['branch_values']
output_sensor_coords = data['output_sensor_coords']
output_sensor_values = data['output_sensor_values']
rho = data['rho']
# v = data['v']
x = data['x']
tt = data['t']
Nx = data['Nx']
Nt = data['Nt']
Xmax = data['Xmax']
Tmax = data['Tmax']
P = data['P']
N = data['N']
# keys = data['keys']


print(f"branch_coords.shape = {branch_coords.shape}, branch_values.shape = {branch_values.shape}, output_sensor_coords.shape = {output_sensor_coords.shape}, ")
print(f"output_sensor_values.shape = {output_sensor_values.shape}, rho.shape = {rho.shape}, x.shape = {x.shape}, t.shape = {tt.shape}")
print(f"Nx = {Nx}, Nt = {Nt}, Xmax = {Xmax}, Tmax = {Tmax}, P = {P}, N = {N}")

In [None]:
memory_stats_npz(data)

In [None]:
idx = 4
n_probes = 8 
max_id = 0
min_id = 1e6

T_PRED = 8
T_PAST = 2

branch_coords_filtered_list, branch_values_filtered_list = [], []
output_sensor_coords_filtered_list, output_sensor_values_filtered_list = [], []
shapes_branch, shapes_trunk = [], []
shapes_branch_neg_1 = []
shapes_branch_other = []
t_starts = []

tt_sub = tt[:,::5]
tt_starts = tt_sub[tt_sub <= Tmax - T_PRED - T_PAST]

seed=0
# set the seed for reproducibility
np.random.seed(seed)

# for every scenario
for idx in tqdm(range(len(branch_coords))):

    # randomly sample a shift between 0 and T_max - T_PRED - T_PAST
    t_start = np.random.choice(tt_starts)
    t = t_start + T_PAST
    t_starts.append(t_start)


    # Filter out coordinates where ID is -2
    coords_probes_boundary = branch_coords[idx][branch_coords[idx, :, 2] != -2]

    # Further filter based on t_max for sampling purposes
    coords_in_horizon = coords_probes_boundary[(coords_probes_boundary[:, 1] <= t + T_PRED) & (coords_probes_boundary[:, 1] >= t - T_PAST)]

    # # Get all unique IDs except -1 and -2
    unique_ids = np.unique(coords_in_horizon[:, 2])
    unique_ids = unique_ids[(unique_ids != -1) & (unique_ids != -2)]

    # Sample IDs with even spacing
    sampled_ids = unique_ids[::max(1, len(unique_ids) // n_probes)]  # Evenly spaced selection of IDs

    # Create a mask for branch_coords to keep points with the sampled IDs and -1
    mask_sampled_ids = np.isin(branch_coords[idx][:, 2], sampled_ids) & (branch_coords[idx][:, 1] <= t) & (branch_coords[idx][:, 1] >= t - T_PAST)
    
    # Mask for ID == -1 entries
    mask_boundary = branch_coords[idx][:, 2] == -1
    
    # # Remove ID == -1 entries if x location is below 5 or above t_max_boundary
    mask_id_neg1_x_above_5 = mask_boundary & (branch_coords[idx][:, 0] >= 4) & (branch_coords[idx][:, 1] <= t + T_PRED) & (branch_coords[idx][:, 1] >= t - T_PAST)
    
    # Get indices of ID == -1 and evenly sample half the points
    neg1_indices = np.where(mask_id_neg1_x_above_5)[0]
    # half_neg1_indices = np.random.choice(neg1_indices, size=len(neg1_indices) // 2, replace=False) if len(neg1_indices) > 0 else []
    half_neg1_indices = np.random.choice(neg1_indices, size=len(neg1_indices), replace=False) if len(neg1_indices) > 0 else []

    # Create a mask for the sampled half of ID == -1
    mask_half_neg1 = np.zeros(mask_id_neg1_x_above_5.shape, dtype=bool)
    mask_half_neg1[half_neg1_indices] = True
    
    # Combine the two masks to filter both branch_coords and branch_values
    final_mask = mask_sampled_ids | mask_half_neg1

    # Apply the final mask to both branch_coords and branch_values
    filtered_coords = branch_coords[idx][final_mask]
    filtered_values = branch_values[idx][final_mask]

    # shift to t = 0
    filtered_coords[:, 1] -= t_start

    # # Append the filtered coordinates and values to the list
    branch_coords_filtered_list.append(filtered_coords)
    branch_values_filtered_list.append(filtered_values)

    # keep trunk_coords where t is in horizon
    output_sensor_coords_filtered = output_sensor_coords[idx][(output_sensor_coords[idx][:, 1] <= t + T_PRED) & (output_sensor_coords[idx][:, 1] >= t - T_PAST)]
    output_sensor_values_filtered = output_sensor_values[idx][(output_sensor_coords[idx][:, 1] <= t + T_PRED) & (output_sensor_coords[idx][:, 1] >= t - T_PAST)]

    # shift to t = 0
    output_sensor_coords_filtered[:, 1] -= t_start

    # Append the filtered coordinates and values to the list
    output_sensor_coords_filtered_list.append(output_sensor_coords_filtered)
    output_sensor_values_filtered_list.append(output_sensor_values_filtered)

    shapes_branch.append(filtered_coords.shape[0])
    shapes_trunk.append(output_sensor_coords_filtered.shape[0])

    # Count the number of id == -1 and id > 0 in the horizon
    count_neg_1 = np.sum(filtered_coords[:, 2] == -1)
    count_other = np.sum(filtered_coords[:, 2] > 0)

    # Append the counts to respective lists
    shapes_branch_neg_1.append(count_neg_1)
    shapes_branch_other.append(count_other)

    # Get the max and min number of IDs
    max_id = max(max_id, filtered_coords.shape[0])
    min_id = min(min_id, filtered_coords.shape[0])

print(f"max_id = {max_id}, min_id = {min_id}")


In [None]:
fig, ax = plt.subplots(1, 4, figsize=(15, 5))
ax[0].plot(sorted(shapes_branch_neg_1)), ax[1].plot(sorted(shapes_branch_other)), ax[2].plot(np.array(sorted(shapes_branch_other)) + np.array(sorted(shapes_branch_neg_1))), ax[3].plot(sorted(shapes_trunk));
ax[0].set_title("# ID == -1"), ax[1].set_title("# ID > 0"), ax[2].set_title("# ID > 0 + # ID == -1"), ax[3].set_title("# Trunk sensors");

In [6]:
# def pad_to_shape_branch(coords, values, target_shape_boundary, target_shape_probe):
#     # Separate boundary and probe data based on ID in coords
#     boundary_data_coords = coords[coords[:, 2] == -1]
#     probe_data_coords = coords[coords[:, 2] != -1]
    
#     boundary_data_values = values[coords[:, 2] == -1]
#     probe_data_values = values[coords[:, 2] != -1]
    
#     # Truncate or pad boundary data to target_shape_boundary
#     filtered_boundary_coords = boundary_data_coords[:target_shape_boundary]
#     filtered_boundary_values = boundary_data_values[:target_shape_boundary]
    
#     # Truncate or pad probe data to target_shape_probe
#     if probe_data_coords.shape[0] < target_shape_probe:
#         # Pad if probe data is smaller than target_shape_probe
#         pad_size = target_shape_probe - probe_data_coords.shape[0]
        
#         # Generate random values from existing probe data
#         random_indices = np.random.choice(probe_data_coords.shape[0], size=pad_size)
#         random_coords = probe_data_coords[random_indices]
#         random_values = probe_data_values[random_indices]
        
#         # Concatenate original and random padded data
#         filtered_probe_coords = np.concatenate([probe_data_coords, random_coords], axis=0)
#         filtered_probe_values = np.concatenate([probe_data_values, random_values], axis=0)
#     else:
#         # Truncate if probe data is larger than target_shape_probe
#         filtered_probe_coords = probe_data_coords[:target_shape_probe]
#         filtered_probe_values = probe_data_values[:target_shape_probe]

#     # Combine boundary and probe data back together
#     filtered_coords = np.vstack([filtered_boundary_coords, filtered_probe_coords])
#     filtered_values = np.vstack([filtered_boundary_values, filtered_probe_values])

#     return filtered_coords, filtered_values

In [7]:
# pad to shape of max_id
filtered_coords_values = [
    pad_to_shape_branch(coords, values, target_shape_boundary=340, target_shape_probe=188) 
    for coords, values in zip(branch_coords_filtered_list, branch_values_filtered_list)
]

# Split into separate lists if needed
filtered_coords_padded = np.array([item[0] for item in filtered_coords_values])
filtered_values_padded = np.array([item[1] for item in filtered_coords_values])

m_min = min(arr.shape[0] for arr in output_sensor_coords_filtered_list)

# Stack the arrays, trimming each to m_min rows
filtered_output_coords_padded = np.stack([arr[-m_min:] for arr in output_sensor_coords_filtered_list])
filtered_output_values_padded = np.stack([arr[-m_min:] for arr in output_sensor_values_filtered_list])

In [None]:
filtered_coords_values[0][0].shape, filtered_coords_values[0][1].shape

In [None]:
for i in range(2):
    plt.figure()
    plt.scatter(filtered_coords_padded[i, :, 1], filtered_coords_padded[i,:, 0], c=filtered_values_padded[i,:,0], cmap='jet', vmin=0, vmax=1)
    plt.xlim([0, T_PRED + T_PAST])
    plt.ylim([0, Xmax])
    plt.title(f"t_start = {t_starts[i]}")

    plt.figure()
    plt.title(f"t_start in tt {t_starts[i] in tt}")
    plt.scatter(filtered_output_coords_padded[i, :, 1], filtered_output_coords_padded[i,:, 0], c=filtered_output_values_padded[i,:], cmap='jet', vmin=0, vmax=1)
    plt.xlim([0, T_PRED + T_PAST])
    plt.ylim([0, Xmax])

In [None]:
# print the shapes
print(f"filtered_coords_padded.shape = {filtered_coords_padded.shape}, filtered_values_padded.shape = {filtered_values_padded.shape}")

In [None]:
# save to npz
rho_sub = rho[:,:,::5]
t_sub = tt[:,::5]
filtred_simulation_data = {
    'branch_coords': filtered_coords_padded.astype(np.float16),
    'branch_values': filtered_values_padded.astype(np.float16),
    'output_sensor_coords': filtered_output_coords_padded.astype(np.float16),
    'output_sensor_values': filtered_output_values_padded.astype(np.float16),
    'rho': rho_sub.astype(np.float16),
    # 'v': v.astype(np.float32),  
    'x': x.astype(np.float16),
    't': t_sub.astype(np.float16),
    'Nx': Nx, 
    'Nt': rho_sub.shape[-1],
    'Xmax': Xmax,
    'Tmax': Tmax,
    'P': P,
    'N': N,
    # 'keys': keys,
    'number_of_probes': n_probes,
    't_pred': T_PRED,
    't_past': T_PAST,
    't_starts': t_starts

}

# Save the data to an npz file
new_path = f'../datasets/godunov_combined_tpast{T_PAST}_tpred{T_PRED}_receding.npz'
print(f"Saving the filtered data to {new_path}")
np.savez(new_path, **filtred_simulation_data)

In [None]:
data_filtered = np.load(new_path, allow_pickle=True)

memory_stats_npz(data_filtered)