In [1]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as utils
import torch.nn.functional as F

from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.utils import save_image

from torchvision.transforms import Compose, CenterCrop, Normalize, ToTensor

from glob import glob

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

from PIL import Image
import cv2 as cv

manualSeed = 999

print("Random Seed:", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

Random Seed: 999


<torch._C.Generator at 0x7f881c0d2290>

In [2]:
batch_size = 6

In [3]:
class HorizontalFlip(object):
    """Horizontally flips the given PIL.Image with a probability of 0.5."""

    def __call__(self, img):
        return img.transpose(Image.FLIP_LEFT_RIGHT)


class VerticalFlip(object):
    """Vertically flips the given PIL.Image with a probability of 0.5."""

    def __call__(self, img):
        return img.transpose(Image.FLIP_TOP_BOTTOM)
    

class ReLabel(object):

    def __init__(self, nlabel):
        self.nlabel = nlabel

    def __call__(self, tensor):
    
        tensor[tensor >= 0.5] = 1
        tensor[tensor < 0.5] = 0
        
        return tensor


class Com_my(object):
    
    def __init__(self):
        self.to_tensor = Compose([
            ToTensor(),
        ])

    def __call__(self, image):
        
        img = image.convert('F')
        
        image = self.to_tensor(image)
        img = self.to_tensor(img)
        
        out = torch.cat([img, image], dim=0)
        
        return out   
     
    
class Normalize_my(object):

    def __call__(self, tensor):
    
        tensor[0,:,:] = tensor[0,:,:]/tensor[0,:,:].max()
        tensor[1,:,:] = tensor[1,:,:]/tensor[1,:,:].max()
        tensor[2,:,:] = tensor[2,:,:]/tensor[2,:,:].max()
        tensor[3,:,:] = tensor[3,:,:]/tensor[3,:,:].max()
        
        return tensor    


class Dataset(torch.utils.data.Dataset):

    def __init__(self, root):

        self.root = root
        
        if not os.path.exists(self.root):
            raise Exception("[!] {} not exists.".format(root))
        
        self.img_transform = Compose([
            
            transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.3),
            Com_my(),
            Normalize_my(),
        ])
        
        self.label_transform = Compose([
            ToTensor(),
            ReLabel(1),
        ])
        
        #sort file names
        self.input_paths = sorted(glob(os.path.join(self.root, '{}/*.jpg'.format("2017_np/train_data_jpg"))))
        self.label_paths = sorted(glob(os.path.join(self.root, '{}/*.jpg'.format("2017_np/train_lab_jpg"))))
        self.name = os.path.basename(root)
        
        if len(self.input_paths) == 0 or len(self.label_paths) == 0:
            raise Exception("No images/labels are found in {}".format(self.root))

    def __getitem__(self, index):
        
        image = Image.open(self.input_paths[index]).convert('RGB')
        label = Image.open(self.label_paths[index]).convert('P')

        if random.random() > 0.3:
            image = np.array(image)
            label = np.array(label)

            s1 = random.randint(-25, 25)
            s2 = random.randint(-25, 25)
            s3 = random.randint(-30, 30)
            s4 = random.randint(-30, 30)        

            d1 = random.randint(215, 295)
            d2 = random.randint(215, 295)
            d3 = random.randint(160, 220)
            d4 = random.randint(160, 220)

            srcTri = np.array( [[0, 0], [255, 0], [0, 191], [255, 191]] ).astype(np.float32)
            dstTri = np.array( [[s1, s2], [d1, s3], [s4, d3], [d2, d4]] ).astype(np.float32)

            warp_mat = cv.getPerspectiveTransform(srcTri, dstTri)

            image = cv.warpPerspective(image, warp_mat, (256, 192))
            label = cv.warpPerspective(label, warp_mat, (256, 192))

            image = Image.fromarray(image)
            label = Image.fromarray(label)
        
        #randomly flip images
        if random.random() > 0.5:
            image = HorizontalFlip()(image)
            
            label = HorizontalFlip()(label)
              
        if random.random() > 0.5:
            image = VerticalFlip()(image)
            label = VerticalFlip()(label)
        
        image = self.img_transform(image)
        label = self.label_transform(label)

        return image, label

    def __len__(self):
        return len(self.input_paths)

    
class Dataset_val(torch.utils.data.Dataset):

    def __init__(self, root):

        self.root = root
        
        if not os.path.exists(self.root):
            raise Exception("[!] {} not exists.".format(root))
        
        self.img_transform = Compose([
            Com_my(),
            Normalize_my(),
        ])
        
        self.label_transform = Compose([
            ToTensor(),
            ReLabel(1),
        ])
        
        #sort file names
        self.input_paths = sorted(glob(os.path.join(self.root, '{}/*.jpg'.format("2017_np/val_data_jpg"))))
        self.label_paths = sorted(glob(os.path.join(self.root, '{}/*.jpg'.format("2017_np/val_lab_jpg"))))
        self.name = os.path.basename(root)
        
        if len(self.input_paths) == 0 or len(self.label_paths) == 0:
            raise Exception("No images/labels are found in {}".format(self.root))

    def __getitem__(self, index):
        
        image = Image.open(self.input_paths[index]).convert('RGB')
        label = Image.open(self.label_paths[index]).convert('P')
        
        image = self.img_transform(image)
        label = self.label_transform(label)

        return image, label

    def __len__(self):
        return len(self.input_paths)
    
class Dataset_test(torch.utils.data.Dataset):

    def __init__(self, root):

        self.root = root
        
        if not os.path.exists(self.root):
            raise Exception("[!] {} not exists.".format(root))
        
        self.img_transform = Compose([
            Com_my(),
            Normalize_my(),
        ])
        
        self.label_transform = Compose([
            ToTensor(),
            ReLabel(1),
        ])
        
        #sort file names
        self.input_paths = sorted(glob(os.path.join(self.root, '{}/*.jpg'.format("2017_np/test_data_jpg"))))
        self.label_paths = sorted(glob(os.path.join(self.root, '{}/*.jpg'.format("2017_np/test_lab_jpg"))))
        self.name = os.path.basename(root)
        
        if len(self.input_paths) == 0 or len(self.label_paths) == 0:
            raise Exception("No images/labels are found in {}".format(self.root))

    def __getitem__(self, index):
        
        image = Image.open(self.input_paths[index]).convert('RGB')
        label = Image.open(self.label_paths[index]).convert('P')
        
        image = self.img_transform(image)
        label = self.label_transform(label)

        return image, label

    def __len__(self):
        return len(self.input_paths)

In [4]:
def loader(dataset, batch_size, num_workers=8, shuffle=True):

    input_images = dataset
    input_loader = torch.utils.data.DataLoader(dataset=input_images, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)

    return input_loader

train_loader = loader(Dataset('../../'), batch_size= batch_size)
test_loader = loader(Dataset_test('../../'), batch_size= batch_size, shuffle=False)
val_loader = loader(Dataset_val('../../'), batch_size= batch_size, shuffle=False)

In [5]:
device = torch.device("cuda:1" if (torch.cuda.is_available()) else "cpu")

print(torch.cuda.is_available())

print(train_loader.batch_size)

True
6


In [6]:
#High-level Feature Enhancement Module

class Vox_Att(nn.Module):
    
    def __init__(self, in_ch):
        super(Vox_Att, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_ch, int(in_ch/2), kernel_size=1, padding=0),
            nn.GroupNorm(8, int(in_ch/2)),
            nn.ReLU(inplace=True),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(int(in_ch/2), int(in_ch/2), kernel_size=3, padding=5, dilation=5),
            nn.GroupNorm(8, int(in_ch/2)),
            nn.ReLU(inplace=True),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(int(in_ch/2), in_ch, kernel_size=1, padding=0),
            nn.GroupNorm(8, in_ch),
            nn.ReLU(inplace=True),
        )
        self.output = nn.Sigmoid()
        
    def forward(self, in_x):
        
        x = self.conv1(in_x)
        x = self.conv2(x)
        x = self.conv3(x)
        
        x = self.output(x)
        
        return in_x * x + in_x
    

In [7]:
class double_conv(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        
        x = self.conv(x)
        return x
    
    
class double_conv_HL(nn.Module):
    
    def __init__(self, in_ch, out_ch):
        super(double_conv_HL, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        
        x = self.conv(x)
        return x    
    
class single_conv(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(single_conv, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        
        x = self.conv(x)
        return x    

    
#Low-level Feature Enhancement Module  

class High_map(nn.Module):
    
    def __init__(self, in_ch, out_ch):
        super(High_map, self).__init__()
        
        self.Sconv1 = single_conv(in_ch, out_ch)
        self.pool1 = nn.MaxPool2d(2)
        
        self.Sconv2 = single_conv(out_ch, out_ch)       
        self.pool2 = nn.MaxPool2d(2)
        
        self.Sconv3 = single_conv(out_ch, out_ch*2)           
        
        self.Tconv1 = nn.ConvTranspose2d(out_ch*2, out_ch, kernel_size=2, stride=2, padding=0)
        self.Sconv4 = single_conv(out_ch*2, out_ch)
        
        self.Tconv2 = nn.ConvTranspose2d(out_ch, out_ch, kernel_size=2, stride=2, padding=0)
        self.Sconv5 = single_conv(out_ch*2, out_ch)        
        
        self.conv6 = nn.Conv2d(out_ch, 1, kernel_size=1, stride=1, padding=0)
        self.output = nn.Sigmoid()  

    def forward(self, x):
        
        c1 = self.Sconv1(x)
        p1 = self.pool1(c1)
        
        c2 = self.Sconv2(p1)
        p2 = self.pool2(c2)
        
        c3 = self.Sconv3(p2)
        
        t1 = self.Tconv1(c3)
        x1 = torch.cat([c2, t1], dim=1)
        c4 = self.Sconv4(x1)
        
        t2 = self.Tconv2(c4)
        x2 = torch.cat([c1, t2], dim=1)
        c5 = self.Sconv5(x2) 
        
        c6 = self.conv6(c5) 
        
        map_x = self.output(c6)
        
        return map_x

      
class High_to_Low(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(High_to_Low, self).__init__()
        
        self.Dconv1 = double_conv_HL(in_ch, out_ch)
        
        self.High_Map = High_map(in_ch, out_ch)
        
        self.Sconv1 = single_conv(out_ch, out_ch)
        
    def forward(self, x):
        
        D_x = self.Dconv1(x)
        
        H_x = self.High_Map(x)
        
        Mix_x = torch.add(D_x*H_x, D_x)
        
        out_x = self.Sconv1(Mix_x)
        
        return out_x

In [8]:
input = 4
numf =16

class G_unet(nn.Module):
    def __init__(self):
        super(G_unet, self).__init__()
        
        self.Dconv1 = High_to_Low(input, numf)       
        self.pool1 = nn.MaxPool2d(2)
        
        self.Dconv2 = High_to_Low(numf, numf*2)           
        self.pool2 = nn.MaxPool2d(2)        
        
        self.Dconv3 = High_to_Low(numf*2, numf*4)          
        self.pool3 = nn.MaxPool2d(2)     
        
        self.Dconv4 = High_to_Low(numf*4, numf*8)              
        self.pool4 = nn.MaxPool2d(2)
        
        self.Dconv5 = double_conv(numf*8, numf*16) 
        
        self.Tconv1 = nn.ConvTranspose2d(numf*16, numf*8, kernel_size=2, stride=2, padding=0)
        self.ACSgate1 = Vox_Att(numf*16)
        self.Dconv6 = double_conv(numf*16, numf*8) 
        
        self.Tconv2 = nn.ConvTranspose2d(numf*8, numf*4, kernel_size=2, stride=2, padding=0)
        self.ACSgate2 = Vox_Att(numf*8)
        self.Dconv7 = double_conv(numf*8, numf*4)        
        
        self.Tconv3 = nn.ConvTranspose2d(numf*4, numf*2, kernel_size=2, stride=2, padding=0)
        self.ACSgate3 = Vox_Att(numf*4)
        self.Dconv8 = double_conv(numf*4, numf*2)          
        
        self.Tconv4 = nn.ConvTranspose2d(numf*2, numf, kernel_size=2, stride=2, padding=0)
        self.ACSgate4 = Vox_Att(numf*2)
        self.Dconv9 = double_conv(numf*2, numf)   
        
        self.conv19 = nn.Conv2d(numf, 1, kernel_size=1, stride=1, padding=0)
        self.output = nn.Sigmoid()
        
    def forward(self, input):
        
        c1 = self.Dconv1(input)
        p1 = self.pool1(c1)
            
        c2 = self.Dconv2(p1)
        p2 = self.pool2(c2)  
        
        c3 = self.Dconv3(p2)
        p3 = self.pool3(c3)  
        
        c4 = self.Dconv4(p3)
        p4 = self.pool4(c4)
        
        c5 = self.Dconv5(p4)
        
        t1 = self.Tconv1(c5)
        x1 = torch.cat([c4, t1], dim=1)
        x1 = self.ACSgate1(x1)
        c6 = self.Dconv6(x1)
        
        t2 = self.Tconv2(c6)
        x2 = torch.cat([c3, t2], dim=1)
        x2 = self.ACSgate2(x2)
        c7 = self.Dconv7(x2)         
        
        t3 = self.Tconv3(c7)
        x3 = torch.cat([c2, t3], dim=1)
        x3 = self.ACSgate3(x3)
        c8 = self.Dconv8(x3)   
        
        t4 = self.Tconv4(c8)
        x4 = torch.cat([c1, t4], dim=1)
        x4 = self.ACSgate4(x4)
        c9 = self.Dconv9(x4)
        
        c19 = self.conv19(c9)
        
        output = self.output(c19)
            
        return output
        

In [9]:
CFENet = G_unet().to(device)

print(CFENet)

G_unet(
  (Dconv1): High_to_Low(
    (Dconv1): double_conv_HL(
      (conv): Sequential(
        (0): Conv2d(4, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): GroupNorm(8, 16, eps=1e-05, affine=True)
        (2): ReLU(inplace)
        (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): GroupNorm(8, 16, eps=1e-05, affine=True)
        (5): ReLU(inplace)
      )
    )
    (High_Map): High_map(
      (Sconv1): single_conv(
        (conv): Sequential(
          (0): Conv2d(4, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): GroupNorm(8, 16, eps=1e-05, affine=True)
          (2): ReLU(inplace)
        )
      )
      (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (Sconv2): single_conv(
        (conv): Sequential(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): GroupNorm(8, 16, eps=1e-05, affine=True)
          (2): ReLU(inplace)

In [10]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1, eps=1e-7):
        super().__init__()
        self.smooth = smooth
        self.eps = eps

    def forward(self, preds, labels):
    
        return 1 - (2 * torch.sum(preds * labels) + self.smooth) / (torch.sum(preds) + torch.sum(labels) + self.smooth)


class jaccard_Loss(nn.Module):
    def __init__(self, smooth=1, eps=1e-7):
        super().__init__()
        self.smooth = smooth
        self.eps = eps

    def forward(self, preds, labels):
        
        intersection = torch.sum(preds * labels)
        return 1 - (intersection + self.smooth) / (torch.sum(preds) + torch.sum(labels) - intersection + self.smooth)    
    

def jaccard_coef(preds, labels):
    
    smooth=1
    eps=1e-7
    
    intersection = torch.sum(preds * labels)
    return (intersection + smooth) / (torch.sum(preds) + torch.sum(labels) - intersection + smooth)
          
def dice_coef(preds, labels):
    
    smooth=1
    eps=1e-7
    
    return (2 * torch.sum(preds * labels) + smooth) / (torch.sum(preds) + torch.sum(labels) + smooth)

  
def com_evaluation(preds, labels):
    
    smooth=1
    eps=1e-7
    
    all_jaccard = 0.0
    all_dice = 0.0
    all_accuracy = 0.0
    
    for id in range(batch_size):
    
        intersection = torch.sum(labels[id,:,:,:] * preds[id,:,:,:])
        num_pred = torch.sum(preds[id,:,:,:])
        num_lab = torch.sum(labels[id,:,:,:])
        
        num_TP = intersection
        num_TN = (192*256) - (num_pred + num_lab - intersection)
        
        jaccard = (intersection + smooth) / (num_pred + num_lab - intersection + smooth)
        dice = (2 * intersection + smooth) / (num_pred + num_lab + smooth)
        accuracy = (num_TP + num_TN) / (192*256)

        all_jaccard = all_jaccard + jaccard
        all_dice = all_dice + dice
        all_accuracy = all_accuracy + accuracy

    return all_jaccard/batch_size, all_dice/batch_size, all_accuracy/batch_size

In [11]:
Criterion = DiceLoss()

C_optimizer = torch.optim.Adam(CFENet.parameters(), lr=0.0001)

scheduler = optim.lr_scheduler.StepLR(C_optimizer, step_size=20, gamma=0.9)

In [None]:
list = []

for epoch in range(300):
    
    run_dice_loss = 0.0
    running_jaccard = 0.0

    testing_dice = 0.0
    testing_jaccard = 0.0
    testing_accuracy = 0.0

    valing_dice = 0.0
    valing_jaccard = 0.0
    valing_accuracy = 0.0

    train_i = 0
    test_i = 0
    val_i = 0
    
    scheduler.step()
    
    for param_group in C_optimizer.param_groups:
        print(param_group['lr'])
    
    
    for i_1, train_data in enumerate(train_loader):
        
        inputs, labels = train_data
        
        inputs = inputs.to(device)
        labels = labels.to(device)
              
        pre_labs = CFENet(inputs)
        dice_loss = Criterion(labels, pre_labs)
        
        jaccard = jaccard_coef(pre_labs, labels)

        C_optimizer.zero_grad()
        dice_loss.backward()
        C_optimizer.step()

        running_jaccard += jaccard
        run_dice_loss += dice_loss.item()

        train_i += 1
        
        if train_i == 100:
    
            for i_2, data in enumerate(test_loader):

                with torch.no_grad():

                    test_data, test_lab = data

                    test_data = test_data.to(device)
                    test_lab = test_lab.to(device)

                    pre_data = CFENet(test_data)

                    pre_data[pre_data < 0.5] = 0
                    pre_data[pre_data >= 0.5] = 1

                    jaccard, dice, accuracy = com_evaluation(pre_data, test_lab)

                    testing_dice += dice 
                    testing_jaccard += jaccard
                    testing_accuracy += accuracy

                    test_i += 1  
      
            for i_3, v_data in enumerate(val_loader):

                with torch.no_grad():

                    val_data, val_lab = v_data

                    val_data = val_data.to(device)
                    val_lab = val_lab.to(device)

                    pre_data = CFENet(val_data)

                    pre_data[pre_data < 0.5] = 0
                    pre_data[pre_data >= 0.5] = 1

                    jaccard, dice, accuracy = com_evaluation(pre_data, val_lab)             

                    valing_dice += dice 
                    valing_jaccard += jaccard
                    valing_accuracy += accuracy

                    val_i += 1                    
         
            list.append(valing_jaccard)

            if  max(list) == valing_jaccard:

                torch.save(CFENet, "model/HL-LH_net.pkl")
                print("-                                     -")
                print("------- Save %d epoch model---------"% epoch)

            print('e: %d, d_loss: %.5f, tr_jacc: %.5f, val_dice: %.5f, val_jacc: %.5f, val_acc: %.5f,' % 
                  (epoch, run_dice_loss/train_i, running_jaccard/train_i, valing_dice/val_i, valing_jaccard/val_i, valing_accuracy/val_i))
            print('test_dice: %.5f, test_jacc: %.5f, test_acc: %.5f,' % 
                  (testing_dice/test_i, testing_jaccard/test_i, testing_accuracy/test_i))
            print('-----------------------------------------------')
            
            run_dice_loss = 0.0
            running_jaccard = 0.0
            
            testing_dice = 0.0
            testing_jaccard = 0.0
            testing_accuracy = 0.0
            
            valing_dice = 0.0
            valing_jaccard = 0.0
            valing_accuracy = 0.0

            train_i = 0
            test_i = 0
            val_i = 0
              
            
print("Finished Training")