In [1]:
!git clone "https://github.com/rezazad68/smunet.git"

Cloning into 'smunet'...
remote: Enumerating objects: 41, done.[K
remote: Counting objects: 100% (41/41), done.[K
remote: Compressing objects: 100% (41/41), done.[K
remote: Total 41 (delta 9), reused 0 (delta 0), pack-reused 0[K
Receiving objects: 100% (41/41), 650.61 KiB | 14.46 MiB/s, done.
Resolving deltas: 100% (9/9), done.


In [2]:
!pip install opendatasets --upgrade --quiet

In [3]:
# get the data from
import opendatasets as od
dataset_url = 'https://www.kaggle.com/datasets/sanglequang/brats2018'
od.download(dataset_url)

# https://www.kaggle.com/datasets/sanglequang/brats2018
# https://www.kaggle.com/datasets/awsaf49/brats20-dataset-training-validation


Please provide your Kaggle credentials to download this dataset. Learn more: http://bit.ly/kaggle-creds
Your Kaggle username: iamsanaullah
Your Kaggle Key: ··········
Downloading brats2018.zip to ./brats2018


100%|██████████| 3.18G/3.18G [00:35<00:00, 95.9MB/s]





data.py

In [4]:
#!/usr/bin/env python3
# encoding: utf-8
# Code modified from https://github.com/Wangyixinxin/ACN
import glob
import os
import numpy as np
import nibabel as nib
import torch
from torch.utils.data import Dataset, DataLoader
import random

class Brats2018(Dataset):

    def __init__(self, patients_dir, crop_size, modes, train=True, normalization = True):
        self.patients_dir = patients_dir
        self.modes = modes
        self.train = train
        self.crop_size = crop_size
        self.normalization = normalization

    def __len__(self):
        return len(self.patients_dir)

    def __getitem__(self, index):
        patient_dir = self.patients_dir[index]
        volumes = []
        modes = list(self.modes) + ['seg']
        for mode in modes:
            patient_id = os.path.split(patient_dir)[-1]
            volume_path = os.path.join(patient_dir, patient_id + "_" + mode + '.nii')
            volume = nib.load(volume_path).get_data()
            if not mode == "seg" and self.normalization:
                volume = self.normlize(volume)  # [0, 1.0]
            volumes.append(volume)                  # [h, w, d]
        seg_volume = volumes[-1]
        volumes = volumes[:-1]
        volume, seg_volume = self.aug_sample(volumes, seg_volume)
        ed_volume = (seg_volume == 2) # peritumoral edema ED
        net_volume = (seg_volume == 1) # enhancing tumor core NET
        et_volume = (seg_volume == 4) # enhancing tumor ET
        bg_volume = (seg_volume == 0)

        seg_volume = [ed_volume, net_volume, et_volume, bg_volume]
        seg_volume = np.concatenate(seg_volume, axis=0).astype("float32")

        return (torch.tensor(volume.copy(), dtype=torch.float),
                torch.tensor(seg_volume.copy(), dtype=torch.float))


    def aug_sample(self, volumes, mask):
        """
            Args:
                volumes: list of array, [h, w, d]
                mask: array [h, w, d], segmentation volume
            Ret: x, y: [channel, h, w, d]

        """
        x = np.stack(volumes, axis=0)       # [N, H, W, D]
        y = np.expand_dims(mask, axis=0)    # [channel, h, w, d]

        if self.train:
            # crop volume
            x, y = self.random_crop(x, y)
            if random.random() < 0.5:
                x = np.flip(x, axis=1)
                y = np.flip(y, axis=1)
            if random.random() < 0.5:
                x = np.flip(x, axis=2)
                y = np.flip(y, axis=2)
            if random.random() < 0.5:
                x = np.flip(x, axis=3)
                y = np.flip(y, axis=3)
        else:
            x, y = self.center_crop(x, y)

        return x, y

    def random_crop(self, x, y):
        """
        Args:
            x: 4d array, [channel, h, w, d]
        """
        crop_size = self.crop_size
        height, width, depth = x.shape[-3:]
        sx = random.randint(0, height - crop_size[0] - 1)
        sy = random.randint(0, width - crop_size[1] - 1)
        sz = random.randint(0, depth - crop_size[2] - 1)
        crop_volume = x[:, sx:sx + crop_size[0], sy:sy + crop_size[1], sz:sz + crop_size[2]]
        crop_seg = y[:, sx:sx + crop_size[0], sy:sy + crop_size[1], sz:sz + crop_size[2]]

        return crop_volume, crop_seg

    def center_crop(self, x, y):
        crop_size = self.crop_size
        height, width, depth = x.shape[-3:]
        sx = (height - crop_size[0] - 1) // 2
        sy = (width - crop_size[1] - 1) // 2
        sz = (depth - crop_size[2] - 1) // 2
        crop_volume = x[:, sx:sx + crop_size[0], sy:sy + crop_size[1], sz:sz + crop_size[2]]
        crop_seg = y[:, sx:sx + crop_size[0], sy:sy + crop_size[1], sz:sz + crop_size[2]]

        return crop_volume, crop_seg

    def normlize(self, x):
        return (x - x.min()) / (x.max() - x.min())


    def normlize_brain(self, x, epsilon=1e-8):
        average        = x[np.nonzero(x)].mean()
        std            = x[np.nonzero(x)].std() + epsilon
        mask           = x>0
        sub_mean       = np.where(mask, x-average, x)
        x_normalized   = np.where(mask, sub_mean/std, x)
        return x_normalized

def split_dataset(data_root, test_p):
    patients_dir = glob.glob(os.path.join(data_root, "*GG", "Brats18*"))
    patients_dir.sort()
    N = int(len(patients_dir)*test_p)
    train_patients_list =  patients_dir[N:]
    val_patients_list   =  patients_dir[:N]

    return train_patients_list, val_patients_list

def make_data_loaders(config):
    train_list, val_list = split_dataset(config['path_to_data'], float(config['test_p']))
    crop_size = np.zeros((3))
    crop_size[0] = config['inputshape'][0]
    crop_size[1] = config['inputshape'][1]
    crop_size[2] = config['inputshape'][2]
    crop_size    = crop_size.astype(np.uint16)
    crop_size    = (160, 192, 128)
    train_ds = Brats2018(train_list, crop_size=crop_size, modes=config['modalities'], train=True)
    val_ds = Brats2018(val_list,     crop_size=crop_size, modes=config['modalities'], train=False)
    loaders = {}
    loaders['train'] = DataLoader(train_ds, batch_size=int(config['batch_size_tr']),
                                  num_workers=4,
                                  pin_memory=True,
                                  shuffle=True)
    loaders['eval'] = DataLoader(val_ds, batch_size=int(config['batch_size_va']),
                                  num_workers=4,
                                  pin_memory=True,
                                  shuffle=False)
    return loaders

models => unet.py

In [5]:
#!/usr/bin/env python3
# encoding: utf-8
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, n_groups=8):
        super(BasicBlock, self).__init__()
        self.gn1 = nn.GroupNorm(n_groups, in_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.gn2 = nn.GroupNorm(n_groups, in_channels)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1))

    def forward(self, x):
        residul = x
        x = self.relu1(self.gn1(x))
        x = self.conv1(x)

        x = self.relu2(self.gn2(x))
        x = self.conv2(x)
        x = x + residul

        return x

class UNet3D(nn.Module):
    """3d unet
    Ref:
        3D MRI brain tumor segmentation using autoencoder regularization. Andriy Myronenko
    Args:
        input_shape: tuple, (height, width, depth)
    """

    def __init__(self, input_shape, in_channels=4, out_channels=3, init_channels=32, p=0.2):
        super(UNet3D, self).__init__()
        self.input_shape = input_shape
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.init_channels = init_channels
        self.make_encoder()
        self.make_decoder()
        self.dropout = nn.Dropout(p=p)

    def make_encoder(self):
        init_channels = self.init_channels
        self.conv1a = nn.Conv3d(self.in_channels, init_channels, (3, 3, 3), padding=(1, 1, 1))
        self.conv1b = BasicBlock(init_channels, init_channels)  # 32

        self.ds1 = nn.Conv3d(init_channels, init_channels * 2, (3, 3, 3), stride=(2, 2, 2),
                             padding=(1, 1, 1))  # down sampling and add channels

        self.conv2a = BasicBlock(init_channels * 2, init_channels * 2)
        self.conv2b = BasicBlock(init_channels * 2, init_channels * 2)

        self.ds2 = nn.Conv3d(init_channels * 2, init_channels * 4, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))

        self.conv3a = BasicBlock(init_channels * 4, init_channels * 4)
        self.conv3b = BasicBlock(init_channels * 4, init_channels * 4)

        self.ds3 = nn.Conv3d(init_channels * 4, init_channels * 8, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))

        self.conv4a = BasicBlock(init_channels * 8, init_channels * 8)
        self.conv4b = BasicBlock(init_channels * 8, init_channels * 8)
        self.conv4c = BasicBlock(init_channels * 8, init_channels * 8)
        self.conv4d = BasicBlock(init_channels * 8, init_channels * 8)

    def make_decoder(self):
        init_channels = self.init_channels
        self.up4conva = nn.Conv3d(init_channels * 8, init_channels * 4, (1, 1, 1))
        self.up4 = nn.Upsample(scale_factor=2)  # mode='bilinear'
        self.up4convb = BasicBlock(init_channels * 4, init_channels * 4)

        self.up3conva = nn.Conv3d(init_channels * 4, init_channels * 2, (1, 1, 1))
        self.up3 = nn.Upsample(scale_factor=2)
        self.up3convb = BasicBlock(init_channels * 2, init_channels * 2)

        self.up2conva = nn.Conv3d(init_channels * 2, init_channels, (1, 1, 1))
        self.up2 = nn.Upsample(scale_factor=2)
        self.up2convb = BasicBlock(init_channels, init_channels)

        self.pool     = nn.MaxPool3d(kernel_size = 2)
        self.convc    = nn.Conv3d(init_channels * 20, init_channels * 8, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0))
        self.convco   = nn.Conv3d(init_channels * 16, init_channels * 8, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0))
        self.up1conv  = nn.Conv3d(init_channels, self.out_channels, (1, 1, 1))

    def forward(self, x):
        c1 = self.conv1a(x)
        c1 = self.conv1b(c1)
        c1d = self.ds1(c1)
        #print("c1d shape:", c1d.shape)

        c2 = self.conv2a(c1d)
        c2 = self.conv2b(c2)
        c2d = self.ds2(c2)
        c2d_p = self.pool(c2d)
#         print("c2d shape:", c2d_p.shape)

        c3 = self.conv3a(c2d)
        c3 = self.conv3b(c3)
        c3d = self.ds3(c3)
#         print("c3d shape:", c3d.shape)

        c4 = self.conv4a(c3d)
        c4 = self.conv4b(c4)
        c4 = self.conv4c(c4)
        c4d = self.conv4d(c4) #[1, 128, 20, 24, 16]
#         print("c4d shape:", c4d.shape)

        style = self.convc(torch.cat([c2d_p, c3d, c4d], dim = 1))
        content = c4d

        c4d = self.convco(torch.cat([style, content], dim = 1))

        c4d = self.dropout(c4d)

        u4 = self.up4conva(c4d)
        u4 = self.up4(u4)
        u4 = u4 + c3
        u4 = self.up4convb(u4)

        u3 = self.up3conva(u4)
        u3 = self.up3(u3)
        u3 = u3 + c2
        u3 = self.up3convb(u3)

        u2 = self.up2conva(u3)
        u2 = self.up2(u2)
        u2 = u2 + c1
        u2 = self.up2convb(u2)

        uout = self.up1conv(u2)
        uout = F.sigmoid(uout)

        return uout, style, content


class Unet_module(nn.Module):

    def __init__(self, input_shape, in_channels=4, out_channels=3, init_channels=32, p=0.2):
        super(Unet_module, self).__init__()
        self.unet = UNet3D(input_shape, in_channels, out_channels, init_channels, p)

    def forward(self, x):
        uout, style, content = self.unet(x)
        return uout, style, content

model=> build.py

In [6]:
#!/usr/bin/env python3
# encoding: utf-8

def build_model(inp_shape = (160, 192, 128), inp_dim1=4, inp_dim2 = 1):
    model_full    = Unet_module(inp_shape,
                      in_channels=inp_dim1,
                      out_channels=4,
                      init_channels=16,
                      p=0.2)

    model_missing = Unet_module(inp_shape,
                      in_channels=inp_dim2,
                      out_channels=4,
                      init_channels=16,
                      p=0.2)
    return model_full, model_missing


model => discriminator.py

In [7]:
from torch import nn

def get_style_discriminator(num_classes, ndf=64):
    return nn.Sequential(
        nn.Conv3d(num_classes, ndf, kernel_size=4, stride=2, padding=1),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        nn.Conv3d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        nn.Conv3d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        nn.Conv3d(ndf * 4, 1, kernel_size=(2,3,2), stride=1, padding=0)
    )


solver => build.py

In [8]:
#!/usr/bin/env python3
# encoding: utf-8
import torch
from torch.optim import lr_scheduler

def make_optimizer_double(config, model1, model2):
    lr = float(config['lr'])
    print('initial learning rate is ', lr)
    optimizer = torch.optim.Adam([
    {'params': model1.parameters()},
    {'params': model2.parameters()}], lr=lr, weight_decay=float(config['weight_decay']))
    scheduler = PolyLR(optimizer, max_epoch=int(config['epochs']), power=float(config['power']))

    return optimizer, scheduler


class PolyLR(lr_scheduler._LRScheduler):
    """Set the learning rate of each parameter group to the initial lr decayed
    by gamma every epoch. When last_epoch=-1, sets initial lr as lr.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        gamma (float): Multiplicative factor of learning rate decay.
        last_epoch (int): The index of last epoch. Default: -1.
    """

    def __init__(self, optimizer, max_epoch, power=0.9, last_epoch=-1):
        self.max_epoch = max_epoch
        self.power = power
        super(PolyLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        return [base_lr * (1 - self.last_epoch / self.max_epoch) ** self.power
                for base_lr in self.base_lrs]



losses.py

In [9]:
#!/usr/bin/env python3
# encoding: utf-8
# Modified from https://github.com/Wangyixinxin/ACN
import torch
from torch.nn import functional as F
import numpy as np
import numpy as np
from matplotlib import cm
import matplotlib.pyplot as plt
import numpy as np
import cv2
from PIL import Image
import cv2
import torch.nn as nn

def sigmoid_rampup(current, rampup_length):
    """Exponential rampup from https://arxiv.org/abs/1610.02242"""
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))

def get_current_consistency_weight(epoch, consistency = 10, consistency_rampup = 20.0):
    # Consistency ramp-up from https://arxiv.org/abs/1610.02242
    return consistency * sigmoid_rampup(epoch, consistency_rampup)

def bce_loss(y_pred, y_label):
    y_truth_tensor = torch.FloatTensor(y_pred.size())
    y_truth_tensor.fill_(y_label)
    y_truth_tensor = y_truth_tensor.to(y_pred.get_device())
    return nn.BCEWithLogitsLoss()(y_pred, y_truth_tensor)


def dice_loss(input, target):
    """soft dice loss"""
    eps = 1e-7
    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()

    return 1 - 2. * intersection / ((iflat ** 2).sum() + (tflat ** 2).sum() + eps)

def gram_matrix(input):
    a, b, c, d, e = input.size()
    features = input.view(a * b, c * d * e)
    G = torch.mm(features, features.t())  # compute the gram product
    return G.div(a * b * c * d * e)

def get_style_loss(sf, sm):
    g_f = gram_matrix(sf)
    g_m = gram_matrix(sm)
    channels = sf.size(1)
    size     = sf.size(2)*sf.size(3)
    sloss = torch.sum(torch.square(g_f-g_m)) / (4.0 * (channels ** 2) * (size ** 2))
    return sloss*0.0001

def unet_Co_loss(config, batch_pred_full, content_full, batch_y, batch_pred_missing, content_missing, sf, sm, epoch):
    loss_dict = {}
    loss_dict['ed_dc_loss']  = dice_loss(batch_pred_full[:, 0], batch_y[:, 0])  # whole tumor
    loss_dict['net_dc_loss'] = dice_loss(batch_pred_full[:, 1], batch_y[:, 1])  # tumore core
    loss_dict['et_dc_loss']  = dice_loss(batch_pred_full[:, 2], batch_y[:, 2])  # enhance tumor

    loss_dict['ed_miss_dc_loss']  = dice_loss(batch_pred_missing[:, 0], batch_y[:, 0])  # whole tumor
    loss_dict['net_miss_dc_loss'] = dice_loss(batch_pred_missing[:, 1], batch_y[:, 1])  # tumore core
    loss_dict['et_miss_dc_loss']  = dice_loss(batch_pred_missing[:, 2], batch_y[:, 2])  # enhance tumor

    ## Dice loss predictions
    loss_dict['loss_dc'] = loss_dict['ed_dc_loss'] + loss_dict['net_dc_loss'] + loss_dict['et_dc_loss']
    loss_dict['loss_miss_dc'] = loss_dict['ed_miss_dc_loss'] + loss_dict['net_miss_dc_loss'] + loss_dict['et_miss_dc_loss']

    ## Consistency loss
    loss_dict['ed_mse_loss']  = F.mse_loss(batch_pred_full[:, 0], batch_pred_missing[:, 0], reduction='mean')
    loss_dict['net_mse_loss'] = F.mse_loss(batch_pred_full[:, 1], batch_pred_missing[:, 1], reduction='mean')
    loss_dict['et_mse_loss']  = F.mse_loss(batch_pred_full[:, 2], batch_pred_missing[:, 2], reduction='mean')
    loss_dict['consistency_loss'] = loss_dict['ed_mse_loss'] + loss_dict['net_mse_loss'] + loss_dict['et_mse_loss']

    ## Content loss
    loss_dict['content_loss'] = F.mse_loss(content_full, content_missing, reduction='mean')

    ## Style loss
    sloss = get_style_loss(sf, sm)


    ## Weights for each loss the lamba values
    weight_content = float(config['weight_content'])
    weight_missing = float(config['weight_mispath'])
    weight_full    = 1 - float(config['weight_mispath'])

    weight_consistency = get_current_consistency_weight(epoch)
    loss_dict['loss_Co'] = weight_full * loss_dict['loss_dc'] + weight_missing * loss_dict['loss_miss_dc'] + \
                            weight_consistency * loss_dict['consistency_loss'] + weight_content * loss_dict['content_loss']+sloss

    return loss_dict

def get_losses(config):
    losses = {}
    losses['co_loss'] = unet_Co_loss
    return losses


class DiceLoss(torch.nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, prediction, target):
        prediction = torch.Tensor(prediction)
        target = torch.Tensor(target)
        iflat = prediction.reshape(-1)
        tflat = target.reshape(-1)
        intersection = (iflat * tflat).sum()

        return ((2.0 * intersection + self.smooth) / (iflat.sum() + tflat.sum() + self.smooth)).numpy()

train.ipynb


In [10]:
#!/usr/bin/env python3
# encoding: utf-8
import os
import random
import torch
import warnings
import numpy as np

def init_env(gpu_id='0', seed=42):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True
    warnings.filterwarnings('ignore')

make_data_loader

In [11]:
## Config file
lr:              1e-4 # Initial learning rate
weight_decay:    1e-5
power:           0.9
epochs:          2 # Number of epochs to train the model
number_classes:  4 # Number of classes in the target dataset
batch_size_tr:   1 # Batch size for train
batch_size_va:   1 # Batch size for validationn
modalities:      ['flair', 't1', 't1ce', 't2']# List of modalities needd to be used for training and evaluating the model
path_to_data:    '/content/brats2018/MICCAI_BraTS_2018_Data_Training/' # path to dataset
path_to_log:     './results/' # path to save results
progress_p:      0.1 # value between 0-1 shows the number of time we need to report training progress in each epoch
validation_p:    0.1 # validation percentage
test_p:          0.2 # Test percentage (20%)
inputshape:      [160, 192, 128]
weight_mispath:  0.6
weight_content:  0.2

In [12]:
def make_data_loaders(config):
    train_list, val_list = split_dataset(config['path_to_data'], float(config['test_p']))
    print(train_list)
    print(val_list)
    crop_size = np.zeros((3))
    crop_size[0] = config['inputshape'][0]
    crop_size[1] = config['inputshape'][1]
    crop_size[2] = config['inputshape'][2]
    crop_size    = crop_size.astype(np.uint16)
    crop_size    = (160, 192, 128)
    train_ds = Brats2018(train_list, crop_size=crop_size, modes=config['modalities'], train=True)
    val_ds = Brats2018(val_list,     crop_size=crop_size, modes=config['modalities'], train=False)
    loaders = {}
    loaders['train'] = DataLoader(train_ds, batch_size=int(config['batch_size_tr']),
                                  num_workers=4,
                                  pin_memory=True,
                                  shuffle=True)
    loaders['eval'] = DataLoader(val_ds, batch_size=int(config['batch_size_va']),
                                  num_workers=4,
                                  pin_memory=True,
                                  shuffle=False)
    return loaders



In [14]:

# The code is extensively uses the ACN implementation, please see:
## https://github.com/Wangyixinxin/ACN##
#!/usr/bin/env python3
# encoding: utf-8
import yaml
# from data import make_data_loaders
# from models import build_model
# from models.discriminator import get_style_discriminator
# from solver import make_optimizer_double
# from losses import get_losses, bce_loss, DiceLoss
import os
import torch
import torch.optim as optim
import torch.nn.functional as F
# from utils import init_env
import nibabel as nib
import numpy as np




def load_old_model(model_full, model_missing, d_style, optimizer, saved_model_path):
    print("Constructing model from saved file... ")
    checkpoint = torch.load(saved_model_path)
    model_full.load_state_dict(checkpoint["model_full"])
    model_missing.load_state_dict(checkpoint["model_missing"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    d_style.load_state_dict(checkpoint["d_style"])
    epoch = checkpoint["epochs"]

    return model_full, model_missing, d_style, optimizer, epoch

def to_numpy(tensor):
    if isinstance(tensor, (int, float)):
        return tensor
    else:
        return tensor.data.cpu().numpy()




## Main section
config = yaml.load(open('/content/smunet/config.yml'), Loader=yaml.FullLoader)
init_env('0')
print(config)
loaders = make_data_loaders(config)
model_full, model_missing = build_model(inp_dim1 = 4, inp_dim2 = 1)
model_full    = model_full.cuda()
model_missing = model_missing.cuda()
d_style       = get_style_discriminator(num_classes = 128).cuda()
task_name = 'brats2018_flair'
log_dir = os.path.join(config['path_to_log'], task_name)
optimizer, scheduler = make_optimizer_double(config, model_full, model_missing)
losses = get_losses(config)

continue_training = False
epoch = 0

if not os.path.exists(log_dir):
    os.makedirs(log_dir)

criteria = DiceLoss()


## evalute the performance
def get_mask(seg_volume):
    seg_volume = seg_volume.detach().cpu().numpy()
    seg_volume = np.squeeze(seg_volume)
    wt_pred = seg_volume[0]
    tc_pred = seg_volume[1]
    et_pred = seg_volume[2]
    mask = np.zeros_like(wt_pred)
    mask[wt_pred > 0.5] = 2
    mask[tc_pred > 0.5] = 1
    mask[et_pred > 0.5] = 4
    mask = mask.astype("uint8")
    return mask

def eval_metrics(gt, pred):
    loss_wt = criteria(np.where(gt>0, 1, 0), np.where(pred>0, 1, 0))
    loss_ct = criteria(np.where(gt==1, 1, 0)+np.where(gt==4, 1, 0), np.where(pred==1, 1, 0)+np.where(pred==4, 1, 0))
    loss_et = criteria(np.where(gt==4, 1, 0), np.where(pred==4, 1, 0))
    return loss_wt, loss_et, loss_ct

def measure_dice_score(batch_pred, batch_y):
    pred = get_mask(batch_pred)
    gt   = get_mask(batch_y)
    loss_wt, loss_et, loss_ct = eval_metrics(gt, pred)
    score = (loss_wt+loss_et+loss_ct)/3.0

    return score




{'lr': '1e-4', 'weight_decay': '1e-5', 'power': 0.9, 'epochs': 250, 'number_classes': 4, 'batch_size_tr': 1, 'batch_size_va': 1, 'modalities': ['flair', 't1', 't1ce', 't2'], 'path_to_data': '/content/brats2018/MICCAI_BraTS_2018_Data_Training/', 'path_to_log': './results/', 'progress_p': 0.1, 'validation_p': 0.1, 'test_p': 0.2, 'inputshape': [160, 192, 128], 'weight_mispath': 0.6, 'weight_content': 0.2}
['/content/brats2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_CBICA_AQU_1', '/content/brats2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_CBICA_AQV_1', '/content/brats2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_CBICA_AQY_1', '/content/brats2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_CBICA_AQZ_1', '/content/brats2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_CBICA_ARF_1', '/content/brats2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_CBICA_ARW_1', '/content/brats2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_CBICA_ARZ_1', '/content/brats2018/MICCAI_BraTS_2018_Data_Trai

In [24]:

def train_val(model_full, model_missing, d_style, loaders, optimizer, scheduler, losses, epoch_init=0):
    n_epochs = int(config['epochs'])
    iter_num = 0
    best_dice = 0.0
    for epoch in range(epoch_init, n_epochs):
        scheduler.step()
        train_loss = 0.0
        val_scores_full = 0.0
        val_scores_miss = 0.0

        for phase in ['train', 'eval']:
            loader = loaders[phase]
            total = len(loader)
            for batch_id, (batch_x, batch_y) in enumerate(loader):
                iter_num = iter_num + 1
                batch_x, batch_y = batch_x.cuda(non_blocking=True), batch_y.cuda(non_blocking=True)
                with torch.set_grad_enabled(phase == 'train'):
                    seg_f, style_f, content_f = model_full(batch_x[:,0:])
                    seg_m, style_m, content_m = model_missing(batch_x[:,0:1])
                    loss_dict = losses['co_loss'](config, seg_f, content_f, batch_y, seg_m, content_m, style_f, style_m, epoch)

                    d_style.train()
                    optimizer_d_style = optim.Adam(d_style.parameters(), lr = float(config['lr']), betas=(0.9, 0.99))
                    # labels for style adversarial training
                    source_label = 0
                    target_label = 1

                    optimizer.zero_grad()
                    optimizer_d_style.zero_grad()

                    # only train. Don't accumulate grads in disciminators
                    for param in d_style.parameters():
                        param.requires_grad = False

                    if phase == 'train':
                        (loss_dict['loss_Co']).backward(retain_graph=True)
                        train_loss += loss_dict['loss_Co'].item()




                    ##################### adversarial training to fool the discriminator ######################
                    df_src_main = style_f
                    df_trg_main = style_m
                    d_df_out_main = d_style(df_trg_main)
                    loss_adv_df_trg_main = bce_loss(d_df_out_main, source_label)
                    loss = 0.0002 * loss_adv_df_trg_main
                    if phase == 'train':
                        loss.backward()


                    ####################### Train discriminator networks ######################################
                    # enable training mode on discriminator networks
                    for param in d_style.parameters():
                        param.requires_grad = True
                    df_src_main = df_src_main.detach()
                    d_df_out_main = d_style(df_src_main)
                    loss_d_feature_main = bce_loss(d_df_out_main, source_label)
                    if phase == 'train':
                        loss_d_feature_main.backward()

                    ####################### train with target ##################################################
                    df_trg_main = df_trg_main.detach()
                    d_df_out_main = d_style(df_trg_main)
                    loss_d_feature_main = bce_loss(d_df_out_main, target_label)

                    if phase == 'train':
                        loss_d_feature_main.backward()


                num_classes = 4

                if phase == 'train':
                    optimizer.step()
                    optimizer_d_style.step()
                    if (batch_id + 1) % 20 == 0:
                        print(f'Epoch {epoch+1}>> itteration {batch_id+1}>> training loss>> {train_loss/(batch_id+1)}')

                else:
                    val_scores_full += measure_dice_score(seg_f, batch_y)
                    val_scores_miss += measure_dice_score(seg_m, batch_y)

            if phase == 'train':
                print(f'Epoch {epoch+1} overall training loss>> {train_loss/(batch_id+1)}')
            else:
                dice = (val_scores_miss/(batch_id+1))
                print(f'Epoch {epoch+1} validation dice score for missing modality>> {dice}')
                state = {}
                state['model_full'] = model_full.state_dict()
                state['model_missing'] = model_missing.state_dict()
                state['d_style'] = d_style.state_dict()
                state['optimizer'] = optimizer.state_dict()
                state['epochs'] = epoch
                file_name = log_dir+'/model_weights.pth'
                torch.save(state, file_name)
                if dice > best_dice:
                    torch.save(state, file_name)
                    best_dice = dice



saved_model_path = log_dir+'/model_weights.pth'
if continue_training:
    model_full, model_missing, d_style, optimizer, epoch = load_old_model(model_full, model_missing, d_style, optimizer, saved_model_path)
train_val(model_full, model_missing, d_style, loaders, optimizer, scheduler, losses, epoch)
print('Training process is finished')


OutOfMemoryError: ignored

##Testing time

evolution.py

In [36]:

#!/usr/bin/env python3
# encoding: utf-8
import yaml
# from data import make_data_loaders
# from models import build_model
# from models.discriminator import get_style_discriminator
# from solver import make_optimizer_double
# from losses import get_losses, bce_loss, DiceLoss
import os
import torch
import torch.optim as optim
import torch.nn.functional as F
# from utils import init_env
import nibabel as nib
import numpy as np
def load_old_model(model_full, model_missing, saved_model_path):
    print("Constructing model from saved file... ")
    checkpoint = torch.load(saved_model_path)
    model_full.load_state_dict(checkpoint["model_full"])
    model_missing.load_state_dict(checkpoint["model_missing"])

    return model_full, model_missing

def to_numpy(tensor):
    if isinstance(tensor, (int, float)):
        return tensor
    else:
        return tensor.data.cpu().numpy()

def eval_metrics(gt, pred):
    loss_hw = criteria(np.where(gt>0, 1, 0), np.where(pred>0, 1, 0))
    loss_ed = criteria(np.where(gt==1, 1, 0)+np.where(gt==4, 1, 0), np.where(pred==1, 1, 0)+np.where(pred==4, 1, 0))
    loss_ct = criteria(np.where(gt==4, 1, 0), np.where(pred==4, 1, 0))
    return (loss_hw, loss_ed, loss_ct)


def evaluate_sample(batch_pred_full, batch_pred_missing, batch_y):
    def get_mask(seg_volume):
        seg_volume = seg_volume.cpu().numpy()
        seg_volume = np.squeeze(seg_volume)
        wt_pred = seg_volume[0]
        tc_pred = seg_volume[1]
        et_pred = seg_volume[2]
        mask = np.zeros_like(wt_pred)
        TH = 0.4
        mask[wt_pred >= TH] = 2
        mask[tc_pred >= TH] = 1
        mask[et_pred >= TH] = 4
        mask = mask.astype("uint8")
        #mask_nii = nib.Nifti1Image(mask, np.eye(4))
        return mask

    pred_nii_full = get_mask(batch_pred_full)
    pred_nii_miss = get_mask(batch_pred_missing)
    gt_nii = get_mask(batch_y)

    metric_full  = eval_metrics(gt_nii, pred_nii_full)
    metric_miss  = eval_metrics(gt_nii, pred_nii_miss)

    return metric_full, metric_miss
## Main section
config = yaml.load(open('/content/smunet/config.yml'), Loader=yaml.FullLoader)
init_env('0')
loaders = make_data_loaders(config)
model_full, model_missing = build_model(inp_dim1 = 4, inp_dim2 = 1)
model_full    = model_full.cuda()
model_missing = model_missing.cuda()
d_style = get_style_discriminator(num_classes = 128).cuda()
task_name = 'brats2018_flair'
log_dir = os.path.join(config['path_to_log'], task_name)
criteria = DiceLoss()
def evaluate_performance(model_full, model_missing, loaders):
    class_score_full  = np.array((0.,0.,0.))
    class_score_mono  = np.array((0.,0.,0.))
    loader = loaders['eval']
    total = len(loader)
    with torch.no_grad():
        for batch_id, (batch_x, batch_y) in enumerate(loader):
            batch_x, batch_y = batch_x.cuda(non_blocking=True), batch_y.cuda(non_blocking=True)
            seg_f, style_f, content_f = model_full(batch_x[:,:])
            seg_m, style_m, content_m = model_missing(batch_x[:,0:1])
            metric_full, metric_miss = evaluate_sample(seg_f, seg_m, batch_y)
            class_score_full += metric_full
            class_score_mono += metric_miss


    class_score_full  /= total
    class_score_mono /= total
    print(f' validation Dise score full modalities class>> whole: {class_score_full[0]} core:{class_score_full[1]}  enh:{class_score_full[2]}')
    print(f' validation Dise score missing modality  class>> whole: {class_score_mono[0]} core:{class_score_mono[1]}  enh:{class_score_mono[2]}')


saved_model_path ='/content/results/brats2018_flair/model_weights.pth'
print(len(saved_model_path))
model_full, model_missing = load_old_model(model_full, model_missing, saved_model_path)
evaluate_performance(model_full, model_missing, loaders)

['/content/brats2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_CBICA_AQU_1', '/content/brats2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_CBICA_AQV_1', '/content/brats2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_CBICA_AQY_1', '/content/brats2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_CBICA_AQZ_1', '/content/brats2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_CBICA_ARF_1', '/content/brats2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_CBICA_ARW_1', '/content/brats2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_CBICA_ARZ_1', '/content/brats2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_CBICA_ASA_1', '/content/brats2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_CBICA_ASE_1', '/content/brats2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_CBICA_ASG_1', '/content/brats2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_CBICA_ASH_1', '/content/brats2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_CBICA_ASK_1', '/content/brats2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18