In [0]:
import scipy
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import imageio
#以Agg作为后端
matplotlib.use('Agg')
class DataLoader():
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res

    def load_data(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "test"
        
        #path = glob('./datasets/%s/*' % (self.dataset_name))
        path = glob('./drive/My Drive/img_align_celeba/*')
        #print(path)
        #path = glob('./A/*')

        batch_images = np.random.choice(path, size=batch_size)

        imgs_hr = []
        imgs_lr = []
        for img_path in batch_images:
            img = self.imread(img_path)

            h, w = self.img_res
            low_h, low_w = int(h / 4), int(w / 4)

            img_hr = scipy.misc.imresize(img, self.img_res)
            img_lr = scipy.misc.imresize(img, (low_h, low_w))
            #img_hr=np.array(img.resize(self.img_res))
            #img_lr=np.array(img.resize(low_h, low_w))

            # If training => do random flip
            if not is_testing and np.random.random() < 0.5:
                img_hr = np.fliplr(img_hr)
                img_lr = np.fliplr(img_lr)

            imgs_hr.append(img_hr)
            imgs_lr.append(img_lr)

        imgs_hr = np.array(imgs_hr) / 127.5 - 1.
        imgs_lr = np.array(imgs_lr) / 127.5 - 1.
        #imgs_hr = imgs_hr / 127.5 - 1.
        #imgs_lr = imgs_lr/ 127.5 - 1.

        return imgs_hr, imgs_lr
    def load_datax(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "test"
        
        #path = glob('./datasets/%s/*' % (self.dataset_name))
        path = glob('./drive/My Drive/SRtest/*')
        #path = glob('./drive/My Drive/C/*')

        batch_images = np.random.choice(path, size=batch_size)

        imgs_hr = []
        imgs_lr = []
        for img_path in batch_images:
            img = self.imread(img_path)

            h, w = self.img_res
            low_h, low_w = h,w

            img_hr = scipy.misc.imresize(img, (64,64))
            img_lr = scipy.misc.imresize(img, (low_h, low_w))

            # If training => do random flip
            if not is_testing and np.random.random() < 0.5:
                img_hr = np.fliplr(img_hr)
                img_lr = np.fliplr(img_lr)

            imgs_hr.append(img_hr)
            imgs_lr.append(img_lr)

        imgs_hr = np.array(imgs_hr) / 127.5 - 1.
        imgs_lr = np.array(imgs_lr) / 127.5 - 1.
        #imgs_hr = imgs_hr / 127.5 - 1.
        #imgs_lr = imgs_lr/ 127.5 - 1.

        return imgs_hr, imgs_lr


    def imread(self, path):
        #return imageio.imread(path, pilmode='RGB').astype(np.float)
        return scipy.misc.imread(path, mode='RGB').astype(np.float)


In [0]:
import warnings
 
warnings.filterwarnings('ignore')

In [0]:
from __future__ import print_function, division 
import scipy
import matplotlib
matplotlib.use('Agg')    #不画图，只写文件
from keras.datasets import mnist
#from keras_contrib.layers.normalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, Add
from keras.layers.advanced_activations import PReLU, LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.applications import VGG19
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
import numpy as np
import os

#os.environ["CUDA_VISIBLE_DEVICES"] = "0"     
import keras.backend as K
from skimage import io,data

from scipy.misc import *
#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
def merge(images, size):
  h, w= images.shape[1], images.shape[2]
  img = np.zeros((h * size[0], w * size[1]))
  for idx, image in enumerate(images):
    i = idx % size[1]
    j = idx // size[1]
    img[j*h:j*h+h, i*w:i*w+w] = image
  return img

#imsave("images/mnist_%d.png" % epoch,gen_imgs)
class SRGAN():
    def __init__(self):
        # Input shape
        self.channels = 3
        #低分辨率图片大小
        self.lr_height = 64                
        self.lr_width = 64                  
        self.lr_shape = (self.lr_height, self.lr_width, self.channels)
        #高分辨率图片大小
        self.hr_height = self.lr_height*4   
        self.hr_width = self.lr_width*4     
        self.hr_shape = (self.hr_height, self.hr_width, self.channels)
        self.latent_dim=100

        # 生成器中的 residual blocks 
        self.n_residual_blocks = 16

        optimizer = Adam(0.0002, 0.5)

        # 使用预训练的 VGG19 模型来提取高分率图片的特征，并不对 VGG19 部分进行训练
        self.vgg = self.build_vgg()
        self.vgg.trainable = False
        self.vgg.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

        # 进行图片加载        
        self.dataset_name = 'img_align_celeba'          
        self.data_loader = DataLoader(dataset_name=self.dataset_name,
                                      img_res=(self.hr_height, self.hr_width))

        # 计算 D (PatchGAN) 的输出形状
        patch = int(self.hr_height / 2**4)
        self.disc_patch = (patch, patch, 1)

        #  G 和 D 的第一层卷积核数量
        self.gf = 64
        self.df = 64

        # 创建判别器
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

        # 创建生成器
        self.generator = self.build_generator()

        # 喂入高低分辨率图片
        img_hr = Input(shape=self.hr_shape)
        img_lr = Input(shape=self.lr_shape)

        # 使用生成器生成高分辨率图片
        fake_hr = self.generator(img_lr)

        # 提取生成图片的特征
        fake_features = self.vgg(fake_hr)

        # 只训练生成器
        self.discriminator.trainable = False

        # 对生成的图片的质量进行判别
        validity = self.discriminator(fake_hr)

        #创建合成模型
        self.combined = Model([img_lr, img_hr], [validity, fake_features])
        self.combined.compile(loss=['binary_crossentropy', 'mse'],
                              loss_weights=[1e-3, 1],
                              optimizer=optimizer)


    def build_vgg(self):
        """
        创建预训练的 VG199 模型来提取图片特征
        """
        vgg = VGG19(weights="imagenet")
        # See architecture at: https://github.com/keras-team/keras/blob/master/keras/applications/vgg19.py
        vgg.outputs = [vgg.layers[9].output]

        img = Input(shape=self.hr_shape)

        # 提取第3块最后一层的输出
        img_features = vgg(img)

        return Model(img, img_features)

    def build_generator(self):

        def residual_block(layer_input):
            """论文中的残差模块"""
            d = Conv2D(64, kernel_size=3, strides=1, padding='same')(layer_input)
            d = Activation('relu')(d)
            d = BatchNormalization(momentum=0.8)(d)
            d = Conv2D(64, kernel_size=3, strides=1, padding='same')(d)
            d = BatchNormalization(momentum=0.8)(d)
            d = Add()([d, layer_input])
            return d

        def deconv2d(layer_input):
            """反卷积层"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
            u = Activation('relu')(u)
            return u

        # 低分辨率输入
        img_lr = Input(shape=self.lr_shape)

        # 前残差块
        c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
        c1 = Activation('relu')(c1)

        # 前向传播
        r = residual_block(c1)
        for _ in range(self.n_residual_blocks - 1):
            r = residual_block(r)

        # 后残差块
        c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
        c2 = BatchNormalization(momentum=0.8)(c2)
        c2 = Add()([c2, c1])

        # 反卷积
        u1 = deconv2d(c2)
        u2 = deconv2d(u1)

        # 生成高分辨率输出
        gen_hr = Conv2D(self.channels, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)
        
        return Model(img_lr, gen_hr)

    def build_discriminator(self):

        def d_block(layer_input, filters, strides=1, bn=True):
            """判别层"""
            d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        # 输入图片
        d0 = Input(shape=self.hr_shape)

        d1 = d_block(d0, self.df, bn=False)
        d2 = d_block(d1, self.df, strides=2)
        d3 = d_block(d2, self.df*2)
        d4 = d_block(d3, self.df*2, strides=2)
        d5 = d_block(d4, self.df*4)
        d6 = d_block(d5, self.df*4, strides=2)
        d7 = d_block(d6, self.df*8)
        d8 = d_block(d7, self.df*8, strides=2)

        d9 = Dense(self.df*16)(d8)
        d10 = LeakyReLU(alpha=0.2)(d9)
        validity = Dense(1, activation='sigmoid')(d10)

        return Model(d0, validity)

    def train(self, epochs, batch_size=1, sample_interval=50):

        start_time = datetime.datetime.now()
        self.generator.summary()
        self.discriminator.summary()
        for epoch in range(epochs):
        
            # ----------------------
            #  训练判别器
            # ----------------------

            # 从对应数据集进行采样
            imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)


            # 从低分辨率图片生成高分辨率图片
            fake_hr = self.generator.predict(imgs_lr)

            valid = np.ones((batch_size,) + self.disc_patch)
            fake = np.zeros((batch_size,) + self.disc_patch)

            # 训练判别器 (oiginal images = real / generated = Fake)
            d_loss_real = self.discriminator.train_on_batch(imgs_hr, valid)
            d_loss_fake = self.discriminator.train_on_batch(fake_hr, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ------------------
            #  训练生成器
            # ------------------

            # 从对应数据集进行采样
            imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)

            # 生成器希望判别器能将生成的图片标签为正样本
            valid = np.ones((batch_size,) + self.disc_patch)

            # 用预训练的 VGG19 提取真实图片的特征
            image_features = self.vgg.predict(imgs_hr)

            # 训练生成器
            g_loss = self.combined.train_on_batch([imgs_lr, imgs_hr], [valid, image_features])

            elapsed_time = datetime.datetime.now() - start_time
            
            print ("%d time: %s" % (epoch, elapsed_time))

            # 根据对应轮数进行保存
            if epoch % sample_interval == 0:
                self.sample_images(epoch)
                ''' 
                r, c = 10, 10   
                noise = np.random.normal(0, 1, (r * c, self.latent_dim))
                gen_imgs = self.generator.predict(noise)
              
                # Rescale images 0 - 1
                gen_imgs = 0.5 * gen_imgs + 1
                gen_imgs=gen_imgs.reshape(-1,28,28)
                gen_imgs = merge(gen_imgs[:49], [7,7])
                
                os.makedirs('images/%s' % 'noise', exist_ok=True)
                imsave("images/sr_%d.png" % epoch,gen_imgs)
                '''
                self.generator.save
                

                
        self.dataset_name = 'C'
        os.makedirs('imagesa/%s' % self.dataset_name, exist_ok=True)
        
        
        r, c = 2, 2
        imgs_hr, imgs_lr = self.data_loader.load_datax(batch_size=1, is_testing=True)
        fake_hr = self.generator.predict(imgs_hr)

        # Rescale images 0 - 1
        imgs_lr = 0.5 * imgs_lr + 0.5
        fake_hr = 0.5 * fake_hr + 0.5
        imgs_hr = 0.5 * imgs_hr + 0.5
        
        fg = plt.figure()
        plt.axis('off')
        plt.imshow(fake_hr[0])
        fg.savefig("imagesa/%s/%d.png" % ('C', epoch))
        #Fp=Image.fromarray(fake_hr[0], 'RGB')
        #Fp.save("imagesa/%s/%d.png" % (self.dataset_name, epoch))
        plt.close()
       

        # 保存生成的图片和原高清图像
       
        
        titles = ['Generated', 'Original']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for row in range(r):
            for col, image in enumerate([fake_hr, imgs_hr]):
                axs[row, col].imshow(image[row])
                axs[row, col].set_title(titles[col])
                axs[row, col].axis('off')
            cnt += 1
        fig.savefig("imagesa/%s/%d.png" % (self.dataset_name, epoch))
        plt.close()
        
        imgs_hr, imgs_lr = self.data_loader.load_data(batch_size=49, is_testing=True)
        fake_hr = self.generator.predict(imgs_lr)
        # Rescale images 0 - 1
        imgs_lr = 0.5 * imgs_lr + 0.5
        fake_hr = 0.5 * fake_hr + 0.5
        imgs_hr = 0.5 * imgs_hr + 0.5           
            
        # 保存生成的图片和原高清图像
        titles = ['Generated', 'Original']
        gen_imgs=fake_hr.reshape(-1,28,28,3)
        gen_imgs = merge(gen_imgs[:49], [7,7])
        imsave("images/sr_%d.png" % epoch,gen_imgs)
             

    def sample_images(self, epoch):
        os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
        r, c = 2, 2

        imgs_hr, imgs_lr = self.data_loader.load_data(batch_size=2, is_testing=True)
        fake_hr = self.generator.predict(imgs_lr)

        # Rescale images 0 - 1
        imgs_lr = 0.5 * imgs_lr + 0.5
        fake_hr = 0.5 * fake_hr + 0.5
        imgs_hr = 0.5 * imgs_hr + 0.5

        # 保存生成的图片和原高清图像
        titles = ['Generated', 'Original']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for row in range(r):
            for col, image in enumerate([fake_hr, imgs_hr]):
                axs[row, col].imshow(image[row])
                axs[row, col].set_title(titles[col])
                axs[row, col].axis('off')
            cnt += 1
        fig.savefig("images/%s/%d.png" % (self.dataset_name, epoch))
        plt.close()
        
        os.makedirs('imagesa/%s' % 'C', exist_ok=True)
        r, c = 2, 2
        imgs_hr, imgs_lr = self.data_loader.load_datax(batch_size=2, is_testing=True)
        fake_hr = self.generator.predict(imgs_hr)

        # Rescale images 0 - 1
        imgs_lr = 0.5 * imgs_lr + 0.5
        fake_hr = 0.5 * fake_hr + 0.5
        imgs_hr = 0.5 * imgs_hr + 0.5
        
        fg = plt.figure()
        plt.axis('off')
        plt.imshow(fake_hr[0])
        fg.savefig("imagesa/%s/%d.png" % ('C', epoch))
        #Fp=Image.fromarray(fake_hr[0], 'RGB')
        #Fp.save("imagesa/%s/%d.png" % (self.dataset_name, epoch))
        plt.close()
        '''
        # 保存低分辨率图片进行对比
        for i in range(r):
            fig = plt.figure()
            plt.imshow(imgs_lr[i])
            fig.savefig('images/%s/%d_lowres%d.png' % (self.dataset_name, epoch, i))
            plt.close()
        '''

if __name__ == '__main__':
    gan = SRGAN()
    gan.train(epochs=5000, batch_size=1, sample_interval=100)

Model: "model_19"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_28 (InputLayer)           (None, 64, 64, 3)    0                                            
__________________________________________________________________________________________________
conv2d_189 (Conv2D)             (None, 64, 64, 64)   15616       input_28[0][0]                   
__________________________________________________________________________________________________
activation_77 (Activation)      (None, 64, 64, 64)   0           conv2d_189[0][0]                 
__________________________________________________________________________________________________
conv2d_190 (Conv2D)             (None, 64, 64, 64)   36928       activation_77[0][0]              
___________________________________________________________________________________________

IndexError: ignored

In [0]:
# 运行此单元格即可装载您的 Google 云端硬盘。
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
import os
img_dir = '/tmp/nst'
if not os.path.exists(img_dir):
    os.makedirs(img_dir)
!wget -P /tmp/nst/ https://drive.google.com/file/d/1Q7CXPq8pb4NF8k5kMKxLjzJAGpAfqfcw/view?usp=sharing

In [0]:
pip install --upgrade Pillow

Requirement already up-to-date: Pillow in /usr/local/lib/python3.6/dist-packages (6.1.0)


In [0]:
pip install scipy==1.2.1

Collecting scipy==1.2.1
[?25l  Downloading https://files.pythonhosted.org/packages/7f/5f/c48860704092933bf1c4c1574a8de1ffd16bf4fde8bab190d747598844b2/scipy-1.2.1-cp36-cp36m-manylinux1_x86_64.whl (24.8MB)
[K     |████████████████████████████████| 24.8MB 1.2MB/s 
[31mERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.[0m
Installing collected packages: scipy
  Found existing installation: scipy 1.3.1
    Uninstalling scipy-1.3.1:
      Successfully uninstalled scipy-1.3.1
Successfully installed scipy-1.2.1


In [0]:
imgpath = glob('/imagesa/C')
for filename in os.listdir:
  img.append(self.imread(img_path))