# Notebook 2-1: HMM Training

In this notebook, we will cover
1. **Loading prepared data**: We will load the prepared data from the previous notebook.
2. **Configuring the HMM**: We will configure the HMM with the appropriate parameters.
3. **Training the HMM (optional)**: We will train the HMM using the prepared data.
4. **Loading a pre-trained HMM**: We will load a pre-trained HMM for inference.
5. **Inference with the HMM**: We will perform inference using the trained HMM.

## 1. Loading prepared data
Here we will load the prepared data from the previous notebook `1_prepare-data.ipynb`. Recall that the data was prepared with TDE and PCA, before a final standardisation step was applied.

In [None]:
from osl_dynamics.data import Data

data = Data("prepared_data")

## 2. Configuring the HMM
The HMM can be configured with the `Config` and `Model` API in the `osl_dynamics.models.hmm` module. The `Config` class is used to set up the model parameters, while the `Model` class is used to create the model object itself.

In [None]:
from osl_dynamics.models.hmm import Config, Model

config = Config(
    n_states=6,
    n_channels=data.n_channels,
    sequence_length=200,
    learn_means=False,
    learn_covariances=True,
    batch_size=128,
    learning_rate=0.01,
    n_epochs=20,
)
model = Model(config)

Once you have created the model, you can have a look at a summary of the model by calling the `Model.summary()` method.

In [None]:
model.summary()

### Exercise
- What do the options in the `Config` class do? Can you find the documentation for the `Config` class and figure out yourself?
- Why do you see `Non-trainable params: 480 (1.88KB)` in the summary? What does this mean?

## 3. Training the HMM (optional - might take a while, ~20 - 60s per epoch)

Now we can start training the HMM. Due to the stochastic nature of the process of training the HMM and non-convexity of the loss function, we normally initialise the model with multiple initialisations. We then continue training the model with the best performing initialisation.

Here we initialise the model by training it with 3 different initialisations, each of which is trained for 1 epoch. For the HMM, this is often enough of an initialisation to get reproducible results. `n_init` could be increased if you find a high run-to-run variability.

In [None]:
init_history = model.random_state_time_course_initialization(data, n_init=3, n_epochs=1)

Now we have a good initialisation, we do the full model training by calling the `Model.fit()` method.

In [None]:
history = model.fit(data)

After training, we can save the trained model with the `Model.save()` method.

In [None]:
import os

os.makedirs("results/model", exist_ok=True)
model.save("results/model")

The free energy is a measure of how well the model fits the data, with "regularisation" term added for measuring the complexity of the model. We can get the free energy on a dataset by calling the `Model.free_energy()` method. Here we will get the free energy on the training data and save the training history to a `.pkl` file.

In [None]:
import pickle

free_energy = model.free_energy(data)
history["free_energy"] = free_energy

pickle.dump(init_history, open("results/model/init_history.pkl", "wb"))
pickle.dump(history, open("results/model/history.pkl", "wb"))

## 4. Loading a pretrained model
We have also supplied a pretrained model. If you have not trained the model yourself following the above steps, you can download the pretrained model by checking out the `0_get_model.ipynb` notebook.

`osl-dynamics` comes with a function to load pretrained models in the `osl_dynamics.models` module.

In [None]:
from osl_dynamics.models import load

model = load('results/model')

## 5. Inference with the HMM
Now that we have a trained HMM, we can use it to infer useful information about the data. These include:
- The state time courses (posterior state probabilities - $\alpha$)
- The state means.
- The state covariance matrices.
- The probability transition matrix.
- The initial state probabilities.

In [None]:
import numpy as np
import os
import pickle

inf_params_dir = "results/inf_params"
os.makedirs(inf_params_dir, exist_ok=True)

alp = model.get_alpha(data)
means = model.get_means()
covs = model.get_covariances()
trans_prob = model.get_trans_prob()
initial_state_probs = model.get_initial_state_probs()

pickle.dump(alp, open(f"{inf_params_dir}/alp.pkl", "wb"))
np.save(f"{inf_params_dir}/means.npy", means)
np.save(f"{inf_params_dir}/covs.npy", covs)
np.save(f"{inf_params_dir}/trans_prob.npy", trans_prob)
np.save(f"{inf_params_dir}/initial_state_probs.npy", initial_state_probs)

## Exercise
How do you get the viterbi path?

In [None]:
viterbi_path = model.get_viterbi_path(data)

We can plot the viterbi path of the first 8 seconds of the first session.

In [None]:
from osl_dynamics.utils import plotting

plotting.plot_alpha(
    viterbi_path[0],
    n_samples=2000,
    sampling_frequency=250,
)

Note: The viterbi path is the most likely sequence of hidden states given the observed data. The difference between the viterbi path and directly using the argmax of the posterior state probabilities is that the viterbi path maximises the joint posterior distribution of the entire sequence of hidden states given the observed data, while the argmax of the posterior state probabilities only maximises the marginal posterior distribution.

Normally we would use the argmax of the posterior state probabilities due to the fact that the viterbi path can be computationally expensive to calculate.