In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
from sklearn.model_selection import train_test_split

from torch.cuda.amp import autocast, GradScaler
torch.cuda.empty_cache()

import importlib
import tools_torch
importlib.reload(tools_torch)
from tools_torch import *

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"using {device}")

# load dataset

In [None]:
data = np.load("../datasets/sumo/sumo_idm_dataset_tpast2_tpred8.npz")

# Extract the data
branch_coords = torch.tensor(data['branch_coords'])
branch_values = torch.tensor(data['branch_values'])
output_sensor_coords = torch.tensor(data['output_sensor_coords'])
output_sensor_values = torch.tensor(data['output_sensor_values'])
rho = torch.tensor(data['rho'])

Nx = data['Nx'].item()
Nt = data['Nt'].item()
Xmax = data['Xmax'].item()
Tmax = data['Tmax'].item()
N = data['N'].item()
t_starts = torch.tensor(data['t_starts'])
t_pred = data['t_pred'].item()
t_past = data['t_past'].item()
t = np.linspace(0, Tmax, Nt)

print(f"Nx = {Nx}, Nt = {Nt}, Xmax = {Xmax}, Tmax = {Tmax}, N = {N}")
print(f"branch_coords.shape = {branch_coords.shape}, branch_values.shape = {branch_values.shape}, output_sensor_coords.shape = {output_sensor_coords.shape}, ")
print(f"output_sensor_values.shape = {output_sensor_values.shape}")
print(f"t_starts = {t_starts.shape}")
print(f", rho.shape = {rho.shape}")

In [None]:
# set numpy random seed
np.random.seed(42)

# Define the validation percentage
validation_percentage = 0.2  # validation
print((1-validation_percentage)*N)

branch_coords_train, branch_coords_val, branch_values_train, branch_values_val, \
output_sensor_coords_train, output_sensor_coords_val, output_sensor_values_train, output_sensor_values_val, rho_train, rho_val, \
t_starts_train, t_starts_val = train_test_split(
    branch_coords, branch_values, output_sensor_coords, output_sensor_values, rho, t_starts,
    test_size=validation_percentage, random_state=42
)

In [None]:
# Create dataset and dataloader
batch_size = 32 

train_dataset = DeepONetDatasetTrain(branch_coords_train[:,:,:], branch_values_train[:,:,:], output_sensor_coords_train, output_sensor_values_train, device=device) # branch_coord, branch_values, trunk_coords, targets, UU
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_dataset = DeepONetDatasetTrain(branch_coords_val[:,:,:], branch_values_val[:,:,:], output_sensor_coords_val, output_sensor_values_val, device=device)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

xs, us, ys, ss = next(iter(train_loader))
print(f"train shapes\t xs: {xs.shape}, us: {us.shape}, ys: {ys.shape}, ss: {ss.shape}, device: {xs.device}")
xs, us, ys, ss = next(iter(val_loader))
print(f"val shapes\t xs: {xs.shape}, us: {us.shape}, ys: {ys.shape}, ss: {ss.shape}, device: {xs.device}")


In [None]:
idx=1
plt.scatter(ys[idx,:,1].cpu().numpy(), ys[idx,:,0].cpu().numpy(), c=ss[idx,:].cpu().numpy(), label='output sensors',  cmap='jet', vmin=0, vmax=1)
plt.colorbar()
plt.figure()
plt.scatter(xs[idx,:,1].cpu().numpy(), xs[idx,:,0].cpu().numpy(), c=us[idx,:,0].cpu().numpy(), label='branch sensors', cmap='jet', vmin=0, vmax=1)
plt.ylim(0), plt.xlim(0)
plt.colorbar()

# model definition

In [None]:
# Define the model
from vidon_model import VIDON, FDLearner
p = 400
model = VIDON(p=p, num_heads=4, d_branch_input=3, d_v=2, use_linear_decoder=False, UQ=True).to(device)
model.to(device)
FD = FDLearner(d_FD=50)
FD.to(device)

# Define the loss function 
criterion = nn.MSELoss()

num_epochs = 5
# Define the optimizer
optimizer = optim.AdamW(
    list(model.parameters()) + list(FD.parameters()), 
    lr=0.001
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=50,
    min_lr=1e-7,
    cooldown=5,
    verbose=True
)

scaler = torch.cuda.amp.GradScaler()


# print number of parameters
print(f"Number of parameters in coord encoder: {sum(p.numel() for p in model.branch.coord_encoder.parameters())}")
print(f"Number of parameters in value encoder: {sum(p.numel() for p in model.branch.value_encoder.parameters())}")
print(f"Number of parameters in combiner: {sum(p.numel() for p in model.branch.combiner.parameters())}")
print(f"Number of parameters in nonlinear decoder: {sum(p.numel() for p in model.nonlinear_decoder.parameters())}")
print(f"Number of parameters in trunk: {sum(p.numel() for p in model.trunk.parameters())}")
print(f"Number of parameters in FD:  \t{sum(p.numel() for p in FD.parameters())}")
print(f"Number of parameters in model: {sum(p.numel() for p in model.parameters())}")

In [7]:
# # freeze except for UQ layers
# for param in model.parameters():
#     param.requires_grad = False

# model.bias.requires_grad = False
# # # Unfreeze parameters for specific components
# for param in model.nonlinear_decoder_sigm.parameters():
#     param.requires_grad = True

# for param in model.branch.combiner_sigm.parameters():
#     param.requires_grad = True

# for param in model.trunk.sigm_trunk.parameters():
#     param.requires_grad = True

# model=model.float()
# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.0001)


# training loop

In [None]:
gradient_norms = []
loss_list = []
loss_ic_list = []
val_loss_list = []
val_loss_mse_list = []
phys_losses_list = []
sigma_list = []
v_loss_list = []
lrs = []
lambdas = []
gammas = []
val_loss = 0
lowest_val_loss = 0.1

T_PAST = 2
T_PRED = 8

previous_epoch_loss_data = float('inf')
previous_epoch_loss_lwr = float('inf')

# Training loop 
pbar = tqdm(range(num_epochs))
for epoch in pbar:
    model.train()
    losses = []
    v_losses = []
    losses_ic = []
    sigmas = []

    T_START = None
    
    for branch_coord, branch_values, trunk_coords, targets in train_loader:

        # bring to gpu
        branch_coord = branch_coord.to(device)
        branch_values = branch_values.to(device)
        trunk_coords = trunk_coords.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()

        filtered_trunk_coords, filtered_targets = trunk_coords, targets
        filtered_branch_coords, filtered_branch_values = branch_coord, branch_values

        # select random parts
        sampled_trunk_coords, sampled_targets = sample_trunk_inputs(filtered_trunk_coords, filtered_targets)
        sampled_branch_coord, sampled_branch_values = sample_branch_inputs_keep_boundary(filtered_branch_coords, filtered_branch_values, min_to_keep=0.8)

        
        # sampled_trunk_coords, sampled_targets = trunk_coords, targets
        # sampled_branch_coord, sampled_branch_values = branch_coord, branch_values

        sampled_trunk_coords = sampled_trunk_coords.clone().detach().requires_grad_(True)

    
        # forward pass for rho and v with automatic mixed precision
        with torch.cuda.amp.autocast():
        
            outputs_rho, outputs_rho_sigm = model(sampled_branch_coord, sampled_branch_values, sampled_trunk_coords) # at collocation points
            outputs_rho_sigm = torch.exp(outputs_rho_sigm) + 1e-8
            v_pred_phys = FD(outputs_rho.unsqueeze(-1))

            # Compute the loss
            loss_first_term = (outputs_rho - sampled_targets) ** 2 / (2 * (outputs_rho_sigm) ** 2)
            loss_second_term = 0.5 * torch.log(2 * torch.pi * (outputs_rho_sigm) ** 2)
            loss_UQ = torch.mean(loss_first_term + loss_second_term)
            mse_loss = criterion(outputs_rho, sampled_targets)
            

            outputs_ic_rho, outputs_ic_sigm = model(sampled_branch_coord, sampled_branch_values, sampled_branch_coord[:,:,:2]) # at probe points
            loss_ic = criterion(outputs_ic_rho, sampled_branch_values[:,:,0].squeeze())

            # Create the mask where sampled_branch_coord[:, :, 2] > -1
            mask = sampled_branch_coord[:, :, 2] > -1  # Mask to select valid elements

            # Apply the mask to outputs and sampled values
            outputs_ic_rho_probe = outputs_ic_rho[mask].reshape(-1, 1)  # Reshape to match expected shape
            sampled_branch_values_probe = sampled_branch_values[:, :, 1][mask].reshape(-1, 1)  # Mask and reshape to match

            # Compute velocity predictions and loss
            v_pred_data = FD(outputs_ic_rho_probe)  # Ensure the input shape matches FD's expectations
            v_loss = criterion(v_pred_data, sampled_branch_values_probe)

            # Total loss
            loss = mse_loss + v_loss/1000 + loss_UQ/300

            losses.append(mse_loss.item())
            v_losses.append(v_loss.item())
            sigmas.append(torch.mean(outputs_rho_sigm).item())
        
        
        # Backward pass with scaled gradients
        scaler.scale(loss).backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3.0)  # Clip gradient norm
        
        # Update weights
        scaler.step(optimizer)
        scaler.update()
    
    loss_list.append(np.mean(losses))
    loss_ic_list.append(np.mean(losses_ic))
    sigma_list.append(np.mean(sigmas))

    
    t0 = time.time()
    # validation loss
    if epoch % 1 == 0: # TODO: save on smallest validation loss
        model.eval()
        with torch.no_grad():
            val_losses = []
            val_losses_mse = []
            T_START = 0 

            # Timing evaluation
            for branch_coord, branch_values, trunk_coords, targets in val_loader:
                branch_coord = branch_coord.to(device)
                branch_values = branch_values.to(device)
                trunk_coords = trunk_coords.to(device)
                targets = targets.to(device)

                
                with autocast():
                    if model.UQ:
                        outputs_rho, outputs_rho_sigm = model(filtered_branch_coords[:,:,:], filtered_branch_values, filtered_trunk_coords)
                        outputs_rho_sigm = torch.exp(outputs_rho_sigm) + 1e-8

                        # Compute the loss
                        loss_first_term = (outputs_rho - filtered_targets) ** 2 / (2 * outputs_rho_sigm)
                        loss_second_term = 0.5 * torch.log(2 * torch.pi * outputs_rho_sigm)
                        val_loss = torch.mean(loss_first_term + loss_second_term).item()
                        val_loss_mse = criterion(outputs_rho, filtered_targets).item()

                    else:
                        outputs = model(filtered_branch_coords[:,:,:2], filtered_branch_values, filtered_trunk_coords)
                        # outputs = model(torch.cat((branch_coord, torch.zeros_like(branch_coord)), dim=-1), branch_values, trunk_coords)
                        val_loss = criterion(outputs, filtered_targets).item()
                        val_loss_mse = val_loss
                    
                    val_losses.append(val_loss)
                    val_losses_mse.append(val_loss_mse)
        
        val_loss_list.append(np.mean(val_losses))
        val_loss_mse_list.append(np.mean(val_losses_mse))
        v_loss_list.append(np.mean(v_losses))
        
    
    # pbar.set_postfix({'Loss': loss.item(), 'Val Loss': val_loss, 'Val Loss MSE': np.mean(val_losses_mse), "learning_rate": lrs[-1]})
    pbar.set_postfix({'Loss': np.mean(losses), 'Val Loss': val_loss, 'Val Loss MSE': np.mean(val_losses_mse), "FD loss": np.mean(v_losses)})

    # Update previous epoch loss values
    previous_epoch_loss_data = np.mean(losses)


    if np.mean(val_losses_mse) < lowest_val_loss:
        lowest_val_loss = np.mean(val_losses_mse)
        print(f"Saving model with lowest validation loss: {lowest_val_loss}")
        torch.save(model.state_dict(), f'model_sumo_19_11_past2_pred8_2_8000samples_UQ.pth')


    scheduler.step(np.mean(val_losses_mse))


In [None]:
for i in range(3):
    plt.figure()
    plt.scatter(sampled_branch_coord[i,:,1].cpu().detach().numpy(), sampled_branch_coord[i,:,0].cpu().detach().numpy(), c=sampled_branch_values[i,:,0].cpu().detach().numpy(), label='output sensors',  cmap='jet', vmin=0, vmax=1)

In [None]:
# np.savez('loss_data_20_11_sumo_past2_pred8_2_8000samples.npz', loss_list=loss_list, val_loss_mse_list=val_loss_mse_list, lambdas=lambdas, phys_losses_list=phys_losses_list)

# Plot the losses
plt.semilogy(loss_list, label='Training Loss')  # Plot the training loss
plt.semilogy(val_loss_mse_list, label='Validation Loss')  # Plot the validation loss
# plt.semilogy(np.array(sigma_list), label='Sigma')  # Plot the validation loss
# plt.semilogy(lambdas, label='Gamma')  # Plot the validation loss
# plt.semilogy(phys_losses_list, label='Physics Loss')  # Plot the validation loss
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Losses')
# plt.yscale('linear')
plt.show()

# Print the minimum validation loss for reference
print(f"Minimum Validation Loss: {min(val_loss_mse_list)}")
