## Prepare Dataset List

In [1]:
import os
import random
from time import time

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch.utils.data import Dataset as dataset

import SimpleITK as sitk
import scipy.ndimage as ndimage

In [2]:
size = 48
lower = -350
num_organ = 13

In [3]:
class Dataset(dataset):
    def  __init__ (self , ct_dir, seg_dir):
        self.ct_list = os.listdir(ct_dir)
        self.seg_list = list(map(lambda x: x.replace('img', 'label'), self.ct_list))

        self.ct_list = list(map(lambda x: os.path.normpath(os.path.join(ct_dir, x)), self.ct_list))
        self.seg_list = list(map(lambda x: os.path.normpath(os.path.join(seg_dir, x)), self.seg_list))

    def __getitem__(self, index):
        """
        :param index:
        :return: torch.Size([B, 1, 48, 256, 256]) torch.Size([B, 48, 256, 256])
        """
        
        ct_path = self.ct_list[index]
        seg_path  = self.seg_list[index]

        # Read CT and gold standard into memory
        ct  =  sitk.ReadImage(ct_path , sitk.sitkInt16)
        seg  =  sitk.ReadImage(seg_path , sitk.sitkUInt8)

        ct_array = sitk.GetArrayFromImage(ct)
        seg_array = sitk.GetArrayFromImage(seg)

        #Randomly select 64 slices in the slice plane
        start_slice = random.randint(0, ct_array.shape[0] - size)
        end_slice = start_slice + size - 1

        ct_array = ct_array[start_slice:end_slice + 1, :, :]
        seg_array = seg_array [start_slice:end_slice + 1 , :, :]

#         # Randomly rotate within 5 degrees with probability 0.5
#         # If the angle is negative, it will rotate clockwise, if the angle is positive, it will rotate counterclockwise
        if random.uniform(0, 1) >= 0.5:
            angle = random.uniform(-5, 5)
            ct_array = ndimage.rotate(ct_array, angle, axes =(1, 2), reshape = False , cval = lower);
            seg_array = ndimage.rotate(seg_array, angle, axes = (1, 2), reshape = False, cval = 0);

        #There is a probability of 0.5 without any modification, and the remaining 0.5 randomly selects a patch with a size of 0.8-0.5 and enlarges it to 256*256
        if random.uniform(0, 1) >= 0.5:
            ct_array, seg_array = self.zoom(ct_array, seg_array, patch_size=random.uniform(0.5, 0.8))

        # After processing, convert array to tensor
        ct_array = torch.FloatTensor(ct_array).unsqueeze(0)
        seg_array = torch.FloatTensor(seg_array)
        
        # ct_array = F.interpolate(ct_array.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False)
        # seg_array = F.interpolate(seg_array.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False).squeeze(0)
        
        # print("ct {} seg {}".format(ct_array.shape, seg_array.shape))

        return ct_array, seg_array

    def __len__(self):
        return len(self.ct_list)

    def zoom(self, ct_array, seg_array, patch_size):
        length = int(256*patch_size)

        x1 = int(random.uniform(0, 255-length))
        y1 = int(random.uniform(0, 255-length))

        x2 = x1 + length
        y2 = y1 + length

        ct_array = ct_array[:, x1:x2 + 1, y1:y2 + 1]
        seg_array = seg_array[:, x1:x2 + 1 , y1:y2 + 1]    

        with torch.no_grad():

            ct_array = torch.FloatTensor(ct_array).unsqueeze(dim=0).unsqueeze(dim=0)
            ct_array = Variable(ct_array)
            ct_array = F.interpolate(ct_array, (size, 256, 256), mode='trilinear').squeeze().detach().numpy()

            seg_array = torch.FloatTensor(seg_array).unsqueeze(dim=0).unsqueeze(dim=0)
            seg_array = Variable(seg_array)
            seg_array = F.interpolate(seg_array, (size, 256, 256), mode='trilinear').squeeze().detach().numpy()

            return ct_array, seg_array

In [4]:
ct_dir = "D:/skripsi/Dataset/train/CT"
seg_dir = "D:/skripsi/Dataset/train/GT"

train_ds = Dataset(ct_dir, seg_dir)

In [5]:
train_ds.seg_list

['D:\\skripsi\\Dataset\\train\\GT\\label-0.nii',
 'D:\\skripsi\\Dataset\\train\\GT\\label-1.nii',
 'D:\\skripsi\\Dataset\\train\\GT\\label-10.nii',
 'D:\\skripsi\\Dataset\\train\\GT\\label-11.nii',
 'D:\\skripsi\\Dataset\\train\\GT\\label-12.nii',
 'D:\\skripsi\\Dataset\\train\\GT\\label-13.nii',
 'D:\\skripsi\\Dataset\\train\\GT\\label-14.nii',
 'D:\\skripsi\\Dataset\\train\\GT\\label-15.nii',
 'D:\\skripsi\\Dataset\\train\\GT\\label-16.nii',
 'D:\\skripsi\\Dataset\\train\\GT\\label-17.nii',
 'D:\\skripsi\\Dataset\\train\\GT\\label-18.nii',
 'D:\\skripsi\\Dataset\\train\\GT\\label-19.nii',
 'D:\\skripsi\\Dataset\\train\\GT\\label-2.nii',
 'D:\\skripsi\\Dataset\\train\\GT\\label-20.nii',
 'D:\\skripsi\\Dataset\\train\\GT\\label-21.nii',
 'D:\\skripsi\\Dataset\\train\\GT\\label-22.nii',
 'D:\\skripsi\\Dataset\\train\\GT\\label-23.nii',
 'D:\\skripsi\\Dataset\\train\\GT\\label-24.nii',
 'D:\\skripsi\\Dataset\\train\\GT\\label-25.nii',
 'D:\\skripsi\\Dataset\\train\\GT\\label-26.nii',
 'D

## Model

In [6]:
# class DoubleConv(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(DoubleConv, self).__init__()
#         self.conv = nn.Sequential(
#             nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
#             nn.BatchNorm3d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
#             nn.BatchNorm3d(out_channels),
#             nn.ReLU(inplace=True)
#         )

#     def forward(self, x):
#         x = self.conv(x)
#         return x

In [7]:
# class Down(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(Down, self).__init__()
#         self.mpconv = nn.Sequential(
#             nn.MaxPool3d(2),
#             DoubleConv(in_channels, out_channels)
#         )

#     def forward(self, x):
#         x = self.mpconv(x)
#         return x

In [8]:
# class Up(nn.Module):
#     def __init__(self, in_channels, out_channels, bilinear=True):
#         super(Up, self).__init__()
#         if bilinear:
#             self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
#         else:
#             self.up = nn.ConvTranspose3d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
#         self.conv = DoubleConv(in_channels, out_channels)

#     def forward(self, x1, x2):
#         x1 = self.up(x1)
#         diffZ = x2.size()[4] - x1.size()[4]
#         diffY = x2.size()[3] - x1.size()[3]
#         diffX = x2.size()[2] - x1.size()[2]
#         x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2,
#                         diffY // 2, diffY - diffY // 2,
#                         diffZ // 2, diffZ - diffZ // 2))
#         x = torch.cat([x2, x1], dim=1)
#         x = self.conv(x)
#         return x

In [9]:
# class AttentionBlock(nn.Module):
#     def __init__(self, in_channels):
#         super(AttentionBlock, self).__init__()
#         self.query_conv = nn.Conv3d(in_channels, in_channels // 8, kernel_size=1)
#         self.key_conv = nn.Conv3d(in_channels, in_channels // 8, kernel_size=1)
#         self.value_conv = nn.Conv3d(in_channels, in_channels, kernel_size=1)
#         self.gamma = nn.Parameter(torch.zeros(1))

#     def forward(self, x):
#         batch_size, channels, height, width, depth = x.size()
#         proj_query = self.query_conv(x).view(batch_size, -1, height * width * depth).permute(0, 2, 1)
#         proj_key = self.key_conv(x).view(batch_size, -1, height * width * depth)
#         energy = torch.bmm(proj_query, proj_key)
#         attention = F.softmax(energy, dim=-1)
#         proj_value = self.value_conv(x).view(batch_size, -1, height * width * depth)
#         out = torch.bmm(proj_value, attention.permute(0, 2, 1))
#         out = out.view(batch_size, channels, height, width, depth)
#         out = self.gamma * out + x
#         return out

In [10]:
# class AttentionUNet(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(AttentionUNet, self).__init__()
#         self.inc = DoubleConv(in_channels, 64)
#         self.down1 = Down(64, 128)
#         self.down2 = Down(128, 256)
#         self.down3 = Down(256, 512)
#         self.attention = AttentionBlock(512)
#         self.up3 = Up(512, 256)
#         self.up2 = Up(256, 128)
#         self.up1 = Up(128, 64)
#         self.outc = nn.Conv3d(64, out_channels, kernel_size=1)

#     def forward(self, x):
#         x1 = self.inc(x)
#         x2 = self.down1(x1)
#         x3 = self.down2(x2)
#         x4 = self.down3(x3)
#         x4 = self.attention(x4)
#         x = self.up3(x4, x3)
#         x = self.up2(x, x2)
#         x = self.up1(x, x1)
#         x = self.outc(x)
#         return x

In [11]:
# class AttentionUNet(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super().__init__()

#         # Downsample block 1
#         self.conv1a = nn.Conv3d(in_channels, 16, kernel_size=3, padding=1)
#         self.bn1a = nn.BatchNorm3d(16)
#         self.conv1b = nn.Conv3d(16, 16, kernel_size=3, padding=1)
#         self.bn1b = nn.BatchNorm3d(16)
#         self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)

#         # Downsample block 2
#         self.conv2a = nn.Conv3d(16, 32, kernel_size=3, padding=1)
#         self.bn2a = nn.BatchNorm3d(32)
#         self.conv2b = nn.Conv3d(32, 32, kernel_size=3, padding=1)
#         self.bn2b = nn.BatchNorm3d(32)
#         self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)

#         # Downsample block 3
#         self.conv3a = nn.Conv3d(32, 64, kernel_size=3, padding=1)
#         self.bn3a = nn.BatchNorm3d(64)
#         self.conv3b = nn.Conv3d(64, 64, kernel_size=3, padding=1)
#         self.bn3b = nn.BatchNorm3d(64)
#         self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)

#         # Downsample block 4
#         self.conv4a = nn.Conv3d(64, 128, kernel_size=3, padding=1)
#         self.bn4a = nn.BatchNorm3d(128)
#         self.conv4b = nn.Conv3d(128, 128, kernel_size=3, padding=1)
#         self.bn4b = nn.BatchNorm3d(128)
#         self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)

#         # Downsample block 5
#         self.conv5a = nn.Conv3d(128, 256, kernel_size=3, padding=1)
#         self.bn5a = nn.BatchNorm3d(256)
#         self.conv5b = nn.Conv3d(256, 256, kernel_size=3, padding=1)
#         self.bn5b = nn.BatchNorm3d(256)

#         # Upsample block 1
#         self.upconv1 = nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2)
#         self.conv6a = nn.Conv3d(256, 128, kernel_size=3, padding=1)
#         self.bn6a = nn.BatchNorm3d(128)
#         self.conv6b = nn.Conv3d(128, 128, kernel_size=3, padding=1)
#         self.bn6b = nn.BatchNorm3d(128)

#         # Upsample block 2
#         self.upconv2 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)
#         self.conv7a = nn.Conv3d(128, 64, kernel_size=3, padding=1)
#         self.bn7a = nn.BatchNorm3d(64)
#         self.conv7b = nn.Conv3d(64, 64, kernel_size=3, padding=1)
#         self.bn7b = nn.BatchNorm3d(64)
        
#         # Upsample block 3
#         self.upconv3 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2)
#         self.conv8a = nn.Conv3d(64, 32, kernel_size=3, padding=1)
#         self.bn8a = nn.BatchNorm3d(32)
#         self.conv8b = nn.Conv3d(32, 32, kernel_size=3, padding=1)
#         self.bn8b = nn.BatchNorm3d(32)

#         # Upsample block 4
#         self.upconv4 = nn.ConvTranspose3d(32, 16, kernel_size=2, stride=2)
#         self.conv9a = nn.Conv3d(32, 16, kernel_size=3, padding=1)
#         self.bn9a = nn.BatchNorm3d(16)
#         self.conv9b = nn.Conv3d(16, 16, kernel_size=3, padding=1)
#         self.bn9b = nn.BatchNorm3d(16)

#         # Attention gate
#         self.attention_gate = nn.Sequential(
#             nn.Conv3d(256, 64, kernel_size=1),
#             nn.BatchNorm3d(64),
#             nn.ReLU(inplace=True),
#             nn.Conv3d(64, 1, kernel_size=1),
#             nn.Sigmoid()
#         )
        
#         self.attention_gate_2 = nn.Sequential(
#             nn.Conv3d(128, 64, kernel_size=1),
#             nn.BatchNorm3d(64),
#             nn.ReLU(inplace=True),
#             nn.Conv3d(64, 1, kernel_size=1),
#             nn.Sigmoid()
#         )
        
#         self.attention_gate_3 = nn.Sequential(
#             nn.Conv3d(64, 64, kernel_size=1),
#             nn.BatchNorm3d(64),
#             nn.ReLU(inplace=True),
#             nn.Conv3d(64, 1, kernel_size=1),
#             nn.Sigmoid()
#         )
        
#         self.attention_gate_4 = nn.Sequential(
#             nn.Conv3d(32, 64, kernel_size=1),
#             nn.BatchNorm3d(64),
#             nn.ReLU(inplace=True),
#             nn.Conv3d(64, 1, kernel_size=1),
#             nn.Sigmoid()
#         )

#         # Output convolution
#         self.out_conv = nn.Conv3d(16, out_channels, kernel_size=1)
        
#     def forward(self, x):
#         # Downsample block 1
#         x1 = F.relu(self.bn1a(self.conv1a(x)))
#         x1 = F.relu(self.bn1b(self.conv1b(x1)))
#         p1 = self.pool1(x1)

#         # Downsample block 2
#         x2 = F.relu(self.bn2a(self.conv2a(p1)))
#         x2 = F.relu(self.bn2b(self.conv2b(x2)))
#         p2 = self.pool2(x2)

#         # Downsample block 3
#         x3 = F.relu(self.bn3a(self.conv3a(p2)))
#         x3 = F.relu(self.bn3b(self.conv3b(x3)))
#         p3 = self.pool3(x3)

#         # Downsample block 4
#         x4 = F.relu(self.bn4a(self.conv4a(p3)))
#         x4 = F.relu(self.bn4b(self.conv4b(x4)))
#         p4 = self.pool4(x4)

#         # Bottom block
#         x5 = F.relu(self.bn5a(self.conv5a(p4)))
#         x5 = F.relu(self.bn5b(self.conv5b(x5)))

#         # Upsample block 1
#         up1 = self.upconv1(x5)
#         gate1 = self.attention_gate(torch.cat([up1, x4], dim=1))
#         x6 = F.relu(self.bn6a(self.conv6a(torch.cat([up1, gate1*x4], dim=1))))
#         x6 = F.relu(self.bn6b(self.conv6b(x6)))

#         # Upsample block 2
#         up2 = self.upconv2(x6)
#         gate2 = self.attention_gate_2(torch.cat([up2, x3], dim=1))
#         x7 = F.relu(self.bn7a(self.conv7a(torch.cat([up2, gate2*x3], dim=1))))
#         x7 = F.relu(self.bn7b(self.conv7b(x7)))

#         # Upsample block 3
#         up3 = self.upconv3(x7)
#         gate3 = self.attention_gate_3(torch.cat([up3, x2], dim=1))
#         x8 = F.relu(self.bn8a(self.conv8a(torch.cat([up3, gate3*x2], dim=1))))
#         x8 = F.relu(self.bn8b(self.conv8b(x8)))
        
#         # Upsample block 4
#         up4 = self.upconv4(x8)
#         gate4 = self.attention_gate_4(torch.cat([up4, x1], dim=1))
#         x9 = F.relu(self.bn9a(self.conv9a(torch.cat([up4, gate4*x1], dim=1))))
#         x9 = F.relu(self.bn9b(self.conv9b(x9)))

#         # Output
#         out = self.out_conv(x9)

#         return out
        

In [6]:
num_organ = 13
dropout_rate = 0.3

class ResUNet(nn.Module):
    """
    共9332094个可训练的参数, 九百三十万左右
    """
    def __init__(self, training, inchannel, stage):
        """
        :param training: 标志网络是属于训练阶段还是测试阶段
        :param inchannel 网络最开始的输入通道数量
        :param stage 标志网络属于第一阶段，还是第二阶段
        """
        super().__init__()

        self.training = training
        self.stage = stage

        self.encoder_stage1 = nn.Sequential(
            nn.Conv3d(inchannel, 16, 3, 1, padding=1),
            nn.PReLU(16),
        )

        self.encoder_stage2 = nn.Sequential(
            nn.Conv3d(32, 32, 3, 1, padding=1),
            nn.PReLU(32),

            nn.Conv3d(32, 32, 3, 1, padding=1),
            nn.PReLU(32),
        )

        self.encoder_stage3 = nn.Sequential(
            nn.Conv3d(64, 64, 3, 1, padding=1),
            nn.PReLU(64),

            nn.Conv3d(64, 64, 3, 1, padding=2, dilation=2),
            nn.PReLU(64),

            nn.Conv3d(64, 64, 3, 1, padding=4, dilation=4),
            nn.PReLU(64),
        )

        self.encoder_stage4 = nn.Sequential(
            nn.Conv3d(128, 128, 3, 1, padding=3, dilation=3),
            nn.PReLU(128),

            nn.Conv3d(128, 128, 3, 1, padding=4, dilation=4),
            nn.PReLU(128),

            nn.Conv3d(128, 128, 3, 1, padding=5, dilation=5),
            nn.PReLU(128),
        )

        self.decoder_stage1 = nn.Sequential(
            nn.Conv3d(128, 256, 3, 1, padding=1),
            nn.PReLU(256),

            nn.Conv3d(256, 256, 3, 1, padding=1),
            nn.PReLU(256),

            nn.Conv3d(256, 256, 3, 1, padding=1),
            nn.PReLU(256),
        )

        self.decoder_stage2 = nn.Sequential(
            nn.Conv3d(128 + 64, 128, 3, 1, padding=1),
            nn.PReLU(128),

            nn.Conv3d(128, 128, 3, 1, padding=1),
            nn.PReLU(128),

            nn.Conv3d(128, 128, 3, 1, padding=1),
            nn.PReLU(128),
        )

        self.decoder_stage3 = nn.Sequential(
            nn.Conv3d(64 + 32, 64, 3, 1, padding=1),
            nn.PReLU(64),

            nn.Conv3d(64, 64, 3, 1, padding=1),
            nn.PReLU(64),
        )

        self.decoder_stage4 = nn.Sequential(
            nn.Conv3d(32 + 16, 32, 3, 1, padding=1),
            nn.PReLU(32),
        )

        self.down_conv1 = nn.Sequential(
            nn.Conv3d(16, 32, 2, 2),
            nn.PReLU(32)
        )

        self.down_conv2 = nn.Sequential(
            nn.Conv3d(32, 64, 2, 2),
            nn.PReLU(64)
        )

        self.down_conv3 = nn.Sequential(
            nn.Conv3d(64, 128, 2, 2),
            nn.PReLU(128)
        )

        self.down_conv4 = nn.Sequential(
            nn.Conv3d(128, 256, 3, 1, padding=1),
            nn.PReLU(256)
        )

        self.up_conv2 = nn.Sequential(
            nn.ConvTranspose3d(256, 128, 2, 2),
            nn.PReLU(128)
        )

        self.up_conv3 = nn.Sequential(
            nn.ConvTranspose3d(128, 64, 2, 2),
            nn.PReLU(64)
        )

        self.up_conv4 = nn.Sequential(
            nn.ConvTranspose3d(64, 32, 2, 2),
            nn.PReLU(32)
        )

        self.map = nn.Sequential(
            nn.Conv3d(32, num_organ + 1, 1),
            nn.Softmax(dim=1)
        )

    def forward(self, inputs):
        long_range1 = self.encoder_stage1(inputs) + inputs

        short_range1 = self.down_conv1(long_range1)

        long_range2 = self.encoder_stage2(short_range1) + short_range1
        long_range2 = F.dropout(long_range2, dropout_rate, self.training)

        short_range2 = self.down_conv2(long_range2)

        long_range3 = self.encoder_stage3(short_range2) + short_range2
        long_range3 = F.dropout(long_range3, dropout_rate, self.training)

        short_range3 = self.down_conv3(long_range3)

        long_range4 = self.encoder_stage4(short_range3) + short_range3
        long_range4 = F.dropout(long_range4, dropout_rate, self.training)

        short_range4 = self.down_conv4(long_range4)

        outputs = self.decoder_stage1(long_range4) + short_range4
        outputs = F.dropout(outputs, dropout_rate, self.training)

        short_range6 = self.up_conv2(outputs)

        outputs = self.decoder_stage2(torch.cat([short_range6, long_range3], dim=1)) + short_range6
        outputs = F.dropout(outputs, dropout_rate, self.training)

        short_range7 = self.up_conv3(outputs)

        outputs = self.decoder_stage3(torch.cat([short_range7, long_range2], dim=1)) + short_range7
        outputs = F.dropout(outputs, dropout_rate, self.training)

        short_range8 = self.up_conv4(outputs)

        outputs = self.decoder_stage4(torch.cat([short_range8, long_range1], dim=1)) + short_range8

        outputs = self.map(outputs)

        # 返回概率图
        return outputs


# 定义最终的级连3D FCN
class Net(nn.Module):
    def __init__(self, training):
        super().__init__()

        self.training = training

        self.stage1 = ResUNet(training=training, inchannel=1, stage='stage1')
        # self.stage2 = ResUNet(training=training, inchannel=num_organ + 2, stage='stage2')

    def forward(self, inputs):
        """
        首先将输入数据在轴向上缩小一倍，然后送入第一阶段网络中
        得到一个粗糙尺度下的分割结果
        然后将原始尺度大小的数据与第一步中得到的分割结果进行拼接，共同送入第二阶段网络中
        得到最终的分割结果
        共18656348个可训练的参数，一千八百万左右
        """
        # 首先将输入缩小一倍
        inputs_stage1 = F.upsample(inputs, (48, 128, 128), mode='trilinear')

        # 得到第一阶段的结果
        output_stage1 = self.stage1(inputs_stage1)
        output_stage1 = F.upsample(output_stage1, (48, 256, 256), mode='trilinear')

        # 将第一阶段的结果与原始输入数据进行拼接作为第二阶段的输入
        # inputs_stage2 = torch.cat((output_stage1, inputs), dim=1)

        # 得到第二阶段的结果
        # output_stage2 = self.stage2(inputs_stage2)

        # if self.training is True:
        #     return output_stage1, output_stage2
        # else:
        #     return output_stage2
        return output_stage1

# 网络参数初始化函数
def init(module):
    if isinstance(module, nn.Conv3d) or isinstance(module, nn.ConvTranspose3d):
        nn.init.kaiming_normal_(module.weight.data, 0.25)
        nn.init.constant_(module.bias.data, 0)


net = Net(training=True)
net.apply(init)

Net(
  (stage1): ResUNet(
    (encoder_stage1): Sequential(
      (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): PReLU(num_parameters=16)
    )
    (encoder_stage2): Sequential(
      (0): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): PReLU(num_parameters=32)
      (2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (3): PReLU(num_parameters=32)
    )
    (encoder_stage3): Sequential(
      (0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): PReLU(num_parameters=64)
      (2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(2, 2, 2), dilation=(2, 2, 2))
      (3): PReLU(num_parameters=64)
      (4): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(4, 4, 4), dilation=(4, 4, 4))
      (5): PReLU(num_parameters=64)
    )
    (encoder_stage4): Sequential(
      (0): Conv3d(128, 128, kernel_size=(3, 3, 3)

In [13]:
# net = torch.compile(AttentionUNet)
# net = AttentionUNet(1,13)

## Loss Function

In [14]:
# def dice_coef(ground_truth, predicted_segmentation):
#     ground_truth_f = torch.flatten(ground_truth).cpu()
#     predicted_segmentation_f = torch.flatten(predicted_segmentation).cpu()
#     intersection = torch.sum(ground_truth_f * predicted_segmentation_f)
#     return (2. * intersection) / (torch.sum(ground_truth_f) + torch.sum(predicted_segmentation_f))

# def dice_coef_multilabel_loss(ground_truth, predicted_segmentation, numLabels=num_organ):
#     dice = 1
#     organ_target = torch.zeros((ground_truth.size(0), num_organ, 48, 256, 256))

#     for organ_index in range(1, num_organ + 1):
#         temp_target = torch.zeros(ground_truth.size())
#         temp_target[ground_truth == organ_index] = 1
#         organ_target[:, organ_index - 1, :, :, :] = temp_target
        
#     # organ_target.to(device)
#     # organ_target = organ_target.cpu()
#     # predicted_segmentation = predicted_segmentation.cpu()
        
#     for index in range(num_organ):
#         dice -= dice_coef(organ_target[:,index,:,:,:], predicted_segmentation[:,index,:,:,:])
#     return dice.to(device)

In [7]:
class DiceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, target, pred_stage1):
        """
        :param pred_stage1: 经过放大之后(B, 14, 48, 256, 256)
        :param pred_stage2: (B, 14, 48, 256, 256)
        :param target: (B, 48, 256, 256)
        :return: Dice距离
        """

        # 首先将金标准拆开
        organ_target = torch.zeros((target.size(0), num_organ, 48, 256, 256))

        for organ_index in range(1, num_organ + 1):
            temp_target = torch.zeros(target.size())
            temp_target[target == organ_index] = 1
            organ_target[:, organ_index - 1, :, :, :] = temp_target
            # organ_target: (B, 13, 48, 128, 128)

        organ_target = organ_target.cuda()

        # 计算第一阶段的loss
        dice_stage1 = 0.0

        for organ_index in range(1, num_organ + 1):
            dice_stage1 += 2 * (pred_stage1[:, organ_index, :, :, :] * organ_target[:, organ_index - 1, :, :, :]).sum(dim=1).sum(dim=1).sum(
                dim=1) / (pred_stage1[:, organ_index, :, :, :].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) +
                          organ_target[:, organ_index - 1, :, :, :].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + 1e-5)

        dice_stage1 /= num_organ

        # 返回的是dice距离
        return (1 - dice_stage1).mean()

## Training

In [8]:
# Define hyperparameters

save_path = "D:/skripsi/weights2/"

num_epochs = 120
batch_size = 1
learning_rate = 1e-4
accumulation_steps = 8
cudnn.benchmark = True

# Define loss function and device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss_func = DiceLoss()

train_dl = DataLoader(train_ds, batch_size, pin_memory=True)

opt = torch.optim.Adam(net.parameters(), lr=learning_rate)
lr_decay = torch.optim.lr_scheduler.MultiStepLR(opt, [900])

net = net.to(device)

start = time()
for epoch in range(1, num_epochs+1):
    mean_loss = 0
    for step, (ct, seg) in enumerate(train_dl):
        
        ct = ct.cuda()

        outputs_stage1 = net(ct)
        loss = loss_func(seg, outputs_stage1)
        
        mean_loss += loss.item()

        opt.zero_grad()
        loss.backward()
        opt.step()

        if step % 4 == 0:
            print('epoch:{}, step:{}, loss:{:.3f}, time:{:.3f} min'
                  .format(epoch, step, loss.item(), (time() - start) / 60))

        
    mean_loss = mean_loss / step
    lr_decay.step()

    # 每十个个epoch保存一次模型参数
    # 网络模型的命名方式为：epoch轮数+当前minibatch的loss+本轮epoch的平均loss
    if epoch % 10 == 0:
        torch.save(net.state_dict(), save_path + '{}-{:.3f}-{:.3f}.pth'.format(epoch, loss.item(), mean_loss))



epoch:1, step:0, loss:0.977, time:0.527 min
epoch:1, step:4, loss:0.983, time:0.626 min
epoch:1, step:8, loss:0.973, time:0.823 min
epoch:1, step:12, loss:0.971, time:0.974 min
epoch:1, step:16, loss:0.960, time:1.109 min
epoch:1, step:20, loss:0.962, time:1.286 min
epoch:1, step:24, loss:0.964, time:1.462 min
epoch:1, step:28, loss:0.981, time:1.631 min
epoch:2, step:0, loss:0.947, time:1.667 min
epoch:2, step:4, loss:0.962, time:1.864 min
epoch:2, step:8, loss:0.950, time:2.019 min
epoch:2, step:12, loss:0.957, time:2.197 min
epoch:2, step:16, loss:0.945, time:2.257 min
epoch:2, step:20, loss:0.956, time:2.382 min
epoch:2, step:24, loss:0.951, time:2.537 min
epoch:2, step:28, loss:0.963, time:2.651 min
epoch:3, step:0, loss:0.927, time:2.750 min
epoch:3, step:4, loss:0.953, time:2.880 min
epoch:3, step:8, loss:0.939, time:2.965 min
epoch:3, step:12, loss:0.984, time:3.020 min
epoch:3, step:16, loss:0.942, time:3.192 min
epoch:3, step:20, loss:0.953, time:3.325 min
epoch:3, step:24, l