Testing effective connectivity from time series prediction using foundation models. The initial case is without fine tuning

In [2]:
pip install timesfm sktime

Collecting timesfm
  Downloading timesfm-1.2.9-py3-none-any.whl.metadata (14 kB)
Collecting sktime
  Downloading sktime-0.37.0-py3-none-any.whl.metadata (34 kB)
Collecting einshape>=1.0.0 (from timesfm)
  Downloading einshape-1.0-py3-none-any.whl.metadata (706 bytes)
Collecting utilsforecast>=0.1.10 (from timesfm)
  Downloading utilsforecast-0.2.12-py3-none-any.whl.metadata (7.6 kB)
Collecting joblib<1.5,>=1.2.0 (from sktime)
  Downloading joblib-1.4.2-py3-none-any.whl.metadata (5.4 kB)
Collecting scikit-base<0.13.0,>=0.6.1 (from sktime)
  Downloading scikit_base-0.12.2-py3-none-any.whl.metadata (8.8 kB)
Collecting InquirerPy==0.3.4 (from huggingface_hub[cli]>=0.23.0->timesfm)
  Downloading InquirerPy-0.3.4-py3-none-any.whl.metadata (8.1 kB)
Collecting pfzy<0.4.0,>=0.3.1 (from InquirerPy==0.3.4->huggingface_hub[cli]>=0.23.0->timesfm)
  Downloading pfzy-0.3.4-py3-none-any.whl.metadata (4.9 kB)
Downloading timesfm-1.2.9-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
import matplotlib.pyplot as plt
import timesfm
import pandas as pd
from sktime.performance_metrics.forecasting import mean_absolute_percentage_error, mean_absolute_scaled_error, mean_absolute_error
from sktime.forecasting.naive import NaiveForecaster
from sktime.forecasting.base import ForecastingHorizon
from sklearn.linear_model import LinearRegression
from sktime.forecasting.compose import make_reduction
from sktime.forecasting.statsforecast import StatsForecastAutoARIMA, StatsForecastAutoETS
import os
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.metrics import mean_squared_error
import numpy as np
from scipy.stats import f
from sklearn.linear_model import Ridge, LinearRegression
import warnings
warnings.filterwarnings("ignore", message="possible convergence problem")

 See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs.
Loaded PyTorch TimesFM, likely because python version is 3.11.12 (main, Apr  9 2025, 08:55:54) [GCC 11.4.0].


In [4]:
##Initialize foundation model without fine tuning
tfm = timesfm.TimesFm(
      hparams=timesfm.TimesFmHparams(
          backend="gpu",
          per_core_batch_size=32,
          horizon_len=128,
      ),
      checkpoint=timesfm.TimesFmCheckpoint(
          huggingface_repo_id="google/timesfm-1.0-200m-pytorch"),
  )

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

README.md:   0%|          | 0.00/5.58k [00:00<?, ?B/s]

torch_model.ckpt:   0%|          | 0.00/814M [00:00<?, ?B/s]

.gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

In [5]:
##Load time series
def load_data(path, verbose = False):
    dataset = pd.read_csv(path, sep='\t', header = None)
    if verbose:
        print(dataset.columns)
        print(len(dataset))
        print(dataset.head())
    return dataset
control_data = load_data('sub-CON001_ses-control_task-rest_space-MNI152NLin2009cAsym_atlas-Schaefer117_timeseries.tsv')

In [6]:
# convert the structure to be read by the model
def split_train_test(data, break_index):
    split = data.index < break_index
    return data[split], data[~split]
control_data_train, control_data_test = split_train_test(control_data, 300)

def convert_to_timefm(data):
    data_time_fm = []
    for col in data.columns:
        data_time_fm += [data[col]]
    return data_time_fm
control_data_train_for_time_fm = convert_to_timefm(control_data_train)


In [None]:
# Individual prediction #RUN only for test
# freq=[0] means no seasonality
#predicted = tfm.forecast(control_data_train_for_time_fm, freq=[0] * len(control_data_train.columns))[0]
#predicted = tfm.forecast(control_data_train_for_time_fm, freq=[0] * len(control_data_train_for_time_fm.columns))[0] #adding this will check only one series [0]

(100, 128)
100


In [9]:

# LLM Granger-like Causality (for preprocessed data)
new_horizon = len(control_data_test.iloc[:, 0])

def safe_timesfm_forecast(series_data, new_horizon):
    """TimesFM forecast for preprocessed fMRI data"""
    try:
        if hasattr(series_data, 'values'):
            data = series_data.values
        else:
            data = np.array(series_data)

        data = data.flatten()
        data = data[np.isfinite(data)]

        if len(data) == 0:
            raise ValueError("No valid data")

        # TimesFM forecast
        predicted = tfm.forecast([data], freq=[0])[0]
        return predicted[0, :new_horizon], "timesfm"

    except Exception as e:
        print(f"    TimesFM failed: {e}, using AR(1) fallback")
        return ar1_forecast_simple(data, new_horizon), "ar1_fallback"

def ar1_forecast_simple(data, horizon):
    """Simple AR(1) model as fallback"""
    if len(data) < 3:
        return np.full(horizon, np.mean(data) if len(data) > 0 else 0)

    # Fit AR(1): y_t = c + φ*y_{t-1} + ε_t
    y = data[1:]
    x = data[:-1]

    try:
        # Simple linear regression
        phi = np.cov(x, y)[0, 1] / np.var(x)
        c = np.mean(y) - phi * np.mean(x)

        # Multi-step forecast
        forecast = []
        last_val = data[-1]

        for _ in range(horizon):
            next_val = c + phi * last_val
            forecast.append(next_val)
            last_val = next_val

        return np.array(forecast)

    except:
        return np.full(horizon, np.mean(data))

def compute_lag1_covariate_effect(target_series, covariate_series):
    """
    Compute lag-1 covariate effect: target[t] ~ covariate[t-1]
    Returns coefficient and model improvement
    """
    target_vals = np.array(target_series).flatten()
    cov_vals = np.array(covariate_series).flatten()

    # Align series length
    min_len = min(len(target_vals), len(cov_vals))
    if min_len < 15:  # Need reasonable sample size
        return 0, 0, 0, "insufficient_data"

    target_vals = target_vals[:min_len]
    cov_vals = cov_vals[:min_len]

    # Create lag-1 relationship
    y = target_vals[1:]              # target from t=1 onwards
    x_target_lag = target_vals[:-1]  # target[t-1]
    x_cov_lag = cov_vals[:-1]        # covariate[t-1]

    try:
        # Restricted model: target[t] = α + β*target[t-1] + ε
        X_restricted = np.column_stack([np.ones(len(y)), x_target_lag])
        coeffs_restricted = np.linalg.lstsq(X_restricted, y, rcond=None)[0]
        y_pred_restricted = X_restricted @ coeffs_restricted
        rss_restricted = np.sum((y - y_pred_restricted) ** 2)

        # Full model: target[t] = α + β*target[t-1] + γ*covariate[t-1] + ε
        X_full = np.column_stack([np.ones(len(y)), x_target_lag, x_cov_lag])
        coeffs_full = np.linalg.lstsq(X_full, y, rcond=None)[0]
        y_pred_full = X_full @ coeffs_full
        rss_full = np.sum((y - y_pred_full) ** 2)

        # Extract covariate coefficient
        covariate_coeff = coeffs_full[2]
        rss_improvement = rss_restricted - rss_full

        # Calculate R-squared improvement
        tss = np.sum((y - np.mean(y)) ** 2)
        r2_restricted = 1 - (rss_restricted / tss)
        r2_full = 1 - (rss_full / tss)
        r2_improvement = r2_full - r2_restricted

        return covariate_coeff, rss_improvement, r2_improvement, "success"

    except Exception as e:
        return 0, 0, 0, f"failed_{e}"

def forecast_with_lag1_covariate(target_series, covariate_series, new_horizon):
    """
    Forecast using lag-1 covariate relationship for fMRI
    """
    try:
        # Get baseline TimesFM forecast
        baseline_forecast, baseline_method = safe_timesfm_forecast(target_series, new_horizon)

        # Compute lag-1 covariate effect
        cov_coeff, rss_improvement, r2_improvement, effect_status = compute_lag1_covariate_effect(
            target_series, covariate_series
        )

        if effect_status != "success" or abs(cov_coeff) < 0.01 or r2_improvement < 0.001:
            return baseline_forecast, "no_covariate_effect", cov_coeff, r2_improvement

        # For forecasting, we need to project the covariate influence
        target_vals = np.array(target_series).flatten()
        cov_vals = np.array(covariate_series).flatten()

        # Simple approach: assume covariate stays at its recent average for forecast period
        recent_cov_window = min(10, len(cov_vals))
        recent_cov_mean = np.mean(cov_vals[-recent_cov_window:])

        # Alternative: use trend
        if len(cov_vals) >= 5:
            recent_trend = np.mean(np.diff(cov_vals[-5:]))
        else:
            recent_trend = 0

        # Create covariate adjustments
        adjustments = []
        current_cov = cov_vals[-1] if len(cov_vals) > 0 else 0

        for step in range(new_horizon):
            # Project covariate value (using trend + mean reversion)
            projected_cov = current_cov + step * recent_trend * 0.5  # Damped trend

            # Apply covariate effect with decay over forecast horizon
            decay_factor = 0.95 ** step  # Effect decays over time
            adjustment = cov_coeff * projected_cov * decay_factor
            adjustments.append(adjustment)

        # Combine baseline with covariate adjustments
        adjusted_forecast = baseline_forecast + np.array(adjustments)

        return adjusted_forecast, "covariate_adjusted", cov_coeff, r2_improvement

    except Exception as e:
        print(f"    Covariate adjustment failed: {e}")
        baseline_forecast, _ = safe_timesfm_forecast(target_series, new_horizon)
        return baseline_forecast, "fallback", 0, 0

def enhanced_granger_test_fmri(target_test, pred_restricted, pred_full):
    """
    Enhanced statistical test for fMRI Granger causality
    Uses multiple approaches to increase sensitivity
    """
    min_len = min(len(target_test), len(pred_restricted), len(pred_full))
    if min_len < 8:  # Need minimum observations
        return np.nan, 1.0, "insufficient_data"

    y_true = np.array(target_test[:min_len])
    y_restricted = np.array(pred_restricted[:min_len])
    y_full = np.array(pred_full[:min_len])

    # Calculate prediction errors
    errors_restricted = y_true - y_restricted
    errors_full = y_true - y_full

    # Method 1: Standard F-test
    RSS_restricted = np.sum(errors_restricted ** 2)
    RSS_full = np.sum(errors_full ** 2)

    # Check for improvement
    if RSS_full >= RSS_restricted * 0.99:  # Allow for small numerical differences
        return 0.0, 1.0, "no_improvement"

    # F-test calculation
    n = min_len
    k = 1  # One additional parameter (covariate effect)

    if RSS_full > 0:
        f_stat = ((RSS_restricted - RSS_full) / k) / (RSS_full / (n - 2))
        if f_stat > 0:
            p_val_f = 1 - f.cdf(f_stat, k, n - 2)
        else:
            p_val_f = 1.0
    else:
        f_stat = 0.0
        p_val_f = 1.0

    # Method 2: Paired t-test on absolute errors (more sensitive for time series)
    try:
        abs_errors_restricted = np.abs(errors_restricted)
        abs_errors_full = np.abs(errors_full)

        # One-sided paired t-test: are full model errors smaller?
        from scipy.stats import ttest_rel
        t_stat, p_val_t = ttest_rel(abs_errors_restricted, abs_errors_full,
                                   alternative='greater')

        # Use the more significant result
        if p_val_t < p_val_f:
            return t_stat, p_val_t, "paired_ttest_abs_errors"
        else:
            return f_stat, p_val_f, "f_test"

    except:
        return f_stat, p_val_f, "f_test_only"

# === Main execution ===
print("=== fMRI-Optimized LLM Granger Causality (Preprocessed Data) ===")
print(f"Forecast horizon: {new_horizon} time points")
print("Using lag-1 relationships for fMRI temporal dynamics\n")

all_results = []

for target_index in range(len(control_data.columns)):
    target_name = control_data.columns[target_index]
    print(f"=== Target: {target_name} ===")

    target_train = control_data_train.iloc[:, target_index]
    target_test = control_data_test.iloc[:, target_index]

    # Baseline forecast (restricted model)
    try:
        pred_baseline, baseline_method = safe_timesfm_forecast(target_train, new_horizon)
        mse_baseline = mean_squared_error(target_test, pred_baseline)
        print(f"Baseline MSE: {mse_baseline:.6f}")
    except Exception as e:
        print(f"Baseline forecast failed: {e}")
        continue

    target_results = []

    for cov_index in range(len(control_data.columns)):
        if cov_index == target_index:
            continue

        cov_name = control_data.columns[cov_index]
        covariate_train = control_data_train.iloc[:, cov_index]

        # Full model forecast with covariate
        pred_full, method_used, cov_coeff, r2_improvement = forecast_with_lag1_covariate(
            target_train, covariate_train, new_horizon
        )

        # Calculate performance metrics
        mse_full = mean_squared_error(target_test, pred_full)
        mse_improvement_pct = (mse_baseline - mse_full) / mse_baseline * 100

        # Statistical test
        test_stat, p_value, test_method = enhanced_granger_test_fmri(
            target_test, pred_baseline, pred_full
        )

        result = {
            'target': target_name,
            'covariate': cov_name,
            'mse_baseline': mse_baseline,
            'mse_full': mse_full,
            'mse_improvement_pct': mse_improvement_pct,
            'covariate_coeff': cov_coeff,
            'r2_improvement': r2_improvement,
            'test_stat': test_stat,
            'p_value': p_value,
            'test_method': test_method,
            'forecast_method': method_used
        }

        target_results.append(result)
        all_results.append(result)

        # Print detailed results for promising cases
        #if p_value < 0.2 or mse_improvement_pct > 1:
        if 1==1:
            print(f"  {cov_name}: p={p_value:.4f}, MSE_improve={mse_improvement_pct:.2f}%, "
                  f"coeff={cov_coeff:.4f}, R²_improve={r2_improvement:.4f}")

    print()

# Summary of all results
print("=== SUMMARY ===")

# Sort all results by p-value
all_results.sort(key=lambda x: x['p_value'])

# Show significant results
significant_05 = [r for r in all_results if r['p_value'] < 0.05]
significant_10 = [r for r in all_results if 0.05 <= r['p_value'] < 0.10]
marginal_20 = [r for r in all_results if 0.10 <= r['p_value'] < 0.20]

print(f"Significant results (p < 0.05): {len(significant_05)}")
for r in significant_05:
    print(f"  {r['covariate']} -> {r['target']}: p={r['p_value']:.4f}, "
          f"improve={r['mse_improvement_pct']:.2f}%")

print(f"\nMarginal results (0.05 ≤ p < 0.10): {len(significant_10)}")
for r in significant_10:
    print(f"  {r['covariate']} -> {r['target']}: p={r['p_value']:.4f}, "
          f"improve={r['mse_improvement_pct']:.2f}%")

print(f"\nWeaker evidence (0.10 ≤ p < 0.20): {len(marginal_20)}")
for r in marginal_20:
    print(f"  {r['covariate']} -> {r['target']}: p={r['p_value']:.4f}, "
          f"improve={r['mse_improvement_pct']:.2f}%")

# Show best forecast improvements regardless of p-value
print(f"\nBest forecast improvements (top 5):")
best_improvements = sorted(all_results, key=lambda x: x['mse_improvement_pct'], reverse=True)[:5]
for i, r in enumerate(best_improvements, 1):
    print(f"  {i}. {r['covariate']} -> {r['target']}: {r['mse_improvement_pct']:.2f}% improve, "
          f"p={r['p_value']:.4f}")

print(f"\nTotal relationships tested: {len(all_results)}")

=== fMRI-Optimized LLM Granger Causality (Preprocessed Data) ===
Forecast horizon: 33 time points
Using lag-1 relationships for fMRI temporal dynamics

=== Target: 0 ===
Baseline MSE: 3.763006
  1: p=1.0000, MSE_improve=-0.85%, coeff=-0.0762, R²_improve=0.0018
  2: p=1.0000, MSE_improve=-7.22%, coeff=0.0748, R²_improve=0.0017
  3: p=1.0000, MSE_improve=-11.11%, coeff=0.2442, R²_improve=0.0174
  4: p=0.0832, MSE_improve=7.58%, coeff=0.1037, R²_improve=0.0049


KeyboardInterrupt: 