In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
from sklearn.model_selection import train_test_split

from torch.cuda.amp import autocast, GradScaler
torch.cuda.empty_cache()

import importlib
import tools_torch
importlib.reload(tools_torch)
from tools_torch import *

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"using {device}")

# load dataset

In [None]:
data = np.load("../datasets/godunov_combined_tpast2_tpred8_receding.npz")

# Extract the data
branch_coords = torch.tensor(data['branch_coords'])
branch_values = torch.tensor(data['branch_values'])
output_sensor_coords = torch.tensor(data['output_sensor_coords'])
output_sensor_values = torch.tensor(data['output_sensor_values'])
rho = torch.tensor(data['rho'])
x = torch.tensor(data['x'])
t = torch.tensor(data['t'])
Nx = data['Nx'].item()
Nt = data['Nt'].item()
Xmax = data['Xmax'].item()
Tmax = data['Tmax'].item()
N = data['N'].item()
T_PAST = data['t_past'].item()
T_PRED = data['t_pred'].item()
t_starts = torch.tensor(data['t_starts'])


print(f"Nx = {Nx}, Nt = {Nt}, Xmax = {Xmax}, Tmax = {Tmax}, N = {N}")
print(f"branch_coords.shape = {branch_coords.shape}, branch_values.shape = {branch_values.shape}, output_sensor_coords.shape = {output_sensor_coords.shape}, ")
print(f"output_sensor_values.shape = {output_sensor_values.shape}, x.shape = {x.shape}, t.shape = {t.shape}")
print(f"t_starts = {t_starts.shape}, rho.shape = {rho.shape}")

In [3]:
# set numpy random seed
np.random.seed(42)

# Define the validation percentage
validation_percentage = 0.2

branch_coords_train, branch_coords_val, branch_values_train, branch_values_val, \
output_sensor_coords_train, output_sensor_coords_val, output_sensor_values_train, output_sensor_values_val, rho_train, rho_val, \
t_starts_train, t_starts_val = train_test_split(
    branch_coords, branch_values, output_sensor_coords, output_sensor_values, rho, t_starts,
    test_size=validation_percentage, random_state=42
)

In [None]:
# Create dataset and dataloader
batch_size = 32 
train_dataset = DeepONetDatasetTrain(branch_coords_train[:,:,:], branch_values_train[:,:,0:1], output_sensor_coords_train, output_sensor_values_train, device=device) # branch_coord, branch_values, trunk_coords, targets, UU
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_dataset = DeepONetDatasetTrain(branch_coords_val[:,:,:], branch_values_val[:,:,0:1], output_sensor_coords_val, output_sensor_values_val, device=device)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

xs, us, ys, ss = next(iter(train_loader))
print(f"train shapes\t xs: {xs.shape}, us: {us.shape}, ys: {ys.shape}, ss: {ss.shape}, device: {xs.device}")
xs, us, ys, ss = next(iter(val_loader))
print(f"val shapes\t xs: {xs.shape}, us: {us.shape}, ys: {ys.shape}, ss: {ss.shape}, device: {xs.device}")


In [None]:
idx=3
plt.scatter(ys[idx,:,1].cpu().numpy(), ys[idx,:,0].cpu().numpy(), c=ss[idx,:].cpu().numpy(), label='output sensors',  cmap='jet', vmin=0, vmax=1)
plt.figure()
plt.scatter(xs[idx,:,1].cpu().numpy(), xs[idx,:,0].cpu().numpy(), c=us[idx,:,0].cpu().numpy(), label='branch sensors', cmap='jet', vmin=0, vmax=1)
plt.ylim(0), plt.xlim(0)

# model definition

In [None]:
# Define the model
from vidon_model import VIDON, FDLearner
p = 400
model = VIDON(p=p, num_heads=4, d_branch_input=2, d_v=1, use_linear_decoder=False, UQ=True).to(device)
model.to(device)
FD = FDLearner(d_FD=50)
FD.to(device)

# Define the loss function 
criterion = nn.MSELoss()

num_epochs = 100
# Define the optimizer
optimizer = optim.AdamW(model.parameters(), lr=0.001)#, weight_decay=1e-10) #
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.2,
    patience=50,
    min_lr=1e-7,
    cooldown=5,
    verbose=True
)
scaler = torch.cuda.amp.GradScaler()


# print number of parameters
print(f"Number of parameters in coord encoder: {sum(p.numel() for p in model.branch.coord_encoder.parameters())}")
print(f"Number of parameters in value encoder: {sum(p.numel() for p in model.branch.value_encoder.parameters())}")
print(f"Number of parameters in combiner: {sum(p.numel() for p in model.branch.combiner.parameters())}")
print(f"Number of parameters in nonlinear decoder: {sum(p.numel() for p in model.nonlinear_decoder.parameters())}")
print(f"Number of parameters in trunk: {sum(p.numel() for p in model.trunk.parameters())}")
print(f"Number of parameters in FD:  \t{sum(p.numel() for p in FD.parameters())}")
print(f"Number of parameters in model: {sum(p.numel() for p in model.parameters())}")

# training loop

In [None]:
gradient_norms = []
loss_list = []
loss_ic_list = []
val_loss_list = []
val_loss_mse_list = []
lrs = []
val_loss = 0
lowest_val_loss = 1e6


# Training loop 
pbar = tqdm(range(num_epochs))
for epoch in pbar:
    model.train()
    losses = []
    losses_ic = []
    
    for branch_coord, branch_values, trunk_coords, targets in train_loader:

        # bring to gpu
        branch_coord = branch_coord.to(device)
        branch_values = branch_values.to(device)
        trunk_coords = trunk_coords.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()

        # select random parts
        sampled_trunk_coords, sampled_targets = sample_trunk_inputs(trunk_coords, targets)
        sampled_branch_coord, sampled_branch_values = sample_branch_inputs_keep_boundary(branch_coord, branch_values)
    
        # forward pass for rho and v with automatic mixed precision
        with torch.cuda.amp.autocast():
            outputs_rho, outputs_rho_sigm = model(sampled_branch_coord[:,:,:2], sampled_branch_values, sampled_trunk_coords)
            outputs_rho_sigm = torch.exp(outputs_rho_sigm) + 1e-8

            # Compute the loss
            loss_first_term = (outputs_rho - sampled_targets) ** 2 / (2 * (outputs_rho_sigm) ** 2)
            loss_second_term = 0.5 * torch.log(2 * torch.pi * (outputs_rho_sigm) ** 2)
            loss_UQ = torch.mean(loss_first_term + loss_second_term)
            mse_loss = criterion(outputs_rho, sampled_targets)
            loss = loss_UQ/300 + mse_loss

            losses.append(mse_loss.item())
        
        # Backward pass with scaled gradients
        scaler.scale(loss).backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3.0)  # Clip gradient norm
        
        # Update weights
        scaler.step(optimizer)
        scaler.update()

    
    loss_list.append(np.mean(losses))

    # validation loss
    model.eval()
    with torch.no_grad():
        val_losses = []
        val_losses_mse = []
        T_START = 0 

        # Timing evaluation
        for branch_coord, branch_values, trunk_coords, targets in val_loader:
            branch_coord = branch_coord.to(device)
            branch_values = branch_values.to(device)
            trunk_coords = trunk_coords.to(device)
            targets = targets.to(device)

            
            with autocast():
                outputs_rho, outputs_rho_sigm = model(branch_coord[:,:,:2], branch_values, trunk_coords)
                outputs_rho_sigm = torch.exp(outputs_rho_sigm) + 1e-8

                # # Compute the loss
                loss_first_term = (outputs_rho - targets) ** 2 / (2 * outputs_rho_sigm)
                loss_second_term = 0.5 * torch.log(2 * torch.pi * outputs_rho_sigm)
                val_loss = torch.mean(loss_first_term + loss_second_term).item()
                val_loss_mse = criterion(outputs_rho, targets).item()
                
                val_losses_mse.append(val_loss_mse)
    
    val_loss_mse_list.append(np.mean(val_losses_mse))

pbar.set_postfix({'Loss': loss.item(), 'Val Loss': val_loss, 'Val Loss MSE': np.mean(val_losses_mse)})


if np.mean(val_losses_mse) < lowest_val_loss:
    lowest_val_loss = np.mean(val_losses_mse)
    print(f"Saving model with lowest validation loss: {lowest_val_loss}")
    torch.save(model.state_dict(), f'trained_model.pth')


scheduler.step(np.mean(val_losses_mse))

In [None]:
plt.semilogy(loss_list, label='training loss')
plt.semilogy(val_loss_mse_list, label='validation loss') 
plt.legend();