In [None]:
import pandas as pd
import numpy as np
from sklearn.metrics import mean_squared_error, r2_score

def calculate_metrics_auto(csv_path, true_prefix="T_", pred_prefix="P_"):
    """
    Automatically calculate MSE, R², ADE, and NRMSE for all atoms in a CSV with
    columns like {true_prefix}<atom>_x0, {pred_prefix}<atom>_x0, etc.
    
    Returns two aggregate approaches:
    - Overall: Metrics computed on FLATTENED data (truly global, 1D)
    - Averaged: Mean of per-atom metrics (equal weight per atom)
    """
    df = pd.read_csv(csv_path)
    
    # Identify all atoms by parsing column names
    true_cols = [col for col in df.columns if col.startswith(true_prefix)]
    pred_cols = [col for col in df.columns if col.startswith(pred_prefix)]
    
    # Extract unique atoms
    atoms = sorted(set(col.split("_")[1] for col in true_cols))
    
    results = {}
    overall_true = []
    overall_pred = []
    
    for atom in atoms:
        # Columns for this atom
        t_cols = [f"{true_prefix}{atom}_x0", f"{true_prefix}{atom}_y0", f"{true_prefix}{atom}_z0"]
        p_cols = [f"{pred_prefix}{atom}_x0", f"{pred_prefix}{atom}_y0", f"{pred_prefix}{atom}_z0"]
        
        y_true = df[t_cols].to_numpy()
        y_pred = df[p_cols].to_numpy()
        
        # Add to overall
        overall_true.append(y_true)
        overall_pred.append(y_pred)
        
        # Calculate ADE per atom (Euclidean distance in 3D)
        distances = np.linalg.norm(y_true - y_pred, axis=1)
        ade = np.mean(distances)
        
        # Calculate MSE and R² (flattened to 1D for this atom)
        mse = mean_squared_error(y_true.flatten(), y_pred.flatten())
        r2 = r2_score(y_true.flatten(), y_pred.flatten())
        
        # Calculate NRMSE (normalized by range)
        rmse = np.sqrt(mse)
        y_range = np.ptp(y_true)  # range (max - min)
        nrmse = rmse / y_range if y_range > 0 else 0
        
        results[atom] = {"MSE": mse, "R2": r2, "ADE": ade, "NRMSE": nrmse}
    
    # Overall metrics across all atoms (FLATTENED to 1D - truly global)
    overall_true = np.concatenate(overall_true, axis=1)  # Shape: (n_samples, 12)
    overall_pred = np.concatenate(overall_pred, axis=1)
    
    # Flatten to 1D for truly global MSE and R2
    overall_true_flat = overall_true.flatten()
    overall_pred_flat = overall_pred.flatten()
    
    overall_mse = mean_squared_error(overall_true_flat, overall_pred_flat)
    overall_r2 = r2_score(overall_true_flat, overall_pred_flat)
    overall_rmse = np.sqrt(overall_mse)
    overall_range = np.ptp(overall_true_flat)
    overall_nrmse = overall_rmse / overall_range if overall_range > 0 else 0
    
    # ADE for overall: Euclidean distance across all 12 dimensions per sample
    overall_ade = np.mean(np.linalg.norm(overall_true - overall_pred, axis=1))
    
    results["Overall"] = {"MSE": overall_mse, "R2": overall_r2, "ADE": overall_ade, 
                          "NRMSE": overall_nrmse}
    
    # Averaged metrics (mean of per-atom metrics - equal weight per atom)
    atom_keys = [k for k in results.keys() if k != "Overall"]
    results["Averaged"] = {
        "MSE": np.mean([results[a]["MSE"] for a in atom_keys]),
        "R2": np.mean([results[a]["R2"] for a in atom_keys]),
        "ADE": np.mean([results[a]["ADE"] for a in atom_keys]),
        "NRMSE": np.mean([results[a]["NRMSE"] for a in atom_keys])
    }
    
    return results

# Example usage
if __name__ == "__main__":
    # Configurable path to the CSV file
    csv_file = "../results/mlp_test_full_data.csv"  # Adjust path as needed
    metrics = calculate_metrics_auto(csv_file)
    
    # Print header
    print(f"{'Atom':<10} {'MSE':<12} {'R2':<12} {'ADE':<12} {'NRMSE':<12}")
    print("-" * 58)
    
    # Print metrics for each atom (excluding Overall and Averaged)
    for atom, vals in metrics.items():
        if atom not in ["Overall", "Averaged"]:
            print(f"{atom:<10} {vals['MSE']:<12.5f} {vals['R2']:<12.5f} {vals['ADE']:<12.5f} "
                  f"{vals['NRMSE']:<12.5f}")
    
    print("-" * 58)
    # Print Overall (global flattened)
    vals = metrics["Overall"]
    print(f"{'Overall':<10} {vals['MSE']:<12.5f} {vals['R2']:<12.5f} {vals['ADE']:<12.5f} "
          f"{vals['NRMSE']:<12.5f}")
    
    # Print Averaged (mean of per-atom metrics)
    vals = metrics["Averaged"]
    print(f"{'Averaged':<10} {vals['MSE']:<12.5f} {vals['R2']:<12.5f} {vals['ADE']:<12.5f} "
          f"{vals['NRMSE']:<12.5f}")