# 3D U^2-Net: A 3D Universal U-Net for Multi-Domain Medical Image Segmentation:
Test things here
https://theaisummer.com/unet-architectures/

https://github.com/patrick-kidger/torchtyping


## Setup
The Medical Segmentation Decathalon data is assumed to be provided in a particular organizational structure with respect to this folder. For the sake of this exercise, I only will be using three of the provided datasets. If you have these files in the correct folder, you should receive the same (or similar) results.

```
../data/U2Net/
        Task02_Heart/
            imagesTr/
                la_003.nii.gz
                ...
                la_030.nii.gz
            imagesTs/
                la_001.nii.gz
                ...
                la_028.nii.gz
            labelsTr/
                la_003.nii.gz
                ...
                la_030.nii.gz
        
        Task04_Hippocampus/
            imagesTr/
                hippocampus_001.nii.gz
                ...
                hippocampus_394.nii.gz
            imagesTs/
                hippocampus_002.nii.gz
                ...
                hippocampus_392.nii.gz
            labelsTr/
                hippocampus_001.nii.gz
                ...
                hippocampus_394.nii.gz

        Task05_Prostate/
            imagesTr/
                prostate_00.nii.gz
                ...
                prostate_47.nii.gz
            imagesTs/
                prostate_03.nii.gz
                ...
                prostate_45.nii.gz
            labelsTr/
                prostate_00.nii.gz
                ...
                prostate_47.nii.gz


```

## Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
import shutil
import numpy as np
sys.path.append('..') # Stupid thing Python makes you do to import from a sibling directory
from gen_utils.ImgTools import ImgAug # Custom class for image generation

## Image creation class

In [None]:
from typing import List, Iterable, Tuple, Any, Union, Generator, Dict
from pathlib import Path

class UData(ImgAug):
    '''Class for data management for U^2-Net training and testing
    
    Parameters
    -------
    - path pairs for the folders containing the raw images and the labels
        [[/img_1, /label_1],[/img_2, /label_2]]
    - output directory for generated images (after patches/augmentation is applied)

    '''
    def __init__(self, img_dir: str, label_dir: str, img_out_dir: str, label_out_dir: str, prefix: str='', suffix: str='') -> None:
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.img_out_dir = img_out_dir
        self.label_out_dir = label_out_dir
        self.in_img_files, self.in_img_paths = self.get_files(img_dir, prefix, suffix)
        self.in_label_files, self.in_label_paths = self.get_files(label_dir, prefix, suffix)
        self.out_img_files = []
        self.out_label_files = []


    def get_files(self, file_dir:str, prefix:str, suffix:str) -> Tuple[List[str],List[str]]:
        files = []
        paths = []
        # If they have provided a list of directories (in the case of DICOM or scattered data)
        if isinstance(file_dir, list):
            for inp_dir in file_dir:
                for fil in os.listdir(inp_dir):
                    if fil.startswith(prefix) and fil.endswith(suffix):
                        paths.append(inp_dir + fil)
                        files.append(fil)

                    if not files:
                        raise FileNotFoundError('No applicable files found in input directory')
        else:
            for fil in os.listdir(file_dir):
                if fil.startswith(prefix) and fil.endswith(suffix):
                    paths.append(file_dir + fil)
                    files.append(fil)

                if not files:
                    raise FileNotFoundError('No applicable files found in input directory')

        return files, paths

    def match_files(self, img_dir: str, label_dir: str, update=False, paths=True) -> Tuple[List[Path], List[Path]]:
        # Get the files that have been generated in the output directory
        # If update is false, then just return a list of matched names, if true then
        # change the class variable values accordingly.
        hr_files = os.listdir(img_dir)
        lr_files = os.listdir(label_dir)

        # Get a set of all the files with agreement before the metadata
        if len(hr_files) > len(lr_files):
            matches = list(set(hr_files)-(set(hr_files)-set(lr_files)))
        else:
            matches = list(set(lr_files)-(set(lr_files)-set(hr_files)))

        if update:
            # If you want to save these matched files as class variables
            self.out_img_files = [Path(img_dir + _) for _ in matches]
            self.out_label_files = [Path(label_dir + _) for _ in matches]
            print('Image and Lable file locations updated')
        
        if paths:
            return [Path(img_dir + _) for _ in matches], [Path(label_dir + _ ) for _ in matches]
        
        return [], [] #lazy to make typing work out

    def load_image_pair(self, im_id: Union[int, str] ) -> Tuple[np.ndarray, np.ndarray]:
        # im_id can either be the index value or the name of the file
        
        if self.out_label_files:
            if isinstance(im_id, int):
                img_file = self.out_img_files[im_id]
                label_file = self.out_label_files[im_id]
            elif isinstance(im_id, str):
                _ = self.out_img_files.index(Path(im_id))
                img_file = self.out_img_files[_]
                label_file = self.out_label_files[_]
            else:
                TypeError("Invalid image identifier, please input a string to integer")

            img = self.load_image(img_file)
            lab = self.load_image(label_file)

            return img, lab
        else:
            raise ValueError("No paths for processed image/label files are stored in this class")


    def run(self, clear=False, save=False, contain_lab: bool=False, verbose=False) -> None:
        

        if clear:
            print('Clearing existing output directories')
            shutil.rmtree(self.img_out_dir, ignore_errors=True)
            shutil.rmtree(self.label_out_dir, ignore_errors=True)
            

        os.makedirs(self.img_out_dir, exist_ok=True)
        os.makedirs(self.label_out_dir, exist_ok=True)
        
        fnames_h = []
        fnames_l = []

        # match in_image_files and in_label_files

        #TODO: Come up with good way for match_files to handle multiple input directories
        self.in_img_paths, self.in_label_paths = self.match_files(self.img_dir, self.label_dir, update=False, paths=True)

        aug_params = {"translation":[10,10,10]}
        patch = [50, 50, 1]
        step = [20, 20, 2]

        rand_params_gen = self.gen_random_aug(aug_params)

        # for each image, label in in_img_files:
        out_img_files = []
        out_label_files = []

        for im_p, lab_p in zip(self.in_img_paths, self.in_label_paths):

            # generate a random parameter set
            rand_params = next(rand_params_gen)

            # Load images
            im = self.load_image(im_p)
            lab = self.load_image(lab_p)

            # apply image augmentations to pairs of images
            im, im_suf = self.array_translate(im, rand_params['translation'])
            lab, lab_suf = self.array_translate(lab, rand_params['translation'])

            # save as patches of size [x,y,z]
            
            if contain_lab: #Whether to only take patches which contain the label of interest

                fname = lab_p.stem
                _, b, not_lab = self.img2patches(lab, patch[:], step[:], min_nonzero= 0.3, fname=fname+lab_suf, save=[self.label_out_dir,'.nii'], verbose = False)
                out_label_files.extend(b)

                fname = im_p.stem
                _, a, not_img = self.img2patches(im, patch[:], step[:], fname=fname+im_suf, slice_select=not_lab, save=[self.img_out_dir,'.nii'], verbose = False)
                out_img_files.extend(a)

            else:
                fname = im_p.stem
                _, a, not_img = self.img2patches(im, patch[:], step[:], fname=fname+im_suf, save=[self.img_out_dir,'.nii'], verbose = False)
                out_img_files.extend(a)

                fname = lab_p.stem
                _, b, not_lab = self.img2patches(lab, patch[:], step[:], fname=fname+lab_suf, slice_select=not_img, save=[self.label_out_dir,'.nii'], verbose = False)
                out_label_files.extend(b)
            


        # update file locations for use with load_image_pair
        self.out_img_files = out_img_files
        self.out_label_files = out_label_files

## U^2-net blocks 

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch: int, out_ch: int) -> None:
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        return x


class InConv(nn.Module):
    def __init__(self, in_ch: int, out_ch: int) -> None:
        super(InConv, self).__init__()
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        return x


class Down(nn.Module):
    def __init__(self, in_ch: int, out_ch: int):
        super(Down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_ch, out_ch)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.mpconv(x)
        return x


class Up(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, bilinear: bool=True) -> None:
        super(Up, self).__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)

        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        x1 = self.up(x1)

        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2))
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class OutConv(nn.Module):
    def __init__(self, in_ch: int, out_ch: int):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        return x

class Unet(nn.Module):
    def __init__(self, in_channels: int, classes: int) -> None:
        super(Unet, self).__init__()
        self.n_channels = in_channels
        self.n_classes =  classes

        self.inc = InConv(in_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256)
        self.up2 = Up(512, 128)
        self.up3 = Up(256, 64)
        self.up4 = Up(128, 64)
        self.outc = OutConv(64, classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return x

In [None]:
# U2-Net steps

class U2Net3D(nn.Module):
    def __init__(self, inChans_list: List[int]=[2], base_outChans: int=16, num_class_list: List[int]=[4], depth: int=5):
        super(U2Net3D, self).__init__()
        self.depth = depth

        nb_tasks = len(num_class_list)

        # Create a module list of the InputTransitions for each input data type
        # In other words, if you are running this model on a set of images with 2 channels and a set of
        # images with 3 channels, then self.in_tr_list = nn.ModuleList([[2x InputTran],[3x InputTran]])
        
        # TODO: In this implementaiton, they are applying a 3x3x3 instead of a 1x1x1 like they say in the paper 
        self.in_tr_list = nn.ModuleList(
            [InputTransition(inChans_list[j], base_outChans) for j in range(nb_tasks)]
        )

        outChans_list = list()
        self.down_blocks = nn.ModuleList() #register modules from regular python list.
        self.down_samps = nn.ModuleList()
        self.down_pads = list() # used to pad as padding='same' in tensorflow

        inChans = base_outChans

        # Create each level of the encoder and decoder
        for i in range(depth):
            outChans = base_outChans * (2**i)
            outChans_list.append(outChans)

            # Add another encoder level to the down_block module list
            self.down_blocks.append(DownBlock(nb_tasks, inChans, outChans, kernel_size=3, stride=1, padding=1))






    def forward(self, x:torch.Tensor) -> torch.Tensor:
        return x


class DownBlock(nn.Module):
    def __init__(self, nb_tasks, inChans, outChans, module: str='seperable_adapter', residual: bool=True, kernel_size=3, stride=1, padding=1) -> None:
        super(DownBlock, self).__init__()
        self.module = module
        self.residual = residual
        self.op1 = conv_unit(nb_tasks, inChans, outChans, kernel_size=kernel_size, stride=stride, padding=padding)
        self.act1 = norm_act(outChans, meth="act")
        self.op2 = conv_unit(nb_tasks, outChans, outChans, kernel_size=kernel_size, stride=stride, padding=padding)
        self.act2 = norm_act(outChans, meth="act")

    def forward(self, x)-> torch.Tensor:

        # TODO: These if/else statements are redundant...

        if self.module == 'parallel_adapter' or self.module == 'separable_adapter':
            out, share_map, para_map = self.op1(x)
        else:
            out = self.op1(x)
        out = self.act1(out)
        if self.module == 'parallel_adapter' or self.module == 'separable_adapter':
            out, share_map, para_map = self.op2(out)
        else:
            out = self.op2(out)
        if self.residual: # same to ResNet
            out = self.act2(x + out)
        else:
            out = self.act2(out)

        return out



class InputTransition(nn.Module):
    '''
    task specific
    '''
    def __init__(self, inChans, base_outChans):
        super(InputTransition, self).__init__()
        self.op1 = nn.Sequential(
            nn.Conv3d(inChans, base_outChans, kernel_size=3, stride=1, padding=1),
            norm_act(base_outChans)
        )

    def forward(self, x):
        out = self.op1(x)
        return out



def norm_act(nchan, meth='both') -> nn.InstanceNorm3d | nn.LeakyReLU | nn.Sequential:
    '''The normalization activation function, or what normalization you are applying to the
    data. Either a InstanceNorm3D for meth='norm', a Leaky ReLU for meth = 'act', 
    or an LeakyReLU(InstanceNorm3D(x)) for meth='both'
     '''
    norm = nn.InstanceNorm3d(nchan, affine=True)
    # act = nn.ReLU() # activation
    act = nn.LeakyReLU(negative_slope=1e-2)
    if meth=='norm':
        return norm
    elif meth=='act':
        return act
    else:
        return nn.Sequential(norm, act)


## Prepare Dataset

### Load Data

### Create DataLoader

## Train Model

## Test Model