# Data Processing

This notebook is for experimenting with the creation of a class which loads/organizes/saves images used for segmentation

## Import List:

In [2]:
from typing import List, Iterable, Tuple
import numpy as np
from PIL import Image
import os
import nibabel as nib
from pathlib import Path
import pydicom
import cv2

## Base Loading and Saving Images Class
While the particulars of each segmentation problem are different, the loading/augmentation/saving

Note: Look at https://github.com/MIC-DKFZ/batchgenerators for list of different types of image editing

The general breakdown that I am thinking of is:

1. **ImgIE** : "Image Import Export", this class will contain the methods for loading the images into numpy/tensors and exporting numpy/tensors back to images. This should also have methods for searching through files and storing them.
2. **ImgAug** : "Image Augmentation", this class will handle any of the different image augmentations that one might think of doing for 2D and 3D images
3. **ImgDL** : "Image Dataloader", this class will function to make a Pytorch DataLoader that is easy to use for training
4. **ImgMet** : "Image Metrics", this class will contain methods for calculating common model performance metrics (SSIM, PSNR, Dice, etc.)

So the overall use of these classes will be through wrappers around them for loading and processing for the particular problem

Also, as an exercise/building good habits, the classes will include typing for all variables
https://peps.python.org/pep-0484/#acceptable-type-hints

### ImgIE

In [None]:
class ImgIE():
    '''Class for the loading of images into numpy arrays and the saving of numpy arrays into images.
    Also handles rudimentary processing of the images.'''
    def __init__(self) -> None:
        '''Test this out'''
        pass

    def load_image(self, im_path: Path, verbose: bool=False) -> np.ndarray:
        # Given an image path, determines the function required to load the contents
        # as a numpy array, which is returned.
        fil_typ = os.path.splitext(im_path)[1]

        if fil_typ == '.png':
            # If file is a png
            with Image.open(im_path) as f_im:
                img = np.array(f_im)

            if verbose:
                print(f'Loading {im_path} as png')
                print(f'Image shape:{img.shape}')

            if self.template['unit'] == 'intensity':
                if len(img.shape)==3:
                    img = self.rgb2ycrbcr(img)
                    img = img[:,:,0] #Just deal with intensity values at the moment because 
                                    # having multiple channels throws off cv2 when saving, 
                                    # since it also does BGR instead of RGB and will save a blue image
                elif len(img.shape)==2:
                    pass            # If the png is just greyscale, then there is nothing that can
                                    # be done except take the single channel
                else:
                    raise ImportError("Provided png image is not 2 or 3 dimensional, something is wrong with the image.")

            elif self.template['unit'] == 'color':
                raise NotImplementedError("""Loading and creation of patches from color png images is
                currently not supported. Please use template['unit']='intensity' for conversion of png
                imges to greyscale intesity images.""")

        elif fil_typ == '.nii' or fil_typ == '.gz':
            img = nib.load(im_path).get_fdata()
            if verbose:
                print(f'Loading {im_path} as nii')
                print(f'Image shape:{img.shape}')

        elif fil_typ == '.dcm':
            img = pydicom.dcmread(im_path).pixel_array
            if verbose:
                print(f'Loading {im_path} as dicom')
                print(f'Image shape:{img.shape}')

        else:
            raise FileNotFoundError(f'Image file type {fil_typ} not supported.')

        return img

    def load_png(self, im_path: Path, unit: str='raw') -> np.ndarray:
        '''Load png image
        
        Parameters
        ----------
        im_path : Path
            Path to the file you wish to load.
        
        unit : str
            What the unit of the given image should be. Current options are:
            - `'raw'` : the raw RGBA values stored in the image (Default)
            - `'lumanince'` :  the intensity channel from converting the RGB image to YCbCr.
            This will result in the image only having one channel

        Returns
        -------
        img : float ndarray
            The loaded image as a numpy array

        '''
        with Image.open(im_path) as f_im:
            if unit == 'raw':
                return np.array(f_im)
            if unit == 'luminance':
                img = np.array(f_im)
                
                if len(img.shape) == 3:
                    # convert to YCbCr then take first channel (luminance)
                    return self.rgba2ycbcr(img)[:,:,0]
                
                elif len(img.shape) == 2:
                    return img
                
                else:
                    raise ImportError("Provided png image is not 2 or 3 dimensional, something is wrong with the image.")
        
            # If the unit type is not supported
            raise NotImplementedError(f'Loading of png using unit value {unit} is currently not supported.')


    def load_jpg(self, im_path: Path, unit: str='raw') -> np.ndarray:
        '''Load jpg image

        '''
        pass

    def load_nifti(self, im_path: Path) -> np.ndarray:
        '''Load nifti file from provided path

        Parameters
        ----------
        im_path : Path
            Path to the file you wish to load.
        
        Returns
        -------
        img : float ndarray
            The loaded image as a numpy array
        
        '''
        return nib.load(im_path).get_fdata()

    def load_dicom(self, im_path: Path) -> np.ndarray:
        '''Load DICOM file from provided path

        Parameters
        ----------
        im_path : Path
            Path to the file you wish to load.
        
        Returns
        -------
        img : float ndarray
            The loaded image as a numpy array
        
        '''
        return pydicom.dcmread(im_path).pixel_array

    def rgba2ycbcr(self, img_rgba: np.ndarray) -> np.ndarray:
        '''Takes an RBG image and returns it as a YCbCr image 

        Parameters:
        ----------
        img_rgb : ndarray
            The RGBA image which you want to convert to YCbCr

        Returns:
        --------
        img_ycbcr : float ndarray
            The converted image

        '''
        if len(img_rgba.shape) != 4:
            raise ValueError('Input image is not RGBA')

        img_rgb = img_rgba.astype(np.float32)
        
        img_ycrcb = cv2.cvtColor(img_rgba, cv2.COLOR_RGB2YCR_CB)
        img_ycbcr = img_ycrcb[:,:,(0,2,1)].astype(np.float32)
        img_ycbcr[:,:,0] = (img_ycbcr[:,:,0]*(235-16)+16)/255.0
        img_ycbcr[:,:,1:] = (img_ycbcr[:,:,1:]*(240-16)+16)/255.0

        return img_ycbcr

    
    def save_image(self, fname: Path, im: np.ndarray, form: str, verbose: bool = False) -> None:
        # Take a given image and save it as the specified format:
        # fname = output name of the saved file
        # im = numpy array of image

        dim = im.shape #Get number of dimensions of image

        if form == 'png':
            # Check that you aren't saving a 3D image
            #TODO: Scale inputs to [0,255] so data isn't lost/image isn't saturated
            cv2.imwrite(f'{fname}',im)
            if verbose:
                print(f'Saving: {fname}')
        elif form == 'nii':

            # TODO: Add option to transpose image for some reason because mricron hates the first dim[0] = 1
            # Still gets loaded fine in terms of loading into python, but visualizing it is bad
            # np.transpose(im, (1,2,0))


            # TODO: If image is 2D then append a third  dimension before saving(?)
            nib.save(nib.Nifti1Image(im, np.eye(len(dim)+1)), fname)
            if verbose:
                print(f'Saving: {fname}')
        elif form == 'dcm':
            raise NotImplementedError('DICOM saving currently not supported')
        else:
            raise NotImplementedError('Specified file type is currently not supported for saving')

In [5]:
type(np.zeros(2))

numpy.ndarray

### ImgAug

In [None]:
class ImgAug():
    def __init__(self) -> None:
        pass

    def aug_run(self) -> None:
        '''Run provided augmentation. Potentially save augmentation lists for repeated use'''
        pass

    def grouped_aug_run(self, img_grp: List[List[Path]]) -> None:
        '''Run provided augmentation on several different groups of images
        
        '''
        pass

    def gen_random_aug(self) -> None:
        '''Create a generator for random combinations of augmentation'''

    def array_translate(self) -> None:
        pass

    def array_rotate(self) -> None:
        pass

    def gen_noise(self) -> None:
        '''Add noise to provided image'''
        pass

    def array_scale(self) -> None:
        '''Either upscales or downscales provided array'''
        pass

    # https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html

### ImgDL

In [None]:
class ImgDL(ImgIE, ImgAug):
    def __init__(self) -> None:
        pass

### ImgMet

In [None]:
class ImgMet():
    '''Class for containing different performance metrics'''
    def __init__(self) -> None:
        pass

    def PSNR_calc(self) -> float:
        pass

    def SSIR_calc(self) -> float:
        pass

## U2Net Implementation
This serves as a test-run for using this collection of classes on a real problemset.
*Note: Task05 has two channels in its images*

#### Outline for how to read in the unique file organization that U2Net has:
The model requires the use of 3 different datasets, each kept in their own directory
- Need to match both the original image and the label file together so they can be loaded at the same time
    - Would be good to have a "list of lists" setup for this
- Need to be able to select randomly from only one dataset at a time
    - Ideally each epoch is from a different dataset

In [None]:
## Rough outline

class UData(ImgIE, ImgAug):
    '''Class for data management for U^2-Net training and testing
    
    Parameters
    -------
    - path pairs for the folders containing the raw images and the labels
        [[/img_1, /label_1],[/img_2, /label_2]]
    - output directory for generated images (after patches/augmentation is applied)

    '''
    def __init__(self) -> None:
        pass

    def get_files(self) -> None:
        pass

    def match_files(self) -> None:
        pass

    def run(self) -> None:
        pass

## SrGen recreation
Using the defined classes above, is it possible to recreate the SrGen class?