# Inspect CNMF-E traces

This code is to look at the time component of each cell. The main goal is to look at the signal for each ROI as extracted by CNMF-E vs the background signal vs using the same ROI to get raw signal out of the raw motion correct and the fft normalized (motion corrected) videos



In [1]:
try:
    get_ipython().magic(u'load_ext autoreload')
    get_ipython().magic(u'autoreload 2')
    get_ipython().magic(u'matplotlib qt')
except:
    pass

import matplotlib.pyplot as plt
import numpy as np
import os
from tqdm import tqdm
import napari
from collections import namedtuple
from scipy.optimize import curve_fit

import caiman as cm
from caiman.source_extraction.cnmf.spatial import threshold_components
from caiman.utils.visualization import inspect_correlation_pnr

import logging
from fancylog import fancylog
import fancylog as package

import cv2
try:
    cv2.setNumThreads(0)
except:
    pass
import bokeh.plotting as bpl

from fcutils.file_io.io import load_yaml
from fcutils.plotting.utils import clean_axes, save_figure, add_colorbar_to_img
from fcutils.plotting.colors import * 
from fcutils.plotting.plot_elements import plot_shaded_withline
from fcutils.maths.utils import rolling_pearson_correlation
from fcutils.file_io.utils import check_create_folder

from utils import print_cnmfe_components, plot_components_over_image, load_fit_cnmfe
from utils import start_server, load_params,  log_cnmfe_components, load_mmap_video_caiman

bpl.output_notebook()
c, dview, n_processes = start_server()

## Load metadta model and data

In [2]:
fld = 'D:\\Dropbox (UCL - SWC)\\Project_vgatPAG\\analysis\\doric\\BF164p2\\19JUN26'
metadata = load_yaml(os.path.join(fld, "01_PARAMS", "analysis_metadata.yml"))

# Load video
video = os.path.join(fld, metadata[metadata['video_for_cnmfe_fit']])
Yr, dims, T, images = load_mmap_video_caiman(video)
bg = np.max(images, axis=0)


# Load model
cnm, model_filepath = load_fit_cnmfe(fld, n_processes, dview)
cnm.estimates.evaluate_components(images, cnm.params, dview=dview)
cnm.estimates.f = None # Need to set it as None otherwise the background doesn't get computed

# Start logging
fancylog.start_logging(os.path.join(fld, "02_LOGS"), package, file_log_level="INFO", variables=[cnm.params], verbose=True,    filename='cnmfe_traces_inspection_logs')
logging.info("Starting CNMF-E traces inspection")
logging.info(f"Loading CNMF-E model from: {model_filepath}")

# Add analysis metadata to log
logging.info("ANALYSIS METADATA FILE:")
for k,v in metadata.items():
    logging.info(f"{k}: {v}")

logging.info(f"Video used for CNMF-E fitting: {metadata['video_for_cnmfe_fit']}: {video}")


INFO:root:Starting logging
2020-04-23 14:49:53 PM - INFO - MainProcess fancylog.py:271 - Starting logging
INFO:root:Multiprocessing-logging module not found, not logging multiple processes.
2020-04-23 14:49:53 PM - INFO - MainProcess fancylog.py:273 - Multiprocessing-logging module not found, not logging multiple processes.
INFO:root:Starting CNMF-E traces inspection
2020-04-23 14:49:53 PM - INFO - MainProcess <ipython-input-2-c204b453ab8e>:17 - Starting CNMF-E traces inspection
INFO:root:Loading CNMF-E model from: D:\Dropbox (UCL - SWC)\Project_vgatPAG\analysis\doric\BF164p2\19JUN26\cnmfe_fit.hdf5
2020-04-23 14:49:53 PM - INFO - MainProcess <ipython-input-2-c204b453ab8e>:18 - Loading CNMF-E model from: D:\Dropbox (UCL - SWC)\Project_vgatPAG\analysis\doric\BF164p2\19JUN26\cnmfe_fit.hdf5
INFO:root:ANALYSIS METADATA FILE:
2020-04-23 14:49:53 PM - INFO - MainProcess <ipython-input-2-c204b453ab8e>:21 - ANALYSIS METADATA FILE:
INFO:root:experimenter: Federico
2020-04-23 14:49:53 PM - INFO -

In [3]:
# Load videos
video = namedtuple("video", "Yr dims T images bg")
vids = ['raw_pw_mc', "transf_pw_mc"]
vids_paths = load_yaml(os.path.join(fld, metadata['outputfld'], "video_paths.yml"))

videos = {}
for vname, v in zip(['raw', 'normalized'], vids):
    # Get video data
    Yr, dims, T, images = load_mmap_video_caiman(vids_paths[v])

    # Get video background
    background = cnm.estimates.compute_background(Yr)
    background = background.T.reshape((T, )+dims, order="F")

    # Store in dict
    videos[vname] = video(Yr, dims, T, np.array(images), np.max(images, axis=0))
    videos[vname+"_background"] = video(Yr, dims, T, np.array(background), np.max(background, axis=0))





In [4]:
# Get components from model
n_components = cnm.estimates.A.shape[1] #  both good and bad components
good_components = cnm.estimates.idx_components
bad_components = cnm.estimates.idx_components_bad

# Spatial components: in a d1 x d1 x n_components matrix
A = np.reshape(cnm.estimates.A.toarray(), list(cnm.estimates.dims)+[-1], order='F') # set of spatial footprints
centroids = cnm.estimates.center
centroid_colors = [seagreen if n in good_components else darkred
                            for n in np.arange(n_components)]

# Masks (a d1 x d1 x n_comp array with 1 only where each cell is )
masks = np.zeros_like(A)
masks[A > metadata['spatial_th']] = 1

flat_mask = masks.sum(axis=2) # all masks together
flat_mask[flat_mask > 0] = 1


# Background stuff 
W = cnm.estimates.W.toarray() # Ring model matrix
b0 = cnm.estimates.b0.reshape(cnm.estimates.dims).T   # constant baseline for each pixel

# Temporal components
n_frames = cnm.estimates.C.shape[1]
C = cnm.estimates.C.T + cnm.estimates.YrA.T# set of temporal traces

# Get lengs of indivudal recordings
tiff_lengths = np.load(os.path.join(fld, '19JUN26_BF164p2_ds126_ffcSub_tifflengths.npy'))


For each video we need to extract each component's trace. To do so, we use the component's mask to get the relevant pixels in the video and then we average. This is a bit slow so it's worth doing it for all components at once and saving to file, this way we can just load stuff next time

In [5]:
def get_component_trace_from_video(cnum, masks, n_frames, video):
    """
        Masks the video (d1 x d1 x nframes 3D array) with a components spatial mask
        and averages the resulting array to get the signal's trace
    """
    # Create  a mask for the whole video
    component_mask = masks[:, :, cnum]
    component_video_mask = np.repeat(component_mask[np.newaxis, :,:, ], n_frames, axis=0)
    component_video_mask[component_video_mask == 0] = np.nan

    # Mask the video and average to extract the trace
    masked_video = video * component_video_mask
    masked_video = masked_video.reshape((masked_video.shape[0], -1)) # reshape to make averaging easier
    trace = np.nanmean(masked_video, axis=1)
    return trace


In [6]:
# Get the video traces for each component
video_traces = {}
for vidname, video in videos.items():
    
    savename = os.path.join(fld, vidname+"_comp_traces.npy")
    if not metadata['overwrite_traces'] and os.path.isfile(savename):
        logging.info(f"Loading traces for video {vidname} from file")
        traces = np.load(savename)
    else:
        logging.info(f"Extracting traces for video {vidname} from data")
        traces = np.zeros_like(C)
        for compn in tqdm(np.arange(n_components)):
            traces[:, compn] = get_component_trace_from_video(compn, masks, n_frames, video.images)
        np.save(savename, traces)
    
    video_traces[vidname] = traces

INFO:root:Loading traces for video raw from file
2020-04-23 14:50:16 PM - INFO - MainProcess <ipython-input-6-9fa5fd942ae3>:7 - Loading traces for video raw from file
INFO:root:Loading traces for video raw_background from file
2020-04-23 14:50:16 PM - INFO - MainProcess <ipython-input-6-9fa5fd942ae3>:7 - Loading traces for video raw_background from file
INFO:root:Loading traces for video normalized from file
2020-04-23 14:50:16 PM - INFO - MainProcess <ipython-input-6-9fa5fd942ae3>:7 - Loading traces for video normalized from file
INFO:root:Loading traces for video normalized_background from file
2020-04-23 14:50:16 PM - INFO - MainProcess <ipython-input-6-9fa5fd942ae3>:7 - Loading traces for video normalized_background from file


In [7]:
def double_exponential(x, a, b, c, d):
    return a * np.exp(b * x) + c * np.exp(d*x)

def remove_exponential_from_trace(x, y):
    """ Fits a double exponential to the data and returns the results
    
        :param x: np.array with time indices
        :param y: np.array with signal
        :returns: np.array with doble exponential corrected out
    """
    popt, pcov = curve_fit(double_exponential, x, y,
                        maxfev=2000, 
                        p0=(1.0,  -1e-6, 1.0,  -1e-6),
                        bounds = [[1, -1e-1, 1, -1e-1], [300, 0, 300, 0]])

    fitted_doubleexp = double_exponential(x, *popt)
    y_pred = y - (fitted_doubleexp - np.min(fitted_doubleexp))
    return y_pred, fitted_doubleexp

## Visualise stuff

In [29]:
def visualize_component(compn):
    # Create figure
    f, axarr = plt.subplots(ncols=4, nrows=4, figsize=(30, 12), gridspec_kw=dict(width_ratios=[1, 1, 2, 2]))
    
    if compn in good_components:
        f.suptitle(f"Component {compn} - GOOD")
    elif compn in bad_components:
        f.suptitle(f"Component {compn} - BAD")
    else:
        f.suptitle(f"Component {compn} - ??")

    # Plot images
    compimg = axarr[0, 0].imshow(flat_mask, cmap="gray") # All components 
    axarr[0, 0].scatter(centroids[compn, 1], centroids[compn, 0], color=centroid_colors[compn], s=20, alpha=1)

    aimg = axarr[0, 1].imshow(A[:, :, compn], vmin=0.1) # This spatial component

    bgimg = axarr[1, 0].imshow(videos['raw'].bg, cmap="gray") # mean frame
    bgimg = axarr[1, 1].imshow(videos['raw_background'].bg, cmap="gray") # mean background frame


    bgimg = axarr[2, 0].imshow(videos['normalized'].bg, cmap="gray") # mean frame
    bgimg = axarr[2, 1].imshow(videos['normalized_background'].bg, cmap="gray") # mean background frame

    axarr[3, 0].imshow(b0)

    axarr[3, 1].imshow(videos['normalized'].bg - b0)

    # Plot traces
    x = np.arange(n_frames)
    plot_shaded_withline(axarr[0, 2], x, C[:, compn], z=np.min(C[:, compn]), label="CNMF-E trace", alpha=0.15, color=salmon)
    # axarr[0, 2].plot(C[:, compn], color=salmon, lw=1.5, alpha=.8, label="CNMF-E trace")
    axarr[0, 2].legend()

    # Plot traces from video
    axes = [axarr[1, 2], axarr[1, 2], axarr[2, 2], axarr[2, 2]]
    colors = [cornflowerblue, mediumseagreen, cornflowerblue, mediumseagreen]
    for ax, color, (vidname, video) in zip(axes, colors, videos.items()):
        y = video_traces[vidname][:, compn]

        # if 'raw' in vidname.lower():
        #    y, _ = remove_exponential_from_trace(x, y)

        if 'background' in vidname:
            kwargs = dict(lw=1, alpha=.5, zorder=10)
        else:
            kwargs = dict(lw=1.5, alpha=.8, zorder=99)

        ax.plot(x, y, color=color, label=vidname, **kwargs)

    axarr[1, 2].legend(loc="upper left")
    axarr[2, 2].legend(loc="upper left")

    # Plot background subtracted traces
    axarr[0, 3].plot(x,  C[:, compn] - video_traces['normalized_background'][:, compn], color=salmon)
    axarr[1, 3].plot(x, video_traces['raw'][:, compn] - video_traces['raw_background'][:, compn], color=cornflowerblue)
    axarr[2, 3].plot(x, video_traces['normalized'][:, compn] - video_traces['normalized_background'][:, compn], color=cornflowerblue)


    # Plot start of each inidivual recording
    for ax in [axarr[0, 2], axarr[0, 3], axarr[1, 2], axarr[1, 3], axarr[2, 2], axarr[2, 3], axarr[3, 2], axarr[3, 3]]:
        for start in tiff_lengths:
            ax.axvline(start, lw=2, color=[.6, .6, .6], ls="--", alpha=.8)


    # Plot rolling pearson correlation between traces
    wnd = metadata['rolling_pearson_wnd']
    roll1 = rolling_pearson_correlation(C[:, compn], video_traces['raw'][:, compn], wnd)
    roll2 = rolling_pearson_correlation( video_traces['raw'][:, compn], video_traces['normalized'][:, compn], wnd)
    roll3 = rolling_pearson_correlation(C[:, compn], video_traces['normalized'][:, compn], wnd)

    axarr[3, 2].plot(roll1, color=plum, lw=2, label="CNMF-E vs raw")
    axarr[3, 2].plot(roll2, color=darksalmon, lw=2, label="raw vs normalized")
    axarr[3, 2].plot(roll3, color=deepskyblue, lw=2, label="CNMF-E vs normalized")
    axarr[3, 2].legend()

    # Plot rolling pearson correlation between background subtracted traces
    roll1 = rolling_pearson_correlation(C[:, compn] - video_traces['normalized_background'][:, compn], 
                                    video_traces['raw'][:, compn]- video_traces['raw_background'][:, compn], wnd)
    roll2 = rolling_pearson_correlation( video_traces['raw'][:, compn]- video_traces['raw_background'][:, compn], 
                                    video_traces['normalized'][:, compn]- video_traces['normalized_background'][:, compn], wnd)
    roll3 = rolling_pearson_correlation(C[:, compn]- video_traces['normalized_background'][:, compn], 
                                    video_traces['normalized'][:, compn]- video_traces['normalized_background'][:, compn], wnd)

    axarr[3, 3].plot(roll1, color=plum, lw=2, label="CNMF-E vs raw")
    axarr[3, 3].plot(roll2, color=darksalmon, lw=2, label="raw vs normalized")
    axarr[3, 3].plot(roll3, color=deepskyblue, lw=2, label="CNMF-E vs normalized")
    axarr[3, 3].legend()

    # Decorate axes and stuff
    axarr[0, 0].set(title="All components", xticks=[], yticks=[])
    axarr[0, 0].axis('off')
    axarr[0, 1].set(title="Spatial component", xticks=[], yticks=[])
    axarr[0, 1].axis('off')
    axarr[1, 0].set(title="[RAW] Mean frame", xticks=[], yticks=[])
    axarr[1, 0].axis('off')
    axarr[1, 1].set(title="[RAW] Mean frame background", xticks=[], yticks=[])
    axarr[1, 1].axis('off')
    axarr[2, 0].set(title="[Normalized] Max projection frame", xticks=[], yticks=[])
    axarr[2, 0].axis('off')
    axarr[2, 1].set(title="[Normalized] Max projection frame background", xticks=[], yticks=[])
    axarr[2, 1].axis('off')

    axarr[0, 2].set(title="CNMF-E CRaw trace", xlabel="frames", ylabel="SIGNAL")
    axarr[1, 2].set(title="[RAW] trace and background", xlabel="frames", ylabel="SIGNAL")

    axarr[1, 2].tick_params(axis="y", labelcolor=cornflowerblue)
    axarr[2, 2].set(title="[Normalized] trace and background", xlabel="frames", ylabel="SIGNAL")
    axarr[2, 2].tick_params(axis="y", labelcolor=cornflowerblue)

    axarr[3, 2].set(title="Rolling pearson correlation", ylabel="correlation", xlabel="frames", ylim=[-1, 1])

    axarr[3, 0].set(title="b0: pixel baselines",  xticks=[], yticks=[])
    axarr[3, 0].axis('off')
    axarr[3, 1].set(title="Normalized mean frame - b0",  xticks=[], yticks=[])
    axarr[3, 1].axis('off')

    clean_axes(f)
    f.tight_layout()

    return f
        
_ = visualize_component(0)



In [30]:
# save a figure for each component
save_fld = os.path.join(fld, "cnmfe_components")
check_create_folder(save_fld)

plt.ioff()

for compn in tqdm(np.arange(n_components)):
    savename = os.path.join(save_fld, f"component_{compn}")
    f = visualize_component(compn)
    save_figure(f, savename, verbose=False)
    plt.close(f)

plt.ion()

100%|██████████| 44/44 [01:48<00:00,  2.47s/it]
