---
# **Introduction to DeepInsight - Decoding position from two-photon calcium recordings**
---

This notebook stands as an example of how to use DeepInsight on calcium data and can be used as a guide on how to adapt it to your own datasets. All methods are stored in the deepinsight library and can be called directly or in their respective submodules. A typical workflow might look like the following: 

- Load your dataset into a format which can be directly indexed (numpy array or pointer to a file on disk)
- Preprocess the raw data (wavelet transformation)
- Preprocess your outputs (the variable you want to decode)
- Define appropriate loss functions for your output and train the model 
- Predict performance across all cross validated models
- Visualize influence of different input frequencies on model output

We use the calcium dataset here as it has lower sampling rate and is therefore faster to preprocess and train, which makes it suitable to also run the preprocessing in a Colab notebook.


---
## **Install and import DeepInsight**
---
Make sure you are using a **GPU runtime** if you want to train your own models. Go to Runtime -> Change Runtime type to change from CPU to GPU.
You can check the GPU which is used in Colab by running !nvidia-smi in a new cell 

In [None]:
# Import DeepInsight
import sys
sys.path.insert(0, "/home/marx/Documents/Github/DeepInsight")
import deepinsight

# Other imports
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import h5py
import numpy as np
import pandas as pd
from scipy.io import loadmat
import plotly.graph_objs as go
from skimage import io

# Initialize plotly figures
from plotly.offline import init_notebook_mode 
init_notebook_mode(connected = True)

# Make sure the output width is adjusted for better export as HTML
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:70% !important; }</style>"))
display(HTML("<style>.output_result { max-width:70% !important; }</style>"))

---
## **Load and preprocess your data**
---
For this example we provide two-photon calcium imaging data from a mouse in a virtual environment. Calcium traces together with the variable of interest is stored in one .mat file. You can load it from whatever datasource you want, just make sure that the dimensions match. 

The input to the model in the form of (Timepoints x Number of Cells) is stored in `raw_data`

The output to be decoded is in the form of (Timepoints x 1) and is stored in `output` together with timestamps for the output in `raw_timestamps`

---

**Run the next cells if you want to load the example data and preprocess it. You can also skip to 'Preprocess Data' to just load the preprocessed file and directly train the model.** 

In [None]:
base_path = './example_data/calcium/'
fp_raw_file = base_path + 'traces_M1336.mat'
if not os.path.exists(base_path):
    os.makedirs(base_path)
if not os.path.exists(fp_raw_file): # Careful as next command is a colab command where parameters had to be hard coded. Keep in mind if changing fp_raw_file
    !wget https://ndownloader.figshare.com/files/24024683 -O ./example_data/calcium/traces_M1336.mat

In [None]:
# Set base variables
sampling_rate = 30 # Might also be stored in above mat file for easier access
channels = np.arange(0, 100) # For this recording channels corresponds to cells. We only use the first 100 cells to speed up preprocessing (Change this if you run it on your own dataset)

# Also define Paths to access downloaded files
base_path = './example_data/calcium/'
fp_raw_file = base_path + 'traces_M1336.mat' # This is an example dataset containing calcium traces and linear position in a virtual track
fp_deepinsight = base_path + 'processed_M1336.h5' # This will be the processed HDF5 file

# Load data from mat file
calcium_data = loadmat(fp_raw_file)['dataSave']
raw_data = np.squeeze(calcium_data['df_f'][0][0])
raw_timestamps = np.arange(0, raw_data.shape[0]) / sampling_rate
output = np.squeeze(calcium_data['pos_dat'][0][0])

print('Data loaded. Calcium traces: {}, Decoding target {}'.format(raw_data.shape, output.shape))

---
### Plot example calcium traces
---
To give a visual impression of the input to our model we can now plot calcium traces for a bunch of different cells. 

In [None]:
end_point, y_offset, num_cells = 10000, 400, 6
fig = go.Figure()
for i in range(0, num_cells):
    fig.add_trace(go.Scatter(x=np.arange(0, end_point) / sampling_rate, y=raw_data[0:end_point, i] + (i * y_offset), line=dict(color='rgba(0, 0, 0, 0.85)', width=2), name='Cell {}'.format(i+1)))
# aesthetics
fig.update_yaxes(visible=False)
fig.update_layout(showlegend=False,plot_bgcolor="white",width=1800, height=650,margin=dict(t=20,l=20,b=20,r=20),xaxis_title='Time (s)', font=dict(family='Open Sans', size=16, color='black'))
fig.show()

---
### Preprocess data 
---

In [None]:
if not os.path.exists(fp_deepinsight):
    if os.path.exists(fp_raw_file): # Only do this if user downloaded raw files otherwise download preprocessed hdf5 file
        # Process output for use as decoding target
        # As the mouse is running on a virtual linear track we have a circular variable. We can solve this by either:
        # (1) Using a circular loss function or 
        # (2) Using the sin and cos of the variable
        # For this dataset we choose method (2), see the loss calculation for head directionality on CA1 recordings for an example of (1)
        output = (output - np.nanmin(output)) / (np.nanmax(output) - np.nanmin(output))
        output = (output * 2*np.pi) - np.pi # Scaled to -pi / pi
        output = np.squeeze(np.column_stack([np.sin(output), np.cos(output)]))
        output = pd.DataFrame(output).ffill().bfill().values # Get rid of NaNs
        output_timestamps = raw_timestamps # In this recording timestamps are the same for output and raw_data, meaning they are already aligned to each other

        # Transform raw data to frequency domain
        # We use a small cutoff (1/500) for the low frequencies to keep the dimensions low & the model training fast
        deepinsight.preprocess.preprocess_input(fp_deepinsight, raw_data, sampling_rate=sampling_rate, average_window=1, wave_highpass=1/500, wave_lowpass=sampling_rate, channels=channels) 
        # # Prepare outputs
        deepinsight.preprocess.preprocess_output(fp_deepinsight, raw_timestamps, output, output_timestamps, average_window=1, dataset_name='sin_cos')
    else:
        if not os.path.exists(base_path):
            os.makedirs(base_path)
        if not os.path.exists(fp_deepinsight):
            !wget https://ndownloader.figshare.com/files/23658674 -O ./example_data/calcium/processed_M1336.h5

---
### Plot preprocessed data
---
We plot examples to double check the wavelet preprocessing. Each plot shows the wavelet processed calcium traces for one cell

In [None]:
hdf5_file = h5py.File(fp_deepinsight, mode='r')
wavelets = hdf5_file['inputs/wavelets']
frequencies = np.round(hdf5_file['inputs/fourier_frequencies'], 3)

In [None]:
num_cells, gap = 20, 30
fig = go.Figure()
for i in range(0, num_cells):
    this_z = wavelets[0:wavelets.shape[0]//2:gap,:,i].transpose()
    fig.add_heatmap(x=np.arange(0, this_z.shape[0]) / (sampling_rate / gap), z=this_z,colorscale='Viridis',visible=False,showscale=False)
fig.data[0].visible = True
# aesthetics
steps = []
for i in range(len(fig.data)):
    step = dict(method="update",label="Cell {}".format(i+1),args=[{"visible": [False] * len(fig.data)}])
    step["args"][0]["visible"][i] = True  # Toggle i'th trace to "visible"
    steps.append(step)
sliders = [dict(active=10,currentvalue={"prefix": "Cell: ", "visible" : False},pad={"t": 70},steps=steps)]

fig.update_layout(width=1800, height=650,sliders=sliders, yaxis = dict(tickvals=np.arange(0, len(frequencies)), ticktext = ['{:.3f}'.format(i) for i in frequencies], autorange='reversed'), yaxis_title='Frequency (Hz)',
                  showlegend=False, plot_bgcolor="white",margin=dict(t=20,l=20,b=20,r=20),xaxis_title='Time (s)', font=dict(family='Open Sans', size=16, color='black'))
fig

In [None]:
hdf5_file.close()

---
## **Train the model**
---
The following command uses 5 cross validations to train the models and stores weights in HDF5 files

In [None]:
# Define loss functions and train model, if more then one behaviour/stimuli needs to be decoded, define loss functions and weights for each of them here
loss_functions = {'sin_cos' : 'mse'}
loss_weights = {'sin_cos' : 1} 
user_opts = {'epochs' : 10, 'sample_per_epoch' : 250} # Speed up for Colab, normally set to {'epochs' : 20, 'sample_per_epoch' : 250
deepinsight.train.run_from_path(fp_deepinsight, loss_functions, loss_weights, user_opts=user_opts)

---
## **Evaluate model performance**
---
Here we calculate the losses over the whole duration of the experiment. Step size indicates how many timesteps are skipped between samples. Note that each sample contains 64 timesteps, so setting step size to 64 will result in non-overlapping samples

In [None]:
step_size = 100

# Get loss and shuffled loss for influence plot, both is also stored back to HDF5 file
losses, output_predictions, indices = deepinsight.analyse.get_model_loss(fp_deepinsight, stepsize=step_size)

# Get real output from HDF5 file
hdf5_file = h5py.File(fp_deepinsight, mode='r')
output_real = hdf5_file['outputs/sin_cos'][indices,:]

---
### Visualize model performance
---
We plot the real output vs. the predicted output for the above trained models. The real output is linearized as in the virtual reality environment the start follows after the mouse reaches the end, therefore we can use a circular variable. Also note that the example plot below is only trained on a subset of channels (see channels variable, default=100) and a limited number of epochs (see epochs, default=5), to make training in the Colab notebook faster. The performance on the fully evaluated dataset is higher. 

In [None]:
fig = go.Figure()

fig.add_trace(go.Scatter(x=np.arange(0, output_real.shape[0]) / (sampling_rate / step_size), y=output_real[:,0], line=dict(color='rgba(0, 0, 0, 0.85)', width=2), name='Real'))
fig.add_trace(go.Scatter(x=np.arange(0, output_real.shape[0]) / (sampling_rate / step_size), y=output_predictions['sin_cos'][:,0], line=dict(color='rgb(67, 116, 144)', width=3), name='Predicted'))

# aesthetics
#fig.update_yaxes(visible=False)
fig.update_layout(width=1800, height=650, plot_bgcolor="rgb(245, 245, 245)",margin=dict(t=20,l=20,b=20,r=20),xaxis_title='Time (s)', yaxis_title='Decoding target (sin)', font=dict(family='Open Sans', size=16, color='black'))
fig

In [None]:
hdf5_file.close()

---
### Get shuffled model performance
---
We use the shuffled loss to evaluate feature importance

In [None]:
shuffled_losses_ax1 = deepinsight.analyse.get_shuffled_model_loss(fp_deepinsight, axis=1, stepsize=step_size)

In [None]:
# Calculate residuals, make sure there is no division by zero by adding small constant.
residuals = (shuffled_losses_ax1 - losses) / (losses + 0.1)
residuals_mean = np.mean(residuals, axis=1)[:,0]
residuals_standarderror = np.std(residuals, axis=1)[:,0] / np.sqrt(residuals.shape[0])

---
### Show feature importance for frequency axis
---
This plot shows the relative influence of each frequency band on the decoding of the position in the virtual environment. We plot the mean across samples + the standard error for each frequency band. 

In [None]:
end_point, y_offset, num_cells = 1000, 400, 6
fig = go.Figure()

fig.add_trace(go.Scatter(x=np.arange(0, residuals_mean.shape[0]), y=residuals_mean, line=dict(color='rgba(0, 0, 0, 0.85)', width=3), name='Real',
                         error_y=dict(type='data', array=residuals_standarderror, visible=True, color='rgb(67, 116, 144)', thickness=3)))

# aesthetics
#fig.update_yaxes(visible=False)
fig.update_layout(width=1800, height=650, plot_bgcolor="rgb(245, 245, 245)",margin=dict(t=20,l=20,b=20,r=20), xaxis = dict(tickvals=np.arange(0, len(frequencies)), ticktext = ['{:.3f}'.format(i) for i in frequencies], autorange='reversed'),
                  xaxis_title='Frequency (Hz)', yaxis_title='Relative influence', font=dict(family='Open Sans', size=16, color='black',
))
fig

---
### Show feature importance for cell axis
---
For this we shuffle across the cell dimension to see the influence each cell has on the decoding of position and then plot it back to the calcium ROIs. In the plot below the size of the dot is indicating the relative influence of this ROI (cell) on the decoding performance. Red dots indicate a high influence of this cell on the decoding of position and blue dots indicate a negative influence of this cell.


In [None]:
shuffled_losses = deepinsight.analyse.get_shuffled_model_loss(fp_deepinsight, axis=2, stepsize=step_size)

In [None]:
# Calculate residuals, make sure there is no division by zero by adding small constant.
residuals = (shuffled_losses - losses) / (losses + 0.1)
residuals_mean = np.mean(residuals, axis=1)[:,0]

In [None]:
# Get some files for plotting the importance of each cell back to brain anatomy
if not os.path.exists('./example_data/calcium/centroid_YX.mat'):
    !wget https://www.dropbox.com/s/z8ynet2nkt9pe1u/centroid_YX.mat -O ./example_data/calcium/centroid_YX.mat
if not os.path.exists('./example_data/calcium/calcium_rois.jpg'): 
    !wget https://www.dropbox.com/s/czak7rphajslcr0/test_rois_F5.jpg -O ./example_data/calcium/calcium_rois.jpg
roi_data = loadmat('./example_data/calcium/centroid_YX.mat')['xy_coords']

In [None]:
fig = go.Figure()
point_size_adjustment = 1250
all_pos = residuals_mean > 0
all_pos_channels = channels[all_pos]
all_neg_channels = channels[~all_pos]
fig.add_trace(go.Image(z=io.imread('./example_data/calcium/calcium_rois.jpg')))
fig.add_trace(go.Scatter(x=roi_data[:,1], y=roi_data[:,0], marker_symbol='circle', mode='markers', marker=dict(color='white', opacity=0.5, line=dict(color='white',width=0)), name='Cell centers'))
fig.add_trace(go.Scatter(x=roi_data[all_pos_channels,1], y=roi_data[all_pos_channels,0], marker_symbol='circle', mode='markers', marker=dict(color='red', size=residuals_mean[all_pos]*point_size_adjustment, opacity=0.5, line=dict(color='black',width=3)), name='Pos. influence'))
fig.add_trace(go.Scatter(x=roi_data[all_neg_channels,1], y=roi_data[all_neg_channels,0], marker_symbol='circle', mode='markers', marker=dict(color='blue', size=residuals_mean[~all_pos]*-point_size_adjustment, opacity=0.5, line=dict(color='black',width=3)), name='Neg. influence'))

fig.update_layout(width=1800, height=650, showlegend=False, plot_bgcolor="white",margin=dict(t=10,l=0,b=10,r=0), xaxis=dict(showticklabels=False), yaxis=dict(showticklabels=False))
fig