In [None]:
import os
from pathlib import Path

import mne
from mne_icalabel import label_components
from mne.preprocessing import ICA
import pandas as pd
from tqdm.notebook import trange

In [None]:
PROJECT_ROOT = Path("PATH/TO/PROJECT/ROOT")
BIDS_DIR = Path("PATH/TO/BIDS/ROOT")
ANNOTATION_DIR = PROJECT_ROOT / "data" / "annotations" / "ipu_v1"
TSV_DIR = PROJECT_ROOT / "data" / "annotations" / "channels_v1"
TARGET_CHANNELS = ["F5", "F3", "F1", "Fz", "F2", "F4", "F6", "F8",
                   "FC5", "FC3", "FC1", "FCz", "FC2", "FC4", "FC6",
                   "C5", "C3", "C1", "Cz", "C2", "C4", "C6", 
                   "CP5", "CP3", "CP1", "CPz", "CP2", "CP4", "CP6",
                   "P5", "P3", "P1", "Pz", "P2", "P4", "P6", "P8"]

DST_DIR = Path("PATH/TO/OUTPUT/DIRECTORY")
os.makedirs(DST_DIR, exist_ok=True)

In [None]:
# Session info
subject = "sub-003"  # e.g. "sub-003"
task = "conversation"  # either "conversation" or "resting"

# Directory to store figures
fig_dir = DST_DIR / "figs"
os.makedirs(fig_dir, exist_ok=True)

# ICA parameters
method = "fastica"
fit_params={}
n_components = None
random_state = 0

# Muscle component detection parameters
threshold = 0.5
muscle_l_freq = 7
muscle_h_freq = 40

# Downsampling
sample_freq = 500

# Filtering
l_freq = 1
h_freq = 40

In [None]:
def get_run_raw(subject, task, run):
    
    # Prepare raw
    raw = mne.io.read_raw(BIDS_DIR / subject / "eeg" / f"{subject}_task-{task}_run-{run}_eeg.edf", 
                          preload=True, verbose="ERROR")

    # Start of the conversation
    events = mne.find_events(raw, verbose="ERROR")
    start = events[0, 0] / raw.info["sfreq"]
    
    # Crop
    duration = 4 * 60
    raw = raw.crop(start, start + duration, verbose="ERROR")

    # Annotations
    annotations = mne.read_annotations(ANNOTATION_DIR / f"{subject}_run-{run}_ipu_annot.fif") 
    annotations.rename({"speech": "bad_speech"})
    raw.set_annotations(annotations, verbose="ERROR")

    # Add montage
    montage = mne.channels.make_standard_montage("biosemi64")
    raw.set_montage(montage, on_missing="ignore")

    # Downsample
    raw = raw.resample(sfreq=sample_freq, verbose="ERROR")

    # Filter
    raw = raw.filter(l_freq=l_freq, h_freq=h_freq, verbose="ERROR")

    return raw

In [None]:
# Prepare raw
raw_list = []
first = 1
end = 9
for run in trange(first, end):

    raw_list.append(get_run_raw(subject, task, run))

raw = mne.concatenate_raws(raw_list, verbose="ERROR")

# Downsample
#raw = raw.resample(sfreq=sample_freq, verbose="ERROR")

# Bad channels
bads = []
for run in range(first, end):
    df = pd.read_csv(TSV_DIR / f"{subject}_task-conversation_run-{run}_channels.tsv", sep="\t")
    
    # Add bad channels
    df = df[df["status"] == "bad"]
    bads.extend(df["name"].tolist())

raw.info["bads"] = list(set(bads))

In [None]:
#raw.plot(highpass=0.5, n_channels=64, picks=TARGET_CHANNELS + ["EMG1", "EMG2"], 
#         scalings={"eeg": 30e-6, "misc": 5e-10})

In [None]:
#raw.set_annotations(None)

In [None]:
%%capture
raw.drop_channels(["EMG1", "EMG2", "lEAR", "rEAR", "Status"])
raw = raw.set_eeg_reference("average")

In [None]:
# Compute ICA
assert len(raw.ch_names) == 64
ica = ICA(n_components=n_components, method=method, fit_params=fit_params, random_state=random_state)
ica.fit(raw)

In [None]:
figs = ica.plot_components()

if isinstance(figs, list):
    for i, fig in enumerate(figs):
        fig.savefig(fig_dir / f"{subject}_{task}_{i}_component.png")
else:
    figs.savefig(fig_dir / f"{subject}_{task}_component.png")

In [None]:
ic_labels = label_components(raw, ica, method="iclabel")
labels = ic_labels["labels"]
exclude_idx = [idx for idx, label in enumerate(labels) if label not in ["brain", "other"]]

In [None]:
muscle_idx_auto, scores = ica.find_bads_muscle(raw, threshold=threshold, l_freq=l_freq, h_freq=h_freq)

In [None]:
ica.exclude = exclude_idx + muscle_idx_auto

In [None]:
#ica.plot_sources(raw)

In [None]:
df = {"index": range(len(labels)), "labels": labels, 
      "probability": list(ic_labels["y_pred_proba"]), 
      "muscle": [idx in muscle_idx_auto for idx in range(len(labels))],
      "excluded": [idx in ica.exclude for idx in range(len(labels))]}
df = pd.DataFrame(df)
df

In [None]:
print(f"{len(ica.exclude)}")

In [None]:
figs = ica.plot_components()

if isinstance(figs, list):
    for i, fig in enumerate(figs):
        fig.savefig(fig_dir / f"{subject}_{task}_{i}_component.png")
else:
    figs.savefig(fig_dir / f"{subject}_{task}_component.png")

In [None]:
# ica.apply(raw)

In [None]:
#raw.plot(highpass=0.5, n_channels=64, picks=TARGET_CHANNELS, 
#         scalings={"eeg": 30e-6, "misc": 5e-10})

In [None]:
ica.save(DST_DIR / f"{subject}_task-{task}-ica.fif", overwrite=True)

In [None]:
df.to_csv(DST_DIR / f"{subject}_task-{task}_excluded.csv", index=False)