# Implementation and run with U-Net architecture

In [1]:
from __future__ import division, print_function
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib
import numpy as np

In [2]:
import os
import matplotlib.image as mpimg
import cv2

In [3]:
import unet

In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
def rotate_img(img, angle, rgb):
    rows, cols = img.shape[0:2]
    if rgb:
        id = 1
    else:
        id = 0
    rot_M = cv2.getRotationMatrix2D((cols/2, rows/2), angle, id)
    return cv2.warpAffine(img, rot_M, (cols, rows))

In [6]:
def flip_img(img, border_id):
    return cv2.flip(img, border_id)

In [18]:
def extract_data(filename, num_images, mytype='train'):
    """Extract the images into a 4D tensor [image index, y, x, channels].
    Values are rescaled from [0, 255] down to [-0.5, 0.5].
    """
    print('Extracting data...')
    imgs = []
    for i in range(1, num_images+1):
        if i%10==0:
            print('Extract original images... i=',i)
        if mytype == 'train':
            imageid = "satImage_%.3d" % i
        else:
            imageid = "test_%.1d" % i
        image_filename = filename + imageid + ".png"
        print ('Loading ' + image_filename) # TOREMOOOOOOOOOOOOOOOOOOOOOOVE
        img = mpimg.imread(image_filename) # to remove
        img = cv2.resize(img, (256,256), interpolation = cv2.INTER_AREA) # to remove
        imgs.append(img) # to remove
        """if os.path.isfile(image_filename):
            #print ('Loading ' + image_filename)
            img = mpimg.imread(image_filename)
            imgs.append(img)
            
            img_cv2 = cv2.imread(image_filename)
            img_flip = np.flip(flip_img(img_cv2, 1),2)/255
            imgs.append(img_flip)
            
            imgs.append(np.flip(rotate_img(img_cv2, 90, True),2)/255)
            imgs.append(np.flip(rotate_img(img_cv2, 180, True),2)/255)
            imgs.append(np.flip(rotate_img(img_cv2, 270, True),2)/255)
            
            if i==2:
                plt.subplot(151),plt.imshow(img),plt.title('Input')
                plt.subplot(152),plt.imshow(np.flip(rotate_img(img_cv2, 90, True),2)/255),plt.title('Output')
                plt.subplot(153),plt.imshow(np.flip(rotate_img(img_cv2, 180, True),2)/255),plt.title('Output')
                plt.subplot(154),plt.imshow(np.flip(rotate_img(img_cv2, 270, True),2)/255),plt.title('Output')
                plt.subplot(155),plt.imshow(img_flip),plt.title('Output')
                plt.show()
        
            
        else:
            print ('File ' + image_filename + ' does not exist')"""
        
    img_size = imgs[0].shape[0]
    img_height = imgs[0].shape[1]
    if img_size != img_height:
        print('Error!! The images should have their height equal to their width.')

    return np.asarray(imgs).astype(np.float32)

In [19]:
# Assign a label to a patch v
def value_to_class(v):
    # you can remark the hot encoding
    foreground_threshold = 0.25 # percentage of pixels > 1 required to assign a foreground label to a patch TODOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOO
    df = np.sum(v)
    if df > foreground_threshold:
        return [0, 1]
    else:
        return [1, 0]

In [20]:
# Extract label images
def extract_labels(filename, num_images):
    """Extract the labels into a 1-hot matrix [image index, label index]."""
    print('Extracting labels...')
    gt_imgs = []
    for i in range(1, num_images+1):
        if i%10==0:
            print('Extract groundtruth images... i=',i)
        imageid = "satImage_%.3d" % i
        image_filename = filename + imageid + ".png"
        print ('Loading ' + image_filename) # remooooooooooove
        img = mpimg.imread(image_filename) #remoooooooooooooooove
        gt_imgs.append(img) # remoooooooove
        """if os.path.isfile(image_filename):
            #print ('Loading ' + image_filename)
            img = mpimg.imread(image_filename)
            gt_imgs.append(img)
            
            img_cv2 = cv2.imread(image_filename,0)
            gt_img_flip = flip_img(img_cv2, 1)/255
            gt_imgs.append(gt_img_flip)
            
            gt_imgs.append(rotate_img(img_cv2, 90, True)/255)
            gt_imgs.append(rotate_img(img_cv2, 180, True)/255)
            gt_imgs.append(rotate_img(img_cv2, 270, True)/255)
            
            if i==2:
                plt.subplot(151),plt.imshow(img),plt.title('Input')
                plt.subplot(152),plt.imshow(rotate_img(img_cv2, 90, True)/255),plt.title('Output')
                plt.subplot(153),plt.imshow(rotate_img(img_cv2, 180, True)/255),plt.title('Output')
                plt.subplot(154),plt.imshow(rotate_img(img_cv2, 270, True)/255),plt.title('Output')
                plt.subplot(155),plt.imshow(gt_img_flip),plt.title('Output')
                plt.show()
            
        else:
            print ('File ' + image_filename + ' does not exist')"""

    data = np.asarray(gt_imgs)
    out_lab = [[[value_to_class(data[i][j][k]) \
                 for k in range(data.shape[2])] \
                for j in range(data.shape[1])] \
               for i in range(data.shape[0])]

    # Convert to dense 1-hot representation.
    return np.asarray(out_lab).astype(np.float32)

In [21]:
########### define directory of the training images ############################
data_dir = '../training/'
train_data_filename = data_dir + 'images/'
train_labels_filename = data_dir + 'groundtruth/'
TRAINING_SIZE = 10

data = extract_data(train_data_filename, TRAINING_SIZE)
labels = extract_labels(train_labels_filename, TRAINING_SIZE)

Extracting data...
Loading ../training/images/satImage_001.png
Loading ../training/images/satImage_002.png
Loading ../training/images/satImage_003.png
Loading ../training/images/satImage_004.png
Loading ../training/images/satImage_005.png
Loading ../training/images/satImage_006.png
Loading ../training/images/satImage_007.png
Loading ../training/images/satImage_008.png
Loading ../training/images/satImage_009.png
Extract original images... i= 10
Loading ../training/images/satImage_010.png
Extracting labels...
Extract groundtruth images... i= 10


In [16]:
initial_data = data[range(0,data.shape[0],5)]
initial_labels = labels[range(0,data.shape[0],5)]

NameError: name 'labels' is not defined

In [None]:
print(data.shape)
print(labels.shape)

In [None]:
net = unet.Unet(channels=3, n_class=2, layers=5, features_root=4) 
    #, cost_kwargs={'regularizer':1e-4}) # class_weights

In [None]:
# Optimizer = "momentum" or "adam"
trainer = unet.Trainer(net, batch_size=3, optimizer="adam") 
    #, opt_kwargs=dict(momentum=0.2)), learning_rate, decay_rate

In [None]:
path = trainer.train(data=data, labels=labels, output_path="./unet_trained", training_iters=5, \
                     epochs=3, dropout=1.0, display_step=5, prediction_path='prediction')

In [None]:
prediction = net.predict(path, data)

In [None]:
id_end = 5
for num in range(0,id_end):
    fig, ax = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(12,5))
    ax[0].imshow(initial_data[num], aspect="auto")
    ax[1].imshow(initial_labels[num,:,:,1], aspect="auto")
    #mask = prediction[num,:,:,1] > 0.5
    
    #ax[2].imshow(mask, aspect="auto")
    ax[2].imshow(prediction[num,:,:,1], aspect="auto")
    ax[0].set_title("Input")
    ax[1].set_title("Ground truth")
    ax[2].set_title("Prediction")
    fig.tight_layout()
    fig.savefig("output/roadSegmentation.png")

In [None]:
id_end = 10
for num in range(0,id_end):
    fig, ax = plt.subplots(1, 1, sharex=True, sharey=True, figsize=(5,5))
    #mask = prediction[num,:,:,1] > 0.5
    
    #ax[2].imshow(mask, aspect="auto")
    ax.imshow(prediction[num,:,:,1], aspect="auto")
    ax.set_title("Prediction")
    #fig.tight_layout()
    fig.savefig("output/roadSegmentation"+str(num)+".png")
