Skip to content
Permalink
Browse files

Even faster now

  • Loading branch information
HasnainRaz committed Oct 27, 2019
1 parent 81b766c commit 58c949401642bc81ec989d6438089788b0b43baa
Showing with 49 additions and 55 deletions.
  1. +0 −1 dataloader.py
  2. +22 −25 main.py
  3. +27 −29 model.py
  4. BIN models/generator.h5
@@ -87,7 +87,6 @@ def _rescale(self, low_res, high_res):
low_res: The tf tensor of the low res image, rescaled.
high_res: the tf tensor of the high res image, rescaled.
"""
low_res = low_res * 2.0 - 1.0
high_res = high_res * 2.0 - 1.0

return low_res, high_res
47 main.py
@@ -5,11 +5,10 @@
import os

parser = ArgumentParser()
parser.add_argument('--image_dir', default='/home/wattx/Downloads/experimental/div2k', type=str,
help='Path to high resolution image directory.')
parser.add_argument('--batch_size', default=12, type=int, help='Batch size for training.')
parser.add_argument('--epochs', default=3000, type=int, help='Number of epochs for training')
parser.add_argument('--hr_size', default=128, type=int, help='Low resolution input size.')
parser.add_argument('--image_dir', type=str, help='Path to high resolution image directory.')
parser.add_argument('--batch_size', default=8, type=int, help='Batch size for training.')
parser.add_argument('--epochs', default=1, type=int, help='Number of epochs for training')
parser.add_argument('--hr_size', default=384, type=int, help='Low resolution input size.')
parser.add_argument('--lr', default=1e-4, type=float, help='Learning rate for optimizers.')
parser.add_argument('--save_iter', default=200, type=int,
help='The number of iterations to save the tensorboard summaries and models.')
@@ -35,8 +34,7 @@ def pretrain_step(model, x, y):


def pretrain_generator(model, dataset, writer):
"""Function that pretrains the generator slightly, so that it can keep up
with the discriminator at the start.
"""Function that pretrains the generator slightly, to avoid local minima.
Args:
model: The keras model to train.
dataset: A tf dataset object of low and high res images to pretrain over.
@@ -46,7 +44,7 @@ def pretrain_generator(model, dataset, writer):
"""
with writer.as_default():
iteration = 0
for epoch in range(2):
for _ in range(1):
for x, y in dataset:
loss = pretrain_step(model, x, y)
if iteration % 20 == 0:
@@ -62,13 +60,13 @@ def train_step(model, x, y):
model: An object that contains a tf keras compiled discriminator model.
x: The low resolution input image.
y: The desired high resolution output image.
Returns:
d_loss: The mean loss of the discriminator.
"""
# Label smoothing for better gradient flow
valid = tf.ones((x.shape[0], 1)) - tf.random.uniform((x.shape[0], 1)) * 0.1
fake = tf.ones((x.shape[0], 1)) * tf.random.uniform((x.shape[0], 1)) * 0.1
valid = tf.ones((x.shape[0],) + model.disc_patch)
fake = tf.zeros((x.shape[0],) + model.disc_patch)

with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
# From low res. image generate high res. version
@@ -87,7 +85,7 @@ def train_step(model, x, y):
# Discriminator loss
valid_loss = tf.keras.losses.BinaryCrossentropy()(valid, valid_prediction)
fake_loss = tf.keras.losses.BinaryCrossentropy()(fake, fake_prediction)
d_loss = tf.divide(tf.add(valid_loss, fake_loss), 2)
d_loss = tf.add(valid_loss, fake_loss)

# Backprop on Generator
gen_grads = gen_tape.gradient(perceptual_loss, model.generator.trainable_variables)
@@ -121,7 +119,7 @@ def train(model, dataset, log_iter, writer):
tf.summary.scalar('Content Loss', content_loss, step=model.iterations)
tf.summary.scalar('MSE Loss', mse_loss, step=model.iterations)
tf.summary.scalar('Discriminator Loss', disc_loss, step=model.iterations)
tf.summary.image('Low Res', tf.cast(255 * (x + 1.0) / 2.0, tf.uint8), step=model.iterations)
tf.summary.image('Low Res', tf.cast(255 * x, tf.uint8), step=model.iterations)
tf.summary.image('High Res', tf.cast(255 * (y + 1.0) / 2.0, tf.uint8), step=model.iterations)
tf.summary.image('Generated', tf.cast(255 * (model.generator.predict(x) + 1.0) / 2.0, tf.uint8),
step=model.iterations)
@@ -142,22 +140,21 @@ def main():
# Create the tensorflow dataset.
ds = DataLoader(args.image_dir, args.hr_size).dataset(args.batch_size)

with tf.device('GPU:1'):
# Initialize the GAN object.
gan = FastSRGAN(args)
# Initialize the GAN object.
gan = FastSRGAN(args)

# Define the directory for saving pretrainig loss tensorboard summary.
pretrain_summary_writer = tf.summary.create_file_writer('logs/pretrain')
# Define the directory for saving pretrainig loss tensorboard summary.
pretrain_summary_writer = tf.summary.create_file_writer('logs/pretrain')

# Run pre-training.
pretrain_generator(gan, ds, pretrain_summary_writer)
# Run pre-training.
pretrain_generator(gan, ds, pretrain_summary_writer)

# Define the directory for saving the SRGAN training tensorbaord summary.
train_summary_writer = tf.summary.create_file_writer('logs/train')
# Define the directory for saving the SRGAN training tensorbaord summary.
train_summary_writer = tf.summary.create_file_writer('logs/train')

# Run training.
for epoch in range(args.epochs):
train(gan, ds, args.save_iter, train_summary_writer)
# Run training.
for _ in range(args.epochs):
train(gan, ds, args.save_iter, train_summary_writer)


if __name__ == '__main__':
@@ -22,7 +22,7 @@ def __init__(self, args):
self.iterations = 0

# Number of inverted residual blocks in the mobilenet generator
self.n_residual_blocks = 12
self.n_residual_blocks = 6

# Define a learning rate decay schedule.
self.gen_schedule = keras.optimizers.schedules.ExponentialDecay(
@@ -53,7 +53,7 @@ def __init__(self, args):

# Number of filters in the first layer of G and D
self.gf = 32 # Realtime Image Enhancement GAN Galteri et al.
self.df = 64
self.df = 32

# Build and compile the discriminator
self.discriminator = self.build_discriminator()
@@ -90,13 +90,13 @@ def build_generator(self):
Based on the Mobilenet design. Idea from Galteri et al."""

def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v

def residual_block(inputs, filters, block_id, expansion=6, stride=1, alpha=1.0):
"""Inverted Residual block that uses depth wise convolutions for parameter efficiency.
@@ -164,15 +164,16 @@ def residual_block(inputs, filters, block_id, expansion=6, stride=1, alpha=1.0):
return keras.layers.Add(name=prefix + 'add')([inputs, x])
return x

def deconv2d(layer_input):
def deconv2d(layer_input, filters):
"""Upsampling layer to increase height and width of the input.
Uses PixelShuffle for upsampling.
Args:
layer_input: The input tensor to upsample.
filters: Numbers of expansion filters.
Returns:
u: Upsampled input by a factor of 2.
"""
u = keras.layers.Conv2D(self.gf * 2 ** 2, kernel_size=3, strides=1, padding='same')(layer_input)
u = keras.layers.Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
u = tf.nn.depth_to_space(u, 2)
u = keras.layers.PReLU(shared_axes=[1, 2])(u)
return u
@@ -181,7 +182,8 @@ def deconv2d(layer_input):
img_lr = keras.Input(shape=self.lr_shape)

# Pre-residual block
c1 = keras.layers.Conv2D(self.gf, kernel_size=9, strides=1, padding='same')(img_lr)
c1 = keras.layers.Conv2D(self.gf, kernel_size=3, strides=1, padding='same')(img_lr)
c1 = keras.layers.BatchNormalization()(c1)
c1 = keras.layers.PReLU(shared_axes=[1, 2])(c1)

# Propogate through residual blocks
@@ -191,15 +193,15 @@ def deconv2d(layer_input):

# Post-residual block
c2 = keras.layers.Conv2D(self.gf, kernel_size=3, strides=1, padding='same')(r)
c2 = keras.layers.BatchNormalization(momentum=0.8)(c2)
c2 = keras.layers.BatchNormalization()(c2)
c2 = keras.layers.Add()([c2, c1])

# Upsampling
u1 = deconv2d(c2)
u2 = deconv2d(u1)
u1 = deconv2d(c2, self.gf * 4)
u2 = deconv2d(u1, self.gf * 4)

# Generate high resolution output
gen_hr = keras.layers.Conv2D(3, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)
gen_hr = keras.layers.Conv2D(3, kernel_size=3, strides=1, padding='same', activation='tanh')(u2)

return keras.models.Model(img_lr, gen_hr)

@@ -226,17 +228,13 @@ def d_block(layer_input, filters, strides=1, bn=True):

d1 = d_block(d0, self.df, bn=False)
d2 = d_block(d1, self.df, strides=2)
d3 = d_block(d2, self.df * 2)
d4 = d_block(d3, self.df * 2, strides=2)
d5 = d_block(d4, self.df * 4)
d6 = d_block(d5, self.df * 4, strides=2)
d7 = d_block(d6, self.df * 8)
d8 = d_block(d7, self.df * 8, strides=2)

d9 = keras.layers.Flatten()(d8)
d10 = keras.layers.Dense(1024)(d9)
d11 = keras.layers.LeakyReLU(alpha=0.2)(d10)

validity = keras.layers.Dense(1, activation='sigmoid')(d11)
d3 = d_block(d2, self.df)
d4 = d_block(d3, self.df, strides=2)
d5 = d_block(d4, self.df * 2)
d6 = d_block(d5, self.df * 2, strides=2)
d7 = d_block(d6, self.df * 2)
d8 = d_block(d7, self.df * 2, strides=2)

validity = keras.layers.Conv2D(1, kernel_size=1, strides=1, activation='sigmoid', padding='same')(d8)

return keras.models.Model(d0, validity)
BIN -153 KB (85%) models/generator.h5
Binary file not shown.

0 comments on commit 58c9494

Please sign in to comment.
You can’t perform that action at this time.