In [1]:
import os
import logging
import configparser
import torch
import numpy as np
import pandas as pd
from meegnet_functions import load_single_subject
from meegnet.parsing import parser, save_config
from meegnet.network import Model
from meegnet.utils import cuda_check
from meegnet.viz import (
    get_positive_negative_saliency,
    compute_saliency_based_psd,
)
from pytorch_grad_cam import GuidedBackpropReLUModel


LOG = logging.getLogger("meegnet")
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%m/%d/%Y %I:%M:%S %p",
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def compute_saliency_maps(
    dataset,
    labels,
    sub,
    sal_path,
    net,
    threshold,
    w_size,
    sfreq,
    clf_type="",
    compute_psd=False,
):

    device = cuda_check()
    GBP = GuidedBackpropReLUModel(net, device=device)

    # Load all trials and corresponding labels for a specific subject.
    data = dataset.data
    targets = dataset.labels
    if clf_type == "eventclf":
        target_saliencies = [[[], []], [[], []]]
        target_psd = [[[], []], [[], []]]
    else:
        target_saliencies = [[], []]
        target_psd = [[], []]

    # For each of those trial with associated label:
    for trial, label in zip(data, targets):
        X = trial
        while len(X.shape) < 4:
            X = X[np.newaxis, :]
        X = X.to(device)
        # Compute predictions of the trained network, and confidence
        preds = torch.nn.Softmax(dim=1)(net(X)).detach().cpu()
        pred = preds.argmax().item()
        confidence = preds.max()
        label = int(label)

        # If the confidence reaches desired treshhold (given by args.confidence)
        if confidence >= threshold and pred == label:
            # Compute Guided Back-propagation for given label projected on given data X
            guided_grads = GBP(X.to(device), label)
            guided_grads = np.rollaxis(guided_grads, 2, 0)
            # Compute saliencies
            pos_saliency, neg_saliency = get_positive_negative_saliency(guided_grads)

            # Depending on the task, add saliencies in lists
            if clf_type == "eventclf":
                target_saliencies[label][0].append(pos_saliency)
                target_saliencies[label][1].append(neg_saliency)
                if compute_psd:
                    target_psd[label][0].append(
                        compute_saliency_based_psd(pos_saliency, X, w_size, sfreq)
                    )
                    target_psd[label][1].append(
                        compute_saliency_based_psd(neg_saliency, X, w_size, sfreq)
                    )
            else:
                target_saliencies[0].append(pos_saliency)
                target_saliencies[1].append(neg_saliency)
                if compute_psd:
                    target_psd[0].append(
                        compute_saliency_based_psd(pos_saliency, X, w_size, sfreq)
                    )
                    target_psd[1].append(
                        compute_saliency_based_psd(neg_saliency, X, w_size, sfreq)
                    )
    # With all saliencies computed, we save them in the specified save-path
    n_saliencies = 0
    n_saliencies += sum([len(e) for e in target_saliencies[0]])
    n_saliencies += sum([len(e) for e in target_saliencies[1]])
    LOG.info(f"{n_saliencies} saliency maps computed for {sub}")
    for j, sal_type in enumerate(("pos", "neg")):
        if clf_type == "eventclf":
            for i, label in enumerate(labels):
                sal_filepath = os.path.join(
                    sal_path,
                    f"{sub}_{labels[i]}_{sal_type}_sal_{threshold}confidence.npy",
                )
                np.save(sal_filepath, np.array(target_saliencies[i][j]))
                if compute_psd:
                    psd_filepath = os.path.join(
                        psd_path,
                        f"{sub}_{labels[i]}_{sal_type}_psd_{threshold}confidence.npy",
                    )
                    np.save(psd_filepath, np.array(target_psd[i][j]))
        else:
            lab = "" if clf_type == "subclf" else f"_{labels[label]}"
            sal_filepath = os.path.join(
                sal_path,
                f"{sub}{lab}_{sal_type}_sal_{threshold}confidence.npy",
            )
            np.save(sal_filepath, np.array(target_saliencies[j]))
            if compute_psd:
                lab = "" if clf_type == "subclf" else f"_{labels[label]}"
                psd_filepath = os.path.join(
                    psd_path,
                    f"{sub}{lab}_{sal_type}_psd_{threshold}confidence.npy",
                )
                np.save(psd_filepath, np.array(target_psd[j]))

def get_saliency_data(saliency_dict, option):
    if option in ("pos", "neg"):
        return saliency_dict[option]
    else:
        saliencies = {}
        if option == "sum":
            operation = lambda a, b: a + b
        else:
            operation = lambda a, b: a - b
        for lab, pos in saliency_dict["pos"].items():
            saliencies[lab] = operation(np.array(pos), np.array(saliency_dict["neg"][lab]))
    return saliencies

In [3]:
###########
# PARSING #
###########

# For Jupyter Notebook Only: we create fake args
argstring = "--config ../scripts/eventclf.ini".split()
args = parser.parse_args(argstring)
save_config(vars(args), args.config)
default_values = configparser.ConfigParser()
default_values.read("../default_values.ini")
default_values = default_values["config"]

fold = None if args.fold == -1 else int(args.fold)
if args.clf_type == "eventclf":
    assert (
        args.datatype != "rest"
    ), "datatype must be set to passive in order to run event classification"

if args.feature == "bins":
    trial_length = default_values["TRIAL_LENGTH_BINS"]
elif args.feature == "bands":
    trial_length = default_values["TRIAL_LENGTH_BANDS"]
elif args.feature == "temporal":
    trial_length = default_values["TRIAL_LENGTH_TIME"]

if args.clf_type == "subclf":
    trial_length = int(args.segment_length * args.sfreq)
if args.clf_type == "eventclf":
    labels = ["visual", "auditory"]
else:
    labels = []
    
if args.feature == "bins":
    trial_length = default_values["TRIAL_LENGTH_BINS"]
elif args.feature == "bands":
    trial_length = default_values["TRIAL_LENGTH_BANDS"]
elif args.feature == "temporal":
    trial_length = default_values["TRIAL_LENGTH_TIME"]

if args.sensors == "MAG":
    n_channels = default_values["N_CHANNELS_MAG"]
    chan_index = [0]
elif args.sensors == "GRAD":
    n_channels = default_values["N_CHANNELS_GRAD"]
    chan_index = [1, 2]
else:
    n_channels = default_values["N_CHANNELS_OTHER"]
    chan_index = [0, 1, 2]

input_size = (n_channels // 102, 102, trial_length)

name = f"{args.clf_type}_{args.model_name}_{args.seed}_{args.sensors}"
suffixes = ""
if args.net_option == "custom_net":
    if args.batchnorm:
        suffixes += "_BN"
    if args.maxpool != 0:
        suffixes += f"_maxpool{args.maxpool}"

    name += f"_dropout{args.dropout}_filter{args.filters}_nchan{args.nchan}_lin{args.linear}_depth{args.hlayers}"
    name += suffixes

n_samples = None if int(args.n_samples) == -1 else int(args.n_samples)
if args.clf_type == "subclf":
    data_path = os.path.join(args.save_path, f"downsampled_{args.sfreq}")
    n_subjects = len(os.listdir(data_path))
    n_outputs = min(n_subjects, args.max_subj)
    lso = False
else:
    n_outputs = 2
    lso = True

In [4]:
##############################
### PREPARING SAVE FOLDERS ###
##############################

if args.compute_psd:
    psd_path = os.path.join(args.save_path, "saliency_based_psds", name)
    if not os.path.exists(psd_path):
        os.makedirs(psd_path)

sal_path = os.path.join(args.save_path, "saliency_maps", name)
if not os.path.exists(sal_path):
    os.makedirs(sal_path)

In [5]:
#####################
### LOADING MODEL ###
#####################

if args.model_path is None:
    model_path = args.save_path
else:
    model_path = args.model_path

if not os.path.exists(model_path):
    LOG.info(f"{model_path} does not exist. Creating folders")
    os.makedirs(model_path)

my_model = Model(name, args.net_option, input_size, n_outputs, save_path=args.save_path)
my_model.from_pretrained()
# my_model.load()

07/16/2024 10:05:22 PM => loading checkpoint '/home/arthur/.cache/huggingface/hub/models--lamaroufle--meegnet/snapshots/5f96fe8d1b9ce85462329cdb3f148e83d3383873/eventclf_meegnet_3_102_400_2.pt'


In [6]:
#######################
### PRELOADING DATA ###
#######################

csv_file = os.path.join(args.save_path, f"participants_info.csv")
dataframe = (
    pd.read_csv(csv_file, index_col=0)
    .sample(frac=1, random_state=args.seed)
    .reset_index(drop=True)[: args.max_subj]
)
subj_list = dataframe["sub"]
np.random.seed(args.seed)
random_subject_idx = np.random.choice(np.arange(len(subj_list)))

In [7]:
############################
### COMPUTING SALIENCIES ###
############################

for sub in subj_list:
    dataset = load_single_subject(sub, n_samples, lso, args)
    compute_saliency_maps(
        dataset,
        labels,
        sub,
        sal_path,
        my_model.net,
        args.confidence,
        args.w_size,
        args.sfreq,
        args.clf_type,
        args.compute_psd,
    )

07/16/2024 10:05:23 PM Logging subjects and labels from /home/arthur/data...
07/16/2024 10:05:23 PM Loading subject CC410220
07/16/2024 10:05:25 PM 206 saliency maps computed for CC410220
07/16/2024 10:05:25 PM Logging subjects and labels from /home/arthur/data...
07/16/2024 10:05:25 PM Loading subject CC221324
07/16/2024 10:05:26 PM 238 saliency maps computed for CC221324
07/16/2024 10:05:26 PM Logging subjects and labels from /home/arthur/data...
07/16/2024 10:05:26 PM Loading subject CC310252
07/16/2024 10:05:27 PM 218 saliency maps computed for CC310252
07/16/2024 10:05:27 PM Logging subjects and labels from /home/arthur/data...
07/16/2024 10:05:27 PM Loading subject CC420167
07/16/2024 10:05:28 PM 210 saliency maps computed for CC420167
07/16/2024 10:05:28 PM Logging subjects and labels from /home/arthur/data...
07/16/2024 10:05:28 PM Loading subject CC722891
07/16/2024 10:05:29 PM 238 saliency maps computed for CC722891
07/16/2024 10:05:30 PM Logging subjects and labels from /hom

KeyboardInterrupt: 

In [None]:
#########################
### HARD CODED VALUES ###
#########################

# TODO add those to a TOML file, either config or default_values
sensors = ["MAG", "PLANNAR1", "PLANNAR2"]
cmap = "coolwarm"
stim_tick = 75
saliency_types = ("pos", "neg")
saliency_options = ("pos", "neg", "sum", "diff")

# Some tested aletrnatives for the colormap:
# cmap = sns.color_palette("icefire", as_cmap=True)
# cmap = sns.color_palette("coolwarm", as_cmap=True, center="dark")
# cmap = "inferno"
# cmap = "seismic"

In [None]:
##########################
### GENERATING FIGURES ###
##########################

all_saliencies = defaultdict(lambda: defaultdict(lambda: []))

LOG.info(f"Generating figure for sensors: {sensors}")
LOG.info(f"For the {args.clf_type} classification")

# label contains same information as sub for subclf but we only load files that have label == sub
for i, sub in enumerate(subj_list):
    sub_saliencies = defaultdict(lambda: {})
    for label in labels:
        if args.clf_type == "subclf":
            if label != sub:
                continue
        nofile = False

        for saliency_type in saliency_types:
            lab = "" if args.clf_type == "subclf" else f"_{label}"
            saliency_file = os.path.join(
                sal_path,
                f"{sub}{lab}_{saliency_type}_sal_{args.confidence}confidence.npy",
            )
            if os.path.exists(saliency_file):
                try:
                    saliencies = np.load(saliency_file)
                    sub_saliencies[saliency_type][label] = saliencies
                except IOError:
                    logging.warning(f"Error loading {saliency_file}")
                    nofile = True
                    continue
            else:
                nofile = True
                continue
            if len(saliencies.shape) == 3:
                saliencies = saliencies[np.newaxis, ...]  # If only one saliency in file
            elif len(saliencies.shape) != 4:
                nofile = True
                continue
            all_saliencies[saliency_type][label].append(saliencies.mean(axis=0))

        if nofile:
            continue
        if args.clf_type == "subclf":
            break  # we only need to add one label per subject so we get out of the loop

    skip = False
    for option in saliency_options:
        if i == random_subject_idx:
            data_dict = get_saliency_data(sub_saliencies, option)
            for val in data_dict.values():
                if val.size == 0:
                    skip = True
                    break
            temp = {
                key: val[np.random.choice(np.arange(len(val)))]
                for key, val in data_dict.items()
            }
            out_path = generate_saliency_figure(
                temp,
                info_path=args.raw_path,
                save_path=visu_path,
                suffix=f"{args.clf_type}_{sub}_single_trial_{option}",
                sensors=sensors,
                title=f"{option} saliencies for a single trial of subject {sub}",
                clf_type=args.clf_type,
                cmap=cmap,
                stim_tick=stim_tick,
            )
            logging.info(f"Figure generated: {out_path}")
            temp = {key: np.mean(val, axis=0) for key, val in data_dict.items()}
            out_path = generate_saliency_figure(
                temp,
                info_path=args.raw_path,
                save_path=visu_path,
                suffix=f"{args.clf_type}_{sub}_all_trials_{option}",
                sensors=sensors,
                title=f"{option} saliencies for the averaged trials of subject {sub}",
                clf_type=args.clf_type,
                cmap=cmap,
                stim_tick=stim_tick,
            )
            logging.info(f"Figure generated: {out_path}")
    if skip:
        random_subject_idx += 1
        continue

for label in labels:
    for saliency_type in saliency_types:
        if type(all_saliencies[saliency_type][label]) == list:
            all_saliencies[saliency_type][label] = np.array(
                all_saliencies[saliency_type][label]
            )
for option in saliency_options:
    data_dict = get_saliency_data(all_saliencies, option)
    # np.newaxis here is a quick fix to a problem that might stick with other clf types
    final_dict = {key: np.mean(val, axis=0)[np.newaxis] for key, val in data_dict.items()}
    if args.clf_type == "subclf":
        final_dict = {"all_subj": np.mean(list(final_dict.values()), axis=0)}
        labels = ["all_subj"]

    out_path = generate_saliency_figure(
        final_dict,
        info_path=args.raw_path,
        save_path=visu_path,
        suffix=f"{args.clf_type}_{option}",
        sensors=sensors,
        title=f"{option} saliencies averaged across all subjects",
        clf_type=args.clf_type,
        cmap=cmap,
        stim_tick=stim_tick,
    )
    logging.info(f"Figure generated: {out_path}")

