In [None]:
# data here: https://mega.nz/#F!9RdjiDiB!06icNxE9XwcWRIlYWfFIgg

In [None]:
import numpy as np
import os
import pickle
from glob import glob
from tqdm import trange
import matplotlib.pyplot as plt
plt.style.use("ggplot")

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import torchvision

%matplotlib inline

In [None]:
from itertools import tee, chain

def pairwise(iterable):
    "s -> (s0,s1), (s1,s2), (s2, s3), ..."
    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)

def flatten(listOfLists):
    "Flatten one level of nesting"
    return chain.from_iterable(listOfLists)

In [None]:
class UNetConvBlock(nn.Module):
    def __init__(self, channels_in, n_filters, kernel_size=3):
        super(UNetConvBlock, self).__init__()
        self.kernel_size = kernel_size
        self.c_in = channels_in
        self.c_out = n_filters
        
        self.conv_fn1 = nn.Conv2d(
            in_channels=channels_in, 
            out_channels=n_filters, 
            kernel_size=kernel_size, 
            padding=kernel_size // 2,
            bias=False)
        
        self.conv_fn2 = nn.Conv2d(
            in_channels=n_filters, 
            out_channels=n_filters, 
            kernel_size=kernel_size, 
            padding=kernel_size // 2,
            bias=False)

    def forward(self, x):
        activation_fn = F.relu
        
        x = self.conv_fn1(x)
        x = activation_fn(x)
        x = F.dropout(x)
        
        x = self.conv_fn2(x)
        x = activation_fn(x)
        x = F.dropout(x)
        
        return x

In [None]:
class UNetForward(nn.Module):
    def __init__(self, channels_in, n_filters, kernel_size=3, pool_size=2):
        super(UNetForward, self).__init__()

        self.kernel_size = kernel_size
        self.pool_size = pool_size
        
        self.conv_block = UNetConvBlock(channels_in, n_filters)

    def forward(self, x):
        conv = self.conv_block(x)
        pool = F.max_pool2d(
            conv, 
            kernel_size=self.pool_size, 
            stride=self.pool_size)
        return conv, pool

In [None]:
class UNetBackward(nn.Module):
    def __init__(self, channels_in, n_filters, kernel_size=2):
        super(UNetBackward, self).__init__()
        self.kernel_size = kernel_size
        
        self.conv_fn = nn.ConvTranspose2d(
            in_channels=channels_in, 
            out_channels=n_filters,
            kernel_size=kernel_size, 
            stride=kernel_size,
            padding=0)
        
        self.conv_block = UNetConvBlock(channels_in, n_filters)

    def forward(self, x):
        x, pre_x = x
        
        x = self.conv_fn(x)
    
        # and yeap, theano-style shape....just why?
        _, _, x_h, x_w = x.size()
        
        upsampling_fn = nn.UpsamplingBilinear2d(size=(x_h, x_w))
        pre_x = upsampling_fn(pre_x)
        
        # @TODO: need to check shapes here
        x = torch.cat((x, pre_x), 1)
        
        conv = self.conv_block(x)
        return conv

In [None]:
class UNet(nn.Module):
    def __init__(self, c_in=3, c_out=2, layers=None):
        super(UNet, self).__init__()
        self.layers = (layers or [32, 64, 128, 256, 512])
        
        self.forward_layers = [UNetForward(c_in, c_out) 
                               for c_in, c_out in pairwise([c_in] + self.layers[:-1])]
        
        self.conv_block = UNetConvBlock(self.layers[-2], self.layers[-1])
        self.backward_layers = [UNetBackward(c_in, c_out) 
                                for c_in, c_out in pairwise(self.layers[::-1])]
        
        self.output_fn = nn.Conv2d(
            in_channels=self.layers[0], 
            out_channels=c_out, 
            kernel_size=1, 
            padding=0,
            bias=True)
    
    def set_training(self, is_training):
        if is_training:
            for layer in self.forward_layers:
                layer.train()
            for layer in self.backward_layers:
                layer.train()
            self.train()
        else:
            for layer in self.forward_layers:
                layer.eval()
            for layer in self.backward_layers:
                layer.eval()
            self.eval()
    
    def set_cuda(self):
        self.cuda()
        for layer in self.forward_layers:
            layer.cuda()
        for layer in self.backward_layers:
            layer.cuda()

    def forward(self, x):
        layers_history = []
        
        for layer in self.forward_layers:
            pre_x, x = layer(x)
            layers_history.append(pre_x)
    
        x = self.conv_block(x)

        for layer, pre_x in zip(self.backward_layers, layers_history[::-1]):
            x = layer([x, pre_x])

        x = self.output_fn(x)
        
        return F.softmax(x)
    
    def get_all_params(self):
        forward_params = [layer.parameters() for layer in self.forward_layers]
        backward_params = [layer.parameters() for layer in self.backward_layers]
        return list(self.parameters()) + \
            list(flatten(forward_params)) + \
            list(flatten(backward_params))

In [None]:
model = UNet()
optimizer = optim.Adam(model.get_all_params(), 1e-5)

In [None]:
use_cuda = True

In [None]:
if use_cuda:
    model.set_cuda()
    loss_fn = nn.NLLLoss2d(weight=torch.Tensor([1.0, 50.0]).cuda())
else:
    loss_fn = nn.NLLLoss2d(weight=torch.Tensor([1.0, 50.0]))

In [None]:
with open("train_data.pkl", "rb") as fin:
    train_data = pickle.load(fin)

In [None]:
with open("val_data.pkl", "rb") as fin:
    val_data = pickle.load(fin)

In [None]:
with open("test_data.pkl", "rb") as fin:
    test_data = pickle.load(fin)

In [None]:
def channel_first_data(data):
    return [
        (np.transpose(x, [2, 0, 1]), np.transpose(y, [2, 0, 1]))
        for x, y in data]

In [None]:
train_data = channel_first_data(train_data)
val_data = channel_first_data(val_data)
# test_data = channel_first_data(test_data)

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=2, shuffle=True)

In [None]:
test_loader = torch.utils.data.DataLoader(
    val_data, batch_size=2, shuffle=True)

In [None]:
def train(epoch):
    model.set_training(True)
    for batch_idx, (data, target) in enumerate(train_loader):
        if use_cuda:
            data, target = data.cuda().type(torch.cuda.FloatTensor), target.cuda().type(torch.cuda.LongTensor)
        data, target = autograd.Variable(data), autograd.Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, torch.squeeze(target, 1))
        loss.backward()
        optimizer.step()
        if batch_idx % 40 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data[0]))

In [None]:
def test(epoch):
    model.set_training(False)
    test_loss = 0
    for data, target in test_loader:
        if use_cuda:
            data, target = data.cuda().type(torch.cuda.FloatTensor), target.cuda().type(torch.cuda.LongTensor)
        data, target = autograd.Variable(data, volatile=True), autograd.Variable(target)
        output = model(data)
        test_loss += loss_fn(output, torch.squeeze(target, 1)).data[0]

    test_loss = test_loss
    test_loss /= len(test_loader) # loss function already averages over batch size
    print('\nTest set: Average loss: {:.4f}\n'.format(
        test_loss))

In [None]:
for epoch in range(1, 240 + 1):
    train(epoch)
    test(epoch)

In [None]:
# def batch_generator(data, batch_size=16):
#     while True:
#         ids = np.random.choice(len(data), batch_size)
#         imgs = np.array([data[i][0] for i in ids])
#         labels = np.array([data[i][1] for i in ids])
#         yield imgs, labels

In [None]:
# def plot_unimetric(history, metric, save_dir=None):
#     plt.figure()
#     plt.plot(history[metric])
#     plt.title('model {}'.format(metric))
#     plt.ylabel(metric)
#     plt.xlabel('epoch')
#     if save_dir is None:
#         plt.show()
#     else:
#         plt.savefig("{}/{}.png".format(save_dir, metric),
#                     format='png', dpi=300)

In [None]:
# def save_stats(stats, save_dir="./"):
#     for key in stats:
#         plot_unimetric(stats, key, save_dir)

In [None]:
# def create_if_need(path):
#     if not os.path.exists(path):
#         os.makedirs(path)

In [None]:
# n_epochs = 240
# n_steps = 100
# gpu_option = 0.95
# batch_size = 2
# load = True
# model_dir = "./model"
# stats_dir = "./stats"
# train = False

In [None]:
# save_dir = "./val_predictions"
# create_if_need(save_dir)

In [None]:
# for i, (label, pred) in enumerate(val_predictions):
#     plt.figure(figsize=(10,8))
#     plt.subplot(1,2,1)
#     plt.imshow(label, 'gray')
#     plt.subplot(1,2,2)
# #     pred[pred < 0.5] = 0.0
#     plt.imshow(pred, 'gray')
#     plt.savefig("{}/{}.png".format(save_dir, i),
#                     format='png', dpi=300)
#     plt.show()

In [None]:
# save_dir = "./test_predictions"
# create_if_need(save_dir)

In [None]:
# for i, (label, pred) in enumerate(test_predictions):
#     plt.figure(figsize=(10,8))
#     plt.subplot(1,2,1)
#     plt.imshow((label + 0.5) * 255.)
#     plt.subplot(1,2,2)
#     plt.imshow(pred, 'gray')
#     plt.savefig("{}/{}.png".format(save_dir, i),
#                     format='png', dpi=300)
#     plt.show()