<a href="https://colab.research.google.com/github/airsresincrop/AIRS/blob/master/inria_pretrainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
#@title Download Dataset (only have to run once per dataset per runtime)
Dataset = "inria_train_1_val_5" #@param ["inria_train_1_val_1", "inria_train_1_val_5", "inria_train_5_val_1", "inria_train_5_val_5", "inria_train_10_val_5", "inria_train_15_val_5", "inria_train_31_val_5"]

link_dict = {
  'inria_train_1_val_1': 'https://drive.google.com/uc?id=1-b_6cWLvzQiyEj1V5uZm51hmXjNjcTF0',
  'inria_train_1_val_5': 'https://drive.google.com/uc?id=1-avaIA0OVN3ee2yZAhQEcOSIKuLbnRJ1',
  'inria_train_5_val_1': 'https://drive.google.com/uc?id=1-aLH5GUtqUTuF1nLRnCr1xOw7td1qsDr',
  'inria_train_5_val_5': 'https://drive.google.com/uc?id=1-_vwFzOrkdyVGwIvZadsEgbT-UoH7TPT',
  'inria_train_10_val_5': 'https://drive.google.com/uc?id=1-_bFTMLKQC3nGEzlQiETtrwMg4vd8Ath',
  'inria_train_15_val_5': 'https://drive.google.com/uc?id=1-YfYq5gXHB_pos6b7EQN17Oqfc1NcUJB',
  'inria_train_31_val_5': 'https://drive.google.com/uc?id=1-Me0Bp-G4vInbybLbR3vVyM-3jOxZ5mg'
}

mra_link_dict = {
  'inria_train_1_val_1': 'https://drive.google.com/uc?id=1-DCJXY3Cze1E4TwjtEmThTzQNCy1qVn4',
  'inria_train_1_val_5': 'https://drive.google.com/uc?id=1-6DzYB889BUSXHF2to1NV1CHB0WZe9DC',
  'inria_train_5_val_1': 'https://drive.google.com/uc?id=1-KBgxPwWrsTJSJ8mcxMAdrJkd1eA_kqJ',
  'inria_train_5_val_5': 'https://drive.google.com/uc?id=1-53ARNyjnRqlmfGWQrIyAtt84B385g1k',
  'inria_train_10_val_5': 'https://drive.google.com/uc?id=1-1KjpGv-bUtrmWkr07eQ7tw_HFoUxQ0q',
  'inria_train_15_val_5': 'https://drive.google.com/uc?id=1-0ae-aaZRI4DtYt00hg6QWfYWOxPIyd_',
  'inria_train_31_val_5': 'https://drive.google.com/uc?id=1D6p3JxA6AFM8w42pxjLEttb3l22B6nOH'
}

dataset_link = link_dict[Dataset]

print('Installing Tensorboard...')
!pip install -q -U tensorboard
print('Downloading Dataset...')
!gdown -q {dataset_link}
print('Unzipping Dataset...')
!unzip -qq {Dataset}.zip
print('Deleting Dataset zip...')
!rm {Dataset}.zip

dataset_link = mra_link_dict[Dataset]
print('Downloading MRA Dataset...')
!gdown -q {dataset_link}
print('Unzipping MRA Dataset...')
!unzip -qq {Dataset}_MRA.zip
print('Deleting MRA Dataset zip...')
!rm {Dataset}_MRA.zip

print('Done.')
# !gdown https://drive.google.com/uc?id=1jcQtXg8bYzTlgCMcxDqEYuC62j1XKf3t # AIRS_128.zip
# !gdown https://drive.google.com/uc?id=1MohUSBykyDu94E1PTRelC5h0dA-LuuU1 # AIRS_256.zip
# !gdown https://drive.google.com/uc?id=1rxR0XTiDbPeJg88WLZctzPwuL9JquC4g # AIRS_512.zip
# !gdown https://drive.google.com/uc?id=1-5BDBZL8Tsxzi_wshN3d_EF6Zhn2iitY # AIRS_1024.zip
# !gdown https://drive.google.com/uc?id=1a5A8xat3wJmEx11ml1UUqhTjMKbMFkG9 # AIRS_2048.zip
# !gdown https://drive.google.com/uc?id=1-F3N0NBorYSIikB_my4NY8t8NI2NAjk6 # AIRS_4096.zip
# !gdown https://drive.google.com/uc?id=1-2p-z1L2CWI-dr0EDj-9Ve5MXGBurEQI # AIRS_8192.zip
# !gdown https://drive.google.com/uc?id=1-CsjdU3xVDiIuWQ0s95vKC9ceT6a7Vmb # AIRS_10000.zip

In [0]:
# Import libraries
import tensorflow as tf
import os
import cv2
import time
import numpy as np
from tqdm.notebook import tqdm
from pathlib import Path
from matplotlib import pyplot as plt
from IPython import display
from tensorflow.keras import backend as K
from random import randint, shuffle
import itertools
import matplotlib
import datetime
from tensorflow.keras.utils import plot_model
from tensorflow.keras.mixed_precision import experimental as mixed_precision
matplotlib.rc('image', cmap='gray')
%load_ext tensorboard

In [0]:
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)

In [0]:
#@title Curate Dataset for TF (only run the first time)

resized_sizes = {}

num_per_cities_train = 1 #@param {type:"integer"}
resized_sizes['train'] = num_per_cities_train
num_per_cities_val =  5 #@param {type:"integer"}
resized_sizes['val'] = num_per_cities_val

IMG_HEIGHT = IMG_WIDTH = 256
print(f'IMG_WIDTH: {IMG_WIDTH}, IMG_HEIGHT: {IMG_HEIGHT}')

PATH = f'train_{num_per_cities_train}_val_{num_per_cities_val}_size_{IMG_WIDTH}'

curate_dataset = True #@param {type:"boolean"}

if curate_dataset:
  print('Curating Dataset...')
  # if USE_MRA:
  for split in ['train','val']:
    img_count = 0
    images_path = f'inria_train_{num_per_cities_train}_val_{num_per_cities_val}'
    print(images_path)
    images = [f[:-4] for f in os.listdir(f'{images_path}/{split}/images/') if 'jpg' in f]
    Path(f'{PATH}/{split}').mkdir(parents = True, exist_ok = True)
    for f in tqdm(images, desc = split):
      img = cv2.imread(f'{images_path}/{split}/images/{f}.jpg')
      lbl = cv2.imread(f'{images_path}/{split}/labels/{f}.png')*255
      cv2.imwrite(f'{PATH}/{split}/{str(img_count).zfill(6)}.png',np.concatenate([img,lbl],axis=1))
      img_count += 1
    img_count = 0
    Path(f'{PATH}_DL1/{split}').mkdir(parents = True, exist_ok = True)
    images = [f[:-4] for f in os.listdir(f'{images_path}/{split}/images/') if 'jpg' in f]
    for f in tqdm(images, desc = split):
      img = cv2.imread(f'{images_path}_DL1/{split}/{f}.jpg')
      lbl = cv2.imread(f'{images_path}_DL1/labels/{split}/{f}.png')
      cv2.imwrite(f'{PATH}_DL1/{split}/{str(img_count).zfill(6)}.png',np.concatenate([img,lbl],axis=1))
      img_count += 1

In [0]:
#@title Choose losses for Generator
USE_L1_LOSS = True #@param {type:"boolean"}
ALPHA = 1 #@param {type:"number"}
USE_L2_LOSS = False #@param {type:"boolean"}
BETA = 1 #@param {type:"number"}
USE_DICE_LOSS = False #@param {type:"boolean"}
GAMMA =  1#@param {type:"number"}
USE_BCE_LOSS = False #@param {type:"boolean"}
DELTA = 1 #@param {type:"number"}
USE_FOCAL_LOSS = False #@param {type:"boolean"}
EPSILON = 1 #@param {type:"number"}

In [0]:
#@title Parameters
# Parameters
BUFFER_SIZE = 400 #@param {type:"number"}
BATCH_SIZE = 32 #@param {type:"number"}
GEN_FACTOR = 1 #@param {type:"number"}
DISC_FACTOR = 1 #@param {type:"number"}
USE_RANDOM_JITTER = True #@param {type:"boolean"}
TRAIN_DL1 = False #@param {type:"boolean"}
if TRAIN_DL1:
  IMG_HEIGHT = IMG_WIDTH = 128
else:
  IMG_HEIGHT = IMG_WIDTH = 256

In [0]:
def load(image_file):
  image = tf.io.read_file(image_file)
  image = tf.image.decode_png(image)[:,:,:3]

  w = tf.shape(image)[1]

  w = w // 2
  real_image = image[:, w:, 0]
  input_image = image[:, :w, :]

  input_image = tf.cast(input_image, tf.float32)
  real_image = tf.cast(real_image, tf.float32)
  real_image = tf.stack([255-real_image,real_image],axis=-1)

  return input_image, real_image

def resize(input_image, real_image, height, width):
  input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  real_image = tf.image.resize(real_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  return input_image, real_image

def random_crop(input_image, real_image):
  real_image = tf.concat([real_image,real_image[:,:,:1]],axis=-1)
  stacked_image = tf.stack([input_image, real_image], axis=0)
  cropped_image = tf.image.random_crop(
      stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image[0], cropped_image[1][:,:,:2]

def normalize(input_image, real_image):
  input_image = (input_image / 127.5) - 1
  #input_image = input_image / 255
  real_image = real_image / 255

  return input_image, real_image

def load_image_train(image_file):
  input_image, real_image = load(image_file)
  if USE_RANDOM_JITTER:
    input_image, real_image = random_jitter(input_image, real_image)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image

def load_image_test(image_file):
  input_image, real_image = load(image_file)
  #input_image, real_image = resize(input_image, real_image, IMG_HEIGHT, IMG_WIDTH)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image

def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result

def upsample(filters, size, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result

def Generator(factor=1,input_shape=[None, None, 3]):
  inputs = tf.keras.layers.Input(shape=input_shape)

  down_stack = [
    downsample(64//factor, 3, apply_batchnorm=False), # (bs, 256, 256, 64)
    downsample(128//factor, 3), # (bs, 128, 128, 128)
    downsample(256//factor, 3), # (bs, 64, 64, 256)
    downsample(512//factor, 3), # (bs, 32, 32, 512)
    downsample(512//factor, 3), # (bs, 16, 16, 512)
    downsample(512//factor, 3), # (bs, 8, 8, 512)
    downsample(512//factor, 3), # (bs, 4, 4, 512)
    downsample(512//factor, 3), # (bs, 2, 2, 512)
  ]

  up_stack = [
    upsample(512//factor, 3, apply_dropout=True), # (bs, 8, 8, 1024)          
    upsample(512//factor, 3, apply_dropout=True), # (bs, 8, 8, 1024)
    upsample(512//factor, 3, apply_dropout=True), # (bs, 16, 16, 1024)
    upsample(512//factor, 3), # (bs, 32, 32, 1024)
    upsample(256//factor, 3), # (bs, 64, 64, 512)
    upsample(128//factor, 3), # (bs, 128, 128, 256)
    upsample(64//factor, 3), # (bs, 256, 256, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(2, 3,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer) # (bs, 512, 512, 3)

  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = tf.keras.layers.Concatenate()([x, skip])

  x = last(x)
  x = tf.keras.layers.Activation('softmax', dtype='float32', name='predictions')(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

def generate_images(test_input, tar):
  prediction = model(test_input, training=False)

  pred = np.array(tf.argmax(prediction[0],axis=-1))
  t = np.array(tf.argmax(tar[0],axis=-1))
  tp = (pred*t).sum()
  tn = ((1-pred)*(1-t)).sum()
  fp = (pred*(1-t)).sum()
  fn = ((1-pred)*t).sum()
  accuracy = (tp + tn) / (tp + tn + fp + fn)
  precision = tp / (tp + fp)
  recall = tp / (tp + fn)
  f1 = (2 * precision * recall) / (precision + recall)
  iou = tp / (tp + fp + fn)

  plt.figure(figsize=(15,15))

  display_list = [test_input[0], t, pred]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']

  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()
  print(f'For this example:\nAccuracy: {accuracy}\nPrecision: {precision}\nRecall: {recall}\nF1 Score: {f1}\nIoU: {iou}')

def get_miou(dataset, dataset_length, desc = None):
  preds = None
  ts = None
  first = True
  for inp, tar in tqdm(dataset.take(dataset_length),total = dataset_length, desc = desc):
    prediction = model(inp, training=False)
    pred = np.array(prediction).argmax(axis=-1).astype(np.uint8)
    t = np.array(tar).argmax(axis=-1).astype(np.uint8)
    if first:
      preds = pred.copy()
      ts = t.copy()
      first = False
    else:
      preds = np.concatenate([preds, pred], axis = 0)
      ts = np.concatenate([ts, t], axis = 0)
  
  tp = (preds*ts).sum()
  tn = ((1-preds)*(1-ts)).sum()
  fp = (preds*(1-ts)).sum()
  fn = ((1-preds)*ts).sum()
  accuracy = (tp + tn) / (tp + tn + fp + fn)
  precision = tp / (tp + fp)
  recall = tp / (tp + fn)
  f1 = (2 * precision * recall) / (precision + recall)
  iou = tp / (tp + fp + fn)
  print(f'TP: {tp}, TN: {tn}, FP: {fp}, FN: {fn}')
  print(f'Accuracy: {accuracy}\nPrecision: {precision}\nRecall: {recall}\nF1 Score: {f1}\nIoU: {iou}')
  return preds, ts, accuracy, precision, recall, f1, iou

@tf.function()
def random_jitter(input_image, real_image):
  
  input_image, real_image = resize(input_image, real_image, int(1.125*IMG_WIDTH), int(1.125*IMG_HEIGHT))

  input_image, real_image = random_crop(input_image, real_image)

  if tf.random.uniform(()) > 0.5:
    # random mirroring
    input_image = tf.image.flip_left_right(input_image)
    real_image = tf.image.flip_left_right(real_image)

  return input_image, real_image

In [0]:
def l1_loss(y_true, y_pred):
  return tf.reduce_mean(tf.abs(y_true - y_pred))

def l2_loss(y_true, y_pred):
  return tf.reduce_mean(tf.square(y_true - y_pred))

def dice_loss(y_true, y_pred):
  numerator = tf.reduce_sum(y_true * y_pred,axis=[1,2]) + 1e-6
  denominator = tf.reduce_sum(tf.square(y_true) + tf.square(y_pred), axis=[1,2]) + 1e-6
  return (1 - tf.reduce_mean(numerator / (denominator)))

def bce_loss(y_true, y_pred):
  bce = tf.keras.losses.BinaryCrossentropy()
  return bce(y_true, y_pred)

def metric_func(y_true,y_pred):
  y_pred = K.argmax(y_pred,-1)
  y_true = K.argmax(y_true,-1)
  tp = K.sum(y_pred*y_true)
  tn = K.sum((1-y_pred)*(1-y_true))
  fp = K.sum(y_pred*(1-y_true))
  fn = K.sum((1-y_pred)*y_true)
  iou = tp / (tp + fp + fn)
  return iou

In [0]:
# Create TF dataset
if TRAIN_DL1:
  train_dataset = tf.data.Dataset.list_files(PATH+'_DL1/train/*.png')
  train_dataset_length = int(np.ceil(len(os.listdir(PATH+'_DL1/train/'))/BATCH_SIZE))
else:
  train_dataset = tf.data.Dataset.list_files(PATH+'/train/*.png')
  train_dataset_length = int(np.ceil(len(os.listdir(PATH+'/train/'))/BATCH_SIZE))
train_dataset = train_dataset.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# train_dataset = train_dataset.shuffle(BUFFER_SIZE)
# train_dataset = train_dataset.batch(BATCH_SIZE)
train_dataset = train_dataset.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

if TRAIN_DL1:
  val_dataset = tf.data.Dataset.list_files(PATH+'_DL1/val/*.png')
  val_dataset_length = int(np.ceil(len(os.listdir(PATH+'_DL1/val/'))/BATCH_SIZE))
else:
  val_dataset = tf.data.Dataset.list_files(PATH+'/val/*.png')
  val_dataset_length = int(np.ceil(len(os.listdir(PATH+'/val/'))/BATCH_SIZE))
val_dataset = val_dataset.map(load_image_test)
val_dataset = val_dataset.batch(BATCH_SIZE)

if TRAIN_DL1:
  test_dataset = tf.data.Dataset.list_files(PATH+'_DL1/val/*.png')
else:
  test_dataset = tf.data.Dataset.list_files(PATH+'/val/*.png')
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(1)

print(f'train_dataset_length: {train_dataset_length}, val_dataset_length: {val_dataset_length}')

class DisplayCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    #display.clear_output(wait=True)
    #generate_images(example_input, example_target)
    print('Validation:')
    _ = get_miou(val_dataset, val_dataset_length//4)

for example_input, example_target in tqdm(test_dataset.take(100),desc = 'Finding nice example', total = 100):
    if tf.reduce_sum(example_target[:,:,:,1]) < IMG_HEIGHT*IMG_WIDTH/2:
      continue
    else:
      break

In [0]:
import ipywidgets as widgets
button = widgets.Button(description="del logs and models")
output = widgets.Output()

def on_button_clicked(b):
  !rm -rf logs
  !rm -rf models
  print('logs and models deleted.')

button.on_click(on_button_clicked)
display.display(button, output)

In [0]:
for loss_type, loss_func in [('L1',l1_loss),('L2',l2_loss),('dice',dice_loss),('bce',bce_loss)]:
  modelname = ''
  if TRAIN_DL1:
    modelname += 'DL1/'
  else:
    modelname += 'full/'
  if USE_RANDOM_JITTER:
    modelname += 'jitter_yes/'
  else:
    modelname += 'jitter_no/'
  modelname += loss_type
  Path(f'models/{modelname}').mkdir(parents=True,exist_ok = True)
  checkpointer = tf.keras.callbacks.ModelCheckpoint(filepath=f'models/{modelname}'+'/best_val_model.h5', monitor='val_metric_func', mode='max', save_best_only=True, verbose = 0)
  tb_callback = tf.keras.callbacks.TensorBoard(log_dir=f"logs/{modelname}/"+(datetime.datetime.now()+datetime.timedelta(hours=5,minutes=30)).strftime("%Y-%m-%d %H:%M:%S"))
  reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_metric_func', mode = 'max', factor=1/np.sqrt(10), patience = 10, min_lr=1e-6, verbose = 0)
  callbacks = [checkpointer, tb_callback, reduce_lr]

  model = Generator()
  model.compile(optimizer=tf.keras.optimizers.Adam(1e-2), loss=loss_func, metrics=[metric_func])

  EPOCHS = 50
  model_history = model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=train_dataset_length,
                            validation_data=val_dataset,validation_steps=val_dataset_length//4,
                            callbacks=[callbacks], verbose = 2)

In [0]:
   %tensorboard --logdir logs