In [None]:
import numpy as np
data = np.load("./data/gsn_img_uint8.npy")
mask = np.load("./data/gsn_msk_uint8.npy")
print(data.shape)
print(mask.shape)

In [None]:
import matplotlib.pyplot as plt



# functions to show an image
def maskshow(mask):
    img = np.zeros((mask.shape[0], mask.shape[1], 3))
    img[:] = mask
    plt.imshow(img)
    plt.show()

def imshow(img):
    fig, ax = plt.subplots(figsize=(3, 3))
    ax.imshow(img,aspect='auto')
    plt.show()

# print(data[0].shape)
# for i in range(10):
#     maskshow(mask[i])
#     imshow(data[i])

def normalize(mask):
    return mask/255


def imshow_many(imgs):
    n = imgs.shape[0]
    if n < 20 :
        cols = 5
        rows = int((n+4)/5)
        fig = plt.figure(figsize=(cols * 4, rows * 4))
    else :
        cols = 10
        rows = int((n+9)/10)
        fig = plt.figure(figsize=(cols * 2, rows * 2))

    for i in range(n):
        sub = fig.add_subplot(rows, cols, i + 1)
        if(imgs.shape[3] == 1) :
            imgs = imgs.reshape((imgs.shape[0], imgs.shape[1], imgs.shape[2]))
        sub.imshow(imgs[i], interpolation='nearest')

def imshow_masked(imgs, masks):
    n = imgs.shape[0]
    if n < 20 :
        cols = 5
        rows = int((n+4)/5)
        fig = plt.figure(figsize=(cols * 4, rows * 4))
    else :
        cols = 10
        rows = int((n+9)/10)
        fig = plt.figure(figsize=(cols * 2, rows * 2))

    for i in range(n):
        sub = fig.add_subplot(rows, cols, i + 1)
        masks = masks.reshape((masks.shape[0], masks.shape[1], masks.shape[2]))
        sub.imshow(imgs[i], interpolation='nearest')
        sub.imshow(masks[i], interpolation='nearest', cmap='jet', alpha=0.6)


imshow_masked(data[0:5], mask[0:5])


In [None]:
import random

def randomDataAug(data, mask):
    def identity(data):
        return data

    def horizontalSymmetry(data):
        newData = np.zeros_like(data)
        numRows = data.shape[0]
        numCols = data.shape[1]
        for i in range(numRows):
            for j in range(numCols):
                newData[numRows - i  - 1][j] = data[i][j]
        return newData

    def verticalSymmetry(data):
        newData = np.zeros_like(data)
        numRows = data.shape[0]
        numCols = data.shape[1]
        for i in range(numRows):
            for j in range(numCols):
                newData[i][numCols - j - 1] = data[i][j]
        return newData

    def diagonalSymmetry1(data):
        newData = np.zeros_like(data)
        numRows = data.shape[0]
        numCols = data.shape[1]
        for i in range(numRows):
            for j in range(numCols):
                newData[numRows - i - 1][numCols - j - 1] = data[j][i]
        return newData

    def diagonalSymmetry2(data):
        newData = np.zeros_like(data)
        numRows = data.shape[0]
        numCols = data.shape[1]
        for i in range(numRows):
            for j in range(numCols):
                newData[i][j] = data[j][i]
        return newData

    def rotateRight(data):
        newData = np.zeros_like(data)
        numRows = data.shape[0]
        numCols = data.shape[1]
        for i in range(numRows):
            for j in range(numCols):
                newData[j][numRows - i - 1] = data[i][j]
        return newData

    def rotateLeft(data):
        newData = np.zeros_like(data)
        numRows = data.shape[0]
        numCols = data.shape[1]
        for i in range(numRows):
            for j in range(numCols):
                newData[numCols - j - 1][i] = data[i][j]
        return newData

    def rotateTwice(data):
        newData = np.zeros_like(data)
        numRows = data.shape[0]
        numCols = data.shape[1]
        for i in range(numRows):
            for j in range(numCols):
                newData[numCols - j - 1][numRows - i - 1] = data[j][i]
        return newData

    dataAug = [identity, horizontalSymmetry, verticalSymmetry,
               diagonalSymmetry1, diagonalSymmetry2, rotateLeft,
               rotateRight, rotateTwice]

    f = dataAug[random.randrange(0, 7, 1)]
    return f(data), f(mask)

def randomDataAugForDataset(data, mask):
    assert(data.shape[0] == mask.shape[0])
    for i in range(data.shape[0]):
        data[i], mask[i] = randomDataAug(data[i], mask[i])
    return data, mask


In [None]:

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

class UNet(nn.Module):
    def contracting_block(self, in_channels, out_channels, kernel_size=3):
        """
        This function creates one contracting block
        """
        block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(out_channels),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(out_channels),
                )
        return block

    def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(mid_channel),
                    torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
                    )
            return  block

    def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
        block = torch.nn.Sequential(
                torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),
                torch.nn.ReLU(),
                torch.nn.BatchNorm2d(mid_channel),
                torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),
                torch.nn.ReLU(),
                torch.nn.BatchNorm2d(mid_channel),
                torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1),
                torch.nn.ReLU(),
                torch.nn.BatchNorm2d(out_channels),
                )
        return  block

    def __init__(self, in_channel, out_channel):
        super(UNet, self).__init__()
        #Encode
        self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64)
        self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode2 = self.contracting_block(64, 128)
        self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode3 = self.contracting_block(128, 256)
        self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2)
        # Bottleneck
        self.bottleneck = torch.nn.Sequential(
                            torch.nn.Conv2d(kernel_size=3, in_channels=256, out_channels=512),
                            torch.nn.ReLU(),
                            torch.nn.BatchNorm2d(512),
                            torch.nn.Conv2d(kernel_size=3, in_channels=512, out_channels=512),
                            torch.nn.ReLU(),
                            torch.nn.BatchNorm2d(512),
                            torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1)
                            )
        # Decode
        self.conv_decode3 = self.expansive_block(512, 256, 128)
        self.conv_decode2 = self.expansive_block(256, 128, 64)
        self.final_layer = self.final_block(128, 64, out_channel)

    def crop_and_concat(self, upsampled, bypass, crop=False):
        """
        This layer crop the layer from contraction block and concat it with expansive block vector
        """
        if crop:
            c = (bypass.size()[2] - upsampled.size()[2]) // 2
            bypass = F.pad(bypass, (-c, -c, -c, -c))
        return torch.cat((upsampled, bypass), 1)

    def forward(self, x):
        # Encode
        encode_block1 = self.conv_encode1(x)
        encode_pool1 = self.conv_maxpool1(encode_block1)
        encode_block2 = self.conv_encode2(encode_pool1)
        encode_pool2 = self.conv_maxpool2(encode_block2)
        encode_block3 = self.conv_encode3(encode_pool2)
        encode_pool3 = self.conv_maxpool3(encode_block3)
        # Bottleneck
        bottleneck1 = self.bottleneck(encode_pool3)
        # Decode
        decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True)
        cat_layer2 = self.conv_decode3(decode_block3)
        decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True)
        cat_layer1 = self.conv_decode2(decode_block2)
        decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=True)
        final_layer = self.final_layer(decode_block1)
        return  final_layer

In [None]:
# def net():
#     unet = UNet(in_channel=3, out_channel=1)
#     #out_channel represents number of segments desired
#     criterion = torch.nn.CrossEntropyLoss()
#     optimizer = torch.optim.SGD(unet.parameters(), lr = 0.01, momentum=0.99)
#     optimizer.zero_grad()
#     return net, optimizer, criterion

In [None]:

def prepareAndRandDatasets():
    data = np.load("./data/gsn_img_uint8.npy")
    mask = np.load("./data/gsn_msk_uint8.npy")
    # print(data.shape)
    # print(mask.shape)
    data =  data.transpose(0,3,1,2)
    mask = mask.transpose(0,3,1,2)
    # print(data.shape)
    # print(mask.shape)
    # data, mask = randomDataAugForDataset(data, mask)
    return data, mask

def prepareTestDatasets():
    data = np.load("./data/test_gsn_image.npy")
    mask = np.load("./data/test_gsn_mask.npy")
    data =  data.transpose(0,3,1,2)
    mask = mask.transpose(0,3,1,2)
    return data, mask

prepareAndRandDatasets()

In [None]:
import torch
use_gpu = torch.cuda.is_available()
from tqdm import tqdm, trange
import gc

def train_step(inputs, labels, optimizer, criterion, unet, width_out, height_out):
    optimizer.zero_grad()
    # forward + backward + optimize
    print("inputs shape: ", inputs.shape)
    outputs = unet(inputs)
    # outputs.shape =(batch_size, n_classes, img_cols, img_rows)
    outputs = outputs.permute(0, 2, 3, 1)
    # outputs.shape =(batch_size, img_cols, img_rows, n_classes)
    m = outputs.shape[0]
    outputs = outputs.resize(m*width_out*height_out, 1)
    labels = labels.resize(m*width_out*height_out)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    return loss

def get_val_loss(x_val, y_val, width_out, height_out, unet):
    x_val = torch.from_numpy(x_val).float()
    y_val = torch.from_numpy(y_val).long()
    if use_gpu:
        x_val = x_val.cuda()
        y_val = y_val.cuda()
    m = x_val.shape[0]
    outputs = unet(x_val)
    # outputs.shape =(batch_size, n_classes, img_cols, img_rows)
    outputs = outputs.permute(0, 2, 3, 1)
    # outputs.shape =(batch_size, img_cols, img_rows, n_classes)
    outputs = outputs.resize(m*width_out*height_out, 2)
    labels = y_val.resize(m*width_out*height_out)
    loss = F.cross_entropy(outputs, labels)
    return loss.data

def train(unet, batch_size, epochs, epoch_lapse, threshold, learning_rate, criterion, optimizer , x_train, y_train, x_val, y_val, width_out, height_out):
    epoch_iter = np.ceil(x_train.shape[0] / batch_size).astype(int)
    t = trange(epochs, leave=True)
    print("t if %d", t)
    for _ in t:
        x_train, y_train = prepareAndRandDatasets();
        total_loss = 0
        for i in range(epoch_iter):
            batch_train_x = torch.from_numpy(x_train[i * batch_size : (i + 1) * batch_size]).float()
            batch_train_y = torch.from_numpy(y_train[i * batch_size : (i + 1) * batch_size]).long()
            if use_gpu:
                batch_train_x = batch_train_x.cuda()
                batch_train_y = batch_train_y.cuda()
            batch_loss = train_step(batch_train_x , batch_train_y, optimizer, criterion, unet, width_out, height_out)
            total_loss += batch_loss
        if (_+1) % epoch_lapse == 0:
            val_loss = get_val_loss(x_val, y_val, width_out, height_out, unet)
            print("Total loss in epoch %f : %f and validation loss : %f" %(_+1, total_loss, val_loss))
    gc.collect()

def plot_examples(unet, datax, datay, num_examples=3):
    fig, ax = plt.subplots(nrows=3, ncols=4, figsize=(18,4*num_examples))
    m = datax.shape[0]
    for row_num in range(num_examples):
        image_indx = np.random.randint(m)
        image_arr = unet(torch.from_numpy(datax[image_indx:image_indx+1]).float().cuda()).squeeze(0).detach().cpu().numpy()
        ax[row_num][0].imshow(np.transpose(datax[image_indx], (1,2,0))[:,:,0])
        ax[row_num][1].imshow(np.transpose(image_arr, (1,2,0))[:,:,0])
        ax[row_num][2].imshow(image_arr.argmax(0))
        ax[row_num][3].imshow(np.transpose(datay[image_indx], (1,2,0))[:,:,0])
    plt.show()

In [None]:
def main():
    width_in = 128
    height_in = 128
    width_out = 128
    height_out = 128
    PATH = './unet.pt'
    x_train, y_train = prepareAndRandDatasets()
    x_val, y_val = prepareTestDatasets()
    print(x_train.shape, y_train.shape, x_val.shape, y_val.shape)
    batch_size = 1
    epochs = 1
    epoch_lapse = 10
    threshold = 0.5
    learning_rate = 0.01
    unet = UNet(in_channel=3,out_channel=1)
    if use_gpu:
        unet = unet.cuda()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(unet.parameters(), lr = 0.01, momentum=0.99)

    train(unet, batch_size, epochs, epoch_lapse, threshold, learning_rate, criterion, optimizer, x_train, y_train, x_val, y_val, width_out, height_out)

    print(unet.eval())
    plot_examples(unet, x_train, y_train)
    plot_examples(unet, x_val, y_val)
    return unet


In [None]:
import time
start_time = time.time()
unet = main()
torch.save(unet.state_dict(), "./u-net")
print("--- %s seconds ---" % (time.time() - start_time))

