# Metrics Correlation Analysis: GCM-RCM Error Metrics

**Objective**: Analyze correlations between error metrics computed for climate model (GCM-RCM) pairs.

**Data Structure**: For each combination of `(region, gridpoint, physical_variable, model)`, we have multiple error metrics (ACC, d, KGE, etc.). We compute correlations between these metrics to identify redundancy and orthogonal error facets.

**Output**: Scatter plot matrix (pairplot), correlation heatmaps, and statistical summaries.


## 1. Project Setup & Configuration

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.stats import pearsonr, spearmanr, kendalltau

plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

pd.set_option('display.width', 300)
pd.set_option('display.max_columns', None)

import os
OUTPUT_DIR = "./output"
os.makedirs(OUTPUT_DIR, exist_ok=True)


In [None]:
# Available regions. Set to None or empty list to include ALL regions. Analyze only France region: ['FR']
REGIONS = ['SC']

# Physical variables: 'ppt' (precipitation) or 'tas' (temperature). Set to None to include both. Only precipitation: ['ppt']
PHYSICAL_VARIABLES = ['ppt']

# Metric abbreviations to include in the correlation analysis. None = all metrics; or specify like ['ACC', 'd', 'KGE (2009)', 'BM']
METRIC_ABBREVIATIONS = ['H10 (MAHE)', 'd', 'dr', 'NED', 'MV', 'KGE (2009)']

# Optional filters for RCM and GCM IDs. None or [] = include all.
# RCM_IDS = None  # e.g., [1, 2, 3]
# GCM_IDS = None  # e.g., [3, 4]
RCM_IDS = [1,2,5,6,8,11,12]
GCM_IDS = [3,7,8]

# Correlation method: 'pearson', 'spearman', or 'kendall'
CORRELATION_METHOD = 'pearson'

print(f"Regions: {REGIONS if REGIONS else 'ALL'}")
print(f"Physical Variables: {PHYSICAL_VARIABLES if PHYSICAL_VARIABLES else 'ALL'}")
print(f"Metric Abbreviations: {METRIC_ABBREVIATIONS if METRIC_ABBREVIATIONS else 'ALL'}")
print(f"RCM IDs: {RCM_IDS if RCM_IDS else 'ALL'}")
print(f"GCM IDs: {GCM_IDS if GCM_IDS else 'ALL'}")
print(f"Correlation Method: {CORRELATION_METHOD}")


## 2. Load Data from Database

We'll query the error and metrics tables from the database (as described in the SQL schema).

In [None]:
import sys
sys.path.append('..') 
from utils import get_db_connection
engine = get_db_connection()
conn = engine.connect()


In [None]:
where_clauses = []

if REGIONS is not None and len(REGIONS) > 0:
    regions_str = "', '".join(REGIONS)
    where_clauses.append(f"error.region IN ('{regions_str}')")

if PHYSICAL_VARIABLES is not None and len(PHYSICAL_VARIABLES) > 0:
    vars_str = "', '".join(PHYSICAL_VARIABLES)
    where_clauses.append(f"error.physical_variable IN ('{vars_str}')")

if METRIC_ABBREVIATIONS is not None and len(METRIC_ABBREVIATIONS) > 0:
    metrics_str = "', '".join(METRIC_ABBREVIATIONS)
    where_clauses.append(f"metrics.metric_name IN ('{metrics_str}')")

if RCM_IDS is not None and len(RCM_IDS) > 0:
    rcm_str = ", ".join(str(x) for x in RCM_IDS)
    where_clauses.append(f"error.rcm_id IN ({rcm_str})")

if GCM_IDS is not None and len(GCM_IDS) > 0:
    gcm_str = ", ".join(str(x) for x in GCM_IDS)
    where_clauses.append(f"error.gcm_id IN ({gcm_str})")

where_clause = " AND ".join(where_clauses)
if where_clause:
    where_clause = f"WHERE {where_clause}"

query = f"""
SELECT 
    error.region,
    error.gridpoint,
    error.physical_variable,
    error.model,
    error.rcm_id,
    error.gcm_id,
    error.metric_id,
    metrics.metric_name,
    error.mat_vector
FROM 
    error 
LEFT JOIN 
    metrics ON metrics.id = error.metric_id
{where_clause}
"""

df_raw = pd.read_sql_query(query, conn)
print(f"✓ Data loaded: {len(df_raw)} rows")
print(f"  Columns: {list(df_raw.columns)}")
print(f"\nFirst few rows:")
print(df_raw.head(5))


## 3. Data Filtering & Preprocessing

Apply the filters to select the region(s), physical variable(s), and metric(s) of interest.

In [None]:
df = df_raw.copy()

print(f"✓ Data retrieved (pre-filtered by database query):")
print(f"  Total rows: {len(df)}")
print(f"  Regions: {sorted(df['region'].unique())}")
print(f"  Physical variables: {sorted(df['physical_variable'].unique())}")
print(f"  Unique metrics: {df['metric_name'].nunique()}")

print(f"\nMissing values:")
print(df.isnull().sum())

print(f"\nDataset shape: {df.shape}")
print(f"Metrics: {sorted(df['metric_name'].unique())}")
print(f"Unique models: {df['model'].nunique()}")
print(f"Unique gridpoints: {df['gridpoint'].nunique()}")


## 4. Create Pivot Table for Correlation Analysis

Transform the data so that each row represents a unique `(region, gridpoint, physical_variable, model)` grouping,  
and each column is a metric. This structure allows us to compute correlations between metrics.

In [None]:
# Create a grouping key for (region, gridpoint, physical_variable, model)
df['group_key'] = df.apply(
    lambda row: f"{row['region']}_{row['gridpoint']}_{row['physical_variable']}_{row['model']}", 
    axis=1
)

# Pivot: rows are group_keys, columns are metric names, values are error values
pivot_df = df.pivot_table(
    index='group_key',
    columns='metric_name',
    values='mat_vector',
    aggfunc='first'  # Should be unique per group_key, but just in case
)

print(f"✓ Pivot table created:")
print(f"  Shape: {pivot_df.shape}")
print(f"  Rows (unique group_keys): {pivot_df.shape[0]}")
print(f"  Columns (metrics): {pivot_df.shape[1]}")
print(f"\nMetrics in pivot table:")
print(pivot_df.columns.tolist())
print(f"\nFirst few rows:")
print(pivot_df.head())

print(f"\nMissing values per metric:")
print(pivot_df.isnull().sum())


In [None]:
pivot_df_clean = pivot_df.dropna()

print(f"  Original: {pivot_df.shape[0]} rows")
print(f"  After dropna: {pivot_df_clean.shape[0]} rows")
print(f"  Rows removed: {pivot_df.shape[0] - pivot_df_clean.shape[0]}")


## 5. Compute Pairwise Correlations

Calculate correlation coefficients between all pairs of metrics.

In [None]:
corr_matrix = pivot_df_clean.corr(method=CORRELATION_METHOD)

print(f"✓ Correlation matrix computed ({CORRELATION_METHOD}):")
print(f"  Shape: {corr_matrix.shape}")
print(f"\nCorrelation Matrix:")
print(corr_matrix)

In [None]:
def pairwise_corr_with_pvalues(df, method='pearson'):
    """Compute pairwise correlations and p-values."""
    cols = df.columns
    n_cols = len(cols)
    results = []
    
    for i in range(n_cols):
        for j in range(i + 1, n_cols):
            metric1 = cols[i]
            metric2 = cols[j]
            
            if method == 'pearson':
                corr, pval = pearsonr(df[metric1], df[metric2])
            elif method == 'spearman':
                corr, pval = spearmanr(df[metric1], df[metric2])
            elif method == 'kendall':
                corr, pval = kendalltau(df[metric1], df[metric2])
            else:
                raise ValueError(f"Unknown method: {method}")
            
            results.append({
                'Metric 1': metric1,
                'Metric 2': metric2,
                'Correlation': corr,
                'P-value': pval,
                'N': len(df)
            })
    
    return pd.DataFrame(results)

In [None]:
pairwise_df = pairwise_corr_with_pvalues(pivot_df_clean, method=CORRELATION_METHOD)

pairwise_df['Abs_Corr'] = pairwise_df['Correlation'].abs()
pairwise_df = pairwise_df.sort_values('Abs_Corr', ascending=False).drop('Abs_Corr', axis=1)

print(f"\n✓ Pairwise correlations ({len(pairwise_df)} pairs):")
print(pairwise_df.to_string(index=False))

- $H_0$: zero correlation  
- $H_1$: non-zero correlation

This **p-value** is the two-sided probability of observing a correlation at least as extreme as the measured one under the null hypothesis.

Assumptions
- **Pearson**: assumes a linear relationship and (approximately) bivariate normality.  
- **Spearman** / **Kendall**: rank-based (nonparametric), less sensitive to outliers or non-normality.

Can also be seen as association:
- $H_0$: The population correlation is zero (no association)
- $H_1$: The population correlation is non-zero (association exists)

## 6. Scatter Plot Matrix (Pairplot)

Generate a pairplot to visualize relationships between metrics (similar to the attached image).

In [None]:
plt.figure(figsize=(16, 16))
g = sns.pairplot(
    pivot_df_clean,
    diag_kind='hist',
    plot_kws={'alpha': 0.6, 's': 30},
    diag_kws={'bins': 30, 'edgecolor': 'k', 'alpha': 0.7}
)

def add_correlation_annotations(g, corr_matrix, pairwise_df):
    """Add correlation values with p-value based significance markers to upper triangle of pairplot."""
    for i, ax_row in enumerate(g.axes):
        for j, ax in enumerate(ax_row):
            if i < j:
                col_i = g.data.columns[i]
                col_j = g.data.columns[j]
                corr_val = corr_matrix.loc[col_i, col_j]
                
                pair_row = pairwise_df[
                    ((pairwise_df['Metric 1'] == col_i) & (pairwise_df['Metric 2'] == col_j)) |
                    ((pairwise_df['Metric 1'] == col_j) & (pairwise_df['Metric 2'] == col_i))
                ]
                
                if not pair_row.empty:
                    pval = pair_row.iloc[0]['P-value']
                    sig_marker = "***" if pval < 0.001 else "**" if pval < 0.01 else "*" if pval < 0.05 else "ns"
                else:
                    sig_marker = "?"
                
                ax.text(
                    0.5, 0.5,
                    f'{corr_val:.2f}{sig_marker}',
                    transform=ax.transAxes,
                    ha='center', va='center',
                    fontsize=12, fontweight='bold',
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8)
                )
                ax.set_title('')

add_correlation_annotations(g, corr_matrix, pairwise_df)

plt.suptitle(
    f'Metrics Correlation Analysis - Scatter Plot Matrix\n'
    f'Region(s): {REGIONS if REGIONS else "ALL"} | ' 
    f'Variable(s): {PHYSICAL_VARIABLES if PHYSICAL_VARIABLES else "ALL"} | ' 
    f'RCM(s): {RCM_IDS if RCM_IDS else "ALL"} | ' 
    f'GCM(s): {GCM_IDS if GCM_IDS else "ALL"} | ' 
    f'Method: {CORRELATION_METHOD}',
    fontsize=14, y=1.00
)
plt.tight_layout()

output_file = f"{OUTPUT_DIR}/pairplot_{CORRELATION_METHOD}.png"
plt.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"✓ Pairplot saved to: {output_file}")
plt.show()


The `**` symbols in the pairplot are significance markers that indicate the statistical significance of the correlation coefficients displayed in the upper triangle of the plot.

- `***` = p-value < 0.001 (highly significant)
- `**` = p-value < 0.01 (very significant)
- `*` = p-value < 0.05 (significant)
- ns = not significant (p-value ≥ 0.05)

## 7. Correlation Heatmap

Visualize the correlation matrix as a heatmap with hierarchical clustering.

In [None]:
plt.figure(figsize=(12, 10))
sns.clustermap(
    corr_matrix,
    cmap='RdBu_r',
    center=0,
    vmin=-1, vmax=1,
    annot=True,
    fmt='.2f',
    cbar_kws={'label': f'{CORRELATION_METHOD.capitalize()} Correlation'},
    linewidths=0.5,
    linecolor='gray'
)

plt.suptitle(
    f'Hierarchical Correlation Heatmap ({CORRELATION_METHOD})\n'
    f'Region(s): {REGIONS if REGIONS else "ALL"} | Variable(s): {PHYSICAL_VARIABLES if PHYSICAL_VARIABLES else "ALL"} | RCM(s): {RCM_IDS if RCM_IDS else "ALL"} | GCM(s): {GCM_IDS if GCM_IDS else "ALL"}',
    fontsize=14, y=0.98
)
plt.tight_layout()

output_file = f"{OUTPUT_DIR}/heatmap_clustered_{CORRELATION_METHOD}.png"
plt.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"✓ Heatmap saved to: {output_file}")
plt.show()


## 8. Export Results

Save correlation matrices and pairwise correlations to CSV files.

In [None]:
# Save correlation matrix to CSV
corr_matrix.to_csv(f"{OUTPUT_DIR}/correlation_matrix_{CORRELATION_METHOD}.csv")
print(f"✓ Correlation matrix saved to: {OUTPUT_DIR}/correlation_matrix_{CORRELATION_METHOD}.csv")

# Save pairwise correlations to CSV
pairwise_df.to_csv(f"{OUTPUT_DIR}/pairwise_correlations_{CORRELATION_METHOD}.csv", index=False)
print(f"✓ Pairwise correlations saved to: {OUTPUT_DIR}/pairwise_correlations_{CORRELATION_METHOD}.csv")

# Save summary statistics
summary_stats = pd.DataFrame({
    'Metric': corr_matrix.columns,
    'Mean': pivot_df_clean.mean(),
    'Std': pivot_df_clean.std(),
    'Min': pivot_df_clean.min(),
    'Max': pivot_df_clean.max(),
})
summary_stats.to_csv(f"{OUTPUT_DIR}/metrics_summary_statistics.csv", index=False)
print(f"✓ Summary statistics saved to: {OUTPUT_DIR}/metrics_summary_statistics.csv")

# Save analysis metadata
metadata = {
    'Analysis Date': pd.Timestamp.now(),
    'Correlation Method': CORRELATION_METHOD,
    'Regions': str(REGIONS),
    'Physical Variables': str(PHYSICAL_VARIABLES),
    'RCM IDs': str(RCM_IDS),
    'GCM IDs': str(GCM_IDS),
    'Metric Abbreviations': str(METRIC_ABBREVIATIONS),
    'Number of Rows': pivot_df_clean.shape[0],
    'Number of Metrics': pivot_df_clean.shape[1],
    'Metrics': ', '.join(pivot_df_clean.columns)
}

metadata_df = pd.DataFrame(list(metadata.items()), columns=['Parameter', 'Value'])
metadata_df.to_csv(f"{OUTPUT_DIR}/analysis_metadata.csv", index=False)
print(f"✓ Analysis metadata saved to: {OUTPUT_DIR}/analysis_metadata.csv")

print(f"\n✓ All results exported to: {OUTPUT_DIR}/")


## 9. Summary & Insights

Summary of the correlation analysis results.

In [None]:
print("=" * 90)
print("ANALYSIS SUMMARY")
print("=" * 90)
print(f"\nData Summary:")
print(f"  Total rows analyzed: {pivot_df_clean.shape[0]}")
print(f"  Number of metrics: {pivot_df_clean.shape[1]}")
print(f"  Correlation method: {CORRELATION_METHOD}")

print(f"\nMetrics included:")
for i, metric in enumerate(pivot_df_clean.columns, 1):
    print(f"  {i}. {metric}")

print(f"\nAll Pairwise Correlations (sorted by absolute correlation value):")
print("-" * 90)
# Create a formatted display with all correlations and p-values
display_df = pairwise_df.copy()
display_df['Significance'] = display_df['P-value'].apply(
    lambda p: "***" if p < 0.001 else "**" if p < 0.01 else "*" if p < 0.05 else "ns"
)
display_df = display_df[['Metric 1', 'Metric 2', 'Correlation', 'P-value', 'Significance']]
print(display_df.to_string(index=False))

print("\n" + "-" * 90)
print(f"Significance levels (α = 0.05):")
sig_count = (pairwise_df['P-value'] < 0.05).sum()
print(f"  Significant correlations: {sig_count} / {len(pairwise_df)} ({100*sig_count/len(pairwise_df):.1f}%)")

print("\nNote: *** p<0.001, ** p<0.01, * p<0.05, ns = not significant")
print("=" * 90)
