# Apply more and more complex models to the data

In [1]:
import numpy as np
from pathlib import Path
import os
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import warnings
from sklearn.model_selection import KFold
import os, json

import UPPtoolbox as upp


# Important variables
cwd = Path.cwd()
SNRs = np.array([-np.inf,-13,-11,-9,-7,-5,-3])
SubIDs = ['01','02','03','05','06','07','08','09','11','12','13','14','15','17','19','20','22','23','24','25']
colormap = {0: (0, 0, 0), 1: (0, 0.25, 1), 2: (0, 0.9375, 1), 3: (0, 0.91, 0.1), 4: (1, 0.6, 0), 5: (1, 0, 0), 6: (0.8, 0, 0)}

# Model comparison on EEG data

In [2]:
def crossval_loglik(state_series, input_series):
    """Compute mean test log-likelihood via 5-fold CV."""

    n_categories = len(list(state_series.keys()))
    categories = list(range(n_categories))
    
    kf = KFold(n_splits=5, shuffle=True, random_state=0)
    trial_indices = np.arange(np.min([state_series[cat].shape[0] for cat in categories]))
    test_lls_linear, test_lls_nonlinear, test_lls_gainmodul = [], [], []

    cv_params = []

    for train_idx, test_idx in kf.split(trial_indices):
        # Split data
        state_train, input_train = {}, {}
        state_test, input_test = {}, {}
        for cat in categories:
            state_train[cat] = state_series[cat][train_idx]
            input_train[cat] = input_series[cat][train_idx]
            state_test[cat] = state_series[cat][test_idx]
            input_test[cat] = input_series[cat][test_idx]


        # ---- Fit StratifiedLinear ----
        linear = upp.fitting_tools.clever_fit_linear(state_train = state_train, input_train = input_train, n_loops = 2, input_start_index=75, input_stop_index=100)

        # ---- Fit StratifiedGainModul  ----
        gainmodul = upp.fitting_tools.clever_fit_gainmodul(linear_prefitted = linear, state_train = state_train, input_train = input_train, n_loops = 2, input_start_index=75, input_stop_index=100)

        # ---- Fit StratifiedNonLinear1  ----
        nonlinear = upp.fitting_tools.clever_fit_nonlinear1(linear_prefitted = linear, state_train = state_train, input_train = input_train, n_loops = 2, input_start_index=75, input_stop_index=100)
        

        # ----  Isolate segments where stimulation is supposed constant ----
        state_test_constant_stim = {cat: state_test[cat][:,75:100] for cat in range(1, n_categories)}
        state_test_constant_stim[0] = state_test[0]
        input_test_constant_stim = {cat: input_test[cat][:,75:100] for cat in range(1, n_categories)}
        input_test_constant_stim[0] = input_test[0]
        
        # ----  Evaluate on test ----
        ll_linear = linear.loglikelihood_ukf(state_test_constant_stim, input_test_constant_stim)
        test_lls_linear.append(ll_linear)
        ll_nonlinear = nonlinear.loglikelihood(state_test_constant_stim, input_test_constant_stim)
        test_lls_nonlinear.append(ll_nonlinear)
        ll_gainmodul = gainmodul.loglikelihood(state_test_constant_stim, input_test_constant_stim)
        test_lls_gainmodul.append(ll_gainmodul)

        # ---- Save parameters ----
        cv_params.append([linear.get_params(), nonlinear.get_params(), gainmodul.get_params()])

    return np.array(test_lls_linear), np.array(test_lls_nonlinear),  np.array(test_lls_gainmodul), cv_params

In [None]:
version = 18

task = 'Active'
def run_recovery(checkpoint_path="Application_v1.json", save=True):

    if os.path.exists(checkpoint_path):
        with open(checkpoint_path, "r") as f:
            checkpoint = json.load(f)
        results_cv = checkpoint["results_cv"]
        detailed = checkpoint["detailed"]
        current_part = checkpoint["part"]
        fitted_params = checkpoint["fitted_params"]
        print(f"Latest participant done: {current_part}.")
    else:
        results_cv, detailed, fitted_params = [], [], []
    
    for part in tqdm(range(19,-1,-1)):
        
        data_ref = f'myEpochs_{task}/Epoch_{SubIDs[part]}-epo.fif'
        epochs_file = cwd.parents[0] / data_ref
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=RuntimeWarning)
            state_series = upp.STG(epochs_file, tmin=300, tmax=500)
            n_categories = len(list(state_series.keys()))
            categories = list(range(n_categories))

        # Input definition (same for all categories & trials)
        one_input = np.concatenate((np.zeros(75), np.ones(25), np.zeros(150)))
        input_series = {cat: np.stack([one_input for _ in range(state_series[cat].shape[0])]) for cat in categories}

        # Arrays of 5 folds each
        lls_lin, lls_nonlin, lls_gainmod, cv_params = crossval_loglik(state_series, input_series)
        ll_lin_mean = float(np.mean(lls_lin))
        ll_nonlin_mean = float(np.mean(lls_nonlin))
        ll_gainmod_mean = float(np.mean(lls_gainmod))
        fitted_params.append(cv_params)

        # ---- CV criterion ----
        candidates = ['Linear', 'NonLinear1', 'GainModulation']
        best_cv = candidates[np.argmax([ll_lin_mean, ll_nonlin_mean, ll_gainmod_mean])]

        
        ########################### Infer best_cv and save results #############################

        results_cv.append(best_cv)
        print(f'{part} looks like a {best_cv}')

        # ---- Save detailed info ----
        detailed.append({
            "part": part,
            "ll_linear_folds": lls_lin.tolist(),
            "ll_nonlinear_folds": lls_nonlin.tolist(),
            "ll_gainmod_folds": lls_gainmod.tolist(),
            "ll_linear_mean": ll_lin_mean,
            "ll_nonlinear_mean": ll_nonlin_mean,
            "ll_gainmod_mean": ll_gainmod_mean,
            "best_cv": best_cv,
        })

        # ---- Save checkpoint ----
        checkpoint = {"results_cv": results_cv, "detailed": detailed, "part": part, "fitted_params": fitted_params}
        
        if save:
            with open(checkpoint_path, "w") as f:
                json.dump(checkpoint, f, indent=2)


    print("âœ… All datasets processed and checkpoint saved.")
    return results_cv, detailed



results_cv, detailed = run_recovery(checkpoint_path=f"Application_EEG_v{version}_{task}_late.json", save=True)

  0%|          | 0/20 [00:00<?, ?it/s]

19 looks like a NonLinear1
