# This notebook shows the current pipeline for training a network for blind spot compression

In [7]:
import os

import fastplotlib as fpl
import sys
import masknmf
import tifffile
import torch
import numpy as np

import matplotlib.pyplot as plt
import time
from typing import *
import pathlib
from pathlib import Path
%load_ext line_profiler
%matplotlib inline
%load_ext autoreload

# First step: take 1 dataset, run PMD on it without any temporal denoiser. The temporal basis from this result (pmd_obj.v) will serve as the training data for the network. If you plan to acquire multiple datasets with the same imaging parameters, you should be able to run this training loop on one dataset and use the neural network denoiser + PMD on all subsequent datasets. 

In [2]:
data = tifffile.imread("/path/to/dataset.tiff")

In [8]:
pmd_obj = masknmf.compression.pmd_decomposition(my_data,
                                                [32, 32],
                                                my_data.shape[0],
                                                max_components = 20,
                                                max_consecutive_failures = 1,
                                                temporal_avg_factor=1,
                                                spatial_avg_factor=1,
                                                background_rank = 0,
                                                device = "cuda",
                                                pixel_weighting=None,
                                                frame_batch_size = 1024)

# Train a denoiser on the temporal basis (the v components)

In [9]:
trained_model, _ = masknmf.compression.denoising.train_total_variance_denoiser(pmd_obj.v.cpu(),
                                                           max_epochs = 5,
                                                           batch_size = 128,
                                                            learning_rate=1e-4)
np.savez("neural_net.npz", model=trained_model)

# To see what the blind spot network does, run it on the training data and plot the results:

In [4]:
trained_model = np.load("neural_net.npz", allow_pickle=True)['model'].item()

In [6]:
# This is a way of saying "do not mix the held out observed data with the neural net predictions. Just use the network
noise_variance_quantile = 1

#This is a prefix for we use for each figure, so it is saved out as trace_1.png, trace_2.png
out_name = "trace"
# out_folder = "/your_output/folder"
out_folder = "testing_plots"
traces_to_check = 3
for k in range(traces_to_check):
    trace_to_denoise = pmd_obj.v.cpu()[[k], :]
    denoised_traces, signal_mean, noise_variance, signal_weight, observation_weight, total_var = masknmf.compression.denoising.denoise_batched(trained_model, 
                                                                                                                                               trace_to_denoise,
                                                                                                                                               noise_variance_quantile=noise_variance_quantile,
                                                                                                                                              input_size = pmd_obj.v.shape[1])
    # Ignore the first few/last few frames here
    t1, t2 = 10, -10     
    masknmf.visualization.pmd_temporal_denoiser_trace_plot(trace_to_denoise.squeeze()[t1:t2],
                             signal_mean.squeeze()[t1:t2],
                             denoised_traces.squeeze()[t1:t2],
                             signal_weight.squeeze()[t1:t2],
                             observation_weight.squeeze()[t1:t2],
                             total_var.squeeze()[t1:t2],
                             noise_variance,
                             out_name=f"{out_name}_{k}.png",
                             out_folder=out_folder)