# This demo shows how to use a pre-trained blind spot denoiser in the loop with PMD to obtain a denoised + compressed representation of your data. This notebook assumes you already have a blind spot denoising network, obtained from running a training loop like the one found in training_demo.ipynb in the denoising notebooks folder

In [2]:
import os

import fastplotlib as fpl
import sys
import masknmf
import tifffile
import torch
import numpy as np

import matplotlib.pyplot as plt
import time
from typing import *
import pathlib
from pathlib import Path
%load_ext line_profiler
%matplotlib inline
%load_ext autoreload

# Load the neural network here and wrap it a PMDTemporalDenoiser object

In [2]:
trained_model = np.load("neural_net.npz", allow_pickle=True)['model'].item()
curr_temporal_denoiser = masknmf.compression.PMDTemporalDenoiser(trained_model)

# Load a motion corrected dataset

In [3]:
my_data  = tifffile.imread("/path/to/dataset.tiff")

# Run blind spot compression, passing in the temporal denoiser object as a parameter

In [1]:
pmd_with_denoiser = masknmf.compression.pmd_decomposition(my_data,
                                                [32, 32],
                                                my_data.shape[0],
                                                max_components = 20,
                                                max_consecutive_failures = 1,
                                                temporal_avg_factor=4,
                                                spatial_avg_factor=1,
                                                background_rank = 0,
                                                device = "cuda",
                                                pixel_weighting=None,
                                                frame_batch_size = 1024,
                                                temporal_denoiser = curr_temporal_denoiser)

# You can directly use this PMD object to run demixing (see the demos in notebooks/demixing). The first and most basic visualization of the results here involves looking at the raw data, pmd outputs, and residual. 

In [5]:
resid_arr = masknmf.compression.PMDResidualArray(my_data, pmd_with_denoiser)

# If you have a nvidia cuda device, move everything to GPU

In [6]:
pmd_with_denoiser.to('cuda') #If you do not, skip this step

In [9]:
iw = fpl.ImageWidget(data = [my_data, pmd_with_denoiser, resid_arr],
                    figure_shape = (1, 3))
iw.cmap = "gray"
iw.show()