In [1]:
# Time-frequency transform of eeg data
import os
import pandas as pd
import numpy as np
import mne
from matplotlib import pyplot as plt
from tqdm import tqdm
import multiprocessing

mne.viz.set_browser_backend('matplotlib')

# Params
ch_list = ['Fp1', 'F3', 'C3', 'P3', 'F7', 'T3', 'T5', 'O1', 'Fz', 'Cz', 'Pz', 'Fp2', 'F4', 'C4', 'P4', 'F8', 'T4', 'T6', 'O2', 'EKG']
ch_types = ['eeg']*(len(ch_list)-1) + ['ecg']
sfreq = 200
n_jobs = 32

# Paths
root = '/media/latlab/MR/projects/kaggle-hms'
data_dir = os.path.join(root, 'data')
results_dir = os.path.join(root, 'results')
out_dir = os.path.join(data_dir, 'eeg')

# Load data
df = pd.read_csv(os.path.join(data_dir, 'train.csv'))
df = df.groupby('eeg_id').head(1).reset_index(drop=True)    # Only keep first row for each eeg_id
eeg_data = np.load(os.path.join(data_dir, 'eeg_data.npy'), allow_pickle=True).item()
display(df)

Using matplotlib as 2D backend.


Unnamed: 0,eeg_id,eeg_sub_id,eeg_label_offset_seconds,spectrogram_id,spectrogram_sub_id,spectrogram_label_offset_seconds,label_id,patient_id,expert_consensus,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote
0,1628180742,0,0.0,353733,0,0.0,127492639,42516,Seizure,3,0,0,0,0,0
1,2277392603,0,0.0,924234,0,0.0,1978807404,30539,GPD,0,0,5,0,1,5
2,722738444,0,0.0,999431,0,0.0,557980729,56885,LRDA,0,1,0,14,0,1
3,387987538,0,0.0,1084844,0,0.0,4099147263,4264,LRDA,0,0,0,3,0,0
4,2175806584,0,0.0,1219001,0,0.0,1963161945,23435,Seizure,3,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
17084,3910994355,0,0.0,2146798838,0,0.0,4272062867,28488,LPD,0,9,0,2,0,7
17085,3938393892,0,0.0,2146798838,1,60.0,2587113091,28488,LPD,0,9,0,2,0,7
17086,1850739625,0,0.0,2146798838,3,162.0,2394534310,28488,LPD,0,9,0,2,0,7
17087,1306668185,0,0.0,2147312808,0,0.0,1216355904,57480,LPD,0,3,0,0,0,0


In [2]:
chunk_size = df.shape[0] // n_jobs
chunks = [df.iloc[df.index[i:i + chunk_size]] for i in range(0, df.shape[0], chunk_size)]
import gc

def get_topomaps(df):
    mne.set_log_level('warning')
    eeg_topomap_data = dict()
    for row in tqdm(df.itertuples(), total=len(df)):
        raw = mne.io.RawArray(eeg_data[row.eeg_id].T, mne.create_info(ch_names=ch_list, sfreq=sfreq, ch_types=ch_types, verbose='warning'))
        raw.set_montage(mne.channels.make_standard_montage('standard_1020'), verbose='warning')
        raw = raw.crop(tmin=row.eeg_label_offset_seconds, tmax=row.eeg_label_offset_seconds + 50, include_tmax=False, verbose='warning')
        spec = raw.compute_psd(fmin=0, fmax=20, n_fft=1024, verbose='warning').get_data()

        fig = plt.figure(figsize=(10,10), frameon=False)
        ax_list = []
        for i in range(25):
            ax = plt.subplot(5,5,i+1)
            ax_list.append(ax)
            mne.viz.plot_topomap(spec[:, (i*4):((i+1)*4)].mean(1), raw.info, ch_type='eeg', sensors=False, outlines=None, cmap='gray', axes=ax, show=False)
        plt.tight_layout(pad=0)

        fig.savefig(f'/tmp/{row.eeg_id}.png', dpi=51.2)
        for ax in ax_list:
            ax.cla()
        plt.cla()
        fig.clf()
        plt.close(fig)
        plt.close('all')
        img = plt.imread(f'/tmp/{row.eeg_id}.png')
        img = img[..., :3].mean(2)
        eeg_topomap_data[row.eeg_id] = img
        del raw, spec, fig, img
        gc.collect()

    return eeg_topomap_data

In [3]:
pool = multiprocessing.Pool(processes=n_jobs)
results = pool.map(get_topomaps, chunks)
eeg_topomap_data = {k: v for d in results for k, v in d.items()}

100%|██████████| 534/534 [18:56<00:00,  2.13s/it]
100%|██████████| 534/534 [18:59<00:00,  2.13s/it]
100%|██████████| 1/1 [00:02<00:00,  2.19s/it]/it]
100%|██████████| 534/534 [19:01<00:00,  2.14s/it]
100%|██████████| 534/534 [19:02<00:00,  2.14s/it]
100%|██████████| 534/534 [19:03<00:00,  2.14s/it]
100%|██████████| 534/534 [19:04<00:00,  2.14s/it]
100%|██████████| 534/534 [19:06<00:00,  2.15s/it]
100%|██████████| 534/534 [19:07<00:00,  2.15s/it]
100%|██████████| 534/534 [19:08<00:00,  2.15s/it]
100%|██████████| 534/534 [19:09<00:00,  2.15s/it]
100%|██████████| 534/534 [19:09<00:00,  2.15s/it]
100%|██████████| 534/534 [19:10<00:00,  2.15s/it]
100%|██████████| 534/534 [19:11<00:00,  2.16s/it]
100%|██████████| 534/534 [19:11<00:00,  2.16s/it]
100%|██████████| 534/534 [19:12<00:00,  2.16s/it]
100%|██████████| 534/534 [19:12<00:00,  2.16s/it]
100%|██████████| 534/534 [19:12<00:00,  2.16s/it]
100%|██████████| 534/534 [19:12<00:00,  2.16s/it]
100%|██████████| 534/534 [19:12<00:00,  2.16s/it]


In [4]:
np.save(os.path.join(data_dir, 'eeg_topomap_data.npy'), eeg_topomap_data)

In [5]:
%%time
test_data = np.load(os.path.join(data_dir, 'eeg_topomap_data.npy'), allow_pickle=True).item()

CPU times: user 150 ms, sys: 12.6 s, total: 12.8 s
Wall time: 30 s
