# WORKSPACE SETUP

In [None]:
# IMPORTING LIBRARIES
import copy
import csv
import os
import random
import time

import numpy as np
import pandas as pd

import optuna
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner
import joblib
import logging
from typing import Dict, Any, Tuple
import warnings
warnings.filterwarnings('ignore')
# Configure logging for Optuna
optuna.logging.set_verbosity(optuna.logging.WARNING)

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdmolops
from rdkit.Chem import rdDistGeom as molDG
from rdkit.Chem import Descriptors
from rdkit.Chem.rdchem import GetPeriodicTable

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch_geometric.nn import MessagePassing, global_mean_pool, global_max_pool, global_add_pool
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn.inits import reset

import networkx as nx
import matplotlib
import matplotlib.pyplot as plt

from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from sklearn.model_selection import train_test_split

# FUNCTIONS
from data_processing import load_dataset, smiles_to_graph, process_dataset, generate_graphs
from path_helpers import get_path
from stats_compute import compute_statistics, scale_graphs
import ModelArchitecture
from EnhancedDataSplit import DataSplitter

from collections import defaultdict
from typing import Tuple, List

# DIRECTORY SETUP
current_directory = os.getcwd()
parent_directory = os.path.dirname(current_directory)

In [None]:
# HYPERPARAMETER SETTINGS
# Reproducibility settings
use_physics_in_loss=True 
monitor_physics=True
seed = 21
split_seed = 42
num_epochs = 100
patience = 30
runtyp = 'RASHYB_BO_PINN'
study_name = '04_mean'
print('Base seed        :', seed)
print('Split seed       :', split_seed)
print('Max epochs       :', num_epochs)
print('Patience         :', patience)
run_time = time.time()

# CUDA Deterministic (ON/OFF SETTING)
# For PyTorch
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

selected_device = 'cuda' # either 'cuda' or 'cpu
device = torch.device(selected_device)
print('device           :', device)

In [None]:
# BAYESIAN OPTIMIZER - PHYSICS-INFORMED VERSION
class BayesianOptimizer:
    def __init__(self, 
                 train_data, 
                 val_data, 
                 test_data,
                 conc_std,
                 conc_mean,
                 temp_std, 
                 temp_mean,
                 pco2_std,
                 pco2_mean,
                 device,
                 parent_directory,
                 n_trials=10,
                 seed=21):

        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        self.conc_std = conc_std
        self.conc_mean = conc_mean
        self.temp_std = temp_std
        self.temp_mean = temp_mean
        self.pco2_std = pco2_std
        self.pco2_mean = pco2_mean
        self.device = device
        self.parent_directory = parent_directory
        self.n_trials = n_trials
        self.seed = seed
        
        # Get dimensions from data
        self.node_dim = train_data[0].x.size(1)
        self.edge_dim = train_data[0].edge_attr.size(1)
        
        # Allow fixed params across sessions
        self.fixed_params = {}

    def define_hyperparameter_space(self, trial, param_ranges=None):
        """Define hyperparameter search space, respecting fixed params and dynamic ranges."""

        if param_ranges is None:
            param_ranges = {}

        # Hidden dim - only suggest if not fixed
        if 'hidden_dim' in self.fixed_params:
            hidden_dim = self.fixed_params['hidden_dim']
        else:
            hidden_dim = trial.suggest_categorical('hidden_dim', param_ranges.get('hidden_dim', [64, 128]))

        # Graph layers - only suggest if not fixed
        if 'graph_layers' in self.fixed_params:
            graph_layers = self.fixed_params['graph_layers']
        else:
            graph_layers = trial.suggest_int('graph_layers', *param_ranges.get('graph_layers', (2, 6)))

        # FC layers - only suggest if not fixed
        if 'fc_layers' in self.fixed_params:
            fc_layers = self.fixed_params['fc_layers']
        else:
            fc_layers = trial.suggest_int('fc_layers', *param_ranges.get('fc_layers', (2, 6)))

        # Learning rate - only suggest if not fixed
        if 'lr' in self.fixed_params:
            lr = self.fixed_params['lr']
        else:
            lr = trial.suggest_float('lr', *param_ranges.get('lr', (1e-6, 1e-2)), log=True)

        # Weight decay - only suggest if not fixed
        if 'weight_decay' in self.fixed_params:
            weight_decay = self.fixed_params['weight_decay']
        else:
            weight_decay = trial.suggest_float('weight_decay', *param_ranges.get('weight_decay', (1e-6, 1e-2)), log=True)

        # Batch size - only suggest if not fixed
        if 'batch_size' in self.fixed_params:
            batch_size = self.fixed_params['batch_size']
        else:
            batch_size = trial.suggest_categorical('batch_size', param_ranges.get('batch_size', [32, 64]))

        # Physics loss scaling factors s1, s2, s3 - only suggest if not fixed
        if 's1' in self.fixed_params:
            s1 = self.fixed_params['s1']
        else:
            s1 = trial.suggest_categorical('s1', [1, 1e+1, 1e+2, 1e+3])

        if 's2' in self.fixed_params:
            s2 = self.fixed_params['s2']
        else:
            s2 = trial.suggest_categorical('s2', [1, 1e+1, 1e+2, 1e+3])

        if 's3' in self.fixed_params:
            s3 = self.fixed_params['s3']
        else:
            s3 = trial.suggest_categorical('s3', [1, 1e+1, 1e+2, 1e+3])


        return {
            'hidden_dim': hidden_dim,
            'graph_layers': graph_layers,
            'fc_layers': fc_layers,
            'lr': lr,
            'weight_decay': weight_decay,
            'batch_size': batch_size,
            's1': s1,
            's2': s2,
            's3': s3
        }

    def create_model(self, params):
        model = ModelArchitecture.VLEAmineCO2(
            node_dim=self.node_dim,
            edge_dim=self.edge_dim,
            hidden_dim=params['hidden_dim'],
            graph_layers=params['graph_layers'],
            fc_layers=params['fc_layers'],
            use_adaptive_pooling=True
        ).to(self.device)
        
        return model
    
    def train_with_params(self, params, trial=None):
        # Create data loaders
        train_loader = DataLoader(self.train_data, batch_size=params['batch_size'], shuffle=True)
        val_loader = DataLoader(self.val_data, batch_size=params['batch_size'], shuffle=False)
        
        # Create model and optimizer
        model = self.create_model(params)
        criterion = ModelArchitecture.MSLELoss()
        optimizer = optim.Adam(model.parameters(), 
                               lr=params['lr'], 
                               weight_decay=params['weight_decay'])
        
        # Training parameters
        num_epochs = 100
        patience = 10
        
        best_val_loss = float('inf')
        epochs_without_improvement = 0
        
        for epoch in range(num_epochs):
            # Training
            model.train()
            train_loss = 0
            for data in train_loader:
                data = data.to(self.device)
                optimizer.zero_grad()
                output = model(data)
                
                # Main loss
                loss_main = criterion(output, data.aco2.view(-1, 1))
                
                # Physics losses
                loss_thermo1 = ModelArchitecture.grad_pres(model, data, self.pco2_std, self.pco2_mean, params['s1']) if params['s1'] != 0 else 0.0
                loss_thermo2 = ModelArchitecture.grad_temp(model, data, self.temp_std, self.temp_mean, params['s2']) if params['s2'] != 0 else 0.0
                loss_thermo3 = ModelArchitecture.grad_conc(model, data, self.conc_std, self.conc_mean, params['s3']) if params['s3'] != 0 else 0.0
                
                total_loss = loss_main + loss_thermo1 + loss_thermo2 + loss_thermo3
                total_loss.backward()
                optimizer.step()
                train_loss += total_loss.item()
            
            # Validation
            model.eval()
            val_loss = 0
            with torch.no_grad():
                for data in val_loader:
                    data = data.to(self.device)
                    output = model(data)
                    loss = criterion(output, data.aco2.view(-1, 1))
                    val_loss += loss.item()
            
            avg_val_loss = val_loss / len(val_loader)
            
            # Early stopping and pruning
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                epochs_without_improvement = 0
            else:
                epochs_without_improvement += 1
            
            if trial is not None:
                trial.report(avg_val_loss, epoch)
                if trial.should_prune():
                    raise optuna.exceptions.TrialPruned()
            
            if epochs_without_improvement >= patience:
                break
        
        return best_val_loss
    
    def objective_with_ranges(self, trial, param_ranges=None):
        try:
            params = self.define_hyperparameter_space(trial, param_ranges=param_ranges)
            val_loss = self.train_with_params(params, trial)
            return val_loss
        except Exception as e:
            print(f"Trial failed with error: {str(e)}")
            return float('inf')
    
    def optimize(self, study_name="gnn_hyperopt", continue_from_last=True, 
                fixed_params=None, param_ranges=None):
        """
        Run Bayesian hyperparameter optimization.

        Args:
            study_name (str): Name of the study.
            continue_from_last (bool): If True, resume from previously saved study.
            fixed_params (dict): Hyperparameters to keep constant.
            param_ranges (dict): Custom search ranges for other hyperparameters.
        Returns:
            study: Optuna study object.
            best_params_complete: Dictionary of best hyperparameters (fixed + tuned).
        """

        # Update fixed params - this is the key fix
        if fixed_params:
            self.fixed_params.update(fixed_params)
            print(f"Fixed parameters: {self.fixed_params}")

        # Store param_ranges for objective
        self._param_ranges = param_ranges if param_ranges is not None else {}

        # Path to save study
        study_path = f"{self.parent_directory}/bayesian_optimization/study_{study_name}.pkl"
        os.makedirs(os.path.dirname(study_path), exist_ok=True)

        sampler = TPESampler(seed=self.seed)
        pruner = MedianPruner(n_startup_trials=10, n_warmup_steps=10, interval_steps=5)

        # Load or create study
        if continue_from_last and os.path.exists(study_path):
            print(f"Resuming study from {study_path}...")
            study = joblib.load(study_path)
        else:
            print("Creating a new study...")
            study = optuna.create_study(
                direction='minimize',
                sampler=sampler,
                pruner=pruner,
                study_name=study_name
            )

        # Wrap the objective to inject fixed params as trial user attributes
        def objective(trial):
            # Log fixed params in this trial
            for k, v in self.fixed_params.items():
                trial.set_user_attr(f"fixed_{k}", v)
            return self.objective_with_ranges(trial, self._param_ranges)

        # Run optimization
        study.optimize(
            objective,
            n_trials=self.n_trials,
            timeout=None,
            show_progress_bar=True
        )

        # Save study
        joblib.dump(study, study_path)
        print(f"Study saved at {study_path}")
        print(f"Best validation loss: {study.best_value:.6f}")

        # Merge fixed params OVER Optuna's best params
        best_params_complete = {**study.best_params, **self.fixed_params}

        # Print the complete parameter set
        print("Best parameters (including fixed):")
        for k, v in best_params_complete.items():
            print(f"  {k}: {v}")

        return study, best_params_complete
        
    def plot_optimization_results(self, study):
        """Visualize optimization progress and parameter importance."""
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # Optimization history
        trial_values = [trial.value for trial in study.trials if trial.value is not None]
        axes[0, 0].plot(trial_values, marker='o')
        axes[0, 0].set_xlabel('Trial')
        axes[0, 0].set_ylabel('Validation Loss')
        axes[0, 0].set_title('Optimization History')
        axes[0, 0].grid(True)
        
        # Parameter importance (if available)
        try:
            importance = optuna.importance.get_param_importances(study)
            if importance:
                params = list(importance.keys())
                values = list(importance.values())
                axes[0, 1].barh(params, values)
                axes[0, 1].set_xlabel('Importance')
                axes[0, 1].set_title('Parameter Importance')
            else:
                axes[0, 1].text(0.5, 0.5, 'Parameter importance\nnot available', 
                                ha='center', va='center', transform=axes[0, 1].transAxes)
        except:
            axes[0, 1].text(0.5, 0.5, 'Parameter importance\nnot available', 
                            ha='center', va='center', transform=axes[0, 1].transAxes)
        
        # Best trial convergence
        best_values = []
        best_so_far = float('inf')
        for trial in study.trials:
            if trial.value is not None and trial.value < best_so_far:
                best_so_far = trial.value
            best_values.append(best_so_far)
        
        axes[1, 0].plot(best_values, marker='o')
        axes[1, 0].set_xlabel('Trial')
        axes[1, 0].set_ylabel('Best Validation Loss')
        axes[1, 0].set_title('Best Score Convergence')
        axes[1, 0].grid(True)
        
        # Distribution of validation losses
        valid_values = [v for v in trial_values if v != float('inf')]
        if valid_values:
            axes[1, 1].hist(valid_values, bins=20, alpha=0.7)
            axes[1, 1].set_xlabel('Validation Loss')
            axes[1, 1].set_ylabel('Frequency')
            axes[1, 1].set_title('Distribution of Validation Losses')
            axes[1, 1].grid(True)
        
        plt.tight_layout()
        plt.show()

In [None]:
# LOAD & GRAPH GENERATION
df_components = load_dataset(get_path(file_name = 'components_set.csv', folder_name='datasets'))
smiles_dict = dict(zip(df_components['Abbreviation'], df_components['SMILES']))
df_systems = load_dataset(get_path(file_name = 'systems_set.csv', folder_name='datasets'))
smiles_list = df_components["SMILES"].dropna().tolist()
mol_name_dict = smiles_dict.copy()
# GRAPH
system_graphs = process_dataset(df_systems, smiles_dict)
# LOAD DATASET
splitter_1 = DataSplitter(system_graphs, random_state=split_seed)
RASset1, RASset2, RASset3 = splitter_1.rarity_aware_unseen_amine_split()
opt_data = RASset1 + RASset2

# HYBRID
splitter_2 = DataSplitter(opt_data, random_state=split_seed)
SRSset1, SRSset2, SRSset3 = splitter_2.stratified_random_split()
train_data = SRSset1
val_data = SRSset2 + SRSset3
test_data = RASset3
#Retrieve the statistics of train_data
stats = compute_statistics(train_data)
conc_mean = stats[0]
conc_std = stats[1]
temp_mean = stats[2]
temp_std = stats[3]
pco2_mean = stats[4]
pco2_std = stats[5]
#Apply the scaling to validation and test
original_train_data = copy.deepcopy(train_data)
original_val_data = copy.deepcopy(val_data)
original_test_data = copy.deepcopy(test_data)
combined_original_data = original_train_data + original_val_data + original_test_data
train_data = scale_graphs(train_data, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)
val_data = scale_graphs(val_data, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)
test_data = scale_graphs(test_data, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)

In [None]:
# INITIAL SEARCH
optimizer = BayesianOptimizer(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    conc_std=conc_std,
    conc_mean=conc_mean,
    temp_std=temp_std,
    temp_mean=temp_mean,
    pco2_std=pco2_std,
    pco2_mean=pco2_mean,
    device=device,
    parent_directory=parent_directory,
    n_trials=20,
    seed=seed
)

# Run optimization
study, best_params = optimizer.optimize(study_name=f"{study_name}_{runtyp}",
                                        continue_from_last=False)

# Plot results
optimizer.plot_optimization_results(study)

In [None]:
# CLEAR-UP CUDA MEMORY
del optimizer
torch.cuda.empty_cache()
import gc
gc.collect()

In [None]:
# 2ND SEARCH
optimizer = BayesianOptimizer(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    conc_std=conc_std,
    conc_mean=conc_mean,
    temp_std=temp_std,
    temp_mean=temp_mean,
    pco2_std=pco2_std,
    pco2_mean=pco2_mean,
    device=device,
    parent_directory=parent_directory,
    n_trials=20,
    seed=seed
)

# Run optimization
study, best_params = optimizer.optimize(study_name=f"{study_name}_{runtyp}",
                                        continue_from_last=True)

# Plot results
optimizer.plot_optimization_results(study)

In [None]:
# CLEAR-UP CUDA MEMORY
del optimizer
torch.cuda.empty_cache()
import gc
gc.collect()

In [None]:
# 3RD SEARCH
optimizer = BayesianOptimizer(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    conc_std=conc_std,
    conc_mean=conc_mean,
    temp_std=temp_std,
    temp_mean=temp_mean,
    pco2_std=pco2_std,
    pco2_mean=pco2_mean,
    device=device,
    parent_directory=parent_directory,
    n_trials=20,
    seed=seed
)

# Run optimization
study, best_params = optimizer.optimize(study_name=f"{study_name}_{runtyp}",
                                        continue_from_last=True)

# Plot results
optimizer.plot_optimization_results(study)

In [None]:
# CLEAR-UP CUDA MEMORY
del optimizer
torch.cuda.empty_cache()
import gc
gc.collect()

In [None]:
# 4TH SEARCH
optimizer = BayesianOptimizer(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    conc_std=conc_std,
    conc_mean=conc_mean,
    temp_std=temp_std,
    temp_mean=temp_mean,
    pco2_std=pco2_std,
    pco2_mean=pco2_mean,
    device=device,
    parent_directory=parent_directory,
    n_trials=20,
    seed=seed
)

# Run optimization
study, best_params = optimizer.optimize(study_name=f"{study_name}_{runtyp}",
                                        continue_from_last=True)

# Plot results
optimizer.plot_optimization_results(study)

In [None]:
# CLEAR-UP CUDA MEMORY
del optimizer
torch.cuda.empty_cache()
import gc
gc.collect()

In [None]:
""" # 5TH SEARCH
optimizer = BayesianOptimizer(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    conc_std=conc_std,
    conc_mean=conc_mean,
    temp_std=temp_std,
    temp_mean=temp_mean,
    pco2_std=pco2_std,
    pco2_mean=pco2_mean,
    device=device,
    parent_directory=parent_directory,
    n_trials=20,
    seed=seed
)

# Run optimization
study, best_params = optimizer.optimize(study_name=f"{study_name}_{runtyp}",
                                        continue_from_last=True)

# Plot results
optimizer.plot_optimization_results(study) """

In [None]:
""" # CLEAR-UP CUDA MEMORY
del optimizer
torch.cuda.empty_cache()
import gc
gc.collect() """

In [None]:
""" # FINE TUNING
optimizer = BayesianOptimizer(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    conc_std=conc_std,
    conc_mean=conc_mean,
    temp_std=temp_std,
    temp_mean=temp_mean,
    pco2_std=pco2_std,
    pco2_mean=pco2_mean,
    device=device,
    parent_directory=parent_directory,
    n_trials=150,
    seed=seed
)
# FINE TUNING
fixed_params = {"hidden_dim"    : 128, 
                "graph_layers"  : 2, 
                "fc_layers"     : 4,
                "batch_size"    : 32}

# Change search range for others
param_ranges = {
    "s1": (10, 100000),
    "s2": (10, 100000),
    "s3": (10, 100000)
}

study, best_params = optimizer.optimize(
    study_name=f"{study_name}_{runtyp}",
    continue_from_last=True,
    fixed_params=fixed_params,
    param_ranges=param_ranges
)
print(f"Total trials now: {len(study.trials)}")
# Plot results
optimizer.plot_optimization_results(study) """

In [None]:
# BEST HYPERPARAMETERS
hidden_dim = best_params['hidden_dim']
graph_layers = best_params['graph_layers']
fc_layers = best_params['fc_layers']
lr = best_params['lr']
weight_decay = best_params['weight_decay']
batch_size = best_params['batch_size']
s1 = best_params['s1']
s2 = best_params['s2']
s3 = best_params['s3']
print("Optimzed hyperparameters loaded...")
print(f"hidden_dim: {hidden_dim}")
print(f"graph_layers: {graph_layers}")
print(f"fc_layers: {fc_layers}")
print(f"lr: {lr}")
print(f"weight_decay: {weight_decay}")
print(f"batch_size: {batch_size}")
print(f"s1: {s1}")
print(f"s2: {s2}")
print(f"s3: {s3}")

In [None]:
# FUNCTION -> TRAINING LOOP
def train_one_epoch(model, train_loader, criterion, optimizer, device, 
                    pco2_std, temp_std, conc_std,
                    pco2_mean, temp_mean, conc_mean,
                    s1, s2, s3, 
                    use_physics_in_loss=True, monitor_physics=True):
    model.train()
    total_loss = 0
    total_main_loss = 0
    total_thermo1_loss = 0
    total_thermo2_loss = 0
    total_thermo3_loss = 0
    
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        
        loss_main = criterion(output, data.aco2.view(-1, 1))
        
        # Calculate physics losses
        if use_physics_in_loss or monitor_physics:
            loss_thermo1 = ModelArchitecture.grad_pres(model, data, pco2_std, pco2_mean, s1) if s1 != 0 else 0.0
            loss_thermo2 = ModelArchitecture.grad_temp(model, data, temp_std, temp_mean, s2) if s2 != 0 else 0.0
            loss_thermo3 = ModelArchitecture.grad_conc(model, data, conc_std, conc_mean, s3) if s3 != 0 else 0.0
        else:
            loss_thermo1 = loss_thermo2 = loss_thermo3 = 0.0
        
        # Only include physics losses in backprop if use_physics_in_loss=True
        if use_physics_in_loss:
            loss = loss_main + loss_thermo1 + loss_thermo2 + loss_thermo3
        else:
            loss = loss_main
            
        loss.backward()
        optimizer.step()
        
        # Track total loss consistently with what was actually backpropagated
        if use_physics_in_loss:
            total_loss += loss.item()  # This includes physics losses
        else:
            total_loss += loss_main.item()
        total_main_loss += loss_main.item()
        total_thermo1_loss += loss_thermo1.item() if isinstance(loss_thermo1, torch.Tensor) else loss_thermo1
        total_thermo2_loss += loss_thermo2.item() if isinstance(loss_thermo2, torch.Tensor) else loss_thermo2
        total_thermo3_loss += loss_thermo3.item() if isinstance(loss_thermo3, torch.Tensor) else loss_thermo3
    
    # Calculate averages
    avg_train_loss = total_loss / len(train_loader)
    avg_main_loss = total_main_loss / len(train_loader)
    avg_thermo1_loss = total_thermo1_loss / len(train_loader)
    avg_thermo2_loss = total_thermo2_loss / len(train_loader)
    avg_thermo3_loss = total_thermo3_loss / len(train_loader)
    
    return avg_train_loss, avg_main_loss, avg_thermo1_loss, avg_thermo2_loss, avg_thermo3_loss


# Validation Step
def validate(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            output = model(data)
            loss = criterion(output, data.aco2.view(-1, 1))
            val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader)
    return avg_val_loss

# Test Evaluation (Metrics)
def evaluate_test(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0
    test_outputs = []
    test_targets = []
    
    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            output = model(data)
            loss = criterion(output, data.aco2.view(-1, 1))
            test_loss += loss.item()
            test_outputs.append(output.cpu().numpy())
            test_targets.append(data.aco2.view(-1, 1).cpu().numpy())
    
    test_outputs = np.concatenate(test_outputs, axis=0)
    test_targets = np.concatenate(test_targets, axis=0)
    test_r2 = r2_score(test_targets, test_outputs)
    avg_test_loss = test_loss / len(test_loader)
    
    return avg_test_loss, test_r2
# Loss Plotting
def plot_losses(train_losses, val_losses, test_losses, test_r2_scores, 
                main_losses, thermo1_losses, thermo2_losses, thermo3_losses):
    # Original combined loss plot
    plt.figure(figsize=(12, 8))
    
    plt.subplot(2, 2, 1)
    plt.plot(range(1, len(train_losses) + 1), train_losses, label='Total Train Loss')
    plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
    plt.plot(range(1, len(test_losses) + 1), test_losses, label='Test Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training, Validation and Test Loss Curves')
    plt.legend()
    plt.grid(True)

    plt.subplot(2, 2, 2)
    plt.plot(range(1, len(test_r2_scores) + 1), test_r2_scores, color='green')
    plt.xlabel('Epochs')
    plt.ylabel('R² Score')
    plt.title('Test R² Score per Epoch')
    plt.grid(True)

    # Individual physics loss components
    plt.subplot(2, 2, 3)
    plt.plot(range(1, len(main_losses) + 1), main_losses, label='Main Loss', color='blue')
    plt.plot(range(1, len(thermo1_losses) + 1), thermo1_losses, label='Partial pressure gradient (s1)', color='red')
    plt.plot(range(1, len(thermo2_losses) + 1), thermo2_losses, label='Temperature gradient (s2)', color='orange')
    plt.plot(range(1, len(thermo3_losses) + 1), thermo3_losses, label='Concentration gradient (s3)', color='purple')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Individual Loss Components')
    plt.legend()
    plt.grid(True)

    # Physics losses only (zoomed view)
    plt.subplot(2, 2, 4)
    plt.plot(range(1, len(thermo1_losses) + 1), thermo1_losses, label='Partial pressure gradient (s1)', color='red')
    plt.plot(range(1, len(thermo2_losses) + 1), thermo2_losses, label='Temperature gradient (s2)', color='orange')
    plt.plot(range(1, len(thermo3_losses) + 1), thermo3_losses, label='Concentration gradient (s3)', color='purple')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Physics Loss Components (Zoomed)')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()

# Main Training Loop
def train_model(model, train_loader, val_loader, 
                test_loader, criterion, optimizer, 
                pco2_std, temp_std, conc_std,
                pco2_mean, temp_mean, conc_mean,
                s1, s2, s3,
                device, num_epochs, 
                seed, parent_directory,
                use_physics_in_loss=True, monitor_physics=True):
    """
    Args:
        use_physics_in_loss (bool): Whether to include physics losses in backpropagation
        monitor_physics (bool): Whether to compute and track physics losses for monitoring
    """
    # Initialize loss tracking lists
    train_losses, val_losses, test_losses, test_r2_scores = [], [], [], []
    main_losses, thermo1_losses, thermo2_losses, thermo3_losses = [], [], [], []

    # Print training mode
    if use_physics_in_loss and monitor_physics:
        print("Training with physics losses in backpropagation + monitoring")
    elif monitor_physics and not use_physics_in_loss:
        print("Training with Main Loss only, but monitoring physics losses")
    elif use_physics_in_loss and not monitor_physics:
        print("Training with physics losses in backpropagation (no monitoring)")
    else:
        print("Training with Main Loss only (no physics)")

    for epoch in range(num_epochs):
        # Get individual loss components from training
        avg_train_loss, avg_main_loss, avg_thermo1_loss, avg_thermo2_loss, avg_thermo3_loss = train_one_epoch(
            model, train_loader, criterion, optimizer, device, 
            pco2_std, temp_std, conc_std, 
            pco2_mean, temp_mean, conc_mean,
            s1, s2, s3, use_physics_in_loss, monitor_physics)

        avg_val_loss = validate(model, val_loader, criterion, device)
        avg_test_loss, test_r2 = evaluate_test(model, test_loader, criterion, device)

        # Store all losses
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        test_losses.append(avg_test_loss)
        test_r2_scores.append(test_r2)
        main_losses.append(avg_main_loss)
        thermo1_losses.append(avg_thermo1_loss)
        thermo2_losses.append(avg_thermo2_loss)
        thermo3_losses.append(avg_thermo3_loss)

        # Enhanced print statement with individual loss components
        if monitor_physics:
            physics_info = f'Main Loss: {round(avg_main_loss, 4)}, gradP: {round(avg_thermo1_loss, 4)}, gradT: {round(avg_thermo2_loss, 4)}, gradC: {round(avg_thermo3_loss, 4)}'
            physics_status = " [BACKPROP]" if use_physics_in_loss else " [MONITOR]"
        else:
            physics_info = f'Main Loss: {round(avg_main_loss, 4)}'
            physics_status = ""
            
        print(f'Epoch {epoch+1}/{num_epochs}, Total Loss: {round(avg_train_loss, 4)}, '
              f'Val Loss: {round(avg_val_loss, 4)}, Test Loss: {round(avg_test_loss, 4)}, '
              f'Test R2: {round(test_r2, 4)} | {physics_info}{physics_status}')
    # Plot all losses including individual components
    plot_losses(train_losses, val_losses, test_losses, test_r2_scores, 
                main_losses, thermo1_losses, thermo2_losses, thermo3_losses)
    
    # Return all tracked losses
    return (train_losses, val_losses, test_losses, test_r2_scores, 
            main_losses, thermo1_losses, thermo2_losses, thermo3_losses)

In [None]:
# TRAIN WITH OPTIMIZED HYPERPARAMETERS
print("Training final model with optimized hyperparameters...")

# Load the data into DataLoader
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

node_dim = train_data[0].x.size(1)
edge_dim = train_data[0].edge_attr.size(1)

model = ModelArchitecture.VLEAmineCO2(node_dim=node_dim,
                    edge_dim=edge_dim, 
                    hidden_dim=hidden_dim,
                    graph_layers=graph_layers,
                    fc_layers=fc_layers,
                    use_adaptive_pooling=True
                    ).to(device)

criterion = ModelArchitecture.MSLELoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

# Start timing
real_time_start = time.time()  # Real time (wall-clock time)
cpu_time_start = time.process_time()  # CPU time

# Updated to receive all loss components
(train_losses, val_losses, test_losses, test_r2_scores, 
 main_losses, thermo1_losses, thermo2_losses, thermo3_losses) = train_model(
    model, train_loader, val_loader, test_loader, 
    criterion, optimizer, 
    pco2_std, temp_std, conc_std, 
    pco2_mean, temp_mean, conc_mean,
    s1, s2, s3, 
    device, num_epochs=num_epochs,
    seed=seed, parent_directory=parent_directory,
    use_physics_in_loss=use_physics_in_loss, monitor_physics=monitor_physics)

model_path = f"{parent_directory}/models/models_root/{runtyp}/model_weights/MODEL_{run_time}.pth"
os.makedirs(os.path.dirname(model_path), exist_ok=True)
torch.save(model.state_dict(), model_path)
# Enhanced CSV path with individual loss tracking
csv_path = f"{parent_directory}/models/models_root/{runtyp}/losses/losses_{run_time}.csv"
os.makedirs(os.path.dirname(csv_path), exist_ok=True)

# Write comprehensive loss data to CSV
with open(csv_path, mode='w', newline='') as f:
    writer = csv.writer(f)
    # Write enhanced header with individual loss components
    writer.writerow([
        "Epoch", "Train Loss", "Validation Loss", "Test Loss", "Test R2",
        "Main Loss", "gradP Loss (s1)", "gradT Loss (s2)", "gradC Loss (s3)"
    ])
    
    # Write losses for each epoch including individual components (training)
    for epoch in range(len(train_losses)):
        writer.writerow([
            epoch + 1,
            train_losses[epoch],
            val_losses[epoch],
            test_losses[epoch],
            test_r2_scores[epoch],
            main_losses[epoch],
            thermo1_losses[epoch],
            thermo2_losses[epoch],
            thermo3_losses[epoch]
        ])

# Optional: Save individual physics loss components to separate CSV for detailed analysis
physics_csv_path = f"{parent_directory}/models/models_root/{runtyp}/physics_loss/physics_losses_{run_time}.csv"
os.makedirs(os.path.dirname(physics_csv_path), exist_ok=True)
with open(physics_csv_path, mode='w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(["Epoch", "Main Loss", "gradP Loss (s1)", "gradT Loss (s2)", "gradC Loss (s3)", "Physics Loss Sum"])
    
    for epoch in range(len(main_losses)):
        physics_sum = thermo1_losses[epoch] + thermo2_losses[epoch] + thermo3_losses[epoch]
        writer.writerow([
            epoch + 1,
            main_losses[epoch],
            thermo1_losses[epoch],
            thermo2_losses[epoch],
            thermo3_losses[epoch],
            physics_sum
        ])

# End timing
real_time_end = time.time()
cpu_time_end = time.process_time()

# Calculate elapsed time
real_time_elapsed = real_time_end - real_time_start
cpu_time_elapsed = cpu_time_end - cpu_time_start

# Output the training time
print(f"Training Real Time (Wall-Clock Time): {real_time_elapsed:.2f} seconds")
print(f"Training CPU Time: {cpu_time_elapsed:.2f} seconds")

# Print summary statistics of loss components
print("\n=== Training Summary ===")
print(f"Final Total Train Loss: {train_losses[-1]:.6f}")
print(f"Final Validation Loss: {val_losses[-1]:.6f}")
print(f"Final Test Loss: {test_losses[-1]:.6f}")
print(f"Final Test R2: {test_r2_scores[-1]:.6f}")
print(f"Final Main Loss: {main_losses[-1]:.6f}")
print(f"Final gradP (s1): {thermo1_losses[-1]:.6f}")
print(f"Final gradT (s2): {thermo2_losses[-1]:.6f}")
print(f"Final gradC (s3): {thermo3_losses[-1]:.6f}")
print(f"\nLoss data saved to: {csv_path}")
print(f"Physics loss details saved to: {physics_csv_path}")

In [None]:
# PARITY PLOT GENERATION
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
from sklearn.metrics import r2_score, mean_squared_error
import torch

def collect_predictions_and_true_values(model, data_loader, device):
    predictions = []
    true_values = []
    
    model.eval()
    with torch.no_grad():
        for data in data_loader:
            data = data.to(device)
            output = model(data)
            predictions.extend(output.cpu().numpy())
            true_values.extend(data.aco2.cpu().numpy())
    
    return predictions, true_values

# Function to calculate R² and RMSE
def calculate_metrics(true_values, predictions):
    r2 = r2_score(true_values, predictions)
    rmse = np.sqrt(mean_squared_error(true_values, predictions))
    return r2, rmse

# Function to save metrics to CSV
def save_metrics_to_csv(r2_train, rmse_train, r2_val, rmse_val, r2_test, rmse_test, parent_directory):
    # Create the metrics dictionary
    metrics_data = {
        'Dataset': ['Training', 'Validation', 'Test'],
        'R2': [r2_train, r2_val, r2_test],
        'RMSE': [rmse_train, rmse_val, rmse_test]
    }

# Function to plot the parity plot with marginal histograms
def plot_parity_plot(train_true_values, train_predictions, 
                     val_true_values, val_predictions, 
                     test_true_values, test_predictions,
                     parent_directory=None):
    fontsize = 16
    matplotlib.rcParams['font.family'] = 'Times New Roman'

    # Calculate metrics
    r2_train, rmse_train = calculate_metrics(train_true_values, train_predictions)
    r2_val, rmse_val = calculate_metrics(val_true_values, val_predictions)
    r2_test, rmse_test = calculate_metrics(test_true_values, test_predictions)

    # Create figure with gridspec for histograms
    fig = plt.figure(figsize=(8, 8))
    gs = gridspec.GridSpec(2, 2, width_ratios=[4, 1], height_ratios=[1, 4], 
                          hspace=0.00, wspace=0.00)
    
    # Main plot
    ax = fig.add_subplot(gs[1, 0])
    ax_histx = fig.add_subplot(gs[0, 0], sharex=ax)
    ax_histy = fig.add_subplot(gs[1, 1], sharey=ax)

    # Scatter plots (keeping your exact style)
    ax.scatter(train_true_values, train_predictions, 
                edgecolors='b', alpha=0.5, c='b', marker='o', 
                label=f'Train   (R² = {r2_train:.3f},  RMSE = {rmse_train:.3f})')
    ax.scatter(val_true_values, val_predictions, 
                edgecolors='g', alpha=0.5, c='g', marker='^', 
                label=f'Val      (R² = {r2_val:.3f},  RMSE = {rmse_val:.3f})')
    ax.scatter(test_true_values, test_predictions, 
                edgecolors='r', alpha=0.5, c='r', marker='v', 
                label=f'Test     (R² = {r2_test:.3f},  RMSE = {rmse_test:.3f})')

    # Parity line (keeping your exact style)
    max_val = max(max(train_true_values), max(val_true_values), max(test_true_values))
    ax.plot([-0.1, max_val+0.5], [-0.1, max_val+0.5], '--', linewidth=1.5, color='black')

    # Labels & ticks (keeping your exact formatting)
    ax.set_xlabel('Actual Solubility', fontsize=fontsize)
    ax.set_ylabel('Predicted Solubility', fontsize=fontsize)
    ax.set_xlim(-0.1, 2.5)
    ax.set_ylim(-0.1, 2.5)
    ax.tick_params(axis='both', which='major', length=6, width=0.8, labelsize=fontsize)
    ax.tick_params(axis='both', which='minor', length=4, width=0.8)
    ax.minorticks_on()
    ax.legend(fontsize=fontsize-3, loc='upper left', frameon=False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)    

    # Add histograms with dataset differentiation
    bins = np.linspace(-0.1, 2.5, 27)
    
    # Top histogram (experimental values) - stacked by dataset
    ax_histx.hist([np.array(train_true_values).flatten(), 
                   np.array(val_true_values).flatten(), 
                   np.array(test_true_values).flatten()], 
                  bins=bins, color=['b', 'g', 'r'], 
                  alpha=0.5, stacked=True, edgecolor='black', linewidth=0.5)
    ax_histx.tick_params(labelbottom=False, labelleft=False, left=False)
    ax_histx.spines['top'].set_visible(False)
    ax_histx.spines['right'].set_visible(False)
    ax_histx.spines['left'].set_visible(False)
    #ax_histx.spines['bottom'].set_visible(False)

    # Right histogram (predicted values) - stacked by dataset
    ax_histy.hist([np.array(train_predictions).flatten(), 
                   np.array(val_predictions).flatten(), 
                   np.array(test_predictions).flatten()], 
                  bins=bins, orientation='horizontal', color=['b', 'g', 'r'], 
                  alpha=0.5, stacked=True, edgecolor='black', linewidth=0.5)
    ax_histy.tick_params(labelbottom=False, labelleft=False, bottom=False)
    ax_histy.spines['top'].set_visible(False)
    ax_histy.spines['right'].set_visible(False)
    #ax_histy.spines['left'].set_visible(False)
    ax_histy.spines['bottom'].set_visible(False)

    plt.show()

    # Save metrics if needed
    if parent_directory:
        save_metrics_to_csv(r2_train, rmse_train, r2_val, rmse_val, r2_test, rmse_test, parent_directory)


# Collect predictions and true values for training, validation, and test data
train_predictions, train_true_values = collect_predictions_and_true_values(model, train_loader, device)
val_predictions, val_true_values = collect_predictions_and_true_values(model, val_loader, device)
test_predictions, test_true_values = collect_predictions_and_true_values(model, test_loader, device)

# Plot the parity plot
plot_parity_plot(train_true_values, train_predictions, 
                 val_true_values, val_predictions, 
                 test_true_values, test_predictions)