# Inspect CNMF-E traces

This code is to look at the time component of each cell. The main goal is to look at the signal for each ROI as extracted by CNMF-E vs the background signal vs using the same ROI to get raw signal out of the raw motion correct and the fft normalized (motion corrected) videos



In [1]:
try:
    get_ipython().magic(u'load_ext autoreload')
    get_ipython().magic(u'autoreload 2')
    get_ipython().magic(u'matplotlib qt')
except:
    pass

import matplotlib.pyplot as plt
import numpy as np
import os
from tqdm import tqdm
import napari
from collections import namedtuple
from scipy.optimize import curve_fit

import caiman as cm
from caiman.source_extraction.cnmf.spatial import threshold_components
from caiman.utils.visualization import inspect_correlation_pnr

import logging
from fancylog import fancylog
import fancylog as package

import cv2
try:
    cv2.setNumThreads(0)
except:
    pass
import bokeh.plotting as bpl

from fcutils.file_io.io import load_yaml
from fcutils.plotting.utils import clean_axes, save_figure, add_colorbar_to_img, set_figure_subplots_aspect
from fcutils.plotting.colors import * 
from fcutils.plotting.plot_elements import plot_shaded_withline
from fcutils.maths.utils import rolling_pearson_correlation
from fcutils.file_io.utils import check_create_folder

from utils import print_cnmfe_components, plot_components_over_image, load_fit_cnmfe
from utils import start_server, load_params,  log_cnmfe_components, load_mmap_video_caiman, load_tiff_video_caiman

bpl.output_notebook()
c, dview, n_processes = start_server()

## Load metadata model and data

In [2]:
fld = 'D:\\Dropbox (UCL - SWC)\\Project_vgatPAG\\analysis\\doric\\BF164p2\\19JUN26'
metadata = load_yaml(os.path.join(fld, "01_PARAMS", "analysis_metadata.yml"))

# Load example vid to get stuff like dims etc video
video = os.path.join(fld, metadata[metadata['video_for_cnmfe_fit']])
Yr, dims, T, images = load_mmap_video_caiman(video)
bg = np.max(images, axis=0)


# Load model
if metadata['working_on'] == "raw":
    cnm, model_filepath = load_fit_cnmfe(fld, n_processes, dview, raw=True, curated=True)
else:
    cnm, model_filepath = load_fit_cnmfe(fld, n_processes, dview, raw=False, curated=True)
cnm.estimates.evaluate_components(images, cnm.params, dview=dview)

# Start logging
fancylog.start_logging(os.path.join(fld, "02_LOGS"), package, file_log_level="INFO", variables=[cnm.params], verbose=True,    filename='cnmfe_traces_inspection_logs')
logging.info("Starting CNMF-E traces inspection")
logging.info(f"Loading CNMF-E model from: {model_filepath}")

# Add analysis metadata to log
logging.info("ANALYSIS METADATA FILE:")
for k,v in metadata.items():
    logging.info(f"{k}: {v}")

logging.info(f"Video used for CNMF-E fitting: {metadata['video_for_cnmfe_fit']}: {video}")


INFO:root:Starting logging
2020-04-24 13:59:57 PM - INFO - MainProcess fancylog.py:271 - Starting logging
INFO:root:Multiprocessing-logging module not found, not logging multiple processes.
2020-04-24 13:59:57 PM - INFO - MainProcess fancylog.py:273 - Multiprocessing-logging module not found, not logging multiple processes.
INFO:root:Starting CNMF-E traces inspection
2020-04-24 13:59:57 PM - INFO - MainProcess <ipython-input-2-6ea036083e99>:19 - Starting CNMF-E traces inspection
INFO:root:Loading CNMF-E model from: D:\Dropbox (UCL - SWC)\Project_vgatPAG\analysis\doric\BF164p2\19JUN26\cnmfe_fit_curated.hdf5
2020-04-24 13:59:57 PM - INFO - MainProcess <ipython-input-2-6ea036083e99>:20 - Loading CNMF-E model from: D:\Dropbox (UCL - SWC)\Project_vgatPAG\analysis\doric\BF164p2\19JUN26\cnmfe_fit_curated.hdf5
INFO:root:ANALYSIS METADATA FILE:
2020-04-24 13:59:57 PM - INFO - MainProcess <ipython-input-2-6ea036083e99>:23 - ANALYSIS METADATA FILE:
INFO:root:experimenter: Federico
2020-04-24 13:5

In [3]:
# Load videos
vids = {"raw":"raw_pw_mc", "raw_bg_fiji":"raw_pw_mc_fiji_bg", "raw_bg_cnmf":"raw_pw_mc_cnmf_bg",
        "normalized":"transf_pw_mc", "normalized_bg_fiji":"transf_pw_mc_fiji_bg", "normalized_bg_cnmf":"transf_pw_mc_cnmf_bg"}

video = namedtuple("video", "Yr dims T images bg") # tuple used to store stuff
vids_paths = load_yaml(os.path.join(fld, metadata['outputfld'], "video_paths.yml"))

videos = {}
for vname, v in vids.items():
    # Get video data
    if 'mmap' in vids_paths[v]:
        Yr, dims, T, images = load_mmap_video_caiman(vids_paths[v])
    else:
        images = load_tiff_video_caiman(vids_paths[v])
        Yr, dims, T = None, None, None

    # Store in dict
    videos[vname] = video(Yr, dims, T, np.array(images), np.max(images, axis=0))



In [4]:
# Get components from model
n_components = cnm.estimates.A.shape[1] #  both good and bad components
good_components = cnm.estimates.idx_components
bad_components = cnm.estimates.idx_components_bad

# Spatial components: in a d1 x d1 x n_components matrix
A = np.reshape(cnm.estimates.A.toarray(), list(cnm.estimates.dims)+[-1], order='F') # set of spatial footprints
centroids = cnm.estimates.center
centroid_colors = [seagreen if n in good_components else darkred
                            for n in np.arange(n_components)]

# Masks (a d1 x d1 x n_comp array with 1 only where each cell is )
masks = np.zeros_like(A)
masks[A > metadata['spatial_th']] = 1

flat_mask = masks.sum(axis=2) # all masks together
flat_mask[flat_mask > 0] = 1


# Background stuff 
W = cnm.estimates.W.toarray() # Ring model matrix
b0 = cnm.estimates.b0.reshape(cnm.estimates.dims).T   # constant baseline for each pixel

# Temporal components
n_frames = cnm.estimates.C.shape[1]
C = cnm.estimates.C.T + cnm.estimates.YrA.T# set of temporal traces

# Get lengs of indivudal recordings
tiff_lengths = np.load(os.path.join(fld, '19JUN26_BF164p2_ds126_ffcSub_tifflengths.npy'))


For each video we need to extract each component's trace. To do so, we use the component's mask to get the relevant pixels in the video and then we average. This is a bit slow so it's worth doing it for all components at once and saving to file, this way we can just load stuff next time

In [5]:
def get_component_trace_from_video(cnum, masks, n_frames, video):
    """
        Masks the video (d1 x d1 x nframes 3D array) with a components spatial mask
        and averages the resulting array to get the signal's trace
    """
    # Create  a mask for the whole video
    component_mask = masks[:, :, cnum]
    component_video_mask = np.repeat(component_mask[np.newaxis, :,:, ], n_frames, axis=0)
    component_video_mask[component_video_mask == 0] = np.nan

    # Mask the video and average to extract the trace
    masked_video = video * component_video_mask
    masked_video = masked_video.reshape((masked_video.shape[0], -1)) # reshape to make averaging easier
    trace = np.nanmean(masked_video, axis=1)
    return trace


# Get the video traces for each component
video_traces = {}
for vidname, video in videos.items():
    if metadata['working_on'] == "raw":
        savename = os.path.join(fld, vidname+"_comp_traces_raw.npy")
    else:
        savename = os.path.join(fld, vidname+"_comp_traces.npy")
    if not metadata['overwrite_traces'] and os.path.isfile(savename):
        logging.info(f"Loading traces for video {vidname} from file")
        traces = np.load(savename)
    else:
        logging.info(f"Extracting traces for video {vidname} from data")
        traces = np.zeros_like(C)
        for compn in tqdm(np.arange(n_components)):
            traces[:, compn] = get_component_trace_from_video(compn, masks, n_frames, video.images)
        np.save(savename, traces)
    
    video_traces[vidname] = traces

INFO:root:Loading traces for video raw from file
2020-04-24 14:00:10 PM - INFO - MainProcess <ipython-input-5-7a666ffc274a>:26 - Loading traces for video raw from file
INFO:root:Loading traces for video raw_bg_fiji from file
2020-04-24 14:00:10 PM - INFO - MainProcess <ipython-input-5-7a666ffc274a>:26 - Loading traces for video raw_bg_fiji from file
INFO:root:Loading traces for video raw_bg_cnmf from file
2020-04-24 14:00:10 PM - INFO - MainProcess <ipython-input-5-7a666ffc274a>:26 - Loading traces for video raw_bg_cnmf from file
INFO:root:Loading traces for video normalized from file
2020-04-24 14:00:10 PM - INFO - MainProcess <ipython-input-5-7a666ffc274a>:26 - Loading traces for video normalized from file
INFO:root:Loading traces for video normalized_bg_fiji from file
2020-04-24 14:00:10 PM - INFO - MainProcess <ipython-input-5-7a666ffc274a>:26 - Loading traces for video normalized_bg_fiji from file
INFO:root:Loading traces for video normalized_bg_cnmf from file
2020-04-24 14:00:10 

## Visualise stuff

In [6]:
def plot_trace(ax1, ax2, x, y, cnmfb, fijibg, col1, col2, col3, label=None, do_subtraction=True):
    ax1.plot(x,y, label=label, alpha=0.8, color=col1)
    ax1.plot(x, cnmfb, color=col2, lw=1, alpha=.6, zorder=10, label="cnmf background")
    ax1.plot(x, fijibg, color=col3, lw=1, alpha=.6, zorder=10, label="fiji background")
    ax1.legend()

    if do_subtraction:
        ax2.plot(x, y-cnmfb, alpha=.6, color=col2, label="signal - cnmf bg")
        ax2.plot(x, y-fijibg, alpha=.6, color=col3, label="signal - fiji bg")
    else:
        ax2.plot(x, y, alpha=.8, color=col1, label=label)
    ax2.legend()


In [7]:
def visualize_component(compn):
    # Create figure
    f, axarr = plt.subplots(ncols=4, nrows=3, figsize=(24, 12), gridspec_kw=dict(width_ratios=[1, 1, 2, 2]))
    
    if compn in good_components:
        f.suptitle(f"Component {compn} - GOOD")
        cmap="Greens_r"
    elif compn in bad_components:
        f.suptitle(f"Component {compn} - BAD")
        cmap="Reds_r"
    else:
        f.suptitle(f"Component {compn} - ??")
        cmap="Blues_r"

    # Plot images
    compimg = axarr[0, 0].imshow(flat_mask, cmap="gray") # All components 
    modmask = masks[:, :, compn].copy()
    modmask[modmask == 0] = np.nan
    axarr[0, 0].imshow(modmask, cmap=cmap, vmin=0, vmax=2)

    aimg = axarr[0, 1].imshow(masks[:, :, compn]) # This spatial component

    bgimg = axarr[1, 0].imshow(videos['raw_bg_fiji'].bg, cmap="gray") # mean frame
    bgimg = axarr[1, 1].imshow(videos['raw_bg_cnmf'].bg, cmap="gray") # mean background frame

    bgimg = axarr[2, 0].imshow(videos['normalized_bg_fiji'].bg, cmap="gray") # mean frame
    bgimg = axarr[2, 1].imshow(videos['normalized_bg_cnmf'].bg, cmap="gray") # mean background frame

    # Plot traces
    x = np.arange(n_frames)

    if metadata['working_on'] == "raw":
        cnmfe_bg_trace = video_traces['raw_bg_cnmf'][:, compn]
        fiji_bg_trace = video_traces['raw_bg_fiji'][:, compn]
    else:
        cnmfe_bg_trace = video_traces['normalized_bg_cnmf'][:, compn]
        fiji_bg_trace = video_traces['normalized_bg_fiji'][:, compn]

    # CNMFE signal
    plot_trace(axarr[0, 2], axarr[0, 3], x, C[:, compn], cnmfe_bg_trace, fiji_bg_trace, 
                        salmon, mediumseagreen, indianred,
                        label=f"CNMF-E trace {metadata['working_on']}", do_subtraction=False)

    # RAW video signal
    plot_trace(axarr[1, 2], axarr[1, 3], x, video_traces['raw'][:, compn], video_traces['raw_bg_cnmf'][:, compn], 
                        video_traces['raw_bg_fiji'][:, compn],
                        cornflowerblue, mediumseagreen, indianred, 
                        label="RAW trace")


    # Normalized video signal
    plot_trace(axarr[2, 2], axarr[2, 3], x, video_traces['normalized'][:, compn], video_traces['normalized_bg_cnmf'][:, compn],
                        video_traces['normalized_bg_fiji'][:, compn] ,
                        cornflowerblue, mediumseagreen, indianred, 
                        label="Norm trace")


    # Plot start of each inidivual recording
    for ax in [axarr[0, 2], axarr[0, 3], axarr[1, 2], axarr[1, 3], axarr[2, 2], axarr[2, 3]]:
        for start in tiff_lengths:
            ax.axvline(start, lw=2, color=[.6, .6, .6], ls="--", alpha=.8)


    # Decorate axes and stuff
    axarr[0, 0].set(title="All components")
    axarr[0, 0].axis('off')
    axarr[0, 1].set(title="Spatial component")
    axarr[0, 1].axis('off')
    axarr[1, 0].set(title="[RAW] background fiji")
    axarr[1, 0].axis('off')
    axarr[1, 1].set(title="[RAW] background cnmf")
    axarr[1, 1].axis('off')
    axarr[2, 0].set(title="[Normalized] background fiji")
    axarr[2, 0].axis('off')
    axarr[2, 1].set(title="[Normalized] background cnmf")
    axarr[2, 1].axis('off')

    axarr[0, 2].set(title="CNMF-E CRaw trace", xlabel="frames", ylabel="SIGNAL")
    axarr[1, 2].set(title="[RAW] trace and background", xlabel="frames", ylabel="SIGNAL")

    axarr[1, 2].tick_params(axis="y", labelcolor=cornflowerblue)
    axarr[2, 2].set(title="[Normalized] trace and background", xlabel="frames", ylabel="SIGNAL")
    axarr[2, 2].tick_params(axis="y", labelcolor=cornflowerblue)

    axarr[1, 3].set(title="Background subtracted traces")

    clean_axes(f)
    set_figure_subplots_aspect(left=0.01, right=0.99, top=0.9)
    return f
        
_ = visualize_component(0)



Save a summary image for each component

In [None]:
# save a figure for each component
if metadata['working_on'] == 'raw':
    save_fld = os.path.join(fld, "cnmfe_components_raw")
else:
    save_fld = os.path.join(fld, "cnmfe_components")
check_create_folder(save_fld)



plt.ioff()
for compn in tqdm(np.arange(n_components)):
    savename = os.path.join(save_fld, f"component_{compn}")
    f = visualize_component(compn)
    save_figure(f, savename, verbose=False)
    plt.close(f)

plt.ion()

Save an image with the masks of all components

In [8]:
f, ax = plt.subplots(figsize=(12, 12))

f.suptitle(f"Components of CNFM-E fitted on {metadata['working_on']}")

compimg = ax.imshow(flat_mask, cmap="gray")

for compn in np.arange(n_components):
    if compn in good_components:
        cmap = "Greens_r"
    else:
        cmap = "Reds_r"

    modmask = masks[:, :, compn].copy()
    modmask[modmask == 0] = np.nan
    ax.imshow(modmask, cmap=cmap, vmin=0, vmax=2)

ax.set(xticks=[], yticks=[])


if metadata['working_on'] == 'raw':
    save_fld = os.path.join(fld, "cnmfe_components_raw")
else:
    save_fld = os.path.join(fld, "cnmfe_components")
save_figure(f, os.path.join(save_fld, f"all_masks"), verbose=False)



In [9]:
### Testing stuff

In [12]:
from caiman.source_extraction.cnmf.cnmf import load_CNMF

doms_model = load_CNMF(os.path.join(fld, "19JUN26_BF164p2_ds126_ffcSub_cnm.hdf5"),  n_processes=n_processes, dview=dview)

INFO:root:Changing key caiman_version in group data from 1.8.5 to 1.5.2
2020-04-24 14:44:27 PM - INFO - MainProcess params.py:962 - Changing key caiman_version in group data from 1.8.5 to 1.5.2
INFO:root:Changing key decay_time in group data from 0.4 to 2.0
2020-04-24 14:44:27 PM - INFO - MainProcess params.py:962 - Changing key decay_time in group data from 0.4 to 2.0
INFO:root:Changing key dims in group data from None to (126, 126)
2020-04-24 14:44:27 PM - INFO - MainProcess params.py:962 - Changing key dims in group data from None to (126, 126)
INFO:root:Changing key fnames in group data from None to [b'/Users/dom/Dropbox (UCL - SWC)/P']
2020-04-24 14:44:27 PM - INFO - MainProcess params.py:962 - Changing key fnames in group data from None to [b'/Users/dom/Dropbox (UCL - SWC)/P']
INFO:root:Changing key fr in group data from 30 to 10.0
2020-04-24 14:44:27 PM - INFO - MainProcess params.py:962 - Changing key fr in group data from 30 to 10.0
INFO:root:Changing key last_commit in group 

In [14]:
doms_model.estimates.coordinates = cm.utils.visualization.get_contours(doms_model.estimates.A, bg.shape, thr=.2, thr_method="max")

In [18]:
dom_A = np.reshape(doms_model.estimates.A.toarray(), list(doms_model.estimates.dims)+[-1], order='F') # set of spatial footprints


# Masks (a d1 x d1 x n_comp array with 1 only where each cell is )
dom_masks = np.zeros_like(dom_A)
dom_masks[dom_A > metadata['spatial_th']] = 1

dom_flat_mask = dom_masks.sum(axis=2) # all masks together
dom_flat_mask[dom_flat_mask > 0] = 1



<matplotlib.image.AxesImage at 0x17fbd105710>

In [23]:
plt.ion()

f, ax = plt.subplots()
ax.imshow(dom_flat_mask, cmap="gray")



<matplotlib.image.AxesImage at 0x17fc1f796a0>