In [None]:
import torch
from torch import nn
import numpy as np
import torch.nn.init as init
from PIL import Image
import SimpleITK as sitk
import matplotlib.pyplot as plt
import os
from os import listdir
from os.path import isfile, join, splitext
import torch.utils.data

from torch.nn import functional as F
from IPython import display
from IPython.display import clear_output
import time

In [None]:
class Timer:
    # record the running time
    def __init__(self):
        self.times = []
        self.start()
    
    def start(self):
        self.tik = time.time()
        
    def stop(self):
        self.times.append(time.time() - self.tik)
        return self.times[-1]
    
    def avg(self):
        return sum(self.times)/len(self.times)
    
    def sum(self):
        return sum(self.times)
    
    def cumsum(self):
        return np.array(self.times).cumsum().tolist()

In [None]:
def try_gpu(i = 0):
    # if the gpu exist in PC, it will return gpu(i), otherwise return cpu()
    if torch.cuda.device_count() >= i+1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')

In [None]:
def DiceCo(X,y):
    intersection = (X*y).sum()
    union1 = X.sum() + y.sum()
    DiceFactor = (2*intersection)/(union1)
    return DiceFactor

In [None]:
def evaluate_accuracy_gpu(model, data_iter, device=None):
    if isinstance(model, nn.Module):
        model.eval()
        if not device:
            device = next(iter(model.parameters())).device
    metric = Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            metric.add(DiceCo(model(X),y), 1)
    return metric[0]/metric[1]

In [None]:
def accuracy(y_hat, y):
    numerator = 2*((y_hat*y).sum())
    return float(numerator)

In [None]:
class Accumulator:
    #在n个变量上累加
    def __init__(self,n):
        self.data = [0.0]*n
    
    def add(self, *args):
        self.data = [a+float(b) for a,b in zip(self.data, args)]
    
    def reset(self):
        self.data = [0.0]*len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):
    axes.set_xlabel(xlabel)
    axes.set_ylabel(ylabel)
    axes.set_xlim(xlim)
    axes.set_ylim(ylim)
    axes.set_xscale(xscale)
    axes.set_yscale(yscale)
    if legend:
        axes.legend(legend)
    axes.grid()

In [None]:
class Animator:
    #在动画中绘制数据
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None, ylim=None, xscale='linear', yscale='linear',fmts=('-','m--','g-.','r:'),nrows=1,ncols=1,figsize=(5,3)):
        #增量地绘制多条线
        if legend is None:
            legend = []
        self.fig, self.axes = plt.subplots(nrows,ncols,figsize=(5,3))
        if nrows*ncols == 1:
            self.axes = [self.axes, ]
        #使用lambda函数捕获参数
        self.config_axes = lambda: set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None,None,fmts

    
    def add(self, x, y):
        #向图表中添加多个数据点
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x]*n
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x,y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b) 
        self.axes[0].cla()
        for x, y, fmt in zip(self.X, self.Y, self.fmts):
            self.axes[0].plot(x, y, fmt)
        self.config_axes()
        display.display(self.fig)
        display.clear_output(wait=True)

In [None]:
basePath=os.getcwd()

params = dict()
params['DataManagerParams']=dict()
params['ModelParams']=dict()

#params of the algorithm
params['ModelParams']['dirTrain']=os.path.join(basePath,'training_data')
params['ModelParams']['dirTest']=os.path.join(basePath,'test_data')
params['ModelParams']['dirResult']=os.path.join(basePath,'Results') #where we need to save the results (relative to the base path)

#params of the DataManager
params['DataManagerParams']['dstRes'] = np.asarray([1,1,1.5],dtype=float)
params['DataManagerParams']['VolSize'] = np.asarray([128,128,64],dtype=int)
params['DataManagerParams']['normDir'] = False #if rotates the volume according to its transformation in the mhd file. Not reccommended.

In [None]:
class DataManager(object):
    params = None
    srcFolder = None
    resultsDir = None
    
    fileList = None
    gtList = None
    
    sitkImages = None
    sitkGT = None
    meanIntensityTrain = None
    
    def __init__(self, srcFolder, resultsDir, parameters):
        self.params = parameters
        self.srcFolder = srcFolder
        self.resultsDir = resultsDir
        
    def createImageFileList(self):
        self.fileList = [f for f in listdir(self.srcFolder) if isfile(join(self.srcFolder, f)) and 'segmentation' not in f and 'raw' not in f]
        print('FILE LIST: ' + str(self.fileList))
    
    def createGTFileList(self):
        self.gtList = list()
        for f in self.fileList:
            filename, ext = splitext(f)
            self.gtList.append(join(filename + '_segmentation' + ext))
        print('GT LIST: ' + str(self.gtList))
            
    def loadImages(self):
        self.sitkImages = dict()
        rescalFilt = sitk.RescaleIntensityImageFilter()
        rescalFilt.SetOutputMaximum(1)
        rescalFilt.SetOutputMinimum(0)
        
        stats = sitk.StatisticsImageFilter()
        m = 0.
        for f in self.fileList:
            self.sitkImages[f] = rescalFilt.Execute(sitk.Cast(sitk.ReadImage(join(self.srcFolder, f)), sitk.sitkFloat32))
            stats.Execute(self.sitkImages[f])
            m += stats.GetMean()
        self.meanIntensityTrain = m/len(self.sitkImages)
        
    def loadGT(self):
        self.sitkGT = dict()
        for f in self.gtList:
            self.sitkGT[f] = sitk.Cast(sitk.ReadImage(join(self.srcFolder, f))>0.5, sitk.sitkFloat32)
            
    def loadTrainingData(self):
        self.createImageFileList()
        self.createGTFileList()
        self.loadImages()
        self.loadGT()
        
    def loadTestData(self):
        self.createImageFileList()
        self.loadImages()
        
    def getNumpyImages(self):
        dat = self.getNumpyData(self.sitkImages, sitk.sitkLinear)
        return dat
    
    def getNumpyGT(self):
        dat = self.getNumpyData(self.sitkGT, sitk.sitkLinear)
        for key in dat:
            dat[key] = (dat[key]>0.5).astype(dtype = np.float32)
        return dat
    
    def getNumpyData(self, dat, method):
        ret = dict()
        for key in dat:
            ret[key] = np.zeros([self.params['VolSize'][0], self.params['VolSize'][1], self.params['VolSize'][2]], dtype = np.float32)
            img = dat[key]
            #we rotate the image according to its transformation using the direction and according to the final spacing we want
            factor = np.asarray(img.GetSpacing()) / np.asarray(self.params['dstRes'])
            factorSize = np.asarray(img.GetSize() * factor, dtype = float)
            newSize = np.max([factorSize, self.params['VolSize']], axis = 0)
            newSize = newSize.astype(dtype=int)
            
            T = sitk.AffineTransform(3)
            T.SetMatrix(img.GetDirection())
            
            resampler = sitk.ResampleImageFilter()
            resampler.SetReferenceImage(img)
            resampler.SetOutputSpacing([self.params['dstRes'][0], self.params['dstRes'][1], self.params['dstRes'][2]])
            resampler.SetSize(newSize.tolist())
            resampler.SetInterpolator(method)
            if self.params['normDir']:
                resampler.SetTransform(T.GetInverse())
            imgResampled = resampler.Execute(img)
            imgCentroid = np.asarray(newSize, dtype = float) / 2.0
            imgStartPx = (imgCentroid - self.params['VolSize']/2.0).astype(dtype = int)
            regionExtractor = sitk.RegionOfInterestImageFilter()
            regionExtractor.SetSize(self.params['VolSize'].astype(dtype = int).tolist())
            regionExtractor.SetIndex(imgStartPx.tolist())
            imgResampledCropped = regionExtractor.Execute(imgResampled)
            ret[key] = np.transpose(sitk.GetArrayFromImage(imgResampledCropped).astype(dtype=float), [2,1,0]) #dimension transformation for transverse/sagittal/coronal
        return ret

In [None]:
#we define here a data manage object for Training
dataManagerTrain = DataManager(params['ModelParams']['dirTrain'],
                               params['ModelParams']['dirResult'],
                               params['DataManagerParams'])

dataManagerTrain.loadTrainingData()

In [None]:
howManyImages_forTraining = len(dataManagerTrain.sitkImages)
howManyGT_forTraining = len(dataManagerTrain.sitkGT)

assert howManyGT_forTraining == howManyImages_forTraining

print("The dataset has shape: data - " + str(howManyImages_forTraining) + ". labels - " + str(howManyGT_forTraining))

In [None]:
numpyImages_Train = dataManagerTrain.getNumpyImages()
numpyGT_Train = dataManagerTrain.getNumpyGT()

for key in numpyImages_Train:
    mean_Train = np.mean(numpyImages_Train[key][numpyImages_Train[key]>0])
    std_Train = np.std(numpyImages_Train[key][numpyImages_Train[key]>0])

    numpyImages_Train[key]-=mean_Train
    numpyImages_Train[key]/=std_Train

tensorImages_Train = {key: torch.tensor(value).float() for key, value in numpyImages_Train.items()}
tensorGT_Train = {key: torch.tensor(value).float() for key, value in numpyGT_Train.items()}

merged_tensorImages_Train = torch.stack(list(tensorImages_Train.values()))
merged_tensorGT_Train = torch.stack(list(tensorGT_Train.values()))

In [None]:
#we define here a data manage object for Test
dataManagerTest = DataManager(params['ModelParams']['dirTest'],
                               params['ModelParams']['dirResult'],
                               params['DataManagerParams'])

dataManagerTest.loadTrainingData()

In [None]:
howManyImages_forTest = len(dataManagerTest.sitkImages)
howManyGT_forTest = len(dataManagerTest.sitkGT)

assert howManyGT_forTest == howManyImages_forTest

print("The dataset has shape: data - " + str(howManyImages_forTest) + ". labels - " + str(howManyGT_forTest))

In [None]:
numpyImages_Test = dataManagerTest.getNumpyImages()
numpyGT_Test = dataManagerTest.getNumpyGT()

for key in numpyImages_Test:
    mean_Test = np.mean(numpyImages_Test[key][numpyImages_Test[key]>0])
    std_Test = np.std(numpyImages_Test[key][numpyImages_Test[key]>0])

    numpyImages_Test[key]-=mean_Test
    numpyImages_Test[key]/=std_Test

tensorImages_Test = {key: torch.tensor(value).float() for key, value in numpyImages_Test.items()}
tensorGT_Test = {key: torch.tensor(value).float() for key, value in numpyGT_Test.items()}

merged_tensorImages_Test = torch.stack(list(tensorImages_Test.values()))
merged_tensorGT_Test = torch.stack(list(tensorGT_Test.values()))

In [None]:
torch_dataset_Train = torch.utils.data.TensorDataset(merged_tensorImages_Train, merged_tensorGT_Train)

torch_dataset_Test = torch.utils.data.TensorDataset(merged_tensorImages_Test, merged_tensorGT_Test)


BATCH_SIZE = 1

train_iter = torch.utils.data.DataLoader(
    dataset = torch_dataset_Train,
    batch_size = BATCH_SIZE,
    shuffle = True,
    #num_works = 2,
)

test_iter = torch.utils.data.DataLoader(
    dataset = torch_dataset_Test,
    batch_size = BATCH_SIZE,
    shuffle = False,
    #num_works = 2,
)

In [None]:
def softmax(X, Y):
    X_sm = (torch.exp(Y))/(torch.exp(X)+torch.exp(Y))
    return X_sm

In [None]:
class Residual(nn.Module):
    def __init__(self, input_channels, num_channels, use_1x1conv = False, strides = 1, shortcuts = 1):
        super().__init__()
        self.conv1 = nn.Conv3d(input_channels, num_channels, kernel_size = 5, padding = 2, stride = strides)
        
        if shortcuts == 2:
            self.conv2 = nn.Sequential(
                nn.PReLU(),
                nn.Conv3d(num_channels, num_channels, kernel_size = 5, padding = 2, stride = 1))
            self.conv3 = None
        elif shortcuts == 3:
            self.conv2 = nn.Sequential(
                nn.PReLU(),
                nn.Conv3d(num_channels, num_channels, kernel_size = 5, padding = 2, stride = 1))
            self.conv3 = nn.Sequential(
                nn.PReLU(),
                nn.Conv3d(num_channels, num_channels, kernel_size = 5, padding = 2, stride = 1))
        else:
            self.conv2 = None
            self.conv3 = None
        
        if use_1x1conv:
            self.conv4 = nn.Conv3d(input_channels, num_channels, kernel_size = 1, stride = strides)
        else:
            self.conv4 = None
    
    def forward(self, X):
        Y = self.conv1(X)
        
        if self.conv2 and self.conv3 == None:
            Y = self.conv2(Y)
        elif self.conv2 and self.conv3:
            Y = self.conv2(Y)
            Y = self.conv3(Y)
        
        if self.conv4:
            X = self.conv4(X)
        
        Y += X
        
        return Y

In [None]:
b1 = Residual(1, 16, use_1x1conv = True)
b2 = Residual(input_channels = 32, num_channels = 32)
b3 = Residual(input_channels = 64, num_channels = 64, shortcuts = 2)
b4 = Residual(input_channels = 128, num_channels = 128, shortcuts = 3)
b5 = Residual(input_channels = 256, num_channels = 256, shortcuts = 3)
b6 = Residual(input_channels = 128, num_channels = 128, shortcuts = 2)
b7 = Residual(input_channels = 64, num_channels = 64, shortcuts = 1)

In [None]:
class EndToEndModel(nn.Module):
    def __init__(self, in_channel):
        super(EndToEndModel, self).__init__()
        self.layer1 = nn.Sequential(
            b1,
            nn.PReLU()
        )
        
        self.layer2 = nn.Sequential(
            nn.Conv3d(in_channels = 16, out_channels = 32, kernel_size = 2, padding = 0, stride = 2),
            nn.PReLU(),
            b2,
            nn.PReLU()
        )
        
        self.layer3 = nn.Sequential(
            nn.Conv3d(in_channels = 32, out_channels = 64, kernel_size = 2, padding = 0, stride = 2),
            nn.PReLU(),
            b3,
            nn.PReLU()
        )
        
        self.layer4 = nn.Sequential(
            nn.Conv3d(in_channels = 64, out_channels = 128, kernel_size = 2, padding = 0, stride = 2),
            nn.PReLU(),
            b4,
            nn.PReLU()
        )
        
        self.layer5 = nn.Sequential(
            nn.Conv3d(in_channels = 128, out_channels = 256, kernel_size = 2, padding = 0, stride = 2),
            nn.PReLU(),
            b5,
            nn.PReLU(),
            nn.ConvTranspose3d(in_channels = 256, out_channels = 128, kernel_size = 2, stride = 2),
            nn.PReLU()
        )
        
        self.layer6 = nn.Sequential(
            b5,
            nn.PReLU(),
            nn.ConvTranspose3d(in_channels = 256, out_channels = 64, kernel_size = 2, padding = 0, stride = 2),
            nn.PReLU()
        )
        
        self.layer7 = nn.Sequential(
            b6,
            nn.PReLU(),
            nn.ConvTranspose3d(in_channels = 128, out_channels = 32, kernel_size = 2, padding = 0, stride = 2),
            nn.PReLU()
        )
        
        self.layer8 = nn.Sequential(
            b7,
            nn.PReLU(),
            nn.ConvTranspose3d(in_channels = 64, out_channels = 16, kernel_size = 2, padding = 0, stride = 2),
            nn.PReLU()
        )
        
        self.layer9 = nn.Sequential(
            b2,
            nn.PReLU(),
            nn.Conv3d(in_channels = 32, out_channels = 2, kernel_size = 5, padding = 2, stride = 1),
            nn.PReLU(),
            nn.Conv3d(in_channels = 2, out_channels = 2, kernel_size = 1, stride = 1),
        )
        
        
    def forward(self, x):
        x = self.layer1(x)
        y1 = x
        x = self.layer2(x)
        y2 = x
        x = self.layer3(x)
        y3 = x
        x = self.layer4(x)
        y4 = x
        x = self.layer5(x)
        x = torch.cat((x,y4),dim = 0)
        x = self.layer6(x)
        x = torch.cat((x,y3),dim = 0)
        x = self.layer7(x)
        x = torch.cat((x,y2),dim = 0)
        x = self.layer8(x)
        x = torch.cat((x,y1),dim = 0)
        x = self.layer9(x)
        x = softmax(X = x[0, :, :, :], Y = x[1, :, :, :])
        return x

In [None]:
model = EndToEndModel(1)
print(model)

In [None]:
class dice_loss(nn.Module):
    def __init__(self, smooth = 0.01):
        super(dice_loss, self).__init__()
        self.smooth = smooth
    
    def forward(self, pred, target):
        denominator_1 = (pred**2).sum()
        denominator_2 = (target**2).sum()
        union = denominator_1+denominator_2
        numerator = 2*((pred*target).sum())
        dice_coefficient = (numerator + self.smooth)/(union + self.smooth)
        #dice_coefficient = (numerator)/(union)
        loss = 1-dice_coefficient
        return loss

In [None]:
def msra_init(module: nn.Module):
    if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d, nn.Conv3d, nn.ConvTranspose3d)):
        nn.init.kaiming_normal_(module.weight.data)
        if module.bias is not None:
            module.bias.data.zero_()
    elif isinstance(module, (nn.Linear)):
        nn.init.kaiming_uniform_(module.weight.data)
        if module.bias is not None:
            module.bias.data.zero()
    elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm1d)):
        module.weight.data.normal_(1.0, 0.02)
        module.bias.data.zero_()

In [None]:
def train(model, train_iter, test_iter, num_epochs, lr, device):
    def init_weights(m):
        for m in model.modules():
            msra_init(m)
        init.constant_(model.layer1[0].conv4.weight, 1)
    model.apply(init_weights)
    print('training on', device)
    model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr = lr, momentum = 0.9)
    loss = dice_loss()
    animator = Animator(xlabel='epoch', xlim=[1, num_epochs], legend=['train loss', 'train acc', 'test acc'])
    timer, num_batches = Timer(), len(train_iter)
    for epoch in range(num_epochs):
        metric = Accumulator(3)
        model.train()
        for i, (X,y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            X,y = X.to(device), y.to(device)
            y_hat = model(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()
            #print(model.layer9[4].weight)
            #print(y_hat)
            with torch.no_grad():
                metric.add(l, accuracy(y_hat, y)/X.numel(), 1)
            timer.stop()
            train_l = metric[0]/metric[2]
            train_acc = metric[1]/metric[2]
            #metric.reset()
            if (i+1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i+1) / num_batches, (train_l, train_acc, None))
        test_acc = evaluate_accuracy_gpu(model, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss{train_l:.3f}, train acc {train_acc:.3f},' f'test acc {test_acc:.3f}')
    print(f'{metric[2]*num_epochs/timer.sum():.1f} examples/sec ' f'on {str(device)}')

In [None]:
lr, num_epochs = 0.0005, 30
train(model, train_iter, test_iter, num_epochs, lr, try_gpu())