# Temporal RSA

This demo notebook demonstrates how to work with temporal data in the RSA toolbox


So far, it demonstrates how to

(1) import temporal dataset into the `rsatoolbox.data.TemporalDataset` class and 

(2) how to create RDM movies using the `rsatoolbox.rdm.calc_rdm_movie` function

The notebook will 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import rsatoolbox
import pickle

from rsatoolbox.rdm import calc_rdm_movie

## Load temporal data

I here used sample data from mne-python

https://mne.tools/dev/overview/datasets_index.html#sample

Data is comprised of the preprocessed MEG data in "sample_audvis_raw.fif".

Preprocessing includes:
- downsampling to 60Hz 
- band-pass filtering between 1 Hz and 20 Hz
- rejecting bad trials using an amplitude threshold
- baseline correction (basline -200 to 0 ms)

*See demos/TemporalSampleData/preproc_mn_sample_data.py*

The preprocessed data is stored in *TemporalSampleData/meg_sample_data.pkl*

In [None]:
dat = pickle.load( open( "TemporalSampleData/meg_sample_data.pkl", "rb" ) )
measurements = dat['data']
cond_names = [x for x in dat['cond_names'].keys()]
cond_idx = dat['cond_idx']
channel_names = dat['channel_names']
times = dat['times']

In [None]:
print('there are %d observations (trials), %d channels, and %d time-points\n' % 
      (measurements.shape))

print('conditions:')
print(cond_names)

Plot condition averages for two channels:

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12,4))
ax = ax.flatten()
for jj,chan in enumerate(channel_names[:2]):
    for ii, cond_ii in enumerate(np.unique(cond_idx)):
        mn = measurements[cond_ii == cond_idx,jj,:].mean(0).squeeze()
        ax[jj].plot(times, mn, label = cond_names[ii])
        ax[jj].set_title(chan)
ax[jj].legend()
plt.show()

## The `rsatoolbox.data.TemporalDataset` class

`measurements` is an `np.array` of shape n_obs x n_channels x n_times

`time_descriptor` should contain the time-point vector for the measurements of length n_times. it is recommended to call this descriptor 'time'

In [None]:
tim_des = {'time': times}

the other descriptors are identical as in the `rsatoolbox.data.Dataset` class

In [None]:
des = {'session': 0, 'subj': 0}
obs_des = {'conds': cond_idx}
chn_des = {'channels': channel_names}

In [None]:
data = rsatoolbox.data.TemporalDataset(measurements, 
                              descriptors = des, 
                              obs_descriptors = obs_des, 
                              channel_descriptors = chn_des, 
                              time_descriptors = tim_des)
data.sort_by('conds')

### convenience methods

`rsatoolbox.data.TemporalDataset` comes with the same convenience methods as `rsatoolbox.data.Dataset`.

In addition, the following functions are provided:

- `rsatoolbox.data.TemporalDataset.split_time(by)`
- `rsatoolbox.data.TemporalDataset.subset_time(by, t_from, t_to)`
- `rsatoolbox.data.TemporalDataset.bin_time(by, bins)`
- `rsatoolbox.data.TemporalDataset.convert_to_dataset(by)`

#### `rsatoolbox.data.TemporalDataset.split_time(by)`

splits the `rsatoolbox.data.TemporalDataset` object into a list of n_times `rsatoolbox.data.TemporalDatset` objects, splitting the measurements along the time_descriptor `by`

In [None]:
print('shape of original measurements')
print(data.measurements.shape)

data_split_time = data.split_time('time')

print('\nafter splitting')
print(len(data_split_time))
print(data_split_time[0].measurements.shape)

#### `rsatoolbox.data.TemporalDataset.subset_time(by, t_from, t_to)`

returns a new `rsatoolbox.data.TemporalDataset` with only the data between where `time_descriptors[by]` is between t_from and t_to

In [None]:
print('shape of original measurements')
print(data.measurements.shape)

data_subset_time = data.subset_time('time', t_from = -.1, t_to = .5)

print('\nafter subsetting')
print(data_subset_time.measurements.shape)
print(data_subset_time.time_descriptors['time'][0])

#### `rsatoolbox.data.TemporalDataset.bin_time(by, bins)`

returns a new `rsatoolbox.data.TemporalDataset` object with binned temporal data. data within bins is averaged.

`bins` is a list or array, where the first dimension contains the bins, and the second dimension the old time-bins that should go into this bin. 

In [None]:
bins = np.reshape(tim_des['time'], [-1, 2])
print(len(bins))
print(bins[0])

In [None]:
print('shape of original measurements')
print(data.measurements.shape)

data_binned = data.bin_time('time', bins=bins)

print('\nafter binning')
print(data_binned.measurements.shape)
print(data_binned.time_descriptors['time'][0])

#### `rsatoolbox.data.TemporalDataset.convert_to_dataset(by)`

returns a `rsatoolbox.data.Dataset` object where the time dimension is absorbed into the observation dimension

In [None]:
print('shape of original measurements')
print(data.measurements.shape)

data_dataset = data.convert_to_dataset('time')

print('\nafter binning')
print(data_dataset.measurements.shape)
print(data_dataset.obs_descriptors['time'][0])

## create RDM movie

the function `calc_rdm_movie` takes `rsatoolbox.data.TemporalDataset` as an input and outputs an RDMs `rsatoolbox.rdm.RDMs` object.
It works like `calc_rdm`.

In [None]:
rdms_data = calc_rdm_movie(data, method = 'euclidean', 
                           descriptor = 'conds')
print(rdms_data)

Binning can be applied before computing the RDMs by simpling specifying the bins argument

In [None]:
rdms_data_binned = calc_rdm_movie(data, method = 'euclidean', 
                           descriptor = 'conds',
                           bins=bins)
print(rdms_data_binned)

## from here on

The following are examples for data analysis and plotting with temporal data. So far it uses the functions for non-temporal data of the toolbox. This section should be expanded once new temporal RSA functions are added to the toolbox.

I here use plotting from the standard plotting function.

In [None]:
plt.figure(figsize=(10,15))

# add formated time as rdm_descriptor
rdms_data_binned.rdm_descriptors['time_formatted'] = ['%0.0f ms' % (np.round(x*1000,2)) for x in rdms_data_binned.rdm_descriptors['time']]

rsatoolbox.vis.show_rdm(rdms_data_binned, 
                   pattern_descriptor='conds',
                   rdm_descriptor='time_formatted')

## Model rdms

This is a simple example with basic model RDMs

In [None]:
from rsatoolbox.rdm import get_categorical_rdm

In [None]:
rdms_model_in = get_categorical_rdm(['%d' % x for x in range(4)])
rdms_model_lr = get_categorical_rdm(['l','r','l','r'])
rdms_model_av = get_categorical_rdm(['a','a','v','v'])

model_names = ['independent', 'left/right', 'audio/visual']

# append in one RDMs object

model_rdms = rdms_model_in
model_rdms.append(rdms_model_lr)
model_rdms.append(rdms_model_av)

model_rdms.rdm_descriptors['model_names'] = model_names
model_rdms.pattern_descriptors['cond_names'] = cond_names

In [None]:
rsatoolbox.vis.show_rdm(model_rdms, rdm_descriptor='model_names', pattern_descriptor = 'cond_names')
None

## data - model similarity across time

In [None]:
from rsatoolbox.rdm import compare

In [None]:
r = []
for mod in model_rdms:
    r.append(compare(mod, rdms_data_binned, method='cosine'))

for i, r_ in enumerate(r):
    plt.plot(rdms_data_binned.rdm_descriptors['time'], r_.squeeze(), label=model_names[i])

plt.xlabel('time')
plt.ylabel('model-data cosine similarity')
plt.legend()