In [None]:
import os
import numpy as np
from scipy import io
from pathlib import Path

from modules import preproc, rhino, source_recon, parcellation, hmm, utils
import mne
from osl_dynamics import inference, analysis
from osl_dynamics.utils import plotting
from osl_dynamics.data import Data

In [None]:
#preprocessing and source reconstruction
for subj_id in [f"{i:02d}" for i in range(1, 34)]:
    for task_id in [f"{i:02d}" for i in range(1, 7)]:

        out_id = f"sub-{subj_id}_ses-01_task-stim{task_id}"

        try:
            if not os.path.exists(f'prep_data/sub-{subj_id}_ses-01_task-stim{task_id}_eeg.fif'):
                print(f"[skip]: {f'prep_data/sub-{subj_id}_ses-01_task-stim{task_id}_eeg.fif'}")
                continue
            fns = utils.OSLFilenames(
                outdir="out_data",
                id=out_id,
                preproc_file=f"prep_data/sub-{subj_id}_ses-01_task-stim{task_id}_eeg.fif",
                surfaces_dir="mni152_surfaces",  # replace with the 'outdir' used in rhino.extract_surfaces if you have your own structural
            )

            rhino.extract_polhemus_from_fif(fns, include_eeg_as_headshape=True)
            rhino.coregister(
                fns,
                allow_smri_scaling=True,  # set to False if using a real structural
            )

            rhino.forward_model(fns, model="Triple Layer", gridstep=8, meg=False, eeg=True)

            fif_file = f"prep_data/sub-{subj_id}_ses-01_task-stim{task_id}_eeg.fif"
            raw = mne.io.read_raw_fif(fif_file, preload=False)

            source_recon.lcmv_beamformer(fns, raw, chantypes="eeg", rank={"eeg": 50})
            voxel_data, voxel_coords = source_recon.apply_lcmv_beamformer(fns, raw)

            parcellation_file = "fmri_d100_parcellation_with_PCC_reduced_2mm_ss5mm_ds8mm.nii.gz"

            parcel_data = parcellation.parcellate(
                fns,
                voxel_data,
                voxel_coords,
                method="spatial_basis",
                orthogonalisation="symmetric",
                parcellation_file=parcellation_file,
            )

            parcellation.save_as_fif(
                parcel_data,
                raw,
                extra_chans="stim",
                filename=f"out_data/{out_id}/lcmv-parc-raw.fif",
            )

        except Exception as e:
            print(f"[error] {out_id} with: {e}")
            continue

In [None]:
import os

base_dir = "/out_data"

# loop range
subj_range = range(1, 33) 
task_range = range(1, 7)
target_file_name = "lcmv-parc-raw.fif"
# -----------------------

fif_files = []
for s in subj_range:
    subj_id = f"{s:02d}"
        
    for t in task_range:
        task_id = f"{t:02d}"
        out_id = f"sub-{subj_id}_ses-01_task-stim{task_id}"
        file_path = os.path.join(base_dir, out_id, target_file_name)
        
        if os.path.exists(file_path):
            fif_files.append(file_path)

data = Data(
    fif_files, 
    picks="misc", 
    reject_by_annotation="omit", 
    n_jobs=32
)

In [4]:
import pickle
with open("prepared_data.pkl",'rb') as f:
    data = pickle.load(f)

In [None]:
data = hmm.prepare_data_for_canonical_hmm(data, parcellation="38ROI_Giles")

In [5]:
model = hmm.load_canonical_hmm(n_states=6, parcellation="38ROI_Giles")
# State probability time course
alp = model.get_alpha(data)


2026-01-02 13:13:15.889528: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


Getting alpha:   0%|          | 0/163 [00:00<?, ?it/s]

2026-01-02 13:13:20.192046: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2026-01-02 13:13:20.699824: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2026-01-02 13:13:21.562670: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2026-01-02 13:13:23.386104: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2026-01-02 13:13:26.777529: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2026-01-02 13:13:33.861781: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2026-01-02 13:13:48.061253: I tensorflow/core/framework/local_rendezvous.cc:407] L

In [11]:
import pandas as pd
import numpy as np
n_states = 6
n_sessions = len(alp)
all_fo = np.zeros([n_sessions,n_states])
all_lt =np.zeros([n_sessions,n_states])
all_intv = np.zeros([n_sessions,n_states])
all_sr = np.zeros([n_sessions,n_states])


column_names = []

for i in range(n_sessions):
    current_alp = alp[i]
    current_fif = fif_files[i]
    
    session_id = current_fif.split('/')[-2]
    column_names.append(session_id)
    
    alp_raw = inference.modes.convert_to_mne_raw(current_alp, current_fif, n_embeddings=data.n_embeddings, verbose=False)
    fs = alp_raw.info["sfreq"]
    
    # HMM features
    stc = inference.modes.argmax_time_courses(current_alp)
    all_fo[i,:] = analysis.post_hoc.fractional_occupancies(stc)
    all_lt[i,:] = analysis.post_hoc.mean_lifetimes(stc, sampling_frequency=fs)
    all_intv[i,:] = analysis.post_hoc.mean_intervals(stc, sampling_frequency=fs)
    all_sr[i,:] = analysis.post_hoc.switching_rates(stc, sampling_frequency=fs)


In [26]:
import pandas as pd
import numpy as np
n_states = 6
n_sessions = len(alp)
all_fo = np.zeros([n_sessions,n_states])
all_lt =np.zeros([n_sessions,n_states])
all_intv = np.zeros([n_sessions,n_states])
all_sr = np.zeros([n_sessions,n_states])


column_names = []

for i in range(n_sessions):
    current_alp = alp[i]
    current_fif = fif_files[i]
    
    session_id = current_fif.split('/')[-2]
    column_names.append(session_id)
    
    alp_raw = inference.modes.convert_to_mne_raw(current_alp, current_fif, n_embeddings=data.n_embeddings, verbose=False)
    fs = alp_raw.info["sfreq"]
    
    # HMM features
    stc = inference.modes.argmax_time_courses(current_alp)
    all_fo[i,:] = analysis.post_hoc.fractional_occupancies(stc)
    all_lt[i,:] = analysis.post_hoc.mean_lifetimes(stc, sampling_frequency=fs)
    all_intv[i,:] = analysis.post_hoc.mean_intervals(stc, sampling_frequency=fs)
    all_sr[i,:] = analysis.post_hoc.switching_rates(stc, sampling_frequency=fs)

# Export
state_labels = [f"State {j}" for j in range(n_states)]

df_fo = pd.DataFrame(all_fo)
df_lt = pd.DataFrame(all_lt)
df_intv = pd.DataFrame(all_intv)
df_sr = pd.DataFrame(all_sr)
df_fo.to_csv("HMM_FO.csv")
df_lt.to_csv("HMM_Lifetimes.csv")
df_intv.to_csv("HMM_Intervals.csv")
df_sr.to_csv("HMM_SwitchingRates.csv")



In [None]:
parcellation = "38ROI_Giles"
n_states=6
plots_dir = f"plots/{n_states}_states_{parcellation}"
hmm.plot_canonical_group_level_networks(n_states=n_states, parcellation=parcellation, plots_dir=plots_dir)
hmm.display_network_plots(n_states=n_states, plots_dir=plots_dir)

In [None]:
import numpy as np
import pandas as pd
from collections import defaultdict

style_mapping = {
    'stim01': 'Style1', 'stim02': 'Style1',
    'stim03': 'Style2', 'stim04': 'Style2',
    'stim05': 'Style3', 'stim06': 'Style3'
}

data_store = defaultdict(lambda: defaultdict(list))

for i in range(len(fif_files)):
    current_alp = alp[i]
    current_fif = fif_files[i]
    
    session_id = current_fif.split('/')[-2]
    
    parts = session_id.split('_')
    sub_id = parts[0]  # 'sub-1'
    stim_id = parts[-1].split('-')[-1]  # 'stim01'
    
    style_label = style_mapping.get(stim_id)
    
    if style_label:
        stc = inference.modes.argmax_time_courses(current_alp)
        data_store[sub_id][style_label].append(stc)

results = []
subjects = sorted(data_store.keys())

for sub in subjects:
    sub_row = {'Subject': sub}
    for style in ['Style1', 'Style2', 'Style3']:
        stclist = data_store[sub].get(style, [])
        
        if len(stclist) > 0:
            combined_stc = np.concatenate(stclist, axis=0) 
            fo = np.mean(combined_stc, axis=0) 
        else:
            fo = np.full(6, np.nan)
        
        for state_idx in range(6):
            col_name = f'State{state_idx+1}_{style}'
            sub_row[col_name] = fo[state_idx]
            
    results.append(sub_row)

df_final = pd.DataFrame(results).set_index('Subject')
df_final.to_csv('FO_results_by_style.csv')