<a href="https://www.kaggle.com/code/lovrorabuzin/unet-zavrad?scriptVersionId=91691982" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Importi i hiperparametri

In [None]:
import torch
from torch import nn
from torch.nn import functional
import torch.nn.functional as F

import sklearn
from sklearn import model_selection

import torchvision
import torchvision.transforms as torch_transforms
import torch.utils.data as data
import torch.optim as optim
import torchvision.models as models

from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from torch.utils.data import TensorDataset
from torch.utils.data import Subset
import torch.utils.checkpoint as cp

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches

import pickle
import numpy as np
import math
import pandas as pd

import skimage as ski
import skimage.io
import os

import random

import nibabel as nib
from PIL import Image
import imageio

import torch.utils.checkpoint as cp

import sys
import json

b_s = 1
learning_rate = 0.0001
weight_decay = 0.0005
scaling_factor = 2 # faktor smanjivanja koraka učenja za slojeve inicijalizirane predtreniranim parametrima
gamma = 0.95
num_epochs = 30
slice_no = 100
random.seed()
root_train_dir = '../input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Učitavanje podataka

In [None]:
name_mapping = pd.read_csv(root_train_dir + '/name_mapping.csv')

name_mapping.rename({'BraTS_2020_subject_ID': 'ID'}, axis=1, inplace=True)

survival_info = pd.read_csv(root_train_dir + '/survival_info.csv')

survival_info.rename({'Brats20ID': 'ID'}, axis=1, inplace=True)

patient_info = survival_info.merge(name_mapping, on="ID", how="right")

modalities = ['_flair.nii', '_t1.nii', '_t1ce.nii', '_t2.nii']
mask_path = '_seg.nii'

train_scan_files = []
valid_scan_files = []
test_scan_files = []

HGG_names = list(patient_info[patient_info['Grade'] == 'HGG'].ID)
LGG_names = list(patient_info[patient_info['Grade'] == 'LGG'].ID)

for i in range(len(HGG_names)):
    image_path = root_train_dir+'/'+HGG_names[i]+'/'+HGG_names[i]
    if i % 7 == 5:
        valid_scan_files.append(image_path)
    elif i % 7 == 4:
        test_scan_files.append(image_path)
    else:
        train_scan_files.append(image_path)

for i in range(len(LGG_names)):
    image_path = root_train_dir+'/'+LGG_names[i]+'/'+LGG_names[i]
    if i % 7 == 5:
        valid_scan_files.append(image_path)
    elif i % 7 == 4:
        test_scan_files.append(image_path)
    else:
        train_scan_files.append(image_path)

horizontalFlipTransform = torch_transforms.RandomHorizontalFlip(p = 0.5)
toTensor = torch_transforms.ToTensor()

trainTransform = torch_transforms.Compose([toTensor, horizontalFlipTransform])

def normalize_vol(volume, i):
    logical_mask = volume != 0.
    mean = np.mean(volume[logical_mask])
    std = np.std(volume[logical_mask])
    return (volume-mean)/std

class SegmentationDataset(Dataset):
    in_channels = 4
    out_channels = 4
    
    def __init__(self, paths, modalities, mask_path, tumor_slices, slice_no = 155, transform = None):
        self.paths = paths
        self.modalities = modalities
        self.mask_path = mask_path
        self.transform = transform
        self.mask_buffer = np.array([])
        self.volume_buffer = np.array([])
        self.passes = 0
        self.tumor_slices = tumor_slices
        random.shuffle(self.paths)
        
    def __len__(self):
        return math.ceil(self.tumor_slices/slice_no)
    
    def reset(self):
        self.passes = 0
        random.shuffle(self.paths)
    
    def __getitem__(self, idx):
        if self.mask_buffer.size > 0:
            bound = min(slice_no, self.mask_buffer.shape[0])
            res_mask = self.mask_buffer[:bound]
            res_volume = self.volume_buffer[:bound]
            if bound < self.mask_buffer.shape[0]:
                self.mask_buffer = self.mask_buffer[bound:]
                self.volume_buffer = self.volume_buffer[bound:]
            else:
                self.mask_buffer = np.array([])
                self.volume_buffer = np.array([])
        else:
            res_mask = np.array([])
            res_volume = np.array([])
        
        while res_mask.shape[0] < slice_no and idx + self.passes < len(self.paths):
            patient = self.paths[idx+self.passes]
            self.passes += 1
            volumes = []
            single_mask = patient + mask_path
            if single_mask == root_train_dir + '/BraTS20_Training_355/BraTS20_Training_355_seg.nii':
                single_mask = root_train_dir + '/BraTS20_Training_355/W39_1998.09.19_Segm.nii'
            single_mask = nib.load(single_mask)
            single_mask = np.asarray(single_mask.dataobj, dtype = np.float)
            mask = single_mask.transpose(2,0,1)
            mask_WT = mask.copy()
            mask_WT[mask_WT == 1] = 0
            mask_WT[mask_WT == 2] = 1
            mask_WT[mask_WT == 4] = 0

            mask_TC = mask.copy()
            mask_TC[mask_TC == 1] = 1
            mask_TC[mask_TC == 2] = 0
            mask_TC[mask_TC == 4] = 0

            mask_ET = mask.copy()
            mask_ET[mask_ET == 1] = 0
            mask_ET[mask_ET == 2] = 0
            mask_ET[mask_ET == 4] = 1

            mask_BG = mask.copy()
            mask_BG[mask_BG == 0] = 3
            mask_BG[mask_BG == 1] = 0
            mask_BG[mask_BG == 2] = 0
            mask_BG[mask_BG == 4] = 0
            mask_BG[mask_BG == 3] = 1

            mask_full = np.stack([mask_WT, mask_TC, mask_ET, mask_BG])
            mask_full = np.transpose(mask_full, (1,0,2,3))

            tumor_indices = []
            mask_pure = []
            for i in range(155):
                mask_slice = mask_full[i]
                if np.sum(mask_slice[0:3]) != 0:
                    tumor_indices.append(i)
                    mask_pure.append(mask_slice)
            mask = np.stack(mask_pure)

            for modality in modalities:
                single_mod_volume = patient + modality
                single_mod_volume = nib.load(single_mod_volume)
                single_mod_volume = np.asarray(single_mod_volume.dataobj, dtype = np.float)
                single_mod_volume = single_mod_volume.transpose(2,0,1)
                intermittent = []
                for i_s in tumor_indices:
                    intermittent.append(single_mod_volume[i_s])
                volumes.append(np.stack(intermittent))

            if self.transform:
                seed = random.randint(0,2**32)
                random.seed(seed)
                torch.manual_seed(seed)
                volumes = np.transpose(np.stack(volumes), (1,0,2,3))
                volumes = np.stack([self.transform(np.transpose(volumes[i], (1,2,0))) for i in range(np.shape(volumes)[0])])
                volumes = np.transpose(np.stack(volumes), (1,0,2,3))
                random.seed(seed)
                torch.manual_seed(seed)
                mask = np.stack([self.transform(np.transpose(mask[i], (1,2,0))) for i in range(np.shape(mask)[0])])
            volumes = [normalize_vol(volumes[i], i) for i in range(np.shape(volumes)[0])]
            volumes = np.transpose(np.stack(volumes), (1,0,2,3))
            bound = min(slice_no-res_mask.shape[0], mask.shape[0])
            if res_mask.size == 0:
                res_mask = mask[:bound]
                res_volume = volumes[:bound]
            else:
                res_mask = np.concatenate((res_mask, mask[:bound]), axis = 0)
                res_volume = np.concatenate((res_volume, volumes[:bound]), axis = 0)
            if bound < mask.shape[0]:
                if self.mask_buffer.size == 0:
                    self.mask_buffer = mask[bound:]
                    self.volume_buffer = volumes[bound:]
                else:
                    self.mask_buffer = np.concatenate((self.mask_buffer, mask[bound:]), axis = 0)
                    self.volume_buffer = np.concatenate((self.volume_buffer, volumes[bound:]), axis = 0)
        self.passes -= 1
        res_volume = torch.from_numpy(res_volume).float()
        res_mask = torch.from_numpy(res_mask).long()
        return res_volume, res_mask
    
class TestingDataset(Dataset):
    in_channels = 4
    out_channels = 4
    
    def __init__(self, paths, modalities, mask_path):
        self.paths = paths
        self.modalities = modalities
        self.mask_path = mask_path
        self.mask_buffer = np.array([])
        self.volume_buffer = np.array([])
        
        
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, idx):
        patient = self.paths[idx]
        volumes = []
        single_mask = patient + mask_path
        if single_mask == root_train_dir + '/BraTS20_Training_355/BraTS20_Training_355_seg.nii':
            single_mask = root_train_dir + '/BraTS20_Training_355/W39_1998.09.19_Segm.nii'
        single_mask = nib.load(single_mask)
        single_mask = np.asarray(single_mask.dataobj, dtype = np.float)
        mask = single_mask.transpose(2,0,1)
        mask_WT = mask.copy()
        mask_WT[mask_WT == 1] = 0
        mask_WT[mask_WT == 2] = 1
        mask_WT[mask_WT == 4] = 0

        mask_TC = mask.copy()
        mask_TC[mask_TC == 1] = 1
        mask_TC[mask_TC == 2] = 0
        mask_TC[mask_TC == 4] = 0

        mask_ET = mask.copy()
        mask_ET[mask_ET == 1] = 0
        mask_ET[mask_ET == 2] = 0
        mask_ET[mask_ET == 4] = 1

        mask_BG = mask.copy()
        mask_BG[mask_BG == 0] = 3
        mask_BG[mask_BG == 1] = 0
        mask_BG[mask_BG == 2] = 0
        mask_BG[mask_BG == 4] = 0
        mask_BG[mask_BG == 3] = 1
        mask_full = np.stack([mask_WT, mask_TC, mask_ET, mask_BG])#, mask_BG
        mask_full = np.transpose(mask_full, (1,0,2,3))
        
        for modality in modalities:
                single_mod_volume = patient + modality
                single_mod_volume = nib.load(single_mod_volume)
                single_mod_volume = np.asarray(single_mod_volume.dataobj, dtype = np.float)
                single_mod_volume = single_mod_volume.transpose(2,0,1)
                volumes.append(np.stack(single_mod_volume))

        volumes = [normalize_vol(volumes[i], i) for i in range(np.shape(volumes)[0])]
        volumes = np.transpose(np.stack(volumes), (1,0,2,3))
        res_volume = torch.from_numpy(volumes).float()
        res_mask = torch.from_numpy(mask_full).long()

        return res_volume, res_mask



trainset = SegmentationDataset(train_scan_files, modalities, mask_path, 17227, transform = trainTransform)
validset = SegmentationDataset(valid_scan_files, modalities, mask_path, 3661)
testset = TestingDataset(test_scan_files, modalities, mask_path)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=b_s, shuffle=False, num_workers=0)
validloader = torch.utils.data.DataLoader(validset, batch_size=b_s,shuffle=False, num_workers=0)
testloader = torch.utils.data.DataLoader(testset, batch_size=b_s, shuffle=False, num_workers=0)

# Neuronska mreža (U-Net)

In [None]:
class UnetBlock(nn.Module):
    def __init__(self, in_channels, features):
        super(UnetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, features, kernel_size = 3, padding = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(features)
        self.relu = nn.ReLU(inplace = True)
        self.conv2 = nn.Conv2d(features, features, kernel_size = 3, padding = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(features)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x

    
class Unet(nn.Module):
    def __init__(self, in_channels = 4, single=False, first_features = 32):
        super(Unet, self).__init__()
        self.in_channels = in_channels
        self.first_features = first_features
        self.single = single
        
        self.b1 = UnetBlock(in_channels, first_features)
        self.pool1 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.b2 = UnetBlock(first_features, first_features*2)
        self.pool2 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.b3 = UnetBlock(first_features*2, first_features*4)
        self.pool3 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.b4 = UnetBlock(first_features*4, first_features*8)
        self.pool4 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        
        self.intermediate = UnetBlock(first_features*8, first_features*16)
        
        self.uc4 = nn.ConvTranspose2d(first_features*16, first_features*8, kernel_size = 2, stride = 2)
        self.upb4 = UnetBlock(first_features*16, first_features*8)
        self.uc3 = nn.ConvTranspose2d(first_features*8, first_features*4, kernel_size = 2, stride = 2)
        self.upb3 = UnetBlock(first_features*8, first_features*4)
        self.uc2 = nn.ConvTranspose2d(first_features*4, first_features*2, kernel_size = 2, stride = 2)
        self.upb2 = UnetBlock(first_features*4, first_features*2)
        self.uc1 = nn.ConvTranspose2d(first_features*2, first_features, kernel_size = 2, stride = 2)
        self.upb1 = UnetBlock(first_features*2, first_features)
        
        if single:
            final_features = 1
        else:
            final_features = 4
        self.blend = nn.Conv2d(first_features, final_features, kernel_size=1)
        
    def custom(self, module):
        def custom_forward(*inputs):
            inputs = module(inputs[0])
            return inputs
        return custom_forward
    
    def forward(self, x):
        down1 = cp.checkpoint(self.custom(self.b1), x)
        down2 = cp.checkpoint(self.custom(self.b2), self.pool1(down1))
        down3 = cp.checkpoint(self.custom(self.b3), self.pool2(down2))
        down4 = cp.checkpoint(self.custom(self.b4), self.pool3(down3))
        
        intermediate = self.intermediate(self.pool4(down4))
        
        
        up4 = self.uc4(intermediate)
        up4 = torch.cat((up4, down4), dim=1)
        up4 = cp.checkpoint(self.custom(self.upb4), up4)
        up3 = self.uc3(up4)
        up3 = torch.cat((up3, down3), dim=1)
        up3 = cp.checkpoint(self.custom(self.upb3), up3)
        up2 = self.uc2(up3)
        up2 = torch.cat((up2, down2), dim=1)
        up2 = cp.checkpoint(self.custom(self.upb2), up2)
        up1 = self.uc1(up2)
        up1 = torch.cat((up1, down1), dim=1)
        up1 = cp.checkpoint(self.custom(self.upb1), up1)
        
        logits = self.blend(up1)
        
        if self.single:
            nonlin = nn.Sigmoid()
        else:
            nonlin = nn.LogSoftmax(1)

        return nonlin(logits)

# Pomoćne funkcije

In [None]:
def dice_score(prediction, ground_truth, smooth = 1.0):
    prediction = prediction.contiguous().view(-1)
    ground_truth = ground_truth.contiguous().view(-1)
    
    intersection = (prediction*ground_truth).sum()
    score = (2*intersection+smooth)/(prediction.sum()+ground_truth.sum()+smooth)
    return score

def un_one_hot(targets):
    return targets.argmax(1)

def output_metrics(loss_avg, dice_scores, phase):
    print(phase)
    print("Average loss: {}".format(loss_avg))
    print("Dice score: {}".format(dice_scores))

def dataset_evaluate(model, loader, loss_fn, threshold = -1, single = False):
    global device
    
    loss_avg = 0.0
    dice_scores = []
    
    with torch.no_grad():
        model.eval()
        for images, ground_truths in loader:
            images = torch.cat([images[i] for i in range(b_s)])
            ground_truths = torch.cat([ground_truths[i] for i in range(b_s)])
            images = images.to(device)
            ground_truths = ground_truths.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, un_one_hot(ground_truths))
            loss_avg += loss.item()
            outputs = outputs.argmax(1)
            ground_truths = ground_truths.argmax(1)
            outputs_WT = outputs.clone()
            outputs_WT[outputs_WT == 0] = 1
            outputs_WT[outputs_WT == 1] = 1
            outputs_WT[outputs_WT == 2] = 1
            outputs_WT[outputs_WT == 3] = 0
            outputs_TC = outputs.clone()
            outputs_TC[outputs_TC == 0] = 0
            outputs_TC[outputs_TC == 1] = 1
            outputs_TC[outputs_TC == 2] = 1
            outputs_TC[outputs_TC == 3] = 0
            outputs_ET = outputs.clone()
            outputs_ET[outputs_ET == 0] = 0
            outputs_ET[outputs_ET == 1] = 0
            outputs_ET[outputs_ET == 2] = 1
            outputs_ET[outputs_ET == 3] = 0
            ground_truths_WT = ground_truths.clone()
            ground_truths_WT[ground_truths_WT == 0] = 1
            ground_truths_WT[ground_truths_WT == 1] = 1
            ground_truths_WT[ground_truths_WT == 2] = 1
            ground_truths_WT[ground_truths_WT == 3] = 0
            ground_truths_TC = ground_truths.clone()
            ground_truths_TC[ground_truths_TC == 0] = 0
            ground_truths_TC[ground_truths_TC == 1] = 1
            ground_truths_TC[ground_truths_TC == 2] = 1
            ground_truths_TC[ground_truths_TC == 3] = 0
            ground_truths_ET = ground_truths.clone()
            ground_truths_ET[ground_truths_ET == 0] = 0
            ground_truths_ET[ground_truths_ET == 1] = 0
            ground_truths_ET[ground_truths_ET == 2] = 1
            ground_truths_ET[ground_truths_ET == 3] = 0
            
            dice_scores.append([dice_score(outputs_WT, ground_truths_WT),
                                dice_score(outputs_TC, ground_truths_TC),
                                dice_score(outputs_ET, ground_truths_ET)])
                
    loss_avg /= len(loader)
    model.train()
    dice_scores = np.stack(dice_scores)
    return loss_avg, dice_scores.mean(0)

# Iscrtavanje grafa gubitka

In [None]:
def plot_progress(data):
    valid_loss = data['valid loss']
    train_loss = data['train loss']

    fig, ax = plt.subplots(figsize=(16,8))
    linewidth = 2
    legend_size = 10
    train_color = 'm'
    val_color = 'c'

    ax.set_title('Loss')
    ax.plot(train_loss, marker='o', color=train_color,
           linewidth=linewidth, linestyle='-', label='train')
    ax.plot(valid_loss, marker='o', color=val_color,
           linewidth=linewidth, linestyle='-', label='validation')
    ax.legend(loc='upper right', fontsize=legend_size)

    save_path = os.path.join('./', 'loss.png')
    print('Plotting in: ', save_path)
    plt.savefig(save_path)
    return

# Treniranje i evaluacija modela

In [None]:
def train_network():
    global device
    plot_data = {}
    plot_data["train loss"] = []
    plot_data["valid loss"] = []
    plot_data["WT valid dice"] = []
    plot_data["TC valid dice"] = []
    plot_data["ET valid dice"] = []
    plot_data["test loss"] = 0
    plot_data["WT test dice"] = []
    plot_data["TC test dice"] = []
    plot_data["ET test dice"] = []
    plot_data["lr"] = []
    
    SAVE_PATH = "./network.pt"
    print("device:", device)
    
    net = Unet().float()
    net = net.to(device)
    lossFunc = nn.NLLLoss(weight = torch.tensor([2,5,3,1]).float().to(device))
    
    optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay = weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = num_epochs, eta_min = 1e-6)
    
    for e in range(num_epochs):
        accLoss = 0.
        net.train()
        i = -1
        for inputs, ground_truths in trainloader:
            i+=1
            inputs = torch.cat([inputs[i] for i in range(b_s)])
            ground_truths = torch.cat([ground_truths[i] for i in range(b_s)])
            inputs.requires_grad = True
            inputs = inputs.to(device)
            ground_truths = ground_truths.to(device)
            
            outputs = net(inputs)
            loss = lossFunc(outputs, un_one_hot(ground_truths))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            accLoss += loss.item()

            if i % 10 == 0:
                print("Epoch: %d, Iteration: %5d, Loss: %.3f" % ((e + 1), (i+1), (accLoss / (i + 1))))
        trainset.reset()

        val_loss, val_dice = dataset_evaluate(net, validloader, lossFunc)
        validset.reset()
        output_metrics(val_loss, val_dice, "Validation")
        plot_data["valid loss"].append(val_loss)
        plot_data["WT valid dice"].append(val_dice[0])
        plot_data["TC valid dice"].append(val_dice[1])
        plot_data["ET valid dice"].append(val_dice[2])
        
        plot_data["lr"].append(scheduler.get_last_lr())
        plot_data["train loss"].append(accLoss/(i+1))
        
        scheduler.step()
    
    test_loss, test_dice = dataset_evaluate(net, testloader, lossFunc)
    output_metrics(test_loss, test_dice, "Test:")
    
    plot_data["test loss"] = test_loss
    plot_data["WT test dice"] = test_dice[0]
    plot_data["TC test dice"] = test_dice[1]
    plot_data["ET test dice"] = test_dice[2]
    
    torch.save(net.state_dict(), SAVE_PATH)
    
    with open("./epoch_data.txt", 'w') as f:
        f.write(repr(plot_data))

    plot_progress(plot_data)

    return plot_data
    
    
epoch_data = train_network()