<a href="https://colab.research.google.com/github/artiumb/EDSR_iris/blob/master/EDSR_Iris.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
# based on https://github.com/Saafke/EDSR_Tensorflow
#  paper: http://openaccess.thecvf.com/content_cvpr_2017_workshops/w12/papers/Lim_Enhanced_Deep_Residual_CVPR_2017_paper.pdf

In [0]:
# data_utils.py 
import pathlib
import os
from PIL import Image
import numpy as np
import cv2
import tensorflow as tf
import random

def getpathsx(path):
    """
    Get all image paths from folder 'path'.
    """
    data = pathlib.Path(path)
    all_image_paths = list(data.glob('*'))
    all_image_paths = [str(p) for p in all_image_paths]
    return all_image_paths

def getpaths(path):
    """
    Get all image paths from folder 'path' while avoiding ._ files.
    """
    im_paths = []
    for fil in os.listdir(path):
            if '.png' in fil:
                if "._" in fil:
                    #avoid dot underscore
                    pass
                else:
                    im_paths.append(os.path.join(path, fil))
    return im_paths

def make_val_dataset(paths, scale, mean):
    """
    Python generator-style dataset for the validation set. Creates input and ground truth.
    """
    for p in paths:
        # normalize
        im_norm = cv2.imread(p.decode(), 3).astype(np.float32) - mean

        # divisible by scale - create low-res
        hr = im_norm[0:(im_norm.shape[0] - (im_norm.shape[0] % scale)),
                  0:(im_norm.shape[1] - (im_norm.shape[1] % scale)), :]
        lr = cv2.resize(hr, (int(hr.shape[1] / scale), int(hr.shape[0] / scale)),
                        interpolation=cv2.INTER_CUBIC)

        yield lr, hr

def make_dataset(paths, scale, mean):
    """
    Python generator-style dataset. Creates 48x48 low-res and corresponding high-res patches.
    """
    size_lr = 48
    size_hr = size_lr * scale

    for p in paths:
        # normalize
        im_norm = cv2.imread(p.decode(), 3).astype(np.float32) - mean

        # random flip
        r = random.randint(-1, 2)
        if not r == 2:
            im_norm = cv2.flip(im_norm, r)

        # divisible by scale - create low-res
        hr = im_norm[0:(im_norm.shape[0] - (im_norm.shape[0] % scale)),
                  0:(im_norm.shape[1] - (im_norm.shape[1] % scale)), :]
        lr = cv2.resize(hr, (int(hr.shape[1] / scale), int(hr.shape[0] / scale)),
                        interpolation=cv2.INTER_CUBIC)

        numx = int(lr.shape[0] / size_lr)
        numy = int(lr.shape[1] / size_lr)

        for i in range(0, numx):
            startx = i * size_lr
            endx = (i * size_lr) + size_lr

            startx_hr = i * size_hr
            endx_hr = (i * size_hr) + size_hr

            for j in range(0, numy):
                starty = j * size_lr
                endy = (j * size_lr) + size_lr
                starty_hr = j * size_hr
                endy_hr = (j * size_hr) + size_hr

                crop_lr = lr[startx:endx, starty:endy]
                crop_hr = hr[startx_hr:endx_hr, starty_hr:endy_hr]

                x = crop_lr.reshape((size_lr, size_lr, 3))
                y = crop_hr.reshape((size_hr, size_hr, 3))

                yield x, y

def calcmean(imageFolder, bgr):
    """
    Calculates the mean of a dataset.
    """
    paths = getpaths(imageFolder)

    total_mean = [0, 0, 0]
    im_counter = 0

    for p in paths:

        image = np.asarray(Image.open(p))

        mean_rgb = np.mean(image, axis=(0, 1), dtype=np.float64)

        if im_counter % 50 == 0:
            print("Total mean: {} | current mean: {}".format(total_mean, mean_rgb))

        total_mean += mean_rgb
        im_counter += 1

    total_mean /= im_counter

    # rgb to bgr
    if bgr is True:
        total_mean = total_mean[...,::-1]

    return total_mean


In [0]:
# edsr.py
from __future__ import print_function

import cv2
import tensorflow as tf
import numpy as np
import os

class Edsr:

    def __init__(self, B, F, scale):
        self.B = B
        self.F = F
        self.scale = scale
        self.global_step = tf.placeholder(tf.int32, shape=[], name="global_step")
        self.scaling_factor = 0.1
        self.bias_initializer = tf.constant_initializer(value=0.0)
        self.PS = 3 * (scale*scale) #channels x scale^2
        self.xavier = tf.contrib.layers.xavier_initializer()

        # -- Filters & Biases --
        self.resFilters = list()
        self.resBiases = list()

        for i in range(0, B*2):
            self.resFilters.append( tf.get_variable("resFilter%d" % (i), shape=[3,3,F,F], initializer=self.xavier))
            self.resBiases.append(tf.get_variable(name="resBias%d" % (i), shape=[F], initializer=self.bias_initializer))

        self.filter_one = tf.get_variable("resFilter_one", shape=[3,3,3,F], initializer=self.xavier)
        self.filter_two = tf.get_variable("resFilter_two", shape=[3,3,F,F], initializer=self.xavier)
        self.filter_three = tf.get_variable("resFilter_three", shape=[3,3,F,self.PS], initializer=self.xavier)

        self.bias_one = tf.get_variable(shape=[F], initializer=self.bias_initializer, name="BiasOne")
        self.bias_two = tf.get_variable(shape=[F], initializer=self.bias_initializer, name="BiasTwo")
        self.bias_three = tf.get_variable(shape=[self.PS], initializer=self.bias_initializer, name="BiasThree")


    def model(self, x, y, lr):
        """
        Implementation of EDSR: https://arxiv.org/abs/1707.02921.
        """

        # -- Model architecture --

        # first conv
        x = tf.nn.conv2d(x, filter=self.filter_one, strides=[1, 1, 1, 1], padding='SAME')
        x = x + self.bias_one
        out1 = tf.identity(x)

        # all residual blocks
        for i in range(self.B):
            x = self.resBlock(x, (i*2))

        # last conv
        x = tf.nn.conv2d(x, filter=self.filter_two, strides=[1, 1, 1, 1], padding='SAME')
        x = x + self.bias_two
        x = x + out1

        # upsample via sub-pixel, equivalent to depth to space
        x = tf.nn.conv2d(x, filter=self.filter_three, strides=[1, 1, 1, 1], padding='SAME')
        x = x + self.bias_three
        out = tf.nn.depth_to_space(x, self.scale, data_format='NHWC', name="NHWC_output")
        
        # -- --

        # some outputs
        out_nchw = tf.transpose(out, [0, 3, 1, 2], name="NCHW_output")
        psnr = tf.image.psnr(out, y, max_val=255.0)
        loss = tf.losses.absolute_difference(out, y) #L1
        ssim = tf.image.ssim(out, y, max_val=255.0)
        
        # (decaying) learning rate
        lr = tf.train.exponential_decay(lr,
                                        self.global_step,
                                        decay_steps=15000,
                                        decay_rate=0.95,
                                        staircase=True)
        # gradient clipping
        optimizer = tf.train.AdamOptimizer(lr)
        gradients, variables = zip(*optimizer.compute_gradients(loss))
        gradients, _ = tf.clip_by_global_norm(gradients, 5.0)
        train_op = optimizer.apply_gradients(zip(gradients, variables))

        return out, loss, train_op, psnr, ssim, lr

    def resBlock(self, inpt, f_nr):
        x = tf.nn.conv2d(inpt, filter=self.resFilters[f_nr], strides=[1, 1, 1, 1], padding='SAME')
        x = x + self.resBiases[f_nr]
        x = tf.nn.relu(x)

        x = tf.nn.conv2d(x, filter=self.resFilters[f_nr+1], strides=[1, 1, 1, 1], padding='SAME')
        x = x + self.resBiases[f_nr+1]
        x = x * self.scaling_factor

        return inpt + x

In [0]:
# main.py 
import tensorflow as tf
import data_utils
import run
import os
import cv2
import numpy as np
import pathlib
import argparse
from PIL import Image
import numpy
from tensorflow.python.client import device_lib

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #gets rid of avx/fma warning

# TODO:
# When starting training for x3 and x4, start from pre-trained x2 model.

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # bools
    parser.add_argument('--train', help='Train the model', action="store_true")
    parser.add_argument('--test', help='Run PSNR test on an image', action="store_true")
    parser.add_argument('--upscale', help='Upscale an image with desired scale', action="store_true")
    parser.add_argument('--export', help='Export the model as .pb', action="store_true")
    parser.add_argument('--fromscratch', help='Load previous model for training',action="store_false")

    # numbers
    parser.add_argument('--quant', type=int, help='Quantize to shrink .pb file size. 1=round_weights. 2=quantize_weights. 3=round_weights&quantize.', default=0)
    parser.add_argument('--B', type=int, help='Number of resBlocks', default=32)
    parser.add_argument('--F', type=int, help='Number of filters', default=256)
    parser.add_argument('--scale', type=int, help='Scaling factor of the model', default=2)
    parser.add_argument('--batch', type=int, help='Batch size of the training', default=16)
    parser.add_argument('--epochs', type=int, help='Number of epochs during training', default=20)
    parser.add_argument('--lr', type=float, help='Learning_rate', default=0.0001)

    # paths
    parser.add_argument('--image', help='Specify test image', default="./images/original.png")
    parser.add_argument('--traindir', help='Path to train images')
    parser.add_argument('--validdir', help='Path to train images')
    args = parser.parse_args()

    # INIT
    scale = args.scale
    meanbgr = [103.1545782, 111.561547, 114.35629928]

    # Set checkpoint paths for different scales and models
    ckpt_path = ""
    if scale == 2:
        ckpt_path = "./CKPT_dir/x2/"
    elif scale == 3:
        ckpt_path = "./CKPT_dir/x3/"
    elif scale == 4:
        ckpt_path = "./CKPT_dir/x4/"
    else:
        print("No checkpoint directory. Choose scale 2, 3 or 4. Or add checkpoint directory for this scale.")
        exit()

    # Set gpu
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    # Create run instance
    run = run.run(config, ckpt_path, scale, args.batch, args.epochs, args.B, args.F, args.lr, args.fromscratch, meanbgr)

    if args.train:
        run.train(args.traindir, args.validdir)

    if args.test:
        run.testFromPb(args.image)
        #run.test(args.image)
    
    if args.upscale:
        run.upscaleFromPb(args.image)
        #run.upscale(args.image)

    if args.export:
        run.export(args.quant)

    print("I ran successfully.")
    

In [0]:
# run.py

import tensorflow as tf
import os
import cv2
import numpy as np
import math
import data_utils
from skimage import io
import edsr
from PIL import Image

from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import optimize_for_inference_lib
from tensorflow.tools.graph_transforms import TransformGraph

class run:
    def __init__(self, config, ckpt_path, scale, batch, epochs, B, F, lr, load_flag, meanBGR):
        self.config = config
        self.ckpt_path = ckpt_path
        self.scale = scale
        self.batch = batch
        self.epochs = epochs
        self.B = B
        self.F = F
        self.lr = lr
        self.load_flag = load_flag
        self.mean = meanBGR

    def train(self, imagefolder, validfolder):

        # Create training dataset
        train_image_paths = data_utils.getpaths(imagefolder)
        train_dataset = tf.data.Dataset.from_generator(generator=data_utils.make_dataset,
                                                 output_types=(tf.float32, tf.float32),
                                                 output_shapes=(tf.TensorShape([None, None, 3]), tf.TensorShape([None, None, 3])),
                                                 args=[train_image_paths, self.scale, self.mean])
        train_dataset = train_dataset.padded_batch(self.batch, padded_shapes=([None, None, 3],[None, None, 3]))

        # Create validation dataset
        val_image_paths = data_utils.getpaths(validfolder)
        val_dataset = tf.data.Dataset.from_generator(generator=data_utils.make_val_dataset,
                                                 output_types=(tf.float32, tf.float32),
                                                 output_shapes=(tf.TensorShape([None, None, 3]), tf.TensorShape([None, None, 3])),
                                                 args=[val_image_paths, self.scale, self.mean])
        val_dataset = val_dataset.padded_batch(1, padded_shapes=([None, None, 3],[None, None, 3]))

        # Make the iterator and its initializers
        train_val_iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
        train_initializer = train_val_iterator.make_initializer(train_dataset)
        val_initializer = train_val_iterator.make_initializer(val_dataset)

        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(handle, train_dataset.output_types, train_dataset.output_shapes)
        LR, HR = iterator.get_next()

        # Edsr model
        print("\nRunning EDSR.")
        edsrObj = edsr.Edsr(self.B, self.F, self.scale)
        out, loss, train_op, psnr, ssim, lr = edsrObj.model(x=LR, y=HR, lr=self.lr)

        # -- Training session
        with tf.Session(config=self.config) as sess:

            train_writer = tf.summary.FileWriter('./logs/train', sess.graph)
            sess.run(tf.global_variables_initializer())

            saver = tf.train.Saver()

            # Create check points directory if not existed, and load previous model if specified.
            if not os.path.exists(self.ckpt_path):
                os.makedirs(self.ckpt_path)
            else:
                if os.path.isfile(self.ckpt_path + "edsr_ckpt" + ".meta"):
                    if self.load_flag:
                        saver.restore(sess, tf.train.latest_checkpoint(self.ckpt_path))
                        print("\nLoaded checkpoint.")
                    if not self.load_flag:
                        print("No checkpoint loaded. Training from scratch.")
                # else:
                #     if os.path.isfile("./CKPT_dir/x2/" + "edsr_ckpt" + ".meta"):
                #         saver.restore(sess, tf.train.latest_checkpoint("./CKPT_dir/x2/"))
                #         print("Previous checkpoint does not exists. Will load model from x2")
                #     else:
                #         print("No checkpoint loaded. Training from scratch.")

            global_step = 0
            tf.convert_to_tensor(global_step)

            train_val_handle = sess.run(train_val_iterator.string_handle())

            print("Training...")
            for e in range(1, self.epochs+1):

                sess.run(train_initializer)
                step, train_loss = 0, 0

                try:
                    while True:
                        o, l, t, l_rate = sess.run([out, loss, train_op, lr], feed_dict={handle:train_val_handle,
                                                                                         edsrObj.global_step: global_step})
                        train_loss += l
                        step += 1
                        global_step += 1

                        if step % 1000 == 0:
                            save_path = saver.save(sess, self.ckpt_path + "edsr_ckpt")
                            print("Step nr: [{}/{}] - Loss: {:.5f} - Lr: {:.7f}".format(step, "?", float(train_loss/step), l_rate))

                except tf.errors.OutOfRangeError:
                    pass

                # Perform end-of-epoch calculations here.
                sess.run(val_initializer)
                tot_val_psnr, tot_val_ssim, val_im_cntr = 0, 0, 0
                try:
                    while True:
                        val_psnr, val_ssim = sess.run([psnr, ssim], feed_dict={handle:train_val_handle})

                        tot_val_psnr += val_psnr[0]
                        tot_val_ssim += val_ssim[0]
                        val_im_cntr += 1

                except tf.errors.OutOfRangeError:
                    pass

                print("Epoch nr: [{}/{}]  - Loss: {:.5f} - val PSNR: {:.3f} - val SSIM: {:.3f}\n".format(e,
                                                                                                         self.epochs,
                                                                                                         float(train_loss/step),
                                                                                                         (tot_val_psnr / val_im_cntr),
                                                                                                         (tot_val_ssim / val_im_cntr)))
                save_path = saver.save(sess, self.ckpt_path + "edsr_ckpt")

            print("Training finished.")
            train_writer.close()

    def upscale(self, path):
        """
        Upscales an image via model. This loads a checkpoint, not a .pb file.
        """
        fullimg = cv2.imread(path, 3)

        floatimg = fullimg.astype(np.float32) - self.mean

        LR_input_ = floatimg.reshape(1, floatimg.shape[0], floatimg.shape[1], 3)

        with tf.Session(config=self.config) as sess:
            print("\nUpscale image by a factor of {}:\n".format(self.scale))
            # load the model
            ckpt_name = self.ckpt_path + "edsr_ckpt" + ".meta"
            saver = tf.train.import_meta_graph(ckpt_name)
            saver.restore(sess, tf.train.latest_checkpoint(self.ckpt_path))
            graph_def = sess.graph
            LR_tensor = graph_def.get_tensor_by_name("IteratorGetNext:0")
            HR_tensor = graph_def.get_tensor_by_name("NHWC_output:0")

            output = sess.run(HR_tensor, feed_dict={LR_tensor: LR_input_})

            Y = output[0]
            HR_image = (Y + self.mean).clip(min=0, max=255)
            HR_image = (HR_image).astype(np.uint8)

            bicubic_image = cv2.resize(fullimg, None, fx=self.scale, fy=self.scale, interpolation=cv2.INTER_CUBIC)

            cv2.imshow('Original image', fullimg)
            cv2.imshow('EDSR upscaled image', HR_image)
            cv2.imshow('Bicubic upscaled image', bicubic_image)
            cv2.waitKey(0)

        sess.close()

    def test(self, path):
        """
        Test single image and calculate psnr. This loads a checkpoint, not a .pb file.
        """
        fullimg = cv2.imread(path, 3)
        width = fullimg.shape[0]
        height = fullimg.shape[1]

        cropped = fullimg[0:(width - (width % self.scale)), 0:(height - (height % self.scale)), :]
        img = cv2.resize(cropped, None, fx=1. / self.scale, fy=1. / self.scale, interpolation=cv2.INTER_CUBIC)
        floatimg = img.astype(np.float32) - self.mean

        LR_input_ = floatimg.reshape(1, floatimg.shape[0], floatimg.shape[1], 3)

        with tf.Session(config=self.config) as sess:
            print("\nTest model with psnr:\n")
            # load the model
            ckpt_name = self.ckpt_path + "edsr_ckpt" + ".meta"
            saver = tf.train.import_meta_graph(ckpt_name)
            saver.restore(sess, tf.train.latest_checkpoint(self.ckpt_path))
            graph_def = sess.graph
            LR_tensor = graph_def.get_tensor_by_name("IteratorGetNext:0")
            HR_tensor = graph_def.get_tensor_by_name("NHWC_output:0")

            output = sess.run(HR_tensor, feed_dict={LR_tensor: LR_input_})

            Y = output[0]
            HR_image = (Y + self.mean).clip(min=0, max=255)
            HR_image = (HR_image).astype(np.uint8)

            bicubic_image = cv2.resize(img, None, fx=self.scale, fy=self.scale, interpolation=cv2.INTER_CUBIC)

            print(np.amax(Y), np.amax(LR_input_))
            print("PSNR of  EDSR   upscaled image: {}".format(self.psnr(cropped, HR_image)))
            print("PSNR of bicubic upscaled image: {}".format(self.psnr(cropped, bicubic_image)))

            cv2.imshow('Original image', fullimg)
            cv2.imshow('EDSR upscaled image', HR_image)
            cv2.imshow('Bicubic upscaled image', bicubic_image)

            cv2.imwrite("./images/EdsrOutput.png", HR_image)
            cv2.imwrite("./images/BicubicOutput.png", bicubic_image)
            cv2.imwrite("./images/original.png", fullimg)
            cv2.imwrite("./images/input.png", img)

            cv2.waitKey(0)
            cv2.destroyAllWindows()

        sess.close()

    def load_pb(self, path_to_pb):
        with tf.gfile.GFile(path_to_pb, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(graph_def, name='')
            return graph

    def testFromPb(self, path):
        """
        Test single image and calculate psnr. This loads a .pb file.
        """
        # Read model
        pbPath = "./models/EDSR_x{}.pb".format(self.scale)

        # Get graph
        graph = self.load_pb(pbPath)

        fullimg = cv2.imread(path, 3)
        width = fullimg.shape[0]
        height = fullimg.shape[1]

        cropped = fullimg[0:(width - (width % self.scale)), 0:(height - (height % self.scale)), :]
        img = cv2.resize(cropped, None, fx=1. / self.scale, fy=1. / self.scale, interpolation=cv2.INTER_CUBIC)
        floatimg = img.astype(np.float32) - self.mean

        LR_input_ = floatimg.reshape(1, floatimg.shape[0], floatimg.shape[1], 3)

        LR_tensor = graph.get_tensor_by_name("IteratorGetNext:0")
        HR_tensor = graph.get_tensor_by_name("NHWC_output:0")

        with tf.Session(graph=graph) as sess:
            print("Loading pb...")
            output = sess.run(HR_tensor, feed_dict={LR_tensor: LR_input_})
            Y = output[0]
            HR_image = (Y + self.mean).clip(min=0, max=255)
            HR_image = (HR_image).astype(np.uint8)

            bicubic_image = cv2.resize(img, None, fx=self.scale, fy=self.scale, interpolation=cv2.INTER_CUBIC)

            print(np.amax(Y), np.amax(LR_input_))
            print("PSNR of  EDSR   upscaled image: {}".format(self.psnr(cropped, HR_image)))
            print("PSNR of bicubic upscaled image: {}".format(self.psnr(cropped, bicubic_image)))

            cv2.imshow('Original image', fullimg)
            cv2.imshow('EDSR upscaled image', HR_image)
            cv2.imshow('Bicubic upscaled image', bicubic_image)

            cv2.imwrite("./images/EdsrOutput.png", HR_image)
            cv2.imwrite("./images/BicubicOutput.png", bicubic_image)
            cv2.imwrite("./images/original.png", fullimg)
            cv2.imwrite("./images/input.png", img)

            cv2.waitKey(0)
            cv2.destroyAllWindows()
            print("Done.")

        sess.close()

    def upscaleFromPb(self, path):
        """
        Upscale single image by desired model. This loads a .pb file.
        """
        # Read model
        pbPath = "./models/EDSR_x{}.pb".format(self.scale)

        # Get graph
        graph = self.load_pb(pbPath)

        fullimg = cv2.imread(path, 3)
        floatimg = fullimg.astype(np.float32) - self.mean
        LR_input_ = floatimg.reshape(1, floatimg.shape[0], floatimg.shape[1], 3)

        LR_tensor = graph.get_tensor_by_name("IteratorGetNext:0")
        HR_tensor = graph.get_tensor_by_name("NHWC_output:0")

        with tf.Session(graph=graph) as sess:
            print("Loading pb...")
            output = sess.run(HR_tensor, feed_dict={LR_tensor: LR_input_})
            Y = output[0]
            HR_image = (Y + self.mean).clip(min=0, max=255)
            HR_image = (HR_image).astype(np.uint8)

            bicubic_image = cv2.resize(fullimg, None, fx=self.scale, fy=self.scale, interpolation=cv2.INTER_CUBIC)

            cv2.imshow('Original image', fullimg)
            cv2.imshow('EDSR upscaled image', HR_image)
            cv2.imshow('Bicubic upscaled image', bicubic_image)

            cv2.waitKey(0)
            cv2.destroyAllWindows()

        sess.close()

    def export(self, quant):
        print("Exporting model...")

        export_dir = "./models/"
        if not os.path.exists(export_dir):
                os.makedirs(export_dir)

        export_file = "EDSRorig_x{}.pb".format(self.scale)

        graph = tf.get_default_graph()
        with graph.as_default():
            with tf.Session(config=self.config) as sess:

                ### Restore checkpoint
                ckpt_name = self.ckpt_path + "edsr_ckpt" + ".meta"
                saver = tf.train.import_meta_graph(ckpt_name)
                saver.restore(sess, tf.train.latest_checkpoint(self.ckpt_path))

                # Return a serialized GraphDef representation of this graph
                graph_def = sess.graph.as_graph_def()

                # All variables to constants
                graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, ['NCHW_output'])

                # Optimize for inference
                graph_def = optimize_for_inference_lib.optimize_for_inference(graph_def, ["IteratorGetNext"],
                                                                            ["NCHW_output"],  # ["NHWC_output"],
                                                                            tf.float32.as_datatype_enum)
                
                # Implement certain file shrinking transforms. 2 is recommended.
                transforms = ["sort_by_execution_order"]
                if quant == 1:
                    print("Rounding weights for export.")
                    transforms = ["sort_by_execution_order", "round_weights"]
                    export_file = "EDSR_x{}_q1.pb".format(self.scale)
                if quant == 2:
                    print("Quantizing for export.")
                    transforms = ["sort_by_execution_order", "quantize_weights"]
                    export_file = "EDSR_x{}.pb".format(self.scale)
                if quant == 3:
                    print("Round weights and quantizing for export.")
                    transforms = ["sort_by_execution_order", "round_weights", "quantize_weights"]
                    export_file = "EDSR_x{}_q3.pb".format(self.scale)

                graph_def = TransformGraph(graph_def, ["IteratorGetNext"],
                                                      ["NCHW_output"],
                                                      transforms)
                
                print("Exported file = {}".format(export_dir+export_file))
                with tf.gfile.GFile(export_dir + export_file, 'wb') as f:
                    f.write(graph_def.SerializeToString())

                tf.train.write_graph(graph_def, ".", 'train.pbtxt')

        sess.close()

    def psnr(self, img1, img2):
        mse = np.mean( (img1 - img2) ** 2 )
        if mse == 0:
            return 100
        PIXEL_MAX = 255.0
        return (20 * math.log10(PIXEL_MAX / math.sqrt(mse)))