<a href="https://colab.research.google.com/github/Nick7900/glhmm_protocols/blob/main/Procedures/Procedure_2_Across_trials_testing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# GLHMM Workshop - Introductory module

This notebook is meant as a guide and example of instantiating and training a TDE-HMM and visualising its output. The notebook is based on one of the [available tutorials](https://nbviewer.org/github/vidaurre/glhmm/blob/main/docs/notebooks/HMM-TDE_vs_HMM-MAR_example.ipynb) in the GLHMM toolbox.

The notebook sketches a hypothetical neuroscience project, with the aim to find common networks across participants in resting state whole-brain MEG data.

This notebook is scheduled as follows:

0. [PREPARATION](#preparation)
1. [Download data](#download)
2. [Basic data preprocessing](#preprocess)
3. [Instantiate TDE-HMM](#HMM_instantiate)
4. [Train HMM](#HMM_train)  -  here you will have the option to load a trained HMM, and go directly to the plotting section.
5. [Basic sanity checks and summary metrics](#sanity_checks)
6. [States spectral analysis](#spectral)

*The GLHMM workshop organising team*

## 0. PREPARATION <a id="preparation"></a>

If you dont have the **GLHMM-package** installed, or if you have not yet installed it using **Google Colab**, then run the following command in your terminal:

```pip install git+https://github.com/vidaurre/glhmm```



In [None]:
!pip install git+https://github.com/vidaurre/glhmm

#### Import necessary packages

In order to be able to run this notebook, you will also need some other packages. Please install them via pip install (follow syntax in previous cell) if the next cell does not run successfully.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from glhmm import glhmm, preproc, utils, graphics, spectral, io, statistics, auxiliary
import pickle
import pip

## 1. Download data <a id="download"></a>

The next two cells will fetch the data from the OSF website and download them into a new folder called "example_data" in the same folder as this notebook. If you prefer, you can create the folder and download the files from the [OSF project page](https://osf.io/8qcyj/?view_only=119c38072a724e0091db5ba377935666) and skip the next two cells.

In [None]:
def install(package):
    pip.main(['install', package])

try:
    import osfclient
except ImportError:
    print('osfclient is not installed, installing it now')
    install('osfclient')

In [None]:
! osf -p 8qcyj fetch MEG_data/data_MEG_TDE.pkl ./example_data/data_MEG_TDE.pkl

### Load data in memory

We will now load the data in memory.

***This step is compulsory, whether you downloaded the data manually or running the previous two cells.***

In [None]:
# Loading the data

with open("./example_data/data_MEG_TDE.pkl", "rb") as f:
    data_meg_tde = pickle.load(f)

### Data information

The data we just loaded are whole-brain MEG resting state data. The data were collected from participants while resting in a dark room. Each participant completed two sessions, except for one participant with just one session.

Each session is stored as a 2D matrix with the shape of (No. of timepoints, No. of parcels):

- Timepoints: The total number of recorded time points in the session.
- Parcels: Here regions according to a prespecified parcellation.

These data are a subset of a dataset collected and used for the original [TDE-HMM paper](https://www.nature.com/articles/s41467-018-05316-z)

In [None]:
# Display data information
print("Number of sessions in data_meg: %d"%len(data_meg_tde))
print("Shape of each session: ")
for i in range(len(data_meg_tde)):
  print(data_meg_tde[i].shape)

### Prepare data in HMM-friendly format

We now need to make the data in the right format for the HMM, meaning concatenate all sessions and subjects along the time axis. We also need to create the indices that will indicate the start and end timepoint of each session:

- **Concatenated brain activity (X_concat)**:
The brain activity data (data_meg) contains recordings from multiple sessions. We concatenate all the sessions along the time dimension to form a single, continuous 2D matrix: [timepoints × sessions, features].

- **Create index matrix (idx_data)**:
To track the start and end timepoints for each session, we generate an index matrix, idx_data, using the function get_indices_from_list. It will have a shape of: [No. of sessions, 2]. Each row specifies the start and end timepoints for a session.



In [None]:
X_concat = np.concatenate(data_meg_tde,axis=0)

# Get the start and end indices for each session
idx_data = statistics.get_indices_from_list(data_meg_tde)

In [None]:
# show the length of the data
print('total length of data:')
print(len(X_concat))

# show indices
print('indices:')
print(idx_data)

## 2. Basic data preprocessing <a id="preprocess"></a>

We always recommend to plot the data before doing any preprocessing, and understand what preprocessing steps might be appropriate. 

Check the [documentation](https://glhmm.readthedocs.io/en/latest/preproc.html) for the various preprocessing options in our toolbox.

***The data for this tutorial have already been preprocessed, so we will only standardise them.***



In [None]:
# Preprocess data - we will only use the default option, which is to standardise the data
X_preproc, _, log = preproc.preprocess_data(X_concat, idx_data, standardise=True)

In [None]:
# visualize some data
# decide on a plotting range of the signal - this is arbitrary
plot_range = np.arange(100000,101000)

Fs = 250 # sampling frequency of the signal
n_regions = X_preproc.shape[1]

# plot each parcel separately
fig = plt.figure(figsize=(10,8))
for i in range(n_regions):
    plt.plot(X_preproc[plot_range,i]+i*5)

plt.xticks(np.arange(0,len(plot_range)+1,Fs),np.arange(int(len(plot_range)/Fs+1)))
plt.xlabel('time [s]')
plt.yticks(np.arange(0,n_regions*5,5),np.arange(n_regions))
plt.ylabel('parcel')
plt.title('Example of Preprocessed Signal')
plt.show()


## TDE-HMM

The TDE-HMM was introduced in [Vidaurre et al. (2018)](https://www.nature.com/articles/s41467-018-05316-z). This HMM models the autocovariance of the signal. Given a multichannel time series data *y*, the autocovariance is computed within a window of length 2L and resolution S around each time point *yt*. The window is specified by the user in the form of **lags** = (-L, -L+S, ...-S, 0, S, ..., L-S, L), indicating the time points around t to be included in the window.

## 3. Instantiate a TDE-HMM  <a id="HMM_instantiate"></a>
To run a TDE-HMM in with the GLHMM package, the data are first restructured (embedded) according to the **lags**, and then a Gaussian HMM is run on the embedded data. Here, the states are the covariance matrices that best describe the signal.  

1. **Embed data**: For the first step, we will use the `build_data_tde` option in the preprocessing module. This will create an embedded version of the data according to the **lags** we define. In this specific case, our window of lags will be of length 2L, where L=35, with resolution S=5. We will then use PCA on the embedded signal, using number of PC = number of parcels * 2, to follow the original [TDE-HMM paper](https://doi.org/10.1038/s41467-018-05316-z) preprocessing and HMM settings.

2. **Initialise HMM**: We then initialise the glhmm object, which we here call `TDE_hmm`. By specifying the parameters of the glhmm object, we define which type of model we want to fit and how states should be defined. We will model 10 states by setting the parameter `K=10`. We will also specify the covariance type as `covtype='full'` and no means modelling. So our states will be characterised only by functional connectivity (covariance).

In [None]:
# Specify time lags
S=5
L = 35
lags = np.arange(-L, L + 1, S)

# Build the MEG data in TDE format
X_embedded, idx_tde, pca_model = preproc.build_data_tde(X_preproc,idx_data,lags=lags,pca=n_regions*2)

In [None]:
# instantiate model
K=10
TDE_hmm = glhmm.glhmm(model_beta='no', model_mean='state', K=K, covtype='full')

## 4. Train a TDE-HMM <a id="HMM_train"></a>

We will now train the TDE-HMM. The training procedure will output the state time courses (Gammas), the joint probability of past and future states conditioned on the data (Xi) and the free energy computed at each iteration of the process (FE).

We will also get the Viterbi path (vpath, a categorical version of the Gammas).


***This step can take several minutes, up to an hour. If you want to avoid training the HMM yourself and want to go directly to the visualization part, you can skip this part and download the trained model and its output from the OSF project page. Simply run the next two cells.***

In [None]:
# download trained model
! osf -p 8qcyj fetch MEG_data/hmm_tde.pkl ./example_data/hmm_tde.pkl

In [None]:
# load into the notebook
with open('./example_data/hmm_tde.pkl', "rb") as f:
    hmm_dict = pickle.load(f)

TDE_hmm = hmm_dict['hmm']
Gamma = hmm_dict['stc']
vpath = hmm_dict['vpath']

***For those of you who have time and want to train the HMM, run the next cell instead!***

In [None]:
# or train the HMM from scratch
print('Training HMM-TDE')
options={'gpu_acceleration':2}
Gamma, Xi, FE = TDE_hmm.train(X=None, Y=X_embedded, indices=idx_tde, options=options)
vpath = TDE_hmm.decode(X=None, Y=X_embedded, viterbi=True)

In [None]:
# save your trained hmm
hmm_dict = {'hmm':TDE_hmm,'stc':Gamma,'xi':Xi,'fe':FE,'vpath':vpath}
with open("./example_data/hmm_tde.pkl", "wb") as fp:
    pickle.dump(hmm_dict, fp, pickle.HIGHEST_PROTOCOL)

### Padding Viterbi path and Gamma

Because of the delay-embeddings, the state time courses are now shorter than the original data. To be able to plot the signal with the Viterbi path, we need a padding operation to fill in the missing values of the HMM output. 

This is done with the function `padGamma()` below.

In [None]:
T = auxiliary.get_T(idx_data)
options_tde = {'embeddedlags':list(lags)}
paddedVP = auxiliary.padGamma(vpath, T, options_tde)


## 5. Basic sanity checks and summary metrics <a id="sanity_checks"></a>

We will now perform some basic sanity checks and plot summary metrics. These include:

*   Plot example of Viterbi path with signal
*   Plot states fractional occupancy (FO)
*   Plot states switching rate (SR)
*   Plot states lifetimes (LT)
*   Plot states probabilities, mean and covariance

We will use the appropriate functions in the GLHMM `graphics` module to plot all these metrics. Check the [documentation](https://glhmm.readthedocs.io/en/latest/graphics.html) to see all our graphics options.



In [None]:
# plot state Viterbi path with signal
# define a plotting range
plotting_range = np.arange(15000,20000)

# use the appropriate function in the graphic package
graphics.plot_vpath(paddedVP[plotting_range], signal=X_preproc[plotting_range,1].copy(), title="States and signal example")


In [None]:
# inspect states - basic sanity checks
# Get summary metrics
FO = utils.get_FO(Gamma, indices=idx_tde)
SR = utils.get_switching_rate(Gamma, indices=idx_tde)
LTmean, LTmed, LTmax = utils.get_life_times(vpath, indices=idx_tde)

In [None]:
# plot some relevant statistics
graphics.plot_FO(FO,width=0.8, figsize=(7,5))


In [None]:
graphics.plot_switching_rates(SR,width=0.2)


In [None]:
graphics.plot_state_lifetimes(LTmed)

In [None]:
# inspect states
# plot probabilities, mean and covariance
graphics.plot_state_prob_and_covariance(TDE_hmm.Pi,TDE_hmm.P, TDE_hmm.get_means(),TDE_hmm.get_covariance_matrices(), figsize=(8,11))

## 6. States spectral analysis <a id="spectral"></a>

We will now use the `spectral` package in GLHMM to compute the states power spectra and coherence, and plot them.

This is done with the function `multitaper_spectral_analysis()` that computes the states power spectra using the nonparametric multitaper approach.

To compute the power spectra, you need to specify the sampling frequency, Fs, of the data. In this case, `Fs=250`.

The function also needs the Gamma (i.e., the state probability time courses) to compute the power spectrum of each state. We can input the padded Gamma, or the original gamma, specifying in the options the `embeddedlags` used to train the HMM. In this last case, the function will pad the Gamma first.

You can also specify in the options for the spectral analysis `fpass`, the frequency range for the power spectrum estimation.

Check the [documentation](https://glhmm.readthedocs.io/en/latest/spectral.html) on how to specify more options for the multitaper spectral analysis.


In [None]:
# get states spectral properties
options = {'embeddedlags':list(lags), 'fpass':[0,100]}
spectral_measures = spectral.multitaper_spectral_analysis(X_preproc, idx_data, Fs=250, Gamma=Gamma, options=options)

The output of the `multitaper_spectral_analysis` function is a dictionary containing:
- 'f' : the frequency bins
- 'p' : the power spectrum of each state, per subject/session and per channel
- 'psdc' : the cross-channel power spectrum, per subject/session and per state
- 'coh' : the cross-channel coherence, per subject/session and per state

We will then use the functions in the `graphics` module to visualise the states spectra, with and without the option of hichlighting the standard frequency bands.

We will visualise them for one session and for two randomly selected channels.


In [None]:
# plot the state power spectra for a specific session, and channel

selected_channel = 4
selected_session = 2

f = spectral_measures['f']
psd = spectral_measures['p']

graphics.plot_state_psd(psd[selected_session,:,selected_channel],
                        f,
                        highlight_freq=True,
                        title="States power spectra in channel %d"%selected_channel)

In [None]:
# plot the state power spectra for a specific session, and channel

selected_channel = 21
selected_session = 2

f = spectral_measures['f']
psd = spectral_measures['p']

graphics.plot_state_psd(psd[selected_session,:,selected_channel],
                        f,
                        highlight_freq=True,
                        title="States power spectra in channel %d"%selected_channel,)

We will then visualise the states cross-channels cohrence, for one session and between two randomly selected channels.


In [None]:
# for a specific session, plot the state coherence between two channels
coh = spectral_measures['coh']
chann_1 = 30
chann_2 = 3
graphics.plot_state_coherence(coh[selected_session,:,chann_1,chann_2],
                              f,
                              title='Coherence between regions %d and %d'%(chann_1,chann_2),
                              #highlight_freq=True,
                              )