<a href="https://colab.research.google.com/github/abhipavi/SRGAN/blob/main/SRGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#NNDL Project
Submitted by
Abhinav Pavithran(17EC202) and Nadeem Roshan(17EC226)

## **Utilities**

In [None]:
%tensorflow_version 1.x

In [None]:
import tensorflow as tf
import numpy as np
import pickle
import skimage.transform
import skimage.filters
import datetime
import os
import shutil
import math
from scipy import misc
import scipy.ndimage
import glob

def process_individual_image(filename_queue, img_size, random_crop=False):
  """Individual loading & processing for each image"""
  image_file = tf.read_file(filename_queue)
  image = tf.image.decode_image(image_file, 3)
  if random_crop:
    # for training, take a random crop of the image
    image_shape = tf.shape(image)
    # if smaller than img_size, pad with 0s to prevent error
    image = tf.image.pad_to_bounding_box(image, 0, 0, tf.maximum(img_size, image_shape[0]), tf.maximum(img_size, image_shape[1]))
    image = tf.random_crop(image, size=[img_size, img_size, 3])
    image.set_shape((img_size, img_size, 3))
  else:
    # for testing, always take a center crop of the image
    image = tf.image.resize_image_with_crop_or_pad(image, img_size, img_size)
    image.set_shape((img_size, img_size, 3))
  return image

def build_input_pipeline(filenames, batch_size, img_size, random_crop=False, shuffle=True, num_threads=1):
  """Builds a tensor which provides randomly sampled pictures from the list of filenames provided"""
  train_file_list = tf.constant(filenames)
  filename_queue = tf.train.string_input_producer(train_file_list, shuffle=shuffle)
  image = process_individual_image(filename_queue.dequeue(), img_size, random_crop)
  image_batch = tf.train.batch([image], batch_size=batch_size,
                                           num_threads=num_threads,
                                           capacity=10 * batch_size)
  return image_batch

def build_inputs(args, sess):
  if args.overfit:
    # Overfit to a single image
    train_filenames = np.array(['overfit.png'])
    val_filenames = np.array(['overfit.png'])
    eval_filenames = np.array(['overfit.png'])
    #args.batch_size = 1
    args.num_test = 1
  else:
    # Regular dataset
    train_filenames = np.array(glob.glob(os.path.join(args.train_dir, '**', '*.*'), recursive=True))             #hguhuhuhuuhuhuuhuh
    val_filenames = np.array(glob.glob(os.path.join('/content/gdrive/MyDrive/SRGAN/Benchmarks', '**', '*_HR.png'), recursive=True))
    eval_indices = np.random.randint(len(train_filenames), size=len(val_filenames))
    eval_filenames = train_filenames[eval_indices[:119]]
  
  # Create input pipelines
  get_train_batch = build_input_pipeline(train_filenames, batch_size=args.batch_size, img_size=args.image_size, random_crop=True)
  get_val_batch = build_input_pipeline(val_filenames, batch_size=args.batch_size, img_size=args.image_size)
  get_eval_batch = build_input_pipeline(eval_filenames, batch_size=args.batch_size, img_size=args.image_size)
  return get_train_batch, get_val_batch, get_eval_batch

def downsample(image, factor):
  """Downsampling function which matches photoshop"""
  return scipy.misc.imresize(image, 1.0/factor, interp='bicubic')
  
def downsample_batch(batch, factor):
  downsampled = np.zeros((batch.shape[0], batch.shape[1]//factor, batch.shape[2]//factor, 3))
  for i in range(batch.shape[0]):
    downsampled[i,:,:,:] = downsample(batch[i,:,:,:], factor)
  return downsampled

def build_log_dir(args, arguments):
  """Set up a timestamped directory for results and logs for this training session"""
  if args.name:
    log_path = args.name #(name + '_') if name else ''
  else:
    log_path = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
  log_path = os.path.join('results', log_path)
  if not os.path.exists(log_path):
    os.makedirs(log_path)
  print('Logging results for this session in folder "%s".' % log_path)
  # Output csv header
  with open(log_path + '/loss.csv', 'a') as f:
    f.write('iteration, val_error, eval_error, set5_psnr, set5_ssim, set14_psnr, set14_ssim, bsd100_psnr, bsd100_ssim\n')
  # Copy this code to folder
  shutil.copy2('/content/gdrive/MyDrive/SRGAN/srgan.py', os.path.join(log_path, 'srgan.py'))
  shutil.copy2('/content/gdrive/MyDrive/SRGAN/train.py', os.path.join(log_path, 'train.py'))
  shutil.copy2('/content/gdrive/MyDrive/SRGAN/utilities.py', os.path.join(log_path, 'utilities.py'))
  # Write command line arguments to file
  with open(log_path + '/args.txt', 'w+') as f:
    f.write(' '.join(arguments))
  return log_path

def preprocess(lr, hr):
  """Preprocess lr and hr batch"""
  lr = lr / 255.0
  hr = (hr / 255.0) * 2.0 - 1.0
  return lr, hr

def save_image(path, data, highres=False):
  # transform from [-1, 1] to [0, 1]
  if highres:
    data = (data + 1.0) * 0.5
  # transform from [0, 1] to [0, 255], clip, and convert to uint8
  data = np.clip(data * 255.0, 0.0, 255.0).astype(np.uint8)
  misc.toimage(data, cmin=0, cmax=255).save(path)

def evaluate_model(loss_function, get_batch, sess, num_images, batch_size):
  """Tests the model over all num_images using input tensor get_batch"""
  loss = 0
  total = 0
  for i in range(int(math.ceil(num_images/batch_size))):
    batch_hr = sess.run(get_batch)
    batch_lr = downsample_batch(batch_hr, factor=4)
    batch_lr, batch_hr = preprocess(batch_lr, batch_hr)
    loss += sess.run(loss_function, feed_dict={'g_training:0': False, 'd_training:0': False, 'input_lowres:0': batch_lr, 'input_highres:0':batch_hr})
    total += 1
  loss = loss / total
  return loss


# VGG-19

In [None]:
import tensorflow as tf
import tensorflow.contrib.slim as slim

# VGG19 net
def vgg_19(inputs,
           num_classes=1000,
           is_training=False,
           dropout_keep_prob=0.5,
           spatial_squeeze=True,
           scope='vgg_19',
           reuse = False,
           fc_conv_padding='VALID'):
  """Oxford Net VGG 19-Layers version E Example.
  Note: All the fully_connected layers have been transformed to conv2d layers.
        To use in classification mode, resize input to 224x224.
  Args:
    inputs: a tensor of size [batch_size, height, width, channels].
    num_classes: number of predicted classes.
    is_training: whether or not the model is being trained.
    dropout_keep_prob: the probability that activations are kept in the dropout
      layers during training.
    spatial_squeeze: whether or not should squeeze the spatial dimensions of the
      outputs. Useful to remove unnecessary dimensions for classification.
    scope: Optional scope for the variables.
    fc_conv_padding: the type of padding to use for the fully connected layer
      that is implemented as a convolutional layer. Use 'SAME' padding if you
      are applying the network in a fully convolutional manner and want to
      get a prediction map downsampled by a factor of 32 as an output. Otherwise,
      the output prediction map will be (input / 32) - 6 in case of 'VALID' padding.
  Returns:
    the last op containing the log predictions and end_points dict.
  """
  with tf.variable_scope(scope, 'vgg_19', [inputs], reuse=reuse) as sc:
    end_points_collection = sc.name + '_end_points'
    # Collect outputs for conv2d, fully_connected and max_pool2d.
    with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
                        outputs_collections=end_points_collection):
      net = slim.repeat(inputs, 2, slim.conv2d, 64, 3, scope='conv1', reuse=reuse)
      net = slim.max_pool2d(net, [2, 2], scope='pool1')
      net = slim.repeat(net, 2, slim.conv2d, 128, 3, scope='conv2',reuse=reuse)
      net = slim.max_pool2d(net, [2, 2], scope='pool2')
      net = slim.repeat(net, 4, slim.conv2d, 256, 3, scope='conv3', reuse=reuse)
      net = slim.max_pool2d(net, [2, 2], scope='pool3')
      net = slim.repeat(net, 4, slim.conv2d, 512, 3, scope='conv4',reuse=reuse)
      net = slim.max_pool2d(net, [2, 2], scope='pool4')
      net = slim.repeat(net, 4, slim.conv2d, 512, 3, scope='conv5',reuse=reuse)
      net = slim.max_pool2d(net, [2, 2], scope='pool5')
      # Use conv2d instead of fully_connected layers.
      # Convert end_points_collection into a end_point dict.
      end_points = slim.utils.convert_collection_to_dict(end_points_collection)

      return net, end_points

# **Benchmark**

In [None]:
import numpy as np
import glob
import os
from scipy import misc
from skimage.measure import compare_ssim
from skimage.color import rgb2ycbcr,rgb2yuv

from skimage.measure import compare_psnr

class Benchmark:
  """A collection of images to test a model on."""

  def __init__(self, path, name):
    self.path = path
    self.name = name
    #self.images_lr, self.names = self.load_images_by_model(model='LR')
    self.images_hr, self.names = self.load_images_by_model(model='HR')
    self.images_lr = []
    for img in self.images_hr:
      self.images_lr.append(downsample(img, 4))
    
  def load_images_by_model(self, model, file_format='png'):
    """Loads all images that match '*_{model}.{file_format}' and returns sorted list of filenames and names"""
    # Get files that match the pattern
    filenames = sorted(glob.glob(os.path.join(self.path, '*_' + model + '.' + file_format)))
    # Extract name/prefix eg: '/.../baby_LR.png' -> 'baby'
    names = [os.path.basename(x).split('_')[0] for x in filenames]
    return self.load_images(filenames), names

  def load_images(self, images):
    """Given a list of file names, return a list of images"""
    out = []
    for image in images:
      out.append(misc.imread(image, mode='RGB').astype(np.uint8))
    return out

  def deprocess(self, image):
    """Deprocess image output by model (from -1 to 1 float to 0 to 255 uint8)"""
    image = np.clip(255 * 0.5 * (image + 1.0), 0.0, 255.0).astype(np.uint8)
    return image

  def luminance(self, image):
    # Get luminance
    lum = rgb2ycbcr(image)[:,:,0]
    # Crop off 4 border pixels
    lum = lum[4:lum.shape[0]-4, 4:lum.shape[1]-4]
    #lum = lum.astype(np.float64)
    return lum

  def PSNR(self, gt, pred):
    #gt = gt.astype(np.float64)
    #pred = pred.astype(np.float64)
    #mse = np.mean((pred - gt)**2)
    #psnr = 10*np.log10(255*255/mse)
    #return psnr
    return compare_psnr(gt, pred, data_range=255)
    
  def SSIM(self, gt, pred):
    ssim = compare_ssim(gt, pred, data_range=255, gaussian_weights=True)
    return ssim
    
  def test_images(self, gt, pred):
    """Applies metrics to compare image lists pred vs gt"""
    avg_psnr = 0
    avg_ssim = 0
    individual_psnr = []
    individual_ssim = []

    for i in range(len(pred)):
      # compare to gt
      psnr = self.PSNR(self.luminance(gt[i]), self.luminance(pred[i]))
      ssim = self.SSIM(self.luminance(gt[i]), self.luminance(pred[i]))
      # save results to log_path ex: 'results/experiment1/Set5/baby/1000.png'
      #if save_images:
      #  path = os.path.join(log_path, self.name, self.names[i])
      # gather results
      individual_psnr.append(psnr)
      individual_ssim.append(ssim)
      avg_psnr += psnr
      avg_ssim += ssim
    if(len(pred)>0):
      avg_psnr /= len(pred)
      avg_ssim /= len(pred)
    return avg_psnr, avg_ssim, individual_psnr, individual_ssim
    
  def validate(self):
    """Tests metrics by using images output by other models"""
    for model in ['bicubic', 'SRGAN-MSE', 'SRGAN-VGG22', 'SRGAN-VGG54', 'SRResNet-MSE', 'SRResNet-VGG22']:
      model_output,_ = self.load_images_by_model(model)
      psnr, ssim, _, _ = self.test_images(self.images_hr, model_output)
      print('Validate %-6s for %-14s: PSNR: %.2f, SSIM: %.4f' % (self.name, model, psnr, ssim))

  def save_image(self, image, path):
    if not os.path.exists(os.path.split(path)[0]):
      os.makedirs(os.path.split(path)[0])
    misc.toimage(image, cmin=0, cmax=255).save(path)

  def save_images(self, images, log_path, iteration):
    count = 0
    for output, lr, hr, name in zip(images, self.images_lr, self.images_hr, self.names):
      # Save output
      path = os.path.join(log_path, self.name, name, '%d_out.png' % iteration)
      self.save_image(output, path)
      # Save ground truth
      if(iteration<1):
        path = os.path.join(log_path, self.name, name, '%d_hr.png' % iteration)
        self.save_image(hr, path)
      # Save low res
      if(iteration<1):
        path = os.path.join(log_path, self.name, name, '%d_lr.png' % iteration)
        self.save_image(lr, path)

      # Hack so that we only do first 14 images in BSD100 instead of the whole thing
      count += 1
      if count >= 3:
        break

  def evaluate(self, sess, g_y_pred, log_path=None, iteration=0):
    """Evaluate benchmark, returning the score and saving images."""
    pred = []
    for i, lr in enumerate(self.images_lr):
      # feed images 1 by 1 because they have different sizes
      lr = lr / 255.0
      output = sess.run(g_y_pred, feed_dict={'d_training:0': False, 'g_training:0': False, 'input_lowres:0': lr[np.newaxis]})
      # deprocess output
      pred.append(self.deprocess(np.squeeze(output, axis=0)))
    # save images
    if log_path:
      self.save_images(pred, log_path, iteration)
    return self.test_images(self.images_hr, pred)


# **SRGAN**

In [None]:
import tensorflow as tf

class SRGanGenerator:
  """SRGAN Generator Model from Ledig et. al. 2017
  
  Reference: https://arxiv.org/pdf/1609.04802.pdf
  """
  def __init__(self, discriminator, training, content_loss='mse', use_gan=True, learning_rate=1e-4, num_blocks=16, num_upsamples=2):
    self.learning_rate = learning_rate
    self.num_blocks = num_blocks
    self.num_upsamples = num_upsamples
    self.use_gan = use_gan
    self.discriminator = discriminator
    self.training = training
    self.reuse_vgg = False
    if content_loss not in ['mse', 'L1', 'vgg22', 'vgg54']:
      print('Invalid content loss function. Must be \'mse\', \'vgg22\', or \'vgg54\'.')
      exit()
    self.content_loss = content_loss

  def ResidualBlock(self, x, kernel_size, filters, strides=1):
    """Residual block a la ResNet"""
    skip = x
    x = tf.layers.conv2d(x, kernel_size=kernel_size, filters=filters, strides=strides, padding='same', use_bias=False)
    x = tf.layers.batch_normalization(x, training=self.training)
    x = tf.contrib.keras.layers.PReLU(shared_axes=[1,2])(x)
    x = tf.layers.conv2d(x, kernel_size=kernel_size, filters=filters, strides=strides, padding='same', use_bias=False)
    x = tf.layers.batch_normalization(x, training=self.training)
    x = x + skip
    return x

  def Upsample2xBlock(self, x, kernel_size, filters, strides=1):
    """Upsample 2x via SubpixelConv"""
    x = tf.layers.conv2d(x, kernel_size=kernel_size, filters=filters, strides=strides, padding='same')
    x = tf.depth_to_space(x, 2)
    x = tf.contrib.keras.layers.PReLU(shared_axes=[1,2])(x)
    return x

  def forward(self, x):
    """Builds the forward pass network graph"""
    with tf.variable_scope('generator') as scope:
      x = tf.layers.conv2d(x, kernel_size=9, filters=64, strides=1, padding='same')
      x = tf.contrib.keras.layers.PReLU(shared_axes=[1,2])(x)
      skip = x

      # B x ResidualBlocks
      for i in range(self.num_blocks):
        x = self.ResidualBlock(x, kernel_size=3, filters=64, strides=1)

      x = tf.layers.conv2d(x, kernel_size=3, filters=64, strides=1, padding='same', use_bias=False)
      x = tf.layers.batch_normalization(x, training=self.training)
      x = x + skip

      # Upsample blocks
      for i in range(self.num_upsamples):
        x = self.Upsample2xBlock(x, kernel_size=3, filters=256)
      
      x = tf.layers.conv2d(x, kernel_size=9, filters=3, strides=1, padding='same', name='forward')
      return x
      
  def vgg_forward(self, x, layer, scope):
    # apply vgg preprocessing
    # move to range 0-255
    x = 255.0 * (0.5 * (x + 1.0))
    # subtract means
    mean = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32, shape=[1, 1, 1, 3], name='img_mean') # RGB means from VGG paper
    x = x - mean
    # convert to BGR
    x = x[:,:,:,::-1]
    # send through vgg19
    _,layers = vgg_19(x, is_training=False, reuse=self.reuse_vgg)
    self.reuse_vgg = True
    return layers[scope + layer]

  def _content_loss(self, y, y_pred):
    """MSE, VGG22, or VGG54"""
    if self.content_loss == 'mse':
      return tf.reduce_mean(tf.square(y - y_pred))
    if self.content_loss == 'L1':
      return tf.reduce_mean(tf.abs(y - y_pred))
    if self.content_loss == 'vgg22':
      with tf.name_scope('vgg19_1') as scope:
        vgg_y = self.vgg_forward(y, 'vgg_19/conv2/conv2_2', scope)
      with tf.name_scope('vgg19_2') as scope:
        vgg_y_pred = self.vgg_forward(y_pred, 'vgg_19/conv2/conv2_2', scope)
      return 0.006*tf.reduce_mean(tf.square(vgg_y - vgg_y_pred)) + 2e-8*tf.reduce_sum(tf.image.total_variation(y_pred))
      
    if self.content_loss == 'vgg54':
      with tf.name_scope('vgg19_1') as scope:
        vgg_y = self.vgg_forward(y, 'vgg_19/conv5/conv5_4', scope)
      with tf.name_scope('vgg19_2') as scope:
        vgg_y_pred = self.vgg_forward(y_pred, 'vgg_19/conv5/conv5_4', scope)
      return 0.006*tf.reduce_mean(tf.square(vgg_y - vgg_y_pred))

  def _adversarial_loss(self, y_pred):
    """For GAN."""
    y_discrim, y_discrim_logits = self.discriminator.forward(y_pred)
    return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=y_discrim_logits, labels=tf.ones_like(y_discrim_logits)))

  def loss_function(self, y, y_pred):
    """Loss function"""
    if self.use_gan:
      # Weighted sum of content loss and adversarial loss
      return self._content_loss(y, y_pred) + 1e-3*self._adversarial_loss(y_pred)
    # Content loss only
    return self._content_loss(y, y_pred)
  
  def optimize(self, loss):
    #tf.control_dependencies([discrim_train
    # update_ops needs to be here for batch normalization to work
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='generator')
    with tf.control_dependencies(update_ops):
      return tf.train.AdamOptimizer(self.learning_rate).minimize(loss, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator'))


class SRGanDiscriminator:
  """SRGAN Discriminator Model from Ledig et. al. 2017
  
  Reference: https://arxiv.org/pdf/1609.04802.pdf
  """
  def __init__(self, training, learning_rate=1e-4, image_size=96):
    self.graph_created = False
    self.learning_rate = learning_rate
    self.training = training
    self.image_size = image_size

  def ConvolutionBlock(self, x, kernel_size, filters, strides):
    """Conv2D + BN + LeakyReLU"""
    x = tf.layers.conv2d(x, kernel_size=kernel_size, filters=filters, strides=strides, padding='same', use_bias=False)
    x = tf.layers.batch_normalization(x, training=self.training)
    x = tf.contrib.keras.layers.LeakyReLU(alpha=0.2)(x)
    return x

  def forward(self, x):
    """Builds the forward pass network graph"""
    with tf.variable_scope('discriminator') as scope:
      # Reuse variables when graph is applied again
      if self.graph_created:
        scope.reuse_variables()
      self.graph_created = True

      # Image dimensions are fixed to the training size because of the FC layer
      x.set_shape([None, self.image_size, self.image_size, 3])

      x = tf.layers.conv2d(x, kernel_size=3, filters=64, strides=1, padding='same')
      x = tf.contrib.keras.layers.LeakyReLU(alpha=0.2)(x)

      x = self.ConvolutionBlock(x, 3, 64, 2)
      x = self.ConvolutionBlock(x, 3, 128, 1)
      x = self.ConvolutionBlock(x, 3, 128, 2)
      x = self.ConvolutionBlock(x, 3, 256, 1)
      x = self.ConvolutionBlock(x, 3, 256, 2)
      x = self.ConvolutionBlock(x, 3, 512, 1)
      x = self.ConvolutionBlock(x, 3, 512, 2)

      x = tf.contrib.layers.flatten(x)
      x = tf.layers.dense(x, 1024)
      x = tf.contrib.keras.layers.LeakyReLU(alpha=0.2)(x)
      logits = tf.layers.dense(x, 1)
      x = tf.sigmoid(logits)
      return x, logits

  def loss_function(self, y_real_pred, y_fake_pred, y_real_pred_logits, y_fake_pred_logits):
    """Discriminator wants to maximize log(y_real) + log(1-y_fake)."""
    loss_real = tf.reduce_mean(tf.losses.sigmoid_cross_entropy(tf.ones_like(y_real_pred_logits), y_real_pred_logits))
    loss_fake = tf.reduce_mean(tf.losses.sigmoid_cross_entropy(tf.zeros_like(y_fake_pred_logits), y_fake_pred_logits))
    return loss_real + loss_fake

  def optimize(self, loss):
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='discriminator')
    with tf.control_dependencies(update_ops):
      return tf.train.AdamOptimizer(self.learning_rate).minimize(loss, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator'))


# **Train**

In [None]:
print(scipy.__version__)

1.2.1


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
class arguments:
    load =  '/content/gdrive/MyDrive/SRGAN/logfolder/weights-28000'
    load_gen=None
    name = '/content/gdrive/MyDrive/SRGAN/logfolder' 
    overfit = False 
    batch_size = 16
    log_freq = 1000
    learning_rate=1e-4
    content_loss ='mse'
    use_gan=True
    image_size=92
    vgg_weights='/content/gdrive/MyDrive/SRGAN/vgg_19.ckpt'
    train_dir = '/content/gdrive/MyDrive/SRGAN/Train'
    validate_benchmarks = False 
    gpu = '0' 
args = arguments()

In [None]:
  import tensorflow as tf
  from tensorflow.python.training import queue_runner
  import numpy as np
  import os
  import sys
  tf.reset_default_graph()
  os.environ['CUDA_VISIBLE_DEVICES'] = '0'

  

  
  # Set up models
  d_training = tf.placeholder(tf.bool, name='d_training')
  g_training = tf.placeholder(tf.bool, name='g_training')
  discriminator = SRGanDiscriminator(training=g_training, image_size=args.image_size)
  generator = SRGanGenerator(discriminator=discriminator, training=d_training, learning_rate=args.learning_rate, content_loss=args.content_loss, use_gan=args.use_gan)
  # Generator
  g_x = tf.placeholder(tf.float32, [None, None, None, 3], name='input_lowres')
  g_y = tf.placeholder(tf.float32, [None, None, None, 3], name='input_highres')
  g_y_pred = generator.forward(g_x)
  g_loss = generator.loss_function(g_y, g_y_pred)
  g_train_step = generator.optimize(g_loss)
  # Discriminator
  d_x_real = tf.placeholder(tf.float32, [None, None, None, 3], name='input_real')
  d_y_real_pred, d_y_real_pred_logits = discriminator.forward(d_x_real)
  d_y_fake_pred, d_y_fake_pred_logits = discriminator.forward(g_y_pred)
  d_loss = discriminator.loss_function(d_y_real_pred, d_y_fake_pred, d_y_real_pred_logits, d_y_fake_pred_logits)
  d_train_step = discriminator.optimize(d_loss)
  
  # Set up benchmarks
  benchmarks = [Benchmark('/content/gdrive/MyDrive/SRGAN/Benchmarks/BSD100', name='BSD100')]
  if args.validate_benchmarks:
    for benchmark in benchmarks:
      benchmark.validate()

  # Create log folder
  if args.load and not args.name:
    log_path = os.path.dirname(args.load)
  else:
    log_path = build_log_dir(args, sys.argv)

  with tf.Session() as sess:
    # Build input pipeline
    get_train_batch, get_val_batch, get_eval_batch = build_inputs(args, sess)
    # Initialize
    sess.run(tf.local_variables_initializer())
    sess.run(tf.global_variables_initializer())
    # Start input pipeline thread(s)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
    # Load saved weights
    iteration = 0
    saver = tf.train.Saver()
    # Load generator
    if args.load_gen:
      gen_saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator'))
      iteration = int(args.load_gen.split('-')[-1])
      gen_saver.restore(sess, args.load_gen)
    # Load all
    if args.load:
      iteration = int(args.load.split('-')[-1])
      saver.restore(sess, args.load)
    # Load VGG
    if 'vgg' in args.content_loss:
      vgg_saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='vgg_19'))
      vgg_saver.restore(sess, args.vgg_weights)

    # Train
    while True:
      if iteration % args.log_freq == 0:
        # Test every log-freq iterations
        val_error = evaluate_model(g_loss, get_val_batch, sess, 119, args.batch_size)
        eval_error = evaluate_model(g_loss, get_eval_batch, sess, 119, args.batch_size)
        # Log error
        print('[%d] Test: %.7f, Train: %.7f' % (iteration, val_error, eval_error), end='')
        # Evaluate benchmarks
        log_line = ''
        for benchmark in benchmarks:
          psnr, ssim, _, _ = benchmark.evaluate(sess, g_y_pred, log_path, iteration)
          print(' [%s] PSNR: %.2f, SSIM: %.4f' %( benchmark.name, psnr, ssim), end='')
          log_line += ',%.7f, %.7f' %(psnr, ssim)
        print()
        # Write to log
        with open(log_path + '/loss.csv', 'a') as f:
          f.write('%d, %.15f, %.15f%s\n' % (iteration, val_error, eval_error, log_line))
        # Save checkpoint
        saver.save(sess, os.path.join(log_path, 'weights'), global_step=iteration, write_meta_graph=False)

      # Train discriminator
      if args.use_gan:
        batch_hr = sess.run(get_train_batch)
        batch_lr = downsample_batch(batch_hr, factor=2)
        batch_lr, batch_hr = preprocess(batch_lr, batch_hr)
        sess.run(d_train_step, feed_dict={d_training: True, g_training: True, g_x: batch_lr, g_y: batch_hr, d_x_real: batch_hr})
      # Train generator
      batch_hr = sess.run(get_train_batch)
      batch_lr = downsample_batch(batch_hr, factor=2)
      batch_lr, batch_hr = preprocess(batch_lr, batch_hr)
      sess.run(g_train_step, feed_dict={d_training: True, g_training: True, g_x: batch_lr, g_y: batch_hr})

      iteration += 1
      if(iteration%100==0):
        print("Iteration:",iteration)

    # Stop queue threads
    coord.request_stop()
    coord.join(threads)

`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.


Logging results for this session in folder "/content/gdrive/MyDrive/SRGAN/logfolder".
INFO:tensorflow:Restoring parameters from /content/gdrive/MyDrive/SRGAN/logfolder/weights-28000
[28000] Test: 0.0263971, Train: 0.0266956

`toimage` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use Pillow's ``Image.fromarray`` directly instead.


 [BSD100] PSNR: 25.49, SSIM: 0.6543
Iteration: 28100
Iteration: 28200
Iteration: 28300
Iteration: 28400
Iteration: 28500
Iteration: 28600
Iteration: 28700
Iteration: 28800
Iteration: 28900
Iteration: 29000
[29000] Test: 0.0256794, Train: 0.0231129 [BSD100] PSNR: 26.36, SSIM: 0.6799
Iteration: 29100
Iteration: 29200
Iteration: 29300
Iteration: 29400
Iteration: 29500
Iteration: 29600
Iteration: 29700
Iteration: 29800
Iteration: 29900
Iteration: 30000
[30000] Test: 0.0244063, Train: 0.0252296 [BSD100] PSNR: 25.96, SSIM: 0.6643
Iteration: 30100
Iteration: 30200
Iteration: 30300
Iteration: 30400
Iteration: 30500
Iteration: 30600
Iteration: 30700
Iteration: 30800
Iteration: 30900
Iteration: 31000
[31000] Test: 0.0246546, Train: 0.0261465 [BSD100] PSNR: 25.79, SSIM: 0.6464
Iteration: 31100
Iteration: 31200
Iteration: 31300
Iteration: 31400
Iteration: 31500
Iteration: 31600
Iteration: 31700
Iteration: 31800
Iteration: 31900
Iteration: 32000
[32000] Test: 0.0282151, Train: 0.0274004 [BSD100] PS

KeyboardInterrupt: ignored