In [1]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as utils
import torch.nn.functional as F

from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.utils import save_image

from torchvision.transforms import Compose, CenterCrop, Normalize, ToTensor

from glob import glob

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

from PIL import Image
import cv2 as cv

import binary as bin_eval

manualSeed = 999

print("Random Seed:", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

Random Seed: 999


<torch._C.Generator at 0x7f9dcc037450>

In [2]:
batch_size = 10

test_size = 10

In [3]:
class HorizontalFlip(object):
    """Horizontally flips the given PIL.Image with a probability of 0.5."""

    def __call__(self, img):
        return img.transpose(Image.FLIP_LEFT_RIGHT)

class VerticalFlip(object):
    """Vertically flips the given PIL.Image with a probability of 0.5."""

    def __call__(self, img):
        return img.transpose(Image.FLIP_TOP_BOTTOM)
       
class ReLabel(object):

    def __init__(self, nlabel):
        self.nlabel = nlabel

    def __call__(self, tensor):
    
        tensor[tensor >= 0.1] = 1
        tensor[tensor < 0.1] = 0
        
        return tensor

class Resize(object):
    """Vertically flips the given PIL.Image with a probability of 0.5."""

    def __call__(self, img):
        
        img = np.array(img)
        img = img[8:232, 8:232]
        img = Image.fromarray(img)
        return img
    
class Dataset(torch.utils.data.Dataset):

    def __init__(self, root):

        self.root = root
        
        if not os.path.exists(self.root):
            raise Exception("[!] {} not exists.".format(root))
        
        self.img_transform = Compose([
            ToTensor(),
        ])
        
        self.label_transform = Compose([
            ToTensor(),
            ReLabel(1),
        ])
    
        self.input_paths = sorted(glob(os.path.join(self.root, '{}/*.jpg'.format("data_jpg_all/train_data"))))
        self.label_paths = sorted(glob(os.path.join(self.root, '{}/*.jpg'.format("data_jpg_all/train_lab"))))
        self.name = os.path.basename(root)
        
        if len(self.input_paths) == 0 or len(self.label_paths) == 0:
            raise Exception("No images/labels are found in {}".format(self.root))

    def __getitem__(self, index):
        
        image = Image.open(self.input_paths[index])
        label = Image.open(self.label_paths[index])
               
        if random.random() > 0.3:
            image = np.array(image)
            label = np.array(label)

            s1 = random.randint(-35, 35)
            s2 = random.randint(-35, 35)
            s3 = random.randint(-35, 35)
            s4 = random.randint(-35, 35)        

            d1 = random.randint(210, 270)
            d2 = random.randint(210, 270)
            d3 = random.randint(210, 270)
            d4 = random.randint(210, 270)

            srcTri = np.array( [[0, 0], [239, 0], [0, 239], [239, 239]] ).astype(np.float32)
            dstTri = np.array( [[s1, s2], [d1, s3], [s4, d3], [d2, d4]] ).astype(np.float32)

            warp_mat = cv.getPerspectiveTransform(srcTri, dstTri)

            image = cv.warpPerspective(image, warp_mat, (240, 240))
            label = cv.warpPerspective(label, warp_mat, (240, 240))

            image = Image.fromarray(image)
            label = Image.fromarray(label)

            #image.show()
            #label.show()

        image = Resize()(image)
        label = Resize()(label)
        
        #randomly flip images
        if random.random() > 0.5:
            image = HorizontalFlip()(image)
            label = HorizontalFlip()(label)
              
        if random.random() > 0.5:
            image = VerticalFlip()(image)
            label = VerticalFlip()(label)
        
        image = self.img_transform(image)
        label = self.label_transform(label)

        return image, label

    def __len__(self):
        return len(self.input_paths)

    
class Dataset_val(torch.utils.data.Dataset):

    def __init__(self, root):

        self.root = root
        
        if not os.path.exists(self.root):
            raise Exception("[!] {} not exists.".format(root))
        
        self.img_transform = Compose([
            ToTensor(),
        ])
        
        self.label_transform = Compose([
            ToTensor(),
            ReLabel(1),
        ])
        
        #sort file names
        self.input_paths = sorted(glob(os.path.join(self.root, '{}/*.jpg'.format("data_jpg_all/val_data"))))
        self.label_paths = sorted(glob(os.path.join(self.root, '{}/*.jpg'.format("data_jpg_all/val_lab"))))
        self.name = os.path.basename(root)
        
        if len(self.input_paths) == 0 or len(self.label_paths) == 0:
            raise Exception("No images/labels are found in {}".format(self.root))

    def __getitem__(self, index):
        
        image = Image.open(self.input_paths[index])
        label = Image.open(self.label_paths[index])
        
        image = Resize()(image)
        label = Resize()(label)
        
        image = self.img_transform(image)
        label = self.label_transform(label)

        return image, label

    def __len__(self):
        return len(self.input_paths)
    
class Dataset_test(torch.utils.data.Dataset):

    def __init__(self, root):

        self.root = root
        
        if not os.path.exists(self.root):
            raise Exception("[!] {} not exists.".format(root))
        
        self.img_transform = Compose([
            ToTensor(),
        ])
        
        self.label_transform = Compose([
            ToTensor(),
            ReLabel(1),
        ])
        
        #sort file names
        self.input_paths = sorted(glob(os.path.join(self.root, '{}/*.jpg'.format("data_jpg_all/test_data"))), key=lambda x:int(x[32:-4]))
        self.label_paths = sorted(glob(os.path.join(self.root, '{}/*.jpg'.format("data_jpg_all/test_lab"))), key=lambda x:int(x[31:-4]))
        self.name = os.path.basename(root)
        
        #print(self.input_paths)
        #print(self.label_paths)
        
        if len(self.input_paths) == 0 or len(self.label_paths) == 0:
            raise Exception("No images/labels are found in {}".format(self.root))

    def __getitem__(self, index):
        
        image = Image.open(self.input_paths[index])
        label = Image.open(self.label_paths[index])
        
        image = Resize()(image)
        label = Resize()(label)
        
        image = self.img_transform(image)
        label = self.label_transform(label)

        return image, label

    def __len__(self):
        return len(self.input_paths)

In [4]:
def loader(dataset, batch_size, num_workers=1, shuffle = False, drop_last=False):

    input_images = dataset
    input_loader = torch.utils.data.DataLoader(dataset=input_images, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_last)

    return input_loader

train_loader = loader(Dataset('../'), batch_size= batch_size, shuffle = True, drop_last=True)
test_loader = loader(Dataset_test('../'), batch_size= test_size, shuffle = False, drop_last=True)
val_loader = loader(Dataset_val('../'), batch_size= test_size, shuffle = False, drop_last=True)


In [5]:
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

print(torch.cuda.is_available())

print(train_loader.batch_size)

True
10


In [6]:
#High-level Feature Enhancement Module

class Vox_Att(nn.Module):
    
    def __init__(self, in_ch):
        super(Vox_Att, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_ch, int(in_ch/2), kernel_size=1, padding=0),
            nn.GroupNorm(8, int(in_ch/2)),
            nn.ReLU(inplace=True),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(int(in_ch/2), int(in_ch/2), kernel_size=3, padding=5, dilation=5),
            nn.GroupNorm(8, int(in_ch/2)),
            nn.ReLU(inplace=True),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(int(in_ch/2), in_ch, kernel_size=1, padding=0),
            nn.GroupNorm(8, in_ch),
            nn.ReLU(inplace=True),
        )
        self.output = nn.Sigmoid()
        
    def forward(self, in_x):
        
        x = self.conv1(in_x)
        x = self.conv2(x)
        x = self.conv3(x)
        
        x = self.output(x)
        
        return in_x * x + in_x
    

In [7]:
class double_conv(nn.Module):
    
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        
        x = self.conv(x)
        return x
    
class double_conv_HL(nn.Module):
    
    def __init__(self, in_ch, out_ch):
        super(double_conv_HL, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        
        x = self.conv(x)
        return x    
    
class single_conv(nn.Module):
    
    def __init__(self, in_ch, out_ch):
        super(single_conv, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        
        x = self.conv(x)
        return x    
    

#Low-level Feature Enhancement Module  

class High_map(nn.Module):
    
    def __init__(self, in_ch, out_ch):
        super(High_map, self).__init__()
        
        self.Sconv1 = single_conv(in_ch, out_ch)
        self.pool1 = nn.MaxPool2d(2)
        
        self.Sconv2 = single_conv(out_ch, out_ch)       
        self.pool2 = nn.MaxPool2d(2)
        
        self.Sconv3 = single_conv(out_ch, out_ch*2)           
        
        self.Tconv1 = nn.ConvTranspose2d(out_ch*2, out_ch, kernel_size=2, stride=2, padding=0)
        self.Sconv4 = single_conv(out_ch*2, out_ch)
        
        self.Tconv2 = nn.ConvTranspose2d(out_ch, out_ch, kernel_size=2, stride=2, padding=0)
        self.Sconv5 = single_conv(out_ch*2, out_ch)        
        
        self.conv6 = nn.Conv2d(out_ch, 1, kernel_size=1, stride=1, padding=0)
        self.output = nn.Sigmoid()  

    def forward(self, x):
        
        c1 = self.Sconv1(x)
        p1 = self.pool1(c1)
        
        c2 = self.Sconv2(p1)
        p2 = self.pool2(c2)
        
        c3 = self.Sconv3(p2)
        
        t1 = self.Tconv1(c3)
        x1 = torch.cat([c2, t1], dim=1)
        c4 = self.Sconv4(x1)
        
        t2 = self.Tconv2(c4)
        x2 = torch.cat([c1, t2], dim=1)
        c5 = self.Sconv5(x2) 
        
        c6 = self.conv6(c5) 
        
        map_x = self.output(c6)
        
        return map_x

      
class High_to_Low(nn.Module):
    
    def __init__(self, in_ch, out_ch):
        super(High_to_Low, self).__init__()
        
        self.Dconv1 = double_conv_HL(in_ch, out_ch)
        
        self.High_Map = High_map(in_ch, out_ch)
        
        self.Sconv1 = single_conv(out_ch, out_ch)
        
    def forward(self, x):
        
        D_x = self.Dconv1(x)
        
        H_x = self.High_Map(x)
        
        Mix_x = torch.add(D_x*H_x, D_x)
        
        out_x = self.Sconv1(Mix_x)
        
        return out_x

In [8]:
input = 1
numf =16

class G_unet(nn.Module):
    def __init__(self):
        super(G_unet, self).__init__()
        
        self.Dconv1 = High_to_Low(input, numf)       
        self.pool1 = nn.MaxPool2d(2)
        
        self.Dconv2 = High_to_Low(numf, numf*2)           
        self.pool2 = nn.MaxPool2d(2)        
        
        self.Dconv3 = High_to_Low(numf*2, numf*4)          
        self.pool3 = nn.MaxPool2d(2)     
        
        self.Dconv4 = High_to_Low(numf*4, numf*8)              
        self.pool4 = nn.MaxPool2d(2)
        
        self.Dconv5 = double_conv(numf*8, numf*16) 
        
        self.Tconv1 = nn.ConvTranspose2d(numf*16, numf*8, kernel_size=2, stride=2, padding=0)
        self.ACSgate1 = Vox_Att(numf*16)
        self.Dconv6 = double_conv(numf*16, numf*8) 
        
        self.Tconv2 = nn.ConvTranspose2d(numf*8, numf*4, kernel_size=2, stride=2, padding=0)
        self.ACSgate2 = Vox_Att(numf*8)
        self.Dconv7 = double_conv(numf*8, numf*4)        
        
        self.Tconv3 = nn.ConvTranspose2d(numf*4, numf*2, kernel_size=2, stride=2, padding=0)
        self.ACSgate3 = Vox_Att(numf*4)
        self.Dconv8 = double_conv(numf*4, numf*2)          
        
        self.Tconv4 = nn.ConvTranspose2d(numf*2, numf, kernel_size=2, stride=2, padding=0)
        self.ACSgate4 = Vox_Att(numf*2)
        self.Dconv9 = double_conv(numf*2, numf)   
        
        self.conv19 = nn.Conv2d(numf, 1, kernel_size=1, stride=1, padding=0)
        self.output = nn.Sigmoid()
        
    def forward(self, input):
        
        c1 = self.Dconv1(input)
        p1 = self.pool1(c1)
            
        c2 = self.Dconv2(p1)
        p2 = self.pool2(c2)  
        
        c3 = self.Dconv3(p2)
        p3 = self.pool3(c3)  
        
        c4 = self.Dconv4(p3)
        p4 = self.pool4(c4)
        
        c5 = self.Dconv5(p4)
        
        t1 = self.Tconv1(c5)
        x1 = torch.cat([c4, t1], dim=1)
        x1 = self.ACSgate1(x1)
        c6 = self.Dconv6(x1)
        
        t2 = self.Tconv2(c6)
        x2 = torch.cat([c3, t2], dim=1)
        x2 = self.ACSgate2(x2)
        c7 = self.Dconv7(x2)         
        
        t3 = self.Tconv3(c7)
        x3 = torch.cat([c2, t3], dim=1)
        x3 = self.ACSgate3(x3)
        c8 = self.Dconv8(x3)   
        
        t4 = self.Tconv4(c8)
        x4 = torch.cat([c1, t4], dim=1)
        x4 = self.ACSgate4(x4)
        c9 = self.Dconv9(x4)
        
        c19 = self.conv19(c9)
        
        output = self.output(c19)
            
        return output
        

In [9]:
CFENet = G_unet().to(device)

print(CFENet)

print('# CFENet parameters:', sum(param.numel() for param in CFENet.parameters()))

G_unet(
  (Dconv1): High_to_Low(
    (Dconv1): double_conv_HL(
      (conv): Sequential(
        (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): GroupNorm(8, 16, eps=1e-05, affine=True)
        (2): ReLU(inplace)
        (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): GroupNorm(8, 16, eps=1e-05, affine=True)
        (5): ReLU(inplace)
      )
    )
    (High_Map): High_map(
      (Sconv1): single_conv(
        (conv): Sequential(
          (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): GroupNorm(8, 16, eps=1e-05, affine=True)
          (2): ReLU(inplace)
        )
      )
      (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (Sconv2): single_conv(
        (conv): Sequential(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): GroupNorm(8, 16, eps=1e-05, affine=True)
          (2): ReLU(inplace)

In [10]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1, eps=1e-7):
        super().__init__()
        self.smooth = smooth
        self.eps = eps

    def forward(self, preds, labels):

        return 1 - (2 * torch.sum(preds * labels) + self.smooth) / (torch.sum(preds) + torch.sum(labels) + self.smooth)


class jaccard_Loss(nn.Module):
    def __init__(self, smooth=1, eps=1e-7):
        super().__init__()
        self.smooth = smooth
        self.eps = eps

    def forward(self, preds, labels):
        
        intersection = torch.sum(preds * labels)
        return 1 - (intersection + self.smooth) / (torch.sum(preds) + torch.sum(labels) - intersection + self.smooth)    
    

def jaccard_coef(preds, labels):
    
    smooth=1
    eps=1e-7
    
    intersection = torch.sum(preds * labels)
    return (intersection + smooth) / (torch.sum(preds) + torch.sum(labels) - intersection + smooth)
          
def dice_coef(preds, labels):
    
    smooth=1
    eps=1e-7
    
    return (2 * torch.sum(preds * labels) + smooth) / (torch.sum(preds) + torch.sum(labels) + smooth)


def com_evaluation(preds, labels):
    
    smooth=1
    eps=1e-7
    
    intersection = torch.sum(labels * preds)
    num_pred = torch.sum(preds)
    num_lab = torch.sum(labels)

    num_TP = intersection
    num_TN = (224*224*80) - (num_pred + num_lab - intersection)
    
    jaccard = (intersection + smooth) / (num_pred + num_lab - intersection + smooth)
    dice = (2 * intersection + smooth) / (num_pred + num_lab + smooth)

    pre_list = preds.numpy()
    lab_list = labels.numpy()
    assd = bin_eval.assd(pre_list, lab_list)
    
    return jaccard, dice, assd

In [11]:
Criterion = DiceLoss()

C_optimizer = torch.optim.Adam(CFENet.parameters(), lr=0.0001)

scheduler = optim.lr_scheduler.StepLR(C_optimizer, step_size=20, gamma=0.9)

In [None]:
list = []

for epoch in range(300):
    
    run_dice_loss = 0.0
    running_jaccard = 0.0

    dice_list = []
    jaccard_list = []
    assd_list = []

    valing_dice = 0.0
    valing_jaccard = 0.0
    valing_assd = 0.0

    train_i = 0
    test_i = 0
    val_i = 0
    
    scheduler.step()
    
    for param_group in C_optimizer.param_groups:
        if epoch % 20 == 0:
            print(param_group['lr'])
    
    
    for i_1, train_data in enumerate(train_loader):
        
        inputs, labels = train_data
        
        inputs = inputs.to(device)
        labels = labels.to(device)
               
        pre_labs = CFENet(inputs)
        dice_loss = Criterion(labels, pre_labs)
        
        jaccard = jaccard_coef(pre_labs, labels)

        C_optimizer.zero_grad()
        dice_loss.backward()
        C_optimizer.step()

        running_jaccard += jaccard
        run_dice_loss += dice_loss.item()

        train_i += 1
        
        if train_i == 600:
            
            lab_list = torch.zeros((80, 1, 224, 224), dtype = torch.float32)
            pre_list = torch.zeros((80, 1, 224, 224), dtype = torch.float32)
            
            t_i = 0
            for i_2, data in enumerate(test_loader):

                with torch.no_grad():
                    
                    test_data, test_lab = data

                    test_data = test_data.to(device)
                    test_lab = test_lab.to(device)

                    pre_lab = CFENet(test_data)
                    
                    lab_list[t_i*10:t_i*10+10,:,:,:] = test_lab
                    pre_list[t_i*10:t_i*10+10,:,:,:] = pre_lab
                    
                    t_i = t_i + 1
                    
                    if t_i%8 == 0:
                        
                        t_i = 0
                        
                        pre_list[pre_list < 0.5] = 0
                        pre_list[pre_list >= 0.5] = 1
                        jaccard, dice, assd = com_evaluation(pre_list, lab_list)
                        
                        dice_list.append(dice) 
                        jaccard_list.append(jaccard)
                        assd_list.append(assd)
                        
                        Dice_npy = np.array(dice_list)
                        Jaccard_npy = np.array(jaccard_list)
                        Assd_npy = np.array(assd_list)
                        
                        test_i += 1

                        
            lab_list = torch.zeros((80, 1, 224, 224), dtype = torch.float32)
            pre_list = torch.zeros((80, 1, 224, 224), dtype = torch.float32)                        
            
            t_i = 0
            for i_3, v_data in enumerate(val_loader):

                with torch.no_grad():
                    
                    val_data, val_lab = v_data

                    val_data = val_data.to(device)
                    val_lab = val_lab.to(device)

                    pre_lab = CFENet(val_data)
                    
                    lab_list[t_i*10:t_i*10+10,:,:,:] = val_lab
                    pre_list[t_i*10:t_i*10+10,:,:,:] = pre_lab
                    
                    t_i = t_i + 1
                    
                    if t_i%8 == 0:
                        
                        t_i = 0
                        
                        pre_list[pre_list < 0.5] = 0
                        pre_list[pre_list >= 0.5] = 1
                        jaccard, dice, assd = com_evaluation(pre_list, lab_list)
                        
                        valing_dice += dice
                        valing_jaccard += jaccard
                        valing_assd += assd
                        
                        val_i += 1
                        
                        
            list.append(valing_jaccard)

            if  max(list) == valing_jaccard:

                torch.save(CFENet, "model/HL-LH.pkl")
                print("-------save %d epoch model---------"% epoch)

            print('e: %d, d_loss: %.5f, tr_jacc: %.5f, val_dice: %.5f, val_jacc: %.5f, val_assd: %.5f,' % 
                  (epoch, run_dice_loss/train_i, running_jaccard/train_i, valing_dice/val_i, valing_jaccard/val_i, valing_assd/val_i))
            print('test_dice: %.5f, dice_std: %.5f, test_jacc: %.5f, jacc_std: %.5f, test_assd: %.5f, assd_std: %.5f,' % 
                  (Dice_npy.mean(), Dice_npy.std(), Jaccard_npy.mean(), Jaccard_npy.std(), Assd_npy.mean(), Assd_npy.std()))
            print('-----------------------------------------------')
            
            run_dice_loss = 0.0
            running_jaccard = 0.0
            
            dice_list = []
            jaccard_list = []
            assd_list = []
            
            valing_dice = 0.0
            valing_jaccard = 0.0
            valing_assd = 0.0

            train_i = 0
            test_i = 0
            val_i = 0
                          
print("Finished Training")