In [1]:
import numpy as np
import math
import matplotlib.pyplot as plt
import pandas as pd
import time
import timeit
import math
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from collections import OrderedDict
from tqdm import tqdm
import torch
import torch.nn as nn
import itertools
from scipy.integrate import solve_ivp

In [2]:
def GetTBModelDerivatives(y, lamda, mu):
    N = (lamda/mu) + 1
    params = [
    lamda/N,         # lambda (recruitment rate)
    0.025,     # beta (transmission rate)
    1,         # delta (differential infectivity)
    0.3,       # p (fraction that goes directly to infectious)
    mu,    # mu (natural death rate)
    0.005,     # k (progression rate from exposed to infectious)
    0,         # r1 (early treatment effectiveness, not used here)
    0.8182,    # r2 (treatment rate of infectious)
    0.02,      # phi (rate from I to L)
    0.01,      # gamma (reactivation from L to I)
    0.0227,    # d1 (death rate from I)
    0.20       # d2 (death rate from L)
    ]
    S, E, I, L = y
    λ, β, δ, p, μ, k, r1, r2, φ, γ, d1, d2 = params

    dSdt = λ - β * S * (I + δ * L) * N - μ * S
    dEdt = β * (1 - p) * S * (I + δ * L) * N + r2 * I - (μ + k * (1 - r1)) * E
    dIdt = β * p * S * (I + δ * L) * N + k * (1 - r1) * E + γ * L - (μ + d1 + φ * (1 - r2) + r2) * I
    dLdt = φ * (1 - r2) * I - (μ + d2 + γ) * L

    if(torch.is_tensor(dSdt)):
      return dSdt, dEdt, dIdt, dLdt
    else:
      return np.array([dSdt, dEdt, dIdt, dLdt])

def GetTBModelDerivativesForSolveIVP(t, y, lamda, mu):
    N = (lamda/mu) + 1
    params = [
    lamda/N,         # lambda (recruitment rate)
    0.025,     # beta (transmission rate)
    1,         # delta (differential infectivity)
    0.3,       # p (fraction that goes directly to infectious)
    mu,    # mu (natural death rate)
    0.005,     # k (progression rate from exposed to infectious)
    0,         # r1 (early treatment effectiveness, not used here)
    0.8182,    # r2 (treatment rate of infectious)
    0.02,      # phi (rate from I to L)
    0.01,      # gamma (reactivation from L to I)
    0.0227,    # d1 (death rate from I)
    0.20       # d2 (death rate from L)
    ]
    S, E, I, L = y
    λ, β, δ, p, μ, k, r1, r2, φ, γ, d1, d2 = params

    dSdt = λ - β * S * (I + δ * L) * N - μ * S
    dEdt = β * (1 - p) * S * (I + δ * L) * N + r2 * I - (μ + k * (1 - r1)) * E
    dIdt = β * p * S * (I + δ * L) * N + k * (1 - r1) * E + γ * L - (μ + d1 + φ * (1 - r2) + r2) * I
    dLdt = φ * (1 - r2) * I - (μ + d2 + γ) * L

    if(torch.is_tensor(dSdt)):
      return dSdt, dEdt, dIdt, dLdt
    else:
      return np.array([dSdt, dEdt, dIdt, dLdt])

def runge_kutta_4(f, y0, t, lamda, mu):
    n = len(t)
    y = np.zeros((n, len(y0)))
    y[0] = y0
    for i in range(1, n):
        h = t[i] - t[i - 1]
        k1 = f(y[i - 1], lamda, mu)
        k2 = f(y[i - 1] + h/2 * k1, lamda, mu)
        k3 = f(y[i - 1] + h/2 * k2, lamda, mu)
        k4 = f(y[i - 1] + h * k3, lamda, mu)
        y[i] = y[i - 1] + (h/6) * (k1 + 2*k2 + 2*k3 + k4)
    return y

In [3]:
def generate_rk_solution(time_true, lamda, mu):
  time_true = np.array(time_true).flatten()
  S = lamda / mu
  N = S + 1
  y0 = [S/N, 1/N, 0, 0]

  solution = runge_kutta_4(GetTBModelDerivatives, y0, time_true, lamda, mu)
  S_true, E_true, I_true, L_true = solution.T
  data = {
    'Time':time_true,
    'S': S_true,
    'E': E_true,
    'I': I_true,
    'L': L_true
}

  # Create DataFrame
  df = pd.DataFrame(data)
  return df


def generate_reference_solution(time_true, lamda, mu):
  time_true = np.array(time_true).flatten()
  S = lamda / mu
  N = S + 1
  y0 = [S/N, 1/N, 0, 0]
  solution = solve_ivp(GetTBModelDerivativesForSolveIVP, time_true, y0, method="BDF", args = (lamda, mu))
  data = {
    'Time':solution.t,
    'S': solution.y[0],
    'E': solution.y[1],
    'I': solution.y[2],
    'L': solution.y[3]
  }

  # Create DataFrame
  df = pd.DataFrame(data)
  return df

def CalculateMSEDifferentTimes(S, E, I, L, t_sim):
    df = generate_reference_solution()
    t_ref = df['Time'].values.astype(np.float32)
    S_true = df['S'].values
    E_true = df['E'].values
    I_true = df['I'].values
    L_true = df['L'].values

    # For each reference time, find nearest simulation time
    indices = np.searchsorted(t_sim, t_ref, side='left')
    indices = np.clip(indices, 0, len(t_sim)-1)

    # Handle edge cases - check if left or right neighbor is closer
    for i in range(len(indices)):
        if indices[i] > 0:
            left_dist = abs(t_ref[i] - t_sim[indices[i]-1])
            right_dist = abs(t_ref[i] - t_sim[indices[i]])
            if left_dist < right_dist:
                indices[i] -= 1

    S_MSE = np.mean((S_true - S[indices])**2)
    E_MSE = np.mean((E_true - E[indices])**2)
    I_MSE = np.mean((I_true - I[indices])**2)
    L_MSE = np.mean((L_true - L[indices])**2)

    return S_MSE, E_MSE, I_MSE, L_MSE

def calculate_mse_over_time(t_true, lamda, mu, S_pred, E_pred, I_pred, L_pred, t_pred):
    """
    Calculate MSE at each time point between predicted and reference values
    """
    df = generate_reference_solution(t_true, lamda, mu)
    time_true = df['Time'].values.astype(np.float32)
    S_true = df['S'].values
    E_true = df['E'].values
    I_true = df['I'].values
    L_true = df['L'].values

    # Initialize MSE arrays
    mse_S = np.zeros(len(time_true))
    mse_E = np.zeros(len(time_true))
    mse_I = np.zeros(len(time_true))
    mse_L = np.zeros(len(time_true))

    # For each reference time point, find closest predicted time point
    for i, t_ref in enumerate(time_true):
        # Find nearest time index in predicted data
        idx = np.argmin(np.abs(t_pred - t_ref))

        # Calculate squared error at this time point
        mse_S[i] = (S_true[i] - S_pred[idx])**2
        mse_E[i] = (E_true[i] - E_pred[idx])**2
        mse_I[i] = (I_true[i] - I_pred[idx])**2
        mse_L[i] = (L_true[i] - L_pred[idx])**2
    return time_true, mse_S, mse_E, mse_I, mse_L

In [5]:
 def PlotResults(time_true, lamda, mu, S_pred, E_pred, I_pred, L_pred, t_pred, title="Solution of the ODE System"):
  df = generate_reference_solution(time_true,lamda, mu)
  time_true = df['Time'].values.astype(np.float32)
  S_true = df['S'].values
  E_true = df['E'].values
  I_true = df['I'].values
  L_true = df['L'].values
  plt.figure(figsize=(10, 6))
  plt.subplot(2, 2, 1)
  plt.plot(t_pred, S_pred, label='S Predicted', color='b', linewidth=2)
  plt.plot(time_true, S_true, label='S Reference', linestyle='--', color='black')
  plt.title('S(t) - Susceptible')
  plt.xlabel('Time (years)')
  plt.ylabel('S')
  plt.grid(True)
  plt.legend()

  plt.subplot(2, 2, 2)
  plt.plot(t_pred, E_pred, label='E Predicted', color='orange', linewidth=2)
  plt.plot(time_true, E_true, label='E Reference', linestyle='--', color='black')
  plt.title('E(t) - Exposed')
  plt.xlabel('Time (years)')
  plt.ylabel('E')
  plt.grid(True)
  plt.legend()

  plt.subplot(2, 2, 3)
  plt.plot(t_pred, I_pred, label='I Predicted', color='r', linewidth=2)
  plt.plot(time_true, I_true, label='I Reference', linestyle='--', color='black')
  plt.title('I(t) - Infectious')
  plt.xlabel('Time (years)')
  plt.ylabel('I')
  plt.grid(True)
  plt.legend()

  plt.subplot(2, 2, 4)
  plt.plot(t_pred, L_pred, label='L Predicted', color='green', linewidth=2)
  plt.plot(time_true, L_true, label='L Reference', linestyle='--', color='black')
  plt.title('L(t) - Latent / Out of Sight')
  plt.xlabel('Time (years)')
  plt.ylabel('L')
  plt.grid(True)
  plt.legend()

  plt.tight_layout()
  plt.suptitle(title, fontsize=16, y=1.02)
  plt.show()


def plot_mse_over_time(y_0, t_true, N, S_pred, E_pred, I_pred, L_pred, t_pred, title="Mean Square Errors over Time"):
    """
    Create a plot showing MSE over time for each compartment
    """
    # Calculate MSE over time
    time_points, mse_S, mse_E, mse_I, mse_L = calculate_mse_over_time(y_0, t_true, N, S_pred, E_pred, I_pred, L_pred, t_pred)

    # Create the plot
    plt.figure(figsize=(10, 6))

    # Plot MSE for each compartment
    plt.plot(time_points, mse_S, 'o-', label='S', color='orange', linewidth=2, markersize=4)
    plt.plot(time_points, mse_E, 's-', label='E', color='green', linewidth=2, markersize=4)
    plt.plot(time_points, mse_I, '^-', label='I', color='blue', linewidth=2, markersize=4)
    plt.plot(time_points, mse_L, 'd-', label='L', color='red', linewidth=2, markersize=4)

    # Customize the plot
    plt.xlabel('Time (years)', fontsize=12)
    plt.ylabel('Mean Square Error', fontsize=12)
    plt.title(title, fontsize=14)
    plt.legend(fontsize=10)
    plt.grid(True, alpha=0.3)

    # Set y-axis to log scale if needed (uncomment if MSE values vary greatly)
    # plt.yscale('log')

    # Format the plot
    plt.tight_layout()
    plt.show()

In [6]:
class OptimizedNN(nn.Module):
    def __init__(
        self,
        input_size,
        output_size,
        hidden_size=128,  # Reduced from 256
        num_layers=4,     # Reduced from 6
        act=torch.nn.Tanh,
        dropout_rate=0.05  # Reduced dropout
    ):
        super(OptimizedNN, self).__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # Simplified architecture for speed
        layers = []
        layers.append(nn.Linear(input_size, hidden_size))
        layers.append(act())

        for i in range(num_layers):
            layers.append(nn.Linear(hidden_size, hidden_size))
            layers.append(act())
            if i % 2 == 0:  # Add dropout every other layer
                layers.append(nn.Dropout(dropout_rate))

        layers.append(nn.Linear(hidden_size, output_size))

        self.network = nn.Sequential(*layers)

        # Optimized weight initialization
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain("tanh"))
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        return self.network(x)

In [7]:
def check_overflow_basic(values):
    """Basic overflow detection"""
    if isinstance(values, (list, tuple)):
        values = np.array(values)
    
    has_inf = np.isinf(values).any()
    has_nan = np.isnan(values).any()
    
    if has_inf:
        print("⚠️  Infinity detected!")
        return True
    if has_nan:
        print("⚠️  NaN detected!")
        return True
    return False

In [8]:
import torch
from torch.utils.data import Dataset, DataLoader
import itertools
import numpy as np

class PINNDataset(Dataset):
    """Custom dataset for PINN training with physics and data points"""
    
    def __init__(self, initial_conditions, true_values, t_physics, ts):
        self.initial_conditions = initial_conditions
        self.true_values = true_values
        self.t_physics = t_physics
        self.ts = ts
        
    def __len__(self):
        return len(self.initial_conditions)
    
    def __getitem__(self, idx):
        return {
            'ic': self.initial_conditions[idx],
            'true_values': self.true_values[idx],
            'idx': idx
        }

class Net:
    def __init__(self, totalICEpochs, ts):
        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

        #The model arch, modify this or the cell above to modify the arch.
        self.model = OptimizedNN(
            input_size=7,
            output_size=4,
            hidden_size=128,
            num_layers=4,
            act=torch.nn.Tanh,
            dropout_rate=0
        ).to(self.device)

        #The initial conditions the model will be trained on
        self.InitialLamdaValues = np.linspace(1, 11, 10)
        self.InitialMuValues = np.linspace(0.0101, 0.0227, 20)

        self.TestLamdaValues = np.linspace(1, 5, 4)
        self.TestMuValues = np.linspace(0.0101, 0.0227, 5)

        self.ts = torch.tensor(ts, dtype=torch.float32).view(-1, 1).to(self.device)

        self.totalICEpochs = totalICEpochs

        self.lambda_physics = 1.0
        self.lambda_initial = 50.0
        self.lambda_data = 10.0
        self.lambda_compartment_sum = 10.0

        self.criterion = torch.nn.MSELoss()

        # DataLoader parameters
        self.batch_size = 200  # Adjust based on your GPU memory
        self.num_workers = 0  # Set to 0 for GPU tensors

        self.LBFGS = torch.optim.LBFGS(
            self.model.parameters(),
            lr=1.0,
            max_iter=10000,
            max_eval=10000,
            history_size=50,
            tolerance_grad=1e-7,
            tolerance_change=1.0 * np.finfo(float).eps,
            line_search_fn="strong_wolfe",   # better numerical stability
        )

        self.huber_loss = torch.nn.SmoothL1Loss()

        self.adam = torch.optim.AdamW(
                    self.model.parameters(), 
                    lr=0.0001,           # Start lower than Adam
                    weight_decay=1e-4,  # Typical range: 1e-5 to 1e-3
                    betas=(0.9, 0.999)
                )
        self.scheduler = torch.optim.lr_scheduler.ExponentialLR(
            self.adam, 
            gamma=0.99  # Decay factor (lr *= gamma each step)
        )

        self.h = 0.1
        t = torch.arange(0, 1 + self.h, self.h)

        # Reduce physics time points to save memory
        self.t_physics = torch.cat([
            torch.linspace(0, 30, 101),      # Reduced from 201
        ], dim=0).view(-1, 1).to(self.device).requires_grad_(True)

        self.t_initial = torch.tensor([0.0], requires_grad=True).view(-1, 1).to(self.device)

        self.lossData = 1000
        self.loss_history = []
        
        # Will be set during Train()
        self.dataloader = None

    # A slightly modified Forward method
    def Forward(self, fullInput, t_grad, is_batched=False):
        outputs = self.model(fullInput)
        S = outputs[:, 0:1]
        E = outputs[:, 1:2]
        I = outputs[:, 2:3]
        L = outputs[:, 3:4]
        
        # The autograd call remains the same, as it operates on the entire tensor
        Sprime = torch.autograd.grad(S, t_grad, grad_outputs=torch.ones_like(S), create_graph=True)[0]
        Eprime = torch.autograd.grad(E, t_grad, grad_outputs=torch.ones_like(E), create_graph=True)[0]
        Iprime = torch.autograd.grad(I, t_grad, grad_outputs=torch.ones_like(I), create_graph=True)[0]
        Lprime = torch.autograd.grad(L, t_grad, grad_outputs=torch.ones_like(L), create_graph=True)[0]
        
        return S, E, I, L, Sprime, Eprime, Iprime, Lprime

    def compute_batch_loss(self, batch_ics, batch_true_values):
        """Compute loss for a single batch (balanced with relative residuals + Pearson correlation)"""
        batch_size = batch_ics.shape[0]
    
        # --- Physics Loss ---
        t_physics_expanded = self.t_physics.repeat(batch_size, 1)
        ics_expanded_physics = batch_ics.repeat_interleave(self.t_physics.shape[0], dim=0)
        physics_input = torch.cat([t_physics_expanded, ics_expanded_physics], dim=1)
    
        # Forward pass for physics
        S, E, I, L, S_prime, E_prime, I_prime, L_prime = self.Forward(
            physics_input, t_physics_expanded, is_batched=True
        )
    
        # True derivatives from ODE
        S_prime_true, E_prime_true, I_prime_true, L_prime_true = GetTBModelDerivatives(
            [S, E, I, L],
            physics_input[:, 1],  # lambda
            physics_input[:, 2]   # mu
        )
    
        # Relative residual loss (scale-invariant)
        def scaled_mse(pred, true, eps=1e-8):
            scale = torch.mean(torch.abs(true)) + eps
            return torch.mean(((pred - true) / scale) ** 2)
        
        loss_ode_S = self.criterion(S_prime, S_prime_true)
        loss_ode_E = self.criterion(E_prime, E_prime_true)
        loss_ode_I = self.criterion(I_prime, I_prime_true)
        loss_ode_L = self.criterion(L_prime, L_prime_true)
        
        physics_loss = loss_ode_S + loss_ode_E + loss_ode_I + loss_ode_L
    
        # Compartment sum constraint
        total_sum_of_compartments = S + E + I + L
        true_sum = torch.ones_like(S)
        compartment_sum_loss = self.criterion(total_sum_of_compartments, true_sum)
    
        # --- Data Loss ---
        ts_expanded = self.ts.repeat(batch_size, 1)
        ics_expanded_data = batch_ics.repeat_interleave(self.ts.shape[0], dim=0)
        data_input = torch.cat([ts_expanded, ics_expanded_data], dim=1)
    
        predicted_data = self.model(data_input)
        batch_true_values_reshaped = torch.cat([tv for tv in batch_true_values], dim=0)
    
        # Normalized MSE (scale-invariant)
        def normalized_mse(pred, true, eps=1e-8):
            denom = torch.mean(torch.abs(true), dim=0, keepdim=True) + eps
            return torch.mean(((pred - true) / denom) ** 2)
    
        # Pearson correlation loss
        def pearson_correlation_loss(pred, true):
            pred_flat = pred.reshape(-1)
            true_flat = true.reshape(-1)
            pred_centered = pred_flat - pred_flat.mean()
            true_centered = true_flat - true_flat.mean()
            numerator = torch.sum(pred_centered * true_centered)
            denom = torch.sqrt(torch.sum(pred_centered**2)) * torch.sqrt(torch.sum(true_centered**2)) + 1e-8
            corr = numerator / denom
            return (1.0 - corr) ** 2
    
        mse_loss = normalized_mse(predicted_data, batch_true_values_reshaped)
        corr_loss = pearson_correlation_loss(predicted_data, batch_true_values_reshaped)
        data_loss = mse_loss + corr_loss
    
        # --- Total Loss ---
        total_loss = (
            self.lambda_physics * physics_loss +
            self.lambda_data * data_loss +
            self.lambda_compartment_sum * compartment_sum_loss
        )
    
        return total_loss



    def compute_loss_with_dataloader(self):
        """Compute total loss using DataLoader"""
        total_loss = 0.0
        num_batches = 0
        
        for batch in self.dataloader:
            batch_ics = batch['ic'].to(self.device)
            batch_true_values = batch['true_values'].to(self.device)
            
            batch_loss = self.compute_batch_loss(batch_ics, batch_true_values)
            total_loss += batch_loss
            num_batches += 1
        
        # Average loss across batches
        return total_loss / num_batches

    def Train(self):
        # Prepare data
        iterables = [self.InitialLamdaValues, self.InitialMuValues]
        allIC_candidates = [list(x) for x in itertools.product(*iterables)]

        test_iterables = [self.TestLamdaValues, self.TestMuValues]
        allIC_test_candidates = [list(x) for x in itertools.product(*test_iterables)]
        allIC = []
        allTrueValues = []
        
        print(f"Generating reference solutions for {len(allIC_candidates)} initial conditions...")
        for currentInitialValue in allIC_candidates:
            df = generate_rk_solution(self.ts.detach().cpu().numpy().flatten(), currentInitialValue[0], currentInitialValue[1])
            t_ref = df['Time'].values.astype(np.float32)
            S_true = df['S'].values
            E_true = df['E'].values
            I_true = df['I'].values
            L_true = df['L'].values
            currentTrueValues = np.array([S_true, E_true, I_true, L_true])
            if(check_overflow_basic(currentTrueValues) is False):
                currentTrueValues = torch.transpose(torch.tensor(currentTrueValues, dtype=torch.float32), 0, 1)
                allTrueValues.append(currentTrueValues)
                S_0 = currentInitialValue[0] / (currentInitialValue[0]+currentInitialValue[1])
                E_0 = 1 / ((currentInitialValue[0] / currentInitialValue[1]) + 1)
                I_0 = 0
                L_0 = 0
                allIC.append(currentInitialValue + [S_0, E_0, I_0, L_0])

        print(f"Generating reference solutions for {len(allIC_test_candidates)} test initial conditions...")
        testTrueValues = []
        allIC_test = []
        for currentInitialValue in allIC_test_candidates:
            df = generate_rk_solution(self.ts.detach().cpu().numpy().flatten(), currentInitialValue[0], currentInitialValue[1])
            t_ref = df['Time'].values.astype(np.float32)
            S_true = df['S'].values
            E_true = df['E'].values
            I_true = df['I'].values
            L_true = df['L'].values
            currentTrueValues = np.array([S_true, E_true, I_true, L_true])
            if(check_overflow_basic(currentTrueValues) is False):
                currentTrueValues = torch.transpose(torch.tensor(currentTrueValues, dtype=torch.float32), 0, 1)
                testTrueValues.append(currentTrueValues)
                S_0 = currentInitialValue[0] / (currentInitialValue[0]+currentInitialValue[1])
                E_0 = 1 / ((currentInitialValue[0] / currentInitialValue[1]) + 1)
                I_0 = 0
                L_0 = 0
                allIC_test.append(currentInitialValue + [S_0, E_0, I_0, L_0])
        self.testTrueValues = testTrueValues
        self.test_ics_tensor = torch.tensor(allIC_test, dtype=torch.float32)
        # Convert to tensors
        all_ics_tensor = torch.tensor(allIC, dtype=torch.float32)
        # Create dataset and dataloader
        dataset = PINNDataset(all_ics_tensor, allTrueValues, self.t_physics, self.ts)
        self.dataloader = DataLoader(
            dataset, 
            batch_size=self.batch_size, 
            shuffle=True, 
            num_workers=self.num_workers,
            pin_memory=True if self.device.type == 'cuda' else False
        )
        
        print(f"Training with {len(allIC)} initial conditions")
        print(f"Batch size: {self.batch_size}, Number of batches: {len(self.dataloader)}")
        print("Starting new training!")

        best_loss = 100000
        for epoch in range(self.totalICEpochs):
            # Update lambda values based on epoch
            if epoch < 400:
                self.lambda_physics = 100
                self.lambda_data = 10
            elif epoch < 800:
                self.lambda_physics = 70
                self.lambda_data = 30
            elif epoch < 1200:
                self.lambda_physics = 70
                self.lambda_data = 70
            else:
                self.lambda_physics = 100.0
                self.lambda_data = 400.0
            
            epoch_loss = 0.0
            num_batches = 0
            
            # Training loop with DataLoader
            for batch in self.dataloader:
                self.adam.zero_grad()
                
                batch_ics = batch['ic'].to(self.device)
                batch_true_values = batch['true_values'].to(self.device)
                
                batch_loss = self.compute_batch_loss(batch_ics, batch_true_values)
                batch_loss.backward()
                
                # Optional gradient clipping
                # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                
                self.adam.step()
                
                epoch_loss += batch_loss.item()
                num_batches += 1
            
            # Average loss for the epoch
            avg_epoch_loss = epoch_loss / num_batches
            self.loss_history.append(avg_epoch_loss)

            #print("Epoch:", epoch, "done!")
            
            if (epoch + 1) % 1000 == 0:
                self.scheduler.step()
                print(f"Adam Epoch [{epoch+1}], Loss: {avg_epoch_loss:.8f}, LR: {self.scheduler.get_last_lr()[0]:.6f}")
                current_loss = self.EvaluateModel()
                print("Validation Loss:", current_loss)
                if(current_loss < best_loss):
                    print("New best model found!")
                    torch.save(self, f'checkpoint_epoch_{epoch+1}.pth')
                    best_loss = current_loss
            
            # Clear cache periodically
            if (epoch + 1) % 100 == 0:
                torch.cuda.empty_cache()

        # LBFGS fine-tuning
        print("Starting LBFGS fine-tuning...")
        def closure():
            self.LBFGS.zero_grad()
            total_loss = self.compute_loss_with_dataloader()
            self.loss_history.append(total_loss.item())
            total_loss.backward()
            return total_loss

        self.LBFGS.step(closure)
        final_loss = self.compute_loss_with_dataloader()
        print(f"\nTraining finished. Final Loss: {final_loss.item():.6f}")
        if(final_loss.item() < best_loss):
            print("New best model found!")
            torch.save(self, f'checkpoint_epoch_{epoch+1}.pth')
    
    def PredictSingle(self, t, y0):
        self.model.eval()
        with torch.no_grad():
            inputTensor = torch.tensor([np.hstack((t, y0))], device=self.device, dtype=torch.float32)
            output = self.model(inputTensor)
            return output.detach().cpu().numpy().flatten()

    def PredictArray(self, t, lamda, mu, S0, E0, I0, L0):
        self.model.eval()
        with torch.no_grad():
            t = torch.tensor([t], dtype=torch.float32, device=self.device)
            # Reshape t to be [n_points, 1] if it isn't already
            t_reshaped = t.reshape(-1, 1)

            # Create tensor of initial values repeated for each time point
            n_points = t_reshaped.shape[0]
            initial_values = torch.tensor([[lamda, mu, S0, E0, I0, L0]], dtype=torch.float32, device=self.device).repeat(n_points, 1)
            # Concatenate time points with initial values
            fullInput = torch.cat([t_reshaped, initial_values], dim=1)
            # Move to device and set requires_grad
            fullInput = fullInput.to(self.device).requires_grad_(True)
            inputTensor = fullInput
            output = self.model(inputTensor)
            return output.squeeze().detach().cpu().numpy()

    def EvaluateModel(self):
        total_loss = 0.0
        n_tests = len(self.test_ics_tensor)
    
        # Normalized MSE (scale-invariant)
        def normalized_mse(pred, true, eps=1e-8):
            denom = torch.mean(torch.abs(true), dim=0, keepdim=True) + eps
            return torch.mean(((pred - true) / denom) ** 2)
    
        # Pearson correlation loss
        def pearson_correlation_loss(pred, true):
            pred_flat = pred.reshape(-1)
            true_flat = true.reshape(-1)
            pred_centered = pred_flat - pred_flat.mean()
            true_centered = true_flat - true_flat.mean()
            numerator = torch.sum(pred_centered * true_centered)
            denom = torch.sqrt(torch.sum(pred_centered**2)) * torch.sqrt(torch.sum(true_centered**2)) + 1e-8
            corr = numerator / denom
            return (1.0 - corr) ** 2
    
        for i in range(n_tests):
            IC = self.test_ics_tensor[i]
    
            # predict full trajectory for this IC
            currentSolution = self.PredictArray(
                self.ts.detach().cpu().numpy().flatten(),
                IC[0], IC[1], IC[2], IC[3], IC[4], IC[5]
            )
            currentSolution = torch.tensor(currentSolution, dtype=torch.float32, device=self.device)
    
            trueSolution = self.testTrueValues[i].to(self.device)
    
            # Data loss = NMSE + Pearson
            mse_loss = normalized_mse(currentSolution, trueSolution)
            corr_loss = pearson_correlation_loss(currentSolution, trueSolution)
            currentLoss = mse_loss + corr_loss
    
            total_loss += currentLoss.item()
    
        return total_loss / n_tests


In [None]:
time_domain = np.linspace(0, 20, 41)
Model = Net(30000, time_domain) 
Model.Train()

Generating reference solutions for 200 initial conditions...
Generating reference solutions for 20 test initial conditions...
Training with 200 initial conditions
Batch size: 200, Number of batches: 1
Starting new training!
Adam Epoch [1000], Loss: 6485.29150391, LR: 0.000099
Validation Loss: 749629.3912334442
Adam Epoch [2000], Loss: 6506.93701172, LR: 0.000098
Validation Loss: 67084.49577615262
New best model found!
Adam Epoch [3000], Loss: 3502.27929688, LR: 0.000097
Validation Loss: 33679.44418091774
New best model found!
Adam Epoch [4000], Loss: 2361.54003906, LR: 0.000096
Validation Loss: 24600.313498020172
New best model found!
Adam Epoch [5000], Loss: 1776.42687988, LR: 0.000095
Validation Loss: 18733.114101088046
New best model found!


In [None]:
Model_Best = torch.load("checkpoint_epoch_24000.pth", weights_only=False)

In [None]:
time_domain = np.linspace(0, 20, 41)
lamda = 4
mu = 0.02
model_solutions = Model.PredictArray(time_domain, lamda, mu)
S_pred = model_solutions[:,0]# * N
E_pred = model_solutions[:,1]#* N
I_pred = model_solutions[:,2]#* N
L_pred = model_solutions[:,3]#* N
PlotResults([0, 20], lamda, mu, S_pred, E_pred, I_pred, L_pred, time_domain)
plot_mse_over_time([0, 20], lamda, mu, S_pred, E_pred, I_pred, L_pred, time_domain)

In [None]:
criterion = torch.nn.MSELoss()
compartment_losses = {'S': 0, 'E': 0, 'I': 0, 'L': 0}

for i in range(len(Model_Best.test_ics_tensor)):
    IC = Model_Best.test_ics_tensor[i]
    predicted = Model_Best.PredictArray(Model_Best.ts.detach().cpu().numpy().flatten(), IC[0], IC[1])
    predicted = torch.tensor(predicted, dtype=torch.float32)
    true_values = Model_Best.testTrueValues[i]
    
    compartment_losses['S'] += criterion(predicted[:, 0], true_values[:, 0]).item()
    compartment_losses['E'] += criterion(predicted[:, 1], true_values[:, 1]).item()
    compartment_losses['I'] += criterion(predicted[:, 2], true_values[:, 2]).item()
    compartment_losses['L'] += criterion(predicted[:, 3], true_values[:, 3]).item()

# Average over all test cases
for key in compartment_losses:
    compartment_losses[key] /= len(Model_Best.test_ics_tensor)
    print(f"MSE {key}: {compartment_losses[key]:.12f}")

In [None]:
torch.save(Model, "workingModel.pth")

In [None]:
y0 = [90/N, 1/N, 0, 0]
model_solutions = Model.PredictArray(time_domain, y0)
S_pred = model_solutions[:,0]*N
E_pred = model_solutions[:,1]*N
I_pred = model_solutions[:,2]*N
L_pred = model_solutions[:,3]*N
PlotResults(y0, time_domain, S_pred, E_pred, I_pred, L_pred, time_domain)
plot_mse_over_time(y0, time_domain, S_pred, E_pred, I_pred, L_pred, time_domain)