In [None]:
import os
import mne
import pyxdf
import numpy as np
from mne.preprocessing import ICA
import matplotlib.pyplot as plt
from mne_icalabel import label_components
plt.style.use('default')
%matplotlib qt
from utils import read_data, parse_xdf, get_event_names
import glob

In [None]:
experimental = ["638390",'323706'] # (Hypercusis), (Hearing loss, Tinnitus, Hyperacusis, Misophonia) 
control = ['184632']

control_path = glob.glob(os.path.join("exp_data", f"sub-{control[0]}", "*.xdf"))[0]
exp_path = glob.glob(os.path.join("exp_data", f"sub-{experimental[0]}", "*.xdf"))[0]

control_raw, control_events, control_mapping = read_data(control_path, eeg_stream_name='obci_eeg1')
exp_raw, exp_events, exp_mapping = read_data(exp_path, eeg_stream_name='obci_eeg1')

In [None]:
exp_epoch = mne.Epochs(exp_raw, exp_events, event_id=exp_mapping['ast_stim'], tmin=-0.2, tmax=4)
control_epoch = mne.Epochs(control_raw, control_events, 
                           event_id=control_mapping['ast_stim'], tmin=-0.2, tmax=4)

In [None]:
control_epoch['ast_stim-paper_crunching'].average().plot()

In [None]:
# e = mne.concatenate_epochs([epochs['ast_stim-paper_crunching'],
#                        epochs['ast_stim-flip'],
#                        epochs['ast_stim-control-light_rain'],
#                        epochs['ast_stim-click_continuous']])

left_ear = ['L1', 'L2', 'L4', 'L5', 'L7', 'L8', 'L9', 'L10']
right_ear = ['R1', 'R2', 'R4', 'R5', 'R7', 'R8', 'R9', 'R10']

# left_ix = mne.pick_channels(epochs.info["ch_names"], include=left_ear)
# right_ix = mne.pick_channel

In [None]:
evokeds = dict(stim=exp_epoch['ast_stim-paper_crunching'].average(), control=control_epoch['ast_stim-paper_crunching'].average())
picks = right_ear 
mne.viz.plot_compare_evokeds(evokeds, picks=picks, combine="mean")

In [None]:
freqs = np.logspace(*np.log10([1, 20]), num=8) # alpha band frequencies
n_cycles = freqs / 2.0  # different number of cycle per frequency
power, itc = exp_epoch['ast_stim-paper_crunching'].compute_tfr(
    method="morlet",
    freqs=freqs,
    n_cycles=n_cycles,
    average=True,
    return_itc=True,
    decim=3,
)

In [18]:
import matplotlib.pyplot as plt
import numpy as np

channels = power.ch_names
n_channels = len(channels)
n_cols = 4
n_rows = int(np.ceil(n_channels / n_cols))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(16, 3 * n_rows))
axes = axes.flatten()

# Normalize each channel's power between 0 and 1
for idx, ch in enumerate(channels):
    data = power.data[idx]
    data_norm = (data - data.min()) / (data.max() - data.min() + 1e-12)
    im = axes[idx].imshow(
        data_norm,
        aspect='auto',
        origin='lower',
        extent=[power.times[0], power.times[-1], power.freqs[0], power.freqs[-1]],
        vmin=0,
        vmax=1,
        cmap='Reds'
    )
    axes[idx].set_title(ch)
    axes[idx].set_ylabel('Freq (Hz)')
    axes[idx].set_xlabel('Time (s)')

# Hide any unused subplots
for ax in axes[n_channels:]:
    ax.axis('off')

# Add a single colorbar to the right
fig.subplots_adjust(right=0.88)
cbar_ax = fig.add_axes([0.90, 0.15, 0.02, 0.7])
plt.colorbar(im, cax=cbar_ax)

plt.tight_layout(rect=[0, 0, 0.88, 1])

  plt.tight_layout(rect=[0, 0, 0.88, 1])


In [None]:
import numpy as np

# Define frequency bands (Hz)
bands = {
    'delta': (1, 4),
    'theta': (4, 8),
    'alpha': (8, 13),
    'beta': (13, 30),
    'gamma': (30, 45)
}

# Get frequency indices for each band
band_indices = {
    band: np.where((power.freqs >= low) & (power.freqs < high))[0]
    for band, (low, high) in bands.items()
}

# Calculate normalized average power per band per channel
band_power_norm = {}
for band, idxs in band_indices.items():
    band_power_norm[band] = []
    for ch in range(power.data.shape[0]):
        data = power.data[ch]
        # Normalize the TFR for this channel
        data_norm = (data - data.min()) / (data.max() - data.min() + 1e-12)
        # Average over band freqs and all times
        band_avg = data_norm[idxs, :].mean()
        band_power_norm[band].append(band_avg)
    band_power_norm[band] = np.array(band_power_norm[band])

# Print normalized average power for each band and channel
for band in bands:
    print(f"\n{band.capitalize()} band normalized average power per channel:")
    for ch, avg in zip(power.ch_names, band_power_norm[band]):
        print(f"  {ch}: {avg:.4f}")

# Example: Compute alpha/beta ratio per channel
alpha = band_power_norm['alpha']
beta = band_power_norm['beta']
alpha_beta_ratio = alpha / (beta + 1e-12)  # Avoid division by zero

print("\nAlpha/Beta ratio per channel:")
for ch, ratio in zip(power.ch_names, alpha_beta_ratio):
    print(f"  {ch}: {ratio:.4f}")

# You can compute other ratios similarly, e.g., theta/alpha, beta/gamma, etc.

In [21]:
import matplotlib.pyplot as plt
import numpy as np

# Bar plot: Normalized average power per band per channel
bands_list = list(band_power_norm.keys())
channels = power.ch_names
x = np.arange(len(channels))

fig, ax = plt.subplots(figsize=(14, 6))
width = 0.15

for i, band in enumerate(bands_list):
    if band == 'delta':
        continue
    ax.bar(x + i * width, band_power_norm[band], width, label=band.capitalize())

ax.set_xticks(x + width * (len(bands_list) - 1) / 2)
ax.set_xticklabels(channels, rotation=45)
ax.set_ylabel('Normalized Average Power')
ax.set_title('Normalized Average Power per Band per Channel')
ax.legend()
plt.tight_layout()
plt.show()

# Bar plot: Alpha/Beta ratio per channel
fig, ax = plt.subplots(figsize=(14, 4))
ax.bar(channels, alpha_beta_ratio, color='orange')
ax.set_ylabel('Alpha/Beta Ratio')
ax.set_title('Alpha/Beta Ratio per Channel')
ax.set_xticklabels(channels, rotation=45)
plt.tight_layout()
plt.show()

  ax.set_xticklabels(channels, rotation=45)


In [None]:
control_ab = alpha_beta_ratio

In [20]:
channels = power.ch_names
x = np.arange(len(channels))
width = 0.35

fig, ax = plt.subplots(figsize=(14, 4))
ax.bar(x + width/2, control_ab, width, color='green', label='Control')
ax.bar(x - width/2, alpha_beta_ratio, width, color='orange', label='Experimental')
ax.set_ylabel('Alpha/Beta Ratio')
ax.set_title('Alpha/Beta Ratio per Channel')
ax.set_xticks(x)
ax.set_xticklabels(channels, rotation=45)
ax.legend()
plt.tight_layout()
plt.show()

NameError: name 'control_ab' is not defined

In [None]:
power.plot(baseline=(-0.2, 0), mode='mean', colorbar=True)

In [19]:
import matplotlib.pyplot as plt

channels = power.ch_names
n_channels = len(channels)
n_cols = 4  # You can adjust this for your preferred layout
n_rows = int(np.ceil(n_channels / n_cols))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(16, 3 * n_rows))
axes = axes.flatten()

for idx, ch in enumerate(channels):
    power.plot(picks=[ch], baseline=(-0.2, 0), mode='mean', axes=axes[idx], colorbar=False, show=False)
    axes[idx].set_title(ch)

# Hide any unused subplots
for ax in axes[n_channels:]:
    ax.axis('off')

plt.tight_layout()
fig.suptitle('Alpha Band TFR Power - Control (per channel)', y=1.02, fontsize=18)
plt.show()

Applying baseline correction (mode: mean)
Applying baseline correction (mode: mean)
Applying baseline correction (mode: mean)
Applying baseline correction (mode: mean)
Applying baseline correction (mode: mean)
Applying baseline correction (mode: mean)
Applying baseline correction (mode: mean)
Applying baseline correction (mode: mean)
Applying baseline correction (mode: mean)
Applying baseline correction (mode: mean)
Applying baseline correction (mode: mean)
Applying baseline correction (mode: mean)


In [None]:
import numpy as np

# power.data shape: (n_channels, n_frequencies, n_times)
# Calculate mean across frequency and time for each channel (average alpha power per channel)
avg_alpha_per_channel = power.data.mean(axis=(1, 2))  # shape: (n_channels,)
# Calculate grand average alpha power across all channels
grand_avg_alpha = avg_alpha_per_channel.mean()

# Print results
for ch, avg in zip(power.ch_names, avg_alpha_per_channel):
    print(f"Channel {ch}: Average alpha power = {avg:.4e}")

print(f"\nGrand average alpha power across all channels: {grand_avg_alpha:.4e}")