In [None]:
from IPython.display import display, clear_output
import numpy as np
import torch
import sys
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, random_split, dataset
import sourcedefender
import random
   
sys.path.append("../../../")
sys.path.append("../../")
sys.path.append("../")
sys.path.append("./")

from lib.util import MHPI, count_parameters
from lib.utiltools import loss_live_plot, GaussianRandomFieldGenerator, generate_batch_parameters, AutomaticWeightedLoss
from lib.DerivativeComputer import batchJacobian_AD
from models.FNO_1d_simple import FNO1d
from equation.ode2 import *

MHPI()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Device: {device}\n')

# Set seeds
random.seed(42)
torch.manual_seed(42)

In [2]:
Sample_number = 1000
training_sample = 10
dataset_segment_size = 1000

In [3]:
plot_live_loss = True
Create_new_dataset = True

In [4]:
modes = 8
width = 20

In [5]:
alpha_ = 1.5
tau = 0.5

In [6]:
t0 = 0.0
t_end = 2
steps = 100
T_in = 10
T_out = steps - T_in

optimizer and training configurations

In [7]:
epochs = 500
batch_size = 16
learning_rate = 0.01
scheduler_step = 100
scheduler_gamma = 0.9

%%

In [8]:
# Define time tensor
t_tensor = torch.linspace(t0, t_end, steps)  # Shape: [Steps]

# Define parameter ranges
alpha_range = (0.02, 0.06)
beta_range = (0.01, 0.03)
gamma_range = (20, 60)
delta_range = (0.5, 1.5)
omega_range = (0.2, 0.6)
epsilon_range = (0.0, 0.2)
zeta_range = (0.0, 0.2)

# Generate random samples for each parameter
alpha = alpha_range[0] + (alpha_range[1] - alpha_range[0]) * torch.rand(Sample_number)
beta = beta_range[0] + (beta_range[1] - beta_range[0]) * torch.rand(Sample_number)
gamma = gamma_range[0] + (gamma_range[1] - gamma_range[0]) * torch.rand(Sample_number)
delta = delta_range[0] + (delta_range[1] - delta_range[0]) * torch.rand(Sample_number)
omega = omega_range[0] + (omega_range[1] - omega_range[0]) * torch.rand(Sample_number)
epsilon = epsilon_range[0] + (epsilon_range[1] - epsilon_range[0]) * torch.rand(Sample_number)
zeta = zeta_range[0] + (zeta_range[1] - zeta_range[0]) * torch.rand(Sample_number)

# Stack all parameters into a single tensor
parameters = torch.stack([alpha, beta, gamma, delta, omega, epsilon, zeta], dim=1).requires_grad_(True)


%%

In [9]:
PATH = ''
# Create dataset < This part can be changed based on different cases>
if Create_new_dataset:
    dataset = creat_dataset(t_tensor, parameters, T_in)
    torch.save(dataset, PATH + 'datasets/main_dataset_ODE2.pt')
    # Calculate sizes for train/eval/test split
    train_size = int(0.75 * Sample_number)
    eval_size = int(0.15 * Sample_number)
    test_size = Sample_number - train_size - eval_size

    # Create DataLoaders
    train_dataset, eval_dataset, test_dataset = random_split(dataset, [train_size, eval_size, test_size]) # Split the dataset
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    eval_loader = DataLoader(eval_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset)
    torch.save(train_loader, PATH + 'datasets/train_dataset_ODE2.pt')
    torch.save(eval_loader, PATH + 'datasets/eval_dataset_ODE2.pt')
    torch.save(test_loader, PATH + 'datasets//test_dataset_ODE2.pt')

else:
    dataset = torch.load(PATH + 'datasets/main_dataset_ODE2.pt')
    train_loader = torch.load(PATH + 'datasets/train_dataset_ODE2.pt')
    eval_loader = torch.load(PATH + 'datasets/eval_dataset_ODE2.pt')
    test_loader = torch.load(PATH + 'datasets/test_dataset_ODE2.pt')

In [None]:
model = FNO1d(modes, width, T_in, T_out, state_size=1, parameters_size=7).to(device)
print(model)


In [None]:
count_parameters(model)

In [12]:
learning_rate = 0.001
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=250, gamma=1.025)
train_fnolosses, train_odelosses, train_iglosses, val_losses = [], [], [], []
coieffs_list = []
criterion_1 = nn.MSELoss()


In [13]:
# Add configuration flags
enable_ig_loss = True
enable_eq_loss = False

if (not enable_ig_loss) and (not enable_eq_loss):
    mode = "Data only"
    ss = 1
elif enable_ig_loss and (not enable_eq_loss):
    mode = "Data + IG"
    ss = 1 + model.parameters_size
elif (not enable_ig_loss) and enable_eq_loss:
    mode = "Data + Eq"
    ss = 1 + 1
else:  # enable_ig_loss and enable_eq_loss
    mode = "Data + IG + Eq"
    ss = 1 + model.parameters_size + 1 

awl = AutomaticWeightedLoss(ss)
optimizer = optim.Adam([
                {'params': model.parameters(), 'lr': learning_rate},
                {'params': awl.parameters(), 'weight_decay': 0}
            ])

In [None]:
outer_loop = tqdm(range(epochs), desc="Progress", position=0)
torch.cuda.empty_cache()

train_fnolosses = []
train_iglosses = []
train_eqlosses = []
val_losses = []
coieffs_list = []

for ep in outer_loop:
    model.train()
    train_fnoloss_accumulated = 0.0
    train_igloss_accumulated = 0.0
    train_eqloss_accumulated = 0.0
    
    for batch_data in train_loader:
        batch_data = [item.to(device) for item in batch_data]
        batch_parameters, batch_u_in, batch_u_out, du_dparam_true = batch_data
        batch_parameters.requires_grad_(True)
        
        batch_size_ = batch_parameters.shape[0]
        t_tensor_ = torch.linspace(t0, t_end, steps)[T_in:].unsqueeze(0).repeat(batch_size_, 1).to(device)
        t_tensor_.requires_grad_(True)
        optimizer.zero_grad()
        U_in = batch_u_in
        U_pred = model(U_in, t_tensor_, batch_parameters).squeeze(-1)
        
        # Compute all losses first
        data_loss = criterion_1(U_pred, batch_u_out.squeeze(-1))
        
        # Always compute IG losses but with graphed=False when not used in loss
        du_dp = torch.zeros(batch_size_, T_out, model.parameters_size).to(device)
        du_dp = batchJacobian_AD(U_pred, batch_parameters, 
                                graphed=(enable_ig_loss),  # Only compute gradients if IG is enabled
                                batchx=True)
        ig_loss_list = []
        for i in range(model.parameters_size):
            ig_loss_individuals = criterion_1(du_dp[:, :, i], du_dparam_true.squeeze(-1)[:, T_in:, i])
            ig_loss_list.append(ig_loss_individuals)
            
        residual = ode_residual(U_pred, batch_parameters, t_tensor_)
        eq_loss = criterion_1(residual, torch.zeros_like(residual))
        
        # Case 1: Data only
        if (not enable_ig_loss) and (not enable_eq_loss):
            loss = awl(data_loss)
            coieffs = awl.params.data.clone().detach()
            fnoloss = coieffs[0].item() * data_loss
            ig_loss = sum(1.0 * x for x in ig_loss_list)
            eq_loss_value = eq_loss.item()
            # Clean up memory since IG is not in loss
            del du_dp
            torch.cuda.empty_cache()
            
        # Case 2: Data + IG
        elif enable_ig_loss and (not enable_eq_loss):
            loss = awl(data_loss, *[x for x in ig_loss_list])
            coieffs = awl.params.data.clone().detach()
            fnoloss = coieffs[0].item() * data_loss
            ig_loss = sum(coieffs[i+1].item() * ig_loss_list[i] for i in range(len(ig_loss_list)))
            eq_loss_value = eq_loss.item()
            
        # Case 3: Data + Eq
        elif (not enable_ig_loss) and enable_eq_loss:
            loss = awl(data_loss, eq_loss)
            coieffs = awl.params.data.clone().detach()
            fnoloss = coieffs[0].item() * data_loss
            ig_loss = sum(1.0 * x for x in ig_loss_list)
            eq_loss_value = coieffs[-1].item() * eq_loss
            # Clean up memory since IG is not in loss
            del du_dp
            torch.cuda.empty_cache()
            
        # Case 4: Data + IG + Eq
        else:  # enable_ig_loss and enable_eq_loss
            loss = awl(data_loss, *[x for x in ig_loss_list], eq_loss)
            coieffs = awl.params.data.clone().detach()
            fnoloss = coieffs[0].item() * data_loss
            ig_loss = sum(coieffs[i+1].item() * ig_loss_list[i] for i in range(len(ig_loss_list)))
            eq_loss_value = coieffs[-1].item() * eq_loss
        
        loss.backward()
        optimizer.step()
        
        # Accumulate losses
        train_fnoloss_accumulated += fnoloss.item() * batch_size_
        train_igloss_accumulated += ig_loss * batch_size_
        train_eqloss_accumulated += eq_loss_value * batch_size_
    
    coieffs_list.append(coieffs)
    epoch_fnoloss = train_fnoloss_accumulated / len(train_loader.dataset)
    epoch_igloss = train_igloss_accumulated / len(train_loader.dataset)
    epoch_eqloss = train_eqloss_accumulated / len(train_loader.dataset)
    
    train_fnolosses.append(epoch_fnoloss)
    train_iglosses.append(epoch_igloss)
    train_eqlosses.append(epoch_eqloss)
    
    # Evaluation phase
    model.eval()
    val_loss_accumulated = 0.0
    with torch.no_grad():
        for batch_data in eval_loader:
            batch_data = [item.to(device) for item in batch_data]
            batch_parameters, batch_u_in, batch_u_out = batch_data[:3]
            batch_size_ = batch_parameters.shape[0]
            t_tensor_ = torch.linspace(t0, t_end, steps)[T_in:].unsqueeze(0).repeat(batch_size_, 1).to(device)
            U_in = batch_u_in
            U_pred = model(U_in, t_tensor_, batch_parameters).squeeze(-1)
            val_fnoloss = criterion_1(U_pred, batch_u_out.squeeze(-1))
            val_loss_accumulated += val_fnoloss.item() * batch_size_
        epoch_val_loss = val_loss_accumulated / len(eval_loader.dataset)
        val_losses.append(epoch_val_loss)

    losses_dict = {
        'Training FNO Loss': train_fnolosses,
        'Training IG Loss': train_iglosses,
        'Training EQ Loss': train_eqlosses,
        'Validation Loss': val_losses
    }
    
    if ep % 10 == 0:
        torch.save({
            'epoch': ep,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
        }, PATH + 'saved_model/ODE_2_saved_Model_Data_IG_EQ_saved.pth')
    
    outer_loop.set_description(f"Progress (Epoch {ep + 1}/{epochs}, Mode: {mode})")
    outer_loop.set_postfix(
        fnoloss=f'{epoch_fnoloss:.2e}',
        ig_loss=f'{epoch_igloss:.2e}',
        eq_loss=f'{epoch_eqloss:.2e}',
        eval_loss=f'{epoch_val_loss:.2e}'
    )

print("Training complete")