# Mount Google Drive

In [None]:
from google.colab import drive

mounted_dir = '/content/drive/'
drive.mount(mounted_dir, force_remount=True)

In [None]:
import os

# project folder path in GDrive
project_dir = ''

# you need to have three folders inside the project folder:
# Stimulus: it contains the stimulus of the experiment
# MEG: it contains the MEG files of an anonymous subject
# Result: the results will be saved in this folder
full_image_path = os.path.join(mounted_dir, 'MyDrive', project_dir, 'Stimulus')
full_raw_path = os.path.join(mounted_dir, 'MyDrive', project_dir, 'MEG')
full_output_path = os.path.join(mounted_dir, 'MyDrive', project_dir, 'Result')

# Extract Features from AlexNet

## Install THINGSvision and dependencies

In [None]:
!pip install --upgrade thingsvision

!pip install ipywidgets

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision.models import alexnet
from thingsvision import get_extractor
from thingsvision import get_extractor_from_model
from thingsvision.utils.storing import save_features
from thingsvision.utils.data import ImageDataset, DataLoader
from thingsvision.core.extraction import center_features
from thingsvision.core.rsa import compute_rdm
from typing import Any, List

## Helper functions to extract features

In [None]:
def extract_features(
                    extractor: Any,
                    module_name: str,
                    image_path: str,
                    out_path: str,
                    batch_size: int,
                    flatten_activations: bool,
                    apply_center_crop: bool,
                    class_names: List[str]=None,
                    file_names: List[str]=None,
) -> np.ndarray:
    """Extract features for a single layer."""
    dataset = ImageDataset(
        root=image_path,
        out_path=out_path,
        backend=extractor.get_backend(),
        transforms=extractor.get_transformations(apply_center_crop=apply_center_crop),
        class_names=class_names,
        file_names=file_names,
    )
    batches = DataLoader(dataset=dataset, batch_size=batch_size, backend=extractor.get_backend())
    features = extractor.extract_features(
                    batches=batches,
                    module_name=module_name,
                    flatten_acts=flatten_activations,
    )
    return features

## Variables

In [None]:
pretrained = True # use pretrained model weights
model_path = None # if pretrained = False (i.e., randomly initialized weights) set path to model weights
batch_size = 16 # use a power of two (this can be any size, depending on the number of images for which you aim to extract features)
apply_center_crop = False # center crop images (set to False, if you don't want to center-crop images)
flatten_activations = True # whether or not features (e.g., of Conv layers) should be flattened
class_names = None  # optional list of class names for class dataset
file_names = None # optional list of file names according to which features should be sorted
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Load AlexNet

In [None]:
model_name = 'alexnet'
module_name = 'features.12'

# specify model source
source = 'torchvision'
# initialize the extractor
extractor = get_extractor(
            model_name=model_name,
            pretrained=pretrained,
            model_path=model_path,
            device=device,
            source=source)

## Extract features for a single layer

In [None]:
features_alexnet = extract_features(
                            extractor=extractor,
                            module_name=module_name,
                            image_path=full_image_path,
                            out_path=full_output_path,
                            batch_size=batch_size,
                            flatten_activations=flatten_activations,
                            apply_center_crop=apply_center_crop,
                            class_names=class_names,
                            file_names=file_names)

# apply centering (not necessary, but may be desirable, depending on the analysis)
features_alexnet = center_features(features_alexnet)

# save features to disk
save_features(features_alexnet, out_path=f'{full_output_path}/features_{model_name}_{module_name}', file_format='npy')

## Load extracted features

In [None]:
features_alexnet = np.load(f'{full_output_path}/features_{model_name}_{module_name}/features.npy', encoding='bytes')

## Create RDM

In [None]:
# average every 100 rows
features_alexnet_sum = 0

for i in range(100):
  features_alexnet_sum = features_alexnet[i::100] + features_alexnet_sum

# create the RDM
rdm_dnn = compute_rdm(features_alexnet_sum/100, method='correlation')

## Visualize RDM

In [None]:
plt.figure(figsize=(10, 4), dpi=200)
plt.imshow(rdm_dnn.reshape(rdm_dnn.shape))

# Extract Features from MEG

## Install MNE-Python and dependencies



In [None]:
!pip install mne

In [None]:
import pandas as pd
import numpy as np
import scipy
import mne
from mne.io import read_raw_fif

## Read raw file and create epochs

In [None]:
raw_1 = read_raw_fif(full_raw_path+'/raw_1.fif', verbose='error', preload=True).pick_types(meg='mag', stim=True, eeg=False, eog=False)
raw_2 = read_raw_fif(full_raw_path+'/raw_2.fif', verbose='error', preload=True).pick_types(meg='mag', stim=True, eeg=False, eog=False)
raw_3 = read_raw_fif(full_raw_path+'/raw_3.fif', verbose='error', preload=True).pick_types(meg='mag', stim=True, eeg=False, eog=False)
raw_4 = read_raw_fif(full_raw_path+'/raw_4.fif', verbose='error', preload=True).pick_types(meg='mag', stim=True, eeg=False, eog=False)
raw_5 = read_raw_fif(full_raw_path+'/raw_5.fif', verbose='error', preload=True).pick_types(meg='mag', stim=True, eeg=False, eog=False)
raw_6 = read_raw_fif(full_raw_path+'/raw_6.fif', verbose='error', preload=True).pick_types(meg='mag', stim=True, eeg=False, eog=False)

# smooth the data
raw_1.savgol_filter(10, verbose=None)
raw_2.savgol_filter(10, verbose=None)
raw_3.savgol_filter(10, verbose=None)
raw_4.savgol_filter(10, verbose=None)
raw_5.savgol_filter(10, verbose=None)
raw_6.savgol_filter(10, verbose=None)

# downsample the data
desired_sfreq = 200
current_sfreq = raw_1.info['sfreq']
decim = np.round(current_sfreq / desired_sfreq).astype(int)
obtained_sfreq = current_sfreq / decim
lowpass_freq = obtained_sfreq / 3.
raw_1.filter(l_freq=None, h_freq=lowpass_freq)
raw_2.filter(l_freq=None, h_freq=lowpass_freq)
raw_3.filter(l_freq=None, h_freq=lowpass_freq)
raw_4.filter(l_freq=None, h_freq=lowpass_freq)
raw_5.filter(l_freq=None, h_freq=lowpass_freq)
raw_6.filter(l_freq=None, h_freq=lowpass_freq)

# concatenate raws
raw = raw_1
raw.append(raw_2)
raw.append(raw_3)
raw.append(raw_4)
raw.append(raw_5)
raw.append(raw_6)
del raw_1, raw_2, raw_3, raw_4, raw_5, raw_6

# create events
events = mne.find_events(raw, stim_channel='STI101', min_duration=.005, shortest_event=1, output='onset')

# create epochs
event_sample_dict = {'N1S1TFA1':1, 'N1S1TFA2':2, 'N1S2TFA1':3, 'N1S2TFA2':4,
                     'N1S3TFA1':5, 'N1S3TFA2':6, 'N1S4TFA1':7, 'N1S4TFA2':8,
                     'N2S1TFA1':9, 'N2S1TFA2':10, 'N2S2TFA1':11, 'N2S2TFA2':12,
                     'N2S3TFA1':13, 'N2S3TFA2':14, 'N2S4TFA1':15, 'N2S4TFA2':16,
                     'N3S1TFA1':17, 'N3S1TFA2':18, 'N3S2TFA1':19, 'N3S2TFA2':20,
                     'N3S3TFA1':21, 'N3S3TFA2':22, 'N3S4TFA1':23, 'N3S4TFA2':24,
                     'N4S1TFA1':25, 'N4S1TFA2':26, 'N4S2TFA1':27, 'N4S2TFA2':28,
                     'N4S3TFA1':29, 'N4S3TFA2':30, 'N4S4TFA1':31, 'N4S4TFA2':32}
epochs = mne.Epochs(raw, events, event_id=event_sample_dict, baseline=(-0.1,0), tmin=-0.1, tmax=0.8, decim=decim)
epochs.equalize_event_counts()

## Visualize average of each sample

In [None]:
epochs.average().plot_joint(picks='mag')

In [None]:
n_events, n_channels, n_times = epochs.get_data(picks='mag').shape
print(f"Number of Events: {n_events}, Number of Channels (Mag): {n_channels}, Time Points: {n_times}")

## Install RSAToolbox and dependencies


In [None]:
!pip install rsatoolbox

In [None]:
import rsatoolbox
from rsatoolbox.rdm import calc_rdm_movie
from scipy.spatial.distance import squareform

## Create RDMs

In [None]:
event_sample_list = ['N1S1TFA1', 'N1S1TFA2', 'N1S2TFA1', 'N1S2TFA2',
                      'N1S3TFA1', 'N1S3TFA2', 'N1S4TFA1', 'N1S4TFA2',
                      'N2S1TFA1', 'N2S1TFA2', 'N2S2TFA1', 'N2S2TFA2',
                      'N2S3TFA1', 'N2S3TFA2', 'N2S4TFA1', 'N2S4TFA2',
                      'N3S1TFA1', 'N3S1TFA2', 'N3S2TFA1', 'N3S2TFA2',
                      'N3S3TFA1', 'N3S3TFA2', 'N3S4TFA1', 'N3S4TFA2',
                      'N4S1TFA1', 'N4S1TFA2', 'N4S2TFA1', 'N4S2TFA2',
                      'N4S3TFA1', 'N4S3TFA2', 'N4S4TFA1', 'N4S4TFA2']

data = rsatoolbox.data.TemporalDataset(epochs.get_data(),
                                       channel_descriptors={'names': epochs.ch_names},
                                       obs_descriptors={'stimulus': event_sample_list},
                                       time_descriptors={'time': epochs.times})

rdms = rsatoolbox.rdm.calc.calc_rdm_movie(data,
                                          method='euclidean',
                                          descriptor=None,
                                          noise=None,
                                          cv_descriptor=None,
                                          prior_lambda=1,
                                          prior_weight=0.1,
                                          time_descriptor='time',
                                          bins=None)

## Helper function to visualize the RDMs

In [None]:
from typing import List, Optional
from scipy.spatial.distance import squareform

def plot_rdm_movie(rdms_data: rsatoolbox.rdm.RDMs,
                   descriptor: str,
                   n_t_display:int = 20, #
                   fig_width: Optional[int] = None,
                   timecourse_plot_rel_height: Optional[int] = None,
                   time_formatted: Optional[List[str]] = None,
                   colored_conditions: Optional[list] = None,
                   plot_individual_dissimilarities: Optional[bool] = None,
                   ):
    """ plots the RDM movie for a given descriptor

    Args:
        rdms_data (rsatoolbox.rdm.RDMs): rdm movie
        descriptor (str): name of the descriptor that created the rdm movie
        n_t_display (int, optional): number of RDM time points to display. Defaults to 20.
        fig_width (int, optional):  width of the figure (in inches). Defaults to None.
        timecourse_plot_rel_height (int, optional): height of the timecourse plot (relative to the rdm movie row).
        time_formatted (List[str], optional): time points formatted as strings.
            Defaults to None (i.e., rdms_data.time_descriptors['time'] is considered to be in seconds)
        colored_condiitons (list, optional): vector of pattern condition names to dissimilarities according to a categorical model on colored_conditions Defaults to None.
        plot_individual_dissimilarities (bool, optional): whether to plot the individual dissimilarities. Defaults to None (i.e., False if colored_conditions is notNone, True otherwise).

    Returns:
        Tuple[matplotlib.figure.Figure, npt.ArrayLike, collections.defaultdict]:

        Tuple of
            - Handle to created figure
            - Subplot axis handles from plt.subplots.
    """
    # create labels
    times = rdms_data.rdm_descriptors['time']
    unique_time = np.unique(times)
    time_formatted = time_formatted or ['%0.0f ms' % (np.round(x*1000,2)) for x in unique_time]

    n_dissimilarity_elements = rdms_data.dissimilarities.shape[1]

    # color mapping from colored conditions
    unsquareform = lambda a: a[np.nonzero(np.triu(a, k=1))]
    if colored_conditions is not None:
        plot_individual_dissimilarities = False if plot_individual_dissimilarities is None else plot_individual_dissimilarities
        unsquare_idx = np.triu_indices(n_dissimilarity_elements, k=1)
        pairwise_conds = unsquareform(np.array([[{c1, c2} for c1 in colored_conditions] for c2 in colored_conditions]))
        pairwise_conds_unique = np.unique(pairwise_conds)
        cnames = np.unique(colored_conditions)
        color_index = {f'{list(x)[0]} vs {list(x)[1]}' if len(list(x))==2 else f'{list(x)[0]} vs {list(x)[0]}': pairwise_conds==x for x in pairwise_conds_unique}
    else:
        color_index = {'': np.array([True]*n_dissimilarity_elements)}
        plot_individual_dissimilarities = True

    colors = plt.get_cmap('turbo')(np.linspace(0, 1, len(color_index)+1))

    # how many rdms to display
    t_display_idx = (np.round(np.linspace(0, len(unique_time)-1, min(len(unique_time), n_t_display)))).astype(int)
    t_display_idx = np.unique(t_display_idx)
    n_t_display = len(t_display_idx)

    # auto determine relative sizes of axis
    timecourse_plot_rel_height = timecourse_plot_rel_height or n_t_display // 3
    base_size = 40 / n_t_display if fig_width is None else fig_width / n_t_display

    # figure layout
    fig = plt.figure(constrained_layout=True, figsize=(base_size*n_t_display,base_size*timecourse_plot_rel_height))
    gs = fig.add_gridspec(timecourse_plot_rel_height+1, n_t_display)
    tc_ax = fig.add_subplot(gs[:-1,:])
    rdm_axes = [fig.add_subplot(gs[-1,i]) for i in range(n_t_display)]

    # plot dissimilarity timecourses
    lines = []

    dissimilarities_mean = np.zeros((rdms_data.dissimilarities.shape[1], len(unique_time)))
    dissimilarities_sem = np.zeros((rdms_data.dissimilarities.shape[1], len(unique_time)))

    for i, t in enumerate(unique_time):
        dissimilarities_mean[:, i] = np.mean(rdms_data.dissimilarities[t == times, :], axis=0)

    def _plot_mean_dissimilarities(labels=False):
        for i, (pairwise_name, idx) in enumerate(color_index.items()):
            mn = np.mean(dissimilarities_mean[idx, :],axis=0)
            se = np.std(dissimilarities_mean[idx, :],axis=0)/ np.sqrt(dissimilarities_mean.shape[0]) # se is over dissimilarities, not over subjects
            tc_ax.fill_between(unique_time, mn-se, mn+se, color=colors[i], alpha=.3)
            tc_ax.plot(unique_time, mn, color=colors[i], linewidth=2, label=pairwise_name if labels else None)

    def _plot_individual_dissimilarities():
        for i, (pairwise_name, idx) in enumerate(color_index.items()):
            tc_ax.plot(unique_time, dissimilarities_mean[idx, :].T, color=colors[i], alpha=max(1/255., 1/n_dissimilarity_elements))

    if plot_individual_dissimilarities:
        if colored_conditions is not None:
            _plot_mean_dissimilarities()
            yl = tc_ax.get_ylim()
            _plot_individual_dissimilarities()
            tc_ax.set_ylim(yl)
        else:
            _plot_individual_dissimilarities()

    if colored_conditions is not None:
        _plot_mean_dissimilarities(True)

    yl = tc_ax.get_ylim()
    for t in unique_time[t_display_idx]:
        tc_ax.plot([t,t], yl, linestyle=':', color='b', alpha=0.3)
    tc_ax.set_ylabel(f'Dissimilarity\n({rdms_data.dissimilarity_measure})')
    tc_ax.set_xticks(unique_time)
    tc_ax.set_xticklabels([time_formatted[idx]  if idx in t_display_idx else '' for idx in range(len(unique_time))])
    dt = np.diff(unique_time[t_display_idx])[0]
    tc_ax.set_xlim(unique_time[t_display_idx[0]]-dt/2, unique_time[t_display_idx[-1]]+dt/2)

    tc_ax.legend()

    # display (selected) rdms
    vmax = np.std(rdms_data.dissimilarities) * 2
    for i, (tidx, a) in enumerate(zip(t_display_idx, rdm_axes)):
        a.imshow(np.mean(rdms_data.subset('time', times[tidx]).get_matrices(),axis=0), vmin=0, vmax=vmax);
        a.set_title('%0.0f ms' % (np.round(unique_time[tidx]*1000,2)))
        a.set_yticklabels([])
        a.set_yticks([])
        a.set_xticklabels([])
        a.set_xticks([])

    return fig, [tc_ax] + rdm_axes

## Visualize the RDMs

In [None]:
fig, ax = plot_rdm_movie(
                         rdms,
                         descriptor=None,
                         n_t_display=10,
                         fig_width=20,
                         colored_conditions=None,
                        );

## Evaluate MEG RDMs similarity to CNN RDM

In [None]:
np.fill_diagonal(rdm_dnn,0)
model_rdm = rsatoolbox.rdm.RDMs(dissimilarities=squareform(rdm_dnn))

results = rsatoolbox.rdm.compare(model_rdm, rdms, method='corr')

## Visualize the result

In [None]:
plt.plot(range(181), results.squeeze())