In [None]:
import os
import numpy as np
import glob
import nibabel as nib
from skimage.util import random_noise

from fastai.vision import *
from fastai.callbacks import *
from fastai.callbacks.hooks import *
from fastai.utils.mem import *


In [None]:
def random_seed(seed_value, use_cuda):
    np.random.seed(seed_value) # cpu vars
    torch.manual_seed(seed_value) # cpu  vars
    random.seed(seed_value) # Python
    if use_cuda: 
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value) # gpu vars
        torch.backends.cudnn.deterministic = True  #needed
        torch.backends.cudnn.benchmark = False

In [None]:
def acc_camvid(input, target):
    target = target.squeeze(1)
    mask = target != void_code
    return (input.argmax(dim=1)[mask]==target[mask]).float().mean()

In [None]:
def dice(input, target):
    input = F.softmax(input, dim=1)
    target = torch.squeeze(target,dim=1)

    eps = 0.0001
    encoded_target = input.detach() * 0
    encoded_target = encoded_target.scatter_(1, target.unsqueeze(1), 1)

    intersection = input * encoded_target
    numerator = 2 * intersection.sum(0).sum(1).sum(1)

    denominator = input + encoded_target
    denominator = denominator.sum(0).sum(1).sum(1) + eps
    loss_per_channel = numerator / denominator

    return loss_per_channel.sum() / input.size(1)

In [None]:
class DiceLoss(nn.Module):
    """
    Dice Loss for a batch of samples
    """
    def __init__(self):
        super().__init__()

    def forward(self, output, target):
        """
        Forward pass
        :param output: NxCxHxW logits
        :param target: NxHxW LongTensor
        :return: torch.tensor
        """
        output = F.softmax(output, dim=1)
        return self._dice_loss_multichannel(output,target)

    @staticmethod
    def _dice_loss_multichannel(output, target):
        """
        Forward pass
        :param output: NxCxHxW Variable
        :param target: NxHxW LongTensor
        :param weights: C FloatTensor
        :param ignore_index: int index to ignore from loss
        :param binary: bool for binarized one chaneel(C=1) input
        :return:
        """
        target = torch.squeeze(target)
        eps = 0.0001
        encoded_target = output.detach() * 0

        encoded_target = encoded_target.scatter_(1, target.unsqueeze(1), 1)

        weights = 1

        intersection = output * encoded_target
        numerator = 2 * intersection.sum(0).sum(1).sum(1)
        denominator = output + encoded_target

        denominator = denominator.sum(0).sum(1).sum(1) + eps
        loss_per_channel = weights * (1 - (numerator / denominator))

        return loss_per_channel.sum() / output.size(1)

In [None]:
class CrossEntropyLoss2d(nn.Module):
    """
    Standard pytorch weighted nn.CrossEntropyLoss
    """
    def __init__(self):
        super(CrossEntropyLoss2d, self).__init__()
        self.nll_loss = nn.CrossEntropyLoss()

    def forward(self, inputs, targets):
        """
        Forward pass

        :param inputs: torch.tensor (NxC)
        :param targets: torch.tensor (N)
        :return: scalar
        """
        targets = torch.squeeze(targets)
        return self.nll_loss(inputs, targets)

In [None]:
class CombinedLoss(nn.Module):
    """
    A combination of dice  and cross entropy loss
    """

    def __init__(self):
        super(CombinedLoss, self).__init__()
        self.cross_entropy_loss = CrossEntropyLoss2d()
        self.dice_loss = DiceLoss()

    def forward(self, input, target, weight=True):
        """
        Forward pass

        :param input: torch.tensor (NxCxHxW)
        :param target: torch.tensor (NxHxW)
        :param weight: torch.tensor (NxHxW)
        :return: scalar
        """
        weight = self.rtn_weight(torch.squeeze(target))

        # input_soft = F.softmax(input, dim=1)
        y_2 = torch.mean(self.dice_loss(input, target))
        if weight is True:
            y_1 = torch.mean(self.cross_entropy_loss.forward(input, target))
        else:
            y_1 = torch.mean(
                torch.mul(self.cross_entropy_loss.forward(input, target), weight))
        return y_1 + y_2

    def rtn_weight(self, labels):
        labels = labels.cpu().numpy()
        class_weights = np.zeros_like(labels)

        grads = np.gradient(labels) 
        edge_weights = (grads[0] ** 2 + grads[1] ** 2 ) > 0 
        class_weights += 2 * edge_weights
        
        return torch.from_numpy(class_weights).to(0)

In [None]:
def r_noise(x,use_on_y=True):
    x = random_noise(x)
    return torch.from_numpy(x).type(torch.FloatTensor)

## path for Dataset

In [None]:
seed = 42
random_seed(seed,True)

path = Path('./')
path_img = Path('./img(random)_10000_3ch')
path_lbl = Path('./gt(random)_10000_3ch')
fnames = get_image_files(path_img)
lbl_names = get_image_files(path_lbl)

print(f"fnames : {fnames[:3]}, label names : {lbl_names[:3]}")

### Checking Data

In [None]:
img_f = fnames[0]
img = open_image(img_f)
img.show(figsize=(5,5), cmap='gray')
get_y_fn = lambda x: path_lbl/f'{x.stem}_P{x.suffix}'
mask = open_mask(get_y_fn(img_f))
mask.show(figsize=(5,5), alpha=1)

src_size = np.array(mask.shape[1:])
print(f"image size : {src_size}")

### Label Codes

In [None]:
codes = np.array(['Void', 'Fat', 'Muscle', 'Visceral_fat'], dtype=str); codes

In [None]:
name2id = {v:k for k,v in enumerate(codes)}
void_code = name2id['Void']

### Define Noise for fastai

In [None]:
rn = TfmPixel(r_noise)
tfms = get_transforms(flip_vert=True, max_rotate=180.0, max_zoom=1.5, max_warp = 0.2 )
new_tfms = (tfms[0] + [rn()], tfms[1])
new_tfms[0][7].use_on_y = False
new_tfms[0][7].p = 0.5
size = src_size

### Checking GPU

In [None]:
free = gpu_mem_get_free_no_cache()
# the max size of bs depends on the available GPU RAM
if free > 8200: bs=4
else:           bs=2
print(f"using bs={bs}, have {free}MB of GPU RAM free")

### Define DataLoaders

In [None]:
src = (SegmentationItemList.from_folder(path_img)
       .split_by_rand_pct(valid_pct=0.1)
       .label_from_func(get_y_fn, classes=codes))
data = (src.transform(new_tfms, size=size, tfm_y=True)
        .databunch(bs=bs, num_workers=0)
        .normalize(imagenet_stats))

### Training Models

In [None]:
loss_func = CombinedLoss
metrics = [ dice,acc_camvid ]
wd = 1e-2

learn = unet_learner(data, models.resnet34, loss_func = loss_func(), metrics=metrics)
lr_find(learn)
learn.recorder.plot()
lr = 3e-4

In [None]:
learn.summary()

In [None]:
learn.fit_one_cycle(10, lr)

### Save Models

In [None]:
learn.save(f"path - ")