# Import

In [64]:
import numpy as np
import torch
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torchvision import transforms
import albumentations as albu
from torch.utils.data import Dataset
from typing import Union,List,Tuple
import pathlib
from timeit import default_timer as timer
from torch.utils.data import DataLoader
import h5py
from tqdm import tqdm
import torchvision
import os
import datetime
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
import time
from torch.utils.tensorboard import SummaryWriter

In [65]:
#Put the device on GPU if possible to train the model faster than with CPU

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

batch_size = 32
period_size = 8
weight_decay = 0.1
epsilon = 10**(-8)
epochs_num = 20
learning_rate = 10**(-3)
n_channel = 1
class_num = 3

label_dir = "label"
input_dir =  "input"
runs_dir = os.path.join("runs", datetime.datetime.now().strftime('%Y-%m-%d%H-%M-%S'))

stride = 32
train_size = 0.7

output = 256

# Training loop

In [66]:
#rdn_train train the model for one iteration
# Input : 
# net: The neural network model being trained.
# optimizer: The optimizer used for updating the model's parameters.
# data_loader: DataLoader providing batches of training data.
# epoch: Current epoch number (optional).
# total_epoch: Total number of epochs (optional).
# tensorboard_plot: Boolean flag for visualizing results on TensorBoard (default is False).
#
# Output :
# returns the total number of iterations processed for this epoch (nb_ite + last_batches)

def rdn_train(net, optimizer, data_loader, epoch=None, total_epoch=None, tensorboard_plot=False, nb_ite=0):
    max_batches1 = len(data_loader.dataset) // data_loader.batch_size + (1 if (len(data_loader.dataset) % data_loader.batch_size) != 0 else 0)

    # the epoch message for printing
    epoch_print = 'Epoch:'
    if epoch is not None:
        epoch_print += f'{epoch + 1}'
    if total_epoch is not None:
        epoch_print += f'/{total_epoch}'
    last_batches = 0.0
    loss1_sum = 0.0
    loss2_sum = 0.0
    ite = 0

    #Initialize writer to visualize results on TensorBoard
    writer = SummaryWriter(runs_dir)
    with tqdm(total=len(data_loader.dataset), desc=epoch_print, unit=' batches') as pbar:
        for i_batches, sample_batched in enumerate(data_loader):
            last_batches = i_batches
            mask = sample_batched['mask']
            image = sample_batched['image']
            #index = sample_batched['index']

            # convert to gpu
            mask = mask.to(device).long()
            image = image.to(device)

            # prediction
            pred = net(image)
            
            #loss1 = DomainEnrichLoss()(net, index, mask)
            loss1 = torch.Tensor(0)
            mask = create_one_hot(mask)
            #Visualisation on Tensorboard
            if tensorboard_plot and nb_ite+ite == 0:
                writer.add_graph(net, image)
            if tensorboard_plot and (ite % (max_batches1 // 3) == 0):
                if epoch is not None:
                    with torch.no_grad():
                        pred2 = net(image)
                    m2 = mask.argmax(1)
                    m2= m2.cpu().squeeze().data.numpy()
                    pred2 = pred2.argmax(1)
                    pred2 = pred2.cpu().squeeze().data.numpy()
                    color_dict = [[0.0], [128.0/255.0], [1]]
                    pred_img = torch.empty_like(image).copy_(image)
                    mask_img = torch.empty_like(image).copy_(image)
                    pred_img.to(device)
                    mask_img.to(device)
                    for i in range(len(pred_img)):
                        for j in range(len(pred_img[i][0])):
                            for k in range(len(pred_img[i][0][j])):
                                pred_img[i][0][j][k] = color_dict[pred2[i][j][k]][0]
                                mask_img[i][0][j][k] = color_dict[m2[i][j][k]][0]

                    writer.add_image('input_image', torchvision.utils.make_grid(image),nb_ite + last_batches)
                    writer.add_image('prediction_image', torchvision.utils.make_grid(pred_img),nb_ite + last_batches)
                    writer.add_image('mask_image', torchvision.utils.make_grid(mask_img),nb_ite + last_batches)

            CE_loss = nn.CrossEntropyLoss()
            loss2 = CE_loss(pred, mask)
            loss2.to(device)


            # backward
            optimizer.zero_grad()
            loss2.backward()
            optimizer.step()

            # Print results
            pbar.update(mask.shape[0])
            pbar.set_postfix(loss=loss2.cpu().data.numpy(),loss1=loss1.cpu().data.numpy(),loss2=loss2.cpu().data.numpy())
            loss1_sum = loss1_sum + loss1.cpu().data.numpy()
            loss2_sum = loss2_sum + loss2.cpu().data.numpy()
            #writer.add_scalars('Losses',{'loss':loss1.cpu().data.numpy(),'loss2':loss2.cpu().data.numpy()}, nb_ite + last_batches)
            #writer.add_scalars('Average_Losses',{'loss':(loss2_sum / (last_batches + 1)),'loss2':(loss2_sum / (last_batches + 1))}, nb_ite + last_batches)
            
            ite += 1

        print(f'\nAverage, loss2: {(loss2_sum/ (last_batches + 1)):.6f}.')
    writer.close()
    return nb_ite + last_batches

#rdn_val test the accuracy of the model for the current epoch
# Input :
# net: The neural network model.
# data_set: The dataset for validation.
# i_epoch: Current epoch number (optional).
# class_num: Number of classes in the dataset (default is 3).
#
# Output :
# returns the average accuracy (criterion_value) and class-wise dice overlap results.

def rdn_val(net, data_set, i_epoch = None, class_num = 3):
    dice_overlap = DiceOverlap(class_num)

    # check whether net is in train mode or not
    origin_is_train_mode = net.training

    # change the net to eval mode
    if origin_is_train_mode:
        net.eval()

    # check whether data set is in train mode
    data_set.val()

    criterion_value_sum = 0.0
    data_loader = DataLoader(data_set, batch_size=1, num_workers=0)
    dice_overlap_results = 0.0

    for i_batches, sample_batched in enumerate(data_loader):
        mask = sample_batched['mask']
        image = sample_batched['image']

        mask = mask.to(device)
        image = image.to(device)

        # prediction
        with torch.no_grad():
            if image.shape==(1,1,256,256):
                pred = net(image)
                criterion_value_sum += Accuracy()(pred, mask.long()).cpu().data.numpy()

                if dice_overlap is not None:
                    dice_overlap_results += dice_overlap(pred, mask.long())

    criterion_value = criterion_value_sum / len(data_loader.dataset)
    dice_overlap_results = dice_overlap_results / len(data_loader.dataset)

    for i in range(dice_overlap_results.shape[0]):
        print(f'Class: {i:.0f}, Dice Overlap: {dice_overlap_results[i]:.6f}')

    if origin_is_train_mode:
        net.train()

    # print message
    if i_epoch is not None:
        print(f"Epoch: {i_epoch + 1}, Accuracy Value: {criterion_value:.6f}")
        writer = SummaryWriter(runs_dir)
        writer.add_scalars('Dice Overlap',{'Air':dice_overlap_results[0],'Dirt':dice_overlap_results[1],'Bone':dice_overlap_results[2]}, i_epoch)
        writer.close()
    return criterion_value, dice_overlap_results

class DiceOverlap():

    def __init__(self, class_num):
        self.len = class_num

    def __call__(self, input, target):
        input = F.sigmoid(input)
        input = torch.max(input, 1)[1]

        dice = []

        for i in range(self.len):
            sub_target = torch.zeros(target.shape).cuda()
            sub_target[target == i] = 1
            sub_input = torch.zeros(input.shape).cuda()
            sub_input[input == i] = 1

            tp_idx = target == i

            eps = 0.0001
            tp = torch.sum(sub_input[tp_idx] == sub_target[tp_idx])
            fn = torch.sum(sub_input != sub_target)
            tp = tp.float()
            fn = fn.float()
            result = (2*tp + eps) / (2*tp + fn + eps)
            dice.append(result.cpu().data.numpy())

        return np.asarray(dice)
    
class Accuracy():

    def __call__(self, input, target, **kwargs):

        input = torch.max(input, 1)[1]
        size = 1
        for i in range(len(input.shape)):
            size = size * input.shape[i]
        return torch.sum(input == target).float() / size

# Generate Patches

In [67]:
#get_minimum_dirt_patches create a subset of initial patches focusing on patches containing dirt (depending on dirt_rate)
# Input :
# dirt_choose_threshold: A threshold value for the dirt ratio to determine whether a patch is considered as a dirt patch.
# dirt_rate: The desired proportion of dirt patches in the selected subset.
# patches: A NumPy array containing information about patches (name, top, left, height, width).
# ratios: A NumPy array containing ratios for each patch (class ratios).
#
# Output :
# returns a subset of initial patches
def get_minimum_dirt_patches(dirt_choose_threshold: float, dirt_rate: float, patches: np.array, ratios:np.array):
    # get ratios
    ratios = np.array(ratios)
    #get index that would sort ratios by decreasing order
    ratios_idx = np.argsort(-ratios, axis=0)

    # ratios dimension is n_patches x n_classes
    # get the second column of ratio_idx which is the dirt index sorted by decreasing order
    dirt_idx = ratios_idx[:, 1]

    # get only the patches that dirt ratio is > dirt_choose_threshold
    last_idx = 0
    for i in range(dirt_idx.shape[0]):
        dirt_ratio = ratios[dirt_idx[i], 1]
        if dirt_ratio < dirt_choose_threshold:
            last_idx = i
            break
    
    #indexes of wanted dirt_patches
    dirt_patches_idx = dirt_idx[0:last_idx]
    #indexes of other patches
    rest_idx = dirt_idx[last_idx:-1]

    
    if not (dirt_rate == 0):
        rest_num = round(((last_idx - 1) / dirt_rate) * (1 - dirt_rate))
        if rest_num > rest_idx.shape[0]:
            rest_num = rest_idx.shape[0]
    else:
        rest_num = rest_idx.shape[0]
    
    #Getting picking other patches
    random_idx = np.random.choice(rest_idx.shape[0], size=rest_num, replace=False)
    non_dirt_patches_idx = rest_idx[random_idx]

    # Getting the final patches
    patches_idx = np.concatenate((dirt_patches_idx, non_dirt_patches_idx), axis=0)
    new_patches = np.asarray(patches)[patches_idx, :].tolist()

    new_patches = shuffle(new_patches)
    new_patches = [[name, int(top), int(left), int(h), int(w)] for [name, top, left, h, w] in new_patches]
    return new_patches

#get_dirt_bone_patches create a subset of initial patches focusing on patches containing air (depending on air_rate)
# Input :
# patches: A NumPy array containing information about patches (name, top, left, height, width).
# ratios: A NumPy array containing ratios for each patch (class ratios).
# air_rate: The desired proportion of air patches in the selected subset.
#
# Output :
# returns a list of extracted patches based on the specified conditions and an index representing the length of this list
def get_dirt_bone_patches(patches: np.array, ratios:np.array, air_rate: float):
    #get ratios 
    ratios = np.array(ratios)
    #get indices that would sort ratios by decreasing order
    ratios_idx = np.argsort(-ratios, axis=0) 

    # get the second column of ratio_idx which is the dirt index sorted by decreasing order
    dirt_idx = ratios_idx[:, 1]
    #get patches 
    patches = np.asarray(patches)

    #get ratios and patches sorted by decreasing dirt ratios 
    ratios_sort = ratios[dirt_idx, :]
    patches_sort = patches[dirt_idx, :]
    
    dirt_patches = []
    bone_patches = []
    
    while (len(dirt_patches) < 128 or len(bone_patches) < 128) and air_rate < 1: #in the case that we have not enough patches 
        dirt_patches = []
        bone_patches = []
        #get patches that contains significant differences between dirt and bone ratios (>0.1)
        for idx in range(ratios_sort.shape[0]):
            if (ratios_sort[idx, 0] < air_rate):
                if (ratios_sort[idx, 1] - ratios_sort[idx, 2] > 0.15):
                    dirt_patches.append(patches_sort[idx, :].tolist())
                elif (ratios_sort[idx, 2] - ratios_sort[idx, 1] > 0.15):
                    bone_patches.append(patches_sort[idx, :].tolist())
            else:
                pass
        air_rate += 0.1

    dirt_len = len(dirt_patches)
    bone_len = len(bone_patches)

    bone_index = [1 for i in range(bone_len)]
    dirt_index = [0 for i in range(dirt_len)]
    
    dirt_patches = shuffle(dirt_patches)
    bone_patches = shuffle(bone_patches)

 
    
    print(f"There are {bone_len} bone and {dirt_len} dirt patches in the training data...")

    end_idx = dirt_len if dirt_len < bone_len else bone_len
    #get the same quantity of dirt and bone patches
    new_patches = []
    for patch in dirt_patches:
        new_patches.append(patch)
    for patch in bone_patches:
        new_patches.append(patch)
    d_index = len(new_patches)

    #This is the use of sklearn shuffle, which can't be replaced by the random shuffle
    #new_patches, d_index = shuffle(new_patches, d_index)
    #new_patches = [[name, int(top), int(left), int(h), int(w)] for [name, top, left, h, w] in new_patches]
    #d_index = [int(idx) for idx in d_index]
    
    
    return new_patches, d_index

#slide_windows generate a list of patches based on a sliding window approach over an input image or data.
# Input :
# name : A string representing the name or identifier for the image or data.
# shape : A tuple representing the shape (height, width) of the input image or data.
# output_size (int): An integer or tuple representing the size of the output patches. Default is set to (128, 128).
# stride (int): An integer or tuple representing the stride of the sliding window. Default is set to 32.
#
# Output :
#  A list of lists, where each inner list contains information about a patch
def slide_windows(name, shape, output_size=128, stride=32):
    output_size = (output_size, output_size)
    strides = (stride, stride)

    patches_list = []
    idx = 0
    while idx * strides[0] + output_size[0] <= shape[0]:
        top = idx * strides[0]
        j = 0
        while j * strides[1] + output_size[1] <= shape[1]:
            left = j * strides[1]
            patches_list.append([name, top, left,output_size[0], output_size[1]])
            j += 1

        if j * strides[1] < shape[1]:
            left = shape[1] - output_size[1]
            patches_list.append([name, top, left, output_size[0], output_size[1]])
        idx += 1

    if idx * strides[0] < shape[0]:
        top = shape[0] - output_size[0]
        j = 0
        while j * strides[1] + output_size[1] <= shape[1]:
            left = j * strides[1]
            patches_list.append([name, top, left, output_size[0], output_size[1]])
            j += 1

        if j * strides[1] < shape[1]:
            left = shape[1] - output_size[1]
            patches_list.append([name, top, left, output_size[0], output_size[1]])
    return patches_list

#get_patches generate patches from a collection of images
# Input :
# data: A dictionary where keys represent image names, and values are the corresponding images (arrays).
# stride: An integer representing the stride of the sliding window. Default is set to 32.
# output_size: An integer representing the size of the output patches. Default is set to 256.
#
# Output :
#  A list of lists, where each inner list contains information about a patch
def get_patches(data, stride: int = 32, output_size: int = 256):
    patches = []
    # Iterate for each image
    for name in (data.keys()):
        shape = data[name].shape
        patches += slide_windows(name, shape, output_size=output_size, stride=stride)
    shuffle(patches)
        # I wrote out the before and after and did a comparison and there are no differences.
        #This section appears to simply cast the first two items in patches as strings.
        # for idx in range(len(patches)):
        #     for j in range(len(patches[idx])):
        #         patches[idx][j] = str(patches[idx][j])
    return patches

#generate_ratios calculates ratios for each patch based on the pixel values in the corresponding labeled masks
# Input :
# patches: A list of lists, where each inner list contains information about a patch (name, top, left, height, width)
# class_num: An integer representing the number of classes in the labeled masks. Default is set to 3.
#
# Output:
# A list of lists, where each inner list contains class ratios for a patch.
def generate_ratios(patches, class_num=3):
    #initialize ratios to 0
    ratios = []
    for i in range(len(patches)):
        mask = plt.imread(label_dir+"/"+patches[0][i])[patches[1][i]: patches[1][i]+patches[3][i], patches[2][i]: patches[2][i]+patches[4][i]]
        mask = adjustMask(mask, class_num)
        size = 1.0
        for idx in range(len(mask.shape)):
            size *= mask.shape[idx]

        #Get ratio for the patch
        ratio = []
        for idx in range(class_num):
            ratio.append(np.sum(mask == idx)/size) 

        ratios.append(ratio)
    return ratios

# Dataset

In [68]:
def load_patches(patches):

    if isinstance(patches, str):
        return np.array(pd.read_csv(patches, header=0)).tolist()
    else:
        return patches

class HDF52D(Dataset):

    # dataset for segmentation used
    def __init__(self, train_patches,val_patches, train_transform=None, val_transform=None, train_idx = None):

        self.patches = {'train': load_patches(train_patches),
                        'val': load_patches(val_patches)}

        self.transforms = {'train': train_transform,
                           'val': val_transform}

        self.train_idx = load_patches(train_idx)

        self.mode = 'train'

    def __getitem__(self, idx):

        p = self.patches[self.mode]
        
        image = input_list[p[0][idx]][int(p[1][idx]):int(p[1][idx]) + int(p[3][idx]), int(p[2][idx]):int(p[2][idx]) + int(p[4][idx]) ]
        mask = label_list[p[0][idx]][int(p[1][idx]):int(p[1][idx]) + int(p[3][idx]), int(p[2][idx]):int(p[2][idx]) + int(p[4][idx]) ]
        sample = {'image': image, 'mask': mask}

        if self.transforms[self.mode] is not None:
            sample = self.transforms[self.mode](sample)
            
        #if self.train_idx is not None and self.mode == 'train':
        #    sample['index'] = self.train_idx[idx]
        return sample


    def train(self):
        self.mode = 'train'

    def val(self):
        self.mode = 'val'

    def __len__(self):
        return len(self.patches[self.mode])

# Dataprocess

In [69]:
def create_one_hot(mask, num_classes = 3):
    one_hot_mask = torch.zeros([mask.shape[0],
                                num_classes,
                                mask.shape[1],
                                mask.shape[2]],
                               dtype=torch.float32)
    if mask.is_cuda:
        one_hot_mask = one_hot_mask.cuda()
    one_hot_mask = one_hot_mask.scatter(1, mask.long().data.unsqueeze(1), 1.0)

    return one_hot_mask

def adjustMask(mask, class_num):

    interval = int(256.0 / class_num)

    # Color_Dict must be a numpy type
    # mask.shape must be a H x W x C
    # do not have channel dimensions
    if len(mask.shape) == 2:
        new_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.longlong)
        for i in range(class_num):
            if i <= class_num - 2:
                new_mask[(mask >= i*interval) & (mask < (i+1) * interval)] = i
            else:
                new_mask[i*interval <= mask] = i
        return new_mask

class AdjustMask(object):
    def __init__(self, class_num = 3):
        self.class_num = class_num

    def __call__(self, sample):
        sample['mask'] = adjustMask(sample['mask'], self.class_num)
        return sample

class ToTensor(object):

    def __init__(self, if_multi_img=False):
        self.if_multi_img = if_multi_img

    def __call__(self, sample):
        image, mask = sample['image'], sample['mask']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W

        if not self.if_multi_img:
            if len(image.shape) == 2:
                image = np.expand_dims(image, axis=2)
            image = image.transpose((2, 0, 1))
        else:
            if len(image.shape) == 3:
                image = np.expand_dims(image, axis=3)

            image = image.transpose((0, 3, 1, 2))

        sample['image'] = torch.from_numpy(image)
        sample['mask'] = torch.from_numpy(mask)

        if 'weights' in sample:
            sample['weights'] = torch.from_numpy(sample['weights'])
        if 'ratio' in sample:
            sample['ratio'] = torch.from_numpy(sample['ratio'])
        return sample

class Normalize(object):
    def __init__(self, max=255.0, min=0.0, tg_max=1.0, tg_min=0.0):
        self.max = max
        self.min = min
        self.tg_max = tg_max
        self.tg_min = tg_min

    def __call__(self, sample):
        image = sample['image'].astype('float32')
        image = self.tg_min + ((image - self.min)*(self.tg_max - self.tg_min)) / (self.max - self.min)
        sample['image'] = image
        return sample

class Augmentation(object):

    def __init__(self, output_size=256):
        self.aug = albu.Compose([
            albu.OneOf([
                albu.HorizontalFlip(p=1),
                albu.VerticalFlip(p=1),   
                albu.Compose([
                    albu.HorizontalFlip(p=1),
                    albu.VerticalFlip(p=1), 
                ])
            ], p=0.75),
            # albu.OneOf([
            # albu.RandomContrast(),
            # albu.RandomGamma(),
            # albu.RandomBrightness(),
            # ], p=0.5),
            # albu.OneOf([
            # albu.ElasticTransform(alpha=60, sigma=120 * 0.05, alpha_affine=120 * 0.03),
            # albu.GridDistortion(),
            # albu.OpticalDistortion(distort_limit=2, shift_limit=0.5),
            # ], p=0.),5
            albu.augmentations.geometric.rotate.RandomRotate90(p=1),
            albu.Resize(output_size, output_size, always_apply=True),
        ])
    def __call__(self,sample):
        augmented = self.aug(image=sample['image'], mask=sample['mask'])
        sample['image'] = augmented['image']
        sample['mask'] = augmented['mask']
        return sample

# UNet-parts

In [70]:
class DomainEnrich(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.domain_enrich = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.domain_enrich(x)


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


# UNet Light RDN

In [71]:
class UNet_Light_RDN(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet_Light_RDN, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        #self.rdn1 = DomainEnrich_Block(n_channels, 8)
        #self.rdn2 = DomainEnrich_Block(n_channels, 8)

        #self.inc = DoubleConv(17, 32)
        
        self.inc = DoubleConv(1, 32) #only unet
        
        self.down1 = Down(32, 64)
        self.down2 = Down(64, 128)
        self.down3 = Down(128, 256)
        self.down4 = Down(256, 256)
        self.up1 = Up(512, 128, bilinear)
        self.up2 = Up(256, 64, bilinear)
        self.up3 = Up(128, 32, bilinear)
        self.up4 = Up(64, 32, bilinear)
        self.outc = OutConv(32, n_classes)

    def forward(self, x):
        # identity = x
        # self.x_rdn1 = self.rdn1(x)
        # self.x_rdn2 = self.rdn2(x)
        # # self.x_rdn2 = self.rdn2(self.x_rdn1)
        # x1 = self.inc(torch.cat((self.x_rdn2, self.x_rdn1, identity), 1))
        
        x1 = self.inc(x) #only unet
        
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

# Parameters

In [72]:
name_list = os.listdir(label_dir)

label_list = {name:plt.imread(label_dir+"/"+name) for name in name_list}
input_list = {name:plt.imread(input_dir+"/"+name) for name in name_list}

#Divide each unsegmented picture into smaller picture as the training set
train_patches = pd.DataFrame(get_patches(input_list, stride=stride, output_size=output))
#Same thing for the labelised pictures as the label set
val_patches = pd.DataFrame(get_patches(label_list,   stride=stride, output_size=output))
#Calculate the % of air, bones and dirt for each patch of the training set
ratios = pd.DataFrame(generate_ratios(val_patches, class_num=class_num))

In [73]:
net = UNet_Light_RDN(n_channels=n_channel, n_classes=class_num)
net.to(device)

UNet_Light_RDN(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(64, eps=1e-05, m

In [74]:
optimizer = Adam(net.parameters(),
                      lr=float(learning_rate),
                      eps=float(epsilon),
                      betas=(0.9, 0.999),
                      weight_decay=weight_decay)

#learning rate schedule
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=period_size, gamma=0.1)


In [75]:
# create train transform
train_transform = transforms.Compose([Augmentation(output_size=64), #config['output_size']
                                                    AdjustMask(class_num=class_num),
                                                    Normalize(max=255, min=0),
                                                    ToTensor()])

val_transform = transforms.Compose([AdjustMask(class_num=class_num),
                                                    Normalize(max=255, min=0),
                                                    ToTensor()])

In [76]:
epoch_count = 0
#st.write(f"TensorBoard is availaible, run this following command in a terminal : tensorboard --logdir=runs")
print(f"Epoch progress:")
print(f"Progress training {epochs_num} epochs...")
total_timer = timer()
iteration = 0
nb_ite = 0
#subprocess.call('echo "TensorBoard available, run this command to enable it : tensorboard --logdir=runs"', shell=True)
for i_epoch in range(epochs_num):
    print(f"Epoch {epoch_count + 1} of {epochs_num}")
    
    if i_epoch < period_size:
        #dirt_rate = 0.5
        air_rate = 0.1
    elif i_epoch < 2 * period_size and i_epoch >= period_size:
        #dirt_rate = 0.3
        air_rate = 0.2
    elif i_epoch < 3 * period_size and i_epoch >= 2 * period_size:
        #dirt_rate = 0.1
        air_rate = 0.4
    else:
        #dirt_rate = 0.0
        air_rate = 0.5
    
    #Get patches 
    patches = get_minimum_dirt_patches(dirt_choose_threshold=0.1, dirt_rate=0, patches=train_patches, ratios=ratios)
    
    DEB_patches, index = get_dirt_bone_patches(train_patches, ratios, air_rate)
    DEB_patches = pd.DataFrame(DEB_patches)
    
    data_set = HDF52D(patches, val_patches, train_transform=train_transform, val_transform=val_transform)
    
    DEB_data_set = HDF52D(DEB_patches, val_patches, train_transform=train_transform, val_transform=val_transform, train_idx=index)

    current_batch = int(batch_size)
    
    train_data_loader = DataLoader(dataset=DEB_data_set, batch_size=current_batch, shuffle=True, num_workers=0)
    
    
                            
    # train_data_loader.append(DataLoader(dataset=training_data_set,
    #                                     batch_size=current_batch,
    #                                     shuffle=True,
    #                                     num_workers=0))
                            
    print(f"learning rate {optimizer.param_groups[0]['lr']:.6f}")
    nb_ite = rdn_train(net, optimizer, train_data_loader, epoch=i_epoch, total_epoch=epochs_num, tensorboard_plot=True, nb_ite=nb_ite)
    #lr_scheduler.step()
    
    # validating
    val_loss, class_val = rdn_val(net, data_set, i_epoch=i_epoch, class_num=class_num)
    
    class_val = pd.DataFrame(class_val)
    class_val.columns = ["Class Dice overlap"]
    print(class_val)
    epoch_count += 1
    iteration = np.floor((100 * epoch_count) / int(epochs_num))

# save model
torch.save(net.state_dict(), "RDN.pth")

Epoch progress:
Progress training 20 epochs...
Epoch 1 of 20
There are 219 bone and 0 dirt patches in the training data...
learning rate 0.001000


  diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
  diffX = torch.tensor([x2.size()[3] - x1.size()[3]])
Epoch:1/20:  29%|██▉       | 64/219 [00:08<00:20,  7.62 batches/s, loss=0.7925877, loss1=[], loss2=0.7925877]


KeyboardInterrupt: 