In [6]:
import matplotlib.pyplot as plt
import timesfm
import pandas as pd
from sktime.performance_metrics.forecasting import mean_absolute_percentage_error, mean_absolute_scaled_error, mean_absolute_error
from sktime.forecasting.naive import NaiveForecaster
from sktime.forecasting.base import ForecastingHorizon
from sklearn.linear_model import LinearRegression
from sktime.forecasting.compose import make_reduction
from sktime.forecasting.statsforecast import StatsForecastAutoARIMA, StatsForecastAutoETS
import os
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.metrics import mean_squared_error
import numpy as np
from scipy.stats import f, pearsonr, ttest_ind
from sklearn.linear_model import Ridge, LinearRegression
import warnings
warnings.filterwarnings("ignore", message="possible convergence problem")
import matplotlib.pyplot as plt

In [7]:
##Initialize foundation model without fine tuning
HORIZON = 1
tfm = timesfm.TimesFm(
      hparams=timesfm.TimesFmHparams(
          backend="gpu",
          per_core_batch_size=32,
          horizon_len=HORIZON,
      ),
      checkpoint=timesfm.TimesFmCheckpoint(
          huggingface_repo_id="google/timesfm-1.0-200m-pytorch"),
  )

Fetching 3 files: 100%|████████████████████████████████████████████████████| 3/3 [00:00<00:00, 21845.33it/s]


In [8]:
#!pip install git+https://github.com/mb-BCA/pyMOU.git@master  --no-deps
import pymou as pm
import pymou.tools as pmt


# =============================================
# Test inhibitory excitatory
# =============================================
def generate_test_data_OU(N):
    """
    I changed this.
    Generates test data with clear X→Y→Z structure and signed influences
    X → Y (positive influence)
    Y → Z (negative influence)
    """
    n = 100  # More points for better statistics
    dt = 0.1
    sigma = 0.01  # Increased noise for more realistic scenario
    d = 0.8 #density

    # Clear causal structure:
    # X → Y , Y → Z
    # Negative here means exitatory, because it is negated later again
    theta = np.array([
        [1.0, 0.0, 0.0],    #
        [-5.0, 1.0, 0.0],   # Y depends on X
        [0.0, 5.0, 1.0]     # Z depends on Y
    ])


    theta = pmt.make_rnd_connectivity(N, density=d, w_min=-1/N/d, w_max=1/N/d)


    X = np.zeros((n, N))

    # Generate the process
    for t in range(1, n):
        X[t] = X[t-1] + dt * (-theta @ X[t-1]) + np.sqrt(dt) * sigma * np.random.randn(N)

    #return pd.DataFrame(X, columns=['x', 'y', 'z'])
    return X, theta

# =============================================
# TimesFM FORECAST FUNCTION
# =============================================
def timesfm_forecast(series_data, horizon):
    """
    TimesFM forecasting function - replace with actual tfm.forecast call
    """
    if hasattr(series_data, 'values'):
        data = series_data.values.flatten()
    else:
        data = np.array(series_data).flatten()

    if len(data) > 1:
        forecast = np.zeros(horizon)
        for i in range(horizon):
            #forecast[i] = mean_val + rho**i * (last_val - mean_val)

            forecast[i] = np.squeeze(tfm.forecast([data], freq=[0])[0])

    else:
        forecast = np.full(horizon, np.mean(data))

    return forecast

In [None]:
from statsmodels.stats.multitest import multipletests

def enhanced_timesfm_causality(target_series, covariate_series, max_lag=5, alpha=0.01, 
                              horizon=1, window_size=30, correction_method='fdr_bh'):
    """
    Enhanced causality detection using TimesFM residual analysis combined with traditional methods
    with multiple test correction for multiple lags
    """
    target = target_series.flatten()
    covar = covariate_series.flatten()
    min_len = min(len(target), len(covar))
    target = target[:min_len]
    covar = covar[:min_len]

    # Default return value
    default_return = {
        'p_value': 1.0,
        'p_value_corrected': 1.0,
        'avg_coefficient': 0.0,
        'effect_type': 'none',
        'significant': False,
        'significant_corrected': False,
        'optimal_lag': None,
        'confidence_interval': None,
        'residual_analysis': None,
        'traditional_analysis': None,
        'correction_method': correction_method
    }

    # PART 1: TimesFM Residual Analysis with multiple test correction
    residual_results = []
    p_values_residual = []
    
    try:
        # Get TimesFM predictions for the entire series
        predictions = []
        actuals = []

        # Rolling predictions using TimesFM
        for i in range(window_size, len(target) - 1):
            train_window = target[i-window_size:i]
            pred = timesfm_forecast(train_window, horizon)
            predictions.append(pred[0])
            actuals.append(target[i])

        predictions = np.array(predictions)
        actuals = np.array(actuals)
        residuals = actuals - predictions

        # Test if lagged covariates explain residuals
        for lag in range(1, max_lag + 1):
            if len(residuals) < lag + 10:
                continue

            # Align residuals with lagged covariates
            residuals_aligned = residuals[lag:]
            cov_lagged = covar[window_size + lag:window_size + lag + len(residuals_aligned)]

            if len(cov_lagged) != len(residuals_aligned):
                min_len_align = min(len(cov_lagged), len(residuals_aligned))
                residuals_aligned = residuals_aligned[:min_len_align]
                cov_lagged = cov_lagged[:min_len_align]

            if len(residuals_aligned) < 10:
                continue

            # Test correlation between lagged covariate and residuals
            if np.std(residuals_aligned) > 0 and np.std(cov_lagged) > 0:
                correlation, p_val_corr = pearsonr(cov_lagged, residuals_aligned)
                p_values_residual.append(p_val_corr)

                # Also test with linear regression
                reg = LinearRegression().fit(cov_lagged.reshape(-1, 1), residuals_aligned)
                r2_score = reg.score(cov_lagged.reshape(-1, 1), residuals_aligned)

                residual_results.append({
                    'lag': lag,
                    'correlation': correlation,
                    'correlation_p_value': p_val_corr,
                    'r2_residual_explained': r2_score,
                    'covariate_coeff': reg.coef_[0],
                    'n_samples': len(residuals_aligned)
                })

        # Apply multiple test correction to residual p-values
        if p_values_residual:
            reject_res, pvals_corrected_res, _, _ = multipletests(
                p_values_residual, alpha=alpha, method=correction_method
            )
            
            # Update residual results with corrected p-values
            for i, result in enumerate(residual_results):
                result['correlation_p_value_corrected'] = pvals_corrected_res[i]
                result['significant_corrected'] = reject_res[i]
                
    except Exception as e:
        print(f"TimesFM residual analysis failed: {e}")

    # PART 2: Traditional Granger-like analysis with AIC optimization and multiple test correction
    aic_values = []
    p_values_traditional = []
    traditional_results_by_lag = {}
    
    for lag in range(1, max_lag+1):
        n_samples = len(target) - lag - 1
        if n_samples < 10:  # Minimum samples required
            continue

        # Prepare data with proper dimension handling
        y = target[lag+1:]
        X_target = np.column_stack([target[i:i+n_samples] for i in range(lag, 0, -1)])
        X_covar = np.column_stack([covar[i:i+n_samples] for i in range(lag, 0, -1)])

        try:
            X = np.hstack([X_target, X_covar])
            model = LinearRegression().fit(X, y)
            pred = model.predict(X)
            rss = np.sum((y - pred)**2)
            if rss > 0:  # Avoid division by zero
                aic = 2*(lag*2) + min_len*np.log(rss/min_len)
                aic_values.append((lag, aic))
                
                # Calculate p-value for this specific lag
                X_restricted = X_target
                X_full = np.hstack([X_target, X_covar])

                # Fit models
                model_restricted = LinearRegression().fit(X_restricted, y)
                model_full = LinearRegression().fit(X_full, y)

                # Calculate statistics
                rss_restricted = np.sum((y - model_restricted.predict(X_restricted))**2)
                rss_full = np.sum((y - model_full.predict(X_full))**2)

                # F-test
                n = len(y)
                k = lag
                if rss_full > 0 and (rss_restricted - rss_full) > 0:
                    f_stat = ((rss_restricted - rss_full)/k) / (rss_full/(n - 2*k - 1))
                    p_val = 1 - f.cdf(f_stat, k, n - 2*k - 1)
                else:
                    f_stat, p_val = 0, 1.0
                    
                p_values_traditional.append(p_val)

                # Get effect direction
                covar_coeffs = model_full.coef_[lag:2*lag]
                avg_effect = np.median(covar_coeffs) if len(covar_coeffs) > 0 else 0.0

                # Store results by lag
                traditional_results_by_lag[lag] = {
                    'p_value': p_val,
                    'avg_coefficient': avg_effect,
                    'f_statistic': f_stat,
                    'aic': aic
                }
                
        except Exception as e:
            continue

    # Apply multiple test correction to traditional p-values
    corrected_traditional_results = []
    if p_values_traditional:
        reject_trad, pvals_corrected_trad, _, _ = multipletests(
            p_values_traditional, alpha=alpha, method=correction_method
        )
        
        # Update traditional results with corrected p-values
        for i, lag in enumerate(range(1, max_lag+1)):
            if lag in traditional_results_by_lag:
                result = traditional_results_by_lag[lag].copy()
                result['p_value_corrected'] = pvals_corrected_trad[i]
                result['significant_corrected'] = reject_trad[i]
                corrected_traditional_results.append((lag, result))

    if not aic_values and not residual_results:
        return default_return

    # Combine results from both methods
    best_residual_result = None
    if residual_results:
        # Find the best residual result using corrected p-values
        best_residual_result = min(residual_results, 
                                 key=lambda x: x.get('correlation_p_value_corrected', x['correlation_p_value']))

    traditional_result = None
    best_lag_traditional = None
    if corrected_traditional_results:
        # Find the best traditional result using corrected p-values
        best_lag_traditional = min(corrected_traditional_results, 
                                 key=lambda x: x[1].get('p_value_corrected', x[1]['p_value']))
        traditional_result = best_lag_traditional[1]

    # Use the method that gives the most significant result
    final_p_value = 1.0
    final_p_value_corrected = 1.0
    final_avg_coeff = 0.0
    final_effect_type = 'none'
    final_significant = False
    final_significant_corrected = False
    final_optimal_lag = None

    # Determine which method to use based on corrected significance
    if best_residual_result and traditional_result:
        residual_p_corrected = best_residual_result.get('correlation_p_value_corrected', best_residual_result['correlation_p_value'])
        traditional_p_corrected = traditional_result.get('p_value_corrected', traditional_result['p_value'])
        
        if residual_p_corrected < traditional_p_corrected:
            final_p_value = best_residual_result['correlation_p_value']
            final_p_value_corrected = residual_p_corrected
            final_avg_coeff = best_residual_result['covariate_coeff']
            final_optimal_lag = best_residual_result['lag']
            final_significant_corrected = best_residual_result.get('significant_corrected', False)
        else:
            final_p_value = traditional_result['p_value']
            final_p_value_corrected = traditional_p_corrected
            final_avg_coeff = traditional_result['avg_coefficient']
            final_optimal_lag = best_lag_traditional[0]
            final_significant_corrected = traditional_result.get('significant_corrected', False)
            
    elif best_residual_result:
        final_p_value = best_residual_result['correlation_p_value']
        final_p_value_corrected = best_residual_result.get('correlation_p_value_corrected', final_p_value)
        final_avg_coeff = best_residual_result['covariate_coeff']
        final_optimal_lag = best_residual_result['lag']
        final_significant_corrected = best_residual_result.get('significant_corrected', False)
        
    elif traditional_result:
        final_p_value = traditional_result['p_value']
        final_p_value_corrected = traditional_result.get('p_value_corrected', final_p_value)
        final_avg_coeff = traditional_result['avg_coefficient']
        final_optimal_lag = best_lag_traditional[0]
        final_significant_corrected = traditional_result.get('significant_corrected', False)

    # Determine effect type and significance
    final_significant = final_p_value < alpha
    final_effect_type = 'excitatory' if final_avg_coeff > 0 else 'inhibitory' if final_avg_coeff < 0 else 'none'

    # Bootstrap confidence interval for the final coefficient
    final_ci = None
    if final_significant_corrected and len(target) > window_size and final_optimal_lag is not None:
        n_boot = 200
        boot_effects = []
        for _ in range(n_boot):
            try:
                # Bootstrap sample
                boot_idx = np.random.choice(len(target) - window_size, size=len(target) - window_size, replace=True)
                boot_target = target[window_size:][boot_idx]
                boot_covar = covar[window_size:][boot_idx]

                if len(boot_target) > final_optimal_lag:
                    y_boot = boot_target[final_optimal_lag:]
                    x_boot = boot_covar[:-final_optimal_lag] if final_optimal_lag > 0 else boot_covar

                    if len(y_boot) == len(x_boot) and len(y_boot) > 0:
                        boot_reg = LinearRegression().fit(x_boot.reshape(-1, 1), y_boot)
                        boot_effects.append(boot_reg.coef_[0])
            except:
                continue

        if len(boot_effects) > 0:
            final_ci = (np.percentile(boot_effects, 2.5), np.percentile(boot_effects, 97.5))

    return {
        'p_value': final_p_value,
        'p_value_corrected': final_p_value_corrected,
        'avg_coefficient': final_avg_coeff,
        'effect_type': final_effect_type,
        'significant': final_significant,
        'significant_corrected': final_significant_corrected,
        'optimal_lag': final_optimal_lag,
        'confidence_interval': final_ci,
        'residual_analysis': best_residual_result,
        'traditional_analysis': traditional_result,
        'correction_method': correction_method
    }


def classical_granger_test_aic(target_series, covariate_series, max_lag=5, alpha=0.01):
    """Granger causality with AIC-based lag selection (no multiple testing)."""
    target = target_series.flatten()
    covar = covariate_series.flatten()
    min_len = min(len(target), len(covar))
    target = target[:min_len]
    covar = covar[:min_len]

    default_return = {
        'p_value': 1.0,
        'effect_type': 'none',
        'significant': False,
        'avg_coeff': 0.0,
        'lag': None
    }

    # === Step 1: AIC-based lag selection ===
    aic_values = []
    for lag in range(1, max_lag + 1):
        if len(target) < lag + 10:  # Minimum samples
            continue

        y = target[lag:]
        X_target = np.column_stack([target[lag - l:-l if l != 0 else None] for l in range(1, lag + 1)])
        X_covar = np.column_stack([covar[lag - l:-l if l != 0 else None] for l in range(1, lag + 1)])

        try:
            X_full = np.hstack([X_target, X_covar])
            model = LinearRegression().fit(X_full, y)
            pred = model.predict(X_full)
            rss = np.sum((y - pred) ** 2)
            aic = 2 * (lag * 2) + min_len * np.log(rss / min_len)  # AIC formula
            aic_values.append((lag, aic))
        except:
            continue

    if not aic_values:
        return default_return

    best_lag = min(aic_values, key=lambda x: x[1])[0]  # Lag with lowest AIC

    # === Step 2: Granger test ONLY at the best lag ===
    y = target[best_lag:]
    X_target = np.column_stack([target[best_lag - l:-l if l != 0 else None] for l in range(1, best_lag + 1)])
    X_covar = np.column_stack([covar[best_lag - l:-l if l != 0 else None] for l in range(1, best_lag + 1)])

    X_restricted = np.hstack([np.ones((len(y), 1)), X_target])  # Restricted model (no covariate)
    X_full = np.hstack([X_restricted, X_covar])                 # Full model (with covariate)

    try:
        model_restricted = LinearRegression().fit(X_restricted, y)
        model_full = LinearRegression().fit(X_full, y)

        rss_restricted = np.sum((y - model_restricted.predict(X_restricted)) ** 2)
        rss_full = np.sum((y - model_full.predict(X_full)) ** 2)

        # F-test (Granger causality)
        n = len(y)
        k = best_lag
        if rss_full > 0 and (rss_restricted - rss_full) > 0:
            f_stat = ((rss_restricted - rss_full) / k) / (rss_full / (n - 2 * k - 1))
            p_val = 1 - f.cdf(f_stat, k, n - 2 * k - 1)
        else:
            p_val = 1.0

        # Effect direction
        covar_coeffs = model_full.coef_[best_lag + 1:2 * best_lag + 1]  # Coefficients for covariate lags
        avg_coeff = np.median(covar_coeffs) if len(covar_coeffs) > 0 else 0.0
        effect_type = 'excitatory' if avg_coeff > 0 else 'inhibitory' if avg_coeff < 0 else 'none'

        # === Add CI-based significance check HERE ===
        if final_significant and final_ci is not None:
          # Only confirm significance if CI doesn't straddle zero
          final_significant = (final_ci[0] * final_ci[1] > 0)  # CI doesn't cross zero


        return {
            'p_value': p_val,
            'effect_type': effect_type,
            'significant': p_val < alpha,
            'avg_coeff': avg_coeff,
            'lag': best_lag,
            'f_statistic': f_stat if 'f_stat' in locals() else None
        }
    except:
        return default_return

from statsmodels.stats.multitest import multipletests
from scipy.stats import f

def classical_granger_test_multiple_correction(target_series, covariate_series, max_lag=5, alpha=0.01, correction_method='fdr_bh'):
    """Granger causality with multiple test correction across all lags."""
    target = target_series.flatten()
    covar = covariate_series.flatten()
    min_len = min(len(target), len(covar))
    target = target[:min_len]
    covar = covar[:min_len]

    default_return = {
        'p_value': 1.0,
        'p_value_corrected': 1.0,
        'effect_type': 'none',
        'significant': False,
        'significant_corrected': False,
        'avg_coeff': 0.0,
        'lag': None,
        'f_statistic': None,
        'correction_method': correction_method,
        'all_lag_results': []
    }

    # Collect results for all lags
    all_results = []
    p_values = []
    
    for lag in range(1, max_lag + 1):
        if len(target) < lag + 10:  # Minimum samples
            continue

        try:
            y = target[lag:]
            X_target = np.column_stack([target[lag - l:-l] for l in range(1, lag + 1)])
            X_covar = np.column_stack([covar[lag - l:-l] for l in range(1, lag + 1)])

            X_restricted = np.hstack([np.ones((len(y), 1)), X_target])  # Restricted model (no covariate)
            X_full = np.hstack([X_restricted, X_covar])                 # Full model (with covariate)

            model_restricted = LinearRegression().fit(X_restricted, y)
            model_full = LinearRegression().fit(X_full, y)

            rss_restricted = np.sum((y - model_restricted.predict(X_restricted)) ** 2)
            rss_full = np.sum((y - model_full.predict(X_full)) ** 2)

            # F-test (Granger causality)
            n = len(y)
            k = lag
            if rss_full > 0 and (rss_restricted - rss_full) > 0:
                f_stat = ((rss_restricted - rss_full) / k) / (rss_full / (n - 2 * k - 1))
                p_val = 1 - f.cdf(f_stat, k, n - 2 * k - 1)
            else:
                f_stat, p_val = 0, 1.0

            # Effect direction
            covar_coeffs = model_full.coef_[lag + 1:2 * lag + 1]  # Coefficients for covariate lags
            avg_coeff = np.median(covar_coeffs) if len(covar_coeffs) > 0 else 0.0
            effect_type = 'excitatory' if avg_coeff > 0 else 'inhibitory' if avg_coeff < 0 else 'none'

            # Store results for this lag
            lag_result = {
                'lag': lag,
                'p_value': p_val,
                'f_statistic': f_stat,
                'avg_coeff': avg_coeff,
                'effect_type': effect_type,
                'significant': p_val < alpha,
                'n_samples': n
            }
            
            all_results.append(lag_result)
            p_values.append(p_val)
            
        except Exception as e:
            # Skip this lag if there's an error
            continue

    if not all_results:
        return default_return

    # Apply multiple test correction
    p_values_array = np.array(p_values)
    if len(p_values_array) > 0:
        reject, pvals_corrected, _, _ = multipletests(
            p_values_array, alpha=alpha, method=correction_method
        )
        
        # Update results with corrected p-values
        for i, result in enumerate(all_results):
            result['p_value_corrected'] = pvals_corrected[i]
            result['significant_corrected'] = reject[i]

    # Find the most significant result (lowest corrected p-value)
    best_result = min(all_results, key=lambda x: x.get('p_value_corrected', x['p_value']))
    
    # Additional CI-based significance check (optional)
    final_significant = best_result.get('significant_corrected', best_result['significant'])
    
    # You could add bootstrap CI calculation here similar to the enhanced function
    # if final_significant and best_result['lag'] is not None:
    #     # Bootstrap CI code would go here
    #     pass

    return {
        'p_value': best_result['p_value'],
        'p_value_corrected': best_result.get('p_value_corrected', best_result['p_value']),
        'effect_type': best_result['effect_type'],
        'significant': best_result['significant'],
        'significant_corrected': best_result.get('significant_corrected', best_result['significant']),
        'avg_coeff': best_result['avg_coeff'],
        'lag': best_result['lag'],
        'f_statistic': best_result['f_statistic'],
        'correction_method': correction_method,
        'all_lag_results': all_results  # For debugging and comprehensive analysis
    }        

# =============================================
# VISUALIZATION AND TESTING
# =============================================
def plot_causal_effects(data):
    """Visualize the generated time series"""
    fig, axes = plt.subplots(3, 1, figsize=(12, 8))
    for i, col in enumerate(['x', 'y', 'z']):
        axes[i].plot(data[col])
        axes[i].set_title(col.upper())
    plt.tight_layout()
    plt.show()

def test_causality_relationships():
    """Main testing function with improved methods"""
    N = 10
    data, con_mat = generate_test_data_OU(N)
    #data, con_mat = simulate_mou_time_series(N, d=0.3, T=100, M=1)
    #data = data[0, :, :]
    #plot_causal_effects(data)


    #from google.colab import drive
    # Mount Google Drive
    #drive.mount('/content/drive')
    # Save to Drive
    save_path = 'con_mat.npy'  # Change to your desired path
    np.save(save_path, con_mat)

    split = int(1*len(data))  #Split only to increase difficulty, Causality is not a predictive test
    train = data[:split,:]

    # Initialize connectivity matrices
    enhanced_C = np.zeros((N, N))
    granger_C = np.zeros((N, N))

    print("\n=== IMPROVED CAUSALITY TEST RESULTS ===")
    save_path_LLM = 'enhanced_C.npy'  # Change to your desired path
    save_path_GC = 'granger_C.npy'  # Change to your desired path

    # Test all possible connections
    for target in range(N):
        for src in range(N):
            if target == src:
                continue  # Skip self-connections

            print(f"\nTesting {src} → {target}")

            # Enhanced method with TimesFM
            enhanced_res = enhanced_timesfm_causality(train[:, target], train[:, src], max_lag=5, alpha=0.05)
            enhanced_C[target, src] = (1 if enhanced_res['significant'] else 0) * (1 if enhanced_res['effect_type'] == 'excitatory' else -1)
            np.save(save_path_LLM, enhanced_C)
            print(f"[Enhanced] {'SIGNIFICANT' if enhanced_res['significant'] else 'not significant'}")
            print(f"  p={enhanced_res['p_value']:.4f}, effect={enhanced_res['avg_coefficient']:.3f} ({enhanced_res['effect_type']})")
            print(f"  optimal lag={enhanced_res.get('optimal_lag', 'N/A')}")
            if enhanced_res['residual_analysis']:
                print(f"  residual analysis: corr={enhanced_res['residual_analysis']['correlation']:.3f}")

            # Granger method
            #granger_res = classical_granger_test_aic(train[:, target], train[:, src], alpha=0.01)
            granger_res = classical_granger_test_multiple_correction(train[:, target], train[:, src], alpha=0.05)
            
            granger_C[target, src] = (1 if granger_res['significant']  else 0) * (1 if granger_res['effect_type'] == 'excitatory' else -1)
            np.save(save_path_GC, granger_C)
            print(f"[Granger] {'SIGNIFICANT' if granger_res['significant'] else 'not significant'}")
            print(f"  p={granger_res['p_value']:.4f}, effect={granger_res['avg_coeff']:.3f} ({granger_res['effect_type']})")

    # Compare with ground truth
    print("\n=== GROUND TRUTH CONNECTIVITY ===")
    # Create a copy to avoid modifying original
    converted = np.copy(con_mat)

    # Convert positive to 1, negative to -1
    converted[converted > 0] = 1
    converted[converted < 0] = -1

    # Set diagonal to 0
    np.fill_diagonal(converted, 0)
    con_mat = np.where(converted != 0, -converted, converted)
    print(con_mat)

    print("\n=== ENHANCED METHOD CONNECTIVITY ===")
    print(enhanced_C)

    diff_LLM = np.sum( (con_mat - enhanced_C) ** 2)
    print(diff_LLM)

    print("\n=== GRANGER METHOD CONNECTIVITY ===")
    print(granger_C)
    diff_GC = np.sum( (con_mat - granger_C) ** 2)
    print(diff_GC)


# =============================================
# MAIN EXECUTION
# =============================================
#if __name__ == "__main__":
test_causality_relationships()


=== IMPROVED CAUSALITY TEST RESULTS ===

Testing 1 → 0
[Enhanced] not significant
  p=0.0850, effect=0.036 (excitatory)
  optimal lag=3
  residual analysis: corr=-0.007
[Granger] not significant
  p=0.0692, effect=0.037 (excitatory)

Testing 2 → 0


In [None]:
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report, matthews_corrcoef
import matplotlib.pyplot as plt
import seaborn as sns

def calculate_causality_performance(ground_truth_mat, detected_mat, method_name="Method"):
    """
    Calculate performance metrics for binary causality detection
    
    Parameters:
    ground_truth_mat: Ground truth connectivity matrix (N x N)
    detected_mat: Detected causality matrix (N x N) - binary (0/1 or -1/0/1)
    method_name: Name of the method for reporting
    """
    # Flatten matrices for analysis (excluding diagonals)
    gt_flat = ground_truth_mat.flatten()
    detected_flat = detected_mat.flatten()
    
    # Remove diagonal elements (self-connections)
    n = ground_truth_mat.shape[0]
    mask = ~np.eye(n, dtype=bool).flatten()
    gt_flat = gt_flat[mask]
    detected_flat = detected_flat[mask]
    
    # Convert to binary: any non-zero value indicates causality
    gt_binary = (gt_flat != 0).astype(int)
    detected_binary = (detected_flat != 0).astype(int)
    
    # Calculate confusion matrix
    tn, fp, fn, tp = confusion_matrix(gt_binary, detected_binary).ravel()
    
    # Calculate metrics
    accuracy = (tp + tn) / (tp + tn + fp + fn)
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    mcc = matthews_corrcoef(gt_binary, detected_binary)
    
    # Calculate sign agreement for true positives
    sign_agreement = 0
    if tp > 0:
        # Only consider connections that exist in both ground truth and detection
        tp_mask = (gt_binary == 1) & (detected_binary == 1)
        # For sign comparison, use original values (not binary)
        gt_sign = np.sign(gt_flat[tp_mask])
        detected_sign = np.sign(detected_flat[tp_mask])
        sign_matches = gt_sign == detected_sign
        sign_agreement = np.mean(sign_matches) if len(sign_matches) > 0 else 0
    
    # Create results dictionary
    results = {
        'method': method_name,
        'true_positives': tp,
        'false_positives': fp,
        'true_negatives': tn,
        'false_negatives': fn,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'specificity': specificity,
        'f1_score': f1_score,
        'mcc': mcc,  # Matthews Correlation Coefficient
        'sign_agreement': sign_agreement,
        'total_connections': len(gt_binary),
        'actual_causal_connections': np.sum(gt_binary),
        'detected_causal_connections': np.sum(detected_binary),
        'confusion_matrix': np.array([[tn, fp], [fn, tp]])
    }
    
    return results

def plot_confusion_matrix(conf_matrix, method_name):
    """Plot confusion matrix"""
    plt.figure(figsize=(8, 6))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Predicted Negative', 'Predicted Positive'],
                yticklabels=['Actual Negative', 'Actual Positive'])
    plt.title(f'Confusion Matrix - {method_name}')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()

def plot_performance_comparison(results_list):
    """Plot comparison of different methods' performance"""
    methods = [res['method'] for res in results_list]
    metrics = ['accuracy', 'precision', 'recall', 'f1_score', 'specificity', 'mcc']
    metric_names = ['Accuracy', 'Precision', 'Recall', 'F1 Score', 'Specificity', 'MCC']
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    for i, (metric, metric_name) in enumerate(zip(metrics, metric_names)):
        values = [res[metric] for res in results_list]
        bars = axes[i].bar(methods, values)
        axes[i].set_title(metric_name)
        axes[i].set_ylim(0, 1)
        axes[i].tick_params(axis='x', rotation=45)
        
        # Add value labels on bars
        for bar, v in zip(bars, values):
            axes[i].text(bar.get_x() + bar.get_width()/2, v + 0.01, 
                        f'{v:.3f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()

def print_detailed_results(results):
    """Print detailed performance results"""
    print(f"\n=== {results['method']} PERFORMANCE ===")
    print(f"True Positives: {results['true_positives']}")
    print(f"False Positives: {results['false_positives']}")
    print(f"True Negatives: {results['true_negatives']}")
    print(f"False Negatives: {results['false_negatives']}")
    print(f"Accuracy: {results['accuracy']:.4f}")
    print(f"Precision: {results['precision']:.4f}")
    print(f"Recall (Sensitivity): {results['recall']:.4f}")
    print(f"Specificity: {results['specificity']:.4f}")
    print(f"F1 Score: {results['f1_score']:.4f}")
    print(f"MCC: {results['mcc']:.4f}")
    print(f"Sign Agreement: {results['sign_agreement']:.4f}")
    print(f"Total connections: {results['total_connections']}")
    print(f"Actual causal: {results['actual_causal_connections']}")
    print(f"Detected causal: {results['detected_causal_connections']}")
    
    # Print classification report
    print(f"\nClassification Report:")
    print(f"Precision: {results['precision']:.3f} (How many detected causalities are real)")
    print(f"Recall:    {results['recall']:.3f} (How many real causalities are detected)")
    print(f"Specificity: {results['specificity']:.3f} (How many non-causal are correctly identified)")

def evaluate_causality_performance(ground_truth_mat, enhanced_C, granger_C):
    """
    Main function to evaluate causality detection performance
    """
    # Convert ground truth to the same format as detected matrices
    gt_formatted = np.copy(ground_truth_mat)
    gt_formatted = np.where(gt_formatted != 0, -gt_formatted, gt_formatted)
    np.fill_diagonal(gt_formatted, 0)
    
    # Calculate performance for each method
    enhanced_results = calculate_causality_performance(gt_formatted, enhanced_C, "Enhanced_TimesFM")
    granger_results = calculate_causality_performance(gt_formatted, granger_C, "Granger_Causality")
    
    # Print detailed results
    print_detailed_results(enhanced_results)
    print_detailed_results(granger_results)
    
    # Plot confusion matrices
    plot_confusion_matrix(enhanced_results['confusion_matrix'], "Enhanced TimesFM")
    plot_confusion_matrix(granger_results['confusion_matrix'], "Granger Causality")
    
    # Plot performance comparison
    plot_performance_comparison([enhanced_results, granger_results])
    
    return enhanced_results, granger_results

# Example of how to integrate with your existing code
def test_causality_relationships_with_evaluation():
    """Main testing function with performance evaluation"""
    N = 10
    data, con_mat = generate_test_data_OU(N)
    
    # Your existing code to compute enhanced_C and granger_C...
    # enhanced_C = ... (from enhanced_timesfm_causality)
    # granger_C = ... (from classical_granger_test_aic)
    
    # Evaluate performance
    enhanced_results, granger_results = evaluate_causality_performance(con_mat, enhanced_C, granger_C)
    
    # Print summary
    print("\n" + "="*50)
    print("PERFORMANCE SUMMARY")
    print("="*50)
    print(f"{'Metric':<15} {'Enhanced':<10} {'Granger':<10}")
    print(f"{'Accuracy':<15} {enhanced_results['accuracy']:.3f}      {granger_results['accuracy']:.3f}")
    print(f"{'Precision':<15} {enhanced_results['precision']:.3f}      {granger_results['precision']:.3f}")
    print(f"{'Recall':<15} {enhanced_results['recall']:.3f}      {granger_results['recall']:.3f}")
    print(f"{'F1 Score':<15} {enhanced_results['f1_score']:.3f}      {granger_results['f1_score']:.3f}")
    print(f"{'False Positives':<15} {enhanced_results['false_positives']:<10} {granger_results['false_positives']:<10}")
    print(f"{'False Negatives':<15} {enhanced_results['false_negatives']:<10} {granger_results['false_negatives']:<10}")
    
    return enhanced_results, granger_results

if __name__ == "__main__":
    enhanced_results, granger_results = test_causality_relationships_with_evaluation()