In [None]:
# References:
#   Bandpower: https://raphaelvallat.com/bandpower.html
#   Colors: https://matplotlib.org/stable/gallery/color/named_colors.html

import logging
import matplotlib.pyplot as plt
import mne
import numpy as np
import scipy

MAX_SUB = 1 #20
MAX_SES = 1 #12

BANDS = { 'delta': (0.5, 4), 'theta': (4, 7), 'alpha': (8, 13), 'beta': (13, 30), 'gamma': (30, 128) }

COLORS = { 'delta': 'skyblue', 'theta': 'hotpink', 'alpha': 'salmon', 'beta': 'lightseagreen', 'gamma': 'thistle' }

logging.basicConfig(level='INFO', format='%(asctime)s %(levelname)s [%(module)s:%(lineno)d] %(message)s')
logging.getLogger('mne').setLevel('ERROR')

In [None]:
sub_list = []
for sub in range(1, MAX_SUB + 1):
    logging.info('Loading subject %03d' % sub)
    ses_list = []

    for ses in range(1, MAX_SES + 1):
        data_path = 'musin-g/sub-%03d/ses-%02d/eeg/sub-%03d_ses-%02d_task-MusicListening_run-%d_eeg' % (sub, ses, sub, ses, ses)
        data = mne.io.read_raw_eeglab(data_path + '.set', preload=True, montage_units='mm')
        data.drop_channels(['E129'])
        ses_list.append(data)

    sub_list.append(ses_list)

raw:mne.io.Raw = sub_list[0][0]
spectrum = raw.compute_psd()
all_psds, freqs = spectrum.get_data(return_freqs=True)

In [None]:
fig = raw.plot(n_channels=128, show=False)
plt.show(fig)
plt.close(fig)

In [None]:
fig = spectrum.plot(show=False)
plt.show(fig)
plt.close(fig)

In [None]:
def get_peaks(all_psds):
    counts = np.bincount(np.argmax(all_psds, axis=1))
    peaks, _ = scipy.signal.find_peaks(np.concatenate(([0], counts, [0])))
    return peaks - 1

get_peaks(all_psds)

In [None]:
chan = 95

def get_unit_label(dB=False, estimate='power', unit='µV'):
    if estimate == 'auto':
        estimate = 'power' if dB else 'amplitude'

    if estimate == 'amplitude':
        label = r'$\mathrm{%s/\sqrt{Hz}}$' % unit
    else:
        if '/' in unit:
            unit = '(%s)' % unit
        label = r'$\mathrm{%s²/Hz}$' % unit
    if dB:
        label += r'$\ \mathrm{(dB)}$'
    return label

def convert_psds(psds, dB=False, estimate='power', unit='µV'):
    psds = psds.copy()
    if estimate == 'auto':
        estimate = 'power' if dB else 'amplitude'

    if estimate == 'amplitude':
        np.sqrt(psds, out=psds)
        np.multiply(psds, 1e6, out=psds)
    else:
        np.multiply(psds, 1e12, out=psds)
        if '/' in unit:
            unit = '(%s)' % unit
    if dB:
        np.log10(np.maximum(psds, np.finfo(float).tiny), out=psds)
        np.multiply(psds, 10, out=psds)
    return psds

def get_power(psds, freqs, key):
    freq_res = (freqs[-1] - freqs[0]) / len(freqs)
    band = BANDS[key]
    idx = np.logical_and(freqs >= band[0], freqs <= band[1])
    psds = convert_psds(psds)
    return scipy.integrate.simps(psds[idx], dx=freq_res)

plt.figure(figsize=(12, 4))

psds = all_psds[chan]
psds_db = convert_psds(psds, dB=True)

plt.xlabel('Frequency (Hz)')
plt.ylabel(get_unit_label(dB=True))
plt.plot(freqs, psds_db, lw=1, color='black')

for name in BANDS.keys():
    band = BANDS[name]
    color = COLORS[name]
    idx = np.logical_and(freqs >= band[0], freqs <= band[1])
    plt.fill_between(freqs, psds_db, y2=psds_db.min(), where=idx, color=color)

plt.tick_params(length=3)
plt.xticks(
    [np.average(band) for band in BANDS.values()],
    BANDS.keys(),
    rotation=-60,
    rotation_mode='anchor',
    ha='left'
)

text = ''
for name in BANDS.keys():
    band = BANDS[name]
    color = COLORS[name]
    power = get_power(psds, freqs, name)
    unit = get_unit_label()
    text += '%s (%.1f-%.1f Hz): %.3f %s\n' % (name, band[0], band[1], power, unit)
plt.text(len(psds_db), psds_db.max(), text, ha='right', va='top')

plt.show()
plt.close()