[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/MouseLand/course-materials/blob/main/cellpose_extraction/tutorial.ipynb)

# processing two-photon calcium imaging data with Cellpose


This notebook will guide you through the stages of processing two-photon data. This is data collected from a wild-type mouse injected with GCaMP6s in layer 2/3 of primary visual cortex. The mouse was head-fixed above a ball and free to run, with no visual input (lights turned off). The recording was collected at 13Hz (there were 3 planes in the recording, 1 is included here). The recording has already been registered - in practice, you will need to run motion registration first. We will cover the subsequent steps of imaging data:

1. cell detection
2. signal extraction
3. visualization

To keep the notebook interactive, there are three types of exercises throughout
1. QUESTION MARKS: where ????? need to be replaced by a short equation, such as a variable or a function name.
2. DISCUSSION: have a short discussion with your colleague about this. At the end of each section, we will open the discussions to the whole group.
3. QUIZ: multiple-choice that we take across the entire group. Keep track of your own points.

**Setup:** First we will install the required packages, if not already installed. If you are on google colab, select the GPU runtime: Runtime > Change runtime type > Hardware accelerator = GPU


In [None]:
# cell segmentation
!pip install cellpose 
# neural activity plotting
!pip install rastermap 
# plotting
!pip install matplotlib 
# download files from google drive
!pip install gdown 

# SUGGESTION: you can hide the ouput of a code cell after running it  
# in jupyter notebook by double-clicking on the left output of the cell 
# or in google colab with the left menu and "Show/hide output" 
# SUGGESTION #2: you can instead run the pip install commands in a different anaconda prompt

import python libraries

In [None]:
import numpy as np # by far the most used library for everyday computation
from scipy import stats # here we import a whole sub-library of stats functions
from matplotlib import pyplot as plt # all of our plotting is done with plt

figure settings and video function

In [None]:
# @title figure settings and functions
import matplotlib
import matplotlib.animation
matplotlib.rcParams.update({
    'axes.spines.top': False,
    'axes.spines.right': False,
    'legend.frameon': False,
    'figure.figsize': (8, 8),
})

from IPython.display import HTML

def make_video(fr, fs=13, trange=[690, 710], yrange=[100, 300], xrange=[200, 400]):
  ms = 1000 // fs
  fig = plt.figure(figsize=(4, 4), dpi=60)
  ax = fig.add_subplot(111)
  im = ax.imshow(fr[0, yrange[0] : yrange[1], xrange[0] : xrange[1]],
                 cmap="gray", vmin=0, vmax=3000)
  ax.axis("off")
  plt.close()
  def animate(t):
      im.set_data(fr[t + trange[0], yrange[0] : yrange[1], xrange[0] : xrange[1]])

  ani = matplotlib.animation.FuncAnimation(fig, animate, frames=trange[1] - trange[0],
                                           interval=ms)
  return ani


The next code cell downloads the data. You can also upload your own data to this folder on the left in the "Files" menu, or you can connect to your google drive (see instructions [here](https://colab.research.google.com/notebooks/io.ipynb)), which will make it easier to download the output files to your local computer.

In [None]:
# @title download data
import os, requests

# raw data
url = "https://www.suite2p.org/test_data/gt1.tif"

# registered data
fname = "gt1_reg.tif"
!gdown 1i8l5BZfIQp0puKpEuIXTr9rvOuQucjIr

from tifffile import imread
data = imread(fname)
data = data.astype("float32")

look at the shape of the tiff we downloaded and loaded

In [None]:
print('imaging data of shape: ', data.shape)
n_frames, Ly, Lx = data.shape

## visualize the data

First step when processing data is to look at it, check for artifacts, and decide how to process the data. We will make a video with 50 ex frames.

In [None]:
ani = make_video(data, fs=13, trange=[0, 50], yrange=[100,300],
                 xrange=[200,400])
HTML(ani.to_jshtml())

In [None]:
## DISCUSSION
# What can you see in the recording? The bright flashing disks are cells, but what about smaller flashing dots?
# Also, what are the black areas in the recording?
# And what does it mean when the background behind the cells lights up?

## choose image to segment

We will calculate:
1. mean image: mean of each pixel across all frames
2. maximum projection image: maximum value of each pixel across all frames

In calcium imaging, not all cells have baseline fluorescence, so not all cells will show up on the mean image. The max projection image will take the max across all timepoints, so cells which have large transients will pop out from the background. Compute these images below and visualize them.


In [None]:
from cellpose import transforms

# compute the mean image across frames in "data"
mean_img = data.mean(axis=0)

# compute the max image across frames in "data"
max_proj = data.max(axis=0)

# put the images in a dictionary for easy access
imgs = {"mean_img": mean_img,
        "max_proj": max_proj}

for d, key in enumerate(imgs.keys()):
  img = imgs[key].copy()
  img = transforms.normalize99(img) # normalize for plotting
  ax = plt.subplot(len(imgs),1,d+1)
  ax.imshow(img, vmin=0, vmax=1, cmap="gray")
  ax.axis("off"); ax.set_title(key)
plt.tight_layout()

### high-pass filtering the image (optional)

There are changes in brightness across the field of view - you can see for example shadows of blood vessels that reduce brightness. We can reduce the contribution of these by removing the low-frequency components of the image. We do this in  the fourier domain:

In [None]:
from torch.fft import fft2, ifft2, fftshift, ifftshift
import torch

# choose which image to analyze
key = "max_proj"

img = imgs[key].copy()
# we subtract the mean here to remove the DC component of the image (optional)
img_mean = img.mean()
img -= img_mean

# put image on GPU if possible
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
m_torch = torch.from_numpy(img).to(device)

# compute the fft of the image
# fft2 takes us from the pixel domain to the frequency domain
m_fft = fft2(m_torch)
# center the fft
m_fft = fftshift(m_fft)

# view the fft - at the center of the fft we see the low frequencies
y_cent, x_cent = Ly//2, Lx//2
im_fft = torch.abs(m_fft[y_cent-50 : y_cent+50, x_cent-50 : x_cent+50])
im_fft = im_fft.cpu().numpy()
plt.figure(figsize=(3,3))
plt.imshow(im_fft)

Now we will set to zero the low-frequency values at the center of the image:

In [None]:
# high pass filter by removing the low frequencies
hp = 8 # number of low frequencies to remove
m_fft_hp = torch.clone(m_fft)
m_fft_hp[y_cent-hp : y_cent+hp, x_cent-hp : x_cent+hp] = 0

# view the fft - w/ low frequencies removed
im_fft = torch.abs(m_fft_hp[y_cent-50 : y_cent+50, x_cent-50 : x_cent+50])
im_fft = im_fft.cpu().numpy()
plt.figure(figsize=(3,3))
plt.imshow(im_fft)

Let's return the image to the pixel domain:

In [None]:
# ifft2 takes us from the frequency domain back to the pixel domain
# (note we also need to undo the fftshift with ifftshift)
m_ifft = torch.real(ifft2(ifftshift(m_fft_hp)))
# return array to CPU
m_ifft = m_ifft.cpu().numpy()
img_filt = m_ifft + img_mean

# add to dictionary
imgs[key+"_filt"] = img_filt

View the filtered image:

In [None]:
img = transforms.normalize99(img_filt)
ax = plt.subplot(111)
ax.imshow(img, vmin=0, vmax=1, cmap="gray")
ax.axis("off")

In [None]:
## DISCUSSION
# The image looks more evenly bright, but there are still areas without cells.
# Why might that be?

# segment images with cellpose

Cellpose is an anatomical segmentation algorithm, which takes as input an image, and outputs masks corresponding to the identified cells. Cellpose does this with a deep neural network and pixel flow dynamics, both of which run much faster on the GPU.

In [None]:
from cellpose import models

# make a cellpose model
model = models.CellposeModel(gpu=True, # will use gpu if available
                             model_type="cyto3") # model type for cells (cytoplasm)

# choose which image to segment
img = imgs["max_proj_filt"].copy()

# run cellpose
masks, flows, styles = model.eval(img, 
                                  channels=[0,0], # grayscale
                                  diameter=8, # ~ diameter of cells in pixels
                                )
print(f"# of cells found: {masks.max()}")

View the masks - each pixel is assigned a number: 0 = background (no cells), 1 = cell1, 2 = cell2 ...

In [None]:
plt.figure(figsize=(15,4))
plt.subplot(1,4,1)
plt.imshow(masks[:,:300], cmap="magma")
plt.title("masks"); plt.axis("off")
plt.subplot(1,4,2)
plt.imshow(masks[:,:300]>0, cmap="magma")
plt.title("masks > 0"); plt.axis("off")
plt.subplot(1,4,3)
plt.imshow(flows[1][0][:,:300], cmap="bwr")
plt.title("Cellpose flows in Y"); plt.axis("off")
plt.subplot(1,4,4)
plt.imshow(flows[1][1][:,:300], cmap="bwr")
plt.title("Cellpose flows in X"); plt.axis("off")
plt.tight_layout()


Plot outlines for each cell on top of image:

In [None]:
from cellpose import utils

fig = plt.figure(figsize=(6,4))
outlines = utils.outlines_list(masks)
img_norm = transforms.normalize99(img.copy())
ax = plt.subplot(111)
ax.imshow(img_norm, vmin=0.05, vmax=0.85, cmap="gray")
ax.axis("off")
for o in outlines:
  ax.plot(o[:,0], o[:,1], lw=1, color="r")
plt.tight_layout()

In [None]:
## QUIZ:
# Why does Cellpose predict flows instead of 0, 1, 2, 3 etc for each mask?
# (A) The numbers for a given mask are meaningless - they can be permuted.
# (B) By running the dynamics steps on the flows, pixels will converge to cell centers for segmentation.
# (C) The flow representations enable representation of non-convex shapes.
# (D) All of the above.

# extract signals across time

We have a mask for each cell - now we want to compute the activity in each of these masks across time. For this we will sum the pixels in each mask on each frame of the recording $D_t$, with some weighting for each pixel. For an example cell with $M$ mask pixels $x_m$ and $y_m$, and weights $w_m$, the fluorescence trace at time $t$ corresponds to:

$$f_t = \sum^{M}_{m=1} w_p D_t[x_m, y_m]. $$

We can reformulate this as a dot product between a mask vector and $D_t$. The mask vector (as a matrix) is $\vec{c}[x_m,y_m] = w_m$ for pixels inside the mask, and $\vec{c}[x,y]=0$ for all other pixels. We can then flatten $\vec{c}$ and $D_t$ and perform a dot product between the two vectors, over all pixels $p$ in the image:

$$f_t = \sum^{L_yL_x}_{p=0} \vec{c}[p]\, D_t[p] = \vec{c} \cdot D_t $$

We compute this for each timepoint, which makes it a vector-matrix multiplication, where $D$ is timepoints by pixels:

$$ \vec{f} = \vec{c} D^\top$$


## cell fluorescence

First, let's compute this mask vector $\vec{c}$ for each cell:

In [None]:
# image used for finding cell masks
img_cells = img.copy()

# cell masks which were found
cell_masks = masks.copy()
Ly, Lx = cell_masks.shape
n_cells = cell_masks.max() # number of cells found

# matrix where each element is an image with one cell mask
cell_pix = np.zeros((n_cells, Ly, Lx), "float32")
for n in range(n_cells):
  # find the pixels of the cell mask
  ypix, xpix = np.nonzero(cell_masks == (n+1))

  # weight each pixel by the image intensity
  w = img_cells[ypix, xpix]
  w /= w.sum()

  # put the weighted pixels in the matrix
  cell_pix[n, ypix, xpix] = w


View example cell_pix:

In [None]:
n = 200
plt.imshow(cell_pix[n])

In [None]:
## QUIZ:
# What is the sum of the weights for each cell? Why did we set it this way?

Now we want to compute the fluorescence trace for each cell $i$:
$$ \vec{f} = \vec{c}_i D^\top.$$

For all cells,

$$ F = [\vec{c}_1 D^\top \,\,\, \vec{c}_2 D^\top \,...\, \vec{c}_n D^\top]. $$

This is equivalent to a matrix multiplication between $C$ and $D$, where each row $i$ of $C$ is $\vec{c}_i$:

$$ F = CD^\top.$$

This will require reshaping ``cell_pix`` and ``data`` to be 2D arrays, where the second dimension for each is the total number of pixels (Ly * Lx). 

In [None]:
print(data.shape, cell_pix.shape)

Reshape ``cell_pix`` and ``data``:

In [None]:
# reshape cell_pix to be n_cells by number of pixels
cell_pix_flat = cell_pix.reshape(n_cells, -1)

# reshape data to be n_frames by number of pixels
data = data.reshape(n_frames, -1)

Perform matrix multiplication to get the fluorescence traces for each cell:

In [None]:
F = cell_pix_flat @ data.T

print(F.shape)

Plot the fluorescence traces:

In [None]:
from scipy.stats import zscore
for n in range(10):
  plt.plot(zscore(F[n]) - n*5)

Matrix multiplication will be faster on the GPU, let's implement it using pytorch:

In [None]:
# put data on GPU if possible
import torch
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
d_gpu = torch.from_numpy(data).to(device)
cell_pix_gpu = torch.from_numpy(cell_pix).to(device)

# reshape data and cell_pix as above
d_gpu = d_gpu.reshape(n_frames, -1)
cell_pix_gpu = cell_pix_gpu.reshape(n_cells, -1)

# matrix multiplication
F = cell_pix_gpu @ d_gpu.T

# return traces to CPU
F = F.cpu().numpy()

**Bonus:** we can make this even faster using sparse matrices. Try to implement it if you're interested.

Sparse matrices are created using the indices of the non-zero values in the array, and the non-zero values.

In [None]:
# get indices in flattened cell_pix
cp = torch.nonzero(cell_pix_gpu)
cell_ids = cp[:,0]
pix = cp[:,1]

# weights (non-zero array values)
cell_weights = cell_pix_gpu[cell_ids, pix]

# indices with weights with shape (2, n_nonzero)
inds = torch.cat([pix.unsqueeze(0), cell_ids.unsqueeze(0)])

# create sparse matrix with "inds" and "cell_weights"
cmasks = torch.sparse_coo_tensor(inds, cell_weights,
                                  size=(Ly*Lx, n_cells))
cmasks = cmasks.to_sparse_csc()

print(cmasks.shape)

# matrix multiplication with sparse matrix
F_sp = d_gpu @ cmasks

# transpose to get n_cells by n_frames
F_sp = F_sp.T
# return traces to CPU
F_sp = F_sp.cpu().numpy()

Did we do it correctly? Check against original matrix-multiplication:

In [None]:
plt.figure(figsize=(2,2))
plt.scatter(F_sp[0], F[0])

## neuropil fluorescence

This is computed using the pixels surrounding the cell -- it is an approximation of the out-of-plane fluorescence contamination of the cell fluorescence.

We will compute it using a square of pixels surrounding each cell, excluding the cell pixels themselves.

In [None]:
# matrix where each element is an image with one neuropil mask
neuropil_pix = np.zeros((n_cells, Ly, Lx), "float32")

# cell_centers for each cell (will use for center of box)
cell_centers = np.zeros((n_cells, 2), "int")

# box size around cell center
bsize = 30

for n in range(n_cells):
  # find pixels in cell mask
  ypix, xpix = np.nonzero(cell_pix[n])
  
  # compute the cell center - we will make box around this
  med = np.median(ypix).astype("int"), np.median(xpix).astype("int")

  # save cell center (we will use this later for visualization)
  cell_centers[n] = med

  # set pixels in box to 1
  neuropil_pix[n, max(0, med[0] - bsize) : min(Ly, med[0] + bsize), 
                  max(0, med[1] - bsize) : min(Lx, med[1] + bsize)] = 1

# pixels to exclude from neuropil -- all pixels in cells
ycell, xcell = np.nonzero(cell_masks > 0)

# exclude cell pixels
neuropil_pix[:, ycell, xcell] = 0

# normalize so neuropil_pix for each cell will sum to 1
neuropil_pix /= neuropil_pix.sum(axis=(1,2), keepdims=True)

In [None]:
## DISCUSSION:
# Why did we have to take the max(0, med[0] - bsize)? What would happen if we didn't?

# Hint: think about what happens if the cell is near the edge of the image, 
# and we try to make a box around it: med[0] - bsize might go negative! 
# What does negative indexing mean in Python?

View example neuropil_pix:

In [None]:
n = 150
plt.subplot(1,2,1)
plt.imshow(neuropil_pix[n])
plt.subplot(1,2,2)
plt.imshow(cell_pix[n])

Compute neuropil fluorescence with the same matrix multiplication as above:

In [None]:
# put neuropil_pix on GPU if possible
neuropil_pix_gpu = torch.from_numpy(neuropil_pix).to(device)
neuropil_pix_gpu = neuropil_pix_gpu.reshape(n_cells, -1)

# matrix multiplication
Fneu = neuropil_pix_gpu @ d_gpu.T

# return neuropil traces to CPU
Fneu = Fneu.cpu().numpy()

Visualize neuropil + cell trace:

In [None]:
n = 150
fmax = F[n].max()
fig = plt.figure(figsize=(8,3))
plt.subplot(2,1,1)
plt.plot(F[n], label="cell")
plt.plot(Fneu[n], label="neuropil")
plt.legend()
plt.ylim([0, fmax])

plt.subplot(2,1,2)
plt.plot(F[n] - 0.7 * Fneu[n], color="k")
plt.ylim([0, fmax])
plt.title("neuropil subtracted trace")

plt.tight_layout()

We will use the neuropil-corrected trace, which is the cell trace minus the neuropil trace multiplied by a scaling factor of 0.7:

In [None]:
Fcorr = F.copy() - 0.7 * Fneu

In [None]:
## DISCUSSION:
# Do you think nearby cells will be more correlated to each other than far away cells?
# What about the neuropil of nearby cells?

# We will address this in  the next section

# visualize data with rastermap

We will make a plot of all the neuron traces. In order to better see the activity, we will sort the neurons so that correlated neurons are put next to each other. For this, we will use our algorithm [Rastermap](github.com/mouseland/rastermap).

In [None]:
from rastermap import Rastermap # import rastermap

# make Rastermap model
# (see rastermap documentation for more details)
rmodel = Rastermap(n_clusters=30, time_lag_window=10, bin_size=1)

# fit Rastermap to neuropil-corrected traces
rmodel.fit(Fcorr)

# get the embedding of the cells
embedding = rmodel.embedding[:,0]

# embedding plot (made in Rastermap)
X_embedding = rmodel.X_embedding

# plot the embedding
fig = plt.figure(figsize=(12,5))
ax = plt.subplot(111)
ax.imshow(X_embedding, vmin=0, vmax=1, aspect="auto", cmap="gray_r")
ax.set_xlabel("frames")
ax.set_ylabel("neurons")
axin = ax.inset_axes([1.02, 0, 0.02,1])
axin.imshow(np.arange(0, n_cells)[:,np.newaxis], cmap="jet", aspect="auto")
axin.axis("off")

Color cells by location in rastermap to see spatial relationships:

In [None]:
fig = plt.figure(figsize=(5,3))
ax = plt.subplot(111)
ax.scatter(cell_centers[:,1], cell_centers[:,0], c=embedding, cmap="jet", s=4)
ax.invert_yaxis()

We can also sort the neuropil traces, and see their spatial relationships:

In [None]:
from rastermap import Rastermap
rmodel = Rastermap(n_clusters=30, time_lag_window=10, bin_size=1)
rmodel.fit(Fneu)

# get the embedding of the neuropil traces
embedding = rmodel.embedding[:,0]

# embedding plot (made in Rastermap)
X_embedding = rmodel.X_embedding

# plot the embedding
fig = plt.figure(figsize=(12,5))
ax = plt.subplot(111)
ax.imshow(X_embedding, vmin=0, vmax=1, aspect="auto", cmap="gray_r")
ax.set_xlabel("frames")
ax.set_ylabel("neuropil (per neuron)")
axin = ax.inset_axes([1.02, 0, 0.02,1])
axin.imshow(np.arange(0, n_cells)[:,np.newaxis], cmap="jet", aspect="auto")
axin.axis("off")

Color cells by position according to neuropil rastermap:

In [None]:
fig = plt.figure(figsize=(5,3))
ax = plt.subplot(111)
ax.scatter(cell_centers[:,1], cell_centers[:,0], c=embedding, cmap="jet", s=4)
ax.invert_yaxis()


In [None]:
## DISCUSSION:
# How is it possible that the neuropil activity (bulk inputs/dendrites/axons)
# can be spatially organized, but NOT the cell activity?