<a href="https://colab.research.google.com/github/airsresincrop/AIRS/blob/master/inria_pix2pix.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
#@title Download Dataset (only have to run once per dataset per runtime)
Dataset = "inria_train_1_val_1" #@param ["inria_train_5_val_5", "inria_train_5_val_1", "inria_train_1_val_5", "inria_train_1_val_1"]

link_dict = {
  'inria_train_5_val_5': 'https://drive.google.com/uc?id=1-GLe20S5IXAKw1XY__-TSwRtmGmNNDbQ',
  'inria_train_5_val_1': 'https://drive.google.com/uc?id=1-55dfqcRB08zpsgsqFWVKy60ocajjIXq',
  'inria_train_1_val_5': 'https://drive.google.com/uc?id=1-PVdP2n6LczlVzP6bMrCRuoJplQarUJk',
  'inria_train_1_val_1': 'https://drive.google.com/uc?id=1-ROwYE-zqHEwOolymVhUjWEQy4WdVjTr'
}

mra_link_dict = {
  'inria_train_5_val_5': 'https://drive.google.com/uc?id=11XGUyBfK87W7UGLv4xRQuDUKP6Liy0VH',
  'inria_train_5_val_1': 'https://drive.google.com/uc?id=1-0c6VFLoZN1Ldrtpg4yul3vlOxwuUyjX',
  'inria_train_1_val_5': 'https://drive.google.com/uc?id=1-6Z3ocQ-5pnf-_F5UWJ6oEtfhrVWVuwO',
  'inria_train_1_val_1': 'https://drive.google.com/uc?id=1-Bjji5CHm5LNxMB858QbGMeYxybwijTH'
}

dataset_link = link_dict[Dataset]

print('Installing Tensorboard...')
!pip install -q -U tensorboard
print('Downloading Dataset...')
!gdown -q {dataset_link}
print('Unzipping Dataset...')
!unzip -qq {Dataset}.zip
print('Deleting Dataset zip...')
!rm {Dataset}.zip

dataset_link = mra_link_dict[Dataset]
print('Downloading MRA Dataset...')
!gdown -q {dataset_link}
print('Unzipping MRA Dataset...')
!unzip -qq {Dataset}_MRA.zip
print('Deleting MRA Dataset zip...')
!rm {Dataset}_MRA.zip

print('Done.')
# !gdown https://drive.google.com/uc?id=1jcQtXg8bYzTlgCMcxDqEYuC62j1XKf3t # AIRS_128.zip
# !gdown https://drive.google.com/uc?id=1MohUSBykyDu94E1PTRelC5h0dA-LuuU1 # AIRS_256.zip
# !gdown https://drive.google.com/uc?id=1rxR0XTiDbPeJg88WLZctzPwuL9JquC4g # AIRS_512.zip
# !gdown https://drive.google.com/uc?id=1-5BDBZL8Tsxzi_wshN3d_EF6Zhn2iitY # AIRS_1024.zip
# !gdown https://drive.google.com/uc?id=1a5A8xat3wJmEx11ml1UUqhTjMKbMFkG9 # AIRS_2048.zip
# !gdown https://drive.google.com/uc?id=1-F3N0NBorYSIikB_my4NY8t8NI2NAjk6 # AIRS_4096.zip
# !gdown https://drive.google.com/uc?id=1-2p-z1L2CWI-dr0EDj-9Ve5MXGBurEQI # AIRS_8192.zip
# !gdown https://drive.google.com/uc?id=1-CsjdU3xVDiIuWQ0s95vKC9ceT6a7Vmb # AIRS_10000.zip

In [0]:
# Import libraries
import tensorflow as tf
import os
import cv2
import time
import numpy as np
from tqdm.notebook import tqdm
from pathlib import Path
from matplotlib import pyplot as plt
from IPython import display
from random import randint, shuffle
import itertools
import matplotlib
matplotlib.rc('image', cmap='gray')

In [0]:
#@title Curate Dataset for TF (only run the first time)

resized_sizes = {}

num_cities_train = 1 #@param {type:"integer"}
resized_sizes['train'] = num_cities_train
num_cities_val =  1 #@param {type:"integer"}
resized_sizes['val'] = num_cities_val

IMG_HEIGHT = IMG_WIDTH = 256
print(f'IMG_WIDTH: {IMG_WIDTH}, IMG_HEIGHT: {IMG_HEIGHT}')

PATH = f'train_{num_cities_train}_val_{num_cities_val}_size_{IMG_WIDTH}'

curate_dataset = False #@param {type:"boolean"}

if curate_dataset:
  print('Curating Dataset...')
  # if USE_MRA:
  for split in ['train','val']:
    img_count = 0
    images_path = f'inria_train_{num_cities_train}_val_{num_cities_val}'
    print(images_path)
    images = [f[:-4] for f in os.listdir(f'{images_path}/{split}/images/') if 'jpg' in f]
    Path(f'{PATH}/{split}').mkdir(parents = True, exist_ok = True)
    for f in tqdm(images, desc = split):
      img = plt.imread(f'{images_path}/{split}/images/{f}.jpg')/255
      lbl = plt.imread(f'{images_path}/{split}/labels/{f}.png')*255
      plt.imsave(f'{PATH}/{split}/{str(img_count).zfill(6)}.png',np.concatenate([img,lbl],axis=1))
      img_count += 1
    img_count = 0
    Path(f'{PATH}_DL1/{split}').mkdir(parents = True, exist_ok = True)
    images = [f[:-4] for f in os.listdir(f'{images_path}/{split}/images/') if 'jpg' in f]
    for f in tqdm(images, desc = split):
      img = plt.imread(f'{images_path}_DL1/LL/{split}/{f}.jpg')/255
      lbl = plt.imread(f'{images_path}_DL1/labels/{split}/{f}.png')
      plt.imsave(f'{PATH}_DL1/{split}/{str(img_count).zfill(6)}.png',np.concatenate([img,lbl],axis=1))
      img_count += 1

## Run from here everytime you train on the same dataset

In [0]:
#@title Choose losses for Generator
USE_L1_LOSS = True #@param {type:"boolean"}
ALPHA = 100 #@param {type:"number"}
USE_L2_LOSS = False #@param {type:"boolean"}
BETA = 25 #@param {type:"number"}
USE_DICE_LOSS = False #@param {type:"boolean"}
GAMMA =  25#@param {type:"number"}
USE_BCE_LOSS = False #@param {type:"boolean"}
DELTA = 25 #@param {type:"number"}

In [0]:
#@title Parameters
# Parameters
BUFFER_SIZE = 400 #@param {type:"number"}
BATCH_SIZE = 32 #@param {type:"number"}
GEN_FACTOR = 4 #@param {type:"number"}
DISC_FACTOR = 4 #@param {type:"number"}
EPOCHS = 100 #@param {type:"number"}
EARLYSTOPPING_LIMIT = 20 #@param {type:"number"}
USE_RANDOM_JITTER = True #@param {type:"boolean"}
TRAIN_DL1 = True #@param {type:"boolean"}
if TRAIN_DL1:
  IMG_HEIGHT = IMG_WIDTH = 128
else:
  IMG_HEIGHT = IMG_WIDTH = 256

In [0]:
# Helper Functions 
def load(image_file):
  image = tf.io.read_file(image_file)
  image = tf.image.decode_png(image)[:,:,:3]

  w = tf.shape(image)[1]

  w = w // 2
  real_image = image[:, w:, 0]
  input_image = image[:, :w, :]

  input_image = tf.cast(input_image, tf.float32)
  real_image = tf.cast(real_image, tf.float32)
  real_image = tf.stack([255-real_image,real_image],axis=-1)

  return input_image, real_image

def resize(input_image, real_image, height, width):
  input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  real_image = tf.image.resize(real_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  return input_image, real_image

def random_crop(input_image, real_image):
  real_image = tf.concat([real_image,real_image[:,:,:1]],axis=-1)
  stacked_image = tf.stack([input_image, real_image], axis=0)
  cropped_image = tf.image.random_crop(
      stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image[0], cropped_image[1][:,:,:2]

def normalize(input_image, real_image):
  input_image = (input_image / 127.5) - 1
  real_image = real_image / 255

  return input_image, real_image

def load_image_train(image_file):
  input_image, real_image = load(image_file)
  if USE_RANDOM_JITTER:
    input_image, real_image = random_jitter(input_image, real_image)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image

def load_image_test(image_file):
  input_image, real_image = load(image_file)
  #input_image, real_image = resize(input_image, real_image, IMG_HEIGHT, IMG_WIDTH)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image

def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result

def upsample(filters, size, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result

def Generator(factor=1):
  inputs = tf.keras.layers.Input(shape=[None,None,3])

  down_stack = [
    downsample(64//factor, 3, apply_batchnorm=False), # (bs, 256, 256, 64)
    downsample(128//factor, 3), # (bs, 128, 128, 128)
    downsample(256//factor, 3), # (bs, 64, 64, 256)
    downsample(512//factor, 3), # (bs, 32, 32, 512)
    downsample(512//factor, 3), # (bs, 16, 16, 512)
    downsample(512//factor, 3), # (bs, 8, 8, 512)
    downsample(512//factor, 3), # (bs, 4, 4, 512)
    #downsample(512//factor, 3), # (bs, 2, 2, 512)
    #downsample(512//factor, 3), # (bs, 1, 1, 512)
  ]

  up_stack = [
    #upsample(512//factor, 3, apply_dropout=True), # (bs, 2, 2, 1024)
    #upsample(512//factor, 3, apply_dropout=True), # (bs, 4, 4, 1024)
    upsample(512//factor, 3, apply_dropout=True), # (bs, 8, 8, 1024)
    upsample(512//factor, 3, apply_dropout=True), # (bs, 16, 16, 1024)
    upsample(512//factor, 3), # (bs, 32, 32, 1024)
    upsample(256//factor, 3), # (bs, 64, 64, 512)
    upsample(128//factor, 3), # (bs, 128, 128, 256)
    upsample(64//factor, 3), # (bs, 256, 256, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(2, 3,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='softmax') # (bs, 512, 512, 3)

  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = tf.keras.layers.Concatenate()([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

def generator_loss(disc_generated_output, gen_output, target):
  total_gen_loss = 0
  gen_loss_dict = {}

  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
  total_gen_loss += gan_loss
  gen_loss_dict['gan_loss'] = gan_loss

  if USE_L1_LOSS:
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    total_gen_loss += ALPHA * l1_loss
    gen_loss_dict['l1_loss'] = l1_loss
  if USE_L2_LOSS:
    l2_loss = tf.reduce_mean(tf.square(target - gen_output))
    total_gen_loss += BETA * l2_loss
    gen_loss_dict['l2_loss'] = l2_loss
  if USE_DICE_LOSS:
    numerator = tf.reduce_sum(target * gen_output,axis=[1,2]) + 1e-6
    denominator = tf.reduce_sum(tf.square(target) + tf.square(gen_output), axis=[1,2]) + 1e-6
    dice_loss = 1 - tf.reduce_mean(numerator / (denominator))
    total_gen_loss += GAMMA * dice_loss
    gen_loss_dict['dice_loss'] = dice_loss
  if USE_BCE_LOSS:
    bce = tf.keras.losses.BinaryCrossentropy()
    bce_loss = bce(target, gen_output)
    total_gen_loss += DELTA * bce_loss
    gen_loss_dict['bce_loss'] = bce_loss
  
  gen_loss_dict['total_gen_loss'] = total_gen_loss

  return gen_loss_dict

def Discriminator(factor=1):
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[None, None, 3], name='input_image')
  tar = tf.keras.layers.Input(shape=[None, None, 2], name='target_image')

  x = tf.keras.layers.concatenate([inp, tar]) # (bs, 512, 512, channels*2)

  down1 = downsample(64//factor, 3, False)(x) # (bs, 256, 256, 64)
  down2 = downsample(128//factor, 3)(down1) # (bs, 128, 128, 128)
  down3 = downsample(256//factor, 3)(down2) # (bs, 64, 64, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 66, 66, 256)
  conv = tf.keras.layers.Conv2D(512//factor, 3, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1) # (bs, 63, 63, 512)

  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 65, 65, 512)

  last = tf.keras.layers.Conv2D(1, 3, strides=1,
                                kernel_initializer=initializer)(zero_pad2) # (bs, 62, 62, 1)

  return tf.keras.Model(inputs=[inp, tar], outputs=last)

def discriminator_loss(disc_real_output, disc_generated_output):

  disc_loss_dict = {}

  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

  disc_loss_dict['real_loss'] = real_loss
  disc_loss_dict['generated_loss'] = generated_loss
  disc_loss_dict['total_disc_loss'] = real_loss + generated_loss

  return disc_loss_dict

def generate_images(model, test_input, tar):
  prediction = model(test_input, training=False)

  pred = np.array(tf.argmax(prediction[0],axis=-1))
  t = np.array(tf.argmax(tar[0],axis=-1))

  tf_tp = tf.keras.metrics.TruePositives()
  _ = tf_tp.update_state(pred, t)
  tp = tf_tp.result().numpy()
  tf_tn = tf.keras.metrics.TrueNegatives()
  _ = tf_tn.update_state(pred, t)
  tn = tf_tn.result().numpy()
  tf_fp = tf.keras.metrics.FalsePositives()
  _ = tf_fp.update_state(pred, t)
  fp = tf_fp.result().numpy()
  tf_fn = tf.keras.metrics.FalseNegatives()
  _ = tf_fn.update_state(pred, t)
  fn = tf_fn.result().numpy()
  accuracy = (tp + tn) / (tp + tn + fp + fn)
  precision = tp / (tp + fp)
  recall = tp / (tp + fn)
  f1 = (2 * precision * recall) / (precision + recall)
  iou = tp / (tp + fp + fn)

  plt.figure(figsize=(15,15))

  display_list = [test_input[0], t, pred]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']

  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()
  print(f'For this example:\nAccuracy: {accuracy}\nPrecision: {precision}\nRecall: {recall}\nF1 Score: {f1}\nIoU: {iou}')

def get_miou(model, dataset, dataset_length):
  preds = None
  ts = None
  first = True
  for inp, tar in dataset.take(dataset_length):
    prediction = model(inp, training=False)
    pred = np.array(prediction).argmax(axis=-1).astype(np.uint8)
    t = np.array(tar).argmax(axis=-1).astype(np.uint8)
    if first:
      preds = pred.copy()
      ts = t.copy()
      first = False
    else:
      preds = np.concatenate([preds, pred], axis = 0)
      ts = np.concatenate([ts, t], axis = 0)
  
  tf_tp = tf.keras.metrics.TruePositives()
  _ = tf_tp.update_state(preds, ts)
  tp = tf_tp.result().numpy()
  tf_tn = tf.keras.metrics.TrueNegatives()
  _ = tf_tn.update_state(preds, ts)
  tn = tf_tn.result().numpy()
  tf_fp = tf.keras.metrics.FalsePositives()
  _ = tf_fp.update_state(preds, ts)
  fp = tf_fp.result().numpy()
  tf_fn = tf.keras.metrics.FalseNegatives()
  _ = tf_fn.update_state(preds, ts)
  fn = tf_fn.result().numpy()
  accuracy = (tp + tn) / (tp + tn + fp + fn)
  precision = tp / (tp + fp)
  recall = tp / (tp + fn)
  f1 = (2 * precision * recall) / (precision + recall)
  iou = tp / (tp + fp + fn)
  print(f'Accuracy: {accuracy}\nPrecision: {precision}\nRecall: {recall}\nF1 Score: {f1}\nIoU: {iou}')
  return preds, ts, accuracy, precision, recall, f1, iou

In [0]:
@tf.function()
def random_jitter(input_image, real_image):

  input_image, real_image = resize(input_image, real_image, int(1.125*IMG_WIDTH), int(1.125*IMG_HEIGHT))

  input_image, real_image = random_crop(input_image, real_image)

  if tf.random.uniform(()) > 0.5:
    # random mirroring
    input_image = tf.image.flip_left_right(input_image)
    real_image = tf.image.flip_left_right(real_image)

  return input_image, real_image

In [0]:
if TRAIN_DL1:
  inp, re = load(PATH+'_DL1/train/000061.png')
else:
  inp, re = load(PATH+'/train/000061.png')
# casting to int for matplotlib to show the image
fig = plt.figure(figsize=(4,4))
plt.subplot(121)
plt.imshow(inp/255.0)
plt.axis('off')
plt.tight_layout()
plt.subplot(122)
plt.imshow(tf.argmax(re,-1)/255)
plt.axis('off')
plt.tight_layout()
fig = plt.figure(figsize=(4, 4))
for i in range(4):
  rj_inp, rj_re = random_jitter(inp, re)
  plt.subplot(2, 2, i+1)
  plt.imshow(rj_inp/255.0)
  plt.imshow(tf.argmax(rj_re,-1)/255, alpha=0.5)
  plt.axis('off')
fig.tight_layout()

In [0]:
# Create TF dataset
if TRAIN_DL1:
  train_dataset = tf.data.Dataset.list_files(PATH+'_DL1/train/*.png')
else:
  train_dataset = tf.data.Dataset.list_files(PATH+'/train/*.png')
train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)
train_dataset_length = 0
for _ in train_dataset:
  train_dataset_length += 1
print(f'train_dataset_length: {train_dataset_length}')

if TRAIN_DL1:
  val_dataset = tf.data.Dataset.list_files(PATH+'_DL1/val/*.png')
else:
  val_dataset = tf.data.Dataset.list_files(PATH+'/val/*.png')
val_dataset = val_dataset.map(load_image_test)
val_dataset = val_dataset.batch(BATCH_SIZE)
val_dataset_length = 0
for _ in val_dataset:
  val_dataset_length += 1
print(f'val_dataset_length: {val_dataset_length}')

if TRAIN_DL1:
  test_dataset = tf.data.Dataset.list_files(PATH+'_DL1/val/*.png')
else:
  test_dataset = tf.data.Dataset.list_files(PATH+'/val/*.png')
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(1)

In [0]:
down_model = downsample(3, 4)
down_result = down_model(tf.expand_dims(inp, 0))
print (down_result.shape)
up_model = upsample(3, 4)
up_result = up_model(down_result)
print (up_result.shape)

In [0]:
# Create Generator
generator = Generator(factor=GEN_FACTOR)
gen_output = generator((inp[tf.newaxis,...]/127.5)-1, training=False)
plt.imshow(tf.argmax(gen_output[0],-1))
# Create Discriminator
plt.figure(figsize=(12,4))
plt.subplot(121)
discriminator = Discriminator(DISC_FACTOR)
disc_out = discriminator([(inp[tf.newaxis,...]/127.5)-1, gen_output], training=False)
plt.imshow(disc_out[0,...,-1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()
plt.title('Generated image input')
plt.subplot(122)
disc_out = discriminator([(inp[tf.newaxis,...]/127.5)-1, re[tf.newaxis,...]/255], training=False)
plt.imshow(disc_out[0,...,-1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()
plt.title('Ground truth input')

In [0]:
print(generator.count_params())
print(discriminator.count_params())

In [0]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

project_name = '4'

checkpoint_dir = f'./models/{project_name}'
# checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)
manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=1)

import datetime
log_dir="logs/"

summary_writer = tf.summary.create_file_writer(
  log_dir + project_name+"/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [0]:
@tf.function
def train_step(input_image, target, epoch):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    gen_output = generator(input_image, training=True)

    disc_real_output = discriminator([input_image, target], training=True)
    disc_generated_output = discriminator([input_image, gen_output], training=True)

    gen_loss_dict = generator_loss(disc_generated_output, gen_output, target)
    disc_loss_dict = discriminator_loss(disc_real_output, disc_generated_output)

  generator_gradients = gen_tape.gradient(gen_loss_dict['total_gen_loss'],
                                          generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss_dict['total_disc_loss'],
                                               discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))
  return gen_loss_dict, disc_loss_dict

In [0]:
def fit(train_ds, train_ds_length, val_ds, val_ds_length, epochs, test_ds):
  best_val_iou = 0
  epochs_since_best_val_iou = 0
  for example_input, example_target in test_ds.take(100):
    if tf.reduce_sum(example_target) < IMG_HEIGHT*IMG_WIDTH/4:
      continue
    else:
      break
  for epoch in range(epochs):
    
    display.clear_output(wait=True)
    generate_images(generator, example_input, example_target)

    # Train
    for n, (input_image, target) in tqdm(train_ds.enumerate(), total = train_ds_length, desc = f'{epoch + 1}'):
      gen_loss_dict, disc_loss_dict = train_step(input_image, target, epoch)

    with summary_writer.as_default():
      for key,value in gen_loss_dict.items():
        tf.summary.scalar(f'gen_losses/{key}', value, step=epoch)
      for key,value in disc_loss_dict.items():
        tf.summary.scalar(f'disc_losses/{key}', value, step=epoch)

    # Train IoU
    print('Training metrics:')
    _, _, _, _, _, _, train_iou = get_miou(generator, train_ds, train_ds_length)
    # Validate IoU
    print('Validation metrics:')
    _, _, _, _, _, _, val_iou = get_miou(generator, val_ds, val_ds_length)
    with summary_writer.as_default():
      tf.summary.scalar('ious/1_train_iou', train_iou*100, step=epoch)
      tf.summary.scalar('ious/2_val_iou', val_iou*100, step=epoch)
      tf.summary.scalar('ious/3_difference', (train_iou-val_iou)*100, step=epoch)
    
    if val_iou > best_val_iou:
      print('Best IoU so far, saving checkpoint')
      manager.save()
      best_val_iou = val_iou
      epochs_since_best_val_iou = 0
    else:
      epochs_since_best_val_iou += 1

    if epochs_since_best_val_iou == EARLYSTOPPING_LIMIT:
      print(f'{epochs_since_best_val_iou} epochs since improvement, stopping.')
      print(f'Last improvement at epoch #{epoch+1-epochs_since_best_val_iou}')
      break

In [0]:
#docs_infra: no_execute
%load_ext tensorboard
%tensorboard --logdir {log_dir}

In [0]:
fit(train_dataset, train_dataset_length, val_dataset, val_dataset_length, EPOCHS, test_dataset)

### Evaluation

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
!zip -r vary_factors_DL1.zip logs models

In [0]:
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [0]:
print(f'gen_factor: {gen_factor}, disc_factor: {disc_factor}, lambda: {LAMBDA}, loss: {GEN_LOSS_CHOICE}')
print('Training')
_, _, _ = get_miou(generator, train_dataset, train_dataset_length)
print('Validating')
_, _, _ = get_miou(generator, val_dataset, val_dataset_length)
print('Testing')
_, _, _ = get_miou(generator, test_dataset, test_dataset_length)