In [1]:
import torch
import torch.nn as nn

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [14]:
class ResBlock(nn.Module):
    def __init__(self, in_chns, out_chns, kernels=[[3, 3, 1], [3, 3, 1]], strides=[[1,1,1], [1,1,1]], dilation_rate=[[1,1,1], [1,1,1]], activation=None, w_init=None, w_reg=None, res=True):
        super().__init__()
        self.in_chns = in_chns
        self.out_chns = out_chns
        self.kernels = kernels
        self.strides = strides
        self.dilation_rate = dilation_rate
        if not activation:
            self.activation = nn.PReLU()
        else:
            self.activation = activation
        self.w_init = w_init
        self.w_reg = w_reg
        self.res = res
        
    
    def forward(self, x):
        output = x
        #print(len(self.kernels))
        in_chns = self.in_chns
        for i in range(len(self.kernels)):
            kernel, stride, dilation = self.kernels[i], self.strides[i], self.dilation_rate[i]
            #print(self.in_chns, self.out_chns, kernel, stride, dilation)
            #print(output.shape)
            self.conv3d = nn.Conv3d(in_chns, self.out_chns, kernel_size=kernel, padding='same', dilation=dilation)
            self.batchnorm = nn.BatchNorm3d(in_chns)
            in_chns = self.out_chns
            output = self.batchnorm(output)
            output = self.activation(output)
            output = self.conv3d(output)
        if self.res:
            if self.in_chns != self.out_chns:
                self.projector = nn.Conv3d(self.in_chns, self.out_chns, kernel_size=1, stride=1, padding='same')
                x = self.projector(x)
            output += x
        #print("Finish block")
        return output
    
            
            
class Conv2dBlock(nn.Module):
    def __init__(self,in_chns, out_chns, kernels, padding=0, strides=[1, 1, 1], activation=nn.PReLU(), w_init=None, w_reg=None, b_init=None, b_reg=None, with_bn=True, deconv=False):
        super().__init__()
        self.in_chns = in_chns
        self.out_chns = out_chns
        self.kernels = kernels
        self.strides = strides
        self.padding = padding
        self.activation = activation
        self.w_init = w_init
        self.w_reg = w_reg
        self.b_init = b_init
        self.b_reg = b_reg
        if not deconv:
            self.conv_block = nn.Conv3d(in_chns, out_chns, kernel_size=kernels, padding=padding, stride=strides, bias=True)
        else:
            self.conv_block = nn.ConvTranspose3d(in_chns, out_chns, kernel_size=kernels, padding=padding, stride=strides, bias=True)
            
        if with_bn:
            self.bn = nn.BatchNorm3d(self.out_chns)
        else:
            self.bn = nn.Identity()
    
    def forward(self, x):
        output = self.conv_block(x)
        output = self.bn(output)
        output = self.activation(output)
        return output
    
class SliceLayer(nn.Module):
    
    def __init__(self, margin=1):
        super().__init__()
        self.margin = margin
        
    def forward(self, x):
        # TODO: Fix this
        return x[:, :, :, :, self.margin:-(self.margin)]
        
        

In [20]:


class MSNet(nn.Module):

    def __init__(
        self,
        in_chns,
        num_classes,
        w_init=None,
        w_reg=None,
        b_init=None,
        b_reg=None,
        activation=nn.PReLU(),
        ):

        # TODO: Add weight init

        super().__init__()
        self.num_classes = num_classes
        (self.w_init, self.w_reg, self.b_init, self.b_reg) = (w_init,
                w_reg, b_init, b_reg)
        self.activation = activation
        self.base_chns = [32, 32, 32, 32]
        self.is_WTNet = True

        # First Block

        self.block1 = nn.Sequential(ResBlock(in_chns,
                                    self.base_chns[0],
                                    activation=activation,
                                    w_init=w_init, w_reg=w_reg),
                                    ResBlock(self.base_chns[0],
                                    self.base_chns[0],
                                    activation=activation,
                                    w_init=w_init, w_reg=w_reg))
        self.fuse1 = Conv2dBlock(
            self.base_chns[0],
            self.base_chns[0],
            kernels=[1, 1, 3],
            padding='valid',
            activation=self.activation,
            w_init=self.w_init,
            w_reg=self.w_reg,
            b_init=self.b_init,
            b_reg=self.b_reg,
            )

        self.downsample1 = Conv2dBlock(
            self.base_chns[0],
            self.base_chns[0],
            kernels=[3, 3, 1],
            strides=[2, 2, 1],
            # padding='same',
            activation=self.activation,
            w_init=self.w_init,
            w_reg=self.w_reg,
            b_init=self.b_init,
            b_reg=self.b_reg,
            )
        self.feature_expand1 = Conv2dBlock(
            self.base_chns[0],
            self.base_chns[1],
            kernels=[1, 1, 1],
            strides=[1, 1, 1],
            # padding='same',
            activation=self.activation,
            w_init=self.w_init,
            w_reg=self.w_reg,
            b_init=self.b_init,
            b_reg=self.b_reg,
            )

        # Second Block

        self.block2 = nn.Sequential(ResBlock(self.base_chns[1],
                                    self.base_chns[1],
                                    activation=activation,
                                    w_init=w_init, w_reg=w_reg),
                                    ResBlock(self.base_chns[1],
                                    self.base_chns[1],
                                    activation=activation,
                                    w_init=w_init, w_reg=w_reg))
        self.fuse2 = Conv2dBlock(
            self.base_chns[1],
            self.base_chns[1],
            kernels=[1, 1, 3],
            padding='valid',
            activation=self.activation,
            w_init=self.w_init,
            w_reg=self.w_reg,
            b_init=self.b_init,
            b_reg=self.b_reg,
            )
        self.downsample2 = Conv2dBlock(
            self.base_chns[1],
            self.base_chns[1],
            kernels=[3, 3, 1],
            strides=[2, 2, 1],
            # padding='same',
            activation=self.activation,
            w_init=self.w_init,
            w_reg=self.w_reg,
            b_init=self.b_init,
            b_reg=self.b_reg,
            )
        self.feature_expand2 = Conv2dBlock(
            self.base_chns[1],
            self.base_chns[2],
            kernels=[1, 1, 1],
            strides=[1, 1, 1],
            # padding='same',
            activation=self.activation,
            w_init=self.w_init,
            w_reg=self.w_reg,
            b_init=self.b_init,
            b_reg=self.b_reg,
            )
        self.pred_1E = nn.Conv3d(self.base_chns[1], self.num_classes,
                                 kernel_size=[3, 3, 1], 
                                 padding='same'
                                 )
        self.pred_1WT = Conv2dBlock(
            self.base_chns[1]
            ,
            self.num_classes,
            kernels=[3, 3, 1],
            strides=[2, 2, 1],
            # padding='same',
            activation=self.activation,
            w_init=self.w_init,
            w_reg=self.w_reg,
            b_init=self.b_init,
            b_reg=self.b_reg,
            deconv=True,
            )

        # Third Block

        self.block3 = nn.Sequential(ResBlock(
            self.base_chns[2],
            self.base_chns[2],
            dilation_rate=[[1, 1, 1], [1, 1, 1]],
            activation=activation,
            w_init=w_init,
            w_reg=w_reg,
            ), ResBlock(
            self.base_chns[2],
            self.base_chns[2],
            strides=[[2, 2, 1], [2, 2, 1]],
            activation=activation,
            w_init=w_init,
            w_reg=w_reg,
            ), ResBlock(
            self.base_chns[2],
            self.base_chns[2],
            dilation_rate=[[3, 3, 1], [3, 3, 1]],
            activation=activation,
            w_init=w_init,
            w_reg=w_reg,
            ))
        self.fuse3 = Conv2dBlock(
            self.base_chns[2],
            self.base_chns[2],
            kernels=[1, 1, 3],
            padding='valid',
            activation=self.activation,
            w_init=self.w_init,
            w_reg=self.w_reg,
            b_init=self.b_init,
            b_reg=self.b_reg,
            )
        self.feature_expand3 = Conv2dBlock(
            self.base_chns[2],
            self.base_chns[3],
            kernels=[1, 1, 1],
            strides=[1, 1, 1],
            # padding='same',
            activation=self.activation,
            w_init=self.w_init,
            w_reg=self.w_reg,
            b_init=self.b_init,
            b_reg=self.b_reg,
            )
        self.pred_21 = Conv2dBlock(
            self.base_chns[2],
            self.num_classes * 2,
            kernels=[3, 3, 1],
            strides=[2, 2, 1],
            # padding='same',
            activation=self.activation,
            w_init=self.w_init,
            w_reg=self.w_reg,
            b_init=self.b_init,
            b_reg=self.b_reg,
            deconv=True,
            )
        self.pred_22 = Conv2dBlock(
            self.num_classes * 2,
            self.num_classes * 2,
            kernels=[3, 3, 1],
            strides=[2, 2, 1],
            # padding='same',
            activation=self.activation,
            w_init=self.w_init,
            w_reg=self.w_reg,
            b_init=self.b_init,
            b_reg=self.b_reg,
            deconv=True,
            )

        # Fourth Block

        self.block4 = nn.Sequential(ResBlock(
            self.base_chns[3],
            self.base_chns[3],
            dilation_rate=[[3, 3, 1], [3, 3, 1]],
            activation=activation,
            w_init=w_init,
            w_reg=w_reg,
            ), ResBlock(
            self.base_chns[3],
            self.base_chns[3],
            dilation_rate=[[2, 2, 1], [2, 2, 1]],
            activation=activation,
            w_init=w_init,
            w_reg=w_reg,
            ), ResBlock(
            self.base_chns[3],
            self.base_chns[3],
            dilation_rate=[[1, 1, 1], [1, 1, 1]],
            activation=activation,
            w_init=w_init,
            w_reg=w_reg,
            ))
        self.fuse4 = Conv2dBlock(
            self.base_chns[3],
            self.base_chns[3],
            kernels=[1, 1, 3],
            padding='valid',
            activation=self.activation,
            w_init=self.w_init,
            w_reg=self.w_reg,
            b_init=self.b_init,
            b_reg=self.b_reg,
            )
        self.pred_31 = Conv2dBlock(
            self.base_chns[3],
            self.num_classes * 4,
            kernels=[3, 3, 1],
            strides=[2, 2, 1],
            # padding='same',
            activation=self.activation,
            w_init=self.w_init,
            w_reg=self.w_reg,
            b_init=self.b_init,
            b_reg=self.b_reg,
            deconv=True,
            )
        self.pred_32 = Conv2dBlock(
            self.num_classes * 4,
            self.num_classes * 4,
            kernels=[3, 3, 1],
            strides=[2, 2, 1],
           #  padding='same',
            activation=self.activation,
            w_init=self.w_init,
            w_reg=self.w_reg,
            b_init=self.b_init,
            b_reg=self.b_reg,
            deconv=True,
            )

        # TODO: Change this MAYBE

        self.final_pred = nn.Conv3d(14,
                                    self.num_classes, kernel_size=[3, 3, 1], padding='same')
        self.centra_slice1 = SliceLayer(margin=2)
        self.centra_slice2 = SliceLayer(margin=1)

    def forward(self, x):
        f1 = x
        # print("f1", f1.shape)
        f1 = self.block1(f1)
        # print("f1", f1.shape)
        f1 = self.fuse1(f1)
        # print("f1", f1.shape)
        if self.is_WTNet:
            f1 = self.downsample1(f1)
        if self.base_chns[0] != self.base_chns[1]:
            f1 = self.feature_expand1(f1)
        # print("f1", f1.shape)
        f1 = self.block2(f1)
        # print("f1", f1.shape)
        f1 = self.fuse2(f1)
        # print("f1", f1.shape)
        f2 = self.downsample2(f1)
        if self.base_chns[1] != self.base_chns[2]:
            f2 = self.feature_expand1(f2)
        f2 = self.block3(f2)
        f2 = self.fuse3(f2)
        # print("f2", f2.shape)
        f3 = f2
        if self.base_chns[2] != self.base_chns[3]:
            f3 = self.feature_expand1(f3)
        f3 = self.block4(f3)
        f3 = self.fuse3(f3)
        # print("f3", f3.shape)
        # Prediction

        p1 = self.centra_slice1(f1)
        # print(f1.shape, p1.shape)
        # print(p1.shape)
        if self.is_WTNet:
            p1 = self.pred_1WT(p1)
        else:
            p1 = self.pred_1E(p1)

        p2 = self.centra_slice2(f2)
        # print(f2.shape, p2.shape)
        p2 = self.pred_21(p2)
        if self.is_WTNet:
            p2 = self.pred_22(p2)

        p3 = self.pred_31(f3)
        if self.is_WTNet:
            p3 = self.pred_32(p3)

        # print(p1.shape, p2.shape, p3.shape)
        combine = torch.cat([p1, p2, p3], 1)
        # print(combine.shape)
        return self.final_pred(combine)


In [5]:
from __future__ import absolute_import, print_function

import numpy as np
import random
from scipy import ndimage
import time
import os
import sys
# import tensorflow as tf
# from niftynet.layer.loss_segmentation import LossFunction
from util.data_loader import *
from util.train_test_func import *
from util.parse_config import parse_config
from kornia.losses import dice_loss as DiceLoss
import torch
import torch.nn as nn

In [6]:
class NetFactory(object):
    @staticmethod
    def create(name):
        if name == 'MSNet':
            return MSNet
        # add your own networks here
        print('unsupported network:', name)
        exit()

In [7]:
config_file = './config17/train_wt_ax.txt'

## 1. load configuration parameters

In [8]:
config = parse_config(config_file)
config_data  = config['data']
config_net   = config['network']
config_train = config['training']
     
random.seed(config_train.get('random_seed', 1))
assert(config_data['with_ground_truth'])

net_type    = config_net['net_type']
net_name    = config_net['net_name']
class_num   = config_net['class_num']
batch_size  = config_data.get('batch_size', 5)

data data_root /home/dd/CSC490_Braindon/brats17/Brats17TrainingData /home/dd/CSC490_Braindon/brats17/Brats17TrainingData
data data_names config17/train_names_temp.txt config17/train_names_temp.txt
data modality_postfix [flair, t1, t1ce, t2] ['flair', 't1', 't1ce', 't2']
data label_postfix seg seg
data file_postfix nii.gz nii.gz
data with_ground_truth True True
data batch_size 5 5
data data_shape [4, 144, 144, 19] [4, 144, 144, 19]
data label_shape [1, 144, 144, 11] [1, 144, 144, 11]
data label_convert_source [0, 1, 2, 4] [0, 1, 2, 4]
data label_convert_target [0, 1, 1, 1] [0, 1, 1, 1]
data batch_slice_direction axial axial
data train_with_roi_patch False False
data label_roi_mask  None
data roi_patch_margin  None
network net_type MSNet MSNet
network net_name MSNet_WT32 MSNet_WT32
network downsample_twice True True
network class_num 2 2
training learning_rate 1e-3 0.001
training decay 1e-7 1e-07
training maximal_iteration 20000 20000
training snapshot_iteration 5000 5000
training start_

## 2. construct graph

In [21]:
full_data_shape  = [batch_size] + config_data['data_shape']
full_label_shape = [batch_size] + config_data['label_shape']
x = torch.zeros(full_data_shape, dtype=torch.float32, requires_grad=True)
y = torch.zeros(full_label_shape, dtype=torch.float32, requires_grad=True)
w = torch.zeros(full_label_shape, dtype=torch.float32, requires_grad=True)
   
w_regularizer = config_train.get('decay', 1e-7)
b_regularizer = config_train.get('decay', 1e-7)
net_class = NetFactory.create(net_type)
net = net_class(in_chns = full_data_shape[1], # not sure
                    num_classes = class_num,
                    w_reg = w_regularizer,
                    b_reg = b_regularizer)
predicty = net(x)


f1 torch.Size([5, 4, 144, 144, 19])
f1 torch.Size([5, 32, 144, 144, 19])
f1 torch.Size([5, 32, 144, 144, 17])
f1 torch.Size([5, 32, 71, 71, 17])
f1 torch.Size([5, 32, 71, 71, 17])
f1 torch.Size([5, 32, 71, 71, 15])
f2 torch.Size([5, 32, 35, 35, 13])
f3 torch.Size([5, 32, 35, 35, 11])
torch.Size([5, 32, 71, 71, 15]) torch.Size([5, 32, 71, 71, 11])
torch.Size([5, 32, 71, 71, 11])
torch.Size([5, 32, 35, 35, 13]) torch.Size([5, 32, 35, 35, 11])
torch.Size([5, 2, 143, 143, 11]) torch.Size([5, 4, 143, 143, 11]) torch.Size([5, 8, 143, 143, 11])
torch.Size([5, 14, 143, 143, 11])


In [22]:
predicty

tensor([[[[[-0.2741, -0.2741, -0.2741,  ..., -0.2741, -0.2741, -0.2741],
           [-1.0037, -1.0037, -1.0037,  ..., -1.0037, -1.0037, -1.0037],
           [-1.2208, -1.2208, -1.2208,  ..., -1.2208, -1.2208, -1.2208],
           ...,
           [-1.1790, -1.1790, -1.1790,  ..., -1.1790, -1.1790, -1.1790],
           [-0.5814, -0.5814, -0.5814,  ..., -0.5814, -0.5814, -0.5814],
           [-0.3892, -0.3892, -0.3892,  ..., -0.3892, -0.3892, -0.3892]],

          [[ 0.0617,  0.0617,  0.0617,  ...,  0.0617,  0.0617,  0.0617],
           [-0.2043, -0.2043, -0.2043,  ..., -0.2043, -0.2043, -0.2043],
           [-0.1849, -0.1849, -0.1849,  ..., -0.1849, -0.1849, -0.1849],
           ...,
           [-1.2913, -1.2913, -1.2913,  ..., -1.2913, -1.2913, -1.2913],
           [-1.2203, -1.2203, -1.2203,  ..., -1.2203, -1.2203, -1.2203],
           [ 0.0426,  0.0426,  0.0426,  ...,  0.0426,  0.0426,  0.0426]],

          [[-0.2761, -0.2761, -0.2761,  ..., -0.2761, -0.2761, -0.2761],
           [-0.

In [105]:
full_data_shape

[5, 19, 4, 144, 144]

In [113]:
config_train

{'learning_rate': 0.001,
 'decay': 1e-07,
 'maximal_iteration': 20000,
 'snapshot_iteration': 5000,
 'start_iteration': 0,
 'test_iteration': 100,
 'test_step': 10,
 'model_pre_trained': None,
 'model_save_prefix': 'model17/msnet_wt32'}

In [114]:
dataloader = DataLoader(config_data)
dataloader.load_data()


* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0
  data = img.get_data()

* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0
  data = img.get_data()

* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0
  data = img.get_data()

* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0
  data = img.get_data()

* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0
  data = img.get_data()

* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0
  data = img.get_data()

* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0
  data = img.get_data()

* deprecated

Data load, 100.0% finished


In [116]:
dic = dataloader.get_subimage_batch()
dic

{'images': array([[[[[ 5.43297231e-01,  4.85446453e-01, -1.09068483e-01,
            -1.28326559e+00],
           [ 1.49805272e+00,  2.62325346e-01, -1.84744799e+00,
             1.20542037e+00],
           [-4.36174989e-01,  6.58175886e-01,  1.16142821e+00,
            -3.09877038e+00],
           ...,
           [-1.29220104e+00, -7.32044995e-01,  9.02587593e-01,
             2.73444414e-01],
           [ 4.20852862e-02, -1.46292734e+00,  1.23183382e+00,
             1.33232594e-01],
           [-4.04372439e-02, -5.39924145e-01, -1.74720681e+00,
            -1.36388624e+00]],
 
          [[-1.50022341e-03, -6.56122342e-02, -4.68413115e-01,
            -7.61003852e-01],
           [-1.51097703e+00, -1.36721110e+00,  5.04708588e-01,
             9.26590785e-02],
           [ 1.67009068e+00, -1.50380397e+00,  1.15556195e-01,
            -3.33671361e-01],
           ...,
           [ 3.01709026e-01,  3.68712306e-01,  1.29268634e+00,
             9.95337307e-01],
           [ 1.61970401e+

In [119]:
dic['images'].shape

(5, 32, 144, 144, 4)

In [40]:
a = torch.zeros([5, 2, 4, 143, 137], dtype=torch.float32, requires_grad=True)
b = torch.zeros([5, 4, 4, 143, 127], dtype=torch.float32, requires_grad=True)
c = torch.zeros([5, 8, 4, 143, 119], dtype=torch.float32, requires_grad=True)

In [41]:
torch.cat([a, b, c], 1).shape

RuntimeError: torch.cat(): Sizes of tensors must match except in dimension 1. Got 2 and 3 in dimension 3 (The offending index is 1)