In [0]:
import scipy
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import cifar10

class DataLoader():
  def __init__(self,dataset_name,img_res=(128,128)):
    self.dataset_name=dataset_name
    self.img_res=img_res
  def load_data(self,batch_size=1,is_testing=False):
    (x,y),(_,_)=cifar10.load_data()
    
    data_type="train" if not is_testing else "test"
    batch_images=np.random.choice(range(x.shape[0]),size=batch_size)
    imgs_hr=[]
    imgs_lr=[]
    
    for img_index in batch_images:
      img=x[img_index,:,:,:]
      h,w=self.img_res
      low_h,low_w=int(h/4),int(w/4)
      
      img_hr=scipy.misc.imresize(img,self.img_res)
      img_lr=scipy.misc.imresize(img,(low_h,low_w))
      
      if not is_testing and np.random.random()<0.5:
        img_hr=np.fliplr(img_hr)
        img_lr=np.fliplr(img_lr)
      imgs_hr.append(img_hr)
      imgs_lr.append(img_lr)
    imgs_hr=np.array(imgs_hr)/127.5 -1.
    imgs_lr=np.array(imgs_lr)/127.5 - 1.
    
    return(imgs_hr,imgs_lr)

In [0]:
from __future__ import print_function,division
import scipy
from keras.layers import BatchNormalization,Input,Dense,Reshape,Flatten,Dropout,Concatenate,Activation,ZeroPadding2D,Add,Conv2D,UpSampling2D
from keras.layers.advanced_activations import PReLU,LeakyReLU
from keras.applications import VGG19
from keras.models import Model,Sequential
from keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import numpy as np
import os
import sys

import keras.backend as k

In [0]:
class SRGAN():
  def __init__(self):
    
    #Image input shape
    self.channels=3
    self.lr_height=64
    self.lr_width=64
    
    self.lr_shape=(self.lr_height,self.lr_width,self.channels)
    
    self.hr_height=self.lr_height*4
    self.hr_width=self.lr_width*4
    
    self.hr_shape=(self.hr_height,self.hr_width,self.channels)
    
    #No of residual blocks in our generator:
    self.n_residual_blocks=16
    
    optimizer=Adam(0.0002,0.5)
    
    #We use pretrained VGG19 model to extract features from the high resolution and the generated high resolution images and we
    #minimize the mse between them
    self.vgg=self.build_vgg()
    self.vgg.trainable=False
    self.vgg.compile(loss='mse',optimizer=optimizer,metrics=['accuracy'])
    
    #Configure our data loader
    self.dataset_name="cifar_dataset"
    self.data_loader=DataLoader(dataset_name=self.dataset_name,img_res=(self.hr_height,self.hr_width))
    
    #Cal output shape of D
    patch=int(self.hr_height/2**4)
    self.disc_patch=(patch,patch,1)
    
    #Number of filters in the first layer of Generator and Discriminator
    self.gf=64
    self.df=64
    
    #Build and compile the discriminator
    self.discriminator=self.build_discriminator()
    self.discriminator.compile(loss='mse',optimizer=optimizer,metrics=['accuracy'])
    
    #Build the generator
    self.generator=self.build_generator()
    
    #High res. and low res. images:
    img_hr=Input(shape=self.hr_shape)
    img_lr=Input(shape=self.lr_shape)
    
    #Generate high resolution images from low resolution using our generator:
    fake_hr=self.generator(img_lr)
    
    #Extract VGG19 features of fake_hr image
    fake_features=self.vgg(fake_hr)
    
    #In combined model, just generator should be trainable.So setting discriminator as non-trainable
    self.discriminator.trainable=False
    
    #Discriminator will determine whether the hr image is real or fake:
    validity=self.discriminator(fake_hr)
    
    #Our combined model
    self.combined=Model([img_lr,img_hr],[validity,fake_features])
    self.combined.compile(loss=['binary_crossentropy','mse'],loss_weights=[1e-3,1],optimizer=optimizer)
    
  def build_vgg(self):    
    #Builds a pretrained VGG-19 model that outputs image features extracted at the third block of the model
    vgg=VGG19(weights='imagenet')
    
    #set outputs to the output of last conv layer in block3
    vgg.outputs=[vgg.layers[9].output]
    
    img=Input(shape=self.hr_shape)
    
    #Extract image features:
    img_features=vgg(img)
    
    return(Model(img,img_features))
    
    
  def build_generator(self):
    def residual_block(layer_input,filters):
      #Residual block as described in the paper
      d=Conv2D(filters,kernel_size=3,strides=1,padding='same')(layer_input)
      d=Activation('relu')(d)
      d=BatchNormalization(momentum=0.8)(d)
      d=Conv2D(filters,kernel_size=3,strides=1,padding='same')(d)
      d=Add()([d,layer_input])
      return(d)
    def deconv2d(layer_input):
      #layers used during upsampling
      u=UpSampling2D(size=2)(layer_input)
      u=Conv2D(256,kernel_size=3,strides=1,padding='same')(u)
      u=Activation('relu')(u)
      return(u)
    
    #Low resolution image input
    img_lr=Input(shape=self.lr_shape)
    
    #Pre-residual block
    c1=Conv2D(64,kernel_size=9,strides=1,padding='same')(img_lr)
    c1=Activation('relu')(c1)
    
    #Propagate through residual blocks
    r=residual_block(c1,self.gf)
    for _ in range(self.n_residual_blocks-1):
      r=residual_block(r,self.gf)
    
    #Post residual block
    c2=Conv2D(64,kernel_size=3,strides=1,padding='same')(r)
    c2=BatchNormalization(momentum=0.8)(c2)
    c=Add()([c2,c1])
    
    #UpSampling
    u1=deconv2d(c2)
    u2=deconv2d(u1)
    
    #Generate high resolution output
    gen_hr=Conv2D(self.channels,kernel_size=9,strides=1,padding='same',activation='tanh')(u2)
    
    return(Model(img_lr,gen_hr))

  def build_discriminator(self):
    def d_block(layer_input,filters,strides=1,bn=True):
      #Discriminator layer
      d=Conv2D(filters,kernel_size=3,strides=strides,padding='same')(layer_input)
      d=LeakyReLU(alpha=0.2)(d)
      if bn:
        d=BatchNormalization(momentum=0.8)(d)
      return(d)
    
    #Input image:
    d0=Input(shape=self.hr_shape)
    
    d1=d_block(d0,self.df,bn=False)
    d2=d_block(d1,self.df,strides=2)
    d3=d_block(d2,self.df*2)
    d4=d_block(d3,self.df*2,strides=2)
    d5=d_block(d4,self.df*4)
    
    d6=d_block(d5,self.df*4,strides=2)
    d7=d_block(d6,self.df*8)
    d8=d_block(d7,self.df*8,strides=2)
    
    d9=Dense(self.df*16)(d8)
    d10=LeakyReLU(alpha=0.2)(d9)
    validity=Dense(1,activation='sigmoid')(d10)
    
    return(Model(d0,validity))
    
  def train(self,epochs,batch_size=1,sample_interval=50):
    start_time=datetime.datetime.now()
    
    for epoch in range(epochs):
      #Train discriminator:
      
      #sample images and their conditioning counterparts:
      imgs_hr,imgs_lr=self.data_loader.load_data(batch_size)
      
      #from low resolution image to high resolution version
      fake_hr=self.generator.predict(imgs_lr)
      
      valid=np.ones((batch_size,)+self.disc_patch)
      fake=np.zeros((batch_size,)+self.disc_patch)

      #Train the discriminator
      d_loss_real=self.discriminator.train_on_batch(imgs_hr,valid)
      d_loss_fake=self.discriminator.train_on_batch(fake_hr,fake)
      d_loss=0.5*np.add(d_loss_real,d_loss_fake)
      
      #Train Generator:
      
      #Sample images and their conditioning counterparts:
      imgs_hr,imgs_lr=self.data_loader.load_data(batch_size)
      
      #Generators want discriminators to label the generated images as real:
      valid=np.ones((batch_size,)+self.disc_patch)
      
      #extract ground truth features from pretrained VGG19 model
      image_features=self.vgg.predict(imgs_hr)
      
      #Train the generators:
      g_loss=self.combined.train_on_batch([imgs_lr,imgs_hr],[valid,image_features])
      
      elapsed_time=datetime.datetime.now() - start_time
      
      #Plot the progress:
      print("%d time: %s"%(epoch,elapsed_time))
      
      #if at save_interval => save generated image samples:
      if epoch%sample_interval==0:
        self.sample_images(epoch)
        
  def sample_images(self,epoch):
    os.makedirs("images/%s"%self.dataset_name,exist_ok=True)
    r,c=2,2
    
    imgs_hr,imgs_lr=self.data_loader.load_data(batch_size=2,is_testing=True)
    
    fake_hr=self.generator.predict(imgs_lr)
    
    #rescale from 0 to 1
    imgs_lr=0.5*imgs_lr+0.5
    fake_hr=0.5*fake_hr+0.5
    imgs_hr=0.5*imgs_hr+0.5
    
    #Save generated images and high resolution original images:
    titles=['Generated','Original']
    fig,axs=plt.subplots(r,c)
    cnt=0
    
    for row in range(r):
      for col,image in enumerate([fake_hr,imgs_hr]):
        axs[row,col].imshow(image[row])
        axs[row,col].set_title(titles[col])
        axs[row,col].axis('off')
      cnt+=1
    plt.savefig('images/%s/%d.png'%(self.dataset_name,epoch))
    plt.close()
                                                
    #Save low resolution images for comparison:
    for i in range(r):
      fig=plt.figure()
      plt.imshow(imgs_lr[i])
      fig.savefig('images/%s/%d_lowres%d.png'%(self.dataset_name,epoch,i))
      plt.close()
      
  def save_model(self):
    def save(model,model_name):
      model_path="saved_model/%s.json"%model_name
      weights_path="saved_model/%s_weights.hdf5"%model_name
      options={'file_arch':model_path,
               'file_weight':weights_path}
      json_string=model.to_json()
      open(options['file_arch'],'w').write(json_string)
      model.save_weights(options['file_weight'])
      
    save(self.generator,"generator")
    save(self.discriminator,"discriminator")
    
    

In [0]:
if __name__=="__main__":
  gan=SRGAN()
  gan.train(epochs=25000,batch_size=1,sample_interval=1000)
  gan.save_model()

Downloading data from https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels.h5
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.
  'Discrepancy between trainable weights and collected trainable'


0 time: 0:01:25.586424
1 time: 0:01:29.373222
2 time: 0:01:30.492679
3 time: 0:01:31.620949
4 time: 0:01:32.755940
5 time: 0:01:33.892750
6 time: 0:01:35.015248
7 time: 0:01:36.131153
8 time: 0:01:37.255673
9 time: 0:01:38.386602
10 time: 0:01:39.516373
11 time: 0:01:40.648084
12 time: 0:01:41.761067
13 time: 0:01:42.882684
14 time: 0:01:44.015679
15 time: 0:01:45.153529
16 time: 0:01:46.281530
17 time: 0:01:47.403283
18 time: 0:01:48.514578
19 time: 0:01:49.633360
20 time: 0:01:50.741758
21 time: 0:01:51.853439
22 time: 0:01:52.954910
23 time: 0:01:54.078469
24 time: 0:01:55.212593
25 time: 0:01:56.335661
26 time: 0:01:57.471409
27 time: 0:01:58.597183
28 time: 0:01:59.725212
29 time: 0:02:00.850855
30 time: 0:02:01.984122
31 time: 0:02:03.120983
32 time: 0:02:04.249693
33 time: 0:02:05.369192
34 time: 0:02:06.506169
35 time: 0:02:07.635609
36 time: 0:02:08.767884
37 time: 0:02:09.888708
38 time: 0:02:11.017265
39 time: 0:02:12.158631
40 time: 0:02:13.287975
41 time: 0:02:14.416287
42