In [None]:
from tensorflow.keras.models import Sequential, Model
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate, MaxPooling2D
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, Add
from keras.layers.advanced_activations import PReLU, LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.applications import VGG19
from keras.models import Sequential, Model
import matplotlib.pyplot as plt
from keras.optimizers import Adam

class GAN:
    def __init__(self, x, y):
        print(x.shape[0])
        self.input_shape = (x[0].shape)
        self.x = x
        self.y = y
        # Number of residual blocks in the generator
        self.n_residual_blocks = 16
        
        optimizer = Adam(0.0002, 0.5)
        
#         patch = int(self.x.shape[0] / 2**4)
        self.disc_patch = (8,8,1)
        
        self.vgg = self.build_vgg()
        self.vgg.trainable = False
        self.vgg.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])
        
        # build and comple discriminator
        self.discriminator = self._get_discriminator()
        self.discriminator.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])
        
        # build generator
        self.generator = self._get_generator()
        
        img_sh = Input(shape=self.input_shape)
        img_d = Input(shape=self.input_shape)
        
        fake_d = self.generator(img_sh)
        
        # Extract image features of the generated img
        fake_features = self.vgg(fake_d)
        
         # For the combined model we will only train the generator
        self.discriminator.trainable = False
        
        validity = self.discriminator(fake_d)
        
        self.gan = Model([img_sh, img_d], [validity, fake_features])
        self.gan.compile(loss=['binary_crossentropy', 'mse'],
                              loss_weights=[1e-3, 1],
                              optimizer=optimizer)
    def build_vgg(self):
        """
        Builds a pre-trained VGG19 model that outputs image features extracted at the
        third block of the model
        """
        vgg = VGG19(weights="imagenet", include_top = False,input_shape = self.input_shape)
        # Set outputs to outputs of last conv. layer in block 3
        # See architecture at: https://github.com/keras-team/keras/blob/master/keras/applications/vgg19.py
        model = Model(inputs=vgg.inputs, outputs=vgg.layers[9].output)

        return model
    
    def _get_generator(self):
        
        def create_block(input, chs): ## Convolution block of 2 layers for conv autoencoder
            x = input
            for i in range(2):
                x = Conv2D(chs, 3, padding="same")(x)
                x = Activation("relu")(x)
                x = BatchNormalization()(x)
            return x
        
        input = Input(self.input_shape)
    
        # Encoder
        block1 = create_block(input, 32)
        x = MaxPooling2D(2)(block1)
        block2 = create_block(x, 64)
        x = MaxPooling2D(2)(block2)


        #Middle
        middle = create_block(x, 128)
    #     middle = AdaIN()(encoder.outputs)

        # Decoder
        up1 = UpSampling2D((2,2))(middle)
        block3 = create_block(up1, 64)
        #up1 = UpSampling2D((2,2))(block3)
        up2 = UpSampling2D((2,2))(block3)
        block4 = create_block(up2, 32)
        #up2 = UpSampling2D((2,2))(block4)

        # output
        x = Conv2D(3, 1)(up2)
        output = Activation("sigmoid")(x)

        return Model(input, output)
    
    def _get_discriminator(self):
        def d_block(layer_input, filters=64, strides=1, bn=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        # Input img
        d0 = Input(shape=self.input_shape)

        d1 = d_block(d0, 64, bn=False)
        d2 = d_block(d1, 64, strides=2)
        d3 = d_block(d2, 64*2)
        d4 = d_block(d3, 64*2, strides=2)
        d5 = d_block(d4, 64*4)
        d6 = d_block(d5, 64*4, strides=2)
        d7 = d_block(d6, 64*8)
        d8 = d_block(d7, 64*8, strides=2)

        d9 = Dense(64*16)(d8)
        d10 = LeakyReLU(alpha=0.2)(d9)
        validity = Dense(1, activation='sigmoid')(d10)

        return Model(d0, validity)
        
#     def _get_gan(self, generator, discriminator):
#         discriminator.trainable = False
        
#         gan_input = Input(shape=self.input_shape)
#         generated = generator(gan_input)
#         output = discriminator(generated)
#         gan = Model(inputs=gan_input, outputs=[generated,output])
        
#         gan.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
#         return gan
    
    def train(self, epochs, batch_size=128):
        accuracy = 0.5
        
        for e in range(1, epochs+1):
            print('-'*10, 'Epoch %s' % e, '-'*10)
            
            batch = np.random.randint(0, self.x.shape[0], size=batch_size)
            image_noise_batch = self.x[batch]
            image_batch = self.y[batch]
            
            generated = self.generator.predict(image_noise_batch)
            X = np.concatenate([image_batch, generated])
            
            y_dis = np.zeros(2*batch_size)
            y_dis[:batch_size] = 1
            
            self.discriminator.trainable = True
            self.generator.trainable = False
            
            accuracy = 0.5
            while accuracy < 0.9:
                disc_loss, accuracy = self.discriminator.train_on_batch(X, y_dis)
                print('Discriminator accuracy:', accuracy)
                
            y_gen = np.ones(batch_size)
            self.discriminator.trainable = False
            self.generator.trainable = True
            gan_accuracy = 0
            while gan_accuracy < 0.6:
                gan_loss, gan_accuracy = self.gan.train_on_batch(image_noise_batch, y_gen)
                print('GAN accuracy:', gan_accuracy)
            
            print('Discriminator loss:', disc_loss, 
                  'Discriminator accuracy:', accuracy,
                  'GAN loss:', gan_loss,
                  'GAN accuracy:', gan_accuracy)
            
            if e == 1 or e % 5 == 0:
                self.plot_images(e)
            
    
                
    def train_new(self,epoch_num, batch_size=128):
        
        
        for epoch in range(epoch_num):
           
             # ----------------------
            #  Train Discriminator
            # ----------------------

            # [Batch Preparation]
            batch = np.random.randint(0, self.x.shape[0], size=batch_size)
            image_sh = np.array(self.x[batch])
            image_d = np.array(self.y[batch])

            # Generate fake inputs
            fake_d = self.generator.predict(image_sh)
        
        
            valid = np.ones((batch_size,) + self.disc_patch)
            fake = np.zeros((batch_size,)+ self.disc_patch)
           
            d_loss_real = self.discriminator.train_on_batch(image_d, valid)
            d_loss_fake = self.discriminator.train_on_batch(fake_d, fake)
            d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)

        # ------------------
        #  Train Generator
        # ------------------


            # The generators want the discriminators to label the generated images as real
            valid = np.ones((batch_size,)+ self.disc_patch)

            # Extract ground truth image features using pre-trained VGG19 model
            image_features = self.vgg.predict(image_d)

            # Train the generators
            g_loss = self.gan.train_on_batch([image_sh, image_d], [valid, image_features])



            # If at save interval => save generated image samples
            if epoch % 50 == 0:
                print('d_loss_real:', d_loss_real, 
                'd_loss_fake:', d_loss_fake,
                'd_loss:', d_loss,
                'GAN loss:', g_loss)
                self.plot_images(epoch)
            
    def plot_images(self, epoch):
        image_noise_batch = self.x[np.random.randint(0, self.x.shape[0], size=2)]
        generated_images = self.generator.predict(image_noise_batch)
        
        fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(16,16))
        
        ax[0][0].imshow(image_noise_batch[0])
        ax[0][1].imshow(generated_images[0])
        ax[1][0].imshow(image_noise_batch[1])
        ax[1][1].imshow(generated_images[1])
        plt.show()

import numpy as np
from keras.preprocessing.image import load_img

image_shape = 256

def load_reshape_img(fname):
    img = load_img(fname,target_size=(image_shape,image_shape))
    x_float = np.array(img).astype('float32') / 255.
    x = im_crop(x_float)
    return x

def im_crop(image):
    im_cropped = image[12:140,72:200,]
    return im_cropped

def generate_x():
    MAIN_DIR = "/kaggle/input/mayadatasetv2/input/train/train_x/x/"
    image_df = []
    for idx in range(1,8):
        for i in range(1,11):
            for j in range(1,13):
                IMAGE_PATH = "Female_"+str(idx)+"."+'{:02}'.format(j)+"_scene"+str(i)+".jpg"
                im = load_reshape_img(MAIN_DIR+IMAGE_PATH)
                image_df.append(im)
    return np.array(image_df)

def generate_y():
    MAIN_DIR = "/kaggle/input/mayadatasetv2/input/train/train_y/y/"
    image_df = []
    for idx in range(1,8):
        for i in range(1,11):
            for j in range(1,13):
                IMAGE_PATH = "Female_"+str(idx)+"."+'{:02}'.format(j)+"_scene"+str(i)+".jpg"
                im = load_reshape_img(MAIN_DIR+IMAGE_PATH)
                image_df.append(im)
    return np.array(image_df)




gan = GAN(generate_x(), generate_y())
gan.train_new(3000)

In [None]:
def load_reshape_img_real(fname):
    img = load_img(fname,target_size=(128,128))
    x_float = np.array(img).astype('float32') / 255.
    return x_float

def generate_real_test():
    #../input/real-specular/real_specular/Image-12.jpg
    MAIN_DIR = "../input/realspecularimages/real_specular_images/"
    image_df = []
    for i in range(8):
#         print(MAIN_DIR+"Capture"+str(i)+".jpg")
        im = load_reshape_img_real( MAIN_DIR+"Capture"+str(i)+".JPG")
        image_df.append(im)
    return np.array(image_df)


x = generate_real_test()

fake_d = gan.generator.predict(x)

fig, ax = plt.subplots(nrows=6, ncols=2, figsize=(16,16))
        
ax[0][0].imshow(x[0])
ax[0][1].imshow(fake_d[0])
ax[1][0].imshow(x[1])
ax[1][1].imshow(fake_d[1])
ax[2][0].imshow(x[2])
ax[2][1].imshow(fake_d[2])
ax[3][0].imshow(x[3])
ax[3][1].imshow(fake_d[3])
ax[4][0].imshow(x[4])
ax[4][1].imshow(fake_d[4])
ax[5][0].imshow(x[5])
ax[5][1].imshow(fake_d[5])

plt.show()


In [None]:
def generate_test():
    MAIN_x_DIR = "../input/mayadatasetv2/input/test/test_x/x/"
    MAIN_y_DIR = "../input/mayadatasetv2/input/test/test_y/y/"
    image_x_df = []
    image_y_df = []
    female_idx = [18, 18, 21, 22, 19, 22, 21, 19, 21, 20]
    angle_idx = [6, 8, 6, 2, 8, 10, 8, 8, 11, 6]
    scene = [4, 6, 6, 7, 2, 2, 1, 7, 9, 1]
    for idx in range(1,11):
        IMAGE_PATH = "Female_"+str(female_idx[idx-1])+"."+'{:02}'.format(angle_idx[idx-1])+"_scene"+str(scene[idx-1])+".jpg"
        im_x = load_reshape_img(MAIN_x_DIR+IMAGE_PATH)
        im_y = load_reshape_img(MAIN_y_DIR+IMAGE_PATH)
        image_x_df.append(im_x)
        image_y_df.append(im_y)
    return np.array(image_x_df),np.array(image_y_df)

x, y = generate_test()
recons = gan.generator.predict(x)


# fig, ax = plt.subplots(nrows=10, ncols=2, figsize=(16,16))
  
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(16,16))
ax[0][0].imshow(x[0])
ax[0][1].imshow(recons[0])
ax[1][0].imshow(x[1])
ax[1][1].imshow(recons[1])

fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(16,16))

ax[0][0].imshow(x[2])
ax[0][1].imshow(recons[2])
ax[1][0].imshow(x[3])
ax[1][1].imshow(recons[3])

fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(16,16))
ax[0][0].imshow(x[4])
ax[0][1].imshow(recons[4])
ax[1][0].imshow(x[5])
ax[1][1].imshow(recons[5])

fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(16,16))

ax[0][0].imshow(x[6])
ax[0][1].imshow(recons[6])
ax[1][0].imshow(x[7])
ax[1][1].imshow(recons[7])


fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(16,16))

ax[0][0].imshow(x[8])
ax[0][1].imshow(recons[8])
ax[1][0].imshow(x[9])
ax[1][1].imshow(recons[9])


plt.show()