In [1]:
import os
import sys
import masknmf
import tifffile
import torch
import numpy as np
import matplotlib.pyplot as plt
import time
from typing import *
import pathlib
from masknmf import display
import pathlib
from pathlib import Path

import fastplotlib as fpl


class MotionBinDataset:
    """Load a suite2p data.bin imaging registration file."""

    def __init__(self, 
                 data_path: Union[str, pathlib.Path],
                 metadata_path: Union[str, pathlib.Path]):
        """
        Load a suite2p data.bin imaging registration file.

        Parameters
        ----------
        data_path (str, pathlib.Path): The session path containing preprocessed data.
        metadata_path (str, pathlib.Path): The metadata_path to load. 
        """
        self.bin_path = Path(data_path)
        self.ops_path = Path(metadata_path)
        self._dtype = np.int16
        self._shape = self._compute_shape()
        self.data = np.memmap(self.bin_path, mode='r', dtype=self.dtype, shape=self.shape)

    @property
    def dtype(self) -> np.dtype:
        return self._dtype

    @property
    def shape(self):
        """
        This property should return the shape of the dataset, in the form: (d1, d2, T) where d1
        and d2 are the field of view dimensions and T is the number of frames.

        Returns
        -------
        (int, int, int)
            The number of y pixels, number of x pixels, number of frames.
        """
        return self._shape

    @property
    def ndim(self):
        return len(self.shape)

    def _compute_shape(self):
        """
        Loads the suite2p ops file to retrieve the dimensions of the data.bin file.

        Returns
        -------
        (int, int, int)
            number of frames, number of y pixels, number of x pixels.
        """
        ops_file = self.ops_path
        if ops_file.exists():
            ops = np.load(ops_file, allow_pickle=True).item()
        return ops['nframes'], ops['Ly'], ops['Lx']

    def __getitem__(self, item: Union[int, list, np.ndarray, Tuple[Union[int, np.ndarray, slice, range]]]):
        return self.data[item].copy()



%load_ext autoreload
%matplotlib inline

In [3]:
results = np.load('./results/SP044_plane11.npz', allow_pickle = True)

In [10]:
make_pmd_widget(results, denoise=True).show()

In [2]:
def make_pmd_widget(results, denoise=False):
    if denoise is False:
    
        pmd_obj = results['pmd_no_denoise'].item()
        pmd_obj.rescale = False
        pmd_obj.to('cuda')
    else:
        pmd_obj = results['pmd_denoise'].item()
        pmd_obj.rescale = False
        pmd_obj.to('cuda')
        

    raw_path = results['raw_path'].item()    
    bin_path = os.path.join(raw_path, "data.bin")
    ops_path = os.path.join(raw_path, "ops.npy")
    my_data = MotionBinDataset(bin_path, ops_path)

    widget = masknmf.visualization.interactive_guis.PMDWidget(my_data,
                                                              pmd_obj,
                                                             device='cuda')
    return widget
    
def visualize_pmd_results(results):
    pmd_denoise = results['pmd_denoise'].item()
    pmd_denoise.rescale = False
    pmd_denoise.to('cuda')

    pmd_no_denoise = results['pmd_no_denoise'].item()
    pmd_no_denoise.rescale = False
    pmd_no_denoise.to('cuda')
    
    raw_path = results['raw_path'].item()    
    bin_path = os.path.join(raw_path, "data.bin")
    ops_path = os.path.join(raw_path, "ops.npy")
    my_data = MotionBinDataset(bin_path, ops_path)

    mean_sub = lambda x:x - pmd_no_denoise.mean_img.cpu().numpy()[None, :, :]
    raw_stack_meansub = masknmf.FilteredArray(my_data, 
                                           mean_sub,
                                          device='cpu')
    
    resid_denoise = masknmf.PMDResidualArray(my_data, pmd_denoise)
    resid_no_denoise = masknmf.PMDResidualArray(my_data, pmd_no_denoise)

    raw_acf, pmd_denoise_acf, resid_denoise_acf = masknmf.diagnostics.pmd_autocovariance_diagnostics(my_data,
                                                                                                     pmd_denoise,
                                                                                                     batch_size = 200,
                                                                                                     device='cuda')

    raw_acf, pmd_no_denoise_acf, resid_no_denoise_acf = masknmf.diagnostics.pmd_autocovariance_diagnostics(my_data,
                                                                                                           pmd_no_denoise,
                                                                                                           batch_size = 200,
                                                                                                           device = 'cuda')

    iw = fpl.ImageWidget(data = [raw_stack_meansub, 
                                 pmd_no_denoise, 
                                 resid_no_denoise,
                                 raw_stack_meansub,
                                 pmd_denoise,
                                 resid_denoise,
                                 raw_acf,
                                 pmd_no_denoise_acf,
                                 resid_no_denoise_acf,
                                 raw_acf,
                                 pmd_denoise_acf,
                                 resid_denoise_acf],
                        names = ['raw mean0',
                                 'pmd',
                                 'raw - pmd',
                                 'raw mean0',
                                 'pmd+nn',
                                 'raw - (pmd+nn)',
                                 'raw L1acf',
                                 'pmd L1acf',
                                 'resid L1acf',
                                 'raw L1acf',
                                 'pmd+nn L1acf',
                                 'resid L1acf'],
                        figure_shape = (4, 3))
    iw.cmap = "gray"
    return iw


def visualize_denoising_diff(results):
    pmd_denoise = results['pmd_denoise'].item()
    pmd_denoise.rescale = True
    pmd_denoise.to('cuda')

    pmd_no_denoise = results['pmd_no_denoise'].item()
    pmd_no_denoise.rescale = True
    pmd_no_denoise.to('cuda')
    nn_difference_array = masknmf.PMDResidualArray(pmd_no_denoise, pmd_denoise)

    iw = fpl.ImageWidget(data = [pmd_no_denoise, pmd_denoise, nn_difference_array],
                         names = ['pmd', 'pmd+nn', 'diff'],
                        figure_shape = (1, 3))
    iw.cmap = "gray"
    return iw
        

In [1]:
iw = visualize_denoising_diff(results)
iw.show()

In [3]:
iw = visualize_pmd_results(results)
iw.show()