# Visualize Dataset

This notebook uses napari, a multi-dimensional image viewing tool for Python.

In [2]:
import numpy as np
import napari
import import_ipynb
from transformations import re_normalize


def enable_gui_qt():
    # performs the command %gui qt
    from IPython import get_ipython
    
    ipython = get_ipython()
    ipython.magic('gui qt')
    

class DatasetViewer:
    def __init__(self, dataset):
        self.dataset = dataset
        self.index = 0
        
        # napari viewer instance
        self.viewer = None
        
        # current image and shape layer
        self.image_layer = None
        self.label_layer = None
        
    def napari(self):
        # magic command IPython
        enable_gui_qt()
        
        # napari
        if self.viewer:
            try:
                del self.viewer
            except AttributeError:
                pass
            self.index = 0
        self.index = 0
        
        # init napari instance
        self.viewer = napari.Viewer()
        
        # show current sample
        self.show_sample()
        
        # key-bindings
        # press n to go to next sample
        @self.viewer.bind_key('n')
        def next(viewer):
            self.increase_index()
            self.show_sample()
            
        # press b to go to previous sample
        @self.viewer.bind_key('b')
        def prev(viewer):
            self.decrease_index()
            self.show_sample()
        
    def increase_index(self):
        self.index += 1
        if self.index >= len(self.dataset):
            self.index = 0
                
    def decrease_index(self):
        self.index -= 1
        if self.index < 0:
            self.index = len(self.dataset) - 1
                
    def show_sample(self):
        # get a sample from the dataset
        sample = self.get_sample_dataset(self.index)
        x, y = sample
        
        # get the names from the dataset
        names = self.get_names_dataset(self.index)
        x_name, y_name = names
        # only possible if pathlib.Path
        x_name, y_name = x_name.name, y_name.name
        
        # transform the sample to numpy, cpu and correct format to visualize
        x = self.transform_x(x)
        y = self.transform_y(y)
        
        # create or update image layer
        if self.image_layer not in self.viewer.layers:
            self.image_layer = self.create_image_layer(x, x_name)
        else:
            self.update_image_layer(self.image_layer, x, x_name)
            
        # create or update label layer
        if self.label_layer not in self.viewer.layers:
            self.label_layer = self.create_label_layer(y, y_name)
        else: 
            self.update_label_layer(self.label_layer, y, y_name)
            
        # reset view
        self.viewer.reset_view()
        
    def create_image_layer(self, x, x_name):
        return self.viewer.add_image(x, name=str(x_name))
    
    def update_image_layer(self, image_layer, x, x_name):
        # replace the data and name of a given image layer
        image_layer.data = x
        image_layer.name = str(x_name)
        
    def create_label_layer(self, y, y_name):
        return self.viewer.add_labels(y, name=str(y_name))
    
    def update_label_layer(self, target_layer, y, y_name):
        # replace the data and name of a given label layer
        target_layer.data = y
        target_layer.name = str(y_name)
        
    def get_sample_dataset(self, index):
        return self.dataset[index]
    
    def get_names_dataset(self, index):
        return self.dataset.inputs[index], self.dataset.targets[index]
    
    def transform_x(self, x):
        # make sure it's a numpy.ndarray on the cpu
        x = x.cpu().numpy()
        
        # from [C, H, W] to [H, W, C] - only for RGB images
        if self.check_if_rgb(x):
            x = np.moveaxis(x, source=0, destination=-1)
            
        # renormalize
        x = re_normalize(x)
        
        return x
    
    def transform_y(self, y):
        # make sure it's a numpy.ndarray on the cpu
        y = y.cpu().numpy()
        return y
    
    def check_if_rgb(self, x):
        # checks if the shape of the first dimension (channel dim) is 3
        # TODO: try other methods as a 3D grayscale input image can have 3 modalities -> 3 channels
        # TODO: also think about RGBA images with 4 channels or a combination of a RGB and a grayscale image -> 4 channels
        return True if x.shape[0] == 3 else False

importing Jupyter notebook from transformations.ipynb
# of unique classes = [10 11 12 13 14]
x = shape: (128, 128, 3); type: uint8
x = min: 0; max: 255
x_t = shape: (3, 64, 64); type: float64
x_t = min: 0.0; max: 1.0
y = shape: (128, 128); class: [10 11 12 13 14]
y_t = shape: (64, 64); class: [0 1 2 3 4]
