In [1]:
import numpy as np
from keras.models import Model
from keras.layers import Input, BatchNormalization, Dense, Add, Reshape, Permute, Flatten
from keras.layers.convolutional import Conv2D
from keras.layers.advanced_activations import PReLU, LeakyReLU
from keras.optimizers import Adam
from keras.applications import VGG19
import datetime
import random
from PIL import Image
import glob
import cv2
import matplotlib.pyplot as plt


ratio = 4
LR_shape = (120, 160, 3)
g_weight_name = "g_weight_01.h5"
d_weight_name = "d_weight_01.h5"

L_h, L_w, channels = LR_shape
H_h = L_h * ratio
H_w = L_w * ratio
# HR_shape = np.array([H_h, H_w, c])
HR_shape = (H_h, H_w, channels)

optimizer = Adam()

Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
def denormalize(img):
    img = (img + 1) * 127.5
    return img.astype(np.uint8)


def save_images(epoch, hr_img, lr_img, sr_img):
    hr_img = denormalize(hr_img)
    lr_img = denormalize(lr_img)
    sr_img = denormalize(sr_img)
    
    cv2.imwrite("../../images/output/" + str(epoch+1) + "_HRimage.png", hr_img)
    cv2.imwrite("../../images/output/" + str(epoch+1) + "_LRimage.png", lr_img)
    cv2.imwrite("../../images/output/" + str(epoch+1) + "_SRimage.png", sr_img)

    
def calc_psnr(img1: np.ndarray, img2: np.ndarray):
    def convert(img):
        return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    def extract_y(image: np.ndarray) -> np.ndarray:
        if image.ndim == 2:
            return image
        image = image.astype(np.int32)
        return ((image[:, :, 2] * 65.481 / 255.
                  + image[:, :, 1] * 128.553 / 255.
                  + image[:, :, 0] * 24.966 / 255.) + 16).astype(np.int32)


    def psnr(img1, img2):
        mse = np.mean((img1 - img2) ** 2)
        if mse == 0:
            return 100
        PIXEL_MAX = 255.0
        return 10 * math.log10(PIXEL_MAX * PIXEL_MAX / mse)

    img1_conv=convert(img1)
    img2_conv=convert(img2)

    # BGR -> YCrCb
    # 画像はcv2.imreadで読まれている前提 [0, 255]
    y1 = extract_y(img1_conv)
    y2 = extract_y(img2_conv)
    # 周囲のcropping
    # assert y1.shape == y2.shape
    h, w = y1.shape
    cr = ratio
    cropped_y1 = y1[cr:h - cr, cr:w - cr].astype(np.float64)
    cropped_y2 = y2[cr:h - cr, cr:w - cr].astype(np.float64)

    # psnr
    psnr_val = psnr(cropped_y1, cropped_y2)
    return psnr_val


def load_data(batch_size):

    files = glob.glob("../../images/train/*.png", recursive=True)
    batch_images = random.sample(files, batch_size)

    hr_imgs = []
    lr_imgs = []
    for img_path in batch_images:
        img = Image.open(img_path)

        hr_img = img.resize((H_w, H_h))  #(64, 64)
        lr_img = img.resize((L_w, L_h))
        hr_img = np.array(hr_img)
        #img_hr = (img_hr - 127.5) / 127.5
        lr_img = np.array(lr_img)
        #img_lr = (img_lr - 127.5) / 127.5

        hr_imgs.append(hr_img)
        lr_imgs.append(lr_img)

    hr_imgs = np.array(hr_imgs) / 127.5 - 1.
    lr_imgs = np.array(lr_imgs) / 127.5 - 1.

    return hr_imgs, lr_imgs
      

In [3]:

def pixel_shuffle(in_map, h, w, c):#((120, 160, 12), 120, 160, 3)
    
    x = Reshape((h, w, 2, 2, c))(in_map)
    x = Permute((3, 1, 4, 2, 5))(x)
    out_map = Reshape((2 * h, 2 * w, c))(x)
    
    return out_map


def upsampling(in_map, h, w, c):#((, 120, 160, 3), g:120, 160 , 64
    
    x = Conv2D(filters = 4 * c, 
                     kernel_size = 3,
                     strides = 1,
                     padding = "same")(in_map)# x: 120, 160, 12
    x = pixel_shuffle(x, h, w, c)# ( , 120, 160, 3)
    out_map = PReLU()(x)
    
    return out_map


def residual_block(in_map):
    x = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(in_map)
    x = BatchNormalization()(x)
#     x = PReLU()(x)
    x = LeakyReLU(alpha = 0.2)(x)
    x = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(x)
    x = BatchNormalization()(x)
    out_map = Add()([x, in_map])
    return out_map


def d_block(in_map, filters, kernel_size, strides, padding):
    d = Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, padding = padding)(in_map)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha = 0.2)(d)
    return d


In [4]:
def build_generator():
    input_img = Input(shape = LR_shape)
    middle = Conv2D(filters = 64, kernel_size = 9, strides = 1, padding = "same")(input_img)
    middle = PReLU()(middle)
    
    g = residual_block(middle)
    for _ in range(4):
        g = residual_block(g)

    g = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(g)
    g = BatchNormalization()(g)
    g = Add()([g, middle])

    n = ratio
    i = 1
    while(n % 2 == 0):
        g = upsampling(g, L_h * i, L_w * i, channels)#(g, 120, 160, 3), g:120, 160 , 64
        i = i * 2
        n = n // 2

    output_img = Conv2D(filters = 3, kernel_size = 9, strides = 1, padding = "same")(g)

    return Model(input_img, output_img)


#------------------------------------------------------------------------#


def build_discriminator():
    input_img = Input(shape = HR_shape)
    
    d = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(input_img)
    d = d_block(d, filters = 64, kernel_size = 3, strides = 2, padding = "same")
    d = d_block(d, filters = 128, kernel_size = 3, strides = 1, padding = "same")
    d = d_block(d, filters = 128, kernel_size = 3, strides = 2, padding = "same")
    d = d_block(d, filters = 256, kernel_size = 3, strides = 1, padding = "same")
    d = d_block(d, filters = 256, kernel_size = 3, strides = 2, padding = "same")
    d = d_block(d, filters = 512, kernel_size = 3, strides = 1, padding = "same")
    d = d_block(d, filters = 512, kernel_size = 3, strides = 2, padding = "same")
    d = Flatten()(d)
    d = Dense(256)(d)
    d = LeakyReLU(alpha = 0.2)(d)
    output = Dense(1, activation = "sigmoid")(d)

    return Model(input_img, output)

#------------------------------------------------------------------------#

def build_vgg():
    vgg = VGG19(include_top = False)
    return Model(vgg.input, vgg.layers[9].output)
    
#------------------------------------------------------------------------#

def combined(generator, discriminator, vgg):
    input_img = Input(shape = LR_shape)
    fake_img = generator(input_img)
    
    validity = discriminator(fake_img)
    features = vgg(fake_img)
    
    return Model(input_img, [validity, features])




In [5]:
losses = []
epochs_checkpoint = []
psnr = []

def train(epochs, batch_size, interval):
    
    start_time = datetime.datetime.now()
    
    real = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    
    for epoch in range(epochs):
        
        real_imgs, lr_imgs = load_data(batch_size)
        fake_imgs = generator.predict(lr_imgs)
        
        #Dの訓練
        d_loss_real = discriminator.train_on_batch(real_imgs, real)
        d_loss_fake = discriminator.train_on_batch(fake_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        #Gの訓練
        vgg_features = vgg.predict(real_imgs)
        g_loss = srgan.train_on_batch(lr_imgs, [real, vgg_features])
        
        time = datetime.datetime.now() - start_time
        print("%d time: %s" % (epoch+1, time))
        
        if (epoch+1) % interval == 0:
            print("epoch: %d" % (epoch+1))
            losses.append((d_loss, g_loss))
            epochs_checkpoint.append(epoch+1)
            save_images(epoch, real_imgs[0], lr_imgs[0], fake_imgs[0])
#             psnr.append(calc_psnr(real_imgs[0], fake_imgs[0]))
            
    generator.save_weights(g_weight_name)
    discriminator.save_weights(d_weight_name)
    print("save seights")
        

In [6]:
#------------------------------------
# メインプログラム
#------------------------------------

#Discriminator
discriminator = build_discriminator()
discriminator.compile(loss = "mse",
                      optimizer = optimizer,
                      metrics = ["accuracy"])
discriminator.summary()

#Generator
generator = build_generator()
generator.summary()
vgg = build_vgg()
vgg.trainable = False
discriminator.trainable = False
srgan = combined(generator, discriminator, vgg)
srgan.compile(loss=['binary_crossentropy', 'mse'],
                              loss_weights=[1e-3, 1],
                              optimizer=optimizer)
srgan.summary()





Instructions for updating:
keep_dims is deprecated, use keepdims instead

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 480, 640, 3)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 480, 640, 64)      1792      
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 240, 320, 64)      36928     
_________________________________________________________________
batch_normalization_1 (Batch (None, 240, 320, 64)      256       
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 240, 320, 64)      0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 240, 320, 128)     73856     
_________________________________________________________________



Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_4 (InputLayer)             (None, 120, 160, 3)   0                                            
____________________________________________________________________________________________________
model_2 (Model)                  (None, 480, 640, 3)   2813432                                      
____________________________________________________________________________________________________
model_1 (Model)                  (None, 1)             161979713                                    
____________________________________________________________________________________________________
model_3 (Model)                  multiple              1735488                                  

In [None]:
train(epochs = 1000, batch_size = 1, interval = 10)

Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
1 time: 0:00:14.601319
2 time: 0:00:15.082756
3 time: 0:00:15.566779
4 time: 0:00:16.035002
5 time: 0:00:16.501269
6 time: 0:00:16.972697
7 time: 0:00:17.448129
8 time: 0:00:17.921062
9 time: 0:00:18.383063
10 time: 0:00:18.855492
epoch: 10
11 time: 0:00:19.350446
12 time: 0:00:19.823937
13 time: 0:00:20.281947
14 time: 0:00:20.756377
15 time: 0:00:21.234812
16 time: 0:00:21.715847
17 time: 0:00:22.165108
18 time: 0:00:22.650270
19 time: 0:00:23.124700
20 time: 0:00:23.601133
epoch: 20
21 time: 0:00:24.085237
22 time: 0:00:24.569657
23 time: 0:00:25.050094
24 time: 0:00:25.525029
25 time: 0:00:25.988008
26 time: 0:00:26.459044
27 time: 0:00:26.934475
28 time: 0:00:27.419916
29 time: 0:00:27.906653
30 time: 0:00:28.377923
epoch: 30
31 time: 0:00:28.891894
32 time: 0:00:29.372330
33 time: 0:00:29.856770
34 time: 0:00:30.321919
35 time: 0:00:30.801355
36 time: 0:00:31.2873

306 time: 0:02:34.577981
307 time: 0:02:35.020152
308 time: 0:02:35.479992
309 time: 0:02:35.938350
310 time: 0:02:36.380222
epoch: 310
311 time: 0:02:36.856880
312 time: 0:02:37.327090
313 time: 0:02:37.790511
314 time: 0:02:38.241920
315 time: 0:02:38.691328
316 time: 0:02:39.141242
317 time: 0:02:39.591650
318 time: 0:02:40.040563
319 time: 0:02:40.494975
320 time: 0:02:40.946384
epoch: 320
321 time: 0:02:41.425819
322 time: 0:02:41.882739
323 time: 0:02:42.344663
324 time: 0:02:42.794070
325 time: 0:02:43.254488
326 time: 0:02:43.743932
327 time: 0:02:44.227372
328 time: 0:02:44.704804
329 time: 0:02:45.153715
330 time: 0:02:45.602627
epoch: 330
331 time: 0:02:46.078563
332 time: 0:02:46.524976
333 time: 0:02:46.987396
334 time: 0:02:47.448815
335 time: 0:02:47.902226
336 time: 0:02:48.351138
337 time: 0:02:48.807057
338 time: 0:02:49.264472
339 time: 0:02:49.735902
340 time: 0:02:50.201326
epoch: 340
341 time: 0:02:50.684897
342 time: 0:02:51.133809
343 time: 0:02:51.588726
344 ti

In [9]:
for epoch in range(len(epochs_checkpoint)):
    print("epoch: %d  d_loss: %.3f  g_loss: %.3f  psnr: %.3f"
         .format(epochs_checkpoint[epoch],
                losses[epoch, 0],
                losses[epoch, 1],
                psnr[epoch]))
    

In [11]:
print(len(losses))

0


In [20]:
from PIL import Image
D = Discriminator(480, 640, 3)
img = Image.open("../../images/train/image_1.png")
img = np.array(img)
valid = D.discriminate(img)
print(valid)

Tensor("dense_2/Sigmoid:0", shape=(?, 30, 40, 1), dtype=float32)


In [None]:
import keras