<a href="https://colab.research.google.com/github/adityaprasad2005/ME691-ChemNODE-Ammonia/blob/main/notebooks/phase3-model_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import json
import os
import time


# 1. Install torchdiffeq
# (Run this in your terminal)
!pip install torchdiffeq

from torchdiffeq import odeint

Collecting torchdiffeq
  Downloading torchdiffeq-0.2.5-py3-none-any.whl.metadata (440 bytes)
Downloading torchdiffeq-0.2.5-py3-none-any.whl (32 kB)
Installing collected packages: torchdiffeq
Successfully installed torchdiffeq-0.2.5


In [None]:
# 3. Mount your Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Set the Configuration

# File paths
DRIVE_PATH = "/content/drive/MyDrive/SciML_Project"

NPZ_FILE = os.path.join(DRIVE_PATH, 'training_data.npz')
PARAMS_FILE = os.path.join(DRIVE_PATH, 'normalization_params.json')
MODEL_SAVE_PATH = os.path.join(DRIVE_PATH, 'chem_node_net_3lay_128dim.pth')

# Training hyperparameters
N_EPOCHS = 500
BATCH_SIZE = 16
LEARNING_RATE = 1e-3
HIDDEN_DIMS = 128

# Set device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")



# -----------------------------------------------------------------
# Load data to get model dimensions

try:
    with np.load(NPZ_FILE) as data:
        all_trajectories = data['trajectories']
        t_points = data['times']

    # Get the number of features (e.g., 32 for your ammonia mech)
    # Shape is (num_sims, num_timesteps, num_features)
    N_FEATURES = all_trajectories.shape[2]

    print(f"Data loaded. Found {N_FEATURES} features (species + temp).")

except FileNotFoundError:
    print(f"Error: Could not find '{NPZ_FILE}'.")
    print("Please make sure you have run Phase 2.")
    raise

Using device: cuda
Data loaded. Found 32 features (species + temp).


In [None]:
# -----------------------------------------------------------------
# Step 1: The Derivative Network

class ChemNet(nn.Module):
    def __init__(self, input_features, hidden_features):
        super(ChemNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_features, hidden_features),
            nn.ELU(),  # Exponential Linear Unit
            nn.Linear(hidden_features, hidden_features),
            nn.ELU(),
            # nn.Linear(hidden_features, hidden_features),
            # nn.ELU(),
            nn.Linear(hidden_features, input_features)
        )

    def forward(self, t, y):
        # The 'odeint' solver requires the network's forward
        # method to accept (time, state) as arguments.
        # We don't use 't' here, but it's a required argument.
        return self.net(y)

In [None]:
# Step 2: The Neural ODE (NODE) Wrapper
# -----------------------------------------------------------------
# This module wraps our ChemNet and calls the 'odeint' solver.

class NeuralODE(nn.Module):
    def __init__(self, derivative_net):
        super(NeuralODE, self).__init__()
        self.net = derivative_net

    def forward(self, y0, t_points):
        # y0 shape: (batch_size, n_features)
        # t_points shape: (n_timesteps,)

        # 'odeint' solves the ODE defined by self.net
        # 'method='dopri5'' is a standard Runge-Kutta solver,
        # as suggested by the Bansude paper (they used 'dopri5' for training)
        # 'adjoint' method is used for efficient backpropagation
        pred_y = odeint(self.net, y0, t_points, method='dopri5', rtol=1e-5, atol=1e-6) # eased the tolerances for faster inferences

        # Output shape from odeint is (n_timesteps, batch_size, n_features)
        return pred_y

In [None]:
# Step 3: The PyTorch Dataset and DataLoader

# We create a simple Dataset to feed trajectories to our model.
class ODESolutionDataset(Dataset):
    def __init__(self, trajectories, device):
        # Data shape: (num_sims, n_timesteps, n_features)
        # We convert to float32 (standard for NN) and move to device
        self.data = torch.tensor(trajectories, dtype=torch.float32).to(device)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Returns one full trajectory
        # Shape: (n_timesteps, n_features)
        return self.data[idx]


# --- Prepare data for training ---

# Convert time points to a tensor on the correct device
t_points = torch.tensor(t_points, dtype=torch.float32).to(device)

# Create Dataset and DataLoader
dataset = ODESolutionDataset(all_trajectories, device)
# The DataLoader will batch trajectories
# A 'batch' will have shape: (BATCH_SIZE, n_timesteps, n_features)
data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
# Step 4: The Training Loop

# 1. Initialize models
derivative_net = ChemNet(N_FEATURES, HIDDEN_DIMS).to(device)
model = NeuralODE(derivative_net).to(device)

# --- LOAD SAVED MODEL IF IT EXISTS ---
try:
    derivative_net.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=device))
    print(f"Successfully loaded existing model from: {MODEL_SAVE_PATH}")
    print("Continuing training from this checkpoint.")
except FileNotFoundError:
    print(f"No existing model found at {MODEL_SAVE_PATH}.")
    print("Starting a new training run.")
except Exception as e:
    print(f"Error loading model: {e}")
    print("Starting a new training run.")
# ----------------------------------------


# 2. Loss Function and Optimizer
loss_function = nn.L1Loss()  # L1Loss is MAE

# --- FIX 1: Define Optimizer BEFORE Scheduler ---
# Bansude et al. used Adam
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# This will reduce the LR by half if the loss doesn't improve for 50 epochs
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    'min',
    factor=0.5,
    patience=50)

Successfully loaded existing model from: /content/drive/MyDrive/SciML_Project/chem_node_net_3lay_128dim.pth
Continuing training from this checkpoint.


In [None]:
print(f"--- Starting Training on {device} ---")

# --- Start training ---
start_time = time.time()
for epoch in range(N_EPOCHS):
    epoch_loss = 0.0

    # Set model to training mode
    model.train()
    for batch_y_true in data_loader:
        # batch_y_true shape: (BATCH_SIZE, n_timesteps, n_features)

        # Get the initial condition (the state at t=0)
        # y0 shape: (BATCH_SIZE, n_features)
        batch_y0 = batch_y_true[:, 0, :]

        # --- Forward Pass ---
        # Run the Neural ODE
        # pred_y shape: (n_timesteps, BATCH_SIZE, n_features)
        pred_y = model(batch_y0, t_points)

        # --- Calculate Loss ---
        # We must re-order pred_y to match batch_y_true
        # [timesteps, batch, features] -> [batch, timesteps, features]
        pred_y_for_loss = pred_y.permute(1, 0, 2)

        loss = loss_function(pred_y_for_loss, batch_y_true)

        # --- Backward Pass ---
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(data_loader)

    # --- FIX 2: Call the scheduler.step() at the end of each epoch ---
    scheduler.step(avg_loss)
    # -----------------------------------------------------------------

    # Updated print statement to show the learning rate
    if (epoch + 1) % 10 == 0:
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch {epoch + 1}/{N_EPOCHS} | Avg. Loss: {avg_loss:.6f} | LR: {current_lr:.2e}")

end_time = time.time()
print(f"--- Training Complete ---")
print(f"Total time: {(end_time - start_time):.2f} seconds")

# Step 5: Save the Trained Model

# We only need to save the derivative network (ChemNet),
# as it *is* the learned dynamics function.
torch.save(derivative_net.state_dict(), MODEL_SAVE_PATH)
print(f"Trained model saved to: {MODEL_SAVE_PATH}")

--- Starting Training on cuda ---
Epoch 10/500 | Avg. Loss: 0.022785 | LR: 1.00e-03
Epoch 20/500 | Avg. Loss: 0.023400 | LR: 1.00e-03
Epoch 30/500 | Avg. Loss: 0.024415 | LR: 1.00e-03
Epoch 40/500 | Avg. Loss: 0.023607 | LR: 1.00e-03
Epoch 50/500 | Avg. Loss: 0.022379 | LR: 1.00e-03
Epoch 60/500 | Avg. Loss: 0.024616 | LR: 1.00e-03
Epoch 70/500 | Avg. Loss: 0.022815 | LR: 1.00e-03
Epoch 80/500 | Avg. Loss: 0.022554 | LR: 1.00e-03
Epoch 90/500 | Avg. Loss: 0.023405 | LR: 1.00e-03
Epoch 100/500 | Avg. Loss: 0.023639 | LR: 1.00e-03
Epoch 110/500 | Avg. Loss: 0.023114 | LR: 1.00e-03
Epoch 120/500 | Avg. Loss: 0.023458 | LR: 1.00e-03
Epoch 130/500 | Avg. Loss: 0.022754 | LR: 1.00e-03
Epoch 140/500 | Avg. Loss: 0.022948 | LR: 1.00e-03
Epoch 150/500 | Avg. Loss: 0.022559 | LR: 1.00e-03
Epoch 160/500 | Avg. Loss: 0.024138 | LR: 1.00e-03
Epoch 170/500 | Avg. Loss: 0.025850 | LR: 1.00e-03
Epoch 180/500 | Avg. Loss: 0.023996 | LR: 1.00e-03
Epoch 190/500 | Avg. Loss: 0.020490 | LR: 5.00e-04
Epoch 