In [None]:
import pandas as pd
import numpy as np
import os
import glob
from PIL import Image
import cv2
from skimage import transform
import matplotlib.pyplot as plt
import random

import torch
import torch.nn as nn
import torch.utils as utils
import torch.nn.init as init
import torchvision.utils
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from skimage.transform import resize

from sklearn.metrics import *
import logging 
log = logging.getLogger("basic")

In [None]:
train_image = glob.glob(os.path.join('/data/train_data/balanced_image/', '*.jpg'))
train_mask = glob.glob(os.path.join('/data/train_data/balanced_mask/', '*.jpg'))


In [None]:
test_image = glob.glob(os.path.join('/data/test_data/balanced_image/', '*.jpg'))
test_mask = glob.glob(os.path.join('/data/test_data/balanced_mask/', '*.jpg'))


In [None]:
width, height = 496, 384
        
class CustomDataset(Dataset):
    def __init__(self, image_paths, mask_paths, train=True):   # initial logic happens like transform
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms.ToTensor()

    def __getitem__(self, index):

        image = Image.open(self.image_paths[index]).convert('RGB').resize((width, height)) 
        image = np.array(image)
        
        mask = Image.open(self.mask_paths[index]).convert('RGB').resize((width, height))
        mask = np.array(mask)
        return self.transforms(image), self.transforms(mask)

    def __len__(self):  
        return len(self.image_paths)

In [None]:
batch_size = 4
train_dataset = CustomDataset(train_image[:], train_mask[:], train=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1)

In [None]:
# Test the loader
#for _,(images, masks) in enumerate(train_loader):
    #print (images.shape)
    #break


In [None]:
test_dataset = CustomDataset(test_image[:], test_mask[:], train=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1)

In [None]:
def conv_block(in_dim,out_dim,act_fn):
    model = nn.Sequential(
        nn.Conv2d(in_dim,out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_dim),
        act_fn,
    )
    return model


def conv_trans_block(in_dim,out_dim,act_fn):
    model = nn.Sequential(
        nn.ConvTranspose2d(in_dim,out_dim, kernel_size=3, stride=2, padding=1,output_padding=1),
        nn.BatchNorm2d(out_dim),
        act_fn,
    )
    return model


def maxpool():
    pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
    return pool


def conv_block_2(in_dim,out_dim,act_fn):
    model = nn.Sequential(
        conv_block(in_dim,out_dim,act_fn),
        nn.Conv2d(out_dim,out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_dim),
    )
    return model    


def conv_block_3(in_dim,out_dim,act_fn):
    model = nn.Sequential(
        conv_block(in_dim,out_dim,act_fn),
        conv_block(out_dim,out_dim,act_fn),
        nn.Conv2d(out_dim,out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_dim),
    )
    return model

class UnetGenerator(nn.Module):

    def __init__(self,in_dim,out_dim,num_filter):
        super(UnetGenerator,self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_filter = num_filter
        act_fn = nn.LeakyReLU(0.2, inplace=True)

        self.down_1 = conv_block_2(self.in_dim,self.num_filter,act_fn)
        self.pool_1 = maxpool()
        self.down_2 = conv_block_2(self.num_filter*1,self.num_filter*2,act_fn)
        self.pool_2 = maxpool()
        self.down_3 = conv_block_2(self.num_filter*2,self.num_filter*4,act_fn)
        self.pool_3 = maxpool()
        self.down_4 = conv_block_2(self.num_filter*4,self.num_filter*8,act_fn)
        self.pool_4 = maxpool()

        self.bridge = conv_block_2(self.num_filter*8,self.num_filter*16,act_fn)

        self.trans_1 = conv_trans_block(self.num_filter*16,self.num_filter*8,act_fn)
        self.up_1 = conv_block_2(self.num_filter*16,self.num_filter*8,act_fn)
        self.trans_2 = conv_trans_block(self.num_filter*8,self.num_filter*4,act_fn)
        self.up_2 = conv_block_2(self.num_filter*8,self.num_filter*4,act_fn)
        self.trans_3 = conv_trans_block(self.num_filter*4,self.num_filter*2,act_fn)
        self.up_3 = conv_block_2(self.num_filter*4,self.num_filter*2,act_fn)
        self.trans_4 = conv_trans_block(self.num_filter*2,self.num_filter*1,act_fn)
        self.up_4 = conv_block_2(self.num_filter*2,self.num_filter*1,act_fn)

        self.out = nn.Sequential(
            nn.Conv2d(self.num_filter,self.out_dim,3,1,1),
            nn.Sigmoid(), 
        )

    def forward(self,input):
        down_1 = self.down_1(input)
        pool_1 = self.pool_1(down_1)
        down_2 = self.down_2(pool_1)
        pool_2 = self.pool_2(down_2)
        down_3 = self.down_3(pool_2)
        pool_3 = self.pool_3(down_3)
        down_4 = self.down_4(pool_3)
        pool_4 = self.pool_4(down_4)

        bridge = self.bridge(pool_4)

        trans_1 = self.trans_1(bridge)
        concat_1 = torch.cat([trans_1,down_4],dim=1)
        up_1 = self.up_1(concat_1)
        trans_2 = self.trans_2(up_1)
        
        t2_shape = trans_2.shape
        concat_2 = torch.cat([trans_2, down_3[:,:,:,:t2_shape[-1]]],dim=1)
        up_2 = self.up_2(concat_2)
        trans_3 = self.trans_3(up_2)
        
        t3_shape = trans_3.shape
        
        concat_3 = torch.cat([trans_3, down_2[:,:,:,:t3_shape[-1]]],dim=1)
        up_3 = self.up_3(concat_3)
        trans_4 = self.trans_4(up_3)
        
        t4_shape = trans_4.shape

        concat_4 = torch.cat([trans_4, down_1[:,:,:,:t4_shape[-1]]],dim=1)
        up_4 = self.up_4(concat_4)

        out = self.out(up_4)

        return out

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

In [None]:
model = UnetGenerator(3,3,64)
model = model.to(device, dtype=torch.float)
lr = 0.0005
optimizer = torch.optim.Adam(model.parameters(),lr=lr)

In [None]:
# Different loss functions - MSE, Weighted_MSE
# Change cell from markdown to code for the loss function cell that you want to try

#Weighted MSE
HEIGHT, WIDTH = height, width
WEIGHTS = np.ones((3, height, width))

#option1:  pixels 0 to 124, 372 to 496 weigh them 0.25
#WEIGHTS[:, :,  :int(width/4)] *= 0.25 
#WEIGHTS[:, :,  -2 *int(width/4):] *= 0.25 

#option2:  pixels 0 to 124 0.2, 372 to 496 weigh them 0.1
WEIGHTS[:, :,  0:124] *= 0.2 
WEIGHTS[:, :,  372:] *= 0.1 

def weighted_mse_loss(inp, target, weights=None):
    if not weights:
        weights = torch.Tensor(WEIGHTS)
    
    # add weighted zones.    
    weights = Variable(weights).to(device)
    out = (inp - target)**2
    out = out * weights.expand_as(out)
    return torch.mean(out)


In [None]:
# MSE 
loss_func = nn.MSELoss()

In [None]:
def evaluate_f1score(model, test_loader):
    model.eval()
    f1_list = []
    for _,(images, masks) in enumerate(test_loader):
        images = Variable(images).to(device)
        predicted_labels = model.forward(images)
        flat_pred = predicted_labels.detach().cpu().numpy().flatten()
        flat_label = masks.detach().cpu().numpy().flatten()
        flat_pred[flat_pred >= 0.5] = 1
        flat_pred[flat_pred <0.5] = 0
        f1 = f1_score(flat_pred.astype('int32'), flat_label.astype('int32'), average="micro")
        f1_list.append(f1)
    return f1_list

In [None]:
min_loss = 999
epoch = 10

for i in range(epoch):
    model.train()
    for _, (images, masks) in enumerate(train_loader):
    
        optimizer.zero_grad()

        x = Variable(images).to(device)
        y_ = Variable(masks).to(device)
        y = model.forward(x)
        
        loss = loss_func(y, y_)  #weighted_mse_loss(y, y_) #.type(torch.FloatTensor))
        loss.backward()
        optimizer.step()
        
        if _ % 10 == 0:
            print("Epoch: %s Batch ID: %s Loss: %s" % (i, _, loss.detach().cpu().numpy()))

    ## Test after every epoch
    
    numerical_loss = loss.detach().cpu().numpy().tolist()
    if  numerical_loss < min_loss:
        min_loss = numerical_loss
        torch.save(model, 'best_model.pth')
        
      
                
    # Evaluate Predictions using F1 score
    f1_list = evaluate_f1score(model, test_loader)
    print ("Evaluation F1", i, np.mean(f1_list))

In [None]:
### Get Formatted Results

out_images = []
out_masks = []
orig_masks = []
def mse(A, B, ax=0):
    return np.square(A.flatten() - B.flatten()).mean(axis=ax)

for _,(images, masks) in enumerate(test_loader):
    images = Variable(images).to(device)
    predicted_labels = model.forward(images)
    flat_pred = predicted_labels.detach().cpu().numpy().flatten()
    flat_label = masks.detach().cpu().numpy().flatten()
    flat_pred[flat_pred >= 0.5] = 1
    flat_pred[flat_pred <0.5] = 0
    f1 = f1_score(flat_pred.astype('int32'), flat_label.astype('int32'), average="micro")
    f1_weighted = f1_score(flat_pred.astype('int32'), flat_label.astype('int32'), average="weighted")
    batch_mse = np.mean(mse(flat_pred, flat_label))
    print ("Batch: %s, F1.micro: %.4f, F1.weighted: %.4f, MSE: %s" % (_, f1, f1_weighted, batch_mse))
    for idx in range(images.shape[0]):
        pred_mask_img = predicted_labels.detach().cpu().numpy()[idx]
        o_img = images.detach().cpu().numpy()[idx]
        out_images.append(o_img)
        out_masks.append(pred_mask_img)
        orig_masks.append(masks.cpu().numpy()[idx])

In [None]:
def get_img(im):
    ''' swap axes '''
    im = np.swapaxes(im, 0, 2)
    im = np.swapaxes(im, 0, 1)
    return im 

def show_image_and_mask(pred_mask_img, o_img, o_mask, idx):
    seam_pred_mask = get_img(pred_mask_img)
    seam_pred_mask = np.mean(seam_pred_mask,axis=2)    
    orig_img = get_img(o_img)
    col_sum = np.sum(seam_pred_mask, axis=0)
    idx_to_keep = col_sum.argsort()[:400]
    idx_to_keep.sort()
    
    fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(14, 8), sharex=False, sharey=True)
    ax[0].imshow(orig_img)
    ax[0].set_title("original image")
    
    ax[1].imshow(get_img(o_mask))
    ax[1].set_title("original mask")

    ax[2].imshow(seam_pred_mask, cmap='gray')
    ax[2].set_title("predicted mask")
    
    ax[3].imshow(orig_img[:, idx_to_keep])
    ax[3].set_title("u-net resized image")
    
    fig.savefig('best_model_results/%s.jpg' %idx)
    
    print (orig_img[:, idx_to_keep].shape, orig_img.shape)

In [None]:
for idx in range(len(out_masks)):
    show_image_and_mask(out_masks[idx], out_images[idx], orig_masks[idx], idx)