# Load Libraries

In [1]:
import numpy as np
from torch.utils import data
import torch
from __future__ import print_function
import argparse
import torch
import numpy as np
from torch.nn import functional
import torch.optim as optim
from torch.utils import data
from torchvision import transforms
from torch.autograd import Variable
from skimage import exposure
from skimage.io import imread
from skimage.util import *
import matplotlib.pyplot as plt
import natsort as ns
import os
from tqdm import tqdm
from collections import OrderedDict
from IPython.display import clear_output

In [2]:
# import custom library

import sys
#base_dir = '/content'  # for Colaboratory
#base_dir = '/Users/cyrilwendl/Documents/EPFL'  # for local machine
base_dir = '/home/cyrilwendl'  # for GCE

sys.path.append(base_dir + '/SIE-Master/Zurich/helpers') # Path to density Tree package
sys.path.append(base_dir + '/SIE-Master/Zurich/data_augment') # Path to density Tree package
from helpers.helpers import *
from helpers.data_augment import *

# UNet

In [3]:
# import torch libraries

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn import init
import numpy as np

def conv3x3(in_channels, out_channels, stride=1,padding=1, bias=True, groups=1):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=3,
        stride=stride,
        padding=padding,
        bias=bias,
        groups=groups) 

def conv1x1(in_channels, out_channels, groups=1):
    return nn.Conv2d(
      in_channels,
      out_channels,
      kernel_size=1,
      groups=groups,
      stride=1)

def upconv2x2(in_channels, out_channels, mode='transpose'):
    """
    upsampling helper in mode 'transpose' or 'sequential'
    """
    if mode == 'transpose':
        return nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size=2,
            stride=2)
    else:
        # out_channels is always going to be the same
        # as in_channels
        return nn.Sequential(
            nn.Upsample(mode='bilinear', scale_factor=2),
            conv1x1(in_channels, out_channels))





class DownConv(nn.Module):
    """
    A helper Module that performs 2 convolutions and 1 MaxPool.
    A ReLU activation follows each convolution.
    """

    def __init__(self, in_channels, out_channels, pooling=True):
        super(DownConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.pooling = pooling

        self.conv1 = conv3x3(self.in_channels, self.out_channels)
        self.conv2 = conv3x3(self.out_channels, self.out_channels)

        if self.pooling:
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        before_pool = x
        if self.pooling:
            x = self.pool(x)
        return x, before_pool


class UpConv(nn.Module):
    """
    A helper Module that performs 2 convolutions and 1 UpConvolution.
    A ReLU activation follows each convolution.
    """

    def __init__(self, in_channels, out_channels,
                 merge_mode='concat', up_mode='transpose'):
        super(UpConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.merge_mode = merge_mode
        self.up_mode = up_mode

        self.upconv = upconv2x2(self.in_channels, self.out_channels,
                                mode=self.up_mode)

        if self.merge_mode == 'concat':
            self.conv1 = conv3x3(
                2 * self.out_channels, self.out_channels)
        else:
            # num of input channels to conv2 is same
            self.conv1 = conv3x3(self.out_channels, self.out_channels)
        self.conv2 = conv3x3(self.out_channels, self.out_channels)

    def forward(self, from_down, from_up):
        """ Forward pass
        Arguments:
            from_down: tensor from the encoder pathway
            from_up: upconv'd tensor from the decoder pathway
        """
        from_up = self.upconv(from_up)
        if self.merge_mode == 'concat':
            x = torch.cat((from_up, from_down), 1)
        else:
            x = from_up + from_down
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x


class UNet(nn.Module):
    """ `UNet` class is based on https://arxiv.org/abs/1505.04597
    The U-Net is a convolutional encoder-decoder neural network.
    Contextual spatial information (from the decoding,
    expansive pathway) about an input tensor is merged with
    information representing the localization of details
    (from the encoding, compressive pathway).
    Modifications to the original paper:
    (1) padding is used in 3x3 convolutions to prevent loss
        of border pixels
    (2) merging outputs does not require cropping due to (1)
    (3) residual connections can be used by specifying
        UNet(merge_mode='add')
    (4) if non-parametric upsampling is used in the decoder
        pathway (specified by upmode='upsample'), then an
        additional 1x1 2d convolution occurs after upsampling
        to reduce channel dimensionality by a factor of 2.
        This channel halving happens with the convolution in
        the tranpose convolution (specified by upmode='transpose')
    """

    def __init__(self, num_classes, in_channels=3, depth=5,
                 start_filts=64, up_mode='transpose',
                 merge_mode='concat'):
        """
        Arguments:
            in_channels: int, number of channels in the input tensor.
                Default is 3 for RGB images.
            depth: int, number of MaxPools in the U-Net.
            start_filts: int, number of convolutional filters for the
                first conv.
            up_mode: string, type of upconvolution. Choices: 'transpose'
                for transpose convolution or 'upsample' for nearest neighbour
                tupsampling.
        """
        super(UNet, self).__init__()

        if up_mode in ('transpose', 'upsample'):
            self.up_mode = up_mode
        else:
            raise ValueError("\"{}\" is not a valid mode for "
                             "upsampling. Only \"transpose\" and "
                             "\"upsample\" are allowed.".format(up_mode))

        if merge_mode in ('concat', 'add'):
            self.merge_mode = merge_mode
        else:
            raise ValueError("\"{}\" is not a valid mode for"
                             "merging up and down paths. "
                             "Only \"concat\" and "
                             "\"add\" are allowed.".format(up_mode))

        # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
        if self.up_mode == 'upsample' and self.merge_mode == 'add':
            raise ValueError("up_mode \"upsample\" is incompatible "
                             "with merge_mode \"add\" at the moment "
                             "because it doesn't make sense to use "
                             "nearest neighbour to reduce "
                             "depth channels (by half).")

        self.num_classes = num_classes
        self.in_channels = in_channels
        self.start_filts = start_filts
        self.depth = depth

        self.down_convs = []
        self.up_convs = []

        # create the encoder pathway and add to a list
        for i in range(depth):
            ins = self.in_channels if i == 0 else outs
            outs = self.start_filts * (2 ** i)
            pooling = True if i < depth - 1 else False

            down_conv = DownConv(ins, outs, pooling=pooling)
            self.down_convs.append(down_conv)

        # create the decoder pathway and add to a list
        # - careful! decoding only requires depth-1 blocks
        for i in range(depth - 1):
            ins = outs
            outs = ins // 2
            up_conv = UpConv(ins, outs, up_mode=up_mode,
                             merge_mode=merge_mode)
            self.up_convs.append(up_conv)

        self.conv_final = conv1x1(outs, self.num_classes)

        # add the list of modules to current module
        self.down_convs = nn.ModuleList(self.down_convs)
        self.up_convs = nn.ModuleList(self.up_convs)

        self.reset_params()

    @staticmethod
    def weight_init(m):
        if isinstance(m, nn.Conv2d):
            init.xavier_normal(m.weight)
            init.constant(m.bias, 0)

    def reset_params(self):
        for i, m in enumerate(self.modules()):
            self.weight_init(m)

    def forward(self, x):
        encoder_outs = []

        # encoder pathway, save outputs for merging
        for i, module in enumerate(self.down_convs):
            x, before_pool = module(x)
            encoder_outs.append(before_pool)

        for i, module in enumerate(self.up_convs):
            before_pool = encoder_outs[-(i + 2)]
            x = module(before_pool, x)

        # No softmax is used. This means you need to use
        # nn.CrossEntropyLoss is your training script,
        # as this module includes a softmax already.
        x = self.conv_final(x)
        return x


if __name__ == "__main__":
    """
    testing
    """
    model = UNet(3, depth=5, merge_mode='concat')
    x = Variable(torch.FloatTensor(np.random.random((1, 3, 320, 320))))
    out = model(x)
    loss = torch.sum(out)
    loss.backward()

# Data Loader

In [4]:
import numpy as np
from torch.utils import data
import torch


class ZurichLoader(data.Dataset):
    """
    Data loader for Zurich dataset
    """
    def __init__(self, im_patches, gt_patches, split, data_augmentation = False):
        """
        Load data.
        :param split: 'train', 'val' or empty
        :param transform_data: list of data transformations for images
        :param transform_labels: list of data transformations for labels
        :param im_size: size to which to crop labels and images
        :param patch_size: size of image patches
        """
        # data transformations
        self.data_augmentation = data_augmentation
        self.split = split
         
        # Load image indexes, depending on set:
        if split == 'train':
            self.img_idx = np.arange(0, int(len(im_patches)*.8))
        else:  # elif split == 'val':
            self.img_idx = np.arange(int(len(im_patches)*.8), int(len(im_patches)))
        

        self.im_patches = [im_patches[i] for i in self.img_idx]
        self.gt_patches= [gt_patches[i] for i in self.img_idx]

        # translate to data and label paths

    def __getitem__(self, idx):
        """
        function must be overridden: returns data-label pair of tensors for data point at index
        Here we just return the entire images for demonstration reasons. In reality, you would crop
        from each image at random here, or would have a pre-defined list of coordinates initialised
        in the constructor and crop according to it.
        """

        img = self.im_patches[idx]
        gt = self.gt_patches[idx]

        # convert image
        # img = Image.fromarray((img*255).astype(np.uint8))
        # gt = Image.fromarray(gt.astype(np.uint8)).convert('L')
        # apply transformations

        if self.data_augmentation:
            img, gt = augment_images_and_gt(img, gt)
        

        # If you want to do special transforms like rotation, do them here.
        # Don't forget to apply the same transforms to both the data and label tensors.
        # You can use Torchsample, or else convert the data to numpy (e.g.: img.numpy())
        # and then load it again into a torch tensor (img = torch.from_numpy(img)).

        # TODO transformations using torchsample
        img = np.asarray(img).transpose((2, 0, 1)).astype(np.float64)
        img = torch.from_numpy(img.copy()).type(torch.FloatTensor)
        gt = torch.from_numpy(gt.copy()).type(torch.LongTensor)  # .astype(np.double))

        return img, gt

    def __len__(self):
        # function must be overridden: returns number of data points in data set
        return len(self.gt_patches)

# Load Data and Model

# Train Model

In [5]:
import psutil
psutil.virtual_memory()

svmem(total=31616704512, available=26244857856, percent=17.0, used=4915703808, free=23746056192, active=5223993344, inactive=2155925504, buffers=101834752, cached=2853109760, shared=9654272)

In [6]:
# Training settings
# Training settings
batch_size = 10
test_batch_size = 20
epochs = 10
lr = 1e-3
momentum = 0.9
no_cuda = False
seed = 1
log_interval = 50

cuda = not no_cuda and torch.cuda.is_available()

torch.manual_seed(seed)
if cuda:
    torch.cuda.manual_seed(seed)

kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

def imgs_stretch_eq_(imgs):
    """
    perform histogram stretching and equalization
    :param imgs: images to stretch and equalize
    :return imgs_eq: equalized images
    """

    imgs_eq = imgs.copy()
    for idx_im, im in enumerate(imgs):
        # Contrast stretching
        p2 = np.percentile(im, 2)
        p98 = np.percentile(im, 98)
        for band in range(im.shape[-1]):
            imgs_eq[idx_im][:, :, band] = exposure.rescale_intensity(im[:, :, band], in_range=(p2, p98))  # stretch
            imgs_eq[idx_im][:, :, band] = exposure.equalize_hist(imgs_eq[idx_im][:, :, band])  # equalize

    # convert to np arrays
    imgs_eq = np.asarray(imgs_eq)
    return imgs_eq


# load data
def load_data(base_dir):
    base_dir = base_dir + "/SIE-Master/Zurich/Zurich_dataset"
    im_dir = base_dir + '/images_tif/'
    gt_dir = base_dir + '/groundtruth/'
    im_names = ns.natsorted(os.listdir(im_dir))
    gt_names = ns.natsorted(os.listdir(gt_dir))
    imgs = np.asarray([im_load(im_dir + im_name) for im_name in im_names])
    gt = np.asarray([im_load(gt_dir + gt_name) for gt_name in gt_names])
    # histogram stretching
    
    print(sys.getsizeof(imgs))
    imgs_eq = imgs_stretch_eq_(imgs)
    imgs = imgs_eq    # continue using stretched image
    patch_size = 64
    stride = 64
    im_patches = get_padded_patches(imgs[:10], patch_size = patch_size, stride = stride)
    print(sys.getsizeof(im_patches))
    gt_patches = get_gt_patches(gt[:10], patch_size = patch_size, stride = stride)
    return im_patches, gt_patches

im_patches, gt_patches = load_data(base_dir)

256


IndexError: too many indices for array

In [None]:
# create datasets
train_loader = data.DataLoader(
    ZurichLoader(im_patches, gt_patches, 'train', data_augmentation=True),
    batch_size=batch_size, shuffle=True)

test_loader = data.DataLoader(
    ZurichLoader(im_patches, gt_patches, 'val'),
    batch_size=test_batch_size, shuffle=False)

n_classes = 9  # TODO parse

model = UNet(n_classes, in_channels=4, depth=7)

if cuda:
    model.cuda()

#optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
optimizer = optim.Adam(model.parameters(), lr=lr)


def train(epochs=epochs):
    """Train model"""
    model.train()
    for batch_idx, (im_data, labels) in enumerate(train_loader):
        im_data, labels = Variable(im_data), Variable(labels)
        if cuda:
            im_data, labels = im_data.cuda(), labels.cuda()
        # class_weights = class_weight.compute_class_weight('balanced', np.unique(labels.data.numpy().flatten()),
        #                                                np.arange(10))
        # class_weights=Variable(torch.from_numpy(class_weights).type(torch.FloatTensor))
        optimizer.zero_grad()
        output = model(im_data)
        loss = functional.cross_entropy(output, labels, ignore_index=0)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}'.format(
                epochs, batch_idx * len(im_data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data[0]))


def test():
    """Show accuracies for training and test sets"""
    model.eval()
    loss_test = 0
    loss_train = 0
    correct_test = 0
    correct_train = 0
    for im_data, labels in test_loader:
        im_data, labels = Variable(im_data, volatile=True), Variable(labels)
        if cuda:
            im_data, labels = im_data.cuda(), labels.cuda()
        output = model(im_data)
        loss_test += functional.cross_entropy(output, labels, ignore_index=0).data[0]  # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
        correct_test += pred.eq(labels.data.view_as(pred)).cpu().sum()
        
    for im_data, labels in train_loader:
        im_data, labels = Variable(im_data, volatile=True), Variable(labels)
        if cuda:
            im_data, labels = im_data.cuda(), labels.cuda()
        output = model(im_data)
        loss_train += functional.cross_entropy(output, labels, ignore_index=0).data[0]  # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
        correct_train += pred.eq(labels.data.view_as(pred)).cpu().sum()

    loss_test /= len(test_loader.dataset)
    print('\nTraining set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        loss_train, correct_train, len(train_loader.dataset) * 64 * 64,
        100. * correct_train / (len(train_loader.dataset) * 64 * 64)))
      
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        loss_test, correct_test, len(test_loader.dataset) * 64 * 64,
        100. * correct_test / (len(test_loader.dataset) * 64 * 64)))

    
def predict():
    """return predictions after training"""
    # TODO test
    model.eval()
    pred = []
    for im_data, labels in test_loader:
        im_data, labels = Variable(im_data, volatile=True), Variable(labels)
        if cuda:
            im_data, labels = im_data.cuda(), labels.cuda()
        output = model(im_data)
        pred.append(output.data.max(1, keepdim=True)[1])   # get the index of the max log-probability
    return pred
        
    


def test_show_some_images():
    # get data
    test_im = train_loader.dataset[8]
    im_test = Variable(test_im[0]).data.numpy()
    im_test_l = Variable(test_im[1]).data.numpy()
    im_test = np.transpose(im_test, (1, 2, 0))

    # show figures
    plt.figure()
    plt.imshow(im_test[:, :, :3])
    plt.show()

    plt.figure()
    plt.imshow(im_test_l)
    plt.show()

In [None]:
for epoch in range(1, epochs + 1):
      train(epoch)
      test()