In [20]:
## Backend library imports
import numpy as np
import os
import sys
import time
import torch
import scipy
from scipy.signal import butter, filtfilt
from sklearn.decomposition import NMF

## Plotting library imports
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import fastplotlib as fpl

## Demixing imports
import localnmf 
import masknmf
from masknmf.visualization import make_demixing_video, plot_ith_roi, construct_index
from matplotlib.gridspec import GridSpec

import plotly.graph_objects as go
import plotly.subplots as sp

import os
import re
import matplotlib.pyplot as plt

%matplotlib inline
%load_ext autoreload
%load_ext line_profiler

# Decide which device the demixing is run on, cuda or cpu

In [29]:
device='cuda'

# Load the high-pass filtered PMD data, convert it to a rlocalnmf Signal Demixer object

In [30]:
filename = "/path/to/high_pass_filtered/pmd_decomposition.npz"
full_pmd_movie = np.load(filename, allow_pickle=True)['pmd'].item()

In [31]:
#Construct the SignalDemixer demixing object
num_frames, fov_dim1, fov_dim2 = full_pmd_movie.shape
highpass_pmd_demixer = localnmf.SignalDemixer(full_pmd_movie.u,
                                                full_pmd_movie.r,
                                                full_pmd_movie.s,
                                                full_pmd_movie.v,
                                                (fov_dim1, fov_dim2, num_frames), 
                                                data_order=full_pmd_movie.order,
                                                device=device)

# Run 1st pass demixing on this data

# Identify remaining signals here

In [19]:
init_kwargs = {
    #Worth modifying
    'mad_correlation_threshold':0.6,
    'min_superpixel_size':3,

    #Mostly stable
    'mad_threshold':2,
    'residual_threshold': 0.3,
    'patch_size':(40, 40),
    'robust_corr_term':0.03,
    'plot_en':True,
    'text':False,
}

highpass_pmd_demixer.initialize_signals(**init_kwargs, is_custom = False)
print(f"Identified {highpass_pmd_demixer.results[0].shape[1]} neurons here")

# Lock the above results and move to demixing

In [18]:
highpass_pmd_demixer.lock_results_and_continue()

# Demix the data

In [17]:
num_iters = 13
## Now run demixing...
localnmf_params = {
    'maxiter':num_iters,
    'support_threshold':np.linspace(0.9, 0.6, num_iters).tolist(),
    'deletion_threshold':0.2,
    'ring_model_start_pt':30, #No ring model needed
    'ring_radius':20,
    'merge_threshold':0.8,
    'merge_overlap_threshold':0.8,
    'update_frequency':40,
    'c_nonneg':False,
    'denoise':False,
    'plot_en': False
}

start_time = time.time()
with torch.no_grad():
    highpass_pmd_demixer.demix(**localnmf_params)
print(f"that took {time.time() - start_time}")
print(f"after this step {highpass_pmd_demixer.results.a.shape[1]} signals identified")



# Visualize a demixing video. Things to check: are there missing neural signals in the demixing video? This is fine -- you can run a second pass of demixing on the "residual" to identify those signals. 

In [9]:
results = highpass_pmd_demixer.results

iw = make_demixing_video(results,
                    device,
                    v_range=[-1, 1])

iw.show()

# Lock these results in and continue

In [16]:
highpass_pmd_demixer.lock_results_and_continue()

# [Optional]: Multipass -- if you think there are residual signals, you can run initialization again followed by demixing. This multi-pass approach combines the previous results with the new initializations and uses this "superset" of signals to seed the new round of demixing. Only run the below if you believe there are missing signals based on the above. 

In [15]:
init_kwargs = {
    #In the second pass, these params are smaller to pick up "smaller" signals
    'mad_correlation_threshold':0.3,
    'min_superpixel_size':3,

    #Mostly stable
    'mad_threshold':2,
    'residual_threshold': 0.3,
    'patch_size':(40, 40),
    'robust_corr_term':0.03,
    'plot_en':True,
    'text':False,
}

highpass_pmd_demixer.initialize_signals(**init_kwargs, is_custom = False)
print(f"Identified {highpass_pmd_demixer.results[0].shape[1]} neural signals")

# [Optional]: Lock the above results and move to demixing

In [14]:
highpass_pmd_demixer.lock_results_and_continue()

# [Optional]: Run demixing

In [1]:
## Now run demixing...
num_iters = 13
localnmf_params = {
    'maxiter':num_iters,
    'support_threshold':np.linspace(0.6, 0.6, num_iters).tolist(),
    'deletion_threshold':0.2,
    'ring_model_start_pt':3, #No ring model needed
    'ring_radius':20,
    'merge_threshold':0.4,
    'merge_overlap_threshold':0.4,
    'update_frequency':4,
    'c_nonneg':False,
    'denoise':False,
    'plot_en': True
}

start_time = time.time()
with torch.no_grad():
    highpass_pmd_demixer.demix(**localnmf_params)
print(f"that took {time.time() - start_time}")
print(f"Number of neurons after demixing is {highpass_pmd_demixer.results.a.shape[1]}")

In [13]:
results = highpass_pmd_demixer.results

iw = make_demixing_video(results,
                         device,
                         v_range=[-1, 1])

iw.show()

# Part 2: Take the above spatial signals, and regress the (unfiltered) PMD data onto it. Note: the "support expansion", merging, are all disabled here (we do all that on the filtered data). At this stage we can just do alternating least squares updates to estimate A and C.

# Load the (unfiltered) PMD data and regress it directly onto the above spatial profiles

In [41]:
filename = "/path/to/pmd_decomposition.npz"
full_pmd_movie = np.load(filename, allow_pickle=True)['pmd'].item()

In [42]:
#Construct the SignalDemixer demixing object
num_frames, fov_dim1, fov_dim2 = full_pmd_movie.shape
unfiltered_pmd_demixer = localnmf.SignalDemixer(full_pmd_movie.u,
                                                full_pmd_movie.r,
                                                full_pmd_movie.s,
                                                full_pmd_movie.v,
                                                (fov_dim1, fov_dim2, num_frames), 
                                                data_order=full_pmd_movie.order,
                                                device=device)

# Initialize the signals using the "custom" option, where we provide pre-computed spatial footprints

In [43]:
unfiltered_pmd_demixer.initialize_signals(is_custom=True, spatial_footprints=highpass_pmd_demixer.results.a)
unfiltered_pmd_demixer.lock_results_and_continue()

Now in demixing state


# Run demixing, with no support updates

In [9]:
## Now run demixing...
num_iters = 18
localnmf_params = {
    'maxiter':num_iters,
    'support_threshold':np.linspace(0.8, 0.5, num_iters).tolist(),
    'deletion_threshold':0.2,
    'ring_model_start_pt':2, #Use ring model needed
    'ring_radius':20,
    'merge_threshold':0.8,
    'merge_overlap_threshold':0.8,
    'update_frequency':4, #No support updates
    'c_nonneg':True,
    'denoise':False,
    'plot_en': False
}

start_time = time.time()
with torch.no_grad():
    unfiltered_pmd_demixer.demix(**localnmf_params)
print(f"that took {time.time() - start_time}")
print(f"Identified {unfiltered_pmd_demixer.results.a.shape[1]} neurons")




# Visualize demixing results

In [10]:
results = unfiltered_pmd_demixer.results

iw = make_demixing_video(results,
                         device,
                         v_range=[-1, 1])

iw.show()

In [12]:
unfiltered_pmd_demixer.lock_results_and_continue()

In [2]:
init_kwargs = {
    #In the second pass, these params are smaller to pick up "smaller" signals
    'mad_correlation_threshold':0.6,
    'min_superpixel_size':3,

    #Mostly stable
    'mad_threshold':1,
    'residual_threshold': 0.3,
    'patch_size':(40, 40),
    'robust_corr_term':0.03,
    'plot_en':True,
    'text':False,
}


unfiltered_pmd_demixer.initialize_signals(**init_kwargs, is_custom = False)

In [3]:
unfiltered_pmd_demixer.lock_results_and_continue()

In [4]:
## Now run demixing...
localnmf_params = {
    'maxiter':25,
    'support_threshold':np.linspace(0.7, 0.6, 25).tolist(),
    'deletion_threshold':0.2,
    'ring_model_start_pt':28, #No ring model needed
    'ring_radius':20,
    'merge_threshold':0.8,
    'merge_overlap_threshold':0.8,
    'update_frequency':40, #No support updates
    'c_nonneg':False,
    'denoise':False,
    'plot_en': False
}

start_time = time.time()
with torch.no_grad():
    unfiltered_pmd_demixer.demix(**localnmf_params)
print(f"that took {time.time() - start_time}")
print(f"Identified {unfiltered_pmd_demixer.results.a.shape[1]} neurons")




# Final Step: regress the data onto the fixed spatial footprints without any background term

In [56]:
final_pmd_demixer = localnmf.SignalDemixer(full_pmd_movie.u,
                                                full_pmd_movie.r,
                                                full_pmd_movie.s,
                                                full_pmd_movie.v,
                                                (fov_dim1, fov_dim2, num_frames), 
                                                data_order=full_pmd_movie.order,
                                                device=device)

In [5]:
final_pmd_demixer.initialize_signals(is_custom=True, spatial_footprints=unfiltered_pmd_demixer.results.a)
final_pmd_demixer.lock_results_and_continue()

In [6]:
## Now run demixing...
num_iters = 25
localnmf_params = {
    'maxiter':num_iters,
    'support_threshold':np.linspace(0.8, 0.5, num_iters).tolist(),
    'deletion_threshold':0.2,
    'ring_model_start_pt':num_iters + 1, #No ring model needed
    'ring_radius':20,
    'merge_threshold':0.8,
    'merge_overlap_threshold':0.8,
    'update_frequency':num_iters + 1, #No support updates
    'c_nonneg':False,
    'denoise':False,
    'plot_en': False
}

start_time = time.time()
with torch.no_grad():
    final_pmd_demixer.demix(**localnmf_params)
print(f"that took {time.time() - start_time}")
print(f"Identified {final_pmd_demixer.results.a.shape[1]} neurons")




# Step 3: Visualize + export results

# Visualize final demixing results as a demixing video

In [7]:
results = unfiltered_pmd_demixer.results
iw = make_demixing_video(results,
                         device,
                         v_range=[-1, 1])

iw.show()

# Plot the results into a folder

In [8]:
## Specify which folder things get saved to: 
folder = 'Path_To_Save_Data'
if os.path.exists(folder):
    raise ValueError(f"folder {folder} already exists. delete it or pick different folder name")
else:
    os.mkdir(folder)

results = unfiltered_pmd_demixer.results

for i in range(results.a.shape[1]):
    name = f"neuron_{i}.html"
    plot_ith_roi(i, results, folder=folder, name=name, radius = 30)

construct_index(folder)