In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
import random
import re

from tqdm import tqdm
import time

import pydicom as dicom
import nibabel as nib
import SimpleITK as sitk
import monai

import torch
import torch.nn as nn
import torch.optim as optim


In [2]:
SEED = 344
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True # Fix the network according to random seed
    print('Finish seeding with seed {}'.format(seed))
    
seed_everything(SEED)
print('Training on device {}'.format(device))

Finish seeding with seed 344
Training on device cuda


In [3]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

class inputBlock(nn.Module):
    def __init__(self, ch_in):
        super(inputBlock, self).__init__()
        self.seq = nn.Sequential(
            nn.Conv3d(ch_in, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.MaxPool3d(3, stride=2, padding=1)
        )
        
    def forward(self, x):
        return self.seq(x)
    
    
class BasicBlock(nn.Module):
    def __init__(self, ch_in, ch_out, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv3d(ch_in, ch_out, kernel_size=3, padding=1, stride=stride)
        self.conv2 = nn.Conv3d(ch_out, ch_out, kernel_size=3, padding=1, stride=1)
        
        self.bn1 = nn.BatchNorm3d(ch_out)
        self.bn2 = nn.BatchNorm3d(ch_out)
        
        if (ch_in == ch_out and stride == 1):
            self.conv_skip = None
        else:
            self.conv_skip = nn.Sequential(
                nn.Conv3d(ch_in, ch_out, kernel_size=1, stride=stride),
                nn.BatchNorm3d(ch_out)
            )
            
        
    def forward(self, x):
        fx = F.relu(self.bn1(self.conv1(x)))
        fx = self.bn2(self.conv2(fx))
        if self.conv_skip:
            fx = fx + self.conv_skip(x)
        else:
            fx = fx + x
        return F.relu(fx)
    
class ResNet18(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(ResNet18, self).__init__()
        self.input_conv = inputBlock(ch_in=ch_in)
        self.blocks = nn.Sequential(
            BasicBlock(ch_in=64, ch_out=64, stride=1),
            BasicBlock(ch_in=64, ch_out=64, stride=1),
            BasicBlock(ch_in=64, ch_out=128, stride=2),
            BasicBlock(ch_in=128, ch_out=128, stride=1),
            BasicBlock(ch_in=128, ch_out=256, stride=2),
            BasicBlock(ch_in=256, ch_out=256, stride=1),
            BasicBlock(ch_in=256, ch_out=512, stride=2),
            BasicBlock(ch_in=512, ch_out=512, stride=1),
        )
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool3d((1, 1, 1)),
            nn.Flatten(),
            nn.Linear(512, ch_out)
        )
        
    def forward(self, x):
        x = self.input_conv(x)
        x = self.blocks(x)
        x = self.classifier(x)
        return x.squeeze(-1)
    
    
class ResNet34(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(ResNet34, self).__init__()
        self.input_conv = inputBlock(ch_in=ch_in)
        self.blocks = nn.Sequential(
            BasicBlock(ch_in=64, ch_out=64, stride=1),
            BasicBlock(ch_in=64, ch_out=64, stride=1),
            BasicBlock(ch_in=64, ch_out=64, stride=1),
            BasicBlock(ch_in=64, ch_out=128, stride=2),
            BasicBlock(ch_in=128, ch_out=128, stride=1),
            BasicBlock(ch_in=128, ch_out=128, stride=1),
            BasicBlock(ch_in=128, ch_out=128, stride=1),
            BasicBlock(ch_in=128, ch_out=256, stride=2),
            BasicBlock(ch_in=256, ch_out=256, stride=1),
            BasicBlock(ch_in=256, ch_out=256, stride=1),
            BasicBlock(ch_in=256, ch_out=256, stride=1),
            BasicBlock(ch_in=256, ch_out=256, stride=1),
            BasicBlock(ch_in=256, ch_out=256, stride=1),
            BasicBlock(ch_in=256, ch_out=512, stride=2),
            BasicBlock(ch_in=512, ch_out=512, stride=1),
            BasicBlock(ch_in=512, ch_out=512, stride=1),
        )
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool3d((1, 1, 1)),
            nn.Flatten(),
            nn.Linear(512, ch_out)
        )
        
    def forward(self, x):
        x = self.input_conv(x)
        x = self.blocks(x)
        x = self.classifier(x)
        #return F.softmax(x, dim=1)
        return x.squeeze(-1)


class BottleNeck(nn.Module):
    def __init__(self, channels, stride=1):
        super(BottleNeck, self).__init__()
        ch_in = channels[0]
        ch_middle = channels[1]
        ch_out = channels[2]
        self.conv1 = nn.Conv3d(ch_in, ch_middle, kernel_size=1, stride=1)
        self.conv2 = nn.Conv3d(ch_middle, ch_middle, kernel_size=3, padding=1, stride=stride)
        self.conv3 = nn.Conv3d(ch_middle, ch_out, kernel_size=1, stride=1)
        
        self.bn1 = nn.BatchNorm3d(ch_middle)
        self.bn2 = nn.BatchNorm3d(ch_middle)
        self.bn3 = nn.BatchNorm3d(ch_out)
        
        self.conv_skip = nn.Sequential(
                nn.Conv3d(ch_in, ch_out, kernel_size=1, stride=stride),
                nn.BatchNorm3d(ch_out)
            )
        
    def forward(self, x):
        fx = F.relu(self.bn1(self.conv1(x)))
        fx = F.relu(self.bn2(self.conv2(fx)))
        fx = self.bn3(self.conv3(fx))
        fx = fx + self.conv_skip(x)

        return F.relu(fx)
    
class MiddleBottleNeck(nn.Module):
    def __init__(self, num):
        super(MiddleBottleNeck, self).__init__()
        layers = []
        for i in range(0, num):
            layers.append(BottleNeck(channels=(1024, 256, 1024), stride=1))
        self.blocks = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.blocks(x)
    
class ResNet50(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(ResNet50, self).__init__()
        self.input_conv = inputBlock(ch_in=ch_in)
        self.blocks = nn.Sequential(
            BottleNeck(channels=(64, 64, 256), stride=1),
            BottleNeck(channels=(256, 64, 256), stride=1),
            BottleNeck(channels=(256, 64, 256), stride=1),
            BottleNeck(channels=(256, 128, 512), stride=2),
            BottleNeck(channels=(512, 128, 512), stride=1),
            BottleNeck(channels=(512, 128, 512), stride=1),
            BottleNeck(channels=(512, 128, 512), stride=1),
            BottleNeck(channels=(512, 256, 1024), stride=2),
            BottleNeck(channels=(1024, 256, 1024), stride=1),
            BottleNeck(channels=(1024, 256, 1024), stride=1),
            BottleNeck(channels=(1024, 256, 1024), stride=1),
            BottleNeck(channels=(1024, 256, 1024), stride=1),
            BottleNeck(channels=(1024, 256, 1024), stride=1),
            BottleNeck(channels=(1024, 512, 2048), stride=2),
            BottleNeck(channels=(2048, 512, 2048), stride=1),
            BottleNeck(channels=(2048, 512, 2048), stride=1),
        )
        
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool3d((1, 1, 1)),
            nn.Flatten(),
            nn.Linear(2048, ch_out)
        )
        
    def forward(self, x):
        x = self.input_conv(x)
        x = self.blocks(x)
        x = self.classifier(x)
        return x.squeeze(-1)
    
class ResNet101(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(ResNet101, self).__init__()
        self.input_conv = inputBlock(ch_in=ch_in)
        self.blocks = nn.Sequential(
            BottleNeck(channels=(64, 64, 256), stride=1),
            BottleNeck(channels=(256, 64, 256), stride=1),
            BottleNeck(channels=(256, 64, 256), stride=1),
            BottleNeck(channels=(256, 128, 512), stride=2),
            BottleNeck(channels=(512, 128, 512), stride=1),
            BottleNeck(channels=(512, 128, 512), stride=1),
            BottleNeck(channels=(512, 128, 512), stride=1),
            BottleNeck(channels=(512, 256, 1024), stride=2),
            MiddleBottleNeck(num=22),
            BottleNeck(channels=(1024, 512, 2048), stride=2),
            BottleNeck(channels=(2048, 512, 2048), stride=1),
            BottleNeck(channels=(2048, 512, 2048), stride=1),
        )
        
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool3d((1, 1, 1)),
            nn.Flatten(),
            nn.Linear(2048, ch_out)
        )
        
    def forward(self, x):
        x = self.input_conv(x)
        x = self.blocks(x)
        x = self.classifier(x)
        return x.squeeze(-1)
    

In [5]:
net = ResNet18(1, 9).to(device)

from torchsummary import summary
summary(net, input_size=(1, 160, 128, 128))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 64, 80, 64, 64]          22,016
       BatchNorm3d-2       [-1, 64, 80, 64, 64]             128
              ReLU-3       [-1, 64, 80, 64, 64]               0
         MaxPool3d-4       [-1, 64, 40, 32, 32]               0
        inputBlock-5       [-1, 64, 40, 32, 32]               0
            Conv3d-6       [-1, 64, 40, 32, 32]         110,656
       BatchNorm3d-7       [-1, 64, 40, 32, 32]             128
            Conv3d-8       [-1, 64, 40, 32, 32]         110,656
       BatchNorm3d-9       [-1, 64, 40, 32, 32]             128
       BasicBlock-10       [-1, 64, 40, 32, 32]               0
           Conv3d-11       [-1, 64, 40, 32, 32]         110,656
      BatchNorm3d-12       [-1, 64, 40, 32, 32]             128
           Conv3d-13       [-1, 64, 40, 32, 32]         110,656
      BatchNorm3d-14       [-1, 64, 40,