# Data Processing

This notebook is for experimenting with the creation of a class which loads/organizes/saves images used for segmentation

## Import List:

In [2]:
from typing import List, Iterable, Tuple, Any, Union
import numpy as np
from PIL import Image
import os
import nibabel as nib
from pathlib import Path
import pydicom
import cv2

## Base Loading and Saving Images Class
While the particulars of each segmentation problem are different, the loading/augmentation/saving

Note: Look at https://github.com/MIC-DKFZ/batchgenerators for list of different types of image editing

The general breakdown that I am thinking of is:

1. **ImgIE** : "Image Import Export", this class will contain the methods for loading the images into numpy/tensors and exporting numpy/tensors back to images. This should also have methods for searching through files and storing them.
2. **ImgAug** : "Image Augmentation", this class will handle any of the different image augmentations that one might think of doing for 2D and 3D images
3. **ImgDL** : "Image Dataloader", this class will function to make a Pytorch DataLoader that is easy to use for training
4. **ImgMet** : "Image Metrics", this class will contain methods for calculating common model performance metrics (SSIM, PSNR, Dice, etc.)

So the overall use of these classes will be through wrappers around them for loading and processing for the particular problem

Also, as an exercise/building good habits, the classes will include typing for all variables
https://peps.python.org/pep-0484/#acceptable-type-hints

### ImgIE

In [None]:
class ImgIE():
    '''Class for the loading of images into numpy arrays and the saving of numpy arrays into images.
    Also handles rudimentary processing of the images.'''
    def __init__(self) -> None:
        '''Test this out'''
        pass

    def load_image(self, im_path: Path, verbose: bool=False) -> np.ndarray:
        # Given an image path, determines the function required to load the contents
        # as a numpy array, which is returned.
        fil_typ = os.path.splitext(im_path)[1]

        if fil_typ == '.png':
            # If file is a png
            with Image.open(im_path) as f_im:
                img = np.array(f_im)

            if verbose:
                print(f'Loading {im_path} as png')
                print(f'Image shape:{img.shape}')

            if self.template['unit'] == 'intensity':
                if len(img.shape)==3:
                    img = self.rgb2ycrbcr(img)
                    img = img[:,:,0] #Just deal with intensity values at the moment because 
                                    # having multiple channels throws off cv2 when saving, 
                                    # since it also does BGR instead of RGB and will save a blue image
                elif len(img.shape)==2:
                    pass            # If the png is just greyscale, then there is nothing that can
                                    # be done except take the single channel
                else:
                    raise ImportError("Provided png image is not 2 or 3 dimensional, something is wrong with the image.")

            elif self.template['unit'] == 'color':
                raise NotImplementedError("""Loading and creation of patches from color png images is
                currently not supported. Please use template['unit']='intensity' for conversion of png
                imges to greyscale intesity images.""")

        elif fil_typ == '.nii' or fil_typ == '.gz':
            img = nib.load(im_path).get_fdata()
            if verbose:
                print(f'Loading {im_path} as nii')
                print(f'Image shape:{img.shape}')

        elif fil_typ == '.dcm':
            img = pydicom.dcmread(im_path).pixel_array
            if verbose:
                print(f'Loading {im_path} as dicom')
                print(f'Image shape:{img.shape}')

        else:
            raise FileNotFoundError(f'Image file type {fil_typ} not supported.')

        return img

    def load_png(self, im_path: Path, unit: str='raw') -> np.ndarray:
        '''Load png image
        
        Parameters
        ----------
        im_path : Path
            Path to the file you wish to load.
        
        unit : str
            What the unit of the given image should be. Current options are:
            - `'raw'` : the raw RGBA values stored in the image (Default)
            - `'lumanince'` :  the intensity channel from converting the RGB image to YCbCr.
            This will result in the image only having one channel

        Returns
        -------
        img : float ndarray
            The loaded image as a numpy array

        '''
        with Image.open(im_path) as f_im:
            if unit == 'raw':
                return np.array(f_im)
            if unit == 'luminance':
                img = np.array(f_im)
                
                if len(img.shape) == 3:
                    # convert to YCbCr then take first channel (luminance)
                    return self.rgba2ycbcr(img)[:,:,0]
                
                elif len(img.shape) == 2:
                    return img
                
                else:
                    raise ImportError("Provided png image is not 2 or 3 dimensional, something is wrong with the image.")
        
            # If the unit type is not supported
            raise NotImplementedError(f'Loading of png using unit value {unit} is currently not supported.')


    def load_jpg(self, im_path: Path, unit: str='raw') -> np.ndarray:
        '''Load jpg image

        '''
        pass

    def load_nifti(self, im_path: Path) -> np.ndarray:
        '''Load nifti file from provided path

        Parameters
        ----------
        im_path : Path
            Path to the file you wish to load.
        
        Returns
        -------
        img : float ndarray
            The loaded image as a numpy array
        
        '''
        return nib.load(im_path).get_fdata()

    def load_dicom(self, im_path: Path) -> np.ndarray:
        '''Load DICOM file from provided path

        Parameters
        ----------
        im_path : Path
            Path to the file you wish to load.
        
        Returns
        -------
        img : float ndarray
            The loaded image as a numpy array
        
        '''
        return pydicom.dcmread(im_path).pixel_array

    def rgba2ycbcr(self, img_rgba: np.ndarray) -> np.ndarray:
        '''Takes an RBG image and returns it as a YCbCr image 

        Parameters:
        ----------
        img_rgb : ndarray
            The RGBA image which you want to convert to YCbCr

        Returns:
        --------
        img_ycbcr : float ndarray
            The converted image

        '''
        if len(img_rgba.shape) != 4:
            raise ValueError('Input image is not RGBA')

        img_rgb = img_rgba.astype(np.float32)
        
        img_ycrcb = cv2.cvtColor(img_rgba, cv2.COLOR_RGB2YCR_CB)
        img_ycbcr = img_ycrcb[:,:,(0,2,1)].astype(np.float32)
        img_ycbcr[:,:,0] = (img_ycbcr[:,:,0]*(235-16)+16)/255.0
        img_ycbcr[:,:,1:] = (img_ycbcr[:,:,1:]*(240-16)+16)/255.0

        return img_ycbcr

    
    def save_image(self, fname: Path, im: np.ndarray, form: str, verbose: bool = False) -> None:
        # Take a given image and save it as the specified format:
        # fname = output name of the saved file
        # im = numpy array of image

        dim = im.shape #Get number of dimensions of image

        if form == 'png':
            # Check that you aren't saving a 3D image
            #TODO: Scale inputs to [0,255] so data isn't lost/image isn't saturated
            cv2.imwrite(f'{fname}',im)
            if verbose:
                print(f'Saving: {fname}')
        elif form == 'nii':

            # TODO: Add option to transpose image for some reason because mricron hates the first dim[0] = 1
            # Still gets loaded fine in terms of loading into python, but visualizing it is bad
            # np.transpose(im, (1,2,0))


            # TODO: If image is 2D then append a third  dimension before saving(?)
            nib.save(nib.Nifti1Image(im, np.eye(len(dim)+1)), fname)
            if verbose:
                print(f'Saving: {fname}')
        elif form == 'dcm':
            raise NotImplementedError('DICOM saving currently not supported')
        else:
            raise NotImplementedError('Specified file type is currently not supported for saving')

In [5]:
type(np.zeros(2))

numpy.ndarray

In [3]:
from collections import OrderedDict
class Dummy:
    def __init__(self) -> None:
        self.x = 2
        self.y = 3

    def add(self, x:int, y:int) -> int:
        return x+y

    def divide(self, x:int) -> float:
        return x//2

    def run(self, x: int)-> float:

        q = OrderedDict()
        q = {'add':self.add,
        'divide':self.divide}

        return q['add'](x,3)

z = Dummy()
print(z.run(4))


7


In [19]:
# This is probably the most sound way to go about doing things until I can think of a better way...
# Have the augmentations be entered as a list

a = [["translation",[0,0,12],"rand"],["resolution",10]]
for func, *kwrd in a:
    print(func)
    print(kwrd)

translation
[[0, 0, 12], 'rand']
resolution
[10]


In [4]:
import numpy as np
np.matmul([2],[3])

6

### ImgAug

The idea behind the majority of these methods is that they intake a numpy array and some parameters and output a numpy array and string indicating what has changed.
Is it better to assign the numpy array being worked on to an attribute of the class?

In [None]:
from skimage.transform import rotate, AffineTransform, warp, rescale, resize
import math

class ImgAug(ImgIE):
    def __init__(self) -> None:
        super(ImgIE, self).__init__()
        self.template = self.get_template()
        pass

    def aug_run(self, inp: np.ndarray, aug_key: "OrderedDict[str, Any]") -> None:
        '''Run provided augmentation using the settings defined by the template dictionary.
        Possibly work storing the various methods into a dictionary
        https://stackoverflow.com/questions/9168340/using-a-dictionary-to-select-function-to-execute

        '''
        pass

    def grouped_aug_run(self, img_grp: List[List[Path]], aug_key: "OrderedDict[str, Any]") -> None:
        '''Run provided augmentation on several different groups of images
        
        '''
        pass

    def set_template(self) -> None:
        pass


    def get_template(self) -> "dict[str, Any]":
        '''Return the previously set template, or return the standard template which 
        should have parameters changed.
        '''
        try:
            return self.template
        except:
            return {'out_type':'png',
                'unit':'intensity', #Currently only matters for png
                'resolution':None,
                'same_size': True, # Whether to have the LR image be the same size as the HR image
                                   # (i.e. whether to scale down then up or just down)
                'translation':None, # Have both single value or multiple
                'rotation': None, # Around each axis
                'scale': False, # What magnitude to zoom in for added jitter
                'patch': False, # Have this accept 3 dimensional input [x,y,z], [x,y], or single
                'step': 10, # Also have this accept 3 dimensional input
                'keep_blank': False,
                'blank_ratio': 0.4,
            }

    def gen_random_aug(self) -> None:
        '''Create a generator for random combinations of augmentation'''
        dim = im_h.shape
        if len(dim)>3:
            raise ValueError('Dimension of input data not currently supported')
        
        # If single image is provided for any of these settings, convert into list of N dimensions
        if self.template['translation'] == None:
            trans = [None]
        elif type(self.template['translation']) != list:
            trans = [self.template['translation'] for _ in range(len(dim))]
        else:
            trans = self.template['translation'][:] #Weird thing I have to add to not link changes to 'trans' to self.template
        
        for idx, x in enumerate(trans):
            if trans[idx] != 0 and trans[idx] != None:
                trans[idx] = np.random.randint(-x,x)


    def array_translate(self, im: np.ndarray, trans: List[int], mode: str='symmetric') -> "tuple[np.ndarray, str]":
        # Translation
        dim = im.shape

        if len(dim) != len(trans):
            raise IndexError(f'Translation of numpy array with dimensions: {dim} is not compatible with translation {trans}')
        
        if len(dim) == 2:
            transform = AffineTransform(translation=(trans[0], trans[1]))
            im = warp(im, transform, mode=mode)
            label = f'_tr{trans[0]}_{trans[1]}'

        elif len(trans) == 3:
            transform = AffineTransform(translation=(trans[1], trans[2]))
            for i in range(dim[0]):
                im[i,:,:] = warp(im[i,:,:], transform, mode = mode)

            for i in range(dim[1]):
                # Because two dimensions were already translated, you only need to translate
                # along one dimension
                im[:,i,:] = warp(im[:,i,:], AffineTransform(translation=(trans[0],0)), mode='symmetric')
                
            label = f'_tr{trans[0]}_{trans[1]}_{trans[2]}'

        else:
            raise NotImplementedError(f"Translation of objects with dimension {len(dim)} is not currently supported.")
        
        return im, label


    def array_rotate(self, im: np.ndarray, rot: List[int], order: int=1) -> "tuple[np.ndarray, str]":
        # TODO: Issue with low resolution not necessairly having the same dimensions
        # Rotation 2D
        dim = im.shape

        if len(dim) != len(rot):
            raise IndexError(f'Translation of numpy array with dimensions: {dim} is not compatible with translation {rot}')

        if len(dim) == 2:
            im = rotate(im, rot[0], order=order)
            label = f'_rot{rot[0]}'

        # Rotation 3D
        elif len(dim) == 3:
            for i in range(dim[0]):
                im[i,:,:] = rotate(im[i,:,:],rot[0], order=order)
            for i in range(dim[1]):
                im[:,i,:] = rotate(im[:,i,:],rot[1], order=order)
            for i in range(dim[2]):
                im[:,:,i] = rotate(im[:,:,i],rot[2], order=order)
            label = f'_rot{rot[0]}_{rot[1]}_{rot[2]}'
        
        else:
            raise NotImplementedError(f"Translation of objects with dimension {len(rot)} is not currently supported.")

        return im, label

    def gen_noise(self) -> None:
        '''Add noise to provided image'''
        pass

    def array_scale(self, im: np.ndarray, scale: List[float], order: int=1, mode: str='symmetric', int_dims: bool=False, anti_alias: bool=True) -> "tuple[np.ndarray, str]":
        '''Either upscales or downscales provided array
        https://scikit-image.org/docs/stable/auto_examples/transform/plot_rescale.html
        '''
        # Scaling
        dim = im.shape

        if len(dim) != len(scale):
            raise IndexError(f'Scaling of numpy array with dimensions: {dim} is not compatible with translation {scale}')

        if int_dims:
            new_dims = scale
            label = f'_si_'
        else:
            new_dims = [math.floor(x) for x in np.matmul(dim, scale)]
            label = f'_sr_'

        im = resize(im, new_dims, order=order, mode=mode)

        if len(dim) == 2:
            label = label + f'{scale[0]}_{scale[1]}'
        elif len(dim) == 3:
            label = label + f'{scale[0]}_{scale[1]}_{scale[2]}'
        else:
            raise NotImplementedError(f"Translation of objects with dimension {len(scale)} is not currently supported.")

        return im, label
        
    
    # def gen_LR_img(self, im, res: float, interp: int=1) -> np.ndarray:
    #     # Generate the low-resolution image from the corresponding HR image using resizing
    #     # https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html

    #     dim = im.shape

    #     # TODO: Patching error occurs when the shape of an image has dimension of odd magintude with
    #     #       even 'res', and vice versa. Need to come up with a fix for this...
    #     new_dims = [math.floor(x) for x in np.divide(dim, res)]

    #     im = resize(im, new_dims, order = interp, mode='symmetric')

    #     if self.template['same_size']:
    #         im = resize(im, dim, order= interp, mode = 'symmetric')

    #     return im

    def img2patches(self, im_h, fname, same_size=True, keep_blank=False, slice_select=None, save=False, sanity_check=False, verbose=False):
        # Depending on the number of dimenions in the `patch` value, either make 2D
        # or 3D images

        dim = im_h.shape
        patch_size = self.template['patch'][:]
        step = self.template['step'][:]

        
        im_name=Path(fname).with_suffix('').__str__()
        #im_name = fname.split('.')[:-2][0] #Kind of janky way to just strip away the suffix
        
        if slice_select: #If slice_select is provided, then getting rid of blanks really screws things up
            print('keeping blank')
            keep_blank = True

        if type(patch_size) != list:
            patch_size = [patch_size for _ in range(len(dim))]
        
        #If they provide a patch size of -1 along a dimension, use the size from dim
        for idx, i in enumerate(patch_size): 
            if i == -1:
                patch_size[idx] = dim[idx]
        
        #TODO: add option of step size is 0 for full image patch size
        if type(step) != list:
            step = [step for _ in range(len(dim))]

        # Whether to shrink the patch size and step size down by the scaling amount for LR images without the same dimensions as the HR images
        if not same_size:
            try: 
                patch_size = [math.floor(x) for x in np.divide(patch_size,self.template['resolution'])]
            except: 
                raise ValueError(f'Resolution change coefficient: {self.template["resolution"]} not defined properly for patch_size: {patch_size}')

            try: 
                step = [math.floor(x) for x in np.divide(step,self.template['resolution'])]
            except:
                raise ValueError(f'Resolution change coefficient: {self.template["resolution"]} not defined properly for step: {step}')
        else:
            if verbose:
                print(f'patch size = {patch_size}')
                print(f'step size = {step}')

        # Create a numpy stack following Pytorch protocols, so 1 dimension more than patch
        
        # Count number of non-zero entries
        cnt = 0
        blank = 0
        not_blank = []
        itter = -1

        # Get total number of patches that will be created:
        #patch_count = np.prod([len(range(0,i,step[idx])) for idx, i in enumerate(dim)])
        if verbose:
            print(f'patch guess = {np.prod([math.floor((i-patch_size[idx])/step[idx])+1 for idx,i in enumerate(dim)])}')
        patch_count = np.prod([math.floor((i-patch_size[idx])/step[idx])+1 for idx,i in enumerate(dim)])
        patch_vol = math.prod(patch_size)*self.template['blank_ratio']

        if len(dim) == 2:
            stack = np.zeros((patch_count,patch_size[0],patch_size[1]))
            if verbose:
                print(f'stack size = {stack.shape}')

            for i in range(0,dim[0],step[0]):
                for j in range(0,dim[1],step[1]):
                    if i+patch_size[0] <= dim[0] and j+patch_size[1] <= dim[1]:
                        itter = itter+1 #just a calculator for finding when blanks occur
                        samp = im_h[i:i+patch_size[0],j:j+patch_size[1]]

                        if keep_blank or (samp==0).sum() <= patch_vol:#(samp.max() > 0):
                            stack[cnt,:,:] = samp
                            cnt += 1
                            not_blank.append(itter)
                        else:
                            blank += 1
                            #blank.append(_)
        elif len(dim) == 3:
            stack = np.zeros((patch_count,patch_size[0],patch_size[1], patch_size[2]))
            print(f'stack size = {stack.shape}')

            for i in range(0,dim[0],step[0]):
                for j in range(0,dim[1],step[1]):
                    for k in range(0,dim[2],step[2]):
                        #itter = itter+1 #just a calculator for finding when blanks occur
                        if i+patch_size[0] <= dim[0] and j+patch_size[1] <= dim[1] and k+patch_size[2] <= dim[2]:
                            itter = itter+1
                            samp = im_h[i:i+patch_size[0],j:j+patch_size[1], k:k+patch_size[2]]

                            if keep_blank or (samp==0).sum() <= patch_vol:#(samp.max() > 0):
                                stack[cnt,:,:,:] = samp
                                cnt += 1
                                not_blank.append(itter)
                            else:
                                blank += 1
                                #blank.append(_)
            print(itter)
        else:
            raise IndexError(f'Images of dimension {dim} not supported by this method. Only 2D and 3D data accepted.')
        

        #TODO: There MUST be a better way to organize this whole mess, lol

        fnames = []
        if slice_select:
            for i in range(len(slice_select)):
                fnames.append(f'{im_name}_{i}.{self.template["out_type"]}')
        else:
            for i in range(cnt):
                fnames.append(f'{im_name}_{i}.{self.template["out_type"]}')

        if save:
            if slice_select:
                for idx, i in tqdm(enumerate(slice_select)):
                    self.save_image(fnames[idx], stack[i], verbose)
            else:
                for idx, i in tqdm(enumerate(fnames)):
                    self.save_image(i,stack[idx], verbose)
            if sanity_check:
                print(f'Number of patches: {len(not_blank)}')
                print(f'Number of blank patches: {blank}')
                return fnames, not_blank
            else:
                return fnames
        else:
            if sanity_check:
                return fnames, stack, not_blank
            else:
                return fnames, stack



In [3]:
type([2])

list

### ImgDL

In [None]:
class ImgDL(ImgIE, ImgAug):
    def __init__(self) -> None:
        pass

### ImgMet

In [None]:
class ImgMet():
    '''Class for containing different performance metrics'''
    def __init__(self) -> None:
        pass

    def PSNR_calc(self) -> float:
        pass

    def SSIR_calc(self) -> float:
        pass

## U2Net Implementation
This serves as a test-run for using this collection of classes on a real problemset.
*Note: Task05 has two channels in its images*

#### Outline for how to read in the unique file organization that U2Net has:
The model requires the use of 3 different datasets, each kept in their own directory
- Need to match both the original image and the label file together so they can be loaded at the same time
    - Would be good to have a "list of lists" setup for this
- Need to be able to select randomly from only one dataset at a time
    - Ideally each epoch is from a different dataset

In [None]:
## Rough outline

class UData(ImgIE, ImgAug):
    '''Class for data management for U^2-Net training and testing
    
    Parameters
    -------
    - path pairs for the folders containing the raw images and the labels
        [[/img_1, /label_1],[/img_2, /label_2]]
    - output directory for generated images (after patches/augmentation is applied)

    '''
    def __init__(self) -> None:
        pass

    def get_files(self) -> None:
        pass

    def match_files(self) -> None:
        pass

    def run(self) -> None:
        pass

## SrGen recreation
Using the defined classes above, is it possible to recreate the SrGen class?