<a href="https://www.kaggle.com/code/lovrorabuzin/densenet-zavrad?scriptVersionId=91691855" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Importi i hiperparametri

In [None]:
import torch
from torch import nn
from torch.nn import functional
import torch.nn.functional as F

import sklearn
from sklearn import model_selection

import torchvision
import torchvision.transforms as torch_transforms
import torch.utils.data as data
import torch.optim as optim
import torchvision.models as models

from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from torch.utils.data import TensorDataset
from torch.utils.data import Subset
import torch.utils.checkpoint as cp

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches

import pickle
import numpy as np
import math
import pandas as pd

import skimage as ski
import skimage.io
import os

import random

import nibabel as nib
from PIL import Image
import imageio

import sys
import json

from math import e

#Batch size je 1 jer sam morao radit mini-grupe dok sam ucitavao podatke zbog nekih specificnosti u radu s datasetovima na Kaggleu
b_s = 1
#Hiperparametri za promjenu stope ucenja
learning_rate = 0.0005
weight_decay = 0.0005
scaling_factor = 2
gamma = 0.95
#Otprilike koliko epoha mi je stalo u 9 sati rada u Kaggle kernelu
num_epochs = 24
#Broj krisaka na kojima se trenira odjednom, to jest batch size
slice_no = 200
random.seed()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
root_train_dir = '../input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData'

# Učitavanje podataka

In [None]:
name_mapping = pd.read_csv(root_train_dir + '/name_mapping.csv')

name_mapping.rename({'BraTS_2020_subject_ID': 'ID'}, axis=1, inplace=True)

survival_info = pd.read_csv(root_train_dir + '/survival_info.csv')

survival_info.rename({'Brats20ID': 'ID'}, axis=1, inplace=True)

patient_info = survival_info.merge(name_mapping, on="ID", how="right")

#Sufiksi za nacine snimanja, koristeno pri ucitavanju
modalities = ['_flair.nii', '_t1.nii', '_t1ce.nii', '_t2.nii']
mask_path = '_seg.nii'

train_scan_files = []
valid_scan_files = []
test_scan_files = []

HGG_names = list(patient_info[patient_info['Grade'] == 'HGG'].ID)
LGG_names = list(patient_info[patient_info['Grade'] == 'LGG'].ID)

#Moj dosta blesavi pristup za train/valid/test splittanje, primijeti da su razdijeljeni pacijenti s visokim stupnjem glioma (HGG)
#i niskim stupnjem (LGG) da bi bili proporcionalno zastupljeni u train/valid/test skupovima
for i in range(len(HGG_names)):
    image_path = root_train_dir+'/'+HGG_names[i]+'/'+HGG_names[i]
    if i % 7 == 5:
        valid_scan_files.append(image_path)
    elif i % 7 == 4:
        test_scan_files.append(image_path)
    else:
        train_scan_files.append(image_path)

for i in range(len(LGG_names)):
    image_path = root_train_dir+'/'+LGG_names[i]+'/'+LGG_names[i]
    if i % 7 == 5:
        valid_scan_files.append(image_path)
    elif i % 7 == 4:
        test_scan_files.append(image_path)
    else:
        train_scan_files.append(image_path)

        
#Pri treniranju slike nasumicno zrcalim - daje bolje performanse
horizontalFlipTransform = torch_transforms.RandomHorizontalFlip(p = 0.5)
toTensor = torch_transforms.ToTensor()

trainTransform = torch_transforms.Compose([toTensor, horizontalFlipTransform])

#Svaki volumen pojedinacno normaliziram da ima srednju vrijednost 0 i std 1
def normalize_vol(volume, i):
    #pritom uzimam u obzir samo piksele koji pripadaju mozgu, ne pozadinu slike
    logical_mask = volume != 0.
    mean = np.mean(volume[logical_mask])
    std = np.std(volume[logical_mask])
    return (volume-mean)/std

class SegmentationDataset(Dataset):
    in_channels = 4
    out_channels = 4
    
    def __init__(self, paths, modalities, mask_path, tumor_slices, slice_no = 155, transform = None):
        self.paths = paths
        self.modalities = modalities
        self.mask_path = mask_path
        self.transform = transform
        self.mask_buffer = np.array([])
        self.volume_buffer = np.array([])
        self.passes = 0
        self.tumor_slices = tumor_slices
        random.shuffle(self.paths)
        
    def __len__(self):
        return math.ceil(self.tumor_slices/slice_no)
    
    def reset(self):
        self.passes = 0
        random.shuffle(self.paths)
    
    def __getitem__(self, idx):
        if self.mask_buffer.size > 0:
            bound = min(slice_no, self.mask_buffer.shape[0])
            res_mask = self.mask_buffer[:bound]
            res_volume = self.volume_buffer[:bound]
            if bound < self.mask_buffer.shape[0]:
                self.mask_buffer = self.mask_buffer[bound:]
                self.volume_buffer = self.volume_buffer[bound:]
            else:
                self.mask_buffer = np.array([])
                self.volume_buffer = np.array([])
        else:
            res_mask = np.array([])
            res_volume = np.array([])
        
        while res_mask.shape[0] < slice_no and idx + self.passes < len(self.paths):
            patient = self.paths[idx+self.passes]
            self.passes += 1
            volumes = []
            single_mask = patient + mask_path
            #Jedna segmentacijska mapa je krivo imenovana u datasetu, a buduci da dataset nije moj, ne mogu to popraviti tamo pa u kodu moram 
            if single_mask == root_train_dir + '/BraTS20_Training_355/BraTS20_Training_355_seg.nii':
                single_mask = root_train_dir + '/BraTS20_Training_355/W39_1998.09.19_Segm.nii'
            single_mask = nib.load(single_mask)
            single_mask = np.asarray(single_mask.dataobj, dtype = np.float)
            mask = single_mask.transpose(2,0,1)
            mask_WT = mask.copy()
            mask_WT[mask_WT == 1] = 0
            mask_WT[mask_WT == 2] = 1
            mask_WT[mask_WT == 4] = 0

            mask_TC = mask.copy()
            mask_TC[mask_TC == 1] = 1
            mask_TC[mask_TC == 2] = 0
            mask_TC[mask_TC == 4] = 0

            mask_ET = mask.copy()
            mask_ET[mask_ET == 1] = 0
            mask_ET[mask_ET == 2] = 0
            mask_ET[mask_ET == 4] = 1

            mask_BG = mask.copy()
            mask_BG[mask_BG == 0] = 3
            mask_BG[mask_BG == 1] = 0
            mask_BG[mask_BG == 2] = 0
            mask_BG[mask_BG == 4] = 0
            mask_BG[mask_BG == 3] = 1

            mask_full = np.stack([mask_WT, mask_TC, mask_ET, mask_BG])
            mask_full = np.transpose(mask_full, (1,0,2,3))

            tumor_indices = []
            mask_pure = []
            for i in range(155):
                mask_slice = mask_full[i]
                if np.sum(mask_slice[0:3]) != 0:
                    tumor_indices.append(i)
                    mask_pure.append(mask_slice)
            mask = np.stack(mask_pure)

            for modality in modalities:
                single_mod_volume = patient + modality
                single_mod_volume = nib.load(single_mod_volume)
                single_mod_volume = np.asarray(single_mod_volume.dataobj, dtype = np.float)
                single_mod_volume = single_mod_volume.transpose(2,0,1)
                intermittent = []
                for i_s in tumor_indices:
                    intermittent.append(single_mod_volume[i_s])
                volumes.append(np.stack(intermittent))

            if self.transform:
                seed = random.randint(0,2**32)
                random.seed(seed)
                torch.manual_seed(seed)
                volumes = np.transpose(np.stack(volumes), (1,0,2,3))
                volumes = np.stack([self.transform(np.transpose(volumes[i], (1,2,0))) for i in range(np.shape(volumes)[0])])
                volumes = np.transpose(np.stack(volumes), (1,0,2,3))
                random.seed(seed)
                torch.manual_seed(seed)
                mask = np.stack([self.transform(np.transpose(mask[i], (1,2,0))) for i in range(np.shape(mask)[0])])
            volumes = [normalize_vol(volumes[i], i) for i in range(np.shape(volumes)[0])]
            volumes = np.transpose(np.stack(volumes), (1,0,2,3))
            bound = min(slice_no-res_mask.shape[0], mask.shape[0])
            if res_mask.size == 0:
                res_mask = mask[:bound]
                res_volume = volumes[:bound]
            else:
                res_mask = np.concatenate((res_mask, mask[:bound]), axis = 0)
                res_volume = np.concatenate((res_volume, volumes[:bound]), axis = 0)
            if bound < mask.shape[0]:
                if self.mask_buffer.size == 0:
                    self.mask_buffer = mask[bound:]
                    self.volume_buffer = volumes[bound:]
                else:
                    self.mask_buffer = np.concatenate((self.mask_buffer, mask[bound:]), axis = 0)
                    self.volume_buffer = np.concatenate((self.volume_buffer, volumes[bound:]), axis = 0)
        self.passes -= 1
        res_volume = torch.from_numpy(res_volume).float()
        res_mask = torch.from_numpy(res_mask).long()
        return res_volume, res_mask
    
class TestingDataset(Dataset):
    in_channels = 4
    out_channels = 4
    
    def __init__(self, paths, modalities, mask_path):
        self.paths = paths
        self.modalities = modalities
        self.mask_path = mask_path
        self.mask_buffer = np.array([])
        self.volume_buffer = np.array([])
        
        
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, idx):
        patient = self.paths[idx]
        volumes = []
        single_mask = patient + mask_path
        if single_mask == root_train_dir + '/BraTS20_Training_355/BraTS20_Training_355_seg.nii':
            single_mask = root_train_dir + '/BraTS20_Training_355/W39_1998.09.19_Segm.nii'
        single_mask = nib.load(single_mask)
        single_mask = np.asarray(single_mask.dataobj, dtype = np.float)
        mask = single_mask.transpose(2,0,1)
        mask_WT = mask.copy()
        mask_WT[mask_WT == 1] = 0
        mask_WT[mask_WT == 2] = 1
        mask_WT[mask_WT == 4] = 0

        mask_TC = mask.copy()
        mask_TC[mask_TC == 1] = 1
        mask_TC[mask_TC == 2] = 0
        mask_TC[mask_TC == 4] = 0

        mask_ET = mask.copy()
        mask_ET[mask_ET == 1] = 0
        mask_ET[mask_ET == 2] = 0
        mask_ET[mask_ET == 4] = 1

        mask_BG = mask.copy()
        mask_BG[mask_BG == 0] = 3
        mask_BG[mask_BG == 1] = 0
        mask_BG[mask_BG == 2] = 0
        mask_BG[mask_BG == 4] = 0
        mask_BG[mask_BG == 3] = 1
        mask_full = np.stack([mask_WT, mask_TC, mask_ET, mask_BG])#, mask_BG
        mask_full = np.transpose(mask_full, (1,0,2,3))
        
        for modality in modalities:
                single_mod_volume = patient + modality
                single_mod_volume = nib.load(single_mod_volume)
                single_mod_volume = np.asarray(single_mod_volume.dataobj, dtype = np.float)
                single_mod_volume = single_mod_volume.transpose(2,0,1)
                volumes.append(np.stack(single_mod_volume))

        volumes = [normalize_vol(volumes[i], i) for i in range(np.shape(volumes)[0])]
        volumes = np.transpose(np.stack(volumes), (1,0,2,3))
        res_volume = torch.from_numpy(volumes).float()
        res_mask = torch.from_numpy(mask_full).long()

        return res_volume, res_mask



trainset = SegmentationDataset(train_scan_files, modalities, mask_path, 17227, transform = trainTransform)
validset = SegmentationDataset(valid_scan_files, modalities, mask_path, 3661)
testset = TestingDataset(test_scan_files, modalities, mask_path)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=b_s, shuffle=False, num_workers=0)
validloader = torch.utils.data.DataLoader(validset, batch_size=b_s,shuffle=False, num_workers=0)
testloader = torch.utils.data.DataLoader(testset, batch_size=b_s, shuffle=False, num_workers=0)

# Neuronska mreža (Ladder Densenet)

In [None]:
upsample = lambda x, size: F.interpolate(x, size, mode='bilinear', align_corners=False)
checkpoint = lambda func, *inputs: cp.checkpoint(func, *inputs, preserve_rng_state=True)
batchnorm_momentum = 0.05
batchnorm_momentum_ckpt = min(np.roots([-1, 2, -batchnorm_momentum]))
use_batchnorm = True
avg_pooling_k = 2
use_dws_up = False
use_dws_down = False
checkpoint_stem = True
use_pyl_in_spp = True
checkpoint_upsample = False

def get_pyramid_loss_scales(downsampling_factor, upsampling_factor):
    num_scales = int(math.log2(downsampling_factor // upsampling_factor))
    scales = [downsampling_factor]
    for i in range(num_scales - 1):
        assert scales[-1] % 2 == 0
        scales.append(scales[-1] // 2)
    return scales

def _batchnorm_factory(num_maps, momentum):
    return nn.BatchNorm2d(num_maps, eps=1e-5, momentum=momentum)

def _checkpoint_unit_nobt(bn, relu, conv):
    def func(*x):
        x = torch.cat(x, 1)
        return conv(relu(bn(x)))
    return func

def _checkpoint_unit(bn1, relu1, conv1, bn2, relu2, conv2):
    def func(*x):
        x = torch.cat(x, 1)
        x = conv1(relu1(bn1(x)))
        return conv2(relu2(bn2(x)))
    return func

class _Transition(nn.Sequential):
    @staticmethod
    def _checkpoint_function(bn, relu, conv, pool):
        def func(inputs):
            return pool(conv(relu(bn(inputs))))
        return func

    def __init__(self, num_input_features, num_output_features, stride=2, checkpointing=False):
        super(_Transition, self).__init__()
        self.stride = stride
        if use_batchnorm:
            m = batchnorm_momentum_ckpt if checkpointing else batchnorm_momentum
            self.add_module('norm', _batchnorm_factory(num_input_features, m))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        if stride > 1:
            if avg_pooling_k == 2:
                self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=stride))
            elif avg_pooling_k == 3:
                self.add_module('pool', nn.AvgPool2d(kernel_size=3, stride=stride,
            					padding=1, ceil_mode=False, count_include_pad=False))
        else:
            self.pool = lambda x: x
        self.checkpointing = checkpointing
        if checkpointing:
            self.func = _Transition._checkpoint_function(self.norm, self.relu, self.conv, self.pool)

    def forward(self, x):
        if self.checkpointing and self.training:
            return checkpoint(self.func, x)
        else:
            return super(_Transition, self).forward(x)

class _BNReluConv(nn.Sequential):
    @staticmethod
    def _checkpoint_function(bn, relu, conv):
        def func(inputs):
            return conv(relu(bn(inputs)))
        return func

    def __init__(self, num_maps_in, num_maps_out, k=3, output_conv=False,
                 dilation=1, drop_rate=0, checkpointing=False):
        super(_BNReluConv, self).__init__()
        self.drop_rate = drop_rate
        if use_batchnorm:
            m = batchnorm_momentum_ckpt if checkpointing else batchnorm_momentum
            self.add_module('norm', _batchnorm_factory(num_maps_in, m))
        self.add_module('relu', nn.ReLU(inplace=True))
        padding = ((k-1) // 2) * dilation
        if k >= 3 and use_dws_up:
            self.add_module('conv', SeparableConv2d(num_maps_in, num_maps_out, kernel_size=k,
                padding=padding, bias=output_conv, dilation=dilation))
        else:
            self.add_module('conv', nn.Conv2d(num_maps_in, num_maps_out, kernel_size=k,
                            padding=padding, bias=output_conv, dilation=dilation))
        self.checkpointing = checkpointing
        if checkpointing:
            self.func = _BNReluConv._checkpoint_function(self.norm, self.relu, self.conv)

    def forward(self, x):
        if self.checkpointing and self.training:
            x = checkpoint(self.func, x)
        else:
            x = super(_BNReluConv, self).forward(x)
        return x

class SpatialPyramidPooling(nn.Module):
    def __init__(self, conv_class, upsample_func, num_maps_in, bt_size=512, level_size=128,
                 out_size=256, grids=[6,3,2,1], square_grid=False):
        super(SpatialPyramidPooling, self).__init__()
        self.upsample = upsample_func
        self.grids = grids
        self.num_levels = len(grids)
        self.square_grid = square_grid
        self.spp = nn.Sequential()
        self.spp.add_module('spp_bn', conv_class(num_maps_in, bt_size, k=1))
        num_features = bt_size
        final_size = num_features
        for i in range(self.num_levels):
            final_size += level_size
            self.spp.add_module('spp'+str(i), conv_class(num_features, level_size, k=1))
        self.spp.add_module('spp_fuse', conv_class(final_size, out_size, k=1))

    def forward(self, x):
        levels = []
        target_size = x.size()[2:4]
        ar = target_size[1] / target_size[0]
        x = self.spp[0].forward(x)
        levels.append(x)
        num = len(self.spp) - 1
        for i in range(1, num):
            if not self.square_grid:
                grid_size = (self.grids[i-1], max(1, round(ar*self.grids[i-1])))
                x_pooled = F.adaptive_avg_pool2d(x, grid_size)
            else:
                x_pooled = F.adaptive_avg_pool2d(x, self.grids[i-1])
            level = self.spp[i].forward(x_pooled)
            level = self.upsample(level, target_size)
            levels.append(level)
        x = torch.cat(levels, 1)
        return self.spp[-1].forward(x)

class UpsampleResidual(nn.Module):
    def __init__(self, conv_class, upsample_func, num_maps_in, skip_maps_in, num_maps_out, k,
                 produce_aux=False, num_classes=0, dws_conv=False):
        super(UpsampleResidual, self).__init__()
        self.upsample_func = upsample_func
        self.bottleneck = conv_class(skip_maps_in, num_maps_in, k=1)
        self.produce_aux = produce_aux
        self.has_blend_conv = num_maps_out > 0
        self.num_maps_out = num_maps_in
        if num_maps_out != num_maps_in:
            self.skip_bt = conv_class(num_maps_in, num_maps_out, k=1)
        else:
            self.skip_bt = None
        if produce_aux:
            self.aux_logits = conv_class(num_maps_in, num_classes, k=1, output_conv=True)
        if self.has_blend_conv:
            self.num_maps_out = num_maps_out
            bt_maps = 128
            self.blend_bt = None
            if not dws_conv and k >=3 and num_maps_in > bt_maps:
                self.blend_bt = conv_class(num_maps_in, bt_maps, k=1)
                num_maps_in = bt_maps
            self.blend_conv = conv_class(num_maps_in, num_maps_out, k=k)

    def forward(self, bottom, skip):
        skip = self.bottleneck(skip)
        skip_size = skip.size()[2:4]
        if self.produce_aux:
            aux = self.aux_logits(bottom)

        bottom = self.upsample_func(bottom, skip_size)
        x = skip
        x += bottom

        if self.has_blend_conv:
            if self.blend_bt is not None:
                x = self.blend_bt(x)
            x = self.blend_conv(x)
        if self.skip_bt is not None:
            bottom = self.skip_bt(bottom)
        x += bottom
        if self.produce_aux:
            return x, aux
        return x

class _DenseLayer(nn.Sequential):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate,
                 dilation=1, checkpointing=True):
        super(_DenseLayer, self).__init__()
        m = batchnorm_momentum_ckpt if checkpointing else batchnorm_momentum
        bottleneck_size = bn_size * growth_rate
        self.add_module('norm1', _batchnorm_factory(num_input_features, m))
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(num_input_features, bottleneck_size,
                        kernel_size=1, stride=1, bias=False))
        num_feats = bottleneck_size
        self.add_module('norm2', _batchnorm_factory(num_feats, m))
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(num_feats, growth_rate, kernel_size=3,
                            stride=1, padding=dilation, bias=False, dilation=dilation))
        self.drop_rate = drop_rate
        self.checkpointing = checkpointing
        if checkpointing:
            if len(self) == 6:
                self.conv_func = _checkpoint_unit(self.norm1, self.relu1, self.conv1,
                                                  self.norm2, self.relu2, self.conv2)
            else:
                self.conv_func = _checkpoint_unit_nobt(self.norm2, self.relu2, self.conv2)

    def forward(self, x):
        if self.checkpointing:
            if self.training:
                x = checkpoint(self.conv_func, *x)
            else:
                x = self.conv_func(*x)
        else:
            x = super(_DenseLayer, self).forward(x)

        if self.drop_rate > 0:
            x = F.dropout(x, p=self.drop_rate, training=self.training, inplace=True)

        return x


class _DenseBlock(nn.Sequential):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate,
                 split=False, dilation=1, checkpointing=True):
        super(_DenseBlock, self).__init__()
        self.checkpointing = checkpointing
        self.split = split

        for i in range(num_layers):
            layer = _DenseLayer(
                num_input_features + i * growth_rate, growth_rate=growth_rate, bn_size=bn_size,
                drop_rate=drop_rate, dilation=dilation, checkpointing=checkpointing)
            self.add_module('denselayer%d' % (i + 1), layer)
        if split:
            self.split_size = num_input_features + (num_layers // 2) * growth_rate
            k = avg_pooling_k
            pad = (k-1) // 2
            self.pool_func = lambda x: F.avg_pool2d(x, k, 2, padding=pad, ceil_mode=False,
                                                    count_include_pad=False)

    def forward(self, x):
        if self.checkpointing:
            x = [x]
        for i, layer in enumerate(self.children()):
            if self.split and len(self) // 2 == i:
                if self.checkpointing:
                    split = torch.cat(x, 1)
                    x = [self.pool_func(split)]
                else:
                    split = x
                    x = self.pool_func(split)
            if self.checkpointing:
                x.append(layer(x))
            else:
                x = torch.cat([x, layer(x)], 1)
        if self.checkpointing:
            x = torch.cat(x, 1)
        if self.split:
            return x, split
        else:
            return x

class DenseNet(nn.Module):
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
                 num_init_features=64, bn_size=4, drop_rate=0):
        super(DenseNet, self).__init__()
        batchnorm_momentum_ckpt = min(np.roots([-1, 2, -batchnorm_momentum]))
        self.first_stride = 2
        self.num_classes = 4
        self.num_logits = 4
        self.checkpointing = True
        self.growth_rate = growth_rate
        self.block_config = block_config
        self.num_blocks = len(block_config)
        self.growth_rate = growth_rate
        self.features = nn.Sequential()
        self.checkpoint_stem = True
        self.features.add_module('conv0', nn.Conv2d(4, num_init_features, kernel_size=7,
                                 stride=self.first_stride, padding=3, bias=False))
        m = batchnorm_momentum_ckpt
        self.features.add_module('norm0', _batchnorm_factory(num_init_features, m))
        self.features.add_module('relu0', nn.ReLU(inplace=True))
        self.features.add_module('pool0', nn.MaxPool2d(kernel_size=2, stride=2))
        self.first_block_idx = len(self.features)
        self.first_ckpt_func = self._checkpoint_segment(0, self.first_block_idx)
        self.random_init = []
        self.fine_tune = []
        self.fine_tune.append(self.features)
        splits = [False, False, False, False]
        up_sizes = [256, 256, 128]
        spp_square_grid = False
        spp_grids = [8,4,2,1]
        self.spp_size = 512
        bt_size = 512
        level_size = 128
        dilations = [1] * len(block_config)
        strides = [2] * (len(block_config) - 1)
        num_downs = self.first_stride + strides.count(2) + sum(splits)
        num_ups = len(up_sizes)
        self.downsampling_factor = 2**num_downs
        self.upsampling_factor = 2**(num_downs-num_ups)
        self.use_upsampling_path = True
        skip_sizes = []
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(
                num_layers=num_layers, num_input_features=num_features, bn_size=bn_size,
                growth_rate=growth_rate, drop_rate=0, split=splits[i],
                dilation=dilations[i], checkpointing=self.checkpointing)
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if block.split and self.use_upsampling_path:
                skip_sizes.append(block.split_size)
            if i != len(block_config) - 1:
                if strides[i] > 1 and self.use_upsampling_path:
                    skip_sizes.append(num_features)
                trans = _Transition(
                    num_input_features=num_features, num_output_features=num_features // 2,
                    stride=strides[i], checkpointing=self.checkpointing)
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = num_features // 2
        self.use_aux = False
        self.spp = SpatialPyramidPooling(
            _BNReluConv, upsample, num_features, bt_size, level_size,
            self.spp_size, spp_grids, spp_square_grid)
        self.random_init.append(self.spp)
        num_features = self.spp_size
        if self.use_aux:
            self.pyramid_loss_scales = get_pyramid_loss_scales(
                    args.downsampling_factor, args.upsampling_factor)
            spp_scales = []
            for scale in reversed(spp_grids):
                assert args.crop_size % scale == 0
                spp_scales.append(args.crop_size // scale)
            self.pyramid_loss_scales = spp_scales + self.pyramid_loss_scales
        if self.use_upsampling_path:
            up_convs = [3] * len(up_sizes)
            self.upsample_layers = nn.Sequential()
            self.random_init.append(self.upsample_layers)
            assert len(up_sizes) == len(skip_sizes)
            for i in range(num_ups):
                upsample_unit = UpsampleResidual(
                    _BNReluConv, upsample, num_features, skip_sizes[-1-i], up_sizes[i],
                    up_convs[i], False, self.num_classes)
                num_features = upsample_unit.num_maps_out
                self.upsample_layers.add_module('upsample_'+str(i), upsample_unit)
        self.logits = _BNReluConv(num_features, self.num_logits, k=1, output_conv=True,
                                  checkpointing=self.checkpointing)
        self.random_init.append(self.logits)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        nn.init.xavier_normal_(self.logits.conv.weight.data)
        if self.use_aux and self.use_upsampling_path:
            for module in self.upsample_layers:
                nn.init.xavier_normal_(module.aux_logits.conv.weight.data)

    
    def forward(self, x, target_size=None):
        skip_layers = []
        if target_size is None:
            target_size = x.size()[2:4]

        if not self.training or not self.checkpoint_stem:
            for i in range(self.first_block_idx+1):
                x = self.features[i].forward(x)
        else:
            x = checkpoint(self.first_ckpt_func, x)
            x = self.features[self.first_block_idx].forward(x)
        for i in range(self.first_block_idx+1, len(self.features), 2):
            if self.features[i].stride > 1 and self.use_upsampling_path:
                skip_layers.append(x)
            x = self.features[i].forward(x)
            x = self.features[i+1].forward(x)
            if isinstance(x, tuple) and self.use_upsampling_path:
                x, split = x
                skip_layers.append(split)

        x = self.spp(x)

        aux_logits = []
        if self.use_upsampling_path:
            for i, up in enumerate(self.upsample_layers):
                x = up(x, skip_layers[-1-i])
                if self.use_aux:
                    x, aux = x
                    aux_logits.append(aux)

        x = self.logits(x)
        x = upsample(x, target_size)

        return x, aux_logits
    
    def _checkpoint_segment(self, start, end):
        def func(x):
            for i in range(start, end):
                x = self.features[i](x)
            return x
        return func
    
    def forward_loss(self, batch, return_outputs=False):
        x = batch['image']
        logits, aux_logits = self.forward(x)
        if not self.training:
            aux_logits = []
        self.output = logits
        loss = losses.segmentation_loss(logits, aux_logits, batch, self.args.aux_loss_weight,
                                        self.dataset.ignore_id, equal_level_weights=use_pyl_in_spp)
        loss, self.aux_losses = loss
        if return_outputs:
            return loss, (logits, aux_logits)
        return loss

class SemSegModel(nn.Module):
    def __init__(self, backbone, num_classes):
        super(SemSegModel, self).__init__()
        self.backbone = backbone

    def forward(self, image):
        logits, additional = self.backbone(image)
        nonlin = nn.LogSoftmax(1)
        return nonlin(logits)

# Pomoćne funkcije

In [None]:
def dice_score(prediction, ground_truth, smooth = 1.0):
    prediction = prediction.contiguous().view(-1)
    ground_truth = ground_truth.contiguous().view(-1)
    
    intersection = (prediction*ground_truth).sum()
    score = (2*intersection+smooth)/(prediction.sum()+ground_truth.sum()+smooth)
    return score

def un_one_hot(targets):
    return targets.argmax(1)

def output_metrics(loss_avg, dice_scores, phase):
    print(phase)
    print("Average loss: {}".format(loss_avg))
    print("Dice score: {}".format(dice_scores))

def dataset_evaluate(model, loader, loss_fn, threshold = -1, single = False):
    global device
    
    loss_avg = 0.0
    dice_scores = []
    
    with torch.no_grad():
        model.eval()
        for images, ground_truths in loader:
            images = torch.cat([images[i] for i in range(b_s)])
            ground_truths = torch.cat([ground_truths[i] for i in range(b_s)])
            images = images.to(device)
            ground_truths = ground_truths.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, un_one_hot(ground_truths))
            loss_avg += loss.item()
            outputs = outputs.argmax(1)
            ground_truths = ground_truths.argmax(1)
            outputs_WT = outputs.clone()
            outputs_WT[outputs_WT == 0] = 1
            outputs_WT[outputs_WT == 1] = 1
            outputs_WT[outputs_WT == 2] = 1
            outputs_WT[outputs_WT == 3] = 0
            outputs_TC = outputs.clone()
            outputs_TC[outputs_TC == 0] = 0
            outputs_TC[outputs_TC == 1] = 1
            outputs_TC[outputs_TC == 2] = 1
            outputs_TC[outputs_TC == 3] = 0
            outputs_ET = outputs.clone()
            outputs_ET[outputs_ET == 0] = 0
            outputs_ET[outputs_ET == 1] = 0
            outputs_ET[outputs_ET == 2] = 1
            outputs_ET[outputs_ET == 3] = 0
            ground_truths_WT = ground_truths.clone()
            ground_truths_WT[ground_truths_WT == 0] = 1
            ground_truths_WT[ground_truths_WT == 1] = 1
            ground_truths_WT[ground_truths_WT == 2] = 1
            ground_truths_WT[ground_truths_WT == 3] = 0
            ground_truths_TC = ground_truths.clone()
            ground_truths_TC[ground_truths_TC == 0] = 0
            ground_truths_TC[ground_truths_TC == 1] = 1
            ground_truths_TC[ground_truths_TC == 2] = 1
            ground_truths_TC[ground_truths_TC == 3] = 0
            ground_truths_ET = ground_truths.clone()
            ground_truths_ET[ground_truths_ET == 0] = 0
            ground_truths_ET[ground_truths_ET == 1] = 0
            ground_truths_ET[ground_truths_ET == 2] = 1
            ground_truths_ET[ground_truths_ET == 3] = 0
            
            dice_scores.append([dice_score(outputs_WT, ground_truths_WT),
                                dice_score(outputs_TC, ground_truths_TC),
                                dice_score(outputs_ET, ground_truths_ET)])
                
    loss_avg /= len(loader)
    model.train()
    dice_scores = np.stack(dice_scores)
    return loss_avg, dice_scores.mean(0)

# Iscrtavanje grafa gubitka

In [None]:
def plot_progress(data):
    valid_loss = data['valid loss']
    train_loss = data['train loss']

    fig, ax = plt.subplots(figsize=(16,8))
    linewidth = 2
    legend_size = 10
    train_color = 'm'
    val_color = 'c'

    ax.set_title('Loss')
    ax.plot(train_loss, marker='o', color=train_color,
           linewidth=linewidth, linestyle='-', label='train')
    ax.plot(valid_loss, marker='o', color=val_color,
           linewidth=linewidth, linestyle='-', label='validation')
    ax.legend(loc='upper right', fontsize=legend_size)

    save_path = os.path.join('./', 'loss.png')
    print('Plotting in: ', save_path)
    plt.savefig(save_path)
    return

# Treniranje i evaluacija modela

In [None]:
mydensenet = DenseNet(num_init_features = 64, growth_rate = 32, block_config = (6,12,24,16))
net = SemSegModel(mydensenet, 4)
print(net)

In [None]:
def train_network():
    global device
    plot_data = {}
    plot_data["train loss"] = []
    plot_data["valid loss"] = []
    plot_data["WT valid dice"] = []
    plot_data["TC valid dice"] = []
    plot_data["ET valid dice"] = []
    plot_data["test loss"] = 0
    plot_data["WT test dice"] = []
    plot_data["TC test dice"] = []
    plot_data["ET test dice"] = []
    plot_data["lr"] = []
    
    SAVE_PATH = "./network.pt"
    print("device:", device)
    
    densenet121 = models.densenet121(pretrained=True, memory_efficient = True)
    mydensenet = DenseNet(num_init_features = 64, growth_rate = 32, block_config = (6,12,24,16))

    pretrained_dict = densenet121.state_dict()
    model_dict = mydensenet.state_dict()

    
    conv_weight = pretrained_dict.pop('features.conv0.weight')
    mean_kernel = conv_weight.mean(dim=1)
    mean_kernel = torch.unsqueeze(mean_kernel, dim=1)
    mean_kernel = torch.cat([mean_kernel]*4, dim = 1)
    pretrained_dict['features.conv0.weight'] = mean_kernel

    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    pretrained_names = list(pretrained_dict.keys())
    
    model_dict.update(pretrained_dict)

    mydensenet.load_state_dict(model_dict)
    
    net = SemSegModel(mydensenet, 4).to(device = device)
    lossFunc = nn.NLLLoss(weight = torch.tensor([2,5,3,1]).float().to(device))
    
    base_params = [v[1] for v in list(filter(lambda kv: kv[0] not in pretrained_names, net.named_parameters()))]
    pretrained_params = [v[1] for v in list(filter(lambda kv: kv[0] in pretrained_names, net.named_parameters()))]
    
    optimizer = optim.Adam([{'params': base_params}, {'params': pretrained_params, 'lr':learning_rate/scaling_factor}], lr=learning_rate, weight_decay = weight_decay)
    
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = num_epochs, eta_min = 1e-6)
    
    for e in range(num_epochs):
        accLoss = 0.
        net.train()
        i = -1
        for inputs, ground_truths in trainloader:
            i+=1
            inputs = torch.cat([inputs[i] for i in range(b_s)])
            ground_truths = torch.cat([ground_truths[i] for i in range(b_s)])
            inputs.requires_grad = True
            inputs = inputs.to(device)
            ground_truths = ground_truths.to(device)
            
            outputs = net(inputs)
            loss = lossFunc(outputs, un_one_hot(ground_truths))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            accLoss += loss.item()

            if i % 10 == 0:
                print("Epoch: %d, Iteration: %5d, Loss: %.3f" % ((e + 1), (i+1), (accLoss / (i + 1))))
        trainset.reset()

        val_loss, val_dice = dataset_evaluate(net, validloader, lossFunc)
        validset.reset()
        output_metrics(val_loss, val_dice, "Validation")
        plot_data["valid loss"].append(val_loss)
        plot_data["WT valid dice"].append(val_dice[0])
        plot_data["TC valid dice"].append(val_dice[1])
        plot_data["ET valid dice"].append(val_dice[2])
        
        plot_data["lr"].append(scheduler.get_last_lr())
        plot_data["train loss"].append(accLoss/(i+1))
        
        scheduler.step()
    
    test_loss, test_dice = dataset_evaluate(net, testloader, lossFunc)
    output_metrics(test_loss, test_dice, "Test:")
    
    plot_data["test loss"] = test_loss
    plot_data["WT test dice"] = test_dice[0]
    plot_data["TC test dice"] = test_dice[1]
    plot_data["ET test dice"] = test_dice[2]
    
    torch.save(net.state_dict(), SAVE_PATH)
    
    with open("./epoch_data.txt", 'w') as f:
        f.write(repr(plot_data))

    plot_progress(plot_data)

    return plot_data
    
    
epoch_data = train_network()