In [1]:
from tensorflow.keras.layers import Input, Dense, Conv2D, MaxPooling2D, Dropout
from tensorflow.keras.layers import Conv2DTranspose, UpSampling2D, add
from skimage.transform import resize, rescale
from tensorflow.keras.models import Model
from tensorflow.keras import regularizers
import matplotlib.pyplot as plt
from scipy import ndimage, misc
from matplotlib import pyplot
import tensorflow as tf
import numpy as np
np.random.seed(0)
import re
import os
import cv2
import matplotlib.image as mpimg
from tensorflow.keras.layers import Dense, Flatten, Input, Conv2D, LeakyReLU
import functools

In [2]:
class ResInResDenseBlock(tf.keras.layers.Layer):
    """Residual imn n Residual Dense Block"""
    def __init__(self, nf=64, gc=32, res_beta=0.2, wd=0., name='RRDB',
                 **kwargs):
        super(ResInResDenseBlock, self).__init__(name=name, **kwargs)
        self.res_beta = res_beta
        self.rdb_1 = ResDenseBlock_5C(nf, gc, res_beta=res_beta, wd=wd)
        self.rdb_2 = ResDenseBlock_5C(nf, gc, res_beta=res_beta, wd=wd)
        self.rdb_3 = ResDenseBlock_5C(nf, gc, res_beta=res_beta, wd=wd)
        
        
    def call(self, x):
        out = self.rdb_1(x)
        out = self.rdb_2(out)
        out = self.rdb_3(out)
        return out * self.res_beta + x

class ResDenseBlock_5C(tf.keras.layers.Layer):
    """Residual Dense Block"""
    def __init__(self, nf=64, gc=32, res_beta=0.2, wd=0., name='RDB5C',**kwargs):
        super(ResDenseBlock_5C, self).__init__(name=name, **kwargs)
        # gc: growth channel, i.e. intermediate channels
        self.res_beta = res_beta
        lrelu_f = functools.partial(LeakyReLU, alpha=0.2)
        _Conv2DLayer = functools.partial(
            Conv2D, kernel_size=3, padding='same',
            kernel_initializer=_kernel_init(0.1), bias_initializer='zeros',
            kernel_regularizer=_regularizer(wd))
        self.conv1 = _Conv2DLayer(filters=gc, activation=lrelu_f())
        self.conv2 = _Conv2DLayer(filters=gc, activation=lrelu_f())
        self.conv3 = _Conv2DLayer(filters=gc, activation=lrelu_f())
        self.conv4 = _Conv2DLayer(filters=gc, activation=lrelu_f())
        self.conv5 = _Conv2DLayer(filters=nf, activation=lrelu_f())

    def call(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(tf.concat([x, x1], 3))
        x3 = self.conv3(tf.concat([x, x1, x2], 3))
        x4 = self.conv4(tf.concat([x, x1, x2, x3], 3))
        x5 = self.conv5(tf.concat([x, x1, x2, x3, x4], 3))
        return x5 * self.res_beta + x
    #custom_objects={'ResInResDenseBlock': ResInResDenseBlock,'ResDenseBlock_5C':ResDenseBlock_5C,"LeakyReLU":LeakyReLU}
def _kernel_init(scale=1.0, seed=None):
    """He normal initializer with scale."""
    scale = 2. * scale
    return tf.keras.initializers.VarianceScaling(
        scale=scale, mode='fan_in', distribution="truncated_normal", seed=seed)
def _regularizer(weights_decay=5e-4):
    return tf.keras.regularizers.l2(weights_decay)

In [3]:
model =  tf.keras.models.load_model('model9.h5',custom_objects={'ResInResDenseBlock': ResInResDenseBlock,'ResDenseBlock_5C':ResDenseBlock_5C,'LeakyReLU':LeakyReLU} )



In [4]:
model.summary()

Model: "RRDB_model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_image (InputLayer)        [(None, None, None,  0                                            
__________________________________________________________________________________________________
conv_first (Conv2D)             (None, None, None, 6 1792        input_image[0][0]                
__________________________________________________________________________________________________
RRDB_trunk (Sequential)         (None, None, None, 6 16546752    conv_first[0][0]                 
__________________________________________________________________________________________________
conv_trunk (Conv2D)             (None, None, None, 6 36928       RRDB_trunk[0][0]                 
_________________________________________________________________________________________

In [5]:
import cv2

In [None]:
for root, dirnames, filenames in os.walk("./all_data_0.25/"):
    for filename in filenames:
        if re.search(".(jpg|jpeg|JPEG|png|bmp|tiff)$", filename):
            filepath = os.path.join(root, filename)
            Lbaby = mpimg.imread(filepath)
            Lbaby = np.array(Lbaby)
            Lbaby = [Lbaby]
            Lbaby = np.array(Lbaby)
            print("********************************************************")
            print(str(os.path.splitext(os.path.basename(os.path.normpath(filepath)))[0]))
            t0= time.time()
            Hbaby = np.clip(model.predict(Lbaby), 0.0, 1.0)
            t1 = time.time() - t0
            print("Time elapsed: ", t1)
            print("********************************************************")