# Handwriting Style Transfer

In [71]:
import warnings
warnings.filterwarnings('ignore') # disable during imports due to TF addons EOL

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras_applications.imagenet_utils import _obtain_input_shape 
import tensorflow_addons as tfa

import matplotlib.pyplot as plt
import numpy as np
import re

import os

Get Training Architecture

In [72]:
def get_strategy():
    """
    Returns a strategy to use given the available hardware (TPU, GPU, or CPU).
    Also, it prints the device type and initializes if it is a TPU.
    """
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        print('Device: TPU', tpu.master())
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
    except ValueError:
        # TPU detection would fail with a ValueError, catch this to proceed to check GPU.
        print('TPU not found')
        
        # Check for GPUs
        gpus = tf.config.list_physical_devices('GPU')
        if gpus:
            try:
                # Restrict TensorFlow to only use the first GPU
                tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
                
                for gpu in gpus:
                    # Memory growth needs to be the same across GPUs
                    tf.config.experimental.set_memory_growth(gpu, True)
                print('Device: GPU', gpus)
                strategy = tf.distribute.MirroredStrategy()
            except RuntimeError as e:
                # Visible devices and memory growth must be set before GPUs have been initialized
                print(e)
                strategy = tf.distribute.get_strategy()
        else:
            print('Device: CPU')
            strategy = tf.distribute.get_strategy()
    
    print('Number of replicas:', strategy.num_replicas_in_sync)
    return strategy

AUTOTUNE = tf.data.experimental.AUTOTUNE
AUTO = tf.data.experimental.AUTOTUNE   
print('TensorFlow version:', tf.__version__)

strategy = get_strategy()


TensorFlow version: 2.14.0
TPU not found
Device: CPU
Number of replicas: 1


In [75]:
dataset_dir = os.path.join(os.path.expanduser("~"),"datasets", "handwriting", "IAM_Handwriting_Top50")
cropped_dir = os.path.join(dataset_dir, "cropped", "author")

FIRST_AUTHOR_FILENAMES  = tf.io.gfile.glob(os.path.join(cropped_dir, "000", "*.tfrec"))
SECOND_AUTHOR_FILENAMES = tf.io.gfile.glob(os.path.join(cropped_dir, "150", "*.tfrec"))

print("# 1st Author TFREC files: ", len(FIRST_AUTHOR_FILENAMES))
print("# 2nd Author TFREC files: ", len(SECOND_AUTHOR_FILENAMES))

# 1st Author TFREC files:  1
# 2nd Author TFREC files:  1


In [76]:
def data_augment_flip(image):
    image = tf.image.random_flip_left_right(image)
    return image

In [90]:
IMAGE_SIZE = [150, 150]

def decode_image(image):
    image = tf.image.decode_png(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_tfrecord(example):
    tfrecord_format = {
        "image": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

In [91]:
def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    return dataset

BATCH_SIZE = 32

first_author_ds = load_dataset(FIRST_AUTHOR_FILENAMES).batch(1)
second_author_ds = load_dataset(SECOND_AUTHOR_FILENAMES).batch(1)

fast_photo_ds = load_dataset(SECOND_AUTHOR_FILENAMES).batch(32*strategy.num_replicas_in_sync).prefetch(32)
fid_photo_ds = load_dataset(SECOND_AUTHOR_FILENAMES).take(1024).batch(32*strategy.num_replicas_in_sync).prefetch(32)
fid_monet_ds = load_dataset(FIRST_AUTHOR_FILENAMES).batch(32*strategy.num_replicas_in_sync).prefetch(32)

In [92]:
def get_gan_dataset(first_auth_files, second_auth_files, augment=None, repeat=True, shuffle=True, batch_size=1):
    first_ds = load_dataset(first_auth_files)
    second_ds = load_dataset(second_auth_files)
    
    if repeat:
        first_ds = first_ds.repeat()
        second_ds = second_ds.repeat()
    if shuffle:
        first_ds = first_ds.shuffle(2048)
        second_ds = second_ds.shuffle(2048)
        
    first_ds = first_ds.batch(batch_size, drop_remainder=True)
    second_ds = second_ds.batch(batch_size, drop_remainder=True)
    if augment:
        monet_ds = first_ds.map(augment, num_parallel_calls=AUTO)
        second_ds = second_ds.map(augment, num_parallel_calls=AUTO)
        
    first_ds = first_ds.prefetch(AUTO)
    second_ds = second_ds.prefetch(AUTO)
    
    gan_ds = tf.data.Dataset.zip((first_ds, second_ds))
    
    return gan_ds

def get_photo_dataset(second_auth_files, augment=None, repeat=False, shuffle=False, batch_size=1):
    second_ds = load_dataset(second_auth_files)
        
    if repeat:
        second_ds = second_ds.repeat()
    if shuffle:
        second_ds = second_ds.shuffle(2048)
  
    second_ds = second_ds.batch(batch_size, drop_remainder=True)
    if augment:
        second_ds = second_ds.map(augment, num_parallel_calls=AUTO)
    
    second_ds = second_ds.prefetch(AUTO)

    return second_ds

In [93]:
final_dataset = get_gan_dataset(FIRST_AUTHOR_FILENAMES, 
                                SECOND_AUTHOR_FILENAMES, 
                                augment=data_augment_flip, 
                                repeat=True, 
                                shuffle=True, 
                                batch_size=BATCH_SIZE
                               )

In [94]:
with strategy.scope():
    inception_model = tf.keras.applications.InceptionV3(input_shape=(150,150,3),pooling="avg",include_top=False)


    mix3  = inception_model.get_layer("mixed9").output
    f0 = tf.keras.layers.GlobalAveragePooling2D()(mix3)

    inception_model = tf.keras.Model(inputs=inception_model.input, outputs=f0)
    inception_model.trainable = False

    def calculate_activation_statistics_mod(images,fid_model):
            act=tf.cast(fid_model.predict(images), tf.float32)
            mu = tf.reduce_mean(act, axis=0)
            mean_x = tf.reduce_mean(act, axis=0, keepdims=True)
            mx = tf.matmul(tf.transpose(mean_x), mean_x)
            vx = tf.matmul(tf.transpose(act), act)/tf.cast(tf.shape(act)[0], tf.float32)
            sigma = vx - mx
            return mu, sigma
    myFID_mu2, myFID_sigma2 = calculate_activation_statistics_mod(fid_monet_ds,inception_model)        
    fids=[]

