In [572]:
import mne
import matplotlib.pyplot as plt
import PyQt5
import seaborn as sns
import numpy as np
import pandas as pd
import glob
from scipy import io, stats

In [573]:
plt.switch_backend('QtAgg')

In [574]:
def load_data(cnt_data, mrk_data, session = 1):
    return { 
        "clab": [_[0] for _ in cnt_data["cnt"][0,session-1]["clab"][0,0][0]],
        "fs": cnt_data["cnt"][0,session-1]["fs"][0,0][0,0],
        "x": cnt_data["cnt"][0,session-1]["x"][0,0],
        "time": mrk_data["mrk"][0,session-1]["time"][0,0][0] * 1e-3,    # conversion from 'ms' to 's'
        "y": mrk_data["mrk"][0,session-1]["y"][0,0],
        "event": [_[0] for _ in mrk_data["mrk"][0,session-1]["event"][0,0][0,0][0]],
        "className": [_[0] for _ in mrk_data["mrk"][0,session-1]["className"][0,0][0]]
    }

In [575]:
dataset = 'A'

subject = '29'              # subject id to determine file path

tmin, tmax = -5.0, 20.0     # epoch start/end relative to event marker (seconds)
baseline = (None, -2.0)     # baseline correction  

In [576]:
channels = {
    'AF7Fp1lowWL': 'S1_D0 760',
    'AF3Fp1lowWL': 'S2_D0 760',
    'AF3AFzlowWL': 'S2_D1 760',
    'FpzFp1lowWL': 'S3_D0 760',
    'FpzAFzlowWL': 'S3_D1 760',
    'FpzFp2lowWL': 'S3_D2 760',
    'AF4AFzlowWL': 'S4_D1 760',
    'AF4Fp2lowWL': 'S4_D2 760',
    'AF8Fp2lowWL': 'S5_D2 760',
    'OzPOzlowWL': 'S6_D3 760',
    'OzO1lowWL': 'S6_D5 760',
    'OzO2lowWL': 'S6_D6 760',
    'C5CP5lowWL': 'S8_D4 760',
    'C5FC5lowWL': 'S8_D8 760',
    'C5C3lowWL': 'S8_D9 760',
    'FC3FC5lowWL': 'S9_D8 760',
    'FC3C3lowWL': 'S9_D9 760',
    'FC3FC1lowWL': 'S9_D10 760',
    'CP3CP5lowWL': 'S10_D4 760',
    'CP3C3lowWL': 'S10_D9 760',
    'CP3CP1lowWL': 'S10_D11 760',
    'C1C3lowWL': 'S11_D9 760',
    'C1FC1lowWL': 'S11_D10 760',
    'C1CP1lowWL': 'S11_D11 760',
    'C2FC2lowWL': 'S12_D12 760',
    'C2CP2lowWL': 'S12_D13 760',
    'C2C4lowWL': 'S12_D14 760',
    'FC4FC2lowWL': 'S13_D12 760',
    'FC4C4lowWL': 'S13_D14 760',
    'FC4FC6lowWL': 'S13_D15 760',
    'CP4CP6lowWL': 'S14_D7 760',
    'CP4CP2lowWL': 'S14_D13 760',
    'CP4C4lowWL': 'S14_D14 760',
    'C6CP6lowWL': 'S15_D7 760',
    'C6C4lowWL': 'S15_D14 760',
    'C6FC6lowWL': 'S15_D15 760',
    'AF7Fp1highWL': 'S1_D0 850',
    'AF3Fp1highWL': 'S2_D0 850',
    'AF3AFzhighWL': 'S2_D1 850',
    'FpzFp1highWL': 'S3_D0 850',
    'FpzAFzhighWL': 'S3_D1 850',
    'FpzFp2highWL': 'S3_D2 850',
    'AF4AFzhighWL': 'S4_D1 850',
    'AF4Fp2highWL': 'S4_D2 850',
    'AF8Fp2highWL': 'S5_D2 850',
    'OzPOzhighWL': 'S6_D3 850',
    'OzO1highWL': 'S6_D5 850',
    'OzO2highWL': 'S6_D6 850',
    'C5CP5highWL': 'S8_D4 850',
    'C5FC5highWL': 'S8_D8 850',
    'C5C3highWL': 'S8_D9 850',
    'FC3FC5highWL': 'S9_D8 850',
    'FC3C3highWL': 'S9_D9 850',
    'FC3FC1highWL': 'S9_D10 850',
    'CP3CP5highWL': 'S10_D4 850',
    'CP3C3highWL': 'S10_D9 850',
    'CP3CP1highWL': 'S10_D11 850',
    'C1C3highWL': 'S11_D9 850',
    'C1FC1highWL': 'S11_D10 850',
    'C1CP1highWL': 'S11_D11 850',
    'C2FC2highWL': 'S12_D12 850',
    'C2CP2highWL': 'S12_D13 850',
    'C2C4highWL': 'S12_D14 850',
    'FC4FC2highWL': 'S13_D12 850',
    'FC4C4highWL': 'S13_D14 850',
    'FC4FC6highWL': 'S13_D15 850',
    'CP4CP6highWL': 'S14_D7 850',
    'CP4CP2highWL': 'S14_D13 850',
    'CP4C4highWL': 'S14_D14 850',
    'C6CP6highWL': 'S15_D7 850',
    'C6C4highWL': 'S15_D14 850',
    'C6FC6highWL': 'S15_D15 850'
}

In [577]:
mnt = io.loadmat(f'../dataset\\NIRS_01-29\\subject {subject}\\mnt.mat')

ch_pos = {}

sources = [_[0] for _ in mnt["mnt"]["source"][0,0]["clab"][0,0][0]]
detectors = [_[0] for _ in mnt["mnt"]["detector"][0,0]["clab"][0,0][0]]

for ch_idx in range(len(sources)):
    if sources[ch_idx] != "-":
        ch_pos[f"S{ch_idx}"] = mnt["mnt"]["source"][0,0]["pos_3d"][0,0][:, ch_idx] * 1e-1

for ch_idx in range(len(detectors)):
    if detectors[ch_idx] != "-":
        ch_pos[f"D{ch_idx}"] = mnt["mnt"]["detector"][0,0]["pos_3d"][0,0][:, ch_idx] * 1e-1

montage = mne.channels.make_dig_montage(ch_pos=ch_pos)

In [578]:
%%capture

raws = []

cnt = io.loadmat(f'../dataset\\NIRS_01-29\\subject {subject}\\cnt.mat')
mrk = io.loadmat(f'../dataset\\NIRS_01-29\\subject {subject}\\mrk.mat')

# session 1,3,5 for Dataset A (lhmi/rhmi)
# session 2,4,6 for Dateset B (ma/baseline)
sessions = [1,3,5] if dataset == 'A' else [2,4,6]

for session in sessions:
    task = load_data(cnt, mrk, session)

    sfreq = task["fs"]  # 10Hz
    ch_names = [ch for ch in channels.values()]
    
    ch_types = ["fnirs_cw_amplitude"] * len(channels)

    data = task["x"].transpose() # ~ 120,000 samples and ~ 600 seconds
    
    sorted_data = np.empty(data.shape)
    sorted_ch_names = []

    for i in range(int(len(ch_names)/2)):
        sorted_data[2*i] = data[i]
        sorted_data[2*i+1] = data[i+36]
        sorted_ch_names.append(ch_names[i])
        sorted_ch_names.append(ch_names[i+36])

    onset = task["time"]
    duration = [10] * len(task["time"])
    description = list(map(lambda y: "cond1" if y else "cond2", task["y"][0]))

    # For Dataset A => cond1 = lhmi and cond2 = rhmi
    # For Dataset B => cond1 = ma and cond2 = baseline

    info = mne.create_info(ch_names=sorted_ch_names, sfreq=sfreq, ch_types=ch_types)

    for ch in range(len(sorted_ch_names)):
        info["chs"][ch]["loc"][9] = 760 if sorted_ch_names[ch].endswith("760") else 850

    annotations = mne.Annotations(onset=onset, duration=duration, description=description)

    raw = mne.io.RawArray(data=sorted_data, info=info)

    raw.set_annotations(annotations)
    raw.set_montage(montage)
    
    raws.append(raw)

raw = mne.concatenate_raws(raws)

In [579]:
# raw.plot(block=True, duration=500, n_channels=len(raw.ch_names)//3)

In [580]:
# dists = mne.preprocessing.nirs.source_detector_distances(raw.info)
# dists

In [581]:
raw = mne.preprocessing.nirs.optical_density(raw)

In [582]:
# raw.plot(block=True, duration=500)

In [583]:
sci = mne.preprocessing.nirs.scalp_coupling_index(raw)
# fig, ax = plt.subplots(layout="constrained")
# ax.hist(sci)
# ax.set(xlabel="Scalp Coupling Index", ylabel="Count", xlim=[0, 1])
# plt.show()

In [584]:
raw.info["bads"] = [raw.ch_names[i] for i in range(len(raw.ch_names)) if sci[i] < 0.5]

In [585]:
# raw.plot()

In [586]:
raw = mne.preprocessing.nirs.beer_lambert_law(raw, ppf=0.1)

  raw = mne.preprocessing.nirs.beer_lambert_law(raw, ppf=0.1)


In [587]:
%%capture

# raw_unfiltered = raw.copy()

raw.filter(picks="all", l_freq=0.01, h_freq=0.1, method="iir" , iir_params=None)     # band pass filters

# fig, ax = plt.subplots(2)

# raw_unfiltered.plot_psd(ax=ax[0], show=False)
# raw.plot_psd(ax=ax[1], show=False)

# ax[0].set_title('PSD before filtering')
# ax[1].set_title('PSD after filtering')
# ax[1].set_xlabel('Frequency (Hz)')

# fig.set_tight_layout(True)

# plt.show()  # block execution and analyze plots

In [588]:
events, event_id = mne.events_from_annotations(raw)

Used Annotations descriptions: ['cond1', 'cond2']


In [589]:
epochs = mne.Epochs(
    raw, 
    events, 
    event_id=event_id, 
    tmin=tmin, 
    tmax=tmax, 
    baseline=baseline,
    reject_by_annotation=True)

Not setting metadata
60 matching events found
Setting baseline interval to [-5.0, -2.0] s
Applying baseline correction (mode: mean)
0 projection items activated


In [590]:
epochs.save(f"epochs_fnirs/Dataset_{dataset}/subject_{subject}_epo.fif", overwrite=True)

Overwriting existing file.
Using data from preloaded Raw for 60 events and 251 original time points ...
0 bad epochs dropped
Using data from preloaded Raw for 1 events and 251 original time points ...
Overwriting existing file.
Using data from preloaded Raw for 60 events and 251 original time points ...
