# Saliency maps tutorial

In order to generate saliency maps, you need the dataset in the proper format (see the [prepare data tutorial](), and an architecture, trained or not (seeing saliency maps with an untrained architecture should be noise).

The first step to generating visualisation for the saliency maps will be to first compute the saliency maps.

## Generate Saliency maps

First we will set all the needed imports

In [1]:
import os
import logging
import pandas as pd
from meegnet.parsing import parser, save_config
from meegnet.network import Model
from meegnet.viz import compute_saliency_maps
from meegnet_functions import load_single_subject

ModuleNotFoundError: No module named 'huggingface_hub'

This next section sets up all the parameters we will need for the saliency maps computations

In [2]:
# We set up our data to be 3 channel types (MAG GRAD GRAD),
# 102 sensor locations (Elekta Neuromag Vector View 306 channel MEG),
# and 400 time samples for 800ms of signal sampled as 500Hz.
n_channels = 3
input_size = (n_channels, 102, 400)

n_outputs = 2 # using auditory vs visual stimulus classification -> 2 classes
n_subjects = 100 # For this tutorial we will only use a fraction of the data
n_samples = None # We will use all trials for each subject

# Setting up paths
save_path = "/home/arthur/data/"
model_path = save_path

clf_type = "eventclf"

OSError: [Errno 45] Operation not supported: '/home/arthur'

Loading the model from pretrained, using the from_pretrained method, can also load another model using the load method. It is also possible to comment both lines in order to use an untrained model.

In [5]:
# setting up a seed for reproducibility (will be used for numpy, pandas, torch, and the meegnet library)
seed = 42 

# net option can be "meegnet", "eegnet" etc, see documentation
net_option = meegnet

# name of the model
name = f"eventclf_{net_option}_{seed}_{n_channels}"

my_model = Model(name, net_option, input_size, n_outputs, save_path=save_path)
my_model.from_pretrained()
# my_model.load(model_path)

07/16/2024 10:05:22 PM => loading checkpoint '/home/arthur/.cache/huggingface/hub/models--lamaroufle--meegnet/snapshots/5f96fe8d1b9ce85462329cdb3f148e83d3383873/eventclf_meegnet_3_102_400_2.pt'


If the data was set-up correctly, we use participands_info.csv in order to generate a subject list and select a random subject for generating figures.

In [6]:
csv_file = os.path.join(save_path, f"participants_info.csv")
dataframe = (
    pd.read_csv(csv_file, index_col=0)
    .sample(frac=1, random_state=seed)
    .reset_index(drop=True)[: n_subjects]
)
subj_list = dataframe["sub"]
np.random.seed(seed)
random_subject_idx = np.random.choice(np.arange(len(subj_list)))

Finally, we compute the saliency maps and save them.

In [7]:
# name for the labels, will be useful for saving and generating figures
labels = ["visual", "auditory"]

# This will create a saliency maps path inside the save path
# Please don't change or it might break stuff later on
sal_path = os.path.join(save_path, "saliency_maps", name)
if not os.path.exists(sal_path):
    os.makedirs(sal_path)
    
# Only keep trials with 85% prediction confidence or more
confidence = .85

for sub in subj_list:
    dataset = Dataset(
            sfreq=500, # sampling frequency of 500Hz
            n_subjects=n_subjects,
            n_samples=n_samples,
            sensortype='ALL', # we use MAG GRAD GRAD here
            lso=true,
            random_state=seed,
    )
    dataset.load(save_path, one_sub=sub)
    compute_saliency_maps(
        dataset,
        labels,
        sub,
        sal_path,
        my_model.net,
        threshold = confidence, 
        clf_type = clf_type
    )

07/16/2024 10:05:23 PM Logging subjects and labels from /home/arthur/data...
07/16/2024 10:05:23 PM Loading subject CC410220
07/16/2024 10:05:25 PM 206 saliency maps computed for CC410220
07/16/2024 10:05:25 PM Logging subjects and labels from /home/arthur/data...
07/16/2024 10:05:25 PM Loading subject CC221324
07/16/2024 10:05:26 PM 238 saliency maps computed for CC221324
07/16/2024 10:05:26 PM Logging subjects and labels from /home/arthur/data...
07/16/2024 10:05:26 PM Loading subject CC310252
07/16/2024 10:05:27 PM 218 saliency maps computed for CC310252
07/16/2024 10:05:27 PM Logging subjects and labels from /home/arthur/data...
07/16/2024 10:05:27 PM Loading subject CC420167
07/16/2024 10:05:28 PM 210 saliency maps computed for CC420167
07/16/2024 10:05:28 PM Logging subjects and labels from /home/arthur/data...
07/16/2024 10:05:28 PM Loading subject CC722891
07/16/2024 10:05:29 PM 238 saliency maps computed for CC722891
07/16/2024 10:05:30 PM Logging subjects and labels from /hom

KeyboardInterrupt: 

In [None]:
sensors = ["MAG", "PLANNAR1", "PLANNAR2"]
stim_tick = 75 # The index for the stimulus timing (150ms in a 800ms trial at 500Hz)
saliency_types = ("pos", "neg")
cmap = "coolwarm"
# Some tested aletrnatives for the colormap:
# cmap = sns.color_palette("icefire", as_cmap=True)
# cmap = sns.color_palette("coolwarm", as_cmap=True, center="dark")
# cmap = "inferno"
# cmap = "seismic"

# if raw_path if left empty, or set to None, the function will use mne Elekta vectorview 306 for sensor location 
raw_path = None

In [None]:
##########################
### GENERATING FIGURES ###
##########################

def get_saliency_data(saliency_dict):
    saliencies = {}
    operation = lambda a, b: a - b
    for lab, pos in saliency_dict["pos"].items():
        saliencies[lab] = operation(np.array(pos), np.array(saliency_dict["neg"][lab]))
    return saliencies

all_saliencies = defaultdict(lambda: defaultdict(lambda: []))

LOG.info(f"Generating figure for sensors: {sensors}")
LOG.info(f"For the eventclf classification")

# First load all computed saliencies
for i, sub in enumerate(subj_list):
    sub_saliencies = defaultdict(lambda: {})
    for label in labels:
        nofile = False
        for saliency_type in saliency_types:
            lab = f"_{label}"
            saliency_file = os.path.join(
                sal_path,
                f"{sub}{lab}_{saliency_type}_sal_{confidence}confidence.npy",
            )
            if os.path.exists(saliency_file):
                try:
                    saliencies = np.load(saliency_file)
                    sub_saliencies[saliency_type][label] = saliencies
                except IOError:
                    LOG.warning(f"Error loading {saliency_file}")
                    nofile = True
                    continue
            else:
                nofile = True
                continue
            if len(saliencies.shape) == 3:
                saliencies = saliencies[np.newaxis, ...]  # If only one saliency in file
            elif len(saliencies.shape) != 4:
                nofile = True
                continue
            all_saliencies[saliency_type][label].append(saliencies.mean(axis=0))

        if nofile:
            continue

    skip = False
    if i == random_subject_idx:
        data_dict = get_saliency_data(sub_saliencies)
        for val in data_dict.values():
            if val.size == 0:
                skip = True
                break
        temp = {
            key: val[np.random.choice(np.arange(len(val)))]
            for key, val in data_dict.items()
        }
        out_path = generate_saliency_figure(
            temp,
            info_path=raw_path,
            save_path=visu_path,
            suffix=f"{clf_type}_{sub}_single_trial",
            sensors=sensors,
            title=f"Saliencies for a single trial of subject {sub}",
            clf_type=clf_type,
            cmap=cmap,
            stim_tick=stim_tick,
        )
        LOG.info(f"Figure generated: {out_path}")
        temp = {key: np.mean(val, axis=0) for key, val in data_dict.items()}
        out_path = generate_saliency_figure(
            temp,
            info_path=raw_path,
            save_path=visu_path,
            suffix=f"{clf_type}_{sub}_all_trials",
            sensors=sensors,
            title=f"Saliencies for the averaged trials of subject {sub}",
            clf_type=clf_type,
            cmap=cmap,
            stim_tick=stim_tick,
        )
        logging.info(f"Figure generated: {out_path}")
    if skip:
        random_subject_idx += 1
        continue

for label in labels:
    for saliency_type in saliency_types:
        if type(all_saliencies[saliency_type][label]) == list:
            all_saliencies[saliency_type][label] = np.array(
                all_saliencies[saliency_type][label]
            )
            
data_dict = get_saliency_data(all_saliencies)
final_dict = {key: np.mean(val, axis=0)[np.newaxis] for key, val in data_dict.items()}

out_path = generate_saliency_figure(
    final_dict,
    info_path=raw_path,
    save_path=visu_path,
    suffix=f"{clf_type}",
    sensors=sensors,
    title=f"Saliencies averaged across all subjects",
    clf_type=clf_type,
    cmap=cmap,
    stim_tick=stim_tick,
)
LOG.info(f"Figure generated: {out_path}")

