<a href="https://colab.research.google.com/github/airsresincrop/AIRS/blob/master/dynamic_inria_pretrainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
#@title Download Dataset (only have to run once per dataset per runtime)
Dataset = "inria_train_1_val_5" #@param ["inria_train_1_val_1", "inria_train_1_val_5", "inria_train_5_val_1", "inria_train_5_val_5", "inria_train_10_val_5", "inria_train_15_val_5", "inria_train_31_val_5"]

link_dict = {
  'inria_train_1_val_1': 'https://drive.google.com/uc?id=1-b_6cWLvzQiyEj1V5uZm51hmXjNjcTF0',
  'inria_train_1_val_5': 'https://drive.google.com/uc?id=1-avaIA0OVN3ee2yZAhQEcOSIKuLbnRJ1',
  'inria_train_5_val_1': 'https://drive.google.com/uc?id=1-aLH5GUtqUTuF1nLRnCr1xOw7td1qsDr',
  'inria_train_5_val_5': 'https://drive.google.com/uc?id=1-_vwFzOrkdyVGwIvZadsEgbT-UoH7TPT',
  'inria_train_10_val_5': 'https://drive.google.com/uc?id=1-_bFTMLKQC3nGEzlQiETtrwMg4vd8Ath',
  'inria_train_15_val_5': 'https://drive.google.com/uc?id=1-YfYq5gXHB_pos6b7EQN17Oqfc1NcUJB',
  'inria_train_31_val_5': 'https://drive.google.com/uc?id=1-Me0Bp-G4vInbybLbR3vVyM-3jOxZ5mg'
}

mra_link_dict = {
  'inria_train_1_val_1': 'https://drive.google.com/uc?id=1-DCJXY3Cze1E4TwjtEmThTzQNCy1qVn4',
  'inria_train_1_val_5': 'https://drive.google.com/uc?id=1-6DzYB889BUSXHF2to1NV1CHB0WZe9DC',
  'inria_train_5_val_1': 'https://drive.google.com/uc?id=1-KBgxPwWrsTJSJ8mcxMAdrJkd1eA_kqJ',
  'inria_train_5_val_5': 'https://drive.google.com/uc?id=1-53ARNyjnRqlmfGWQrIyAtt84B385g1k',
  'inria_train_10_val_5': 'https://drive.google.com/uc?id=1-1KjpGv-bUtrmWkr07eQ7tw_HFoUxQ0q',
  'inria_train_15_val_5': 'https://drive.google.com/uc?id=1-0ae-aaZRI4DtYt00hg6QWfYWOxPIyd_',
  'inria_train_31_val_5': 'https://drive.google.com/uc?id=1D6p3JxA6AFM8w42pxjLEttb3l22B6nOH'
}

dataset_link = link_dict[Dataset]

print('Installing Tensorboard...')
!pip install -q -U tensorboard
print('Downloading Dataset...')
!gdown -q {dataset_link}
print('Unzipping Dataset...')
!unzip -qq {Dataset}.zip
print('Deleting Dataset zip...')
!rm {Dataset}.zip

dataset_link = mra_link_dict[Dataset]
print('Downloading MRA Dataset...')
!gdown -q {dataset_link}
print('Unzipping MRA Dataset...')
!unzip -qq {Dataset}_MRA.zip
print('Deleting MRA Dataset zip...')
!rm {Dataset}_MRA.zip

print('Done.')
# !gdown https://drive.google.com/uc?id=1jcQtXg8bYzTlgCMcxDqEYuC62j1XKf3t # AIRS_128.zip
# !gdown https://drive.google.com/uc?id=1MohUSBykyDu94E1PTRelC5h0dA-LuuU1 # AIRS_256.zip
# !gdown https://drive.google.com/uc?id=1rxR0XTiDbPeJg88WLZctzPwuL9JquC4g # AIRS_512.zip
# !gdown https://drive.google.com/uc?id=1-5BDBZL8Tsxzi_wshN3d_EF6Zhn2iitY # AIRS_1024.zip
# !gdown https://drive.google.com/uc?id=1a5A8xat3wJmEx11ml1UUqhTjMKbMFkG9 # AIRS_2048.zip
# !gdown https://drive.google.com/uc?id=1-F3N0NBorYSIikB_my4NY8t8NI2NAjk6 # AIRS_4096.zip
# !gdown https://drive.google.com/uc?id=1-2p-z1L2CWI-dr0EDj-9Ve5MXGBurEQI # AIRS_8192.zip
# !gdown https://drive.google.com/uc?id=1-CsjdU3xVDiIuWQ0s95vKC9ceT6a7Vmb # AIRS_10000.zip

In [0]:
# Import libraries
!pip install -U segmentation-models
import segmentation_models as sm
import tensorflow as tf
import os
import cv2
import time
import numpy as np
from tqdm.notebook import tqdm
from pathlib import Path
from matplotlib import pyplot as plt
from IPython import display
from tensorflow.keras import backend as K
from random import randint, shuffle
import itertools
import matplotlib
import datetime
from tensorflow.keras.utils import plot_model
from tensorflow.keras.mixed_precision import experimental as mixed_precision
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D, Activation
sm.set_framework('tf.keras')
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)
matplotlib.rc('image', cmap='gray')
%load_ext tensorboard

In [0]:
import ipywidgets as widgets
button = widgets.Button(description="del logs and models")
output = widgets.Output()

def on_button_clicked(b):
  !rm -rf logs
  !rm -rf models
  print('logs and models deleted.')

button.on_click(on_button_clicked)
display.display(button, output)

In [0]:
#@title Parameters

resized_sizes = {}

num_per_cities_train = 1 #@param {type:"integer"}
num_per_cities_val =  5 #@param {type:"integer"}
BUFFER_SIZE = 400 #@param {type:"number"}
BATCH_SIZE = 32 #@param {type:"number"}
GEN_FACTOR = 1 #@param {type:"number"}
DISC_FACTOR = 1 #@param {type:"number"}
USE_RANDOM_JITTER = False #@param {type:"boolean"}
TRAIN_DL1 = False #@param {type:"boolean"}
if TRAIN_DL1:
  IMG_HEIGHT = IMG_WIDTH = 128
else:
  IMG_HEIGHT = IMG_WIDTH = 256
print(f'IMG_WIDTH: {IMG_WIDTH}, IMG_HEIGHT: {IMG_HEIGHT}')
PATH = f'inria_train_{num_per_cities_train}_val_{num_per_cities_val}'

In [0]:
def load(image_file):

  label_file = tf.strings.regex_replace(image_file, "images", "labels")
  label_file = tf.strings.regex_replace(label_file, "jpg", "png")

  input_image = tf.io.read_file(image_file)
  input_image = tf.image.decode_jpeg(input_image)

  real_image = tf.io.read_file(label_file)
  real_image = tf.image.decode_png(real_image, channels = 1)

  input_image = tf.cast(input_image, tf.float32)
  real_image = tf.cast(real_image, tf.float32)
  real_image = tf.concat([1-real_image,real_image],axis=-1)

  return input_image, real_image

def resize(input_image, real_image, height, width):
  input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  real_image = tf.image.resize(real_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  return input_image, real_image

def random_crop(input_image, real_image):
  real_image = tf.concat([real_image,real_image[:,:,:1]],axis=-1)
  stacked_image = tf.stack([input_image, real_image], axis=0)
  cropped_image = tf.image.random_crop(
      stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image[0], cropped_image[1][:,:,:2]

def normalize(input_image, real_image):
  input_image = (input_image / 127.5) - 1
  #input_image = input_image / 255
  #real_image = real_image

  return input_image, real_image

def load_image_train(image_file):
  input_image, real_image = load(image_file)
  if USE_RANDOM_JITTER:
    input_image, real_image = random_jitter(input_image, real_image)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image

def load_image_test(image_file):
  input_image, real_image = load(image_file)
  #input_image, real_image = resize(input_image, real_image, IMG_HEIGHT, IMG_WIDTH)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image

def generate_images(test_input, tar):
  prediction = model(test_input, training=False)

  pred = np.array(tf.argmax(prediction[0],axis=-1))
  t = np.array(tf.argmax(tar[0],axis=-1))
  tp = (pred*t).sum()
  tn = ((1-pred)*(1-t)).sum()
  fp = (pred*(1-t)).sum()
  fn = ((1-pred)*t).sum()
  accuracy = (tp + tn + 1e-14) / (tp + tn + fp + fn + 1e-14)
  precision = (tp + 1e-14) / (tp + fp + 1e-14)
  recall = (tp + 1e-14) / (tp + fn + 1e-14)
  f1 = (2 * precision * recall) / (precision + recall)
  iou = (tp + 1e-14) / (tp + fp + fn + 1e-14)

  plt.figure(figsize=(15,15))

  display_list = [test_input[0], t, pred]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']

  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()
  print(f'For this example:\nAccuracy: {accuracy}\nPrecision: {precision}\nRecall: {recall}\nF1 Score: {f1}\nIoU: {iou}')

def get_metrics(dataset, dataset_length, desc = None):

  first = True
  accuracys = None
  precisions = None
  recalls = None
  f1s = None
  ious = None
  #for inp, tar in tqdm(dataset.take(dataset_length),total = dataset_length, desc = desc):
  for inp, tar in dataset.take(dataset_length):
    prediction = model(inp, training=False)
    pred = np.array(prediction).argmax(axis=-1).astype(np.uint8)
    t = np.array(tar).argmax(axis=-1).astype(np.uint8)
    tp = (pred*t).sum(axis=(1,2))
    tn = ((1-pred)*(1-t)).sum(axis=(1,2))
    fp = (pred*(1-t)).sum(axis=(1,2))
    fn = ((1-pred)*t).sum(axis=(1,2))
    accuracy = (tp + tn + 1e-14) / (tp + tn + fp + fn + 1e-14)
    precision = (tp + 1e-14) / (tp + fp + 1e-14)
    recall = (tp + 1e-14) / (tp + fn + 1e-14)
    f1 = (2 * precision * recall) / (precision + recall)
    iou = (tp + 1e-14) / (tp + fp + fn + 1e-14)
    if first:
      accuracys = accuracy.copy()
      precisions = precision.copy()
      recalls = recall.copy()
      f1s = f1.copy()
      ious = iou.copy()
      first = False
    else:
      accuracys = np.concatenate([accuracys, accuracy], axis = 0)
      precisions = np.concatenate([precisions, precision], axis = 0)
      recalls = np.concatenate([recalls, recall], axis = 0)
      f1s = np.concatenate([f1s, f1], axis = 0)
      ious = np.concatenate([ious, iou], axis = 0)
  accuracy = accuracys.mean()
  precision = precisions.mean()
  recall = recalls.mean()
  f1 = f1s.mean()
  iou = ious.mean()
  print(f'Accuracy: {accuracy:.2f}, Precision: {precision:.2f}, Recall: {recall:.2f}, F1 Score: {f1:.2f}, IoU: {iou:.2f}')
  return accuracys, precisions, recalls, f1s, ious

@tf.function()
def random_jitter(input_image, real_image):
  
  input_image, real_image = resize(input_image, real_image, int(1.125*IMG_WIDTH), int(1.125*IMG_HEIGHT))

  input_image, real_image = random_crop(input_image, real_image)

  if tf.random.uniform(()) > 0.5:
    # random mirroring
    input_image = tf.image.flip_left_right(input_image)
    real_image = tf.image.flip_left_right(real_image)

  return input_image, real_image

In [0]:
# Create TF dataset
if TRAIN_DL1:
  train_dataset = tf.data.Dataset.list_files(PATH+'_DL1/train/*.jpg')
  train_dataset_length = int(np.ceil(len(os.listdir(PATH+'_DL1/train/'))/BATCH_SIZE))
else:
  train_dataset = tf.data.Dataset.list_files(PATH+'/train/images/*.jpg')
  train_dataset_length = int(np.ceil(len(os.listdir(PATH+'/train/images/'))/BATCH_SIZE))
train_dataset = train_dataset.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# train_dataset = train_dataset.shuffle(BUFFER_SIZE)
# train_dataset = train_dataset.batch(BATCH_SIZE)
train_dataset = train_dataset.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

if TRAIN_DL1:
  val_dataset = tf.data.Dataset.list_files(PATH+'_DL1/val/*.jpg')
  val_dataset_length = int(np.ceil(len(os.listdir(PATH+'_DL1/val/'))/BATCH_SIZE))
else:
  val_dataset = tf.data.Dataset.list_files(PATH+'/val/images/*.jpg')
  val_dataset_length = int(np.ceil(len(os.listdir(PATH+'/val/images/'))/BATCH_SIZE))
val_dataset = val_dataset.map(load_image_test)
val_dataset = val_dataset.batch(BATCH_SIZE)

if TRAIN_DL1:
  test_dataset = tf.data.Dataset.list_files(PATH+'_DL1/val/*.jpg')
else:
  test_dataset = tf.data.Dataset.list_files(PATH+'/val/images/*.jpg')
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(1)

print(f'train_dataset_length: {train_dataset_length}, val_dataset_length: {val_dataset_length}')

class DisplayCallback(tf.keras.callbacks.Callback):
  def __init__(self, writer):
    self.writer = writer
  def on_epoch_end(self, epoch, logs=None):
    #display.clear_output(wait=True)
    #generate_images(example_input, example_target)
    print('Validation:',end='')
    accuracys, precisions, recalls, f1s, ious = get_metrics(val_dataset, val_dataset_length // 8)
    with self.writer.as_default():
      tf.summary.scalar(f'my_metrics/accuracy', accuracys.mean(), step=epoch)
      tf.summary.scalar(f'my_metrics/precision', precisions.mean(), step=epoch)
      tf.summary.scalar(f'my_metrics/recall', recalls.mean(), step=epoch)
      tf.summary.scalar(f'my_metrics/f1', f1s.mean(), step=epoch)
      tf.summary.scalar(f'my_metrics/iou', ious.mean(), step=epoch)

for example_input, example_target in tqdm(test_dataset.take(100),desc = 'Finding nice example', total = 100):
    if tf.reduce_sum(example_target[:,:,:,1]) < IMG_HEIGHT*IMG_WIDTH/2:
      continue
    else:
      break

In [0]:
def l1_loss(y_true, y_pred):
  return tf.reduce_mean(tf.abs(y_true - y_pred))

def l2_loss(y_true, y_pred):
  return tf.reduce_mean(tf.square(y_true - y_pred))

def dice_loss(y_true, y_pred):
  numerator = tf.reduce_sum(y_true * y_pred,axis=[1,2]) + 1e-14
  denominator = tf.reduce_sum(tf.square(y_true) + tf.square(y_pred), axis=[1,2]) + 1e-14
  return (1 - tf.reduce_mean(numerator / (denominator)))

def bce_loss(y_true, y_pred):
  bce = tf.keras.losses.BinaryCrossentropy()
  return bce(y_true, y_pred)

def whole_batch_iou_metric(y_true,y_pred):
  y_pred = tf.argmax(y_pred,-1)
  y_true = tf.argmax(y_true,-1)
  tp = tf.reduce_sum(y_pred*y_true)
  tn = tf.reduce_sum((1-y_pred)*(1-y_true))
  fp = tf.reduce_sum(y_pred*(1-y_true))
  fn = tf.reduce_sum((1-y_pred)*y_true)
  iou = tp / (tp + fp + fn)
  return iou

def smpl_iou(y_true,y_pred):
  y_pred = tf.argmax(y_pred,-1)
  y_true = tf.argmax(y_true,-1)
  tp = tf.cast(tf.reduce_sum(y_pred*y_true,axis=[1,2]), 'float32')
  tn = tf.cast(tf.reduce_sum((1-y_pred)*(1-y_true),axis=[1,2]), 'float32')
  fp = tf.cast(tf.reduce_sum(y_pred*(1-y_true),axis=[1,2]), 'float32')
  fn = tf.cast(tf.reduce_sum((1-y_pred)*y_true,axis=[1,2]), 'float32')
  iou = (tp + 1e-14) / (tp + fp + fn + 1e-14)
  return tf.reduce_mean(iou)

def smpl_miou(y_true, y_pred):
  y_pred = tf.argmax(y_pred,-1)
  y_true = tf.argmax(y_true,-1)
  tp = tf.cast(tf.reduce_sum(y_pred*y_true,axis=[1,2]), 'float32')
  tn = tf.cast(tf.reduce_sum((1-y_pred)*(1-y_true),axis=[1,2]), 'float32')
  fp = tf.cast(tf.reduce_sum(y_pred*(1-y_true),axis=[1,2]), 'float32')
  fn = tf.cast(tf.reduce_sum((1-y_pred)*y_true,axis=[1,2]), 'float32')
  miou = ((tp + 1e-14)/(tp+fp+fn + 1e-14) + (tn + 1e-14)/(tn+fp+fn + 1e-14))/2
  return tf.reduce_mean(miou)

class MyMeanIOU(tf.keras.metrics.MeanIoU):
    def update_state(self, y_true, y_pred, sample_weight=None):
        return super().update_state(tf.argmax(y_true, axis=-1), tf.argmax(y_pred, axis=-1), sample_weight)

In [0]:
# og_model = sm.Unet(BACKBONE, encoder_weights='imagenet', classes=2, activation='softmax',encoder_freeze=True)
# act = tf.keras.layers.Activation('softmax',dtype='float32')(og_model.layers[-2].output)
# model = tf.keras.models.Model(inputs=og_model.inputs, outputs=act)
# plot_model(model,show_shapes=True, show_layer_names=False)
# !pip install -U tensorboard_plugin_profile

In [0]:
BACKBONE = 'resnet34'
preprocess_input = sm.get_preprocessing(BACKBONE)
for loss_type, loss_func in [('l1_loss',l1_loss),('l2_loss',l2_loss),('bce_dice_loss',sm.losses.bce_dice_loss),('bce_jaccard_loss',sm.losses.bce_jaccard_loss),('binary_focal_loss',sm.losses.binary_focal_loss),('binary_focal_dice_loss',sm.losses.binary_focal_dice_loss),('binary_focal_jaccard_loss',sm.losses.binary_focal_jaccard_loss)]:
  modelname = f'{BACKBONE}/'
  if TRAIN_DL1:
    modelname += 'DL1/'
  else:
    modelname += 'full/'
  if USE_RANDOM_JITTER:
    modelname += 'jitter_yes/'
  else:
    modelname += 'jitter_no/'
  modelname += '1e-3/'
  modelname += loss_type
  Path(f'models/{modelname}').mkdir(parents=True,exist_ok = True)
  checkpointer = tf.keras.callbacks.ModelCheckpoint(filepath=f'models/{modelname}'+'/model.h5', monitor='val_smpl_iou', mode='max', save_best_only=True, verbose = 0)
  tb_callback = tf.keras.callbacks.TensorBoard(log_dir=f"logs/{modelname}/"+(datetime.datetime.now()+datetime.timedelta(hours=5,minutes=30)).strftime("%Y-%m-%d %H:%M:%S"),write_graph=False)
  reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_smpl_iou', mode = 'max', factor=1/np.sqrt(10), patience = 5, min_lr=1e-6, verbose = 0)
  summary_writer = tf.summary.create_file_writer(f"logs/{modelname}/" + (datetime.datetime.now()+datetime.timedelta(hours=5,minutes=30)).strftime("%Y-%m-%d %H:%M:%S"))
  callbacks = [checkpointer, tb_callback, reduce_lr, DisplayCallback(summary_writer)]
  metrics = [smpl_iou, sm.metrics.iou_score]

  #model = sm.Unet(BACKBONE, encoder_weights='imagenet', classes=2, activation='softmax',encoder_freeze=True)
  og_model = sm.Unet(BACKBONE, encoder_weights='imagenet', classes=2, activation='softmax',encoder_freeze=True)
  act = tf.keras.layers.Activation('softmax',dtype='float32')(og_model.layers[-2].output)
  model = tf.keras.models.Model(inputs=og_model.inputs, outputs=act)
  model.compile(optimizer=tf.keras.optimizers.Adam(1e-3), loss=loss_func, metrics=metrics)

  EPOCHS = 40
  model_history = model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=train_dataset_length,
                            validation_data=val_dataset,validation_steps=val_dataset_length // 8,
                            callbacks=[callbacks], verbose = 2)

In [0]:
%tensorboard --logdir logs

In [0]:
!pip install image-classifiers

In [0]:
from classification_models.tfkeras import Classifiers

In [0]:
ResNet34, preprocess_input = Classifiers.get('resnet34')

In [0]:
resnet34 = ResNet34(include_top=False, input_shape = (256,256,3))

In [0]:
resnet34.summary()