In [22]:
import cv2
import os
from numpy import *
import glob
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
import torch.nn.functional as F
import shutil
import sys,time

In [23]:
total_data_path = 'D:/workspace/Projects/vae/PedCut2013_SegmentationDataset/data/completeData/left_images/'
total_label_path = 'D:/workspace/Projects/vae/PedCut2013_SegmentationDataset/data/completeData/left_groundTruth/'
val_data_path = 'D:/workspace/Projects/vae/PedCut2013_SegmentationDataset/data/validationData/left_images'
val_label_path = 'D:/workspace/Projects/vae/PedCut2013_SegmentationDataset/data/validationData/left_groundTruth'

test_data_path = 'D:/workspace/Projects/vae/PedCut2013_SegmentationDataset/data/testData/left_images'
test_label_path = 'D:/workspace/Projects/vae/PedCut2013_SegmentationDataset/data/testData/left_groundTruth'
total_pngs = glob.glob(total_data_path+"/*.png")
val_pngs = glob.glob(val_data_path+"/*.png")
test_pngs = glob.glob(test_data_path+"/*.png")
val_imgs = os.listdir(val_data_path)
test_imgs = os.listdir(test_data_path)
train_path = 'D:/workspace/Projects/vae/PedCut2013_SegmentationDataset/data/trainData/'
train_pngs = []
max_pixels = 0
min_pixels = inf
pixels = []
for png in total_pngs:
    tag = png.split('\\')[-1]
    if (tag not in val_imgs) and (tag not in test_imgs):
        src_ground_truth = png.replace("left_images","left_groundTruth")
        src_image = png
        dst_ground_truth = train_path+"left_groundTruth/" + tag
        dst_image = train_path+"left_images/"+tag
        train_pngs.append(png)
        shutil.copy(src=src_ground_truth,dst=dst_ground_truth)
        shutil.copy(src=src_image,dst=dst_image)
        
#     img = cv2.imread(png)
#     p = img.shape[0]*img.shape[1]
#     if max_pixels < p: max_pixels = p
#     if min_pixels > p: min_pixels = p
#     pixels.append([img.shape[0], img.shape[1]])

# pixels = array(pixels)
# ratios = []
# for height, width in (zip(pixels[:,0],pixels[:,1])):
#     ratio = height/width
#     ratios.append(ratio)

# ratio = sorted(ratios)[int(len(ratios)/2)]

In [25]:
def normalize(volume):
    mean = volume.mean()
    std = volume.std()
    out = (volume - mean)/(std+1e-20)
    out_random = zeros(volume.shape)
    out[volume == 0] = out_random[volume ==0]
    return out


def default_loader(image, label, ratio, width, train_set): 

    image = cv2.imread(image)
    mask = cv2.imread(label)
    new_size = (int(width * ratio), width)
    old_height, old_width, _ = image.shape
    assert(old_height == mask.shape[0])
    assert(old_width == mask.shape[1])
    
    height_ = int(old_width * ratio)
    if old_height < height_:
        padding = height_ - old_height
        padding_array = zeros((padding, old_width,3))
        image = concatenate((image, padding_array),0)
        mask = concatenate((mask, padding_array), 0)
    else:
        cutting = int(round((old_height - height_)/2))
        image = image[cutting: old_height-cutting]
        mask = mask[cutting: old_height-cutting]
    
    x_ratio = new_size[1] / image.shape[1]
    y_ratio = new_size[0] / image.shape[0]
    
    image = cv2.resize(image, (0, 0), fx=x_ratio, fy=y_ratio, interpolation=cv2.INTER_NEAREST)
    mask = cv2.resize(mask, (0, 0), fx=x_ratio, fy=y_ratio, interpolation=cv2.INTER_NEAREST)
    
    image = image.transpose(2, 0, 1)
    mask = mask[:,:,0].astype(float)
    mask[mask==255] = 1.
    return normalize(image), mask


class Data(Dataset):
    def __init__(self, files, ratio, width, label_path, train_set, loader = default_loader):
        
        image_label = []
        for file_ in files:
            tag = file_.split('\\')[-1]
            mask_name = label_path + tag
            image_label.append((file_, mask_name))
        self.image_label = image_label
        self.train_set = train_set
        self.loader = loader
        self.ratio = ratio
        self.width = width
    
    def __getitem__(self, index):
        image, label = self.image_label[index]
        image, label = self.loader(image, label, self.ratio, self.width, self.train_set)
        
        return image, label
    
    def __len__(self):
        return len(self.image_label)
    


# Models

In [26]:

class Encoder(nn.Module):
    def __init__(self, ):
        super(Encoder, self).__init__()
        #conv1
        self.conv1_1 = nn.Conv2d(3, 16, 3,padding=1)
        self.relu1_1 = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)  # 1/2
        
        # conv2
        self.conv2_1 = nn.Conv2d(16, 32, 3, padding=1)
        self.relu2_1 = nn.ReLU(inplace=True)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)  # 1/4
        
        # conv3
        self.conv3_1 = nn.Conv2d(32, 64, 3, padding=1)
        self.relu3_1 = nn.ReLU(inplace=True)
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)  # 1/8
        
        
    def forward(self, x):
        
        x = self.conv1_1(x)
        
        x = self.relu1_1(x)
        x = self.pool1(x)
        
        x = self.conv2_1(x)
        x = self.relu2_1(x)
        x = self.pool2(x)
        
        x = self.conv3_1(x)
        x = self.relu3_1(x)
        x = self.pool3(x)
        
        
        return x
    
    
class Decoder(nn.Module):
    def __init__(self, ):
        super(Decoder, self).__init__()
        
        # upConv3
        self.upsample1 = nn.Upsample(scale_factor=2)
        self.upconv1_1 = nn.Conv2d(64,32, 3,padding=1)
        self.BN1 = nn.BatchNorm2d(32)
        self.uprelu1_1 = nn.ReLU(inplace=True)
        
        # upConv2
        self.upsample2 = nn.Upsample(scale_factor=2)
        self.upconv2_1 = nn.Conv2d(32,32, 3,padding=1)
        self.BN2 = nn.BatchNorm2d(32)
        self.uprelu2_1 = nn.ReLU(inplace=True)
        
        
        # upConv1
        self.upsample3 = nn.Upsample(scale_factor=2)
        self.upconv3_1 = nn.Conv2d(32,32, 3,padding=1)
        self.BN3 = nn.BatchNorm2d(32)
        self.uprelu3_1 = nn.ReLU(inplace=True)
        
        self.conv = nn.Conv2d(32, 2, 1)
        
    def forward(self, x):
        
        x = self.upsample1(x)
        x = self.upconv1_1(x)
        x = self.BN1(x)
        x = self.uprelu1_1(x)
        
        x = self.upsample2(x)
        x = self.upconv2_1(x)
        x = self.BN2(x)
        x = self.uprelu2_1(x)
        
        x = self.upsample3(x)
        x = self.upconv3_1(x)
        x = self.BN3(x)
        x = self.uprelu3_1(x)
        
        x = self.conv(x)
        
        return x

class net(nn.Module):
    def __init__(self, ):
        super(net, self).__init__()
        
        self.encoder = Encoder()
        self.decoder = Decoder()
        
    def forward(self,x):
        
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [27]:
model = net()
x = torch.ones((8,3,200,96))
y = model(x)
print(y.shape)

torch.Size([8, 2, 200, 96])


Eval function

In [56]:
def torch_binary_dice(input, target, eps = 0.000001):
    assert(input.shape == target.shape)
    input = input.float()
    target = target.float()
    N = target.size(0)
    input_flat = input.view(N, -1)
    target_flat = target.view(N, -1)
    intersaction = input_flat * target_flat
    dice = (2*intersaction.sum(1) + eps) / (input_flat.sum(1)+target_flat.sum(1)+eps)
    return dice.sum()/N

class multiclass_dice_loss(nn.Module):
    def __init__(self):
        super(multiclass_dice_loss, self).__init__()
        
    def forward(self, input, target):
        input = input[:,1,...]
        input = input.unsqueeze(1)
        
        dice = torch_binary_dice(input, target, eps=0.000001)
        return 1 - dice

    
def evaluation(model, val_loader, epoch, dice_loss):
    model.eval()
    total_dice = 0
    num_batch_processed=0
    total_loss = 0
    for step,(b_x,gt) in enumerate(val_loader):
        
        b_x.resize_(b_x.shape[0],3,b_x.shape[2],b_x.shape[3])
        gt.resize_(gt.shape[0],1,gt.shape[1],gt.shape[2])
        
        b_x = b_x.float()
        gt = gt.float()
        output = model(b_x)
        output_softmax = F.softmax(output, dim=1)
        loss= dice_loss(output_softmax,gt)
        dice = 1-loss
        total_dice += dice
        total_loss += loss
        num_batch_processed += 1
    mean_dice = total_dice/(num_batch_processed+1)
    mean_loss = total_loss/(num_batch_processed+1)
        
    return mean_dice, mean_loss
    

Train

In [61]:
def trainer(model, train_data,val_data, batch_size, lr):
    train_loader = DataLoader(dataset=train_data, batch_size=batch_size,shuffle=True)
    val_loader = DataLoader(dataset=val_data, batch_size=1,shuffle=False)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr,weight_decay=5e-4)
    dice_loss = multiclass_dice_loss()
    best_eval_loss = 999
    total_iterations_per_epoch = (train_data.__len__()//batch_size) + 1
    
    for epoch in range(20):
        print('    ----- training epoch {} -----'.format(epoch+1))
        model.train()
        loss_epoch = 0.0
        num_batch_processed = 0
        
        for step, (b_x, b_y) in enumerate(train_loader):
            
            percentage = int(step/total_iterations_per_epoch*100)
            str = '>'*(step//2)+' '*((total_iterations_per_epoch-step)//2)
            sys.stdout.write('\r'+str+'[%s%%]'%(percentage+1))
            sys.stdout.flush()
            
            b_x.resize_(b_x.shape[0],3,b_x.shape[2],b_x.shape[3])
            b_y.resize_(b_y.shape[0],1,b_y.shape[1],b_y.shape[2])
            
            b_x = b_x.float()
            b_y = b_y.float()
            optimizer.zero_grad()
            output = model(b_x)
            
            output_softmax = F.softmax(output, dim=1)
            loss = dice_loss(output_softmax,b_y)
            loss.backward()
            optimizer.step()
            loss_epoch += loss.item()
            num_batch_processed += 1
            
        loss_epoch /= num_batch_processed
        print("epoch {} train loss: {} train dice {}\n".format(epoch+1, loss, 1-loss))
        eval_dice,eval_loss = evaluation(model, val_loader, epoch, dice_loss)
        print("epoch {} validation loss: {} eval dice {}\n".format(epoch+1, eval_loss, eval_dice))

In [62]:
train_path = 'D:/workspace/Projects/vae/PedCut2013_SegmentationDataset/data/trainData/left_images/'
train_label_path = 'D:/workspace/Projects/vae/PedCut2013_SegmentationDataset/data/trainData/left_groundTruth/'
train_pngs = glob.glob(train_path+'*.png')
val_path = 'D:/workspace/Projects/vae/PedCut2013_SegmentationDataset/data/validationData/left_images/'
val_label_path = 'D:/workspace/Projects/vae/PedCut2013_SegmentationDataset/data/validationData/left_groundTruth/'
val_pngs = glob.glob(val_path+'*.png')
train_data = Data(files=train_pngs, ratio =2.09, width=96,label_path=train_label_path, train_set=True, loader = default_loader)
val_data = Data(files=val_pngs, ratio =2.09, width=96, label_path=val_label_path, train_set=True, loader = default_loader)
model = net()
trainer(model, train_data,val_data, batch_size = 8, lr = 0.0001)

    ----- training epoch 1 -----
>>>>>>>>>>>>>>>>>>>>>>>>>>>>[99%]epoch 1 train loss: 0.39857012033462524 train dice 0.6014298796653748

epoch 1 validation loss: 0.5208099484443665 eval dice 0.446931928396225

    ----- training epoch 2 -----
>>>>>>>>>>>>>>>>>>>>>>>>>>>>[99%]epoch 2 train loss: 0.35781192779541016 train dice 0.6421880722045898

epoch 2 validation loss: 0.39401817321777344 eval dice 0.5737237334251404

    ----- training epoch 3 -----
>>>>>>>>>>>>>>>>>>>>>>>>>>>>[99%]epoch 3 train loss: 0.32251864671707153 train dice 0.6774813532829285

epoch 3 validation loss: 0.37292781472206116 eval dice 0.5948140621185303

    ----- training epoch 4 -----
>>>>>>>>>>>>>>>>>>>>>>>>>>>>[99%]epoch 4 train loss: 0.32284027338027954 train dice 0.6771597266197205

epoch 4 validation loss: 0.3527725338935852 eval dice 0.6149693727493286

    ----- training epoch 5 -----
>>>>>>>>>>>>>>>>>>>>>>>>>>>>[99%]epoch 5 train loss: 0.2338932752609253 train dice 0.7661067247390747

epoch 5 validation 