In [None]:
import tensorflow as tf
import os
import datetime
import time
from pathlib import Path
import matplotlib.pyplot as plt

In [None]:
dataset_name = "maps"
path_to_zip = tf.keras.utils.get_file(
    fname=f'{dataset_name}.tar.gz',
    origin=f'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/{dataset_name}.tar.gz',
    extract=True
)
path_to_zip = Path(path_to_zip)
path = path_to_zip.parent/dataset_name

In [None]:
def load(image_file):
  image_file = tf.io.read_file(image_file)
  image_file = tf.image.decode_jpeg(image_file)
  w = tf.shape(image_file)[1]
  w = w//2
  input_image = image_file[:,:w,:]
  output_image = image_file[:,w:,:]
  input_image = tf.cast(input_image,dtype=tf.float32)
  output_image = tf.cast(output_image,dtype=tf.float32)
  return input_image,output_image

In [None]:
BUFFER_SIZE = 400
BATCH_SIZE=1
IMG_HEIGHT = 256
IMG_WIDTH=256

In [None]:
def resize(input_img,target_img,height,width):
  input_img = tf.image.resize(input_img,[height,width],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  target_img = tf.image.resize(target_img,[height,width],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  return input_img,target_img

In [None]:
def random_crop(input_img,target_img):
  stacked_image = tf.stack([input_img,target_img],axis=0)
  cropped_image = tf.image.random_crop(stacked_image,size=[2,IMG_HEIGHT,IMG_WIDTH,3])
  return cropped_image[0],cropped_image[1]

In [None]:
def normalize(input_img,target_img):
  input_img = (input_img/127.5)-1
  target_img = (target_img/127.5)-1
  return input_img,target_img

In [None]:
def random_jitter(input_img,real_img):
  input_img,real_img = resize(input_img,real_img,286,286)
  input_img,real_img = random_crop(input_img,real_img)
  if tf.random.uniform(()) > 0.5:
    input_img = tf.image.flip_left_right(input_img)
    real_img = tf.image.flip_left_right(real_img)
  return input_img,real_img  

In [None]:
def load_train_dataset(ds):
  input_img,real_img = load(ds)
  input_img,real_img = random_jitter(input_img,real_img)
  input_img,real_img = normalize(input_img,real_img)
  return input_img,real_img

In [None]:
def load_test_dataset(ds):
  input_img,real_img = load(ds)
  input_img,real_img = resize(input_img,real_img,IMG_HEIGHT,IMG_WIDTH)
  input_img,real_img = normalize(input_img,real_img)
  return input_img,real_img

In [None]:
train_ds = tf.data.Dataset.list_files(str(path/"train/*.jpg"))
train_ds = train_ds.map(load_train_dataset,num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.shuffle(BUFFER_SIZE)
train_ds = train_ds.batch(BATCH_SIZE)

In [None]:
test_ds = tf.data.Dataset.list_files(str(path/"val/*.jpg"))
test_ds = test_ds.map(load_test_dataset,num_parallel_calls=tf.data.AUTOTUNE)
test_ds = test_ds.shuffle(BUFFER_SIZE)
test_ds = test_ds.batch(BATCH_SIZE)

In [None]:
def downsample(filters,size,apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0.,0.02)
  model = tf.keras.Sequential()
  model.add(tf.keras.layers.Conv2D(filters,size,strides=2,padding='same',kernel_initializer=initializer,use_bias=False))
  if apply_batchnorm:
    model.add(tf.keras.layers.BatchNormalization())
  model.add(tf.keras.layers.LeakyReLU())
  return model  

In [None]:
def upsample(filters,size,apply_dropout=False):
  initializer=tf.random_normal_initializer(0.,0.02)
  model = tf.keras.Sequential()
  model.add(tf.keras.layers.Conv2DTranspose(filters,size,strides=2,padding='same',kernel_initializer=initializer,use_bias=False))
  model.add(tf.keras.layers.BatchNormalization())
  if apply_dropout:
    model.add(tf.keras.layers.Dropout(0.5))
  model.add(tf.keras.layers.ReLU())
  return model  

In [None]:
def Generator():
  inputs = tf.keras.Input(shape=[256,256,3])
  down_stacks = [
    downsample(64,4,apply_batchnorm=False),
    downsample(128,4),
    downsample(256,4),
    downsample(512,4),
    downsample(512,4),
    downsample(512,4),
    downsample(512,4),
    downsample(512,4)
  ]
  up_stacks = [
    upsample(512,4,apply_dropout=True),
    upsample(512,4,apply_dropout=True),
    upsample(512,4,apply_dropout=True),
    upsample(512,4),
    upsample(256,4),
    upsample(128,4),
    upsample(64,4)
  ]
  x  = inputs
  skips = []
  for down in down_stacks:
    x = down(x)
    skips.append(x)
  skips = reversed(skips[:-1])
  for up,skip in zip(up_stacks,skips):
    x = up(x)
    x = tf.keras.layers.Concatenate()([x,skip])
  initializer = tf.random_normal_initializer(0.,0.02)    
  last = tf.keras.layers.Conv2DTranspose(3,4,strides=2,kernel_initializer=initializer,use_bias=False,padding='same',activation='tanh')
  x = last(x)
  return tf.keras.Model(inputs=inputs,outputs=x)

In [None]:
generator = Generator()

In [None]:
LAMBDA = 100
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def gen_loss_func(disc_generated_output,gen_output,target):
  gen_loss  = loss_obj(tf.ones_like(disc_generated_output),disc_generated_output)
  l1_loss = tf.reduce_mean(tf.abs(target-gen_output))
  total_loss = gen_loss + (LAMBDA*l1_loss)
  return total_loss

In [None]:
def Discriminator():
  initializer = tf.random_normal_initializer()
  inp = tf.keras.layers.Input(shape=[256,256,3],name='inp')
  tar = tf.keras.layers.Input(shape=[256,256,3],name='tar')
  x = tf.keras.layers.concatenate([inp,tar])
  down1 = downsample(64,4,apply_batchnorm=False)(x)
  down2 = downsample(128,4)(down1)
  down3 = downsample(256,4)(down2)
  zero_padding_2D_1 = tf.keras.layers.ZeroPadding2D()(down3)
  conv = tf.keras.layers.Conv2D(512,4,strides=1,kernel_initializer=initializer,use_bias=False,padding='same',activation='tanh')(zero_padding_2D_1)
  batch_norm_1 = tf.keras.layers.BatchNormalization()(conv)
  relu = tf.keras.layers.LeakyReLU()(batch_norm_1)
  zero_padding_2D_2 = tf.keras.layers.ZeroPadding2D()(relu)
  last = tf.keras.layers.Conv2D(1,4,strides=1,kernel_initializer=initializer,padding='same',use_bias=False)(zero_padding_2D_2)
  return tf.keras.Model(inputs=[inp,tar],outputs=last)

In [None]:
discriminator = Discriminator()

In [None]:
def disc_loss_func(real_loss,fake_loss):
  real_loss = loss_obj(tf.ones_like(real_loss),real_loss)
  fake_loss = loss_obj(tf.zeros_like(fake_loss),fake_loss)
  total_loss = real_loss+fake_loss
  return total_loss

In [None]:
gen_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3,beta_1=0.5)
disc_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3,beta_1=0.5)

In [None]:
def generate_images(model,test_input,tar):
  pred = model(test_input,training=True)
  display_list = [test_input[0],pred[0],tar[0]]
  title = ['Input Image', 'Predicted Image', 'Ground Truth']
  plt.figure(figsize=(15,15))
  for i in range(3):
    plt.subplot(1,3,i+1)
    plt.title(title[i])
    plt.imshow(display_list[i]*0.5+0.5)
    plt.axis('off')
  plt.show()  
  

In [None]:
def train_step(image,target):
  with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape:
    gen_out = generator(image,training=True)
    disc_gen = discriminator([image,gen_out],training=True)
    disc_true = discriminator([image,target],training=True)
    gen_loss = gen_loss_func(disc_gen,gen_out,target)
    disc_loss = disc_loss_func(disc_true,disc_gen)
    disc_gradient = disc_tape.gradient(disc_loss,discriminator.trainable_variables)
    gen_gradient = gen_tape.gradient(gen_loss,generator.trainable_variables)
    disc_optimizer.apply_gradients(zip(disc_gradient,discriminator.trainable_variables))
    gen_optimizer.apply_gradients(zip(gen_gradient,generator.trainable_variables))  

In [None]:
def fit(train_ds,test_ds,steps):
  example_inp,example_target = next(iter(test_ds.take(1)))
  for step,(input_image,target) in train_ds.repeat().take(steps).enumerate():
    start = time.time()
    if step % 1000 ==0:
      if step!=0:
        print(f'Time taken for 1000 steps: {time.time()-start}')
      start = time.time()
      generate_images(generator,example_inp,example_target)
      print(f'Step: {step//1000}')
    train_step(input_image,target)
    if (step+1) % 10 == 0:
      print('.',end='',flush=True)      


In [None]:
fit(train_ds,test_ds,40000)

In [None]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir,"ckpt")
checkpoint = tf.train.Checkpoint(
    generator_optimizer=gen_optimizer,
    discriminator_optimizer=disc_optimizer,
    generator=generator,
    discriminator=discriminator
)

In [None]:
checkpoint.save(file_prefix=checkpoint_prefix)

In [None]:
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))