In [41]:
# %% step up
import os
import mne
import numpy as np
import matplotlib.pyplot as plt
from ieeg.viz.mri import plot_on_average


In [None]:
# %% subject list
subjs=['D0053','D0063','D0065','D0066','D0070','D0071','D0081','D0094']

In [None]:
# %% define condition
stat_type='mask'
con='Auditory'
contrast='ave'

match stat_type:
    case "zscore":
        fif_read = lambda f: mne.read_epochs(f, False, preload=True)
    case "power":
        fif_read = lambda f: mne.read_epochs(f, False, preload=True)
    case "significance":
        fif_read = mne.read_evokeds
    case "pval":
        fif_read = mne.read_evokeds

In [None]:
# %% load data
chs = []
data_lst = []
for i, subject in enumerate(subjs):
    
    HOME = os.path.expanduser("~")
    LAB_root = os.path.join(HOME, "Box", "CoganLab")
    bids_root = os.path.join(LAB_root,'BIDS-1.0_LexicalDecRepDelay','BIDS')
    subj_gamma_stats_dir = os.path.join(bids_root, "derivatives", "stats", subject)
    
    file_dir = os.path.join(subj_gamma_stats_dir,f'{con}_{stat_type}-{contrast}.fif')
    subj_dataset = fif_read(file_dir)
    
    subj_data = subj_dataset[0].data
    subj_chs = subj_dataset[0].ch_names
    labeled_chs = [f"{subject} {ch}" for ch in subj_chs]
    
    data_lst.append(subj_data)
    chs.extend(labeled_chs)
    if i == 0:
        times = subj_dataset[0].times
    
data = np.concatenate(data_lst, axis=0)

In [None]:
# %% get the onsets of the activation (an effective cluster is defined as 0.2s)
spf = 1 / (times[1] - times[0])  # Calculate the sampling frequency
win_len = 0.2 # in second
win = int(win_len * spf)  # Number of samples in 0.1 seconds

onsets = {}

for ch_idx, ch_name in enumerate(chs):
    ch_data = data[ch_idx]
    found = False
    
    for start_idx in range(len(ch_data) - win + 1):
        win_data = ch_data[start_idx:start_idx + win]
        
        if np.all(win_data == 1):
            starting_time = times[start_idx]
            onsets[ch_name] = starting_time
            found = True
            break
    
    if not found:
        onsets[ch_name] = None  # No significant window found

# `starting_points` now contains the starting time of significance for each channel

In [None]:
# %% select channels with significant activation clusters
data_s = []
chs_s = []
chs_s_idx = []
onsets_s = []

for ch_idx, ch_name in enumerate(chs):
    if onsets[ch_name] is not None:  # Check if the channel has a valid onset
        data_s.append(data[ch_idx])  # Add the channel data to the selected data list
        chs_s.append(ch_name)        # Add the channel name to the selected channel names list
        chs_s_idx.append(ch_idx)
        onsets_s.append(onsets[ch_name])

# Convert the selected data list to a numpy array
data_s = np.array(data_s)

In [None]:
# %% do the ranking
sorted_indices = np.argsort(np.array(onsets_s))  # Get the indices that would sort the array
# %% rearrange the data according to sorted_indices
data_s_sorted = data_s[sorted_indices]
chs_s_sorted = [chs_s[i] for i in sorted_indices]
onsets_s_sorted = [onsets_s[i] for i in sorted_indices]

# %% plot the data

plt.figure(figsize=(2^15, 2^15))
fig, ax = plt.subplots()
ax.imshow(data_s_sorted, cmap='Reds')

ch_gap=20
time_gap=50
channel_names=chs_s_sorted[::ch_gap]
ax.set_yticks(range(0,len(channel_names)*ch_gap,ch_gap))
ax.set_yticklabels(channel_names)
time_stamps=times[::time_gap]
ax.set_xticks(range(0,len(time_stamps)*time_gap,time_gap))
ax.set_xticklabels(time_stamps)
try:
    zero_time_index = np.where(times == 0)[0][0]
    ax.axvline(x=zero_time_index, color='black', linestyle='--', linewidth=1)
except Exception as e:
    print('no zero time found')
fig.savefig('try.jpg', dpi=300)


In [50]:
chs_s

['D0053 LPIF10',
 'D0053 LPIF11',
 'D0053 LPIF12',
 'D0053 LPIF14',
 'D0053 LPIF15',
 'D0053 LPI14',
 'D0053 RPMT1',
 'D0053 RPMT2',
 'D0053 RPMT3',
 'D0053 RPMT4',
 'D0053 RPMT5',
 'D0053 RPMT6',
 'D0063 LOF8',
 'D0063 LOF9',
 'D0063 LOF10',
 'D0063 LOF11',
 'D0063 LMSF1',
 'D0063 LMSF2',
 'D0063 LMSF3',
 'D0063 LMSF4',
 'D0063 LPSF1',
 'D0063 LPSF2',
 'D0063 LPSF3',
 'D0063 LPSF4',
 'D0063 LPSF5',
 'D0063 LPSF6',
 'D0063 LPSF7',
 'D0063 LPSF8',
 'D0063 LPSF9',
 'D0063 LPSF10',
 'D0063 LPSF11',
 'D0063 LPSF12',
 'D0063 LMMT9',
 'D0063 LMMT10',
 'D0063 LMMT11',
 'D0063 LMMT12',
 'D0063 LMMT14',
 'D0063 ROF8',
 'D0063 ROF11',
 'D0063 RAI1',
 'D0063 RAI2',
 'D0063 RAI3',
 'D0063 RAI4',
 'D0063 RAI5',
 'D0063 RAI6',
 'D0063 RAI8',
 'D0063 RAI13',
 'D0063 RASF1',
 'D0063 RASF2',
 'D0063 RASF3',
 'D0063 RASF13',
 'D0063 RASF15',
 'D0063 RMSF1',
 'D0063 RMSF2',
 'D0063 RMSF3',
 'D0063 RMSF4',
 'D0063 RMSF5',
 'D0063 RMSF6',
 'D0063 RMSF12',
 'D0063 RPSF1',
 'D0063 RPSF2',
 'D0063 RPSF3',
 'D

In [49]:
# %% plot the significance electrodes on the average brain
elecols = [[1 - i/(len(chs_s_idx) - 1), 0, i/(len(chs_s_idx) - 1)] for i in range(len(chs_s_idx))]
elecols_s = [elecols[i] for i in sorted_indices]
fig2 = plot_on_average(subjs,picks=chs_s_idx,hemi='both',color=elecols_s)#, label_every=8)

  info.set_montage(montage)
  info.set_montage(montage)
  info.set_montage(montage)
  info.set_montage(montage)
  info.set_montage(montage)
  info.set_montage(montage)
  info.set_montage(montage)
  info.set_montage(montage)
