In [34]:
#data Loader
#this will prepare training data for the model

import tensorflow as tf
import os

In [35]:
from tensorflow.python.ops import array_ops, math_ops

In [36]:
class DataLoader(object):
    def __init__(self, image_dir, hr_image_size):
        self.image_paths = [os.path.join(image_dir, x) for x in os.listdir(image_dir)]
        self.image_size = hr_image_size
    def _parse_image(self, image_path):
        image = tf.io.read_file(image_path)
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.image.convert_image_dtype(image, tf.float32)
        
        if tf.keras.backend.image_data_format() == 'channels_last':
            shape = array_ops.shape(image)[:2]
        else:
            shape = array_ops_shape(image)[1:]
        cond = math_ops.reduce_all(shape >= tf.constant(self, image_size))
        image = tf.cond(cond, lambda: tf.identity(image)), lambda: tf.image.resize(image, [self.image_size, self.image_size])
        return image
    def _random_crop(self, image):
        image = tf.image.random_crop(image, [self.image_size, self.image_size, 3])
        return image
    def _high_low_res_pairs(self, high_res):
        low_res = tf.image.resize(high_res, [self.image_size // 4 , self.image_size//4], method = 'bicubic')
        return low_res, high_res
    def _rescale(self, low_res, high_res):
        high_res = high_res * 2.0 - 1.0
        return low_res, high_res

In [37]:
def dataset(self, batch_size, threads=4):
    dataset = tf.data.Dataset.from_tensor_slices(self.image_paths)
    dataset = dataset.map(self._parse_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.map(self._random_crop, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.map(self._high_low_res_pairs, num_parallel_calls=tf.data.experimental.AUTOTUNE) 
    dataset = dataset.map(self._rescale, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.shuffle(30).batch(batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE) 
    return dataset

In [38]:
#Now We are creating the model architecture or brain of the model

In [39]:
from tensorflow import keras

In [40]:
class FastSRGAN(object):
    def __init__(self, args):
        self.hr_height = args.hr.size
        self.hr_width = args.hr.size
        self.lr_height = self.hr_height // 4
        self.lr_width = self.hr_width // 4
        self.lr_shape = (self.lr_height, self.lr_width, 3)
        self.hr_shape = (self.hr_height, self.hr_width, 3)
        self.iteractions = 0
        
        self.n_residual_blocks = 6
        self.gen_schedule = keras.optimizers.schedules.ExponentialDecay(args.lr, decay_steps=100000, decay_rate=0.1, staircase=True)
        self.disc_schedule = keras.optimizers.schedules.ExponentialDecay(args.lr*5, decay_steps=100000, decay_rate=0.1, staircase=True)
        
        self.gen_optimizer = keras.optimizers.Adam(learning_rate=self.gen_schedule)
        self.disc_optimizer = keras.optimizers.Adam(learning_rate=self.disc_schedule)
        
        self.vgg = self.build_vgg()
        self.vgg.trainable = False
        
        patch = int(self.hr_height/2**4)
        self.disc_patch = (patch, patch, 1)
        
        self.gf = 32
        self.df = 32
        
        self.discriminator = self.build_discriminator()
        self.generator = self.build_generator()
    
    @tf.function
    def content_loss(self, hr, sr):
        
        sr = keras.applications.vgg19.preprocess_input(((sr+1.0)*255)/2.0)
        hr = keras.applications.vgg19.preprocess_input(((sr+1.0)*255)/2.0)
        sr_features = self.vgg(sr) / 12.75
        hr_features = self.vgg(sr) / 12.75
        return tf.keras.losses.MeanSquaredError()(hr_features, sr_features)
    
    def build_vgg(self):
        vgg = keras.applications.VGG19(weights='imagenet', input_shape=self.hr_shape, include_top=False)
        vgg.trainable = False
        
        for layer in vgg.layers:
            layer.trainable = False
            
        model = keras.models.Model(inputs=vgg.input, outputs=vgg.get_layer("block5_conv4").output)
        
        return model
    
    def build_generator(self):
        
        def _make_divisible(v, divisor, min_value=None):
            if min_value is None:
                min_value = divisor
            new_v = max(min_value, int(v+divisor/2) // divisor * divisor)
            
            if new_v < 0.9*v:
                new_v += divisor
            return new_v
        def residual_block(inputs, filters, block_id, expansion=6, stride=1, alpha=1.0):
            channel_axis = 1 if keras.backend.image_data_format() == 'channels_first' else -1
            
            in_channels = keras.backend.int_shape(inputs)[channel_axis]
            
            pointwise_conv_filters = int(filters * alpha)
            pointwise_filters = _make_divisible(pointwise_conv_filters, 8)
            x = inputs
            prefix = 'block_{}_'.format(block_id)
            
            if block_id:
                
                x = keras.layers.Conv2D(expansion * in_channels, kernel_size = 1, padding = 'same', use_bias=True, activation=None, name=prefix + 'expand')(x)
                x = keras.layers.BatchNormalization(axis=channel_axis, epsilon=1e-3, padding = 'same', use_bias=True, activation=None, name=prefix + 'expand')(x)
                x = keras.layers.Activation('relu', name=prefix + 'expand_relu')(x)
            else:
                prefix =  'expanded_conv_'
                
            x = keras.layers.DepthwiseConv2D(kernel_size=3,
                                            strides=stride,
                                            activation=None,
                                            use_bias=True,
                                            padding='same' if stride == 1 else 'valid',
                                            name= prefix+'depthwise')(x)
            x = keras.layers.BatchNormalization(axis=channel_axis,
                                               epsilon=1e-3,
                                               momentum=0.999,
                                               name=prefix+'depthwise_BN')(x)
            x = keras.layers.Activation('relu', name=prefix+'depthwise_relu')(x)
            
            x = keras.layers.Conv2D(pointwise_filters,
                                   kernel_size = 1,
                                   padding='same',
                                   use_bias=True,
                                   activation=None,
                                   name=prefix+'project')(x)
            
            x = keras.layers.BatchNormalization(axis=channel_axis,
                                               epsilon=1e-3,
                                               momentum=0.999,
                                               name=prefix+'project_BN')(x)
            
            if in_channels == pointwise_filters and stride == 1: 
                return keras.layers.Add(name=prefix+'add')([inputs, x]) 
            return x 
        
        def deconv2D(layer_input, filters): 
            
            u = keras.layers.Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input) 
            u = tf.nn.depth_to_space(u,2) 
            u = keras.layers.PReLU(shared_axes=[1, 2])(u) 
            return u 
        
        img_lr = keras.Input(shape=self.lr_shape) 
        
        c1 = keras.layers.Conv2D(self.gf, kernel_size=3, strides=1, padding='same')(img_lr) 
        c1 = keras.layers.BatchNormalization()(c1)
        c1 = keras.layers.PReLU(shared_axes=[1, 2])(c1) 
        r = residual_block(cl, self.gf, 0) 
        
        for idx in range(1, self.n_residual_blocks):
            r = residual_block(r, self.gf, idx) 
            
        c2 = keras.layers.Conv2D(self.gf, kernel_size=3, strides=1, padding='same')(r)
        c2 = keras.layers.BatchNormalization()(c2) 
        c2 = keras.layers.Add([c2, c1]) 
        
        u1 = deconv2d(c2, self.gf*4) 
        u2 = deconv2d(u1, self.gf*4) 
        
        gen_hr = keras.layers.Conv2D(3, kernel_size=3, strides=1, padding='same', activation='tanh')(u2) 
        return keras.models.Model(img_lr, gen_hr) 
    
    
    def build_discriminator(self): 
        
        def d_block(layer_input, filters, strides=1, bn=True): 
            d = keras.layers.Conv2D(filters, kernel_size=3, strides = strides, padding='same')(layer_input) 
            if bn:
                d = keras.layers.BatchNormalization(momentum=0.8)(d)
            d = keras.layers.LeakyReLU(alpha=0.2)(d) 
            return d 

        d0 = keras.layers.input(shape=self.hr_shape)

        d1 = d_block(d0, self.df, bn=False)
        d2 = d_block(d1, self.df, strides=2) 
        d3 = d_block(d2, self.df) 
        d4 = d_block(d3, self.df, strides=2) 
        d5 = d_block(d4, self.df*2) 
        d6 = d_block(d5, self.df*2, strides=2) 
        d7 = d_block(d6, self.df*2) 
        d8 = d_block(d7, self.df*2, strides=2) 

        validity = keras.layers.Conv2D(1, kernel_size=1, strides=1, activation='sigmoid', padding='same')(d8) 
        return keras.models.Model(d0, validity) 
        
        

            

In [41]:
#training Code

In [46]:
from argparse import ArgumentParser

parser = ArgumentParser() 
parser.add_argument('--image_dir', type=str, help='Path to high resolution image directory.') 
parser.add_argument('--batch_size', default=8, type=int, help='Batch size for training.') 
parser.add_argument('--epochs', default=1, type=int, help='Number of epochs for training') 
parser.add_argument('--hr_size', default=384, type=int, help='Low resolution input size.') 
parser.add_argument('--lr', default=1e-4, type=float, help='Learning rate for optimizers.') 
parser.add_argument('--save_iter', default=200, type=int, 
                    help='The number of iterations to save the tensorboard summaries and models.') 
@tf.function 
def pretrain_step(model, x, y): 
    """
    Single step of generator pre-training.
    Args: 
        model: A model object with a tf keras compiled generator. 
        x: The low resolution image tensor. 
        y: The high resolution image tensor. 
    """    
        
    with tf.GradientTape() as tape: 
        fake_hr = model.generator(x) 
        loss_mse = tf.keras.losses.MeanSquaredError()(y, fake_hr) 
    grads = tape.gradient(loss_mse, model.generator.trainable_variables)
    model.gen_optimizer.apply_gradients(zip(grads, model.generator.trainable_variables)) 
    return loss_mse 



def pretrain_generator(model, dataset, writer):
    """Function that pretrains the generator slightly, to avoid local minima. 
    Args: 
        model: The keras model to train. 
        dataset: A tf dataset object of low and high res images to pretrain over. 
        writer: A summary writer object. 
    Returns: 
        None
    """
    with writer.as_default(): 
        iteration = 0 
        for _ in range(1): 
            for x, y in dataset: 
                loss = pretrain_step(model, x, y) 
                if iteration % 20 == 0: 
                    tf.summary.scalar('MSE Loss', loss, step=tf.cast(iteration, tf.int64)) 
                    writer.flush() 
                iteration += 1 
@tf.function 
def train_step(model, x, y): 
    """
    Single train step function for the SRGAN. 
    Args: 
        model: An object that contains a tf keras compiled discriminator model. 
        x: The low resolution input image. 
        y: The desired high resolution output image. 
        Returns: 
        d_loss: The mean loss of the discriminator. """ 
    
    # label smoothing for better gradient flow 
    valid = tf.ones((x.shape[0],) + model.disc_patch) 
    fake = tf.zeros((x.shape[0],) + model.disc_patch)
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: 
        # From Low res. image generate high res. version 
        fake_hr = model.generator(x) 
        # Train the discriminators (original images = real / generated = Fake) 
        valid_prediction = model.discriminator(y) 
        fake_prediction = model.discriminator(fake_hr) 
        # Generator Loss 
        content_loss = model.content_loss(y, fake_hr) 
        adv_loss = 1e-3 * tf.keras.losses.BinaryCrossentropy()(valid, fake_prediction)
        mse_loss = tf.keras.losses.MeanSquaredError()(y, fake_hr) 
        perceptual_loss = content_loss + adv_loss + mse_loss 
        # Discriminator Loss 
        valid_loss = tf.keras.losses.BinaryCrossentropy()(valid, valid_prediction) 
        fake_loss = tf.keras.losses.BinaryCrossentropy()(fake, fake_prediction) 
        d_loss = tf.add(valid_loss, fake_loss) 
    # Backprop on Generator 
    gen_grads = gen_tape.gradient(perceptual_loss, model.generator.trainable_variables) 
    model.gen_optimizer.apply_gradients(zip(gen_grads, model.generator.trainable_variables)) 
    # Backprop on Discriminator 
    disc_grads = disc_tape.gradient(d_loss, model.discriminator.trainable_variables) 
    model.disc_optimizer.apply_gradients(zip(disc_grads, model.discriminator.trainable_variables)) 
    return d_loss, adv_loss, content_loss, mse_loss

def train(model, dataset, log_iter, writer):
    """"
    Function that defines a single training step for the SR-GAN. 
    Args: 
        model: An object that contains tf keras compiled generator and 
                discriminator models. 
        dataset: A tf data object that contains low and high res images. 
        log_iter: Number of iterations after which to add logs in 
                  tensorboard. 
        writer: Summary writer 
    """
    with writer.as_default(): 
        # Iterate over dataset 
        for x, y in dataset: 
            disc_loss, adv_loss, content_loss, mse_loss = train_step(model, x, y)
            # Log tensorboard summaries if Log iteration is reached. 
            if model.iterations % log_iter == 0: 
                tf.summary.scalar('Adversarial Loss', adv_loss, step=model.iterations) 
                tf.summary.scalar('Content Loss', content_loss, step=model.iterations) 
                tf.summary.scalar('MSE Loss', mse_loss, step=model.iterations) 
                tf.summary.scalar('Discriminator Loss', disc_loss, step=model.iterations) 
                tf.summary.image('Low Res', tf.cast(255 * x, tf.uint8), step=model.iterations) 
                tf.summary.image('High Res', tf.cast(255 * (y + 1.0) / 2.0, tf.uint8), step=model.iterations) 
                tf.summary.image('Generated', tf.cast(255 * (model.generator.predict(x) + 1.0) / 2.0, tf.uint8), 
                                 step=model.iterations) 
                model.generator.save('models/generator.h5') 
                model.discriminator.save('models/discriminator.h5')
                writer.flush()     
            model.iterations += 1
            
def main(): 
    # Parse the CLI arguments. 
    args = parser.parse_args() 
    
    # create directory for saving troined models. 
    if not os.path.exists('models'): 
        os.makedirs('models') 
    # Create the tensorftow dataset. 
    ds = Dataloader(args.image_dir, args.hr_size).dataset(args.batch_size) 
    
    # Initialize the GAN object. 
    gan = FastSRGAN(args) 
    
    # Define the directory for saving pretrainig Loss tensorboard summary.
    pretrain_summary_writer = tf.summary.create_file_writer('logs/pretrain') 
    # Run pre-training. 
    pretrain_generator(gan, ds, pretrain_summary_writer) 
    # Define the directory for saving the SRGAN training tensorboord summary. 
    train_summary_writer = tf.summary.create_file_writer('logs/train') 
    # Run training.
    for _ in range(args.epochs): 
        train(gan, ds, args.save_iter, train_summary_writer) 



In [47]:
main()

usage: ipykernel_launcher.py [-h] [--image_dir IMAGE_DIR] [--batch_size BATCH_SIZE] [--epochs EPOCHS]
                             [--hr_size HR_SIZE] [--lr LR] [--save_iter SAVE_ITER]
ipykernel_launcher.py: error: unrecognized arguments: -f C:\Users\Artis\AppData\Roaming\jupyter\runtime\kernel-b15b227f-167e-4a68-8cd8-2defd9c3cd6c.json


SystemExit: 2