# INITIALIZATION

In [None]:
# IMPORTING LIBRARIES
import copy
import os
import random
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # or ':16:8'
import numpy as np
import pandas as pd
import torch
from torch_geometric.loader import DataLoader
# FUNCTIONS
from data_processing import load_dataset, process_dataset
from path_helpers import get_path
from stats_compute import compute_statistics, scale_graphs
from smart_loader import load_model_for_inference
from EnhancedDataSplit import DataSplitter
from typing import List
# DIRECTORY SETUP
current_directory = os.getcwd()
parent_directory = os.path.dirname(current_directory)

In [None]:
# HYPERPARAMETER SETTINGS
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Reproducibility settings
seed = 21
split_seed = 42
batch_size = 32
runtime = timestamp
selected_device = 'cuda' # either 'cuda' or 'cpu
device = torch.device(selected_device)

# CUDA Deterministic (ON/OFF SETTING)
# For PyTorch
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
print('device           :', device)
print('seed             :', seed)
print('split seed       :', split_seed)

In [None]:
# LOAD & GRAPH GENERATION FOR RAS
df_components = load_dataset(get_path(file_name = 'components_set.csv', folder_name='datasets'))
smiles_dict = dict(zip(df_components['Abbreviation'], df_components['SMILES']))
df_systems = load_dataset(get_path(file_name = 'systems_set.csv', folder_name='datasets'))
smiles_list = df_components["SMILES"].dropna().tolist()
mol_name_dict = smiles_dict.copy()
# GRAPH
system_graphs = process_dataset(df_systems, smiles_dict)
# LOAD DATASET
splitter = DataSplitter(system_graphs, random_state=split_seed)
splitter.print_dataset_stats()
# Options: rarity_aware_unseen_amine_split stratified_random_split
train_data, val_data, test_data = splitter.rarity_aware_unseen_amine_split()
#Retrieve the statistics of train_data
stats = compute_statistics(train_data)
conc_mean = stats[0]
conc_std = stats[1]
temp_mean = stats[2]
temp_std = stats[3]
pco2_mean = stats[4]
pco2_std = stats[5]
#Apply the scaling to validation and test
original_train_data = copy.deepcopy(train_data)
original_val_data = copy.deepcopy(val_data)
original_test_data = copy.deepcopy(test_data)
combined_original_data = original_train_data + original_val_data + original_test_data
train_data = scale_graphs(train_data, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)
val_data = scale_graphs(val_data, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)
test_data = scale_graphs(test_data, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)
#Load the data into DataLoader
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [None]:
# Computing varience of the alpha_co2 on datasets
def get_alpha_stats(dataset):
    values = torch.cat([torch.tensor([data.aco2]) for data in dataset])
    
    # Variance calculations
    sample_var = torch.var(values, unbiased=True)      # N-1
    population_var = torch.var(values, unbiased=False) # N
    
    # Range and statistics
    min_val = torch.min(values)
    max_val = torch.max(values)
    mean_val = torch.mean(values)
    std_val = torch.std(values, unbiased=True)
    median_val = torch.median(values)
    
    return {
        'sample_var': sample_var.item(),
        'population_var': population_var.item(),
        'min': min_val.item(),
        'max': max_val.item(),
        'range': (max_val - min_val).item(),
        'mean': mean_val.item(),
        'std': std_val.item(),
        'median': median_val.item(),
        'count': len(values)
    }

# Calculate statistics for all datasets
train_stats = get_alpha_stats(original_train_data)
val_stats = get_alpha_stats(original_val_data)
test_stats = get_alpha_stats(original_test_data)

# Print comprehensive statistics
print("=" * 80)
print("α_CO2 DATASET STATISTICS")
print("=" * 80)

datasets = [('Train', train_stats), ('Val', val_stats), ('Test', test_stats)]

for name, stats in datasets:
    print(f"\n{name:>5} Dataset (n={stats['count']}):")
    print(f"  Range:     [{stats['min']:.4f}, {stats['max']:.4f}]  (span: {stats['range']:.4f})")
    print(f"  Mean:      {stats['mean']:.4f}  ±  {stats['std']:.4f}")
    print(f"  Median:    {stats['median']:.4f}")
    print(f"  Variance:  {stats['sample_var']:.6f} (sample), {stats['population_var']:.6f} (population)")

print("\n" + "=" * 80)
print("RELATIVE COMPARISONS (vs Train)")
print("=" * 80)

for name, stats in datasets[1:]:  # Skip train for comparison
    print(f"\n{name} vs Train:")
    print(f"  Range ratio:    {stats['range']/train_stats['range']:.3f}")
    print(f"  Mean ratio:     {stats['mean']/train_stats['mean']:.3f}")
    print(f"  Variance ratio: {stats['sample_var']/train_stats['sample_var']:.3f}")
    print(f"  Min overlap:    {'Yes' if stats['min'] >= train_stats['min'] else 'No'} ({stats['min']:.4f} vs {train_stats['min']:.4f})")
    print(f"  Max overlap:    {'Yes' if stats['max'] <= train_stats['max'] else 'No'} ({stats['max']:.4f} vs {train_stats['max']:.4f})")

In [None]:
# Importing model weights
# Baseline
baseline_dir = os.path.join(
    os.path.dirname(os.getcwd()),
    "models",
    "models_root",
    "model_for_inference",
    "ras_baseline"
)
files = sorted(os.listdir(baseline_dir))
if len(files) == 0:
    raise FileNotFoundError(f"No files found in {baseline_dir}")
model_file = files[0]
print(f"Using model file: {model_file}")
path = os.path.join(baseline_dir, model_file)
selected_device = 'cuda'
model_1 = load_model_for_inference(path, device=device)

# PIGNN
pinn_dir = os.path.join(
    os.path.dirname(os.getcwd()),
    "models",
    "models_root",
    "model_for_inference",
    "ras_pinn"
)
files = sorted(os.listdir(pinn_dir))
if len(files) == 0:
    raise FileNotFoundError(f"No files found in {pinn_dir}")
model_file = files[0]
print(f"Using model file: {model_file}")
path = os.path.join(pinn_dir, model_file)
selected_device = 'cuda'
model_2 = load_model_for_inference(path, device=device)

# Load both models and prepare for comparison
models = {
    "Baseline": (model_1, lambda g: scale_graphs(g, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)),
    "PINN": (model_2, lambda g: scale_graphs(g, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)),
}

In [None]:
""" # PARITY PLOT GENERATION
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
from sklearn.metrics import r2_score, mean_squared_error
import torch
import matplotlib

def collect_predictions_and_true_values(model, data_loader, device):
    predictions = []
    true_values = []
    
    model.eval()
    with torch.no_grad():
        for data in data_loader:
            data = data.to(device)
            output = model(data)
            predictions.extend(output.cpu().numpy())
            true_values.extend(data.aco2.cpu().numpy())
    
    return np.array(predictions), np.array(true_values)

# Function to calculate R² and RMSE
def calculate_metrics(true_values, predictions):
    r2 = r2_score(true_values, predictions)
    rmse = np.sqrt(mean_squared_error(true_values, predictions))
    return r2, rmse

# Function to save metrics to CSV
def save_metrics_to_csv(metrics_dict, parent_directory):
    import pandas as pd
    # Create the metrics dictionary
    metrics_data = {
        'Dataset': [],
        'R2': [],
        'RMSE': []
    }
    
    for dataset_name, metrics in metrics_dict.items():
        metrics_data['Dataset'].append(dataset_name)
        metrics_data['R2'].append(metrics['r2'])
        metrics_data['RMSE'].append(metrics['rmse'])
    
    df = pd.DataFrame(metrics_data)
    df.to_csv(f"{parent_directory}/metrics.csv", index=False)

# Function to plot the parity plot with marginal histograms
def plot_parity_plot(datasets_dict, parent_directory=None, fontsize=16):
    """
    Plot parity plot for any combination of datasets.
    
    Parameters:
    -----------
    datasets_dict : dict
        Dictionary containing datasets to plot. Format:
        {
            'train': {'true': train_true_values, 'pred': train_predictions, 'color': 'b', 'marker': 'o'},
            'val': {'true': val_true_values, 'pred': val_predictions, 'color': 'g', 'marker': '^'},
            'test': {'true': test_true_values, 'pred': test_predictions, 'color': 'r', 'marker': 'v'}
        }
        You can include any subset of these keys.
    parent_directory : str, optional
        Directory to save metrics CSV
    fontsize : int, optional
        Font size for plot elements
    """
    matplotlib.rcParams['font.family'] = 'Times New Roman'

    # Calculate metrics for each dataset
    metrics_dict = {}
    for dataset_name, data in datasets_dict.items():
        r2, rmse = calculate_metrics(data['true'], data['pred'])
        metrics_dict[dataset_name] = {'r2': r2, 'rmse': rmse}

    # Create figure with gridspec for histograms
    fig = plt.figure(figsize=(8, 8))
    gs = gridspec.GridSpec(2, 2, width_ratios=[4, 1], height_ratios=[1, 4], 
                          hspace=0.00, wspace=0.00)
    
    # Main plot
    ax = fig.add_subplot(gs[1, 0])
    ax_histx = fig.add_subplot(gs[0, 0], sharex=ax)
    ax_histy = fig.add_subplot(gs[1, 1], sharey=ax)

    # Scatter plots for each dataset
    for dataset_name, data in datasets_dict.items():
        label = f"{dataset_name.capitalize()} (R² = {metrics_dict[dataset_name]['r2']:.4f}, RMSE = {metrics_dict[dataset_name]['rmse']:.4f})"
        ax.scatter(data['true'], data['pred'], 
                   edgecolors=data['color'], alpha=0.5, c=data['color'], 
                   marker=data['marker'], label=label)

    # Parity line
    all_true_values = np.concatenate([data['true'] for data in datasets_dict.values()])
    max_val = max(all_true_values) if len(all_true_values) > 0 else 2.5
    ax.plot([-0.1, 10+0.5], [-0.1, 10+0.5], '--', linewidth=1.5, color='black')

    # Labels & ticks
    ax.set_xlabel('Actual Solubility', fontsize=fontsize)
    ax.set_ylabel('Predicted Solubility', fontsize=fontsize)
    ax.set_xlim(-0.1, 2.5)
    ax.set_ylim(-0.1, 2.5)
    ax.tick_params(axis='both', which='major', length=6, width=0.8, labelsize=fontsize)
    ax.tick_params(axis='both', which='minor', length=4, width=0.8)
    ax.minorticks_on()
    ax.legend(fontsize=fontsize-3, loc='upper left', frameon=False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)    

    # Add histograms with dataset differentiation
    bins = np.linspace(-0.1, 2.5, 27)
    
    # Top histogram (experimental values)
    if datasets_dict:
        # Convert to numpy arrays and flatten
        true_data = [data['true'].flatten() for data in datasets_dict.values()]
        colors = [data['color'] for data in datasets_dict.values()]
        
        ax_histx.hist(true_data, bins=bins, color=colors, 
                      alpha=0.5, stacked=True, edgecolor='black', linewidth=0.5)
        ax_histx.tick_params(labelbottom=False, labelleft=False, left=False)
        ax_histx.spines['top'].set_visible(False)
        ax_histx.spines['right'].set_visible(False)
        ax_histx.spines['left'].set_visible(False)

    # Right histogram (predicted values)
    if datasets_dict:
        # Convert to numpy arrays and flatten
        pred_data = [data['pred'].flatten() for data in datasets_dict.values()]
        colors = [data['color'] for data in datasets_dict.values()]
        
        ax_histy.hist(pred_data, bins=bins, orientation='horizontal', 
                      color=colors, alpha=0.5, stacked=True, 
                      edgecolor='black', linewidth=0.5)
        ax_histy.tick_params(labelbottom=False, labelleft=False, bottom=False)
        ax_histy.spines['top'].set_visible(False)
        ax_histy.spines['right'].set_visible(False)
        ax_histy.spines['bottom'].set_visible(False)

    plt.show()

    # Save metrics if needed
    if parent_directory:
        save_metrics_to_csv(metrics_dict, parent_directory)


# Collect predictions and true values for training, validation, and test data
train_predictions, train_true_values = collect_predictions_and_true_values(model_2, train_loader, device)
val_predictions, val_true_values = collect_predictions_and_true_values(model_2, val_loader, device)
test_predictions, test_true_values = collect_predictions_and_true_values(model_2, test_loader, device)

# EXAMPLE USAGE - Choose any combination:

# Option 1: All three datasets
plot_parity_plot({
    'train': {'true': train_true_values, 'pred': train_predictions, 'color': 'b', 'marker': 'o'},
    'val': {'true': val_true_values, 'pred': val_predictions, 'color': 'g', 'marker': '^'},
    'test': {'true': test_true_values, 'pred': test_predictions, 'color': 'r', 'marker': 'v'}
})

# Option 2: Only train and validation
plot_parity_plot({
    'train': {'true': train_true_values, 'pred': train_predictions, 'color': 'b', 'marker': 'o'},
    'val': {'true': val_true_values, 'pred': val_predictions, 'color': 'g', 'marker': '^'}
})

# Option 3: Only train and test
plot_parity_plot({
    'train': {'true': train_true_values, 'pred': train_predictions, 'color': 'b', 'marker': 'o'},
    'test': {'true': test_true_values, 'pred': test_predictions, 'color': 'r', 'marker': 'v'}
})

# Option 4: Only validation and test
plot_parity_plot({
    'val': {'true': val_true_values, 'pred': val_predictions, 'color': 'g', 'marker': '^'},
    'test': {'true': test_true_values, 'pred': test_predictions, 'color': 'r', 'marker': 'v'}
})

# Option 5: Only train
plot_parity_plot({
    'train': {'true': train_true_values, 'pred': train_predictions, 'color': 'b', 'marker': 'o'}
})

# Option 6: Only validation
plot_parity_plot({
    'val': {'true': val_true_values, 'pred': val_predictions, 'color': 'g', 'marker': '^'}
})

# Option 7: Only test
plot_parity_plot({
    'test': {'true': test_true_values, 'pred': test_predictions, 'color': 'r', 'marker': 'v'}
}) """

In [None]:
""" # PARITY PLOT GENERATION (isolated)
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
from sklearn.metrics import r2_score, mean_squared_error
import torch
import matplotlib
import pandas as pd

def collect_predictions_and_true_values_with_amines(model, data_loader, device):
    predictions = []
    true_values = []
    amine_names = []
    
    model.eval()
    with torch.no_grad():
        for data in data_loader:
            data = data.to(device)
            output = model(data)
            predictions.extend(output.cpu().numpy())
            true_values.extend(data.aco2.cpu().numpy())
            
            # Extract amine names - handle different possible formats
            if hasattr(data, 'name'):
                # If name is a tensor, convert to numpy and then to list
                if torch.is_tensor(data.name):
                    amine_names.extend(data.name.cpu().numpy().tolist())
                # If name is already a list or array
                elif isinstance(data.name, (list, np.ndarray)):
                    amine_names.extend(data.name)
                else:
                    # Fallback if format is unexpected
                    amine_names.extend(['unknown'] * len(data))
            else:
                amine_names.extend(['unknown'] * len(data))
    
    return np.array(predictions), np.array(true_values), amine_names

# Function to calculate R² and RMSE
def calculate_metrics(true_values, predictions):
    r2 = r2_score(true_values, predictions)
    rmse = np.sqrt(mean_squared_error(true_values, predictions))
    return r2, rmse

# Function to plot the parity plot with amine-based coloring
def plot_test_parity_plot_by_amine(test_true_values, test_predictions, amine_names, fontsize=16):
    """
    Plot parity plot for test set with points colored by amine names.
    
    Parameters:
    -----------
    test_true_values : array-like
        True values for test set
    test_predictions : array-like
        Predicted values for test set
    amine_names : list
        List of amine names corresponding to each data point
    fontsize : int, optional
        Font size for plot elements
    """
    matplotlib.rcParams['font.family'] = 'Times New Roman'

    # Calculate metrics
    r2_test, rmse_test = calculate_metrics(test_true_values, test_predictions)

    # Create figure with gridspec for histograms
    fig = plt.figure(figsize=(8, 8))
    gs = gridspec.GridSpec(2, 2, width_ratios=[4, 1], height_ratios=[1, 4], 
                          hspace=0.00, wspace=0.00)
    
    # Main plot
    ax = fig.add_subplot(gs[1, 0])
    ax_histx = fig.add_subplot(gs[0, 0], sharex=ax)
    ax_histy = fig.add_subplot(gs[1, 1], sharey=ax)

    # Get unique amine names and assign colors
    unique_amines = sorted(set(amine_names))
    colors = plt.cm.Set1(np.linspace(0, 1, len(unique_amines)))
    
    # Create color mapping
    amine_color_map = {amine: colors[i] for i, amine in enumerate(unique_amines)}
    
    # Convert amine_names to numpy array for boolean indexing
    amine_names_array = np.array(amine_names)
    
    # Scatter plot colored by amine
    for amine in unique_amines:
        mask = amine_names_array == amine
        if np.any(mask):  # Only plot if there are points for this amine
            ax.scatter(test_true_values[mask], test_predictions[mask], 
                       color=amine_color_map[amine], alpha=0.7, s=50,
                       label=amine, edgecolors='black', linewidth=0.5)

    # Parity line
    ax.plot([-0.1, 10], [-0.1, 10], '--', linewidth=1.5, color='black')

    # Labels & ticks
    ax.set_xlabel('Actual Solubility', fontsize=fontsize)
    ax.set_ylabel('Predicted Solubility', fontsize=fontsize)
    ax.set_xlim(-0.1, 2.5)
    ax.set_ylim(-0.1, 2.5)
    ax.tick_params(axis='both', which='major', length=6, width=0.8, labelsize=fontsize)
    ax.tick_params(axis='both', which='minor', length=4, width=0.8)
    ax.minorticks_on()
    ax.legend(fontsize=fontsize-4, loc='upper left', frameon=True, ncol=2)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    # Add title with metrics
    ax.set_title(f'Test Set: R² = {r2_test:.4f}, RMSE = {rmse_test:.4f}', fontsize=fontsize)

    # Add histograms with amine-based coloring
    bins = np.linspace(-0.1, 2.5, 27)
    
    # Top histogram (experimental values) - colored by amine
    if len(test_true_values) > 0:
        # Filter out empty arrays
        true_data = [test_true_values[amine_names_array == amine] for amine in unique_amines]
        true_data = [arr for arr in true_data if len(arr) > 0]
        true_colors = [amine_color_map[amine] for amine in unique_amines if len(test_true_values[amine_names_array == amine]) > 0]
        
        if true_data:  # Only plot if there's data
            ax_histx.hist(true_data, bins=bins, color=true_colors, 
                          alpha=0.7, stacked=True, edgecolor='black', linewidth=0.5)
        ax_histx.tick_params(labelbottom=False, labelleft=False, left=False)
        ax_histx.spines['top'].set_visible(False)
        ax_histx.spines['right'].set_visible(False)
        ax_histx.spines['left'].set_visible(False)

    # Right histogram (predicted values) - colored by amine
    if len(test_predictions) > 0:
        # Filter out empty arrays and ensure they are 1D
        pred_data = [test_predictions[amine_names_array == amine] for amine in unique_amines]
        pred_data = [arr.flatten() for arr in pred_data if len(arr) > 0]  # Flatten to ensure 1D
        pred_colors = [amine_color_map[amine] for amine in unique_amines if len(test_predictions[amine_names_array == amine]) > 0]
        
        if pred_data:  # Only plot if there's data
            ax_histy.hist(pred_data, bins=bins, orientation='horizontal', 
                          color=pred_colors, alpha=0.7, stacked=True, 
                          edgecolor='black', linewidth=0.5)
        ax_histy.tick_params(labelbottom=False, labelleft=False, bottom=False)
        ax_histy.spines['top'].set_visible(False)
        ax_histy.spines['right'].set_visible(False)
        ax_histy.spines['bottom'].set_visible(False)

    plt.show()

    # Print metrics
    print(f"Test Set Metrics:")
    print(f"R²: {r2_test:.4f}")
    print(f"RMSE: {rmse_test:.4f}")
    print(f"Number of amines: {len(unique_amines)}")
    print(f"Amines: {', '.join(unique_amines)}")
    
    # Print count per amine
    for amine in unique_amines:
        count = np.sum(amine_names_array == amine)
        print(f"{amine}: {count} points")

# Collect predictions, true values, and amine names for test data
test_predictions, test_true_values, test_amine_names = collect_predictions_and_true_values_with_amines(model_2, test_loader, device)

# Debug: Check what we got
print(f"Number of predictions: {len(test_predictions)}")
print(f"Number of true values: {len(test_true_values)}")
print(f"Number of amine names: {len(test_amine_names)}")
print(f"First few amine names: {test_amine_names[:10]}")
print(f"Unique amine names: {sorted(set(test_amine_names))}")

# Check if arrays are 1D
print(f"Predictions shape: {test_predictions.shape}")
print(f"True values shape: {test_true_values.shape}")

# If predictions/true values are 2D, flatten them
if len(test_predictions.shape) > 1:
    test_predictions = test_predictions.flatten()
    print(f"Flattened predictions shape: {test_predictions.shape}")

if len(test_true_values.shape) > 1:
    test_true_values = test_true_values.flatten()
    print(f"Flattened true values shape: {test_true_values.shape}")

# Plot the test set parity plot colored by amine
plot_test_parity_plot_by_amine(test_true_values, test_predictions, test_amine_names) """

In [None]:
# AAD CALCULATION
models = {
    "Baseline": (model_1, lambda g: scale_graphs(g, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)),
    "PINN": (model_2, lambda g: scale_graphs(g, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)),
}

# Min points per temperature (filtering)
min_points_per_temp = 2

used_data_for_inference = original_test_data
# Extract unique values from the data
unique_systems = list(set(graph.name for graph in used_data_for_inference))
unique_concentrations = list(set(graph.conc for graph in used_data_for_inference))
unique_temperatures = list(set(graph.temp for graph in used_data_for_inference))
unique_references = list(set(graph.ref for graph in used_data_for_inference))
# Store AAD results
aad_results = {model_name: {} for model_name in models.keys()}

for amine in unique_systems:
    # Collect all graphs for this amine across all refs and concentrations
    graphs_amine = [g for g in used_data_for_inference if g.name == amine]
    
    # Temperature filtering
    temp_counts = {temp: sum(1 for g in graphs_amine if g.temp == temp) for temp in set(g.temp for g in graphs_amine)}
    filtered_temps = sorted([t for t, count in temp_counts.items() if count >= min_points_per_temp])
    
    if not filtered_temps:
        continue  # skip amines without enough points
    
    # Keep only graphs in filtered temperatures
    graphs_filtered = [g for g in graphs_amine if g.temp in filtered_temps]
    
    # For each model
    for model_name, (model, scaler) in models.items():
        abs_errors = []
        for g in graphs_filtered:
            g_pred = g.clone()
            
            # Ensure all scalar features are tensors
            g_pred.temp = torch.tensor([g_pred.temp], dtype=torch.float)
            g_pred.conc = torch.tensor([g_pred.conc], dtype=torch.float)
            g_pred.pco2 = torch.tensor([g_pred.pco2], dtype=torch.float)
            
            g_pred = scaler(g_pred).to(device)
            with torch.no_grad():
                pred = model(g_pred).cpu().numpy().flatten()
            
            abs_errors.append(np.abs(pred - g_pred.aco2))  # g_pred.aco2 can stay float
            
        # Average absolute deviation
        aad_results[model_name][amine] = np.mean(abs_errors)

# Print AAD per amine per model
for model_name, amines in aad_results.items():
    print(f"\n=== {model_name} AAD per amine ===")
    for amine, aad in amines.items():
        print(f"{amine}: {aad:.4f}")

In [None]:
# RMSE CALCULATION
models = {
    "Baseline": (model_1, lambda g: scale_graphs(g, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)),
    "PINN": (model_2, lambda g: scale_graphs(g, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)),
}

# Min points per temperature (filtering)
min_points_per_temp = 2
used_data_for_inference = original_test_data
# Extract unique values from the data
unique_systems = list(set(graph.name for graph in used_data_for_inference))
unique_concentrations = list(set(graph.conc for graph in used_data_for_inference))
unique_temperatures = list(set(graph.temp for graph in used_data_for_inference))
unique_references = list(set(graph.ref for graph in used_data_for_inference))
# Store RMSE results
rmse_results = {model_name: {} for model_name in models.keys()}

for amine in unique_systems:
    # Collect all graphs for this amine across all refs and concentrations
    graphs_amine = [g for g in used_data_for_inference if g.name == amine]
    
    # Temperature filtering
    temp_counts = {temp: sum(1 for g in graphs_amine if g.temp == temp) for temp in set(g.temp for g in graphs_amine)}
    filtered_temps = sorted([t for t, count in temp_counts.items() if count >= min_points_per_temp])
    
    if not filtered_temps:
        continue  # skip amines without enough points
    
    # Keep only graphs in filtered temperatures
    graphs_filtered = [g for g in graphs_amine if g.temp in filtered_temps]
    
    # For each model
    for model_name, (model, scaler) in models.items():
        squared_errors = []
        for g in graphs_filtered:
            g_pred = g.clone()
            
            # Ensure all scalar features are tensors
            g_pred.temp = torch.tensor([g_pred.temp], dtype=torch.float)
            g_pred.conc = torch.tensor([g_pred.conc], dtype=torch.float)
            g_pred.pco2 = torch.tensor([g_pred.pco2], dtype=torch.float)
            
            g_pred = scaler(g_pred).to(device)
            with torch.no_grad():
                pred = model(g_pred).cpu().numpy().flatten()
            
            # Calculate squared error
            squared_errors.append((pred - g_pred.aco2) ** 2)
        
        # Root Mean Square Error
        rmse_results[model_name][amine] = np.sqrt(np.mean(squared_errors))

# Print RMSE per amine per model
for model_name, amines in rmse_results.items():
    print(f"\n=== {model_name} RMSE per amine ===")
    for amine, rmse in amines.items():
        print(f"{amine}: {rmse:.4f}")

In [None]:
# MAPE CALCULATION
models = {
    "Baseline": (model_1, lambda g: scale_graphs(g, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)),
    "PINN": (model_2, lambda g: scale_graphs(g, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)),
}

# Min points per temperature (filtering)
min_points_per_temp = 2
used_data_for_inference = original_test_data
# Extract unique values from the data
unique_systems = list(set(graph.name for graph in used_data_for_inference))
unique_concentrations = list(set(graph.conc for graph in used_data_for_inference))
unique_temperatures = list(set(graph.temp for graph in used_data_for_inference))
unique_references = list(set(graph.ref for graph in used_data_for_inference))
# Store MAPE results
mape_results = {model_name: {} for model_name in models.keys()}

for amine in unique_systems:
    # Collect all graphs for this amine across all refs and concentrations
    graphs_amine = [g for g in used_data_for_inference if g.name == amine]
    
    # Temperature filtering
    temp_counts = {temp: sum(1 for g in graphs_amine if g.temp == temp) for temp in set(g.temp for g in graphs_amine)}
    filtered_temps = sorted([t for t, count in temp_counts.items() if count >= min_points_per_temp])
    
    if not filtered_temps:
        continue  # skip amines without enough points
    
    # Keep only graphs in filtered temperatures
    graphs_filtered = [g for g in graphs_amine if g.temp in filtered_temps]
    
    # For each model
    for model_name, (model, scaler) in models.items():
        percentage_errors = []
        for g in graphs_filtered:
            g_pred = g.clone()
            
            # Ensure all scalar features are tensors
            g_pred.temp = torch.tensor([g_pred.temp], dtype=torch.float)
            g_pred.conc = torch.tensor([g_pred.conc], dtype=torch.float)
            g_pred.pco2 = torch.tensor([g_pred.pco2], dtype=torch.float)
            
            g_pred = scaler(g_pred).to(device)
            with torch.no_grad():
                pred = model(g_pred).cpu().numpy().flatten()
            
            # Avoid division by zero
            if g_pred.aco2 != 0:
                percentage_errors.append(np.abs(pred - g_pred.aco2) / np.abs(g_pred.aco2))
        
        # Mean Absolute Percentage Error
        mape_results[model_name][amine] = np.mean(percentage_errors) * 100

# Print MAPE per amine per model
for model_name, amines in mape_results.items():
    print(f"\n=== {model_name} MAPE per amine (%) ===")
    for amine, mape in amines.items():
        print(f"{amine}: {mape:.2f}%")

In [None]:
# Organize original_test_data to see available combinations for isotherm visualization
# Get unique graph
unique_named_graphs = {}
for graph in test_data:
    name = graph['name']
    if name not in unique_named_graphs:
        unique_named_graphs[name] = graph  # Keep the first occurrence

# Optional: convert to a list
unique_graph_list = list(unique_named_graphs.values())
unique_list = DataLoader(unique_graph_list, batch_size=batch_size, shuffle=False)

data_summary = []
for point in original_test_data:
    data_summary.append({
        'name': point['name'],
        'temp': point.temp,
        'conc': point.conc,
        'pco2': point.pco2,
        'aco2': point.aco2
    })

# Convert to DataFrame for easier analysis
df_summary = pd.DataFrame(data_summary)

# Group by amine to see available conditions
print("Available experimental conditions by amine:")
amines = {}  # Final object to store all amine data

for amine in df_summary['name'].unique():
    amine_data = df_summary[df_summary['name'] == amine]
    print(f"\n{amine}:")
    
    # List unique temperatures
    unique_temps = sorted(amine_data['temp'].unique())
    print(f"  Temperatures (K): {[float(t) for t in unique_temps]}")
    
    # List unique concentrations
    unique_concs = sorted(amine_data['conc'].unique())
    print(f"  Concentrations (M): {[float(c) for c in unique_concs]}")
    
    print(f"  pCO2 range: {amine_data['pco2'].min():.2e} kPa - {amine_data['pco2'].max():.2e} kPa")
    print(f"  Total data points: {len(amine_data)}")
    
    # Store in final amines object
    amines[amine] = {
        'data': amine_data,  # Full DataFrame for this amine
        'temperatures': [float(t) for t in unique_temps],
        'concentrations': [float(c) for c in unique_concs],
        'pco2_range': (float(amine_data['pco2'].min()), float(amine_data['pco2'].max())),
        'num_points': len(amine_data),
        'temp_conc_combinations': list(amine_data[['temp', 'conc']].drop_duplicates().itertuples(index=False, name=None))
    }

# Create a simple list of amine names
amine_names = list(amines.keys())

print(f"\n\nFinal amines object created with {len(amines)} amines.")
print("Access data using: amines['amine_name']['property']")
print("Available properties: 'data', 'temperatures', 'concentrations', 'pco2_range', 'num_points', 'temp_conc_combinations'")

print(f"\n\nAvailable amines ({len(amine_names)}):")
for i, name in enumerate(amine_names, 1):
    print(f"  {i}. {name}")
    
print(f"\nUse amine_names list to access: {amine_names}")
print("Example: amines[amine_names[0]] to get data for the first amine")

In [None]:
""" fontsize = 16
# Set Times New Roman font for all text
plt.rcParams['font.family'] = 'Times New Roman'

# Prediction setup
pco2_values = np.logspace(0, 3, 100)
system = 'DGA'
temperature_values = [313.15, 353.15, 393.15, 0]
constant_conc = 1

# Define colors for each temperature
# Colors: automatic assignment from colormap
cmap = plt.get_cmap("Set1")  # supports many distinct colors
colors = {T: cmap(i) for i, T in enumerate(temperature_values)}
# Define markers and linestyles
linestyles = ['-', '--', '-.', ':']
markers = ['o', 's', 'D', '^', 'v']


# Filter the data for the specific system
system_data = [graph for graph in original_test_data if graph.name == system and graph.aco2]
plt.figure(figsize=(6, 5))

# Iterate over each temperature and generate predictions for continuous lines
for i, temp in enumerate(temperature_values):
    predictions = []
    
    for pco2 in pco2_values:
        graph = system_data[0].clone()
        graph.temp = torch.tensor([temp], dtype=torch.float)
        graph.conc = torch.tensor([constant_conc], dtype=torch.float)
        graph.pco2 = torch.tensor([pco2], dtype=torch.float)
        graph = scale_graphs(graph, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)
        graph = graph.to(device)
        with torch.no_grad():
            prediction = model_2(graph)
            predictions.append(prediction.cpu().numpy().flatten())
    
    predictions = np.array(predictions).flatten()
    
    # Plot continuous prediction line with specified color and linestyle
    plt.plot(pco2_values, predictions, color=colors[temp], linestyle=linestyles[i % len(linestyles)], label=f'Prediction at {temp} K')

# Retrieving actual data (if exists) for the specific system
filtered_graphs = [graph for graph in original_test_data 
                   if graph.name == system and
                   graph.conc == constant_conc]

# Store R² and RMSE values
r2_values = {}
rmse_values = {}

# Collect actual data points and calculate errors for each temperature
for i, temp in enumerate(temperature_values):
    pco2_act = []
    aco2_act = []
    aco2_pred = []
    
    # Collect experimental points for this temperature
    for graph in filtered_graphs:
        if graph.temp == temp:
            pco2_act.append(graph.pco2)
            aco2_act.append(graph.aco2)
    
    # Generate predictions ONLY at experimental PCO2 values
    for pco2 in pco2_act:
        graph = system_data[0].clone()
        graph.temp = torch.tensor([temp], dtype=torch.float)
        graph.conc = torch.tensor([constant_conc], dtype=torch.float)
        graph.pco2 = torch.tensor([pco2], dtype=torch.float)
        graph = scale_graphs(graph, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)
        graph = graph.to(device)
        with torch.no_grad():
            prediction = model_2(graph)
            aco2_pred.append(prediction.cpu().numpy().flatten()[0])

    # Compute R² and RMSE
    if len(aco2_act) > 0 and len(aco2_pred) > 0:
        r2_values[temp] = r2_score(aco2_act, aco2_pred)
        rmse_values[temp] = np.sqrt(mean_squared_error(aco2_act, aco2_pred))
        
        # Plot experimental points without error bars
        plt.scatter(pco2_act, aco2_act, 
                   marker=markers[i % len(markers)],  # Different marker for each temp
                   color=colors[temp],
                   label=f'Experimental at {temp} K',
                   s=50,  # marker size
                   edgecolors='black',  # add black edge for better visibility
                   linewidth=0.5)

# Generate metric summary string
metrics_text = "\n".join([f"T={temp} K: R²={r2_values[temp]:.2f}, RMSE={rmse_values[temp]:.2f}" for temp in r2_values])

# Plot customization
plt.xlabel('CO$_{2}$ Partial Pressure (kPa)', fontsize=fontsize)
plt.ylabel(r'CO$_{\mathrm{2}}$ Loading (mol$_{\mathrm{CO2}}$/mol$_{\mathrm{A}}$)', fontsize=fontsize)

# Add system name inside the plot area (top-left corner)
from matplotlib.ticker import MultipleLocator
plt.text(0.05, 0.95, f'{system} at {constant_conc} M', 
         transform=plt.gca().transAxes, 
         fontsize=fontsize-2,
         verticalalignment='top',
         bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.xlim(min(pco2_values), max(pco2_values))
plt.ylim(0, 1.5)
plt.xscale('log')

plt.gca().yaxis.set_major_locator(MultipleLocator(0.5))
plt.gca().yaxis.set_minor_locator(MultipleLocator(0.25))
plt.minorticks_on()

plt.grid(True, which='both', linewidth=0.5)
plt.legend(
    bbox_to_anchor=(0, 1.02, 1, 0.2),
    loc='lower left',
    mode='expand',
    frameon=False,
    fontsize=fontsize-4,
    ncol=2,
    markerscale=1,
    labelspacing=0.4,
    borderpad=0.5,
    columnspacing=1.0,
)

plt.tight_layout()
plt.show() """