# Data Processing

This notebook is for experimenting with the creation of a class which loads/organizes/saves images used for segmentation

## Import List:

In [1]:
from typing import List, Iterable, Tuple, Any, Union, Generator, Dict
import numpy as np
from PIL import Image
import os
import nibabel as nib
from pathlib import Path
import pydicom
import cv2

## Base Loading and Saving Images Class
While the particulars of each segmentation problem are different, the loading/augmentation/saving

Note: Look at https://github.com/MIC-DKFZ/batchgenerators for list of different types of image editing

The general breakdown that I am thinking of is:

1. **ImgIE** : "Image Import Export", this class will contain the methods for loading the images into numpy/tensors and exporting numpy/tensors back to images. This should also have methods for searching through files and storing them.
2. **ImgAug** : "Image Augmentation", this class will handle any of the different image augmentations that one might think of doing for 2D and 3D images
3. **ImgDL** : "Image Dataloader", this class will function to make a Pytorch DataLoader that is easy to use for training
4. **ImgMet** : "Image Metrics", this class will contain methods for calculating common model performance metrics (SSIM, PSNR, Dice, etc.)

So the overall use of these classes will be through wrappers around them for loading and processing for the particular problem

Also, as an exercise/building good habits, the classes will include typing for all variables
https://peps.python.org/pep-0484/#acceptable-type-hints

# Formating Rules:
1. Data is a first-class citizen: The raw data/numerical values should always be at the begining of any list of arguments. For example, the numpy array being saved or augmented should always be the first argument.
2. Methods should, when possible, only return 1 to 2 values. If multiple variables are returned, then there is a good chance the method can be broken into smaller pieces
    a. Methods should try to be consistent in the typing of their returned values, returning `None` or raising errors when a method isn't applicable
3. 

### ImgIE
Note my use of the `Path` class in order for more explicit coding

In [53]:
class ImgIE():
    '''Class for the loading of images into numpy arrays and the saving of numpy arrays into images.
    Also handles rudimentary processing of the images.'''
    def __init__(self) -> None:
        '''Test this out'''
        pass

    def load_image(self, im_path: Path, unit: str='intensity', verbose: bool=False) -> np.ndarray:
        # Given an image path, determines the function required to load the contents
        # as a numpy array, which is returned.
        fil_typ = os.path.splitext(im_path)[1]

        if fil_typ == '.png':
            # If file is a png
            with Image.open(im_path) as f_im:
                img = np.array(f_im)

            if verbose:
                print(f'Loading {im_path} as png')
                print(f'Image shape:{img.shape}')

            if unit == 'intensity':
                if len(img.shape)==3:
                    img = self.rgba2ycbcr(img)
                    img = img[:,:,0] #Just deal with intensity values at the moment because 
                                    # having multiple channels throws off cv2 when saving, 
                                    # since it also does BGR instead of RGB and will save a blue image
                elif len(img.shape)==2:
                    pass            # If the png is just greyscale, then there is nothing that can
                                    # be done except take the single channel
                else:
                    raise ImportError("Provided png image is not 2 or 3 dimensional, something is wrong with the image.")

            elif unit == 'color':
                raise NotImplementedError("""Loading and creation of patches from color png images is
                currently not supported. Please use template['unit']='intensity' for conversion of png
                imges to greyscale intesity images.""")

        elif fil_typ == '.nii' or fil_typ == '.gz':
            img = nib.load(im_path).get_fdata()
            if verbose:
                print(f'Loading {im_path} as nii')
                print(f'Image shape:{img.shape}')

        elif fil_typ == '.dcm':
            img = pydicom.dcmread(im_path).pixel_array
            if verbose:
                print(f'Loading {im_path} as dicom')
                print(f'Image shape:{img.shape}')

        else:
            raise FileNotFoundError(f'Image file type {fil_typ} not supported.')

        return img


    def load_png(self, im_path: Path, unit: str='raw') -> np.ndarray:
        '''Load png image
        
        Parameters
        ----------
        im_path : Path
            Path to the file you wish to load.
        
        unit : str
            What the unit of the given image should be. Current options are:
            - `'raw'` : the raw RGBA values stored in the image (Default)
            - `'lumanince'` :  the intensity channel from converting the RGB image to YCbCr.
            This will result in the image only having one channel

        Returns
        -------
        img : float ndarray
            The loaded image as a numpy array

        '''
        with Image.open(im_path) as f_im:
            if unit == 'raw':
                return np.array(f_im)
            if unit == 'luminance':
                img = np.array(f_im)
                
                if len(img.shape) == 3:
                    # convert to YCbCr then take first channel (luminance)
                    return self.rgba2ycbcr(img)[:,:,0]
                
                elif len(img.shape) == 2:
                    return img
                
                else:
                    raise ImportError("Provided png image is not 2 or 3 dimensional, something is wrong with the image.")
        
            # If the unit type is not supported
            raise NotImplementedError(f'Loading of png using unit value {unit} is currently not supported.')


    def load_jpg(self, im_path: Path, unit: str='raw') -> np.ndarray:
        '''Load jpg image

        '''
        pass

    def load_nifti(self, im_path: Path) -> np.ndarray:
        '''Load nifti file from provided path

        Parameters
        ----------
        im_path : Path
            Path to the file you wish to load.
        
        Returns
        -------
        img : float ndarray
            The loaded image as a numpy array
        
        '''
        return nib.load(im_path).get_fdata()

    def load_dicom(self, im_path: Path) -> np.ndarray:
        '''Load DICOM file from provided path

        Parameters
        ----------
        im_path : Path
            Path to the file you wish to load.
        
        Returns
        -------
        img : float ndarray
            The loaded image as a numpy array
        
        '''
        return pydicom.dcmread(im_path).pixel_array

    def rgba2ycbcr(self, img_rgba: np.ndarray) -> np.ndarray:
        '''Takes an RBG image and returns it as a YCbCr image 

        Parameters:
        ----------
        img_rgb : ndarray
            The RGBA image which you want to convert to YCbCr

        Returns:
        --------
        img_ycbcr : float ndarray
            The converted image

        '''
        if len(img_rgba.shape) != 4:
            raise ValueError('Input image is not RGBA')

        img_rgb = img_rgba.astype(np.float32)
        
        img_ycrcb = cv2.cvtColor(img_rgba, cv2.COLOR_RGB2YCR_CB)
        img_ycbcr = img_ycrcb[:,:,(0,2,1)].astype(np.float32)
        img_ycbcr[:,:,0] = (img_ycbcr[:,:,0]*(235-16)+16)/255.0
        img_ycbcr[:,:,1:] = (img_ycbcr[:,:,1:]*(240-16)+16)/255.0

        return img_ycbcr

    
    def save_image(self, im: np.ndarray, fname: Path, verbose: bool = False) -> Path:
        # Take a given image and save it as the specified format. The use of a Path variable type
        # is intentional, as there should be less chance of incorrect entries
        # fname = output name of the saved file, including the suffix
        # im = numpy array of image
        #
        # returns fname, so it can be easily appended to a path if necessary

        dim = im.shape #Get number of dimensions of image

        if fname.suffix == '.png':
            # Check that you aren't saving a 3D image
            #TODO: Scale inputs to [0,255] so data isn't lost/image isn't saturated
            cv2.imwrite(f'{fname}',im)
            if verbose:
                print(f'Saving: {fname}')
        elif fname.suffix == '.nii' or fname.suffix == '.gz':

            # TODO: Add option to transpose image for some reason because mricron hates the first dim[0] = 1
            # Still gets loaded fine in terms of loading into python, but visualizing it is bad
            # np.transpose(im, (1,2,0))


            # TODO: If image is 2D then append a third  dimension before saving(?)
            nib.save(nib.Nifti1Image(im, np.eye(len(dim)+1)), fname)
            if verbose:
                print(f'Saving: {fname}')
        elif fname.suffix == '.dcm':
            raise NotImplementedError('DICOM saving currently not supported')
        else:
            raise NotImplementedError(f'Specified file type {fname.suffix} for {fname} is currently not supported for saving')

        return fname

### ImgAug

The idea behind the majority of these methods is that they intake a numpy array and some parameters and output a numpy array and string indicating what has changed.
Is it better to assign the numpy array being worked on to an attribute of the class?

In [54]:
from skimage.transform import rotate, AffineTransform, warp, rescale, resize
import math

class ImgAug(ImgIE):
    def __init__(self) -> None:
        #super(ImgIE, self).__init__()
        # super().__init__()
        pass

    # def aug_run(self, inp: np.ndarray, aug_key: "OrderedDict[str, Any]", randomize: bool=False) -> None:
    #     '''Run provided augmentation using the settings defined by the template dictionary.
    #     Possibly work storing the various methods into a dictionary
    #     https://stackoverflow.com/questions/9168340/using-a-dictionary-to-select-function-to-execute

    #     '''

    #     function_key = {'translation': self.array_translate,
    #     'rotation': self.array_rotate,
    #     'scale': self.array_scale,
    #     'patch': self.img2patches}


    #     for func, kwrd in aug_key:
    #         if func == 'translation':
    #             inp, label = function_key[func](inp,**kwrd)
    #         elif func == 'rotation':
    #             inp, label = function_key[func](inp,**kwrd)
    #         elif func == 'scale':
    #             inp, label = function_key[func](inp,**kwrd)
    #         elif func == 'patch':
    #             inp, label,  = function_key[func](inp,**kwrd)
    #     pass


    def gen_random_aug(self, params: Dict[str,List[int]], float_params: bool=False, negative=True) -> Generator[Dict[str,List[int]],None, None]:
        while True:
            out_params = {}
            if float_params:
                for ky, i in params.items():
                    out_params[ky] = [np.random.uniform(-k*negative, k) if k!=0 else k for k in i]
            else:
                for ky, i in params.items():
                    out_params[ky] = [np.random.randint(-k*negative,k) if k!=0 else k for k in i]

            yield out_params


    def array_translate(self, im: np.ndarray, trans: List[int], mode: str='symmetric') -> "tuple[np.ndarray, str]":
        # Translation
        dim = im.shape

        if len(dim) != len(trans):
            raise IndexError(f'Translation of numpy array with dimensions: {dim} is not compatible with translation {trans}')
        
        if len(dim) == 2:
            transform = AffineTransform(translation=(trans[0], trans[1]))
            im = warp(im, transform, mode=mode)
            label = f'_tr{trans[0]}_{trans[1]}'

        elif len(trans) == 3:
            transform = AffineTransform(translation=(trans[1], trans[2]))
            for i in range(dim[0]):
                im[i,:,:] = warp(im[i,:,:], transform, mode = mode)

            for i in range(dim[1]):
                # Because two dimensions were already translated, you only need to translate
                # along one dimension
                im[:,i,:] = warp(im[:,i,:], AffineTransform(translation=(trans[0],0)), mode='symmetric')
                
            label = f'_tr{trans[0]}_{trans[1]}_{trans[2]}'

        else:
            raise NotImplementedError(f"Translation of objects with dimension {len(dim)} is not currently supported.")
        
        return im, label


    def array_rotate(self, im: np.ndarray, rot: List[int], order: int=1) -> "tuple[np.ndarray, str]":
        # TODO: Issue with low resolution not necessairly having the same dimensions
        # Rotation 2D
        dim = im.shape

        if len(dim) != len(rot):
            raise IndexError(f'Translation of numpy array with dimensions: {dim} is not compatible with translation {rot}')

        if len(dim) == 2:
            im = rotate(im, rot[0], order=order)
            label = f'_rot{rot[0]}'

        # Rotation 3D
        elif len(dim) == 3:
            for i in range(dim[0]):
                im[i,:,:] = rotate(im[i,:,:],rot[0], order=order)
            for i in range(dim[1]):
                im[:,i,:] = rotate(im[:,i,:],rot[1], order=order)
            for i in range(dim[2]):
                im[:,:,i] = rotate(im[:,:,i],rot[2], order=order)
            label = f'_rot{rot[0]}_{rot[1]}_{rot[2]}'
        
        else:
            raise NotImplementedError(f"Translation of objects with dimension {len(rot)} is not currently supported.")

        return im, label

    def array_scale(self, im: np.ndarray, scale: List[float], order: int=1, mode: str='symmetric', int_dims: bool=False) -> "tuple[np.ndarray, str]":
        '''Either upscales or downscales provided array

        https://scikit-image.org/docs/stable/auto_examples/transform/plot_rescale.html
        '''
        # Scaling
        dim = im.shape

        if len(dim) != len(scale):
            raise IndexError(f'Scaling of numpy array with dimensions: {dim} is not compatible with translation {scale}')

        if int_dims:
            new_dims = scale
            label = f'_si_'
        else:
            new_dims = [math.floor(x) for x in np.matmul(dim, scale)]
            label = f'_sr_'

        im = resize(im, new_dims, order=order, mode=mode) #If anti-alaising not specified, it is set to True when downsampling an image whose data type is not bool


        if len(dim) == 2:
            label = label + f'{scale[0]}_{scale[1]}'
        elif len(dim) == 3:
            label = label + f'{scale[0]}_{scale[1]}_{scale[2]}'
        else:
            raise NotImplementedError(f"Translation of objects with dimension {len(scale)} is not currently supported.")

        return im, label
        
    def array_degrade(self, im: np.ndarray, scale: List[float], order: int=1, mode: str='symmetric', int_dims: bool=False) -> np.ndarray:
        '''Uses array_scale to scale in array down and up using the specified order'''

        dim = im.shape

        im, _ = self.array_scale(im, scale, order=order, mode=mode, int_dims=int_dims)

        im, _ = self.array_scale(im, dim, order=order, mode=mode, int_dims=True)

        return im

    
    def gen_noise(self) -> None:
        '''Add noise to provided image'''
        pass
    

    def img2patches(self, im: np.ndarray, patch: List[int], step: List[int], fname:str, min_nonzero: float = 0,
        slice_select: List[int] = [], save: List[str] = [], verbose=False) -> "tuple[np.ndarray, List[str], List[int]]":
        # Depending on the number of dimenions in the `patch` value, either make 2D
        # or 3D images

        '''
        Needs:
            - image: tuple of numpy arrays? That way if you want to apply the same patching to multiple different images you can?
            - min_nonzero: [float] either a flat value or a fraction of the minimum amount of input needs to be non-zero before you keep it
            - slice_select: [list of ints] a list of slices to preserve from the patches (used when pairing slices between images)
            - save_individual patches: tuple[path/filename, form] whether to save each of the patches as seperate files and what format to save them as.
                                    if the tuple is empty, then just return the stack of patches and list of filenames 
            - verbose: [bool]
            - patch; [list of ints] if they input -1 as a patch size, then take the full size of that dimension

        Returns:
            - image stack
            - the paths to the image stack or file names for each of the images in the stack
            - the slice_select list

            not_blank: List[int] a list of all the entries in the stack that passed the min_nonzero quota
        '''

        # Check patch size
        dim = im.shape

        if len(dim) != len(patch):
            raise IndexError(f'Patch selection of numpy array with dimensions: {dim} is not compatible with patch size: {patch}')

        if len(dim) != len(step):
            raise IndexError(f'Patch selection step size of numpy array with dimensions: {dim} is not compatible with step size: {step}')

        # If they have input something for "slice_select", then they are trying to replicate patch selection
        if len(slice_select):
            min_nonzero = 0

        for idx, i in enumerate(patch):
            if i == -1:
                patch[idx] = dim[idx]
                step[idx] = 1
        

        # Create a numpy stack following Pytorch protocols, so 1 dimension more than patch
        
        # Count number of non-zero entries
        cnt = 0
        blank = 0
        not_blank = []
        itter = -1

        # Get total number of patches that will be created:
        #patch_count = np.prod([len(range(0,i,step[idx])) for idx, i in enumerate(dim)])
        if verbose:
            print(f'patch guess = {np.prod([math.floor((i-patch[idx])/step[idx])+1 for idx,i in enumerate(dim)])}')
        patch_count = np.prod([math.floor((i-patch[idx])/step[idx])+1 for idx,i in enumerate(dim)])
        
        if min_nonzero > 1: #If they have given pixel/voxel numbers instead of fractions
            patch_vol = math.prod(patch) - min_nonzero
        else:
            # else, calculate the number of pixels/voxels which must be nonzero
            patch_vol = math.prod(patch) - math.prod(patch)*min_nonzero


        #TODO: There MUST be a better way to organize this whole mess, lol
        if len(dim) == 2:
            stack = np.zeros((patch_count,patch[0],patch[1]))
            if verbose:
                print(f'stack size = {stack.shape}')

            for i in range(0,dim[0],step[0]):
                for j in range(0,dim[1],step[1]):
                    if i+patch[0] <= dim[0] and j+patch[1] <= dim[1]:
                        itter = itter+1 #just a calculator for finding when blanks occur
                        samp = im[i:i+patch[0],j:j+patch[1]]

                        if min_nonzero == 0 or (samp==0).sum() <= patch_vol:
                            stack[cnt,:,:] = samp
                            cnt += 1
                            not_blank.append(itter)
                        else:
                            blank += 1
    
        elif len(dim) == 3:
            stack = np.zeros((patch_count,patch[0],patch[1], patch[2]))
            print(f'stack size = {stack.shape}')

            for i in range(0,dim[0],step[0]):
                for j in range(0,dim[1],step[1]):
                    for k in range(0,dim[2],step[2]):
                        #itter = itter+1 #just a calculator for finding when blanks occur
                        if i+patch[0] <= dim[0] and j+patch[1] <= dim[1] and k+patch[2] <= dim[2]:
                            itter = itter+1
                            samp = im[i:i+patch[0],j:j+patch[1], k:k+patch[2]]

                            if min_nonzero == 0 or (samp==0).sum() <= patch_vol:
                                stack[cnt,:,:,:] = samp
                                cnt += 1
                                not_blank.append(itter)
                            else:
                                blank += 1
        else:
            raise IndexError(f'Images of dimension {dim} not supported by this method. Only 2D and 3D data accepted.')
        

        fnames = []
        pnames = [] #Pnames should be used if patching is the last process you plan on doing. These are paths
        if slice_select:
            for i in range(len(slice_select)):
                fnames.append(f'{save[0]}{fname}_{i}')
        else:
            for i in range(cnt):
                fnames.append(f'{save[0]}{fname}_{i}')

        if verbose:
            print(f'Number of patches: {len(not_blank)}')
            print(f'Number of blank patches: {blank}')


        if save: #If they want to save the intermediate files (return a list of paths instead)
            if slice_select:
                for idx, i in enumerate(slice_select):
                    pnames.append(self.save_image(stack[i], Path(fnames[idx]+save[1]), verbose=verbose))
            else:
                for idx, i in enumerate(fnames):
                    pnames.append(self.save_image(stack[idx], Path(fnames[idx]+save[1]), verbose=verbose))

            return np.array([]), pnames, not_blank # Send back not_blank for some comparison tests between running of this
        
        else:
            if slice_select:
                return stack[slice_select], fnames, [] #Only return the patches which match slice_select

            else:
                return stack[:cnt], fnames, not_blank  #Only return a stack of the patches that passed the min_nonzero check



### ImgDL

In [None]:
class ImgDL(ImgAug):
    def __init__(self) -> None:
        pass

### ImgMet

In [None]:
class ImgMet():
    '''Class for containing different performance metrics'''
    def __init__(self) -> None:
        pass

    def PSNR_calc(self) -> float:
        pass

    def SSIR_calc(self) -> float:
        pass

## U2Net Implementation
This serves as a test-run for using this collection of classes on a real problemset.
*Note: Task05 has two channels in its images*

#### Outline for how to read in the unique file organization that U2Net has:
The model requires the use of 3 different datasets, each kept in their own directory
- Need to match both the original image and the label file together so they can be loaded at the same time
    - Would be good to have a "list of lists" setup for this
- Need to be able to select randomly from only one dataset at a time
    - Ideally each epoch is from a different dataset

In [None]:
## Rough outline
import shutil

class UData(ImgAug):
    '''Class for data management for U^2-Net training and testing
    
    Parameters
    -------
    - path pairs for the folders containing the raw images and the labels
        [[/img_1, /label_1],[/img_2, /label_2]]
    - output directory for generated images (after patches/augmentation is applied)

    '''
    def __init__(self, img_dir: str, label_dir: str, img_out_dir: str, label_out_dir: str, prefix: str='', suffix: str='') -> None:
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.img_out_dir = img_out_dir
        self.label_out_dir = label_out_dir
        self.in_img_files, self.in_img_paths = self.get_files(img_dir, prefix, suffix)
        self.in_label_files, self.in_label_paths = self.get_files(label_dir, prefix, suffix)
        self.out_img_files = []
        self.out_label_files = []


    def get_files(self, file_dir:str, prefix:str, suffix:str) -> Tuple[List[str],List[str]]:
        files = []
        paths = []
        # If they have provided a list of directories (in the case of DICOM or scattered data)
        if isinstance(file_dir, list):
            for inp_dir in file_dir:
                for fil in os.listdir(inp_dir):
                    if fil.startswith(prefix) and fil.endswith(suffix):
                        paths.append(inp_dir + fil)
                        files.append(fil)

                    if not files:
                        raise FileNotFoundError('No applicable files found in input directory')
        else:
            for fil in os.listdir(file_dir):
                if fil.startswith(prefix) and fil.endswith(suffix):
                    paths.append(file_dir + fil)
                    files.append(fil)

                if not files:
                    raise FileNotFoundError('No applicable files found in input directory')

        return files, paths

    def match_files(self, img_dir: str, label_dir: str, update=False, paths=True) -> Tuple[List[Path], List[Path]]:
        # Get the files that have been generated in the output directory
        # If update is false, then just return a list of matched names, if true then
        # change the class variable values accordingly.
        hr_files = os.listdir(img_dir)
        lr_files = os.listdir(label_dir)

        # Get a set of all the files with agreement before the metadata
        if len(hr_files) > len(lr_files):
            matches = list(set(hr_files)-(set(hr_files)-set(lr_files)))
        else:
            matches = list(set(lr_files)-(set(lr_files)-set(hr_files)))

        if update:
            # If you want to save these matched files as class variables
            self.out_img_files = [Path(img_dir + _) for _ in matches]
            self.out_label_files = [Path(label_dir + _) for _ in matches]
            print('Image and Lable file locations updated')
        
        if paths:
            return [Path(img_dir + _) for _ in matches], [Path(label_dir + _ ) for _ in matches]
        
        return [], [] #lazy to make typing work out

    def load_image_pair(self, im_id: Union[int, str] ) -> Tuple[np.ndarray, np.ndarray]:
        # im_id can either be the index value or the name of the file
        
        if self.out_label_files:
            if isinstance(im_id, int):
                img_file = self.out_img_files[im_id]
                label_file = self.out_label_files[im_id]
            elif isinstance(im_id, str):
                _ = self.out_img_files.index(Path(im_id))
                img_file = self.out_img_files[_]
                label_file = self.out_label_files[_]
            else:
                TypeError("Invalid image identifier, please input a string to integer")

            img = self.load_image(img_file)
            lab = self.load_image(label_file)

            return img, lab
        else:
            raise ValueError("No paths for processed image/label files are stored in this class")


    def run(self, clear=False, save=False, contain_lab: bool=False, verbose=False) -> None:
        

        if clear:
            print('Clearing existing output directories')
            shutil.rmtree(self.img_out_dir, ignore_errors=True)
            shutil.rmtree(self.label_out_dir, ignore_errors=True)
            

        os.makedirs(self.img_out_dir, exist_ok=True)
        os.makedirs(self.label_out_dir, exist_ok=True)
        
        fnames_h = []
        fnames_l = []

        # match in_image_files and in_label_files

        #TODO: Come up with good way for match_files to handle multiple input directories
        self.in_img_paths, self.in_label_paths = self.match_files(self.img_dir, self.label_dir, update=False, paths=True)

        aug_params = {"translation":[10,10,10]}
        patch = [50, 50, 1]
        step = [20, 20, 2]

        rand_params_gen = self.gen_random_aug(aug_params)

        # for each image, label in in_img_files:
        out_img_files = []
        out_label_files = []

        for im_p, lab_p in zip(self.in_img_paths, self.in_label_paths):

            # generate a random parameter set
            rand_params = next(rand_params_gen)

            # Load images
            im = self.load_image(im_p)
            lab = self.load_image(lab_p)

            # apply image augmentations to pairs of images
            im, im_suf = self.array_translate(im, rand_params['translation'])
            lab, lab_suf = self.array_translate(lab, rand_params['translation'])

            # save as patches of size [x,y,z]
            
            if contain_lab: #Whether to only take patches which contain the label of interest

                fname = lab_p.stem
                _, b, not_lab = self.img2patches(lab, patch[:], step[:], min_nonzero= 0.3, fname=fname+lab_suf, save=[self.label_out_dir,'.nii'], verbose = False)
                out_label_files.extend(b)

                fname = im_p.stem
                _, a, not_img = self.img2patches(im, patch[:], step[:], fname=fname+im_suf, slice_select=not_lab, save=[self.img_out_dir,'.nii'], verbose = False)
                out_img_files.extend(a)

            else:
                fname = im_p.stem
                _, a, not_img = self.img2patches(im, patch[:], step[:], fname=fname+im_suf, save=[self.img_out_dir,'.nii'], verbose = False)
                out_img_files.extend(a)

                fname = lab_p.stem
                _, b, not_lab = self.img2patches(lab, patch[:], step[:], fname=fname+lab_suf, slice_select=not_img, save=[self.label_out_dir,'.nii'], verbose = False)
                out_label_files.extend(b)
            


        # update file locations for use with load_image_pair
        self.out_img_files = out_img_files
        self.out_label_files = out_label_files
        


test = UData('../data/U2Net/Task02_Heart/imagesTr/','../data/U2Net/Task02_Heart/labelsTr/','../data/U2Net/Task02_Heart/IMG_Patch/','../data/U2Net/Task02_Heart/Label_Patch/')

test.run(clear=True, save=True, contain_lab=True)


In [60]:
test.load_image_pair(1)[1].shape

(50, 50, 1)

## SrGen recreation
Using the defined classes above, is it possible to recreate the SrGen class?

In [None]:
import shutil
class SrGen(ImgAug):
    def __init__(self, inp_dir, HR_out_dir, LR_out_dir, prefix='', suffix='') -> None:
        self.inp_dir = inp_dir
        self.HR_out_dir = HR_out_dir
        self.HR_files = None
        self.LR_out_dir = LR_out_dir
        self.LR_files = None
        self.inp_files, self.inp_paths = self._get_inp_(prefix, suffix)

    def _get_inp_(self, prefix='', suffix='')->Tuple[List[str],List[str]]:

        files = []
        paths = []
        # If they have provided a list of directories (in the case of DICOM or scattered data)
        if isinstance(self.inp_dir, list):
            for inp_dir in self.inp_dir:
                for fil in os.listdir(inp_dir):
                    if fil.startswith(prefix) and fil.endswith(suffix):
                        paths.append(inp_dir + fil)
                        files.append(fil)

                    if not files:
                        raise FileNotFoundError('No applicable files found in input directory')
        else:
            for fil in os.listdir(self.inp_dir):
                if fil.startswith(prefix) and fil.endswith(suffix):
                    paths.append(self.inp_dir + fil)
                    files.append(fil)

                if not files:
                    raise FileNotFoundError('No applicable files found in input directory')

        return files, paths
    
    def _get_LR_out_(self) -> List[str]:
        # get list of files in output directory and determine matching files
        return os.listdir(self.LR_out_dir)

    def _get_HR_out_(self) -> List[str]:
        return os.listdir(self.HR_out_dir)

    def match_altered(self, update=True, paths=False, sort=False):
        # Get the files that have been generated in the output directory
        # If update is false, then just return a list of matched names, if true then
        # change the class variable values accordingly.
        hr_files = self._get_HR_out_()
        lr_files = self._get_LR_out_()

        # Get a set of all the files with agreement before the metadata
        if len(hr_files) > len(lr_files):
            if sort: #TODO: make sort so it isnt [*1.*, *10.*, *100.*, ..., *2.*,...]
                matches = sorted(list(set(hr_files)-(set(hr_files)-set(lr_files))))
            else:
                matches = list(set(hr_files)-(set(hr_files)-set(lr_files)))
        else:
            if sort:
                matches = sorted(list(set(lr_files)-(set(lr_files)-set(hr_files))))
            else:
                matches = list(set(lr_files)-(set(lr_files)-set(hr_files)))

        if update:
            # If you want to save these matched files as class variables
            self.HR_files = [self.HR_out_dir + _ for _ in matches]
            self.LR_files = [self.LR_out_dir + _ for _ in matches]
            print('HR and LR file locations updated')
        
        if paths:
            return self.HR_files, self.LR_files
        
    def change_out(self, HR_out_dir, LR_out_dir):
        # Change the output locations so you can save into a new file
        self.HR_out_dir = HR_out_dir
        self.HR_files = None
        self.LR_out_dir = LR_out_dir
        self.LR_files = None


    def run(self, clear=False, save=False, verbose=False):
        # This method is called to generate the data

        if clear:
            print('Clearing existing output directories')
            shutil.rmtree(self.HR_out_dir, ignore_errors=True)
            if self.template['resolution'] != None:
                shutil.rmtree(self.LR_out_dir, ignore_errors=True)
                os.makedirs(self.LR_out_dir, exist_ok=True)

        os.makedirs(self.HR_out_dir, exist_ok=True)
        fnames_h = []
        fnames_l = []

        for ids, im in enumerate(self.inp_paths):
            im_h = self.load_image(im, verbose)
            opp, im_h = self.img_transform(im_h)

            # Prevents weird new file naming issues when input is compressed file (.nii.gz)
            im = Path(im)
            while im.suffix in {'.tar', '.gz', '.zip'}:
                im = im.with_suffix('')
            
            im = os.path.splitext(im)[0]+opp # Add transformations to file name
            im = os.path.split(im)[1]
        
            # Generate Low Resolution
            if self.template['resolution']:

                dim = im_h.shape
                # efficient way to either make a single value into an array or do nothing if resolution is already a vector
                # TODO: replace transformation if statements with this
                self.template['resolution'] = [int(x) for x in np.multiply(np.ones(len(dim)), self.template['resolution'])]

                # Check that dimensions of HR image are multiples of resolution change, else shave off data
                for i in range(len(dim)):
                    if dim[i] % self.template['resolution'][i]:
                        # If it isn't a clean scaling down
                        _ = dim[i]-(dim[i] % self.template['resolution'][i])

                        im_h = np.delete(im_h,[x for x in range(_, dim[i])],i)

                im_l = self.gen_LR_img(im_h, self.template['resolution'])

            # Create image patches and save them
            if self.template['patch'] and save:
                fnames_h, slice_select = self.img2patches(im_h, self.HR_out_dir + im, save=True, sanity_check=True)
                if self.template['resolution'] != None:
                    fnames_l = self.img2patches(im_l, self.LR_out_dir + im, same_size=self.template['same_size'], save=True, slice_select=slice_select, sanity_check=False)

                    # if not _a == _:
                    #     raise FileExistsError('''WARNING: The patches for High and Low resolution do not match, this is
                    #         most likely due to resolution scaling or patches/steps not being divisible by resolution''')

            elif save:
                fname_h = f'{self.HR_out_dir}{im}.{self.template["out_type"]}'
                self.save_image(fname_h, im_h, verbose)
                fnames_h.append(fname_h)
                if self.template['resolution']:# != None:
                    fname_l = f'{self.LR_out_dir}{im}.{self.template["out_type"]}'
                    self.save_image(fname_l, im_l, verbose)
                    fnames_l.append(fname_l)


        self.HR_files = fnames_h
        self.LR_files = fnames_l

        print('Files processed successfully')