In [None]:
import cv2
import matplotlib.pyplot as plt
import ipyvolume as ipv # pip install ipyvolume
from pylab import rcParams
rcParams['figure.figsize'] = (7, 7)

class viewer(object):
    
    @staticmethod
    def imshow_fp_(path, gray=True):
        '''
        fp: stands for file path
        '''
        if gray == True:
            img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
            imgplot = plt.imshow(img, cmap='gray')
        else:
            img = cv2.imread(path, cv2.IMREAD_COLOR)
            imgplot = plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        plt.show()
    
    @staticmethod
    def imshow_(img, cmap='gray', gray=True):
        img = np.squeeze(img)
        if gray == True:
            imgplot = plt.imshow(img, cmap=cmap)
        else:
            imgplot = plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        plt.show()
   
    @staticmethod
    def ipv_showvolume(volume):
        ipv.figure()
        ipv.style.axes_off()
        ipv.style.box_off() 
        ipv.volshow(volume, lighting=True, level=[0.25, 0.75], opacity=0.03, level_width=0.1, data_min=0, data_max=1)
        ipv.view(-30, 40)
        ipv.show()
        # level=[0.0, 0.5, 1.0], opacity=[0.0, 0.1, 0.2]
        # level=[0.25, 0.75], opacity=0.03
        
    @staticmethod
    def draw_coord_on_image(x, y, image, value, color=False):
        '''
        Draw coordinates on the image. The image could be grayscale or RGB image.
        x and y are a list of integers representing x and y coordinates.
        value is either an integer or tuple representing the value you want to set
        on the coordinates.
        '''
        plt.figure()
        if len(image.shape) == 2 and color:
            image = cv2.cvtColor(image,cv2.COLOR_GRAY2RGB)
        if len(image.shape) == 2:
            for i in range(len(x)):
                image[int(y[i])][int(x[i])] = value
            imgplot = plt.imshow(image, cmap='gray')
        else:
            for i in range(len(x)):
                image[int(y[i])][int(x[i])][:] = value
            imgplot = plt.imshow(image)
        
    @staticmethod
    def split_cv_bnd_into_x_y(bnd, index):
        '''
        bnd should have form: struct_num * [point_num * [1 * [x ,y]]].
        index is a list of integers with desired structure indices.
        '''
        x_coord = []
        y_coord = []
        for dim in index:
            assert(dim < len(bnd))
            pt_num = len(bnd[dim])
            x_     = [0] * pt_num
            y_     = [0] * pt_num
            for i in range(pt_num):
                x_[i] = bnd[dim][i][0][0]
                y_[i] = bnd[dim][i][0][1]
            x_coord = x_coord + x_
            y_coord = y_coord + y_
        return x_coord, y_coord
    
    @staticmethod
    def draw_bnd_on_single_dim_cv(image, bnd, value, idx=-1):
        '''
        bnd should have form: struct_num * [point_num * [1 * [x ,y]]].
        value: the color to be drawn.
        idx: -1 for all, other index for particular structure
        '''
        if len(image.shape) == 2:
            image = cv2.cvtColor(image,cv2.COLOR_GRAY2RGB)
        if idx == -1:
            idx = np.arange(len(bnd))
        else:
            idx = [idx]
        x_coord, y_coord = viewer.split_cv_bnd_into_x_y(bnd, idx)
        viewer.draw_coord_on_image(x_coord, y_coord, image, value, color=True)
    
    @staticmethod
    def draw_red_on_single_dim_cv(red, cmap='jet', component_to_plot=-1):
        '''
        red: [0] is the number of labels (connected component), [1] is the matrix
        '''
        plt.figure()
        assert(component_to_plot < red[0])
        if component_to_plot == -1:
            imgplot = plt.imshow(red[1], cmap=cmap)
        else:
            curtain = red[1]
            curtain[curtain!=component_to_plot] = red[0]
            imgplot = plt.imshow(curtain, cmap=cmap)
        return red[0]