[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/MouseLand/rastermap/blob/main/notebooks/rastermap_interactive.ipynb)

# Rastermap sorting of 34k neurons

We will use a spontaneous activity recording from [Syeda et al, 2023](https://www.biorxiv.org/content/10.1101/2022.11.03.515121v1.abstract). We recorded 34,086 neurons from mouse sensorimotor cortex for 2+ hours using two-photon calcium imaging at a rate of 3.2Hz. FYI to make the download of the dataset faster, we are analyzing only the first half of the recording. During the recording, the mouse was free to run on an air floating ball, and we recorded the mouse face with a camera at a rate of 50Hz and tracked keypoints on the mouse face.

This notebook includes an **interactive** plotting section to explore the spatial relationships among neurons in the dataset.

If you are using colab, to ensure a smooth and efficient experience while running the interactive plot, it's recommended to clear your browser's cache before executing the code. Here's how you can do it for different browsers:

**Chrome**: Settings > Privacy and security > Clear browsing data > Check "Cached images and files" > Click "clear data"

**Safari**: Settings > Advanced > check "Show Develop menu in menu bar" > Go back to the menu bar > Develop > Empty Caches

**Firefox**: Settings > Privacy & security > Clear data > check "Cached Web Content" > Clear

First we will install the required packages, if not already installed. If on google colab, it will require you to click the "RESTART RUNTIME" button because we are updating numpy. Also, select the GPU runtime to make the interactive plotting faster:
**Runtime > Change runtime type > Hardware accelerator = GPU**

In [None]:
!pip install numpy>=1.24 # (required for google colab)
!pip install rastermap
!pip install matplotlib

### Load data and import libraries

If not already downloaded, the following cell will automatically download the processed data stored [here](https://osf.io/8xg7n).

In [None]:
import numpy as np
import matplotlib.pyplot as plt
# importing rastermap
# (this will be slow the first time since it is compiling the numba functions)
from rastermap import Rastermap, utils
from scipy.stats import zscore

# download spontaneous activity
filename = utils.download_data(data_type="spont2")

dat = np.load(filename)

# spks is neurons by time
# (each timepoint is 313 ms)
spks = dat["spks"]
n_neurons, n_time = spks.shape
print(f"{n_neurons} neurons by {n_time} timepoints")

# zscore activity (each neuron activity trace is then mean 0 and standard-deviation 1)
spks = zscore(spks, axis=1)

# XY position of each neuron in the recording
xpos, ypos = dat["xpos"], dat["ypos"]

# for your own data, you will need "spks" and "xpos" and "ypos"

### Run Rastermap

Let's sort the single neurons with Rastermap, with clustering and upsampling:

In [None]:
model = Rastermap(n_clusters=100, # number of clusters to compute
                  n_PCs=128, # number of PCs to use
                  locality=0.75, # locality in sorting to find sequences (this is a value from 0-1)
                  time_lag_window=5, # use future timepoints to compute correlation
                  grid_upsample=10, # default value, 10 is good for large recordings
                ).fit(spks)
y = model.embedding # neurons x 1
isort = model.isort

Let's create superneurons from Rastermap -- we sort the data and then sum over neighboring neurons:

In [None]:
nbin = 200 # number of neurons to bin over
sn = utils.bin1d(spks[isort], bin_size=nbin, axis=0) # bin over neuron axis

### Interactive Visualization

Use the Rastermap sorting to visualize neural activity of all neurons and show the positions of selected neurons. GPU is required for fast rendering.

In [None]:
!pip install dash
!pip install plotly

In [None]:
# @title Interactive Plot - press play and wait around 5sec
import dash
import plotly.graph_objs as go
from dash.dependencies import Input, Output
from dash import callback, dcc, html, State

matrix = sn
NN, NT = matrix.shape

# get indices of neurons corresponding to each superneuron
NN_all = spks.shape[0]
indices_bin = isort[:NN_all // nbin * nbin].reshape((NN_all // nbin, nbin))

# decide how many time points to show per frame
NT_show = min(1000, NT-1)
nmin, nmax = 0, NN
tmin, tmax = 0, NT_show

# swap the xpos and ypos
xpos_plot = ypos
ypos_plot = xpos

# visualize first frame
fig = go.Figure(
    data=[
        go.Heatmap(
            x=np.arange(tmin, tmax).tolist(),
            z=matrix[:, tmin:tmax],
            colorscale="Greys",
            zmin=0,
            zmax=0.8,
        )
    ]
)

# initialize the positions of the selecting bar
x0 = tmin
x1 = tmin+NT_show
y0 = nmin+int(NN/10) * 8
y1 = nmin+int(NN/10) * 9

# visualize neurons with their positions
color_values = np.ones(len(xpos_plot)) * 0.1
size_values = np.ones(len(xpos_plot)) * 5

neuron_fig = go.Figure(
    data=[
        go.Scattergl(x=xpos_plot, y=ypos_plot, mode='markers',
                    marker=dict(
                  size=size_values,
                  color=color_values,
                  colorscale='Purples',
                  cmin=0,
                  cmax=1,
                  )
                  )
    ]
)

neuron_fig.update_layout(
    width=500,
    height=500,
    yaxis={"title": 'y position'},
    xaxis={"title": 'x position'},
    template='simple_white',
    margin=dict(l=10, r=0, t=100, b=0),
)

# define the dash app layout
app = dash.Dash(__name__)
app.layout = html.Div(
    style={'display': 'flex', 'flex-direction': 'row', "padding": "0", "margin": "0"},
    children=[
        html.Div(
            style={"width": "60%", "display": "flex", "flex-direction": "column", "padding": "0", "margin": "0"},
            children=[
                html.Div(
                    style={'display': 'flex', 'flex-direction': 'row', "padding": "10", "margin": "10"},
                    children=[
                        html.H2("Rastermap", style={'margin': '0'}),
                        html.Div(dcc.Input(id='input-on-submit', type='text', placeholder="{}".format(NT_show)), style={'margin-left': '10px', 'margin-top': '50px'}),
                        html.Button('Submit', id='submit-val', n_clicks=0, style={'height': '20px', 'margin-left': '10px', 'margin-top': '50px'}),
                        html.Div(id='button-output', children=f'number of time points to show: {NT_show}',
                                 style = {'margin-left': '10px', 'margin-top': '50px'})
                    ]
                ),
                dcc.Graph(id="matrix-plot",
                          figure=fig,
                          config={
                              'edits': {
                                  'shapePosition': True
                              }
                          },
                          style={'margin-top': '0px'}
                          ),
                dcc.Slider(0, 1, step=1/NT, id='slider-time',
                           marks={i*NT/(NT-NT_show): '{}'.format(int(i*NT)) for i in (np.arange(NT+1, step=int(NT/10))/NT).tolist()})
            ]
        ),
        html.Div(
            style={"width": "40%", "display": "inline-block", "padding": "0", "margin": "0"},
            children=[
                html.H2("Neuron locations"),
                dcc.Graph(id="neuron-plot")
            ]
        )
    ]
)

# call back for slider to change time points to show
@app.callback(
    Output("matrix-plot", 'figure'),
    Input('slider-time', 'value'),
    )
def update_output(tvalue):
    if tvalue is not None:
        tmin = int((NT-NT_show)*tvalue)
        tmax = tmin + NT_show
    else:
        tmin, tmax = 0, NT_show

    fig = go.Figure(
        data=[
            go.Heatmap(
                x=np.arange(tmin, tmax).tolist(),
                z=matrix[:, tmin:tmax],
                colorscale="Greys",
                zmin=0,
                zmax=0.8,
            )
        ]
    )

    fig.add_shape(type="rect",
        xref="x", yref="y",
        x0=tmin, y0=y0,
        x1=tmin+x1-x0, y1=y1,
        line=dict(
            color="grey",
            width=3,
        ),
        fillcolor="grey",
        opacity=0.5,
        xanchor=tmin,
    )

    fig.update_layout(
        width=800,
        height=500,
        yaxis={"title": 'Neuron'},
        xaxis={"title": 'Time'},
        margin=dict(l=0, r=0, t=50, b=0),
    )
    return fig

# call back for moving the selecting bar to select neurons
@app.callback(
    Output("neuron-plot", "figure"),
    Input("matrix-plot", "relayoutData"))
def update_matrix_plot(relayout_data):
    global x0, y0, x1, y1
    color_values = np.ones(len(xpos_plot)) * 0.1
    size_values = np.ones(len(xpos_plot)) * 5
    if relayout_data is not None:
        x0, y0 = int(relayout_data["shapes[0].x0"]), int(relayout_data["shapes[0].y0"])
        x1, y1 = int(relayout_data["shapes[0].x1"]), int(relayout_data["shapes[0].y1"])
        neuron_range = np.arange(y0, y1)
        neuron_range = indices_bin[neuron_range].reshape(-1)
        color_values[neuron_range] = 1
        size_values[neuron_range] = 5

    neuron_fig['data'][0]['marker']['color'] = color_values
    neuron_fig['data'][0]['marker']['size'] = size_values

    return neuron_fig

# call back for updating number of time points to show per frame
@callback(
    Output('button-output', 'children'),
    Output('slider-time', 'marks'),
    Input('submit-val', 'n_clicks'),
    State('input-on-submit', 'value')
)
def update_timepoints_show(n_clicks, value):
    global NT_show, x1
    if value is not None:
        NT_show = int(value)
        x1 = x0+NT_show
    else:
        value = NT_show
    text_to_show = f'number of time points to show: {value}'
    new_marks = {i*NT/(NT-NT_show): '{}'.format(int(i*NT)) for i in (np.arange(NT+1, step=int(NT/10))/NT).tolist()}
    return text_to_show, new_marks

# run the app
app.run_server(jupyter_mode='inline')

### Settings

You can see all the rastermap settings with `Rastermap?`

In [None]:
Rastermap?

### Outputs

All the attributes assigned to the Rastermap `model` are listed with `Rastermap.fit?`

In [None]:
Rastermap.fit?