In [1]:
%run env_setup.py
%matplotlib notebook

In [42]:
from matplotlib import pyplot as plt
#from tensorflow.contrib import keras
import keras
from keras import backend as K
import numpy as np
import bcolz
from lessdeep.utils import download_file, extract_file
from lessdeep.model.vgg16n import Vgg16N

# Super resolution

In [3]:
resized_72_path = extract_file(download_file('http://files.fast.ai/data/trn_resized_72.tar.gz'))
resized_288_path = extract_file(download_file('http://files.fast.ai/data/trn_resized_288.tar.gz'))

In [4]:
arr_lr = bcolz.open(resized_72_path) # Low resolution
arr_hr = bcolz.open(resized_288_path) # High resolution

In [31]:
def conv_block(x, filters, size, strides=(2,2), padding='same', activation='relu'):
    x = keras.layers.Conv2D(filters, kernel_size=size, strides=strides, padding=padding)(x)
    x = keras.layers.BatchNormalization()(x)
    return keras.layers.Activation(activation)(x) if activation else x

def res_block(x, filters=64):
    y = conv_block(x, filters, 3, strides=(1,1))
    y = conv_block(y, filters, 3, strides=(1,1), activation=None)
    return keras.layers.add([y, x])

def up_block(x, filters, size):
    x = keras.layers.UpSampling2D()(x)
    x = keras.layers.Conv2D(filters, size, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    return keras.layers.Activation('relu')(x)

def deconv_block(x, filters, size, strides=(2,2)):
    x = keras.layers.Conv2DTranspose(filters, size, strides=strides, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    return keras.layers.Activation('relu')(x)

In [32]:
inp = keras.layers.Input(arr_lr.shape[1:])
x = conv_block(inp, 64, 9, strides=(1,1))
for _ in range(4):
    x = res_block(x)
x = up_block(x, 64, 3)
x = up_block(x, 64, 3)
x = keras.layers.Conv2D(3, 9, activation='tanh', padding='same')(x)
out = keras.layers.Lambda(lambda x: (x + 1)*127.5)(x)

In [33]:
vgg = Vgg16N(include_top=False, image_size=arr_hr.shape[1:3])
vgg_inp = vgg.model.input

In [34]:
for l in vgg.model.layers: l.trainable=False

In [35]:
def get_layer_out(model, i):
    return model.model.layers[i].output

conv_layers = keras.models.Model(vgg_inp, [get_layer_out(vgg, i) for i in [1, 4, 7]])
vgg1 = conv_layers(out)
vgg2 = conv_layers(vgg_inp)

In [53]:
def mean_sqr(diff):
    return K.mean(K.sqrt(diff**2))

def lost_func(vec_concat):
    weight = [0.1, 0.8, 0.1]
    n = len(weight)
    res = 0
    return mean_sqr(vec_concat[0] - vec_concat[0+n])
    for i, w in enumerate(weight):
        res += mean_sqr(vec_concat[i] - vec_concat[i+n]) * w
    return res

In [54]:
content_lost = keras.Model([inp, vgg_inp], keras.layers.Lambda(lost_func)(vgg1 + vgg2))

In [43]:
tgt = np.zeros((len(arr_hr), 1))

In [None]:
content_lost.compile('adam', 'mse')

In [45]:
%pdb

Automatic pdb calling has been turned ON


In [None]:
def image_compare(img):
    fig, (ax1, ax2) = plt.subplots(1,2)
    ax1.imshow(img)
    ax2.imshow(img)
    plt.show()