In [None]:
import os
import numpy as np
import mne
import matplotlib.pyplot as plt
from mne_bids import BIDSPath, read_raw_bids
from mne_icalabel import label_components
from mne.preprocessing import ICA

event_id = {
    'normal': 1,   
    'conflict': 2  
}

In [2]:
def process_subject(subject_id, session, bids_root):
    
    bids_path = BIDSPath(
        subject=subject_id, session=session, task="PredictionError", suffix="eeg", extension=".vhdr", root=bids_root
    )
    
    # Load raw data
    raw = read_raw_bids(bids_path)

    # Preprocessing steps
    raw.annotations.onset -= 0.063  # Adjust for EEG setup delay
    raw_resampled = raw.copy().resample(sfreq=250, npad="auto")  # Resample
    raw_filtered = raw_resampled.filter(l_freq=1.0, h_freq=124.0).notch_filter(freqs=50)  # Bandpass + Notch filter
    raw_referenced = raw_filtered.set_eeg_reference(ref_channels="average").set_montage("standard_1020")  # Re-reference

    # Extracting events based on annotations
    events = []
    for annot in raw_referenced.annotations:
        print(f"Processing annotation: {annot['description']}")
        if 'normal_or_conflict:normal' in annot['description']:
            events.append([int(annot['onset'] * raw_referenced.info['sfreq']), 0, event_id['normal']])
        elif 'normal_or_conflict:conflict' in annot['description']:
            events.append([int(annot['onset'] * raw_referenced.info['sfreq']), 0, event_id['conflict']])
        else:
            print("Skipping irrelevant annotation:", annot['description'])
    events = np.array(events, dtype=int)

    # Extracting epochs from raw_referenced data
    epochs = mne.Epochs(
        raw_referenced, events, event_id=event_id, tmin=-0.3, tmax=0.7,
        baseline=(-0.3, 0), preload=True, event_repeated='merge'
    )
    print(f"Total epochs: {len(epochs)}")

    # Computing Mean absolute amplitude for every epochs
    epoch_data = epochs.get_data()  
    mean_amplitudes = np.mean(np.abs(epoch_data), axis=(1, 2))  

    ranked_indices = np.argsort(mean_amplitudes)  

    percentage = 85 #considering only 85% of the clean data for further processing
    n_epochs_to_keep = int(len(epochs) * (percentage / 100))
    selected_indices = ranked_indices[:n_epochs_to_keep] 

    clean_epochs = epochs[selected_indices]
    
    #Applying ICA through python library

    ica = ICA(n_components=10, method='picard', random_state=42, max_iter=5000)
    ica.fit(clean_epochs)

    ica.plot_components()

    # Using ICLabel for component classification

    labels = label_components(raw_referenced, ica, method='iclabel')

    print("ICLabel Results:")
    for idx, (label, prob) in enumerate(zip(labels['labels'], labels['y_pred_proba'])):
        print(f"Component {idx}: {label} (Probability: {prob:.2f})")

    # Removing artifacts like eye blink, muscle artifacts and line nose
    bad_ics = [idx for idx, label in enumerate(labels['labels'])
            if label in ('eye blink', 'muscle artifact', 'line_noise')]

    ica.exclude = bad_ics  

    ica.apply(raw_referenced)

    raw_referenced.set_meas_date(None)

    #Filtering ERP data with 0.2 Hz high-pass and 35 Hz low-pass
    raw_referenced = raw_referenced.copy().filter(l_freq=0.2, h_freq=35.0)

    # Rejecting 10% of the noisiest epochs based on amplitude of the signal
    epoch_data = epochs.get_data()  
    mean_amplitudes = np.mean(np.abs(epoch_data), axis=(1, 2))  
    threshold = np.percentile(mean_amplitudes, 90)  
    clean_epochs = epochs[mean_amplitudes < threshold]

    # Focusing on analysing selected electrodes 
    frontal_channels = ['Fz', 'Cz', 'Fp1', 'FC1', 'FC2']
    clean_epochs.pick_channels(frontal_channels)

    # Extract ERP negativity peaks in the specific time window
    time_window = (0.1, 0.3)  # 100–300 ms
    negativity_peaks = {}

    for ch_name in frontal_channels:
        channel_idx = clean_epochs.ch_names.index(ch_name)
        erp_data = clean_epochs.average().data[channel_idx]
        times = clean_epochs.times

        mask = (times >= time_window[0]) & (times <= time_window[1])
        time_window_data = erp_data[mask]
        time_window_times = times[mask]

        peak_idx = np.argmin(time_window_data)
        peak_time = time_window_times[peak_idx]
        peak_amplitude = time_window_data[peak_idx]

        negativity_peaks[ch_name] = (peak_time, peak_amplitude)
        print(f"Channel {ch_name}: Negativity peak at {peak_time:.3f} s with amplitude {peak_amplitude:.3f} µV")


    epochs_match = epochs['normal']
    epochs_mismatch = epochs['conflict']
    print(f"Match trials: {len(epochs_match)}, Mismatch trials: {len(epochs_mismatch)}")
    return epochs_match, epochs_mismatch




In [None]:

# Below we are plotting time-frequency graphs for all the subjects and sessions

bids_root = "/home/st/st_us-053000/st_st190561/EEG" #path to the bids_root folder

subjects = ["02", "03", "06", "07", "08", "11", "12", "13", "14", "15", "16"] #Hard-coded valid subjects for experiment

sessions = ["Visual", "EMS", "Vibro"]

channel = 'Fz'

# Initialize lists for storing TFR data
tfr_normal_visual, tfr_conflict_visual = [], []
tfr_normal_ems, tfr_conflict_ems = [], []
tfr_normal_vibro, tfr_conflict_vibro = [], []

for subject in subjects:
    for session in sessions:
        try:
            print(f"Processing Subject {subject}, Session {session}...")

            freqs = np.arange(1, 124, 1)
            n_cycles = freqs / 2

            epochs_match, epochs_mismatch = process_subject(subject, session, bids_root)


            tfr_match = mne.time_frequency.tfr_morlet(epochs_match, freqs=freqs, n_cycles=n_cycles,
                                                      use_fft=True, return_itc=False, decim=3, n_jobs=1)

            tfr_mismatch = mne.time_frequency.tfr_morlet(epochs_mismatch, freqs=freqs, n_cycles=n_cycles,
                                                         use_fft=True, return_itc=False, decim=3, n_jobs=1)


            if session == "Visual":
                tfr_conflict_visual.append(tfr_mismatch.data)
                tfr_normal_visual.append(tfr_match.data)
            elif session == "EMS":
                tfr_conflict_ems.append(tfr_mismatch.data)
                tfr_normal_ems.append(tfr_match.data)
            elif session == "Vibro":
                tfr_conflict_vibro.append(tfr_mismatch.data)
                tfr_normal_vibro.append(tfr_match.data)

        except FileNotFoundError:
            print(f"File not found for subject {subject}, session {session}. Skipping this session.")
            continue


avg_data = {
    "Visual": {
        "normal": np.mean(tfr_normal_visual, axis=0) if tfr_normal_visual else None,
        "conflict": np.mean(tfr_conflict_visual, axis=0) if tfr_conflict_visual else None,
    },
    "EMS": {
        "normal": np.mean(tfr_normal_ems, axis=0) if tfr_normal_ems else None,
        "conflict": np.mean(tfr_conflict_ems, axis=0) if tfr_conflict_ems else None,
    },
    "Vibro": {
        "normal": np.mean(tfr_normal_vibro, axis=0) if tfr_normal_vibro else None,
        "conflict": np.mean(tfr_conflict_vibro, axis=0) if tfr_conflict_vibro else None,
    },
}


tfr_match_avg = tfr_match.copy()
tfr_match_avg.data = avg_data["EMS"]["normal"]  

tfr_mismatch_avg = tfr_mismatch.copy()
tfr_mismatch_avg.data = avg_data["EMS"]["conflict"]


fig, axes = plt.subplots(1, 2, figsize=(14, 8))

print(f"tfr_normal_visual: {len(tfr_normal_visual)}, tfr_conflict_visual: {len(tfr_conflict_visual)}")
print(f"tfr_normal_ems: {len(tfr_normal_ems)}, tfr_conflict_ems: {len(tfr_conflict_ems)}")
print(f"tfr_normal_vibro: {len(tfr_normal_vibro)}, tfr_conflict_vibro: {len(tfr_conflict_vibro)}")

print(f"Total channels in epochs_match: {len(epochs_match.ch_names)}")
print(f"Shape of tfr_match_avg.data: {tfr_match_avg.data.shape}")  # Expected shape: (n_channels, ...)
print(f"Requested channel index: {epochs_match.ch_names.index(channel)}")
print(f"Channel: {channel}")


im1 = axes[0].imshow(tfr_match_avg.data[epochs_match.ch_names.index(channel)], aspect='auto', 
                      origin='lower', extent=[epochs_match.times[0], epochs_match.times[-1], freqs[0], freqs[-1]])
axes[0].set_title("TFR - Match")
fig.colorbar(im1, ax=axes[0])

im2 = axes[1].imshow(tfr_mismatch_avg.data[epochs_mismatch.ch_names.index(channel)], aspect='auto',
                      origin='lower', extent=[epochs_mismatch.times[0], epochs_mismatch.times[-1], freqs[0], freqs[-1]])
axes[1].set_title("TFR - Mismatch")
fig.colorbar(im2, ax=axes[1])

plt.savefig("/home/st/st_us-053000/st_st190561/EEG/TF1.png")
plt.show()


Processing Subject 02, Session Visual...
Extracting parameters from /home/st/st_us-053000/st_st190561/EEG/sub-02/ses-Visual/eeg/sub-02_ses-Visual_task-PredictionError_eeg.vhdr...
Setting channel info structure...
Reading events from /home/st/st_us-053000/st_st190561/EEG/sub-02/ses-Visual/eeg/sub-02_ses-Visual_task-PredictionError_events.tsv.
Reading channel info from /home/st/st_us-053000/st_st190561/EEG/sub-02/ses-Visual/eeg/sub-02_ses-Visual_task-PredictionError_channels.tsv.


cap_size: 58
block_1: Visual
block_2: Visual + Vibro
block_3: Visual + Vibro + EMS
  raw = read_raw_bids(bids_path)
