In [1]:
import argparse
from functools import partial
import json
from keras import optimizers
from pathlib import Path

from toolbox.data import load_set
from toolbox.models import get_model
from toolbox.experiment import Experiment

from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import InputLayer
from keras.models import Sequential


import numpy as np
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import load_img

from toolbox.image import bicubic_rescale
from toolbox.image import modcrop
from toolbox.paths import data_dir

# import matplotlib
# matplotlib.use('Agg')
import matplotlib.pyplot as plt

Using TensorFlow backend.


In [17]:
param = json.load(open('./espcn-example.json'))

In [20]:
scale = param['scale']
build_model = partial(get_model(param['model']['name']),
                      **param['model']['params'])
if 'optimizer' in param:
    optimizer = getattr(optimizers, param['optimizer']['name'].lower())
    optimizer = optimizer(**param['optimizer']['params'])
else:
    optimizer = 'adam'
    

In [22]:
# Data
load_set = partial(load_set,
                   lr_sub_size=param['lr_sub_size'],
                   lr_sub_stride=param['lr_sub_stride'])

In [24]:
# Training
expt = Experiment(scale=param['scale'], load_set=load_set,
                  build_model=build_model,optimizer=optimizer,
                  save_dir=param['save_dir'])

In [94]:
def espcn(x, f=[5, 3, 3], n=[64, 32], scale=3):
    """Build an ESPCN model.

    See https://arxiv.org/abs/1609.05158
    """
    assert len(f) == len(n) + 1
    model = Sequential()
    model.add(InputLayer(input_shape=x.shape[1:]))
    c = x.shape[-1]
    for ni, fi in zip(n, f):
        model.add(Conv2D(ni, fi, padding='same',
                         kernel_initializer='he_normal', activation='tanh'))
    model.add(Conv2D(c * scale ** 2, f[-1], padding='same',
                     kernel_initializer='he_normal'))
    model.add(Conv2DSubPixel(scale))
    return model

In [20]:


def load_image_pair(path, scale=3):
    image = load_img(path)
    image = image.convert('YCbCr')
    hr_image = modcrop(image, scale)
    lr_image = bicubic_rescale(hr_image, 1 / scale)
    return lr_image, hr_image


def generate_sub_images(image, size, stride):
    for i in range(0, image.size[0] - size + 1, stride):
        for j in range(0, image.size[1] - size + 1, stride):
            yield image.crop([i, j, i + size, j + size])

     

In [95]:
lr_sub_size=11
lr_sub_stride=5
scale=3
name = '91-image'

hr_sub_size = lr_sub_size * scale
hr_sub_stride = lr_sub_stride * scale
lr_gen_sub = partial(generate_sub_images, size=lr_sub_size,
                     stride=lr_sub_stride)
hr_gen_sub = partial(generate_sub_images, size=hr_sub_size,
                     stride=hr_sub_stride)

lr_sub_arrays = []
hr_sub_arrays = []
for path in (data_dir / name).glob('*'):
    lr_image, hr_image = load_image_pair(str(path), scale=scale)
    lr_sub_arrays += [img_to_array(img) for img in lr_gen_sub(lr_image)]
    hr_sub_arrays += [img_to_array(img) for img in hr_gen_sub(hr_image)]
x = np.stack(lr_sub_arrays)
y = np.stack(hr_sub_arrays)

In [106]:
matplotlib.pyplot.ion()
imgplot = plt.imshow(lr_sub_arrays[0]/255)
show()

NameError: name 'show' is not defined

In [8]:
img = load_img('../data/91-image/t1.bmp')
x = img_to_array(img)
plt.imshow(x/255.)
plt.show()

In [9]:
data_dir

PosixPath('/home/yixu/Study/SuperResolition/srcnn/data')

In [98]:
lr_sub_arrays[0]/255

array([[[0.6313726 , 0.64705884, 0.49803922],
        [0.63529414, 0.63529414, 0.49803922],
        [0.6392157 , 0.627451  , 0.49803922],
        [0.63529414, 0.63529414, 0.5019608 ],
        [0.62352943, 0.6392157 , 0.5019608 ],
        [0.627451  , 0.6431373 , 0.49803922],
        [0.6509804 , 0.627451  , 0.5019608 ],
        [0.5019608 , 0.6627451 , 0.5019608 ],
        [0.5176471 , 0.68235296, 0.49803922],
        [0.5529412 , 0.6627451 , 0.49803922],
        [0.5647059 , 0.65882355, 0.49803922]],

       [[0.5764706 , 0.64705884, 0.49803922],
        [0.6313726 , 0.6431373 , 0.49803922],
        [0.6431373 , 0.6313726 , 0.49803922],
        [0.627451  , 0.63529414, 0.5019608 ],
        [0.6156863 , 0.6392157 , 0.5019608 ],
        [0.6039216 , 0.6431373 , 0.5019608 ],
        [0.6       , 0.6509804 , 0.5019608 ],
        [0.4745098 , 0.68235296, 0.5058824 ],
        [0.50980395, 0.68235296, 0.5019608 ],
        [0.54901963, 0.6627451 , 0.49803922],
        [0.5529412 , 0.6627451 ,

In [62]:
for path in (data_dir / name).glob('*'):
    print(path)
    image = load_img(path)
    x = img_to_array(image)
    plt.figure(1)
    imgplot = plt.imshow(x/255.0)
    plt.show()
    image = image.convert('YCbCr')
    hr_image = modcrop(image, scale)
    lr_image = bicubic_rescale(hr_image, 1 / scale)
    break

/home/yixu/Study/SuperResolition/srcnn/data/91-image/t63.bmp
<class 'PIL.BmpImagePlugin.BmpImageFile'>
<class 'numpy.ndarray'>
(202, 245, 3)
