In [None]:
import os
import glob
import numpy as np
from utils import read_data, get_event_names, compute_band_ratios, calculate_power_spectrum, plot_power_spectrum
import mne
import matplotlib.pyplot as plt
plt.style.use('default')

In [None]:
exp_path = os.path.join("exp_data","02_Experimental")
control_path = os.path.join("exp_data","01_Control")
glob_pattern = os.path.join("**","*.xdf")

exp_files = glob.glob(os.path.join(exp_path,glob_pattern),recursive=True)
control_files = glob.glob(os.path.join(control_path,glob_pattern),recursive=True)[:6]
# exp_mapping, exp_common = get_event_names(control_files)
# CTRL03-sub-129059-old error

In [None]:
exp_files

In [None]:
# participant_files = control_files[:6]
participant_files = exp_files[4:6]
event_type = 'ast_stim'
ast_event_groups = ['neutral', 'trigger']
participant_ratios = []

for participant_file in participant_files:
    participant_number = participant_file.split(os.sep)[2]
    participant_id = participant_file.split(os.sep)[-1].split('_')[0]
    print(f"Processing {participant_number} ({participant_id})")
    raw, events, mapping = read_data(
        participant_file, bandpass={'low': 2, 'high': 50})

    # Build a new event_id mapping for combining
    combined_event_id = {}

    # Assign all neutral events to 1, all trigger events to 2
    for k in mapping[event_type]['neutral'].keys():
        combined_event_id[str(k)] = 1
    for k in mapping[event_type]['trigger'].keys():
        combined_event_id[str(k)] = 2

    montage = mne.channels.read_custom_montage(
        'ceegrid_coords.csv', coord_frame='head')
    raw.set_montage(montage)
    raw.interpolate_bads(reset_bads=True)
    participant_data = {
        'participant_number': participant_number,
        'neutral_epoch': None,
        'trigger_epoch': None,
    }
    for group in ast_event_groups:
        if len(mapping[event_type][group]) == 0:
            print(
                f"No events found for {event_type} in group {group} for participant {participant_number}")
            continue
        epoch = mne.Epochs(
            raw, events, event_id=combined_event_id, tmin=-0.5, tmax=4)
        participant_data[f'{group}_epoch'] = epoch

    participant_ratios.append({
        'participant_id': f"{participant_number}",
        'neutral_epoch': participant_data['neutral_epoch'],
        'trigger_epoch': participant_data['trigger_epoch'],
    })
    break

AST->[pre_stim,stim,post_stim]

{
    "ast": {
    "pre_stim": [
      "trigger": [ if audio_label doesn't contain "control"
            "audio_label": 2
        ]
        "neutral":[ if audio_label contains "control"
            "audio_label": 2
        ]
    ]
     "stim": {
        "trigger": [ if audio_label doesn't contain "control"
            "audio_label": 2
        ]
        "neutral":[ if audio_label contains "control"
            "audio_label": 2
        ]
    }

    }

   
}

In [None]:
mapping['ast']

In [None]:
combined_event_id

In [None]:
epoch

In [None]:
# Combine all neutral and trigger epochs from participant_ratios
neutral_epochs = [r['neutral_epoch'] for r in participant_ratios if r['neutral_epoch'] is not None]
trigger_epochs = [r['trigger_epoch'] for r in participant_ratios if r['trigger_epoch'] is not None]

combined_neutral = mne.concatenate_epochs(neutral_epochs)
combined_trigger = mne.concatenate_epochs(trigger_epochs)

In [None]:
left_ear = ['L1', 'L2', 'L4', 'L5', 'L7', 'L8', 'L9', 'L10']
right_ear = ['R1', 'R2', 'R4', 'R5', 'R7', 'R8', 'R9', 'R10']

evokeds = dict(control=control_trigger.average(), experimental=combined_trigger.average())
picks = left_ear
mne.viz.plot_compare_evokeds(evokeds, picks=picks, combine="mean")

In [None]:
# participant_files = control_files[:6]
participant_files = exp_files
event_type = 'ast_stim'
ast_event_groups = ['neutral', 'trigger']
participant_ratios = []

for participant_file in participant_files:
    participant_number = participant_file.split(os.sep)[2]
    participant_id = participant_file.split(os.sep)[-1].split('_')[0]
    print(f"Processing {participant_number} ({participant_id})")
    raw, events, mapping = read_data(
        participant_file, bandpass={'low': 2, 'high': 50})
    # print(events)
    participant_data = {
        'participant_number': participant_number,
        'neutral_ratio': 0,
        'trigger_ratio': 0,
    }
    for group in ast_event_groups:
        if len(mapping[event_type][group]) == 0:
            print(
                f"No events found for {event_type} in group {group} for participant {participant_number}")
            continue
        epoch = mne.Epochs(
            raw, events, event_id=mapping[event_type][group], tmin=-0.2, tmax=4)
        power, freqs = calculate_power_spectrum(epoch, fmin=2, fmax=50)
        band_powers = compute_band_ratios(power, freqs)
        alpha = band_powers['alpha']
        beta = band_powers['beta']
        alpha_beta_ratio = alpha / (beta + 1e-12)  # Avoid division by zero
        participant_data[f'{group}_ratio'] = alpha_beta_ratio.mean()

    participant_ratios.append({
        'participant_id': f"{participant_number}",
        'neutral_ratio': participant_data['neutral_ratio'],
        'trigger_ratio': participant_data['trigger_ratio'],
    })

In [None]:
participant_ids = [r['participant_id'] for r in participant_ratios]
neutral_ratios = [r['neutral_ratio'] for r in participant_ratios]
trigger_ratios = [r['trigger_ratio'] for r in participant_ratios]
x = np.arange(len(participant_ids))
width = 0.35

# Calculate averages
neutral_avg = np.mean(neutral_ratios)
trigger_avg = np.mean(trigger_ratios)

plt.figure(figsize=(10, 6))
plt.bar(x - width/2, neutral_ratios, width, color='green', label='Neutral')
plt.bar(x + width/2, trigger_ratios, width, color='orange', label='Trigger')

# Draw average lines
plt.axhline(neutral_avg, color='green', linestyle=':', linewidth=2, label='Neutral Avg')
plt.axhline(trigger_avg, color='orange', linestyle=':', linewidth=2, label='Trigger Avg')

# Add average values as text on the plot
plt.text(len(x)-0.5, neutral_avg + 0.01, f"{neutral_avg:.2f}", color='green', fontsize=11, va='bottom', ha='right', fontweight='bold')
plt.text(len(x)-1.5, trigger_avg + 0.01, f"{trigger_avg:.2f}", color='orange', fontsize=11, va='bottom', ha='right', fontweight='bold')

plt.xlabel('Participant ID')
plt.ylabel('Alpha/Beta Ratio')
plt.title('Alpha/Beta Ratio for Experimental Participants (Neutral vs Trigger)')
plt.xticks(x, participant_ids, rotation=45)
plt.legend()
plt.tight_layout()
plt.show()