In [1]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as utils
import torch.nn.functional as F

from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.utils import save_image

from torchvision.transforms import Compose, CenterCrop, Normalize, ToTensor
from glob import glob

import nibabel as nib 

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

import binary as bin_eval

from PIL import Image
import cv2 as cv


manualSeed = 999

print("Random Seed:", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

Random Seed: 999


<torch._C.Generator at 0x7fca9b906170>

In [2]:
batch_size = 1

test_size = 1

In [3]:
class ReLabel(object):

    def __call__(self, tensor):
    
        tensor[tensor > 0] = 1
        tensor[tensor < 0] = 0
        
        return tensor

class ReImage(object):

    def __call__(self, tensor):
    
        t_max = tensor.max()
        tensor = tensor/t_max
        
        return tensor     
    
class Dataset(torch.utils.data.Dataset):

    def __init__(self, root):

        self.root = root
        
        if not os.path.exists(self.root):
            raise Exception("[!] {} not exists.".format(root))
        
        self.img_transform = Compose([
            
            ReImage(),
        ])
        
        self.label_transform = Compose([
            
            ReLabel(),
        ])
        
        #sort file names
        self.input_paths = sorted(glob(os.path.join(self.root, '{}/*.npy'.format("3D_data_4/train_data"))))
        self.label_paths = sorted(glob(os.path.join(self.root, '{}/*.npy'.format("3D_data_4/train_lab"))))
        self.name = os.path.basename(root)
        
        if len(self.input_paths) == 0 or len(self.label_paths) == 0:
            raise Exception("No images/labels are found in {}".format(self.root))

    def __getitem__(self, index):
        
               
        image = np.load(self.input_paths[index])
        label = np.load(self.label_paths[index])

        image = image[np.newaxis,:,:,:]
        label = label[np.newaxis,:,:,:]
        
        image = image.astype(np.float32)
        label = label.astype(np.float32)
        
        image = self.img_transform(image)
        label = self.label_transform(label)

        return image, label

    def __len__(self):
        return len(self.input_paths)

    
class Dataset_test(torch.utils.data.Dataset):

    def __init__(self, root):

        self.root = root
        
        if not os.path.exists(self.root):
            raise Exception("[!] {} not exists.".format(root))
        
        self.img_transform = Compose([
            ReImage(),    
        ])
        
        self.label_transform = Compose([
            ReLabel(),
        ])
        
        #sort file names
        self.input_paths = sorted(glob(os.path.join(self.root, '{}/*.npy'.format("3D_data_4/test_data"))))
        self.label_paths = sorted(glob(os.path.join(self.root, '{}/*.npy'.format("3D_data_4/test_lab"))))
        self.name = os.path.basename(root)
        
        if len(self.input_paths) == 0 or len(self.label_paths) == 0:
            raise Exception("No images/labels are found in {}".format(self.root))

    def __getitem__(self, index):
               
        image = np.load(self.input_paths[index])
        label = np.load(self.label_paths[index])

        image = image[np.newaxis,:,:,:]
        label = label[np.newaxis,:,:,:]
        
        image = image.astype(np.float32)
        label = label.astype(np.float32)        
        
        image = self.img_transform(image)
        label = self.label_transform(label)

        return image, label

    def __len__(self):
        return len(self.input_paths)

In [4]:
def loader(dataset, batch_size, num_workers=6, shuffle = False, drop_last=False):

    input_images = dataset
    input_loader = torch.utils.data.DataLoader(dataset=input_images, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_last)

    return input_loader

train_loader = loader(Dataset('../'), batch_size= batch_size, shuffle = True, drop_last=True)
test_loader = loader(Dataset_test('../'), batch_size= test_size, shuffle = False, drop_last=True)

In [5]:
device = torch.device("cuda:1" if (torch.cuda.is_available()) else "cpu")

print(torch.cuda.is_available())

print(train_loader.batch_size)

True
1


In [6]:
#High-level Feature Enhancement Module

class Vox_Att(nn.Module):
    
    def __init__(self, in_ch):
        super(Vox_Att, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv3d(in_ch, int(in_ch/2), kernel_size=3, padding=1),
            nn.GroupNorm(8, int(in_ch/2)),
            nn.ReLU(inplace=True),
        )
        self.conv2 = nn.Sequential(
            nn.Conv3d(int(in_ch/2), int(in_ch/2), kernel_size=3, padding=5, dilation=5),
            nn.GroupNorm(8, int(in_ch/2)),
            nn.ReLU(inplace=True),
        )
        self.conv3 = nn.Sequential(
            nn.Conv3d(int(in_ch/2), in_ch, kernel_size=3, padding=1),
            nn.GroupNorm(8, in_ch),
            nn.ReLU(inplace=True),
        )
        self.output = nn.Sigmoid()
        
    def forward(self, in_x):
        
        x = self.conv1(in_x)
        x = self.conv2(x)
        x = self.conv3(x)
        
        x = self.output(x)
        
        return in_x * x + in_x
    

In [7]:
class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True),
            
            nn.Conv3d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        
        x = self.conv(x)
        return x
    
    
    
class double_conv_HL(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(double_conv_HL, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True),
            
            nn.Conv3d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        
        x = self.conv(x)
        return x    
    
class single_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(single_conv, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        
        x = self.conv(x)
        return x    
    
#Low-level Feature Enhancement Module

class High_map(nn.Module):
    
    def __init__(self, in_ch, out_ch):
        super(High_map, self).__init__()
        
        self.Sconv1 = single_conv(in_ch, out_ch)
        self.pool1 = nn.MaxPool3d((1,2,2))
        
        self.Sconv2 = single_conv(out_ch, out_ch)       
        self.pool2 = nn.MaxPool3d((1,2,2))
        
        self.Sconv3 = single_conv(out_ch, out_ch*2)           
        
        self.Tconv1 = nn.ConvTranspose3d(out_ch*2, out_ch, kernel_size=(1,2,2), stride=(1,2,2), padding=0)
        self.Sconv4 = single_conv(out_ch*2, out_ch)
        
        self.Tconv2 = nn.ConvTranspose3d(out_ch, out_ch, kernel_size=(1,2,2), stride=(1,2,2), padding=0)
        self.Sconv5 = single_conv(out_ch*2, out_ch)        
        
        self.conv6 = nn.Conv3d(out_ch, 1, kernel_size=1, stride=1, padding=0)
        self.output = nn.Sigmoid()  

    def forward(self, x):
        
        c1 = self.Sconv1(x)
        p1 = self.pool1(c1)
        
        c2 = self.Sconv2(p1)
        p2 = self.pool2(c2)
        
        c3 = self.Sconv3(p2)
        
        t1 = self.Tconv1(c3)
        x1 = torch.cat([c2, t1], dim=1)
        c4 = self.Sconv4(x1)
        
        t2 = self.Tconv2(c4)
        x2 = torch.cat([c1, t2], dim=1)
        c5 = self.Sconv5(x2) 
        
        c6 = self.conv6(c5) 
        
        map_x = self.output(c6)
        
        return map_x

      
class High_to_Low(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(High_to_Low, self).__init__()
        
        self.Dconv1 = double_conv_HL(in_ch, out_ch)
        
        self.High_Map = High_map(in_ch, out_ch)
        
        self.Sconv1 = single_conv(out_ch, out_ch)
        
    def forward(self, x):
        
        D_x = self.Dconv1(x)
        
        H_x = self.High_Map(x)
        
        Mix_x = torch.add(D_x*H_x, D_x)
        
        out_x = self.Sconv1(Mix_x)
        
        return out_x

In [8]:
input = 1
numf = 16

class G_unet(nn.Module):
    def __init__(self):
        super(G_unet, self).__init__()
        
        self.Dconv1 = High_to_Low(input, numf)       
        self.pool1 = nn.MaxPool3d((1,2,2))
        
        self.Dconv2 = High_to_Low(numf, numf*2)           
        self.pool2 = nn.MaxPool3d((1,2,2))        
        
        self.Dconv3 = High_to_Low(numf*2, numf*3)          
        self.pool3 = nn.MaxPool3d(2)     
        
        self.Dconv4 = High_to_Low(numf*3, numf*4)              
        self.pool4 = nn.MaxPool3d(2)
        
        self.Dconv5 = double_conv(numf*4, numf*8) 
        
        self.Tconv1 = nn.ConvTranspose3d(numf*8, numf*4, kernel_size=2, stride=2, padding=0)
        self.AVgate1 = Vox_Att(numf*8)
        self.Dconv6 = double_conv(numf*8, numf*4) 
        
        self.Tconv2 = nn.ConvTranspose3d(numf*4, numf*3, kernel_size=2, stride=2, padding=0)
        self.AVgate2 = Vox_Att(numf*6)
        self.Dconv7 = double_conv(numf*6, numf*3)        
        
        self.Tconv3 = nn.ConvTranspose3d(numf*3, numf*2, kernel_size=(1,2,2), stride=(1,2,2), padding=0)
        self.AVgate3 = Vox_Att(numf*4)
        self.Dconv8 = double_conv(numf*4, numf*2)          
        
        self.Tconv4 = nn.ConvTranspose3d(numf*2, numf, kernel_size=(1,2,2), stride=(1,2,2), padding=0)
        self.AVgate4 = Vox_Att(numf*2)
        self.Dconv9 = double_conv(numf*2, numf)   
        
        self.conv19 = nn.Conv3d(numf, 1, kernel_size=1, stride=1, padding=0)
        self.output = nn.Sigmoid()
        
    def forward(self, input):
        
        c1 = self.Dconv1(input)
        p1 = self.pool1(c1)
            
        c2 = self.Dconv2(p1)
        p2 = self.pool2(c2)  
        
        c3 = self.Dconv3(p2)
        p3 = self.pool3(c3)  
        
        c4 = self.Dconv4(p3)
        p4 = self.pool4(c4)
        
        c5 = self.Dconv5(p4)
        
        t1 = self.Tconv1(c5)
        x1 = torch.cat([c4, t1], dim=1)
        x1 = self.AVgate1(x1)
        c6 = self.Dconv6(x1)
        
        t2 = self.Tconv2(c6)
        x2 = torch.cat([c3, t2], dim=1)
        x2 = self.AVgate2(x2)
        c7 = self.Dconv7(x2)         
        
        t3 = self.Tconv3(c7)
        x3 = torch.cat([c2, t3], dim=1)
        x3 = self.AVgate3(x3)
        c8 = self.Dconv8(x3)   
        
        t4 = self.Tconv4(c8)
        x4 = torch.cat([c1, t4], dim=1)
        x4 = self.AVgate4(x4)
        c9 = self.Dconv9(x4)
        
        c19 = self.conv19(c9)
        
        output = self.output(c19)
            
        return output
        

In [9]:
CFENet = G_unet().to(device)

print(CFENet)

print('# CFENet parameters:', sum(param.numel() for param in CFENet.parameters()))

G_unet(
  (Dconv1): High_to_Low(
    (Dconv1): double_conv_HL(
      (conv): Sequential(
        (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): GroupNorm(8, 16, eps=1e-05, affine=True)
        (2): ReLU(inplace)
        (3): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (4): GroupNorm(8, 16, eps=1e-05, affine=True)
        (5): ReLU(inplace)
      )
    )
    (High_Map): High_map(
      (Sconv1): single_conv(
        (conv): Sequential(
          (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (1): GroupNorm(8, 16, eps=1e-05, affine=True)
          (2): ReLU(inplace)
        )
      )
      (pool1): MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0, dilation=1, ceil_mode=False)
      (Sconv2): single_conv(
        (conv): Sequential(
          (0): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (1): GroupNorm(8, 16, 

In [10]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1, eps=1e-7):
        super().__init__()
        self.smooth = smooth
        self.eps = eps

    def forward(self, preds, labels):
    
        return 1 - (2 * torch.sum(preds * labels) + self.smooth) / (torch.sum(preds) + torch.sum(labels) + self.smooth)


class jaccard_Loss(nn.Module):
    def __init__(self, smooth=1, eps=1e-7):
        super().__init__()
        self.smooth = smooth
        self.eps = eps

    def forward(self, preds, labels):
        
        intersection = torch.sum(preds * labels)
        return 1 - (intersection + self.smooth) / (torch.sum(preds) + torch.sum(labels) - intersection + self.smooth)    
    

def jaccard_coef(preds, labels):
    
    smooth=1
    eps=1e-7
    
    intersection = torch.sum(preds * labels)
    return (intersection + smooth) / (torch.sum(preds) + torch.sum(labels) - intersection + smooth)
          
def dice_coef(preds, labels):
    
    smooth=1
    eps=1e-7
    
    return (2 * torch.sum(preds * labels) + smooth) / (torch.sum(preds) + torch.sum(labels) + smooth)


def com_evaluation(preds, labels):
    
    smooth=1
    eps=1e-7
    
    intersection = torch.sum(labels * preds)
    num_pred = torch.sum(preds)
    num_lab = torch.sum(labels)
    
    jaccard = (intersection + smooth) / (num_pred + num_lab - intersection + smooth)
    dice = (2 * intersection + smooth) / (num_pred + num_lab + smooth)

    pre_list = preds.cpu().numpy()
    lab_list = labels.cpu().numpy()
    assd = bin_eval.assd(pre_list, lab_list)
    
    return jaccard, dice, assd

In [11]:
Criterion = DiceLoss()

C_optimizer = torch.optim.Adam(CFENet.parameters(), lr=0.0001)

scheduler = optim.lr_scheduler.StepLR(C_optimizer, step_size=20, gamma=0.9)

In [None]:
list = []

for epoch in range(300):
    
    run_dice_loss = 0.0
    running_jaccard = 0.0

    test_dice = 0.0
    test_jaccard = 0.0
    test_assd = 0.0

    test_i = 0
    train_i = 0
    
    scheduler.step()
    
    for param_group in C_optimizer.param_groups:
        if epoch % 20 == 0:
            print(param_group['lr'])
    
    
    for i_1, train_data in enumerate(train_loader):
        
        inputs, labels = train_data
        
        inputs = inputs.to(device)
        labels = labels.to(device)
             
        pre_labs = CFENet(inputs)
        dice_loss = Criterion(labels, pre_labs)
        
        jaccard = jaccard_coef(pre_labs, labels)

        C_optimizer.zero_grad()
        dice_loss.backward()
        C_optimizer.step()

        running_jaccard += jaccard
        run_dice_loss += dice_loss.item()
        
        train_i += 1

        
    for i_2, test_data in enumerate(test_loader): 
  
        with torch.no_grad():   
            
            inputs, labels = test_data

            inputs = inputs.to(device)
            labels = labels.to(device)

            pre_labs = CFENet(inputs)
            
            pre_labs[pre_labs < 0.5] = 0
            pre_labs[pre_labs >= 0.5] = 1
            
            jaccard, dice, assd = com_evaluation(pre_labs, labels)
            
            test_jaccard += jaccard
            test_dice += dice
            test_assd += assd
            
            test_i += 1
                                        
    list.append(test_jaccard)
    #print(test_i)

    if  max(list) == test_jaccard:

        torch.save(CFENet, "HL-LH_4.pkl")
        print("-------Save %d epoch model---------"% epoch)

    print('e: %d, d_loss: %.5f, tr_jacc: %.5f,' % 
          (epoch, run_dice_loss/train_i, running_jaccard/train_i))
    print('test_dice: %.5f, test_jacc: %.5f, test_acc: %.5f,' % 
          (test_dice/test_i, test_jaccard/test_i, test_assd/test_i))
    print('-----------------------------------------------')
            
    run_dice_loss = 0.0
    running_jaccard = 0.0

    test_dice = 0.0
    test_jaccard = 0.0
    test_assd = 0.0

    test_i = 0
    train_i = 0
                         
print("Finished Training")