In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import plotly.graph_objs as go
from sklearn.model_selection import train_test_split

from torch.cuda.amp import autocast, GradScaler
torch.cuda.empty_cache()
import plotly.graph_objects as go


import importlib
import tools_torch
importlib.reload(tools_torch)
from tools_torch import *


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"using {device}")

In [None]:
data = np.load('../datasets/sumo/receding_sumo_idm_dataset_tpast2_tpred8.npz')


# Extract the data
branch_coords = torch.tensor(data['branch_coords'])
branch_values = torch.tensor(data['branch_values'])
output_sensor_coords = torch.tensor(data['output_sensor_coords'])
output_sensor_values = torch.tensor(data['output_sensor_values'])
rho = torch.tensor(data['rho'])
Nx = data['Nx'].item()
Nt = data['Nt'].item()
Xmax = data['Xmax'].item()
Tmax = data['Tmax'].item()
N = data['N'].item()
t_starts = torch.tensor(data['t_starts'])
t_pred = torch.tensor(data['t_pred'])
t_past = torch.tensor(data['t_past'])


print(f"Nx = {Nx}, Nt = {Nt}, Xmax = {Xmax}, Tmax = {Tmax}, N = {N}")
print(f"branch_coords.shape = {branch_coords.shape}, branch_values.shape = {branch_values.shape}, output_sensor_coords.shape = {output_sensor_coords.shape}, ")
print(f"t_starts = {t_starts.shape}")


In [7]:
# set numpy random seed
np.random.seed(42)

# Define the validation percentage
validation_percentage = 0.2  # validation

branch_coords_train, branch_coords_val, branch_values_train, branch_values_val, \
output_sensor_coords_train, output_sensor_coords_val, output_sensor_values_train, output_sensor_values_val, rho_train, rho_val, \
t_starts_train, t_starts_val = train_test_split(
    branch_coords, branch_values, output_sensor_coords, output_sensor_values, rho, t_starts,
    test_size=validation_percentage, random_state=42
)

In [None]:
# Create dataset and dataloader
batch_size = 32
train_dataset = DeepONetDatasetTrain(branch_coords_train[:,:,:], branch_values_train[:,:,0:1], output_sensor_coords_train, output_sensor_values_train, device=device) # branch_coord, branch_values, trunk_coords, targets, UU
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
val_dataset = DeepONetDatasetTrain(branch_coords_val[:,:,:], branch_values_val[:,:,0:1], output_sensor_coords_val, output_sensor_values_val, device=device)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)

xs, us, ys, ss = next(iter(train_loader))
print(f"train shapes\t xs: {xs.shape}, us: {us.shape}, ys: {ys.shape}, ss: {ss.shape}, device: {xs.device}")
xs, us, ys, ss = next(iter(val_loader))
print(f"val shapes\t xs: {xs.shape}, us: {us.shape}, ys: {ys.shape}, ss: {ss.shape}, device: {xs.device}")


In [None]:
# Define the model
from vidon_model import VIDON, FDLearner
p = 400
model = VIDON(p=p, num_heads=4, d_branch_input=3, d_v=2, use_linear_decoder=False, UQ=True).to(device)
model.to(device)
FD = FDLearner(d_FD=50)
FD.to(device)

# Define the loss function 
criterion = nn.MSELoss()

model.load_state_dict(torch.load("model_sumo_19_11_past4_pred6_2_phys_UQ.pth"))

In [10]:
batch_size_test = batch_size
train_dataset_test = DeepONetDataset(branch_coords_train[:,:,:], branch_values_train[:,:,0:], output_sensor_coords_train, output_sensor_values_train, rho_train, device=device, t_start=t_starts_train) # branch_coord, branch_values, trunk_coords, targets, UU
train_loader_test = DataLoader(train_dataset_test, batch_size=batch_size_test, shuffle=True)
val_dataset_test = DeepONetDataset(branch_coords_val[:,:,:], branch_values_val[:,:,0:], output_sensor_coords_val, output_sensor_values_val, rho_val, device=device, t_start=t_starts_val)
val_loader_test = DataLoader(val_dataset_test, batch_size=batch_size_test, shuffle=False)

In [None]:
# validate model on validation set
model.eval()
FD.eval()
import torch.nn.functional as F

losses = []  # Initialize for tracking loss
mae_errors = []  # Initialize for tracking MAE

with torch.no_grad():
    for i, (branch_coords, branch_values, trunk_coords, trunk_values, rhos, tstarts) in tqdm(enumerate(val_loader_test)):
        branch_coords = branch_coords.to(device)
        branch_values = branch_values.to(device)
        trunk_coords = trunk_coords.to(device)
        trunk_values = trunk_values.to(device)
        
        # Forward pass with mixed precision
        with torch.cuda.amp.autocast():
            rho_pred, sigma = model(branch_coords[:,:,:], branch_values, trunk_coords)
            loss = criterion(rho_pred, trunk_values)
            mae = F.l1_loss(rho_pred, trunk_values, reduction='mean')  # Calculate MAE
        
        losses.append(loss.item())
        mae_errors.append(mae.item())

# Calculate mean of the losses and MAE errors
validation_loss = np.mean(losses)
mean_absolute_error = np.mean(mae_errors)

print(f"Validation loss: {validation_loss}")
print(f"Mean Absolute Error (MAE): {mean_absolute_error}")
