# Urinary-Stone-Challenge

This notebook provides everything necessary to train and evaluate a urinary stone segmentation model.
The baseline network is Modified-UNet.

## import

In [None]:
from __future__ import print_function

import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader,Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

import os
import argparse
import random
import numpy as np
import sys
import time

import cv2
from random import uniform
from imgaug import augmenters as iaa

from glob import glob
import SimpleITK as sitk

from tqdm import tqdm
from medpy.metric.binary import sensitivity, specificity, dc, hd95

from options import parse_option
from network import create_model


import warnings

from tensorboardX import SummaryWriter

from matplotlib import pyplot as plt 


In [None]:
warnings.filterwarnings('ignore')
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.fastest = True

In [None]:
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
random.seed(0)

# Network

## Modified-Unet

In [None]:
class Modified2DUNet(nn.Module):
	def __init__(self, in_channels, n_classes, base_n_filter = 8):
		super(Modified2DUNet, self).__init__()
		self.in_channels = in_channels
		self.n_classes = n_classes
		self.base_n_filter = base_n_filter

		self.lrelu = nn.LeakyReLU()
		self.dropout3d = nn.Dropout3d(p=0.6)
		

		# Level 1 context pathway
		self.conv3d_c1_1 = nn.Conv2d(self.in_channels, self.base_n_filter, kernel_size=3, stride=1, padding=1, bias=False)
		self.conv3d_c1_2 = nn.Conv2d(self.base_n_filter, self.base_n_filter, kernel_size=3, stride=1, padding=1, bias=False)
		self.lrelu_conv_c1 = self.lrelu_conv(self.base_n_filter, self.base_n_filter)
		self.gnorm3d_c1 = nn.GroupNorm(self.base_n_filter//2, self.base_n_filter)

		# Level 2 context pathway
		self.conv3d_c2 = nn.Conv2d(self.base_n_filter, self.base_n_filter*2, kernel_size=3, stride=2, padding=1, bias=False)
		self.norm_lrelu_conv_c2 = self.norm_lrelu_conv(self.base_n_filter*2, self.base_n_filter*2)
		self.norm_lrelu_conv_c2 = self.norm_lrelu_conv(self.base_n_filter*2, self.base_n_filter*2)
		self.norm_lrelu_conv_c2 = self.norm_lrelu_conv(self.base_n_filter*2, self.base_n_filter*2)
		self.gnorm3d_c2 = nn.GroupNorm(self.base_n_filter, self.base_n_filter*2)

		# Level 3 context pathway
		self.conv3d_c3 = nn.Conv2d(self.base_n_filter*2, self.base_n_filter*4, kernel_size=3, stride=2, padding=1, bias=False)
		self.norm_lrelu_conv_c3 = self.norm_lrelu_conv(self.base_n_filter*4, self.base_n_filter*4)
		self.norm_lrelu_conv_c3 = self.norm_lrelu_conv(self.base_n_filter*4, self.base_n_filter*4)
		self.norm_lrelu_conv_c3 = self.norm_lrelu_conv(self.base_n_filter*4, self.base_n_filter*4)
		self.gnorm3d_c3 = nn.GroupNorm(self.base_n_filter, self.base_n_filter*4)

		# Level 4 context pathway
		self.conv3d_c4 = nn.Conv2d(self.base_n_filter*4, self.base_n_filter*8, kernel_size=3, stride=2, padding=1, bias=False)
		self.norm_lrelu_conv_c4 = self.norm_lrelu_conv(self.base_n_filter*8, self.base_n_filter*8)
		self.norm_lrelu_conv_c4 = self.norm_lrelu_conv(self.base_n_filter*8, self.base_n_filter*8)
		self.norm_lrelu_conv_c4 = self.norm_lrelu_conv(self.base_n_filter*8, self.base_n_filter*8)
		self.gnorm3d_c4 = nn.GroupNorm(self.base_n_filter*2, self.base_n_filter*8)

		# Level 5 context pathway, level 0 localization pathway
		self.conv3d_c5 = nn.Conv2d(self.base_n_filter*8, self.base_n_filter*16, kernel_size=3, stride=2, padding=1, bias=False)
		self.norm_lrelu_conv_c5 = self.norm_lrelu_conv(self.base_n_filter*16, self.base_n_filter*16)
		self.norm_lrelu_conv_c5 = self.norm_lrelu_conv(self.base_n_filter*16, self.base_n_filter*16)
		self.norm_lrelu_conv_c5 = self.norm_lrelu_conv(self.base_n_filter*16, self.base_n_filter*16)
		self.norm_lrelu_upscale_conv_norm_lrelu_l0_1 = self.norm_lrelu_upscale_conv_norm_lrelu_1(self.base_n_filter*16)
		self.norm_lrelu_upscale_conv_norm_lrelu_l0_2 = self.norm_lrelu_upscale_conv_norm_lrelu_2(self.base_n_filter*16, self.base_n_filter*8)

		self.conv3d_l0 = nn.Conv2d(self.base_n_filter*8, self.base_n_filter*8, kernel_size = 1, stride=1, padding=0, bias=False)
		self.gnorm3d_l0 = nn.GroupNorm(self.base_n_filter*2, self.base_n_filter*8)

		# Level 1 localization pathway
		self.conv_norm_lrelu_l1 = self.conv_norm_lrelu(self.base_n_filter*16, self.base_n_filter*16)
		self.conv3d_l1 = nn.Conv2d(self.base_n_filter*16, self.base_n_filter*8, kernel_size=1, stride=1, padding=0, bias=False)
		self.norm_lrelu_upscale_conv_norm_lrelu_l1_1 = self.norm_lrelu_upscale_conv_norm_lrelu_1(self.base_n_filter*8)
		self.norm_lrelu_upscale_conv_norm_lrelu_l1_2 = self.norm_lrelu_upscale_conv_norm_lrelu_2(self.base_n_filter*8, self.base_n_filter*4)

		# Level 2 localization pathway
		self.conv_norm_lrelu_l2 = self.conv_norm_lrelu(self.base_n_filter*8, self.base_n_filter*8)
		self.conv3d_l2 = nn.Conv2d(self.base_n_filter*8, self.base_n_filter*4, kernel_size=1, stride=1, padding=0, bias=False)
		self.norm_lrelu_upscale_conv_norm_lrelu_l2_1 = self.norm_lrelu_upscale_conv_norm_lrelu_1(self.base_n_filter*4)
		self.norm_lrelu_upscale_conv_norm_lrelu_l2_2 = self.norm_lrelu_upscale_conv_norm_lrelu_2(self.base_n_filter*4, self.base_n_filter*2)

		# Level 3 localization pathway
		self.conv_norm_lrelu_l3 = self.conv_norm_lrelu(self.base_n_filter*4, self.base_n_filter*4)
		self.conv3d_l3 = nn.Conv2d(self.base_n_filter*4, self.base_n_filter*2, kernel_size=1, stride=1, padding=0, bias=False)
		self.norm_lrelu_upscale_conv_norm_lrelu_l3_1 = self.norm_lrelu_upscale_conv_norm_lrelu_1(self.base_n_filter*2)
		self.norm_lrelu_upscale_conv_norm_lrelu_l3_2 = self.norm_lrelu_upscale_conv_norm_lrelu_2(self.base_n_filter*2, self.base_n_filter)

		# Level 4 localization pathway
		self.conv_norm_lrelu_l4 = self.conv_norm_lrelu(self.base_n_filter*2, self.base_n_filter*2)
		self.conv3d_l4 = nn.Conv2d(self.base_n_filter*2, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False)

		self.ds2_1x1_conv3d = nn.Conv2d(self.base_n_filter*8, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False)
		self.ds3_1x1_conv3d = nn.Conv2d(self.base_n_filter*4, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False)




	def conv_norm_lrelu(self, feat_in, feat_out):
		return nn.Sequential(
			nn.Conv2d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False),
			nn.GroupNorm(feat_out//2, feat_out),
			nn.LeakyReLU())

	def norm_lrelu_conv(self, feat_in, feat_out):
		return nn.Sequential(
			nn.GroupNorm(feat_in//2, feat_in),
			nn.LeakyReLU(),
			nn.Conv2d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False))

	def lrelu_conv(self, feat_in, feat_out):
		return nn.Sequential(
			nn.LeakyReLU(),
			nn.Conv2d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False))

	def norm_lrelu_upscale_conv_norm_lrelu_1(self, feat_in):
		return nn.Sequential(
			nn.GroupNorm(feat_in//2, feat_in),
			nn.LeakyReLU())

	def norm_lrelu_upscale_conv_norm_lrelu_2(self, feat_in, feat_out):
		return nn.Sequential(
			# should be feat_in*2 or feat_in
			nn.Conv2d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False),
			nn.GroupNorm(feat_out//2, feat_out),
			nn.LeakyReLU())

	def forward(self, x):
		#  Level 1 context pathway
		out = self.conv3d_c1_1(x)
		residual_1 = out
		out = self.lrelu(out)
		out = self.conv3d_c1_2(out)
		out = self.dropout3d(out)
		out = self.lrelu_conv_c1(out)
		# Element Wise Summation
		out += residual_1
		context_1 = self.lrelu(out)
		out = self.gnorm3d_c1(out)
		out = self.lrelu(out)

		# Level 2 context pathway
		out = self.conv3d_c2(out)
		residual_2 = out
		out = self.norm_lrelu_conv_c2(out)
		out = self.norm_lrelu_conv_c2(out)
		out = self.dropout3d(out)
		out = self.norm_lrelu_conv_c2(out)
		out = self.norm_lrelu_conv_c2(out)
		out += residual_2
		out = self.gnorm3d_c2(out)
		out = self.lrelu(out)
		context_2 = out

		# Level 3 context pathway
		out = self.conv3d_c3(out)
		residual_3 = out
		out = self.norm_lrelu_conv_c3(out)
		out = self.norm_lrelu_conv_c3(out)
		out = self.dropout3d(out)
		out = self.norm_lrelu_conv_c3(out)
		out = self.norm_lrelu_conv_c3(out)
		out += residual_3
		out = self.gnorm3d_c3(out)
		out = self.lrelu(out)
		context_3 = out

		# Level 4 context pathway
		out = self.conv3d_c4(out)
		residual_4 = out
		out = self.norm_lrelu_conv_c4(out)
		out = self.norm_lrelu_conv_c4(out)
		out = self.dropout3d(out)
		out = self.norm_lrelu_conv_c4(out)
		out = self.norm_lrelu_conv_c4(out)
		out += residual_4
		out = self.gnorm3d_c4(out)
		out = self.lrelu(out)
		context_4 = out

		# Level 5
		out = self.conv3d_c5(out)
		residual_5 = out
		out = self.norm_lrelu_conv_c5(out)
		out = self.norm_lrelu_conv_c5(out)
		out = self.dropout3d(out)
		out = self.norm_lrelu_conv_c5(out)
		out = self.norm_lrelu_conv_c5(out)
		out += residual_5
		out = self.norm_lrelu_upscale_conv_norm_lrelu_l0_1(out)
		out = F.interpolate(out, scale_factor=2, mode='nearest')
		out = self.norm_lrelu_upscale_conv_norm_lrelu_l0_2(out)

		out = self.conv3d_l0(out)
		out = self.gnorm3d_l0(out)
		out = self.lrelu(out)

		# Level 1 localization pathway
		out = F.interpolate(out, size = context_4.size()[-2:])
		out = torch.cat([out, context_4], dim=1)
		out = self.conv_norm_lrelu_l1(out)
		out = self.conv3d_l1(out)
		out = self.norm_lrelu_upscale_conv_norm_lrelu_l1_1(out)
		out = F.interpolate(out, scale_factor=2, mode='nearest')
		out = self.norm_lrelu_upscale_conv_norm_lrelu_l1_2(out)


		# Level 2 localization pathway
		out = F.interpolate(out, size = context_3.size()[-2:])
		out = torch.cat([out, context_3], dim=1)
		out = self.conv_norm_lrelu_l2(out)
		ds2 = out
		out = self.conv3d_l2(out)
		out = self.norm_lrelu_upscale_conv_norm_lrelu_l2_1(out)
		out = F.interpolate(out, scale_factor=2, mode='nearest')
		out = self.norm_lrelu_upscale_conv_norm_lrelu_l2_2(out)

		# Level 3 localization pathway
		out = F.interpolate(out, size = context_2.size()[-2:])
		out = torch.cat([out, context_2], dim=1)
		out = self.conv_norm_lrelu_l3(out)
		ds3 = out
		out = self.conv3d_l3(out)
		out = self.norm_lrelu_upscale_conv_norm_lrelu_l3_1(out)
		out = F.interpolate(out, scale_factor=2, mode='nearest')
		out = self.norm_lrelu_upscale_conv_norm_lrelu_l3_2(out)

		# Level 4 localization pathway
		out = F.interpolate(out, size = context_1.size()[-2:])
		out = torch.cat([out, context_1], dim=1)
		out = self.conv_norm_lrelu_l4(out)
		out_pred = self.conv3d_l4(out)

		ds2_1x1_conv = self.ds2_1x1_conv3d(ds2)
		ds1_ds2_sum_upscale = F.interpolate(ds2_1x1_conv, scale_factor=2, mode='nearest')
		ds3_1x1_conv = self.ds3_1x1_conv3d(ds3)
		ds1_ds2_sum_upscale = F.interpolate(ds1_ds2_sum_upscale, size = ds3_1x1_conv.size()[-2:])
		ds1_ds2_sum_upscale_ds3_sum = ds1_ds2_sum_upscale + ds3_1x1_conv
		ds1_ds2_sum_upscale_ds3_sum_upscale = F.interpolate(ds1_ds2_sum_upscale_ds3_sum, scale_factor=2, mode='nearest')

		out = out_pred + ds1_ds2_sum_upscale_ds3_sum_upscale
		# out = out.permute(0, 2, 3, 4, 1).contiguous().view(-1, self.n_classes)
		# out = out.view(-1, self.n_classes)
		
		return out

## Create model

In [None]:

def create_model(opt):
    # Load network
    net = Modified2DUNet(1, 1, opt.base_n_filter)

    # GPU settings
    if opt.use_gpu:
        net.cuda()
        if opt.ngpu > 1:
            net = torch.nn.DataParallel(net)
    
    if opt.resume:
        if os.path.isfile(opt.exp + "/" + opt.resume):
            pretrained_dict = torch.load(opt.exp + "/" + opt.resume, map_location=torch.device('cpu'))
            model_dict = net.state_dict()

            match_cnt = 0
            mismatch_cnt = 0
            pretrained_dict_matched = dict()
            for k, v in pretrained_dict.items():
                if k in model_dict and v.size() == model_dict[k].size():
                    pretrained_dict_matched[k] = v
                    match_cnt += 1
                else:
                    mismatch_cnt += 1
                    
            model_dict.update(pretrained_dict_matched) 
            net.load_state_dict(model_dict)

            print("=> Successfully loaded weights from %s (%d matched / %d mismatched)" % (opt.resume, match_cnt, mismatch_cnt))

        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    return net

# Dataloader

## Augmentation

In [None]:
def image_windowing(img, w_min=0, w_max=300):
    img_w = img.copy()

    img_w[img_w < w_min] = w_min
    img_w[img_w > w_max] = w_max

    return img_w
    
def image_minmax(img):
    img_minmax = ((img - np.min(img)) / (np.max(img) - np.min(img))).copy()
    img_minmax = (img_minmax * 255).astype(np.uint8)
        
    return img_minmax

def mask_binarization(mask_array):
    threshold = np.max(mask_array) / 2
    mask_binarized = (mask_array > threshold).astype(np.uint8)
    
    return mask_binarized

def augment_imgs_and_masks(imgs, masks, rot_factor, scale_factor, trans_factor, flip):
    rot_factor = uniform(-rot_factor, rot_factor)
    scale_factor = uniform(1-scale_factor, 1+scale_factor)
    trans_factor = [int(imgs.shape[1]*uniform(-trans_factor, trans_factor)),
                    int(imgs.shape[2]*uniform(-trans_factor, trans_factor))]

    seq = iaa.Sequential([
            iaa.Affine(
                translate_px={"x": trans_factor[0], "y": trans_factor[1]},
                scale=(scale_factor, scale_factor),
                rotate=rot_factor
            )
        ])

    seq_det = seq.to_deterministic()

    imgs = seq_det.augment_images(imgs)
    masks = seq_det.augment_images(masks)

    if flip and uniform(0, 1) > 0.5:
        imgs = np.flip(imgs, 2).copy()
        masks = np.flip(masks, 2).copy()

    return imgs, masks


def mask_binarization(mask_array):
    threshold = np.max(mask_array) / 2
    mask_binarized = (mask_array > threshold).astype(np.uint8)
    
    return mask_binarized


def center_crop(img, width):
    y, x = img.shape
    x_center = x/2.0
    y_center = y/2.0
    x_min = int(x_center - width/2.0)
    x_max = x_min + width
    y_min = int(y_center - width/2.0)
    y_max = y_min + width
    img_cropped = img[y_min:y_max, x_min: x_max]
    return img_cropped


## dataset & loader

In [None]:
# UrinaryStoneDataset
class UrinaryStoneDataset(Dataset):
    def __init__(self, opt, is_Train=True, augmentation=True):
        super(UrinaryStoneDataset, self).__init__()

        self.dcm_list = sorted(glob(os.path.join(opt.data_root, 'Train' if is_Train else 'Valid', 'DCM', '*.dcm')))
        
        self.len = len(self.dcm_list)

        self.augmentation = augmentation
        self.opt = opt

        self.is_Train = is_Train

    def __getitem__(self, index):
        # Load Image and Mask
        dcm_path = self.dcm_list[index]
        mask_path = dcm_path.replace('DCM', 'Label').replace('.dcm', '.png')

        img_sitk = sitk.ReadImage(dcm_path)
        img = sitk.GetArrayFromImage(img_sitk)[0]
        mask = cv2.imread(mask_path, 0)

        # HU Windowing
        img = image_windowing(img, self.opt.w_min, self.opt.w_max)

        # Center Crop and MINMAX to [0, 255] and Resize
        img = center_crop(img, self.opt.crop_size)
        mask = center_crop(mask, self.opt.crop_size)
        
        img = image_minmax(img)
        
        img = cv2.resize(img, (self.opt.input_size, self.opt.input_size))
        mask = cv2.resize(mask, (self.opt.input_size, self.opt.input_size))

        # MINMAX to [0, 1]
        img = img / 255.

        # Mask Binarization (0 or 1)
        mask = mask_binarization(mask)

        # Add channel axis
        img = img[None, ...].astype(np.float32)
        mask = mask[None, ...].astype(np.float32)
                
        # Augmentation
        if self.augmentation:
            img, mask = augment_imgs_and_masks(img, mask, self.opt.rot_factor, self.opt.scale_factor, self.opt.trans_factor, self.opt.flip)

        return img, mask
        
    def __len__(self):
        return self.len


In [None]:
# get_dataloader
def get_dataloader(opt):
    trn_dataset = UrinaryStoneDataset(opt, is_Train=True, augmentation=True)
    val_dataset = UrinaryStoneDataset(opt, is_Train=False, augmentation=False)

    train_dataloader = DataLoader(trn_dataset,
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.workers)

    valid_dataloader = DataLoader(val_dataset,
                                  batch_size=opt.batch_size,
                                  shuffle=False,
                                  num_workers=opt.workers)
    
    return train_dataloader, valid_dataloader

# Optimizer & Loss

## Loss 

### Dice loss

In [None]:
class DiceLoss(nn.Module):
    """Computes Dice Loss, which just 1 - DiceCoefficient described above.
    Additionally allows per-class weights to be provided.
    """

    def __init__(self, epsilon=1e-5, weight=None, ignore_index=None, sigmoid_normalization=True,
                 skip_last_target=False):
        super(DiceLoss, self).__init__()
        if isinstance(weight, list):
            weight = torch.Tensor(weight)
            
        self.epsilon = epsilon
        self.register_buffer('weight', weight)
        self.ignore_index = ignore_index

        if sigmoid_normalization:
            self.normalization = nn.Sigmoid()
        else:
            self.normalization = nn.Softmax(dim=1)
        # if True skip the last channel in the target
        self.skip_last_target = skip_last_target

    def forward(self, input, target):
        # get probabilities from logits

        input = self.normalization(input)
        if self.weight is not None:
            weight = Variable(self.weight, requires_grad=False).to(input.device)
        else:
            weight = None

        if self.skip_last_target:
            target = target[:, :-1, ...]

        per_channel_dice = compute_per_channel_dice(input, target, epsilon=self.epsilon, ignore_index=self.ignore_index, weight=weight)
        # Average the Dice score across all channels/classes
        return torch.mean(1. - per_channel_dice)



In [None]:

def compute_per_channel_dice(input, target, epsilon=1e-5, ignore_index=None, weight=None):
    # assumes that input is a normalized probability
    # input and target shapes must match
    assert input.size() == target.size(), "'input' and 'target' must have the same shape"

    # mask ignore_index if present
    if ignore_index is not None:
        mask = target.clone().ne_(ignore_index)
        mask.requires_grad = False

        input = input * mask
        target = target * mask

    input = flatten(input)
    target = flatten(target)

    # Compute per channel Dice Coefficient
    intersect = (input * target).sum(-1)
    if weight is not None:
        intersect = weight * intersect

    denominator = (input + target).sum(-1)
    return 2. * intersect / denominator.clamp(min=epsilon)

In [None]:
def flatten(tensor):
    """Flattens a given tensor such that the channel axis is first.
    The shapes are transformed as follows:
       (N, C, D, H, W) -> (C, N * D * H * W)
    """
    C = tensor.size(1)
    # new axis order
    axis_order = (1, 0) + tuple(range(2, tensor.dim()))
    # Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
    transposed = tensor.permute(axis_order).contiguous()
    # Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
    return transposed.view(C, -1)

### IoU Loss

In [None]:
class IoULoss(nn.Module):
    def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,
                 square=False):
        """
        paper: https://link.springer.com/chapter/10.1007/978-3-319-50835-1_22
        
        """
        super(IoULoss, self).__init__()

        self.square = square
        self.do_bg = do_bg
        self.batch_dice = batch_dice
        self.apply_nonlin = apply_nonlin
        self.smooth = smooth

    def forward(self, x, y, loss_mask=None):
        shp_x = x.shape

        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        if self.apply_nonlin is not None:
            x = self.apply_nonlin(x)

        tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)


        iou = (tp + self.smooth) / (tp + fp + fn + self.smooth)

        if not self.do_bg:
            if self.batch_dice:
                iou = iou[1:]
            else:
                iou = iou[:, 1:]
        iou = iou.mean()

        return -iou


In [None]:
def get_loss_function(opt):
    if opt.loss == 'dice':
        loss = DiceLoss(sigmoid_normalization=True)
    elif opt.loss.lower() == 'IoU':
        loss = IoULoss()
    else:
        raise ValueError("Only 'dice' loss is supported now.")
    
    return loss

## Optimizer

In [None]:
def get_optimizer(net, opt):
  if isinstance(net, list):
    optims = []
    for network in net:
      optims.append(get_optimizer(network, opt))
    return optims

  else:
    if opt.no_bias_decay:
      weight_params = []
      bias_params = []
      for n, p in net.named_parameters():
          if 'bias' in n:
              bias_params.append(p)
          else:
              weight_params.append(p)
      parameters = [{'params' : bias_params, 'weight_decay' : 0},
                    {'params' : weight_params}]
    else:
      parameters = net.parameters()

    if opt.optim.lower() == 'rmsprop':
      optimizer = optim.RMSprop(parameters, lr=opt.lr, momentum=opt.momentum, weight_decay=opt.wd)
    elif opt.optim.lower() == 'sgd':
      optimizer = optim.SGD(parameters, lr=opt.lr, momentum=opt.momentum, weight_decay=opt.wd)
    elif opt.optim.lower() == 'adam':
      optimizer = optim.Adam(parameters, lr=opt.lr)
 
    return optimizer


# ETC

## AverageMeter

In [None]:
class AverageMeter(object):
  """Computes and stores the average and current value"""
  def __init__(self):
      self.reset()

  def reset(self):
    self.val = 0
    self.avg = 0
    self.sum = 0
    self.count = 0

  def update(self, val, n=1):
    self.val = val
    self.sum += val * n
    self.count += n
    self.avg = self.sum / self.count


## Iou

In [None]:

def iou_modified(preds, labels, opt):
    
    SMOOTH = opt.iou_smooth

    preds = preds.squeeze(1).int()
    labels = labels.squeeze(1).int()

    intersection = (preds & labels).float().sum((1, 2)) # zero if mask=0 or Prediction=0
    union = (preds | labels).float().sum((1, 2)) # zero if both are 0

    iou = (intersection + SMOOTH) / (union + SMOOTH)

    # set Threshold
    # thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10

    return iou.squeeze(0)


def avg_precision(iou_list):
    
    thresh1 = 0.5 
    thresh2 =0.75

    # thresh 0.5
    iou_list = np.array(iou_list)
    iou_list_thresh1= np.where(iou_list > thresh1, 1, 0)
    
    # thresh 0.75
    iou_list_thresh2 = np.where(iou_list > thresh2, 1, 0)
    
    prec_thresh1 = np.sum(iou_list_thresh1) / len(iou_list_thresh1)
    prec_thresh2 = np.sum(iou_list_thresh2) / len(iou_list_thresh2)
    
    iou_mean = (prec_thresh1 + prec_thresh2) / 2.
    
    return prec_thresh1, prec_thresh2, iou_mean

## dice_coef

In [None]:

class DiceCoef(nn.Module):
    """Computes Dice Coefficient
    """

    def __init__(self, epsilon=1e-5, return_score_per_channel=False):
        super(DiceCoef, self).__init__()
        self.epsilon = epsilon
        self.return_score_per_channel = return_score_per_channel

    def forward(self, input, target):
        per_channel_dice = compute_per_channel_dice(input, target, epsilon=self.epsilon)

        if self.return_score_per_channel:
            return per_channel_dice
        else:
            return torch.mean(per_channel_dice)

## Learning rate

In [None]:
def get_current_lr(optimizer):
  return optimizer.state_dict()['param_groups'][0]['lr']


def lr_update(epoch, opt, optimizer):
  prev_lr = get_current_lr(optimizer)
  if 0 <= epoch < opt.lr_warmup_epoch:
    mul_rate = 10 ** (1/opt.lr_warmup_epoch)

    for param_group in optimizer.param_groups:
        param_group['lr'] *= mul_rate
    
    current_lr = get_current_lr(optimizer)
    print("LR warm-up : %.7f to %.7f" % (prev_lr, current_lr))
  
  else:
    if isinstance(opt.lr_decay_epoch, list):
      if (epoch+1) in opt.lr_decay_epoch:
        for param_group in optimizer.param_groups:
          param_group['lr'] = (prev_lr * 0.1)
          print("LR Decay : %.7f to %.7f" % (prev_lr, prev_lr * 0.1))

# Core

## Train

In [None]:
def train(net, dataset_trn, optimizer, criterion, epoch, opt,train_writer):
    print("Start Training...")
    net.train()

    losses, total_dices, total_iou = AverageMeter(), AverageMeter(), AverageMeter()

    for it, (img, mask) in enumerate(dataset_trn):
        # Optimizer
        optimizer.zero_grad()

        # Load Data
        img, mask = torch.Tensor(img).float(), torch.Tensor(mask).float()
        if opt.use_gpu:
            img, mask = img.cuda(non_blocking=True), mask.cuda(non_blocking=True)

        # Predict
        pred = net(img)

        # Loss Calculation
        loss = criterion(pred, mask)

        pred = pred.sigmoid()
        # Backward and step
        loss.backward()
        optimizer.step()
        
        # Calculation Dice Coef Score
        dice = DiceCoef(return_score_per_channel=False)(pred, mask)
        total_dices.update(dice.item(), img.size(0))
        
        # Convert to Binary
        zeros = torch.zeros(pred.size())
        ones = torch.ones(pred.size())
        pred = pred.cpu()

        pred = torch.where(pred > 0.5, ones, zeros).cuda() # threshold 0.99

        # Calculation IoU Score
        iou_score = iou_modified(pred, mask,opt)

        total_iou.update(iou_score.mean().item(), img.size(0))

        # Stack Results
        losses.update(loss.item(), img.size(0))

        if (it==0) or (it+1) % 10 == 0:
            print('Epoch[%3d/%3d] | Iter[%3d/%3d] | Loss %.4f | Dice %.4f | Iou %.4f'
                % (epoch+1, opt.max_epoch, it+1, len(dataset_trn), losses.avg, total_dices.avg, total_iou.avg))

    print(">>> Epoch[%3d/%3d] | Training Loss : %.4f | Dice %.4f | Iou %.4f\n "
        % (epoch+1, opt.max_epoch, losses.avg, total_dices.avg, total_iou.avg))

    train_writer.add_scalar("train/loss", losses.avg, epoch+1)
    train_writer.add_scalar("train/dice", total_dices.avg, epoch+1)
    train_writer.add_scalar("train/IoU", total_iou.avg, epoch+1)

In [None]:

def validate(dataset_val, net, criterion, epoch, opt, best_iou, best_epoch,train_writer):
    print("Start Evaluation...")
    net.eval()

    # Result containers
    losses, total_dices, total_iou = AverageMeter(), AverageMeter(), AverageMeter()

    for it, (img, mask) in enumerate(dataset_val):
        # Load Data
        img, mask = torch.Tensor(img).float(), torch.Tensor(mask).float()
        if opt.use_gpu:
            img, mask = img.cuda(non_blocking=True), mask.cuda(non_blocking=True)

        # Predict
        pred = net(img)

        # Loss Calculation
        loss = criterion(pred, mask)

        pred = pred.sigmoid()

        # Calculation Dice Coef Score
        dice = DiceCoef(return_score_per_channel=False)(pred, mask)
        total_dices.update(dice.item(), img.size(0))
        
        # Convert to Binary
        zeros = torch.zeros(pred.size())
        ones = torch.ones(pred.size())
        pred = pred.cpu()

        pred = torch.where(pred > 0.5, ones, zeros).cuda()
        
        # Calculation IoU Score
        iou_score = iou_modified(pred, mask,opt)

        total_iou.update(iou_score.mean().item(), img.size(0))

        # Stack Results
        losses.update(loss.item(), img.size(0))

        # if (it==0) or (it+1) % 10 == 0:
        #     print('Epoch[%3d/%3d] | Iter[%3d/%3d] | Loss %.4f | Dice %.4f | Iou %.4f'
        #         % (epoch+1, opt.max_epoch, it+1, len(dataset_trn), losses.avg, total_dices.avg, total_iou.avg))

    print(">>> Epoch[%3d/%3d] | Test Loss : %.4f | Dice %.4f | Iou %.4f"
        % (epoch+1, opt.max_epoch, losses.avg, total_dices.avg, total_iou.avg))

    train_writer.add_scalar("valid/loss", losses.avg, epoch+1)
    train_writer.add_scalar("valid/dice", total_dices.avg, epoch+1)
    train_writer.add_scalar("valid/IoU", total_iou.avg, epoch+1)

    # Update Result
    if total_iou.avg > best_iou:
        print('Best Score Updated...')
        best_iou = total_iou.avg
        best_epoch = epoch

        # # Remove previous weights pth files
        # for path in glob('%s/*.pth' % opt.exp):
        #     os.remove(path)

        model_filename = '%s/epoch_%04d_iou_%.4f_loss_%.8f.pth' % (opt.exp, epoch+1, best_iou, losses.avg)

        # Single GPU
        if opt.ngpu == 1:
            torch.save(net.state_dict(), model_filename)
        # Multi GPU
        else:
            torch.save(net.module.state_dict(), model_filename)

    print('>>> Current best: IoU: %.8f in %3d epoch\n' % (best_iou, best_epoch+1))
    
    return best_iou, best_epoch

In [None]:
def evaluate(dataset_val, net, opt, save_dir):
    print("Start Evaluation...")
    net.eval()

    iou_scores = []
    for idx, (img, mask) in enumerate(dataset_val):
        # Load Data
        img = torch.Tensor(img).float()
        if opt.use_gpu:
            img = img.cuda(non_blocking=True)

        # Predict
        with torch.no_grad():
            pred = net(img)
	
            y = pred.sigmoid()
            dice = DiceCoef(return_score_per_channel=False)(y, mask.cuda())
            
            # Convert to Binary
            zeros = torch.zeros(y.size())
            ones = torch.ones(y.size())
            y = y.cpu()

            y = torch.where(y > opt.threshold, ones, zeros) # threshold 0.99
            y = Variable(y).cuda()

            iou_score = iou_modified(y, mask.cuda(),opt)

            if idx%10 ==0:
                print("{}/{} - dice {:.4f} | IoU {:.4f}".format(idx+1, len(dataset_val), dice.item(), iou_score.item()))

            iou_scores.append(iou_score.item())

            if iou_score < 0.75:
                ###### Plot & Save Figure #########
                origin = img.cpu().numpy()[0,0,:,:] 
                pred = y.cpu().numpy()[0,0,:,:]
                true = mask.cpu().numpy()[0,0,:,:]	

                fig = plt.figure()

                ax1 = fig.add_subplot(1,3,1)
                ax1.axis("off")
                ax1.imshow(origin, cmap = "gray")

                ax2= fig.add_subplot(1,3,2)
                ax2.axis("off")
                ax2.imshow(origin,cmap = "gray")
                ax2.contour(true, cmap='Greens', linewidths=0.3)

                ax3 = fig.add_subplot(1,3,3)
                ax3.axis("off")
                ax3.imshow(origin,cmap = "gray")
                ax3.contour(pred, cmap='Reds', linewidths=0.3)

                plt.axis('off')
                plt.subplots_adjust(left = 0, bottom = 0, right = 1, top = 1, hspace = 0, wspace = 0)

                plt.savefig(opt.exp + "/" + opt.save_dir + "/original_label_pred_image_file_{}_dice_{:.4f}_iou_{:.4f}.png".format(idx, dice.item(),iou_score.item()),bbox_inces='tight', dpi=300)
                plt.cla()
                plt.close(fig)
                plt.gray()
                ###############################

    prec_thresh1, prec_thresh2, iou_mean = avg_precision(iou_scores)

    print("Presion with threshold 0.5: {}, 0.75: {}, Average: {}".format(prec_thresh1, prec_thresh2, iou_mean))

# Main code

## Options

In [None]:
def str2bool(v):
    if isinstance(v, bool):
       return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def parse_option(print_option=True):    
    p = argparse.ArgumentParser(description='')

    # Data Directory
    p.add_argument('--data_root', default='../DataSet', type=str, help='root directory of dataset files.')
    
    # Data augmentation
    p.add_argument('--rot_factor', default=30, type=float)
    p.add_argument('--scale_factor', default=0.15, type=float)
    p.add_argument('--flip', default='True', type=str2bool)
    p.add_argument('--trans_factor', default=0.1, type=float)

    # Input image
    p.add_argument('--crop_size', default=300, type=float, help='Center crop width')
    p.add_argument('--input_size', default=224, type=int, help='input resolution using resize process')
    p.add_argument('--w_min', default=-100., type=float, help='Min value of HU Windowing')
    p.add_argument('--w_max', default=300., type=float, help='Max value of HU Windowing')

    # Network
    p.add_argument('--base_n_filter', default=32, type=int)

    # Optimizer
    p.add_argument('--optim', default='Adam', type=str, help='RMSprop | SGD | Adam')
    p.add_argument('--lr', default=2e-5, type=float)
    p.add_argument('--lr_decay_epoch', default='150', type=str, help="decay epochs with comma (ex - '20,40,60')")
    p.add_argument('--lr_warmup_epoch', default=0, type=int)
    p.add_argument('--momentum', default=0.99, type=float, help='momentum')
    p.add_argument('--wd', default=1e-4, type=float, help='weight decay')
    p.add_argument('--no_bias_decay', default='True', type=str2bool, help='weight decay for bias')

    # Hyper-parameter
    p.add_argument('--batch_size', default=16, type=int, help='use 1 batch size in 3D training.')
    p.add_argument('--start_epoch', default=0, type=int)
    p.add_argument('--max_epoch', default=300, type=int)
    p.add_argument('--threshold', default=0.9, type=float)

    # Loss function
    p.add_argument('--loss', default='dice', type=str)
    p.add_argument('--iou_smooth', default=1e-6, type=float, help='avoid 0/0')

    # Resume trained network
    p.add_argument('--resume', default='', type=str, help="pth file path to resume")

    # Resource option
    p.add_argument('--workers', default=10, type=int, help='#data-loading worker-processes')
    p.add_argument('--use_gpu', default="True", type=str2bool, help='use gpu or not (cpu only)')
    p.add_argument('--gpu_id', default="3", type=str)

    # Output directory
    p.add_argument('--exp', default='./ckpt_crop_300', type=str, help='checkpoint dir.')
    p.add_argument('--save_dir', default='plots', type=str, help='evaluation plot directory')


    opt = p.parse_args("")
    
    # Make output directory
    if not os.path.exists(opt.exp):
        os.makedirs(opt.exp)

    # GPU Setting
    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
    os.environ["CUDA_VISIBLE_DEVICES"]=opt.gpu_id

    if opt.use_gpu:
        opt.ngpu = len(opt.gpu_id.split(","))
    else:
        opt.gpu_id = 'cpu'
        opt.ngpu = 'cpu'

    # lr decay setting
    if ',' in opt.lr_decay_epoch:
        opt.lr_decay_epoch = opt.lr_decay_epoch.split(',')
        opt.lr_decay_epoch = [int(epoch) for epoch in opt.lr_decay_epoch]

    if print_option:
        print("\n==================================== Options ====================================\n")
    
        print('   Data root : %s' % (opt.data_root))
        print()
        print('   Data Crop size : Crop to (%d,%d)' % (opt.crop_size,opt.crop_size))
        print('   Data input size : Resized to (%d,%d)' % (opt.input_size,opt.input_size))
        print()
        print('   Base #Filters of Network : %d' % (opt.base_n_filter))
        print()
        print('   Optimizer : %s (weight decay %f)' % (opt.optim, opt.wd))
        print('   Loss function : %s' % opt.loss)
        print('   Batch size : %d' % opt.batch_size)
        print('   Max epoch : %d' % opt.max_epoch)
        print('   Learning rate : %s (linear warm-up until %s / decay at %s)' % (opt.lr, opt.lr_warmup_epoch, opt.lr_decay_epoch))
        print()
        print('   Resume pre-trained weights path : %s' % opt.resume)
        print('   Output dir : %s' % opt.exp)
        print()
        print('   GPU ID : %s' % opt.gpu_id)
        print('   #Workers : %s' % opt.workers)
        print('   pytorch version: %s (CUDA : %s)' % (torch.__version__, torch.cuda.is_available()))
        print("\n=================================================================================\n")

    return opt


## Train

In [None]:
# Option
opt = parse_option(print_option=True)

# Data Loader
dataset_trn, dataset_val = get_dataloader(opt)

# Network
net = create_model(opt)

# Loss Function
criterion = get_loss_function(opt)

# Optimizer
optimizer = get_optimizer(net, opt)

# Tensorboard
train_writer = SummaryWriter(opt.exp+'/logs/')

# Initial Best Score
best_iou, best_epoch = [0, 0]

for epoch in range(opt.start_epoch, opt.max_epoch):
    # Train
    train(net, dataset_trn, optimizer, criterion, epoch, opt,train_writer)

    # Evaluate
    best_iou, best_epoch = validate(dataset_val, net, criterion, epoch, opt, best_iou, best_epoch,train_writer)

    lr_update(epoch, opt, optimizer)

print('Training done')

## Evaluation

In [None]:
# Option
opt = parse_option(print_option=False)

# Data Loader
_, dataset_val = get_dataloader(opt)

# Network
net = create_model(opt)

# Evaluate
save_dir = "./ckpt_beta_100_300"
evaluate(dataset_val, net, opt, save_dir)