In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip /content/drive/MyDrive/archive.zip

In [None]:
%cd /content/
!git clone https://github.com/ryankiros/skip-thoughts.git

In [None]:
!mkdir pretrained
%cd pretrained
!wget http://www.cs.toronto.edu/~rkiros/models/dictionary.txt
!wget http://www.cs.toronto.edu/~rkiros/models/utable.npy
!wget http://www.cs.toronto.edu/~rkiros/models/btable.npy
!wget http://www.cs.toronto.edu/~rkiros/models/uni_skip.npz
!wget http://www.cs.toronto.edu/~rkiros/models/uni_skip.npz.pkl
!wget http://www.cs.toronto.edu/~rkiros/models/bi_skip.npz
!wget http://www.cs.toronto.edu/~rkiros/models/bi_skip.npz.pkl

In [3]:
import os
import sys
import numpy as np
from random import shuffle
import tensorflow as tf
from keras.preprocessing.image import img_to_array,load_img
from PIL import Image
import math
import random
from collections import Counter
import nltk
import json
import h5py
import pickle
import re
import urllib.request
from mpl_toolkits.axes_grid1 import ImageGrid
import zipfile
from keras.models import Model,Sequential
from keras.layers import Input,Dense,Reshape,concatenate,Flatten,Lambda,LeakyReLU
from keras.layers.core import Activation
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D,Conv2D,MaxPooling2D,Conv2DTranspose
from keras.optimizers import Adam
from keras.initializers import TruncatedNormal,Zeros,RandomNormal,Constant
from keras import backend as K
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import time
nltk.download('punkt')
print("GPU:",tf.test.gpu_device_name(),"TF version:",tf.__version__)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
GPU: /device:GPU:0 TF version: 2.3.0


In [4]:
class DCGan(object):
    model_name='dc_gan'

    def __init__(self):
        self.generator=None
        self.discriminator=None
        self.model=None
        self.img_width=7
        self.img_height=7
        self.img_channels=1
        self.text_input_dim=4800
        self.random_input_dim=100
        self.config=None
    
    @staticmethod
    def get_config_file_path(model_dir_path):
        return os.path.join(model_dir_path,DCGan.model_name+'-config.npy')
    
    @staticmethod
    def get_weight_file_path(model_dir_path,model_type):
        return os.path.join(model_dir_path,DCGan.model_name+'-'+model_type+'-weights.h5')
    
    def create_model(self):
        init_img_width=4
        init_img_height=4

        random_input=Input((self.random_input_dim,))
        text_input1=Input((self.text_input_dim,))
        text_layer1=Dense(256,kernel_initializer=RandomNormal(stddev=0.02),bias_initializer=Constant(value=0))(text_input1)
        text_layer1=LeakyReLU(alpha=0.2)(text_layer1)

        merged=concatenate([random_input,text_layer1])
        generator_layer=Activation('tanh')(merged)

        generator_layer=Dense(512*init_img_width*init_img_height,kernel_initializer=RandomNormal(stddev=0.02),bias_initializer=Constant(value=0))(generator_layer)
        generator_layer=BatchNormalization()(generator_layer)
        generator_layer=LeakyReLU(alpha=0.2)(generator_layer)

        generator_layer=Reshape((init_img_height,init_img_width,512),input_shape=(128*init_img_width*init_img_height,))(generator_layer)

        generator_layer=Conv2DTranspose(256,kernel_size=5,strides=(2,2),padding='same',kernel_initializer=RandomNormal(stddev=0.02),bias_initializer=Constant(value=0))(generator_layer)
        generator_layer=BatchNormalization()(generator_layer)
        generator_layer=LeakyReLU(alpha=0.2)(generator_layer)

        generator_layer=Conv2DTranspose(128,kernel_size=5,strides=(2,2),padding='same',kernel_initializer=RandomNormal(stddev=0.02),bias_initializer=Constant(value=0))(generator_layer)
        generator_layer=BatchNormalization()(generator_layer)
        generator_layer=LeakyReLU(alpha=0.2)(generator_layer)

        generator_layer=Conv2DTranspose(64,kernel_size=5,strides=(2,2),padding='same',kernel_initializer=RandomNormal(stddev=0.02),bias_initializer=Constant(value=0))(generator_layer)
        generator_layer=BatchNormalization()(generator_layer)
        generator_layer=LeakyReLU(alpha=0.2)(generator_layer)

        generator_layer=Conv2DTranspose(self.img_channels,kernel_size=5,strides=(2,2),padding='same',kernel_initializer=RandomNormal(stddev=0.02),bias_initializer=Constant(value=0))(generator_layer)
        generator_layer=Activation('tanh')(generator_layer)

        generator_layer=Lambda(lambda x:x/2.)(generator_layer)
        generator_output=Lambda(lambda x:x+0.5)(generator_layer)

        self.generator=Model([random_input,text_input1],generator_output, name='generator')
        g_optim=Adam(lr=0.0002,beta_1=0.5)

        self.generator.compile(loss='binary_crossentropy',optimizer=g_optim)
        self.generator.summary()
        print()

        #discriminator
        text_input2=Input((self.text_input_dim,))
        text_layer2=Dense(256,kernel_initializer=RandomNormal(stddev=0.02),bias_initializer=Constant(value=0))(text_input2)

        img_input2=Input((self.img_height,self.img_width,self.img_channels))

        img_layer2=Conv2D(64,kernel_size=5,padding='same',strides=(2,2),kernel_initializer=TruncatedNormal(stddev=0.02),bias_initializer=Constant(value=0))(img_input2)
        img_layer2=LeakyReLU(alpha=0.2)(img_layer2)

        img_layer2=Conv2D(128,kernel_size=5,padding='same',strides=(2,2),kernel_initializer=TruncatedNormal(stddev=0.02),bias_initializer=Constant(value=0))(img_layer2)
        img_layer2=BatchNormalization()(img_layer2)
        img_layer2=LeakyReLU(alpha=0.2)(img_layer2)

        img_layer2=Conv2D(256,kernel_size=5,padding='same',strides=(2,2),kernel_initializer=TruncatedNormal(stddev=0.02),bias_initializer=Constant(value=0))(img_layer2)
        img_layer2=BatchNormalization()(img_layer2)
        img_layer2=LeakyReLU(alpha=0.2)(img_layer2)
        
        img_layer2=Conv2D(512,kernel_size=5,padding='same',strides=(2,2),kernel_initializer=TruncatedNormal(stddev=0.02),bias_initializer=Constant(value=0))(img_layer2)
        img_layer2=BatchNormalization()(img_layer2)
        img_layer2=LeakyReLU(alpha=0.2)(img_layer2)

        text_layer2=Lambda(K.expand_dims,arguments={'axis':1})(text_layer2)
        text_layer2=Lambda(K.expand_dims,arguments={'axis':2})(text_layer2)
        text_layer2=Lambda(K.tile,arguments={'n':(1,4,4,1)})(text_layer2)

        img_layer2=concatenate([img_layer2,text_layer2],axis=3)
        
        img_layer2=Conv2D(512,kernel_size=5,padding='same',strides=(1,1),kernel_initializer=TruncatedNormal(stddev=0.02),bias_initializer=Constant(value=0))(img_layer2)
        img_layer2=BatchNormalization()(img_layer2)
        img_layer2=LeakyReLU(alpha=0.2)(img_layer2)

        img_layer2=Flatten()(img_layer2)

        discriminator_layer=Dense(1,kernel_initializer=RandomNormal(stddev=0.02),bias_initializer=Constant(value=0))(img_layer2)
        discriminator_output=Activation('sigmoid')(discriminator_layer)

        self.discriminator=Model([img_input2,text_input2],discriminator_output)
        d_optim=Adam(learning_rate=0.0001,beta_1=0.5)
        self.discriminator.compile(loss='binary_crossentropy',optimizer=d_optim)
        self.discriminator.summary()

        self.discriminator.trainable=False
        model_output=self.discriminator([self.generator.output,text_input1])

        self.model=Model([random_input,text_input1],model_output, name='discriminator')

        self.model.compile(loss='binary_crossentropy',optimizer=g_optim)
    
    def load_batch(self,batch_idx,batch_size,image_label_pairs):
        image_label_pair_batch=image_label_pairs[batch_idx*batch_size:(batch_idx+1)*batch_size]
        image_files_batch=[]
        wrong_image_batch=np.zeros((batch_size,self.img_height,self.img_width,self.img_channels))
        real_image_batch=np.zeros((batch_size,self.img_height,self.img_width,self.img_channels))
        noise=np.zeros((batch_size,self.random_input_dim))
        skipthought_batch=np.zeros((batch_size,self.text_input_dim))

        for i in range(batch_size):
            normalised_img=image_label_pair_batch[i][0]
            real_image_batch[i,:,:,:]=normalised_img

            idx=random.randint(0,len(image_label_pairs)-1)
            wrong_img=image_label_pairs[idx][0]
            wrong_image_batch[i,:,:,:]=wrong_img

            skipthought_batch[i,:]=image_label_pair_batch[i][1]
            noise[i,:]=np.random.uniform(-1,1,self.random_input_dim)
            image_files_batch.append(image_label_pairs[i][2])
        
        return real_image_batch,wrong_image_batch,noise,skipthought_batch,image_files_batch
        
    def fit(self,model_dir_path,image_label_pairs,epochs=None,batch_size=None,snapshot_dir_path=None):
        if epochs is None:
            epochs=100
        
        if batch_size is None:
            batch_size=128
        
        self.config=dict()
        self.config['img_width']=self.img_width
        self.config['img_height']=self.img_height
        self.config['random_input_dim']=self.random_input_dim
        self.config['text_input_dim']=self.text_input_dim
        self.config['img_channels']=self.img_channels

        config_file_path=DCGan.get_config_file_path(model_dir_path)

        np.save(config_file_path,self.config)

        n_batches=image_label_pairs.shape[0]//batch_size
        d_loss_list=[]
        g_loss_list=[]
        for epoch in range(epochs):
            epoch_d_loss=0
            epoch_g_loss=0
            start=time.time()
            for batch_idx in range(n_batches):
                real_images_batch,wrong_images_batch,noise,skipthought_batch,image_files_batch=self.load_batch(batch_idx,batch_size,image_label_pairs)

                fake_images_batch=self.generator.predict([noise,skipthought_batch],verbose=0)

                self.discriminator.trainable=True
                if (batch_idx+1)%4==0:
                    d_loss1=self.discriminator.train_on_batch([fake_images_batch,skipthought_batch],np.array([1]*batch_size))
                    d_loss2=self.discriminator.train_on_batch([wrong_images_batch,skipthought_batch],np.array([0]*batch_size))
                    d_loss3=self.discriminator.train_on_batch([fake_images_batch,skipthought_batch],np.array([0]*batch_size))
                
                else:
                    d_loss1=self.discriminator.train_on_batch([real_images_batch,skipthought_batch],np.array([1]*batch_size))
                    d_loss2=self.discriminator.train_on_batch([wrong_images_batch,skipthought_batch],np.array([0]*batch_size))
                    d_loss3=self.discriminator.train_on_batch([fake_images_batch,skipthought_batch],np.array([0]*batch_size))
                self.discriminator.trainable=False

                d_loss=d_loss1+0.5*(d_loss2+d_loss3)
                
                g_loss=self.model.train_on_batch([noise,skipthought_batch],np.array([1]*batch_size))

                epoch_d_loss+=d_loss
                epoch_g_loss+=g_loss

                if (batch_idx+1)%100==0 and snapshot_dir_path is not None:
                    generated_images=self.generator.predict([noise,skipthought_batch],verbose=0)
                    self.save_snapshots(generated_images,snapshot_dir_path=snapshot_dir_path,epoch=epoch,batch_idx=batch_idx)

            d_loss_list.append(epoch_d_loss/n_batches)
            g_loss_list.append(epoch_g_loss/n_batches)
            print('Epoch: '+str(epoch+1)+'/'+str(epochs)+' epoch_duration: '+str(time.time()-start)+' discriminator_loss: '+str(epoch_d_loss/n_batches)+' generator_loss: '+str(epoch_g_loss/n_batches))
            if (epoch+1)%5==0 or (epoch+1)==epochs:
                self.generator.save_weights(DCGan.get_weight_file_path(model_dir_path,'generator'),True)
                self.discriminator.save_weights(DCGan.get_weight_file_path(model_dir_path,'discriminator'),True)
                with h5py.File('losses_list.h5','w') as out:
                    out.create_dataset("discriminator",data=np.array(d_loss_list))
                    out.create_dataset("generator",data=np.array(g_loss_list))
    
    def generate_image_from_text(self,caption,skipthought_model):
        encoded_text=skipthoughts.encode(skipthought_model,caption)
        noise=np.random.uniform(-1,1,self.random_input_dim)
        noise=np.expand_dims(noise,axis=0)
        generated_image=self.generator.predict([noise,encoded_text],verbose=0)
        print('Caption: '+caption[0])
        plt.imshow(generated_image[0])

    def save_snapshots(self,generated_images,snapshot_dir_path,epoch,batch_idx):
        plot_batch(generated_images,DCGan.model_name,epoch,batch_idx,snapshot_dir_path)

In [None]:
from tqdm import tqdm

def plot_batch(generated_images,model_name,epoch,batch_idx,snapshot_dir_path):
    fig=plt.figure(1)
    grid=ImageGrid(fig,111,nrows_ncols=(2,8),axes_pad=0.05)
    size=2*8
    for i in range(size):
        grid[i].axis('off')
        grid[i].imshow((generated_images[i]*255).astype(np.uint8))
    plt.savefig(os.path.join(snapshot_dir_path,model_name+'-'+str(epoch)+'-'+str(batch_idx)+'.png'))

def img_from_normalised_img(normalised_img):
    image=normalised_img.astype(float)*255
    image=image.astype('uint8')
    return image

def load_normalised_img_and_cap(img_path, caption_file):
    IMAGES_COUNT = 50000
    imgs = []
    names = []
    for pic_file in tqdm(os.listdir(img_path)[:IMAGES_COUNT]):
      pic = Image.open(img_path + pic_file).resize((img_width, img_height))
      pic.thumbnail((img_width, img_height), Image.ANTIALIAS)
      imgs.append(np.uint8(pic))
      names.append(pic_file)
    caps = {}
    with open(caption_file,'r') as c:
        for line in c.readlines():
            img_name = line.strip().split('\t')[0]
            cap = " ".join(line.split('\t')[1].split('|'))
            if cap == "" or cap == " ":
                cap = "This is a person with a face and nothing else."
            caps[img_name] = cap
    # with open(caption_file, 'r') as c:
    #     for line in c.readlines()[1:]:
    #         img_name = line.strip().split(' ')[0]
    #         cap = np.array(line.strip().split(' ')[1:])
    #         caps[img_name] = cap
    # print(caps['000001.jpg'].shape)
    captions=[]
    for i in range(len(names)):
        captions.append(caps.setdefault(names[i], "This is a person with a face and nothing else."))
    vectors = captions
    if not os.path.exists('/content/drive/MyDrive/text2face_gan/skipthought_vectors_50k.pkl'):
        model=skipthoughts.load_model()
        vectors=skipthoughts.encode(model,captions)
        with open('/content/drive/MyDrive/text2face_gan/skipthought_vectors.pkl','wb') as f:
            pickle.dump(vectors,f)
    else:
        with open('/content/drive/MyDrive/text2face_gan/skipthought_vectors_50k.pkl','rb') as f:
            vectors=pickle.load(f)

    result=[]
    for i in range(len(imgs)):
        result.append([imgs[i],vectors[i],names[i]])
    
    return np.array(result)

def resize(img,input_shape):
  height,width=input_shape
  return cv2.resize(img,(width,height))

In [None]:
img_width=80
img_height=96
img_channels=3

WIDTH = 80
HEIGHT = 96

In [None]:
%cd /content/

/content


In [None]:
import skipthoughts

seed=2020
np.random.seed(seed)
model_dir_path='models'
img_path='/content/img_align_celeba/img_align_celeba/'
caption_file = '/content/caps.txt'
image_label_pairs=load_normalised_img_and_cap(img_path, caption_file)
shuffle(image_label_pairs)

100%|██████████| 50000/50000 [01:40<00:00, 495.39it/s]


In [5]:
dcgan=DCGan()
dcgan.img_width=64
dcgan.img_height=64
dcgan.img_channels=3
dcgan.random_input_dim=100
dcgan.text_input_dim=4800
batch_size=64
epochs=200
dcgan.create_model()

Model: "generator"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 4800)]       0                                            
__________________________________________________________________________________________________
dense (Dense)                   (None, 256)          1229056     input_2[0][0]                    
__________________________________________________________________________________________________
input_1 (InputLayer)            [(None, 100)]        0                                            
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 256)          0           dense[0][0]                      
__________________________________________________________________________________________

In [None]:
with tf.device('/gpu:0'):
  dcgan.fit(model_dir_path=model_dir_path,
            image_label_pairs=image_label_pairs,
            snapshot_dir_path='snapshots',
            batch_size=batch_size,
            epochs=epochs)

Epoch: 1/200 epoch_duration: 339.9365015029907 discriminator_loss: 1.862405559541741 generator_loss: 0.8810346476094518
Epoch: 2/200 epoch_duration: 356.5037784576416 discriminator_loss: 1.6974921808715209 generator_loss: 0.9056111268136321
Epoch: 3/200 epoch_duration: 375.8615710735321 discriminator_loss: 1.6480528653316706 generator_loss: 0.9252977371979004
Epoch: 4/200 epoch_duration: 395.00529885292053 discriminator_loss: 1.6176933720707893 generator_loss: 0.9273769383851759
Epoch: 5/200 epoch_duration: 418.76893377304077 discriminator_loss: 1.5996374668274433 generator_loss: 0.9202556036956484
Epoch: 6/200 epoch_duration: 450.0394198894501 discriminator_loss: 1.587459905428404 generator_loss: 0.9210686414250949
Epoch: 7/200 epoch_duration: 482.3057882785797 discriminator_loss: 1.5788968922367628 generator_loss: 0.9236420818846601
Epoch: 8/200 epoch_duration: 509.61283898353577 discriminator_loss: 1.5730991708389013 generator_loss: 0.9382573537118304
Epoch: 9/200 epoch_duration: 54

In [None]:
dcgan.generator.load_weights(model_dir_path+'/dc_gan-generator-weights.h5')
skipthought_model=skipthoughts.load_model()

In [None]:
caption=['The woman has oval face. She has straight hair which is brown in colour. The smiling, young attractive woman has heavy makeup. She’s wearing lipstick.']
dcgan.generate_image_from_text(caption,skipthought_model)

In [None]:
!rm -r snap*