<a href="https://colab.research.google.com/github/airsresincrop/AIRS/blob/master/new_MRA_pix2pix.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
USE_MRA = True #@param {type:"boolean"}

In [0]:
#@title Download Dataset (only have to run once per dataset per runtime)
Dataset = "AIRS_2048" #@param ["AIRS_256", "AIRS_512", "AIRS_1024", "AIRS_2048"]

link_dict = {
  'AIRS_256': 'https://drive.google.com/uc?id=1q8MKNhQ4n05loymGAxI6vjMoQl9mEHbq',
  'AIRS_512': 'https://drive.google.com/uc?id=1WcEfywv97anuPs9AYE64oO3FEsROC0jg',
  'AIRS_1024': 'https://drive.google.com/uc?id=1x3WTLwWdbsbcDbkPmT5YvrX1WnvWPrK5',
  'AIRS_2048': 'https://drive.google.com/uc?id=1--eTWLDHAw0AlZdVTBIssbrdJ_1hEf-U'
}

mra_link_dict = {
  'AIRS_256': 'https://drive.google.com/uc?id=1-EFPPIJH2plGy7UqOyH2nXA75OOomxRZ',
  'AIRS_512': 'https://drive.google.com/uc?id=1-17C-mEy8YHwpKUbXFScLa0xYVnmdT2K',
  'AIRS_1024': 'https://drive.google.com/uc?id=1-R2hfd9OxiFy0pAjbqk4rEBsBJG92Fo4',
  'AIRS_2048': 'https://drive.google.com/uc?id=1-VhkFJRMa-hVm0fYsoi7skZX9qpqWy88'
}

dataset_link = link_dict[Dataset]

print('Installing Tensorboard...')
!pip install -q -U tensorboard
print('Downloading Dataset...')
!gdown -q {dataset_link}
print('Unzipping Dataset...')
!unzip -qq {Dataset}.zip
print('Deleting Dataset zip...')
!rm {Dataset}.zip

dataset_link = mra_link_dict[Dataset]

if USE_MRA:
  print('Downloading MRA Dataset...')
  !gdown -q {dataset_link}
  print('Unzipping MRA Dataset...')
  !unzip -qq {Dataset}_MRA.zip
  print('Deleting MRA Dataset zip...')
  !rm {Dataset}_MRA.zip

print('Done.')
# !gdown https://drive.google.com/uc?id=1jcQtXg8bYzTlgCMcxDqEYuC62j1XKf3t # AIRS_128.zip
# !gdown https://drive.google.com/uc?id=1MohUSBykyDu94E1PTRelC5h0dA-LuuU1 # AIRS_256.zip
# !gdown https://drive.google.com/uc?id=1rxR0XTiDbPeJg88WLZctzPwuL9JquC4g # AIRS_512.zip
# !gdown https://drive.google.com/uc?id=1-5BDBZL8Tsxzi_wshN3d_EF6Zhn2iitY # AIRS_1024.zip
# !gdown https://drive.google.com/uc?id=1a5A8xat3wJmEx11ml1UUqhTjMKbMFkG9 # AIRS_2048.zip
# !gdown https://drive.google.com/uc?id=1-F3N0NBorYSIikB_my4NY8t8NI2NAjk6 # AIRS_4096.zip
# !gdown https://drive.google.com/uc?id=1-2p-z1L2CWI-dr0EDj-9Ve5MXGBurEQI # AIRS_8192.zip
# !gdown https://drive.google.com/uc?id=1-CsjdU3xVDiIuWQ0s95vKC9ceT6a7Vmb # AIRS_10000.zip

In [0]:
# Import libraries
import tensorflow as tf
import os
import cv2
import time
import numpy as np
from tqdm.notebook import tqdm
from pathlib import Path
from matplotlib import pyplot as plt
from IPython import display
from random import randint, shuffle
import itertools
import matplotlib
matplotlib.rc('image', cmap='gray')

In [0]:
#@title Curate Dataset for TF (only run the first time)

resized_sizes = {}
datasets = {}

resized_sizes_train = ["512", "1024", "2048"] #@param {type:"raw"}
datasets_train =  ["crop_full", "crop_512", "crop_512"]#@param {type:"raw"}
resized_sizes['train'] = resized_sizes_train
datasets['train'] = datasets_train
resized_sizes_val =  ["1024"]#@param {type:"raw"}
datasets_val =  ["crop_512"]#@param {type:"raw"}
resized_sizes['val'] = resized_sizes_val
datasets['val'] = datasets_val

IMG_HEIGHT = IMG_WIDTH = int(resized_sizes['train'][0])
CROP = datasets['train'][0].split('_')[1]
if CROP != 'full':
  IMG_WIDTH = int(CROP)
  IMG_HEIGHT = IMG_WIDTH
print(f'IMG_WIDTH: {IMG_WIDTH}, IMG_HEIGHT: {IMG_HEIGHT}')

PATH = f'train_{"_".join(resized_sizes["train"])}__val_{"_".join(resized_sizes["val"])}__size_{IMG_WIDTH}'

curate_dataset = True #@param {type:"boolean"}

if curate_dataset:
  print('Curating Dataset...')
  # if USE_MRA:
  for split in ['train','val']:
    img_count = 0
    for resized_size, dataset in zip(resized_sizes[split],datasets[split]):
      images_path = f'full_data/AIRS_{resized_size}/{dataset}'
      print(images_path)
      images = sorted([f[:-4] for f in os.listdir(f'{images_path}/{split}/images/') if 'jpg' in f],key=lambda x:int(x))
      Path(f'{PATH}/{split}').mkdir(parents = True, exist_ok = True)
      for f in tqdm(images, desc = split):
        final_image = np.zeros((IMG_HEIGHT*2, IMG_WIDTH*2,3),dtype=np.uint8)
        final_image[:IMG_HEIGHT,:IMG_WIDTH] = plt.imread(f'{images_path}/{split}/images/{f}.jpg')
        final_image[:IMG_HEIGHT,IMG_WIDTH:] = plt.imread(f'{images_path}/{split}/labels/{f}.tif')*255
        final_image[IMG_HEIGHT:IMG_HEIGHT+IMG_HEIGHT//2,:IMG_WIDTH//2] = plt.imread(f'{images_path}_DL1/LL/{split}/{f}.jpg')
        final_image[IMG_HEIGHT:IMG_HEIGHT+IMG_HEIGHT//2,IMG_WIDTH//2:IMG_WIDTH] = plt.imread(f'{images_path}_DL1/LH/{split}/{f}.jpg')
        final_image[IMG_HEIGHT+IMG_HEIGHT//2:,:IMG_WIDTH//2] = plt.imread(f'{images_path}_DL1/HL/{split}/{f}.jpg')
        final_image[IMG_HEIGHT+IMG_HEIGHT//2:,IMG_WIDTH//2:IMG_WIDTH] = plt.imread(f'{images_path}_DL1/HH/{split}/{f}.jpg')
        final_image[IMG_HEIGHT+IMG_HEIGHT//2:,IMG_WIDTH+IMG_WIDTH//2:] = plt.imread(f'{images_path}_DL1/labels/{split}/{f}.tif')*255
        plt.imsave(f'{PATH}/{split}/{str(img_count).zfill(6)}.png',final_image)
        img_count += 1
  # else:
  #   for split in ['train','val']:
  #     img_count = 0
  #     for resized_size, dataset in zip(resized_sizes[split],datasets[split]):
  #       images_path = f'full_data/AIRS_{resized_size}/{dataset}'
  #       print(images_path)
  #       images = sorted([f[:-4] for f in os.listdir(f'{images_path}/{split}/images/') if 'jpg' in f],key=lambda x:int(x))
  #       Path(f'{PATH}/{split}').mkdir(parents = True, exist_ok = True)
  #       for f in tqdm(images, desc = split):
  #         img = plt.imread(f'{images_path}/{split}/images/{f}.jpg')
  #         lbl = plt.imread(f'{images_path}/{split}/labels/{f}.tif')*255
  #         stack = np.concatenate([img, lbl],axis=1)
  #         plt.imsave(f'{PATH}/{split}/{str(img_count).zfill(6)}.png',stack)
  #         img_count += 1

## Run from here everytime you train on the same dataset

In [0]:
#@title Choose losses for Generator
USE_L1_LOSS = True #@param {type:"boolean"}
ALPHA = 100 #@param {type:"number"}
USE_L2_LOSS = False #@param {type:"boolean"}
BETA = 25 #@param {type:"number"}
USE_DICE_LOSS = False #@param {type:"boolean"}
GAMMA =  25#@param {type:"number"}
USE_BCE_LOSS = False #@param {type:"boolean"}
DELTA = 25 #@param {type:"number"}

In [0]:
#@title Parameters
# Parameters
BUFFER_SIZE = 400 #@param {type:"number"}
BATCH_SIZE = 16 #@param {type:"number"}
GEN_FACTOR = 1 #@param {type:"number"}
DISC_FACTOR = 1 #@param {type:"number"}
OUTPUT_CHANNELS = 2 #@param {type:"number"}
EPOCHS = 100 #@param {type:"number"}
EARLYSTOPPING_LIMIT = 20 #@param {type:"number"}
USE_RANDOM_JITTER = True #@param {type:"boolean"}

In [0]:
# Helper Functions 
def load(image_file):
  image = tf.io.read_file(image_file)
  image = tf.image.decode_png(image)[:,:,:3]
  w = tf.shape(image)[1]

  w = w // 2
  input_image = tf.cast(image[:w, :w, :], tf.float32)
  real_image = tf.cast(image[:w, w:, 0], tf.float32)
  real_image = tf.stack([255-real_image,real_image],axis=-1)

  if USE_MRA:
    LL_image = tf.cast(image[w:w+w//2,:w//2], tf.float32)
    LH_image = tf.cast(image[w:w+w//2,w//2:w], tf.float32)
    HL_image = tf.cast(image[w+w//2:,:w//2], tf.float32)
    HH_image = tf.cast(image[w+w//2:,w//2:w], tf.float32)
    input_image = (input_image, LL_image, LH_image, HL_image, HH_image)

  return input_image, real_image

def resize(input_image, real_image, height, width):
  if USE_MRA:
    input_image, LL_image, LH_image, HL_image, HH_image = input_image
    input_image = tf.image.resize(input_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    LL_image = tf.image.resize(LL_image, [height//2, width//2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    LH_image = tf.image.resize(LH_image, [height//2, width//2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    HL_image = tf.image.resize(HL_image, [height//2, width//2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    HH_image = tf.image.resize(HH_image, [height//2, width//2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    input_image = (input_image, LL_image, LH_image, HL_image, HH_image)
  else:
    input_image = tf.image.resize(input_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  real_image = tf.image.resize(real_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  return input_image, real_image

def random_crop(input_image, real_image):

  if USE_MRA:
    input_image, LL_image, LH_image, HL_image, HH_image = input_image
    x2 = tf.random.uniform((1,),minval=IMG_WIDTH//2,maxval=int(1.125*IMG_WIDTH)//2,dtype=tf.int32)[0]
    y2 = tf.random.uniform((1,),minval=IMG_HEIGHT//2,maxval=int(1.125*IMG_HEIGHT)//2,dtype=tf.int32)[0]
    x1 = x2 - IMG_WIDTH//2
    y1 = y2 - IMG_HEIGHT//2
    LL_image = LL_image[y1:y2,x1:x2]
    LH_image = LH_image[y1:y2,x1:x2]
    HL_image = HL_image[y1:y2,x1:x2]
    HH_image = HH_image[y1:y2,x1:x2]
    x1,x2,y1,y2 = x1*2,x2*2,y1*2,y2*2
  else:
    x2 = tf.random.uniform((1,),minval=IMG_WIDTH,maxval=int(1.125*IMG_WIDTH),dtype=tf.int32)[0]
    y2 = tf.random.uniform((1,),minval=IMG_HEIGHT,maxval=int(1.125*IMG_HEIGHT),dtype=tf.int32)[0]
    x1 = x2 - IMG_WIDTH
    y1 = y2 - IMG_HEIGHT
  input_image = input_image[y1:y2,x1:x2]
  real_image = real_image[y1:y2,x1:x2]
  if USE_MRA:
    input_image = (input_image, LL_image, LH_image, HL_image, HH_image)
  return input_image, real_image

def normalize(input_image, real_image):
  if USE_MRA:
    input_image, LL_image, LH_image, HL_image, HH_image = input_image
    LL_image = (LL_image/127.5) - 1
    LH_image = (LH_image/127.5) - 1
    HL_image = (HL_image/127.5) - 1
    HH_image = (HH_image/127.5) - 1
    input_image = (input_image/127.5) - 1
    DL1_stacked = tf.concat([LL_image, LH_image, HL_image, HH_image], axis = -1)
    input_image = (input_image, DL1_stacked)
  else:
    input_image = (input_image / 127.5) - 1
  real_image = real_image/255
  return input_image, real_image

def load_image_train(image_file):
  input_image, real_image = load(image_file)
  if USE_RANDOM_JITTER:
    input_image, real_image = random_jitter(input_image, real_image)
  else:
    #input_image, real_image = resize(input_image, real_image, IMG_HEIGHT, IMG_WIDTH)
    pass
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image

def load_image_test(image_file):
  input_image, real_image = load(image_file)
  #input_image, real_image = resize(input_image, real_image, IMG_HEIGHT, IMG_WIDTH)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image

def dl_conv_layer(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=1, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result

def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result

def upsample(filters, size, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result

def Generator(factor=1,use_dl1=False,use_dl2=False):
  inputs = [tf.keras.layers.Input(shape=[IMG_HEIGHT,IMG_WIDTH,3])]
  dl_inputs = []
  if use_dl1:
    inputs.append(tf.keras.layers.Input(shape=[IMG_HEIGHT//2,IMG_WIDTH//2,12]))
    dl_inputs.append(dl_conv_layer(64//factor, 3, apply_batchnorm=False)(inputs[-1]))
  if use_dl2:
    inputs.append(tf.keras.layers.Input(shape=[IMG_HEIGHT//4,IMG_WIDTH//4,12]))
    dl_inputs.append(dl_conv_layer(128//factor, 3, apply_batchnorm=False)(inputs[-1]))
  
  down_stack = [
    downsample(64//factor, 3, apply_batchnorm=False), # (bs, 256, 256, 64)
    downsample(128//factor, 3), # (bs, 128, 128, 128)
    downsample(256//factor, 3), # (bs, 64, 64, 256)
    downsample(512//factor, 3), # (bs, 32, 32, 512)
    downsample(512//factor, 3), # (bs, 16, 16, 512)
    downsample(512//factor, 3), # (bs, 8, 8, 512)
    downsample(512//factor, 3), # (bs, 4, 4, 512)
    downsample(512//factor, 3, apply_batchnorm=False), # (bs, 2, 2, 512)
  ]

  up_stack = [
    upsample(512//factor, 3, apply_dropout=True), # (bs, 4, 4, 1024)
    upsample(512//factor, 3, apply_dropout=True), # (bs, 8, 8, 1024)
    upsample(512//factor, 3, apply_dropout=True), # (bs, 16, 16, 1024)
    upsample(512//factor, 3), # (bs, 32, 32, 1024)
    upsample(256//factor, 3), # (bs, 64, 64, 512)
    upsample(128//factor, 3), # (bs, 128, 128, 256)
    upsample(64//factor, 3), # (bs, 256, 256, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 3,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='softmax') # (bs, 512, 512, 3)

  x = inputs[0]

  # Downsampling through the model
  skips = []
  for num_layer, down in enumerate(down_stack):
    if num_layer == 1 and use_dl1:
      x = tf.keras.layers.Concatenate()([x, dl_inputs[0]])
      dl_inputs.pop(0)
    if num_layer == 2 and use_dl2:
      x = tf.keras.layers.Concatenate()([x, dl_inputs[0]])
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = tf.keras.layers.Concatenate()([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

def generator_loss(disc_generated_output, gen_output, target):
  total_gen_loss = 0
  gen_loss_dict = {}

  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
  total_gen_loss += gan_loss
  gen_loss_dict['gan_loss'] = gan_loss

  if USE_L1_LOSS:
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    total_gen_loss += ALPHA * l1_loss
    gen_loss_dict['l1_loss'] = l1_loss
  if USE_L2_LOSS:
    l2_loss = tf.reduce_mean(tf.square(target - gen_output))
    total_gen_loss += BETA * l2_loss
    gen_loss_dict['l2_loss'] = l2_loss
  if USE_DICE_LOSS:
    numerator = tf.reduce_sum(target * gen_output,axis=[1,2]) + 1e-6
    denominator = tf.reduce_sum(tf.square(target) + tf.square(gen_output), axis=[1,2]) + 1e-6
    dice_loss = 1 - tf.reduce_mean(numerator / (denominator))
    total_gen_loss += GAMMA * dice_loss
    gen_loss_dict['dice_loss'] = dice_loss
  if USE_BCE_LOSS:
    bce = tf.keras.losses.BinaryCrossentropy()
    bce_loss = bce(target, gen_output)
    total_gen_loss += DELTA * bce_loss
    gen_loss_dict['bce_loss'] = bce_loss
  
  gen_loss_dict['total_gen_loss'] = total_gen_loss

  return gen_loss_dict

def Discriminator(factor=1,use_dl1=False,use_dl2=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[IMG_HEIGHT, IMG_WIDTH, 3], name='input_image')
  tar = tf.keras.layers.Input(shape=[IMG_HEIGHT, IMG_WIDTH, OUTPUT_CHANNELS], name='target_image')
  inputs = [inp, tar]

  x = tf.keras.layers.concatenate([inp, tar]) # (bs, 512, 512, channels*2)

  down1 = downsample(64//factor, 3, False)(x) # (bs, 256, 256, 64)

  if use_dl1:
    dl1_input = tf.keras.layers.Input(shape=[IMG_HEIGHT//2, IMG_WIDTH//2, 12], name='input_dl1')
    inputs.insert(-1, dl1_input)
    dl1_conv = dl_conv_layer(64//factor, 3, False)(dl1_input)
    down1 = tf.keras.layers.concatenate([down1, dl1_conv])

  down2 = downsample(128//factor, 3)(down1) # (bs, 128, 128, 128)

  if use_dl2:
    dl2_input = tf.keras.layers.Input(shape=[IMG_HEIGHT//4, IMG_WIDTH//4, 12], name='input_dl2')
    inputs.insert(-1, dl2_input)
    dl2_conv = dl_conv_layer(128//factor, 3, False)(dl2_input)
    down2 = tf.keras.layers.concatenate([down2, dl2_conv])

  down3 = downsample(256//factor, 3)(down2) # (bs, 64, 64, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 66, 66, 256)
  conv = tf.keras.layers.Conv2D(512//factor, 3, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1) # (bs, 64, 64, 512)

  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 66, 66, 512)

  last = tf.keras.layers.Conv2D(1, 3, strides=1,
                                kernel_initializer=initializer)(zero_pad2) # (bs, 64, 64, 1)

  return tf.keras.Model(inputs=inputs, outputs=last)

def discriminator_loss(disc_real_output, disc_generated_output):

  disc_loss_dict = {}

  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

  disc_loss_dict['real_loss'] = real_loss
  disc_loss_dict['generated_loss'] = generated_loss
  disc_loss_dict['total_disc_loss'] = real_loss + generated_loss

  return disc_loss_dict

def generate_images(model, test_input, tar):
  prediction = model(test_input, training=False)

  pred = np.array(tf.argmax(prediction[0],axis=-1))
  t = np.array(tf.argmax(tar[0],axis=-1))
  fore_inter = np.logical_and(pred,t).sum()+1e-6
  fore_union = np.logical_or(pred,t).sum()+1e-6
  back_inter = np.logical_and(1-pred,1-t).sum()+1e-6
  back_union = np.logical_or(1-pred,1-t).sum()+1e-6
  miou = (fore_inter/fore_union + back_inter/back_union)/2
  plt.figure(figsize=(20,8))

  if USE_MRA:
    display_list = [test_input[0][0], test_input[1][0], t, pred]
    title = ['Input Image', 'MRA', 'Ground Truth', 'Predicted Image']

    for i in range(4):
      plt.subplot(1, 4, i+1)
      plt.title(title[i])
      # getting the pixel values between [0, 1] to plot it.
      if i == 1:
        plotpic = np.zeros(display_list[0].shape,np.float32)
        plotpic[:IMG_HEIGHT//2,:IMG_WIDTH//2] = display_list[i][:,:,:3]
        plotpic[:IMG_HEIGHT//2,IMG_WIDTH//2:] = display_list[i][:,:,3:6]
        plotpic[IMG_HEIGHT//2:,:IMG_WIDTH//2] = display_list[i][:,:,6:9]
        plotpic[IMG_HEIGHT//2:,IMG_WIDTH//2:] = display_list[i][:,:,9:]
        plt.imshow(plotpic * 0.5 + 0.5)
      else:
        plt.imshow(display_list[i] * 0.5 + 0.5)
      plt.axis('off')
    plt.show()
    plt.tight_layout()
    print(f'mIoU for this example: {miou}')
  else:
    display_list = [test_input[0], t, pred]
    title = ['Input Image', 'Ground Truth', 'Predicted Image']

    for i in range(3):
      plt.subplot(1, 3, i+1)
      plt.title(title[i])
      # getting the pixel values between [0, 1] to plot it.
      plt.imshow(display_list[i] * 0.5 + 0.5)
      plt.axis('off')
    plt.show()
    plt.tight_layout()
    print(f'mIoU for this example: {miou}')

def get_miou(model, dataset, dataset_length):
  preds = None
  ts = None
  first = True
  for inp, tar in dataset.take(dataset_length):
    prediction = model(inp, training=False)
    pred = np.array(prediction).argmax(axis=-1).astype(np.uint8)
    t = np.array(tar).argmax(axis=-1).astype(np.uint8)
    if first:
      preds = pred.copy()
      ts = t.copy()
      first = False
    else:
      preds = np.concatenate([preds, pred], axis = 0)
      ts = np.concatenate([ts, t], axis = 0)
  fore_inter = np.logical_and(preds,ts).sum()+1e-6
  fore_union = np.logical_or(preds,ts).sum()+1e-6
  back_inter = np.logical_and(1-preds,1-ts).sum()+1e-6
  back_union = np.logical_or(1-preds,1-ts).sum()+1e-6
  miou = (fore_inter/fore_union + back_inter/back_union)/2
  print(f'mIoU: {miou}')
  return preds, ts, miou

In [0]:
@tf.function()
def random_jitter(input_image, real_image):
  # resizing to 572 x 572 x 3
  input_image, real_image = resize(input_image, real_image, int(1.125*IMG_WIDTH), int(1.125*IMG_HEIGHT))

  # randomly cropping to 512 x 512 x 3
  input_image, real_image = random_crop(input_image, real_image)

  if tf.random.uniform(()) > 0.5:
    # random mirroring

    if USE_MRA:
      input_image, LL_image, LH_image, HL_image, HH_image = input_image
      input_image = tf.image.flip_left_right(input_image)
      LL_image = tf.image.flip_left_right(LL_image)
      LH_image = tf.image.flip_left_right(LH_image)
      HL_image = tf.image.flip_left_right(HL_image)
      HH_image = tf.image.flip_left_right(HH_image)
      input_image = (input_image, LL_image, LH_image, HL_image, HH_image)
    else:
      input_image = tf.image.flip_left_right(input_image)
    real_image = tf.image.flip_left_right(real_image)

  return input_image, real_image

In [0]:
inp, re = load(f'{PATH}/train/000099.png')
# casting to int for matplotlib to show the image
fig = plt.figure(figsize=(12,4))
plt.subplot(131)
if USE_MRA:
  plt.imshow(inp[0]/255.0)
else:
  plt.imshow(inp/255.0)
plt.axis('off')
plt.tight_layout()
plt.subplot(132)
plt.imshow(tf.argmax(re,-1)/255)
plt.axis('off')
if USE_MRA:
  plt.subplot(133)
  plotpic = np.zeros(inp[0].shape,np.uint8)
  plotpic[:len(plotpic)//2,:len(plotpic)//2] = inp[1]
  plotpic[:len(plotpic)//2,len(plotpic)//2:] = inp[2]
  plotpic[len(plotpic)//2:,:len(plotpic)//2] = inp[3]
  plotpic[len(plotpic)//2:,len(plotpic)//2:] = inp[4]
  plt.imshow(plotpic/255.0)
  plt.axis('off')
fig = plt.figure(figsize=(12, 24))
for i in range(0,8,2):
  rj_inp, rj_re = random_jitter(inp, re)
  plt.subplot(4, 2, i+1)
  if USE_MRA:
    plt.imshow(rj_inp[0]/255.0)
  else:
    plt.imshow(rj_inp/255.0)
  plt.imshow(tf.argmax(rj_re,-1)/255, alpha=0.5)
  plt.axis('off')
  if USE_MRA:
    plt.subplot(4, 2, i+2)
    plotpic = np.zeros(inp[0].shape,np.uint8)
    plotpic[:len(plotpic)//2,:len(plotpic)//2] = rj_inp[1]
    plotpic[:len(plotpic)//2,len(plotpic)//2:] = rj_inp[2]
    plotpic[len(plotpic)//2:,:len(plotpic)//2] = rj_inp[3]
    plotpic[len(plotpic)//2:,len(plotpic)//2:] = rj_inp[4]
    plt.imshow(plotpic/255.0)
    plt.axis('off')
  #plt.tight_layout()

In [0]:
# Create TF dataset
train_dataset = tf.data.Dataset.list_files(PATH+'/train/*.png')
train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)
train_dataset_length = 0
for _ in train_dataset:
  train_dataset_length += 1
print(f'train_dataset_length: {train_dataset_length}')

val_dataset = tf.data.Dataset.list_files(PATH+'/val/*.png')
val_dataset = val_dataset.map(load_image_test)
val_dataset = val_dataset.batch(BATCH_SIZE)
val_dataset_length = 0
for _ in val_dataset:
  val_dataset_length += 1
print(f'val_dataset_length: {val_dataset_length}')

test_dataset = tf.data.Dataset.list_files(PATH+'/val/*.png')
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(1)
test_dataset_length = 0
for _ in test_dataset:
  test_dataset_length += 1
print(f'test_dataset_length: {test_dataset_length}')

# test_dataset = tf.data.Dataset.list_files(PATH+'/test/*.png')
# test_dataset = test_dataset.map(load_image_test)
# test_dataset = test_dataset.batch(BATCH_SIZE)
# test_dataset_length = 0
# for _ in test_dataset:
#   test_dataset_length += 1
# print(f'test_dataset_length: {test_dataset_length}')

In [0]:
down_model = downsample(3, 3)
if USE_MRA:
  down_result = down_model(tf.expand_dims(inp[0], 0))
else:
  down_result = down_model(tf.expand_dims(inp, 0))
print (down_result.shape)
up_model = upsample(3, 4)
up_result = up_model(down_result)
print (up_result.shape)

In [0]:
# Create Generator
generator = Generator(factor=GEN_FACTOR, use_dl1=True)
if USE_MRA:
  gen_output = generator([((tf.expand_dims(inp[0],0)/127.5)-1,(tf.expand_dims(tf.concat([inp[1],inp[2],inp[3],inp[4]],axis=-1),0)/127.5)-1)], training=False)
else:
  gen_output = generator((tf.expand_dims(inp,0)/127.5)-1, training=False)
plt.imshow(tf.argmax(gen_output[0],-1))

# Create Discriminator
plt.figure(figsize=(12,4))
plt.subplot(121)
discriminator = Discriminator(DISC_FACTOR, use_dl1=True)
if USE_MRA:
  disc_out = discriminator([((tf.expand_dims(inp[0],0)/127.5)-1,(tf.expand_dims(tf.concat([inp[1],inp[2],inp[3],inp[4]],axis=-1),0)/127.5)-1), gen_output], training=False)
else:
  disc_out = discriminator([tf.expand_dims(inp,0), gen_output], training=False)
plt.imshow(disc_out[0,...,-1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()
plt.title('Generated image input')
plt.subplot(122)
if USE_MRA:
  disc_out = discriminator([((tf.expand_dims(inp[0],0)/127.5)-1,(tf.expand_dims(tf.concat([inp[1],inp[2],inp[3],inp[4]],axis=-1),0)/127.5)-1), tf.expand_dims(re,0)/255], training=False)
else:
  disc_out = discriminator([(tf.expand_dims(inp,0)/127.5)-1, tf.expand_dims(re,0)/255], training=False)
plt.imshow(disc_out[0,...,-1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()
plt.title('Ground truth input')

In [0]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

checkpoint_dir = f'./models/{PATH}'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)
manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=2)

import datetime
log_dir=f"logs/"

summary_writer = tf.summary.create_file_writer(
  log_dir + f"{PATH}/use_mra/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [0]:
@tf.function
def train_step(input_image, target, epoch):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    gen_output = generator(input_image, training=True)

    disc_real_output = discriminator([input_image, target], training=True)
    disc_generated_output = discriminator([input_image, gen_output], training=True)

    gen_loss_dict = generator_loss(disc_generated_output, gen_output, target)
    disc_loss_dict = discriminator_loss(disc_real_output, disc_generated_output)

  generator_gradients = gen_tape.gradient(gen_loss_dict['total_gen_loss'],
                                          generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss_dict['total_disc_loss'],
                                               discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))
  return gen_loss_dict, disc_loss_dict

In [0]:
def fit(train_ds, train_ds_length, val_ds, val_ds_length, epochs, test_ds):
  best_val_miou = 0
  epochs_since_best_val_miou = 0
  for example_input, example_target in test_ds.take(1):
    continue
  for epoch in range(epochs):
    
    display.clear_output(wait=True)
    generate_images(generator, example_input, example_target)

    # Train
    for n, (input_image, target) in tqdm(train_ds.enumerate(), total = train_ds_length, desc = f'{epoch + 1}'):
      gen_loss_dict, disc_loss_dict = train_step(input_image, target, epoch)
    print()

    with summary_writer.as_default():
      for key,value in gen_loss_dict.items():
        tf.summary.scalar(f'gen_losses/{key}', value, step=epoch)
      for key,value in disc_loss_dict.items():
        tf.summary.scalar(f'disc_losses/{key}', value, step=epoch)

    # Train mIoU
    _, _, train_miou = get_miou(generator, train_ds, train_ds_length)
    print(f'Train mIoU: {train_miou*100:.4f} %')
    # Validate mIoU
    _, _, val_miou = get_miou(generator, val_ds, val_ds_length)
    print(f'Validation mIoU: {val_miou*100:.4f} %')
    with summary_writer.as_default():
      tf.summary.scalar('mious/1_train_miou', train_miou*100, step=epoch)
      tf.summary.scalar('mious/2_val_miou', val_miou*100, step=epoch)
      tf.summary.scalar('mious/3_difference', (train_miou-val_miou)*100, step=epoch)
    
    if val_miou > best_val_miou:
      print('Best mIoU so far, saving checkpoint')
      manager.save()
      best_val_miou = val_miou
      epochs_since_best_val_miou = 0
    else:
      epochs_since_best_val_miou += 1

    if epochs_since_best_val_miou == EARLYSTOPPING_LIMIT:
      print(f'{epochs_since_best_val_miou} epochs since improvement, stopping.')
      print(f'Last improvement at epoch #{epoch+1-epochs_since_best_val_miou}')
      break

In [0]:
#docs_infra: no_execute
%load_ext tensorboard
%tensorboard --logdir {log_dir}

In [0]:
fit(train_dataset, train_dataset_length, val_dataset, val_dataset_length, EPOCHS, test_dataset)

### Evaluation

In [0]:
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [0]:
print(USE_MRA)
print('Training')
a,b,c = get_miou(generator, train_dataset, train_dataset_length*4)
print('Validating')
d,e,f = get_miou(generator, val_dataset, val_dataset_length*4)
# print('Testing')
# _, _, _ = get_miou(generator, test_dataset, test_dataset_length)