## superresolution

### options


In [0]:
normalize = True #if set to True, input and output images will be normalized to [0, 1]
separator_epochs = 50
epochs = 1000 #Number of epochs to train
scale = 4 #How much should we upscale images
channels = 3 #channels of low resolution image
batch_size = 8 #what batch-size should we use (decrease if you encounter video memory errors)
height_lr = 128 #height of low resolution image
width_lr = height_lr #width of low resolution image
learning_rate = 0.0001 #learning rate
logging_steps = epochs // 20 #how often to update the training log
discriminator_weight_file = 'sr_4x_discriminator_weights' #name of weight file
generator_weight_file = 'sr_4x_generator_weights' #name of weight file
separator_weight_file = 'sr_4x_separator_weights' #name of weight file
pruning_weight_file = 'sr_4x_pruning_weights' #name of weight file
filters = 16 #width of network
kernelSize = 3 #kernel size of convolution

### imports

In [0]:
import os
%cd /content
!git clone https://github.com/BenjaminWegener/superresolution #download Dataset
%cd superresolution
#%tensorflow_version 2.x
import tensorflow as tf
!pip install tensorflow-model-optimization
import tensorflow_model_optimization as tfmot
print(tf.__version__)
import numpy as np
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import LambdaCallback
from tensorflow.keras.applications.vgg16 import VGG16
from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline
from keras.preprocessing.image import ImageDataGenerator
from math import isnan
import random
import dill
import time
from PIL import Image
from skimage.transform import resize
from contextlib import contextmanager
import sys
@contextmanager
def silence_stdout():
    new_target = open(os.devnull, "w")
    old_target = sys.stdout
    sys.stdout = new_target
    try:
        yield new_target
    finally:
        sys.stdout = old_target   

### functions for image visualization

In [0]:
def show(tensors):
    plt.rcParams['figure.figsize'] = [20, 10]
    fig = plt.figure()
    for i in range(len(tensors)):
        cmap = 'gray'
        try:
            tensors[i] = np.squeeze(tensors[i], axis = 0)
        except:
            pass
        try:
            tensors[i] = np.squeeze(tensors[i], axis = 2)
        except:
            pass
        try:    
            depth = tensors[i].shape[2]
            cmap = None
        except:
            pass
        if normalize:
            tensors[i] = tensors[i] * 255
        tensors[i] = np.clip(tensors[i], 0, 255)
        fig.add_subplot(1, len(tensors), i + 1)
        plt.imshow(tensors[i].astype(np.uint8), cmap = cmap, interpolation = 'spline36')
        #plt.imshow(tensors[i].astype(np.uint8).squeeze(axis=2), cmap='gray', interpolation = 'spline36')
    plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[]);
    plt.show()

In [0]:
# show image in actual size https://stackoverflow.com/a/42314798/
def display_image_in_actual_size(im_data):
    try:
        im_data = np.squeeze(im_data, axis = 0)
    except:
        pass 
    im_data = np.clip(im_data, 0, 255)
    dpi = 100
    print(im_data.shape)
    height = im_data.shape[0]
    width = im_data.shape[1]
    figsize = width / float(dpi), height / float(dpi)
    fig = plt.figure(figsize=figsize)
    ax = fig.add_axes([0, 0, 1, 1])
    ax.axis('off')
    #ax.imshow(im_data.astype(np.uint8), cmap='gray', interpolation = 'spline36')
    ax.imshow(im_data.astype(np.uint8), interpolation = 'spline36')

### dataset function

In [0]:
# return batch of augmented train and target images with quantity n_samples
def get_batch_separator(n_samples, height, width):
    # define a ImageGenerator instance from keras with augmentations
    image_gen = ImageDataGenerator(rotation_range = 359,
                           width_shift_range = 2,
                           height_shift_range = 2,
                           zoom_range = [0.25, 0.8],
                           shear_range = 0.1,
                           horizontal_flip = True,
                           vertical_flip = True,
                           fill_mode = 'reflect',
                           data_format = 'channels_last',
                           interpolation_order = 5,
                           brightness_range = [0.5, 1.5])
    #seed for random augmentations
    random_seed = int(random.random() * 100000)
    #generate augmented images
    with silence_stdout():
        y_train = image_gen.flow_from_directory('.', 
                                                target_size = (height * scale, width * scale), 
                                                batch_size = n_samples, 
                                                class_mode = None,
                                                interpolation = 'lanczos', 
                                                seed = random_seed)
        y_train = y_train.__getitem__(0).copy() #fix for 'array doesn't own its data'
        x_train = np.empty((len(y_train), height, width, 3))
    for i in range(n_samples):
        #random_zoom = random.random() * 2.5 + 2.5 #random blur/zoom between 2.5 and 5
        #dummy = resize(x_train[i], (height // random_zoom, width // random_zoom, 3))
        #x_train[i] = resize(dummy, (height, width, 3))
        x_train[i] = resize(y_train[i], (height, width, 3), order = 5)
    if normalize:
        x_train = x_train / 255
        y_train = y_train / 255
    z_train = np.dot(y_train[...,:3], [0.33, 0.33, 0.33])
    return x_train, y_train, z_train

###initialize tpu backend

In [0]:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])

tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)

strategy = tf.distribute.experimental.TPUStrategy(tpu)
print("REPLICAS: ", strategy.num_replicas_in_sync)

###combined perceptual and L1 loss

In [0]:
from tensorflow.keras.applications.vgg16 import VGG16
perceptual_model = VGG16(include_top=False, weights='imagenet', input_shape=(height_lr, width_lr, 3)) 
perceptual_model.summary()

def multi_loss(y_true, y_pred): #combination of different losses
    loss = 1
    true = loss_model(y_true)
    pred = loss_model(y_pred)

    loss += K.mean(K.abs(pred - true)

    loss = loss * (K.mean(K.abs(y_true - y_pred)) + 1) # mean absolute error loss

    loss = loss - 1
    loss = K.clip(loss, 0.00001, 999)
    return loss

### build color separation model

In [0]:
def build_separator():
  inputs = Input(shape = (height_lr, width_lr, channels))
  gray = SeparableConv2D(1, 1, padding = 'same', name = 'color_sep')(inputs)
  color = inputs - gray
  
  gray = SeparableConv2D(16, kernel_size = 3, padding = 'same', activation = 'relu', name = 'gray_filters')(gray)
  #gray = upsample4xGray(gray) #wait for tf2.3-rc0 release
  gray = Conv2DTranspose(1, 4, 4, padding = 'same', name = 'gray_upsample')(gray)

  #color = UpSampling2D(size = 4, interpolation = 'bilinear')(color) #wait for tf2.3-rc0 release
  color = Conv2DTranspose(3, 4, 4, padding = 'same', name = 'color_upsample')(color)
  outputs = Add()([color, gray])
  return Model(inputs = inputs, outputs = [outputs, gray])

### define calllback

In [0]:
def logging_separator(epoch, logs):
  global this_time
  if (epoch % 200 == 0) and (epoch > 0):
    last_time = this_time
    this_time = time.time()
      
    clear_output()
    print('epoch', real_epoch + 1, '/', epochs, '--> step', (epoch), '/', steps_per_epoch
          , '| loss:', logs['loss'], '| time taken:', this_time - last_time
         )
    TFLITE_MODEL = "superresolution.tflite"
    run_model = tf.function(lambda x : separator(x))
    concrete_func = run_model.get_concrete_function(tf.TensorSpec(separator.inputs[0].shape, separator.inputs[0].dtype))
    converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
    #converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    converted_tflite_model = converter.convert()
    open(TFLITE_MODEL, "wb").write(converted_tflite_model)
    tflite_separator_interpreter = tf.lite.Interpreter(model_path=TFLITE_MODEL)
    testX, testY, testZ = get_batch_separator(1, height_lr, width_lr)
    input_details = tflite_separator_interpreter.get_input_details()
    output_details = tflite_separator_interpreter.get_output_details()
    
    tflite_separator_interpreter.allocate_tensors()
    tflite_separator_interpreter.set_tensor(input_details[0]['index'], testX.astype(np.float32))
    tflite_separator_interpreter.invoke()
    predY = tflite_separator_interpreter.get_tensor(output_details[0]['index'])
    predZ = tflite_separator_interpreter.get_tensor(output_details[1]['index'])

    show([testX[0], testY[0], predY[0], testZ[0], predZ[0]])

separator_logging_callback = LambdaCallback(
  on_epoch_end = lambda epoch, logs: logging_separator(epoch, logs)
)


### compile separator

In [0]:
with strategy.scope():
  separator = build_separator()
  separator.compile(optimizer = Adam(learning_rate), loss = 'mae')
  separator.summary()

### train separator

In [0]:
this_time = time.time()
print('trying to load last saved weights...', end = ' ') 
try:
  with open(separator_weight_file, 'rb') as file:
    separator.set_weights(dill.load(file))
  print('success.')
except:
  print('failed.')

X, Y, Z = get_batch_separator(batch_size, height_lr, width_lr)
test_loss = separator.evaluate(X, [Y, Z], return_dict = True)
if test_loss['loss'] > 0.08:
  for real_epoch in range(separator_epochs):
    X, Y, Z = get_batch_separator(batch_size, height_lr, width_lr)
    separator.fit(X, [Y, Z], batch_size, epochs = 200, verbose = 0, callbacks = [separator_logging_callback], shuffle = True)
        
    print('trying to save weights...', end = ' ')
    try:
      with open(separator_weight_file, 'wb') as file:
        dill.dump(separator.get_weights(), file)
      print('success.')
    except:
      print('failed.')

### save weights for generator to local variables

In [0]:
color_sep = separator.get_layer('color_sep').get_weights()
gray_filters = separator.get_layer('gray_filters').get_weights()
color_upsample = separator.get_layer('color_upsample').get_weights()
gray_upsample = separator.get_layer('gray_upsample').get_weights()

### dataset function for generator

In [0]:
# return batch of augmented train and target images with quantity n_samples
def get_batch_generator(n_samples, height, width):
    # define a ImageGenerator instance from keras with augmentations
    image_gen = ImageDataGenerator(rotation_range = 359,
                           width_shift_range = 2,
                           height_shift_range = 2,
                           zoom_range = [0.25, 0.8],
                           shear_range = 0.1,
                           horizontal_flip = True,
                           vertical_flip = True,
                           fill_mode = 'reflect',
                           data_format = 'channels_last',
                           interpolation_order = 5,
                           brightness_range = [0.5, 1.5])
    #seed for random augmentations
    random_seed = int(random.random() * 100000)
    #generate augmented images
    with silence_stdout():
        y_train = image_gen.flow_from_directory('.', 
                                                target_size = (height * scale, width * scale), 
                                                batch_size = n_samples, 
                                                class_mode = None,
                                                interpolation = 'lanczos', 
                                                seed = random_seed)
        y_train = y_train.__getitem__(0).copy() #fix for 'array doesn't own its data'
        x_train = np.empty((len(y_train), height, width, 3))
    for i in range(n_samples):
        #random_zoom = random.random() * 2.5 + 2.5 #random blur/zoom between 2.5 and 5
        #dummy = resize(x_train[i], (height // random_zoom, width // random_zoom, 3))
        #x_train[i] = resize(dummy, (height, width, 3))
        x_train[i] = resize(y_train[i], (height, width, 3), order = 5)
    if normalize:
        x_train = x_train / 255
        y_train = y_train / 255
    #y_train = np.dot(y_train[...,:3], [0.33, 0.33, 0.33])
    return x_train, y_train

### build models

In [0]:
def build_generator():
  inputs = Input(shape = (height_lr, width_lr, channels))
  gray = SeparableConv2D(1, 1, padding = 'same', name = 'color_sep', trainable = False)(inputs)
  color = inputs - gray
  
  gray = SeparableConv2D(16, kernel_size = 3, padding = 'same', activation = 'relu', name = 'gray_filters', trainable = False)(gray)
  
  for block in range (30):
    skip = gray
    #gray = Conv2D(32, 5, activation = 'swish', padding = 'same', strides = 2)(gray)
    #gray = Conv2DTranspose(16, 4, padding = 'same', strides = 2)(gray)
    gray = Conv2D(16, 3, activation = 'swish', padding = 'same', name = 'conv1_block' + str(block))(gray)
    gray = Conv2D(16, 3, padding = 'same', name = 'conv2_block' + str(block))(gray)
    attention = GlobalAveragePooling2D()(gray)
    attention = Dense(8, activation = 'swish', name = 'dense1_block' + str(block))(attention)
    attention = Dense(16, activation = 'sigmoid', name = 'dense2_block' + str(block))(attention)
    gray = Multiply()([gray, attention])
    gray = Add()([gray, skip])

  #gray = upsample4xGray(gray) #wait for tensorflow 2.3.0-rc0
  gray = Conv2DTranspose(1, 4, 4, padding = 'same', name = 'gray_upsample')(gray)
  #color = UpSampling2D(size = 4, interpolation = 'bilinear')(color) #wait for tensorflow 2.3.0-rc0
  color = Conv2DTranspose(3, 4, 4, padding = 'same', name = 'color_upsample')(color)
  outputs = Add()([color, gray])
  return Model(inputs = inputs, outputs = outputs)

def build_discriminator():
  x = inputs = Input(shape = (height_lr * scale, width_lr * scale, channels))
  x = Conv2D(30, 3, activation = 'swish')(x)
  x = Conv2D(60, 3, strides = 2, activation = 'swish')(x)
  x = Conv2D(120, 3, strides = 2, activation = 'swish')(x)
  x = Flatten()(x)
  outputs = Dense(units = 1, activation = 'sigmoid')(x)
  return Model(inputs = inputs, outputs = outputs)

###define callback for generator

In [0]:
def loggingGenerator(epoch, logs):
  global this_time
  if (epoch % logging_steps == 0) and (epoch > 0):
    last_time = this_time
    this_time = time.time()
      
    clear_output()
    print('epoch', real_epoch + 1, '/', epochs, '--> step', (epoch), '/', steps_per_epoch
          , '| loss:', logs['loss'], '| time taken:', this_time - last_time
         )
    TFLITE_MODEL = "superresolution.tflite"
    run_model = tf.function(lambda x : generator(x))
    concrete_func = run_model.get_concrete_function(tf.TensorSpec(separator.inputs[0].shape, generator.inputs[0].dtype))
    converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    converted_tflite_model = converter.convert()
    open(TFLITE_MODEL, "wb").write(converted_tflite_model)
    tflite_generator_interpreter = tf.lite.Interpreter(model_path=TFLITE_MODEL)
    testX, testY = get_batch_generator(1, height_lr, width_lr)
    input_details = tflite_generator_interpreter.get_input_details()
    output_details = tflite_generator_interpreter.get_output_details()
    
    tflite_generator_interpreter.allocate_tensors()
    tflite_generator_interpreter.set_tensor(input_details[0]['index'], testX.astype(np.float32))
    tflite_generator_interpreter.invoke()
    predY = tflite_generator_interpreter.get_tensor(output_details[0]['index'])
    
    show([testX[0], testY[0], predY[0]])

generator_logging_callback = LambdaCallback(
  on_epoch_end = lambda epoch, logs: loggingGenerator(epoch, logs)
)


### build model

In [0]:
with strategy.scope():
  loss_model = Model(inputs=perceptual_model.input, outputs=perceptual_model.get_layer('block2_conv2').output) 
  generator = build_generator()
  generator.compile(optimizer = Adam(lr=0.00001, beta_1=0.5, beta_2=0.999), loss = 'mae')

generator.get_layer('color_sep').set_weights(color_sep)
generator.get_layer('gray_filters').set_weights(gray_filters)
generator.get_layer('color_upsample').set_weights(color_upsample)
generator.get_layer('gray_upsample').set_weights(gray_upsample)

### load weights and train model

In [0]:
this_time = time.time()
print('trying to load last saved weights...', end = ' ') 
try:
    with open(weight_file, 'rb') as file:
        generator.set_weights(dill.load(file))
    print('success.')
except:
    print('failed.')

for real_epoch in range(epochs):
    X, Y = get_batch_generator(batch_size, height_lr, width_lr)
    generator.fit(X, Y, batch_size, epochs = steps_per_epoch + 1, verbose = 0, callbacks = [generator_logging_callback], shuffle = True)
    
    print('trying to save weights...', end = ' ')
    try:
        with open(weight_file, 'wb') as file:
            dill.dump(generator.get_weights(), file)
        print('success.')
    except:
        print('failed.')

### try several speed and size optimization strategies

In [0]:
#get layer weights
color_sep = generator.get_layer('color_sep').get_weights()
gray_filters = generator.get_layer('gray_filters').get_weights()
color_upsample = generator.get_layer('color_upsample').get_weights()
gray_upsample = generator.get_layer('gray_upsample').get_weights()
weights = {}

for x in range(30):
  weights["conv1_block{0}".format(x)] = generator.get_layer('conv1_block' + str(x)).get_weights()
  weights["conv2_block{0}".format(x)] = generator.get_layer('conv2_block' + str(x)).get_weights()
  weights["dense1_block{0}".format(x)] = generator.get_layer('dense1_block' + str(x)).get_weights()
  weights["dense2_block{0}".format(x)] = generator.get_layer('dense2_block' + str(x)).get_weights() 

In [0]:
#build varia√Ωqa 
weftion model
def buildGeneratorVariation():
  inputs = Input(shape = (height_lr, width_lr, channels))
  gray = SeparableConv2D(1, 1, padding = 'same', name = 'color_sep', trainable = False)(inputs)
  color = inputs - gray
  
  gray = SeparableConv2D(16, kernel_size = 3, padding = 'same', activation = 'relu', name = 'gray_filters', trainable = False)(gray)
  
  for block in range (30):
    skip = gray
    #gray = Conv2D(32, 5, activation = 'swish', padding = 'same', strides = 2)(gray)
    #gray = Conv2DTranspose(16, 4, padding = 'same', strides = 2)(gray)
    gray = DepthwiseConv2D(3, activation = 'swish', padding = 'same', name = 'conv1_block' + str(block))(gray)
    gray = Conv2D(16, 3, padding = 'same', name = 'conv2_block' + str(block), trainable = False)(gray)
    attention = GlobalAveragePooling2D()(gray)
    attention = Dense(8, activation = 'swish', name = 'dense1_block' + str(block), trainable = False)(attention)
    attention = Dense(16, activation = 'sigmoid', name = 'dense2_block' + str(block), trainable = False)(attention)
    gray = Multiply()([gray, attention])
    gray = Add()([gray, skip])

  #gray = upsample4xGray(gray) #wait for tensorflow 2.3.0-rc0
  gray = Conv2DTranspose(1, 4, 4, padding = 'same', name = 'gray_upsample', trainable = False)(gray)
  #color = UpSampling2D(size = 4, interpolation = 'bilinear')(color) #wait for tensorflow 2.3.0-rc0
  color = Conv2DTranspose(3, 4, 4, padding = 'same', name = 'color_upsample', trainable = False)(color)
  outputs = Add()([color, gray])
  return Model(inputs = inputs, outputs = outputs)

with strategy.scope():
  generator = buildGeneratorVariation()
  generator.compile(optimizer = Adam(gen_lr), loss = multi_loss)
  generator.summary()

In [0]:
#transfer old weights
generator.get_layer('color_sep').set_weights(color_sep)
generator.get_layer('gray_filters').set_weights(gray_filters)
generator.get_layer('color_upsample').set_weights(color_upsample)
generator.get_layer('gray_upsample').set_weights(gray_upsample)

for x in range(30):
  #generator.get_layer('conv1_block' + str(x)).set_weights(weights["conv1_block{0}".format(x)])
  generator.get_layer('conv2_block' + str(x)).set_weights(weights["conv2_block{0}".format(x)])
  generator.get_layer('dense1_block' + str(x)).set_weights(weights["dense1_block{0}".format(x)])
  generator.get_layer('dense2_block' + str(x)).set_weights(weights["dense2_block{0}".format(x)])

In [0]:
#train to test
for real_epoch in range(epochs):
    X, Y = get_batch_generator(batch_size, height_lr, width_lr)
    generator.fit(X, Y, batch_size, epochs = steps_per_epoch + 1, verbose = 0, callbacks = [generator_logging_callback], shuffle = True)
    


### prune the model for faster inference

In [0]:
epochs = 5 #Number of epochs to train
steps_per_epoch = 100 #How much iterations per epoch to train

In [0]:
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=epochs)
}

with strategy.scope():
  model_for_pruning = prune_low_magnitude(generator, **pruning_params)
  model_for_pruning.compile(optimizer = Adam(gen_lr), loss = multi_loss)
  
  model_for_pruning.summary()

In [0]:
this_time = time.time()
print('trying to load last saved weights...', end = ' ') 
try:
    with open(pruning_weight_file, 'rb') as file:
        model_for_pruning.set_weights(dill.load(file))
    print('success.')
except:
    print('failed.')

for real_epoch in range(epochs):
    X, Y = get_batch_generator(batch_size, height_lr, width_lr)
    model_for_pruning.fit(X, Y, batch_size, epochs = steps_per_epoch + 1, verbose = 1, callbacks = [tfmot.sparsity.keras.UpdatePruningStep()], shuffle = True)
    clear_output()
    print('trying to save weights...', end = ' ')
    try:
        with open(pruning_weight_file, 'wb') as file:
            dill.dump(model_for_pruning.get_weights(), file)
        print('success.')
    except:
        print('failed.')

In [0]:
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

### validate on complete picture

In [0]:
img = Image.open('./DIV2K-sample.png').convert('L')
img = np.array(img)
img = img.astype(np.float32)

print('ground truth:')
display_image_in_actual_size(img)

#generator = build_generator(img.shape)

print('trying to load last saved weights...', end = ' ') 
try:
    with open(weight_file, 'rb') as file:
        generator.set_weights(dill.load(file))
    print('success.')
except:
    print('failed.')


print('superresolution:')
predicted = generator.predict(np.expand_dims((img), 0))
print(predicted.shape)
display_image_in_actual_size(predicted.squeeze(3))
predicted = Image.fromarray(predicted.astype(np.uint8))
'''
print('trying to save image as \'superresolution_result.png\'...', end = ' ')
try:
    predicted.save('superresolution_result.png', "PNG")
    print('success.')
except:
    print('failed.')
    pass
'''

###export to tensorflow.js

In [0]:
generator.save('sr_tfjs.h5')
!pip install tensorflowjs
!tensorflowjs_converter --input_format=keras sr_tfjs.h5 model/
!ls -la
!zip -r model.zip model 
print('you can download model.zip from the menu...')