In [None]:
from google.colab import drive

drive.mount('/content/gdrive')
!nvidia-smi

In [None]:
import tensorflow as tf

from tensorflow.keras.layers import Dense, Conv2D, Dropout, BatchNormalization, Input, Reshape, Flatten, Conv2DTranspose, MaxPooling2D, UpSampling2D, Add
from tensorflow.keras.layers import LeakyReLU, Lambda, ReLU, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint

from sklearn.model_selection import train_test_split

from functools import partial
import re
import numpy as np
import math, os
import seaborn as sns
from matplotlib import pyplot as plt
from datetime import datetime

from glob import glob

import pandas as pd


print(tf.__version__)

# Detecting TPU

Detect TPU, if it is available use it! 
Please, note that use TPU requires the TFRecord data in Google Cloud Storage

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() 
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu is not None:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() 

print("REPLICAS: ", strategy.num_replicas_in_sync)

AUTO = tf.data.experimental.AUTOTUNE

# Training parameters

In [None]:
params = {
  'MSE_LOSS_LAMBDA': 0.01,  
  'PERP_LOSS_LAMBDA': 1,
  #'PERP_LOSS_LAMBDA': 1e-4,
  'perceptual_layers' : [5,8,13,18],
    
  'USE_CUTOUT' : True,

  'EPOCHS' : 1000,

  'BATCH_SIZE' : 16 * strategy.num_replicas_in_sync if strategy.num_replicas_in_sync > 1 else 64,
  'RECORD_SIZE' : (256+20,256+20,3),
  'IMAGE_SIZE' : (256,256,3),
  'LATENT_DIM' : 500,

  'LR_START' : 1e-5,
  'LR_MIN' : 1e-5,
  'LR_MAX' : 9e-4 * strategy.num_replicas_in_sync,
  'LR_RAMPUP_EPOCHS' : 3,
  'LR_SUSTAIN_EPOCHS' : 0,
  'N_CYCLES' : .5,

  'SEED' : 100,

  'use_bn' : True,
  'use_bn_dconv' : True,
  'lrelu_slop' : 0.2,

  'USE_TFRECORD' : True
}

In [None]:
grids = grid_list = [ [3,1], [4,1], [2,2], [3,2], [2,3], [4,3] ]

BASE_OUTPUT = '/content/gdrive/MyDrive/pcbs_cae/'

## The dataset is available in the following Google Storage. Note that is necessary store the data in GS to use TPUs
if params['USE_TFRECORD']:
  BASE_DATA = 'gs://mpi_pcb/'
else:
  BASE_DATA = '/content/gdrive/MyDrive/mpi_pcbs/'

if not os.path.exists(BASE_OUTPUT):
  os.makedirs(BASE_OUTPUT)

# Cutout function

In [None]:
def random_cutout(image, label, height, width, channels=3, min_mask_size=(10, 10), max_mask_size=(80, 80), k=1):
  assert height > min_mask_size[0]
  assert width > min_mask_size[1]
  assert height > max_mask_size[0]
  assert width > max_mask_size[1]

    
  for i in range(k):
    mask_height = tf.random.uniform(shape=[], minval=min_mask_size[0], maxval=max_mask_size[0], dtype=tf.int32)
    mask_width = tf.random.uniform(shape=[], minval=min_mask_size[1], maxval=max_mask_size[1], dtype=tf.int32)

    pad_h = height - mask_height
    pad_top = tf.random.uniform(shape=[], minval=0, maxval=pad_h, dtype=tf.int32)
    pad_bottom = pad_h - pad_top

    pad_w = width - mask_width
    pad_left = tf.random.uniform(shape=[], minval=0, maxval=pad_w, dtype=tf.int32)
    pad_right = pad_w - pad_left

    cutout_area = tf.zeros(shape=[mask_height, mask_width, channels], dtype=tf.uint8)
    cutout_mask = tf.pad([cutout_area], [[0,0],[pad_top, pad_bottom], [pad_left, pad_right], [0,0]], constant_values=1)
    cutout_mask = tf.squeeze(cutout_mask, axis=0)
    image = tf.multiply(tf.cast(image, tf.float32), tf.cast(cutout_mask, tf.float32))
        
  return image, label

# Dataset functions

In [None]:
def decode_image(image):
  image = tf.image.decode_jpeg( image, channels=params['RECORD_SIZE'][2] )
  image = tf.cast(image, tf.float32) / 255.0

  image = tf.reshape(image, [*params['RECORD_SIZE']])
  return image

def read_tfrecord(example):
  feature = { "image" : tf.io.FixedLenFeature([], tf.string)
              ,"image_name" : tf.io.FixedLenFeature([], tf.string) 
              }

  example = tf.io.parse_single_example(example, feature)
  image = decode_image(example['image'])
  return image

def read_image(img_path):
  img = tf.io.read_file(img_path)
  img = tf.image.decode_image(img, channels=3)
  img = tf.cast(img, tf.float32) / 255.0

  return img

def load_dataset(filenames, use_tfrecord=True, ordered=False):
  ignore_order = tf.data.Options()
  if not ordered:
    ignore_order.experimental_deterministic = False

  if use_tfrecord:
    print("Reading from tfrecord!")
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
    dataset = dataset.with_options(ignore_order)  
    dataset = dataset.map(partial(read_tfrecord), num_parallel_calls=AUTO)
  else:
    print("Reading from file list!")
    dataset = tf.data.Dataset.from_tensor_slices(filenames)
    dataset = dataset.map(read_image, num_parallel_calls=AUTO)

  return dataset

def create_label(image):
  image = tf.image.random_crop(image, [*params['IMAGE_SIZE']] )
  return image, image

def get_center_crop(image):
  return tf.image.crop_to_bounding_box(image, 10, 10, 256, 256)

def get_train_dataset(filenames, use_tfrecord=True):
  dataset = load_dataset(filenames, use_tfrecord=use_tfrecord)
  
  dataset = dataset.map(lambda image: create_label(image), num_parallel_calls=AUTO )
  dataset = dataset.map(lambda image, label: random_cutout(image, label, params['IMAGE_SIZE'][0], params['IMAGE_SIZE'][1], k=tf.random.uniform([],0,15, dtype=tf.int32)), num_parallel_calls=AUTO )

  dataset = dataset.repeat()
  dataset = dataset.shuffle(2048, seed=params['SEED'])
  dataset = dataset.batch(params['BATCH_SIZE'])
  dataset = dataset.prefetch(AUTO)
  
  return dataset

def get_validation_dataset(filenames, use_tfrecord=True):
  dataset = load_dataset(filenames, use_tfrecord=use_tfrecord)
  dataset = dataset.map(lambda image: create_label(image), num_parallel_calls=AUTO )

  dataset = dataset.batch(params['BATCH_SIZE'])
  dataset = dataset.cache()
  dataset = dataset.prefetch(AUTO)
  
  return dataset

def get_test_dataset(filenames, use_tfrecord=True):
  dataset = load_dataset(filenames, use_tfrecord=use_tfrecord)
  dataset = dataset.shuffle(2048, seed=params['SEED'])
  dataset = dataset.batch(params['BATCH_SIZE'])
  dataset = dataset.map(lambda image: get_center_crop(image), num_parallel_calls=AUTO )
  dataset = dataset.cache()
  dataset = dataset.prefetch(AUTO)
  return dataset

# Learning rate decay function

In [None]:
def lrfn(epoch):
  if epoch < LR_RAMPUP_EPOCHS:
    lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
  elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
    lr = LR_MAX
  else:
    progress = (epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS) / (EPOCHS - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS)
    lr = LR_MAX * (0.5 * (1.0 + tf.math.cos(math.pi * N_CYCLES * 2.0 * progress)))
    if LR_MIN is not None:
      lr = tf.math.maximum(LR_MIN, lr)
            
  return lr
  
LR_START = params['LR_START']
LR_MIN = params['LR_MIN']
LR_MAX = params['LR_MAX']
LR_RAMPUP_EPOCHS = params['LR_RAMPUP_EPOCHS']
LR_SUSTAIN_EPOCHS = params['LR_SUSTAIN_EPOCHS']
N_CYCLES = params['N_CYCLES']
EPOCHS = params['EPOCHS']

# CAE architecture definition

In [None]:
def conv_block(x, filters=16, kernel=5, stride=2, transpose=False, leaky=True, slope=0.2, padding='same', bn=False, bias=True, only_conv=False):
  conv = Conv2DTranspose if transpose else Conv2D
  activation = LeakyReLU(slope) if leaky else ReLU()
  
  x = conv(filters=filters, kernel_size=(kernel,kernel), strides=stride, padding=padding, use_bias=bias)(x)
  if not only_conv:
    if bn:
      x = BatchNormalization()(x)
    x = activation(x)
  
  return x

def get_cae():
  with strategy.scope():
    n_leves_enc = 7
    n_dense = int(params['IMAGE_SIZE'][0]/(2**n_leves_enc))

    inputs = Input(shape=params['IMAGE_SIZE'], name='encoder_input')
    conv1 = conv_block(inputs, 32, bn=params['use_bn'], slope=params['lrelu_slop'])
    conv2 = conv_block(conv1, 64, bn=params['use_bn'], slope=params['lrelu_slop'])  
    conv3 = conv_block(conv2, 128, bn=params['use_bn'], slope=params['lrelu_slop'])  
    conv4 = conv_block(conv3, 128, bn=params['use_bn'], slope=params['lrelu_slop'])   
    conv5 = conv_block(conv4, 256, bn=params['use_bn'], slope=params['lrelu_slop'])   
    conv6 = conv_block(conv5, 256, bn=params['use_bn'], slope=params['lrelu_slop'])  
    conv7 = conv_block(conv6, 256, bn=params['use_bn'], slope=params['lrelu_slop']) 

    conv7_flat = Flatten()(conv7)
    fc1 = Dense(units=(n_dense*n_dense*256))(conv7_flat)
    if params['use_bn']:
      fc1 = BatchNormalization()(fc1)
    fc1 = LeakyReLU(alpha=params['lrelu_slop'])(fc1)

    fc1 = Dense(units=params['LATENT_DIM'])(fc1)
    fc1 = LeakyReLU(alpha=params['lrelu_slop'])(fc1)

    fc2 = Dense(units=(n_dense*n_dense*256))(fc1)
    if params['use_bn']:
      fc2 = BatchNormalization()(fc2)
    fc2 = LeakyReLU(alpha=params['lrelu_slop'])(fc2)
    
    z_mat = Reshape((n_dense,n_dense,256))(fc2)
    dconv0 = conv_block(z_mat, 256, transpose=True, bn=params['use_bn_dconv'], slope=params['lrelu_slop'])
    dconv1 = conv_block(dconv0, 256, transpose=True, bn=params['use_bn_dconv'], slope=params['lrelu_slop'])
    dconv2 = conv_block(dconv1, 128, transpose=True, bn=params['use_bn_dconv'], slope=params['lrelu_slop'])
    dconv3 = conv_block(dconv2, 128, transpose=True, bn=params['use_bn_dconv'], slope=params['lrelu_slop'])
    dconv4 = conv_block(dconv3, 64, transpose=True, bn=params['use_bn_dconv'], slope=params['lrelu_slop'])
    dconv5 = conv_block(dconv4, 32, transpose=True, bn=params['use_bn_dconv'], slope=params['lrelu_slop'])
    

    dconv6 = Conv2DTranspose(filters=params['IMAGE_SIZE'][2], kernel_size=(5,5), strides=2, padding='same', 
        use_bias=True, activation='sigmoid')(dconv5)

    ae_model = Model(inputs, dconv6, name="ae")
    #ae_model.summary()
  return ae_model

# MSE + Perceptual loss

In [None]:
def init_perceptual_loss(perp_layers):
  perp_model = tf.keras.applications.VGG19(input_shape=(224,224,3))
  modelOutputs = [perp_model.layers[i].output for i in perp_layers]
  return Model(perp_model.inputs, modelOutputs)

class MSE_Perceptual(tf.keras.losses.Loss):
  def __init__(self, perc_layers, lambda_mse, lambda_perc):
    super(MSE_Perceptual, self).__init__()
    self.perc_layers = perc_layers
    self.lambda_mse = lambda_mse
    self.lambda_perc = lambda_perc
    with strategy.scope():
      self.perc_model = init_perceptual_loss(self.perc_layers)
      self.loss_fn = tf.keras.losses.mean_squared_error

  def call(self, y_true, y_pred):
    mse_loss = tf.reduce_mean( tf.square( tf.math.pow(y_pred,2) - tf.math.pow(y_true,2) ) )

    y_true_reshape = tf.image.resize(y_true, (224, 224))
    y_pred_reshape = tf.image.resize(y_pred, (224, 224))

    y_true_output = self.perc_model(y_true_reshape)
    y_pred_output = self.perc_model(y_pred_reshape)

    perceptual_loss = []

    for idx, (y_true_layer, y_pred_layer) in enumerate(zip(y_true_output, y_pred_output)):
      w, h, d = y_pred_layer[0].shape
      perceptual_loss.append( (1/(w*h*d))*tf.reduce_mean( tf.square( y_true_layer - y_pred_layer ) ) )
      #perceptual_loss.append( (1/(w*h*d))*tf.reduce_sum( tf.square( y_true_layer - y_pred_layer ) ) )

    return self.lambda_mse*mse_loss + self.lambda_perc*tf.reduce_sum(perceptual_loss)



loss_fn = MSE_Perceptual(params['perceptual_layers'], params['MSE_LOSS_LAMBDA'], params['PERP_LOSS_LAMBDA'])

In [None]:
def count_data_items(filenames):
  n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
  return np.sum(n)

In [None]:
for grid_x, grid_y in grids:  
  OUTPUT_PATH = '{}training_{}-{}_s1024/'.format(BASE_OUTPUT, grid_x, grid_y)

  if params['USE_TFRECORD']:
    GCS_PATH = '{}grid{}-{}_s1024'.format(BASE_DATA, grid_x, grid_y)
    TRAINING_FILENAMES = tf.io.gfile.glob(GCS_PATH + "/train*.tfrecord")
    VALID_FILENAMES = tf.io.gfile.glob(GCS_PATH + "/val*.tfrecord")

    print("Number of train tfrecords: {}".format(len(TRAINING_FILENAMES)))
    print("Number of validation tfrecords: {}".format(len(VALID_FILENAMES)))

    NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES)
    NUM_VALIDATION_IMAGES = count_data_items(VALID_FILENAMES)
  else:
    TRAINING_FILENAMES = glob(BASE_DATA + "extracted/grid{}-{}_s1024/train/good/*.png".format(grid_x, grid_y))
    print(TRAINING_FILENAMES)

    TRAINING_FILENAMES, VALID_FILENAMES, _, _ = train_test_split(TRAINING_FILENAMES, TRAINING_FILENAMES, test_size=0.1, random_state=100) ## Same split train/val in the TFRecords...
    NUM_TRAINING_IMAGES = len(TRAINING_FILENAMES)
    NUM_VALIDATION_IMAGES = len(VALID_FILENAMES)
  
  print("{} training images AND {} validation images".format(NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES))

  ae_model = get_cae()
  output_folder = str(datetime.now()).replace(' ', '_').replace(':', '-')

  if not os.path.exists(OUTPUT_PATH+output_folder):
    os.makedirs(OUTPUT_PATH+output_folder)

  config_file = open('{}{}/config_used.conf'.format(OUTPUT_PATH, output_folder),"w+")
  for key in params:
    config_file.write("{}={}\n".format(key, params[key]))
  config_file.close()
  
  lr_callback2 = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose=True)
  cp = ModelCheckpoint(filepath='{}{}/best_model_val_loss.h5'.format(OUTPUT_PATH, output_folder), monitor='val_loss', save_best_only=True)

  STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // params['BATCH_SIZE']
  STEPS_PER_VAL = NUM_VALIDATION_IMAGES // params['BATCH_SIZE']

  ae_model.compile(optimizer="adam",
                 loss=loss_fn,
                 metrics=['mae', 'mse'])
  
  history = ae_model.fit(
    get_train_dataset(TRAINING_FILENAMES, use_tfrecord=params['USE_TFRECORD']),
    validation_data=get_validation_dataset(VALID_FILENAMES, use_tfrecord=params['USE_TFRECORD']),
    epochs=params['EPOCHS'],
    steps_per_epoch=STEPS_PER_EPOCH,
    callbacks=[lr_callback2, cp],
    validation_steps=STEPS_PER_VAL
  )

  
  pd.DataFrame.from_dict(history.history).to_csv('{}{}/history.csv'.format(OUTPUT_PATH, output_folder),index=False)
