In [1]:
pip install torchfile

Collecting torchfile
  Downloading torchfile-0.1.0.tar.gz (5.2 kB)
  Preparing metadata (setup.py) ... [?25l- \ done
[?25hBuilding wheels for collected packages: torchfile
  Building wheel for torchfile (setup.py) ... [?25l- \ | done
[?25h  Created wheel for torchfile: filename=torchfile-0.1.0-py3-none-any.whl size=5692 sha256=cab3eaa3e852b89d8730766d242de5112f3a9fa3f4a2679117fd0bc8858dec79
  Stored in directory: /root/.cache/pip/wheels/c7/e9/87/1c51daf8e468d5c14931f8ac3344880f903ba96b063675cac2
Successfully built torchfile
Installing collected packages: torchfile
Successfully installed torchfile-0.1.0
Note: you may need to restart the kernel to use updated packages.


In [2]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
from tensorboard import summary
from six.moves import range
import PIL
from PIL import Image

from torch.autograd import Variable

from torch.nn import init
import numpy as np 
import pandas as pd 

import os
print(os.listdir("../input"))

import torch.utils.data as data
import os.path
import pickle
import random

import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
from torchvision.models import inception_v3
import matplotlib.pyplot as plt

import argparse
import sys
import torchfile

import errno
import torch.nn as nn
import torchvision.utils as vutils

import torch.optim as optim
import time


import torch.nn.parallel

['stage1generator', 'text-to-image-cub-200-2011']


In [3]:
def load_class_ids_filenames(class_id_path, filename_path):
    with open(class_id_path, 'rb') as file:
        class_id=pickle.load(file, encoding='latin1')

    with open(filename_path, 'rb') as file:
        filename=pickle.load(file, encoding='latin1')

    return class_id, filename

def load_text_embeddings(text_embeddings):
    with open(text_embeddings, 'rb') as file:
        embeds=pickle.load(file, encoding='latin1')
        embeds=np.array(embeds)

    return embeds

def load_bbox(data_path):
    bbox_path=data_path+'/bounding_boxes.txt'
    image_path=data_path+'/images.txt'
    bbox_df=pd.read_csv(bbox_path, delim_whitespace=True, header=None).astype(int)
    filename_df=pd.read_csv(image_path, delim_whitespace=True, header=None)

    filenames=filename_df[1].tolist()
    bbox_dict={i[:-4]:[] for i in filenames[:2]}

    for i in range(0, len(filenames)):
        bbox=bbox_df.iloc[i][1:].tolist()
        dict_key=filenames[i][:-4]
        bbox_dict[dict_key]=bbox

    return bbox_dict

def load_images(image_path, bounding_box, size, transform):
    image=Image.open(image_path).convert('RGB')
    w, h=image.size
    if bounding_box is not None:
        r=int(np.maximum(bounding_box[2], bounding_box[3])*0.75)
        c_x=int((bounding_box[0]+bounding_box[2])/2)
        c_y=int((bounding_box[1]+bounding_box[3])/2)
        y1=np.maximum(0, c_y-r)
        y2=np.minimum(h, c_y+r)
        x1=np.maximum(0, c_x-r)
        x2=np.minimum(w, c_x+r)
        image=image.crop([x1, y1, x2, y2])

    image=image.resize(size, PIL.Image.BILINEAR)
    if transform is not None:
        image=transform(image)
    return image

def load_data(size, transform):
    """Loads the Dataset.
    """
    data_dir="/kaggle/input/text-to-image-cub-200-2011/CUB-200-2011"
    train_dir=data_dir+"/train"
    test_dir=data_dir+"/test"
    embeddings_path_train=train_dir+"/char-CNN-RNN-embeddings.pickle"
    embeddings_path_test=test_dir+"/char-CNN-RNN-embeddings.pickle"
    filename_path_train=train_dir+"/filenames.pickle"
    filename_path_test=test_dir+"/filenames.pickle"
    class_id_path_train=train_dir+"/class_info.pickle"
    class_id_path_test=test_dir+"/class_info.pickle"
    if(args.TEST_OR_TRAIN=='train'):
        class_id, filenames=load_class_ids_filenames(class_id_path_train, filename_path_train)
        embeddings=load_text_embeddings(embeddings_path_train)
    else:
        class_id, filenames=load_class_ids_filenames(class_id_path_test, filename_path_test)
        embeddings=load_text_embeddings(embeddings_path_test)
    bbox_dict=load_bbox(data_dir)

    x, y, embeds=[], [], []

    for i, filename in enumerate(filenames):
        bbox=bbox_dict[filename]

        try:    
            image_path=f'{data_dir}/images/{filename}.jpg'
            image=load_images(image_path, bbox, size, transform)
            e=embeddings[i, :, :]
            embed_index=np.random.randint(0, e.shape[0]-1)
            embed=e[embed_index, :]

            x.append(np.array(image))
            y.append(class_id[i])
            embeds.append(embed)

        except Exception as e:
            print(f'{e}')
    
    x=np.array(x)
    y=np.array(y)
    embeds=np.array(embeds)
    
    return x, y, embeds

In [4]:

def KL_loss(mu, logvar):
    KLD_element=mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD=torch.mean(KLD_element).mul_(-0.5)
    return KLD

def calculate_inception_score(logits):

    probs=torch.nn.functional.softmax(logits, dim=1)

    kl_div=(probs*(torch.log(probs)-torch.log(probs.mean(dim=0)))).sum(dim=1)
    entropy=-torch.sum(probs * torch.log(probs), dim=1)

    inception_score=torch.exp(torch.mean(kl_div - entropy))
    return inception_score


def compute_discriminator_loss(disNet, realImgs, fakeImgs, reallabels, fakelabels, conditions, gpus):
    criterion=nn.BCELoss()
    batch_size=realImgs.size(0)
    cond=conditions.detach()
    fake=fakeImgs.detach()
    realfeatures=nn.parallel.data_parallel(disNet, (realImgs), gpus)
    fakefeatures=nn.parallel.data_parallel(disNet, (fake), gpus)
    inputs=(realfeatures, cond)
    reallogits=nn.parallel.data_parallel(disNet.get_cond_logits, inputs, gpus)
    disErr_real=criterion(reallogits, reallabels)
    inputs=(realfeatures[:(batch_size-1)], cond[1:])
    wrong_logits =nn.parallel.data_parallel(disNet.get_cond_logits, inputs, gpus)
    disErr_wrong=criterion(wrong_logits, fakelabels[1:])
    inputs=(fakefeatures, cond)
    fakelogits=nn.parallel.data_parallel(disNet.get_cond_logits, inputs, gpus)
    disErr_fake=criterion(fakelogits, fakelabels)

    if disNet.get_uncond_logits is not None:
        reallogits=nn.parallel.data_parallel(disNet.get_uncond_logits, (realfeatures), gpus)
        fakelogits=nn.parallel.data_parallel(disNet.get_uncond_logits, (fakefeatures), gpus)
        uncond_disErr_real=criterion(reallogits, reallabels)
        uncond_disErr_fake=criterion(fakelogits, fakelabels)
        #
        disErr=((disErr_real+uncond_disErr_real)/2.+(disErr_fake+disErr_wrong+uncond_disErr_fake)/3.)
        disErr_real=(disErr_real+uncond_disErr_real)/2.
        disErr_fake=(disErr_fake+uncond_disErr_fake)/2.
    else:
        disErr=disErr_real+(disErr_fake+disErr_wrong)*0.5
    return disErr, disErr_real.data, disErr_wrong.data, disErr_fake.data


def compute_generator_loss(disNet, fakeImgs, reallabels, conditions, gpus):
    criterion=nn.BCELoss()
    cond=conditions.detach()
    fakefeatures=nn.parallel.data_parallel(disNet, (fakeImgs), gpus)
    inputs=(fakefeatures, cond)
    fakelogits=nn.parallel.data_parallel(disNet.get_cond_logits, inputs, gpus)
    disErr_fake=criterion(fakelogits, reallabels)
    if disNet.get_uncond_logits is not None:
        fakelogits=nn.parallel.data_parallel(disNet.get_uncond_logits,(fakefeatures), gpus)
        uncond_disErr_fake=criterion(fakelogits, reallabels)
        disErr_fake += uncond_disErr_fake
    return disErr_fake


def weights_init(m):
    classname=m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 0.02)
        if m.bias is not None:
            m.bias.data.fill_(0.0)


def save_img_results(data_img, fake, epoch, image_dir):
    num=args.VIS_COUNT
    fake=fake[0:num]
    if data_img is not None:
        data_img=data_img[0:num]
        vutils.save_image(data_img, '%s/realsamples_epoch_%03d.png'%(image_dir, epoch), normalize=True)
        vutils.save_image(fake.data, '%s/fakesamples_epoch_%03d.png'%(image_dir, epoch), normalize=True)
    else:
        vutils.save_image(fake.data, '%s/lr_fakesamples_epoch_%03d.png'%(image_dir, epoch), normalize=True)


def save_model(genNet, disNet, epoch, model_dir):
    torch.save(
        genNet.state_dict(),
        '%s/genNet_epoch_%d.pth'%(model_dir, epoch))
    torch.save(
        disNet.state_dict(),
        '%s/disNet_epoch_last.pth'%(model_dir))
    print('Save G/D models')


def mkdir_p(path):
    try:
        os.makedirs(path)
    except OSError as exc: 
        if exc.errno == errno.EEXIST and os.path.isdir(path):
            pass
        else:
            raise



def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)



def upBlock(in_planes, out_planes):
    block=nn.Sequential(
        nn.Upsample(scale_factor=2, mode='nearest'),
        conv3x3(in_planes, out_planes),
        nn.BatchNorm2d(out_planes),
        nn.ReLU(True))
    return block


class ResNet(nn.Module):
    def __init__(self, channel_num):
        super(ResNet, self).__init__()
        self.block=nn.Sequential(
            conv3x3(channel_num, channel_num),
            nn.BatchNorm2d(channel_num),
            nn.ReLU(True),
            conv3x3(channel_num, channel_num),
            nn.BatchNorm2d(channel_num))
        self.relu=nn.ReLU(inplace=True)

    def forward(self, x):
        residual=x
        out=self.block(x)
        out += residual
        out=self.relu(out)
        return out
    
class CA_NET(nn.Module):
    def __init__(self):
        super(CA_NET, self).__init__()
        self.t_dim=args.DIMENSION
        self.c_dim=args.CONDITION_DIM
        self.fc=nn.Linear(self.t_dim, self.c_dim*2, bias=True)
        self.relu=nn.ReLU()

    def encode(self, text_embedding):
        x=self.relu(self.fc(text_embedding))
        mu=x[:, :self.c_dim]
        logvar=x[:, self.c_dim:]
        return mu, logvar

    def reparametrize(self, mu, logvar):
        std=logvar.mul(0.5).exp_()
        if args.CUDA:
            eps=torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps=torch.FloatTensor(std.size()).normal_()
        eps=Variable(eps)
        return eps.mul(std).add_(mu)

    def forward(self, text_embedding):
        mu, logvar=self.encode(text_embedding)
        c_code=self.reparametrize(mu, logvar)
        return c_code, mu, logvar

class D_GET_LOGITS(nn.Module):
    def __init__(self, ndf, nef, bcondition=True):
        super(D_GET_LOGITS, self).__init__()
        self.df_dim=ndf
        self.ef_dim=nef
        self.bcondition=bcondition
        if bcondition:
            self.outlogits=nn.Sequential(
                conv3x3(ndf*8+nef, ndf*8),
                nn.BatchNorm2d(ndf*8),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(ndf*8, 1, kernel_size=4, stride=4),
                nn.Sigmoid())
        else:
            self.outlogits=nn.Sequential(
                nn.Conv2d(ndf*8, 1, kernel_size=4, stride=4),
                nn.Sigmoid())

    def forward(self, h_code, c_code=None):
        if self.bcondition and c_code is not None:
            c_code=c_code.view(-1, self.ef_dim, 1, 1)
            c_code=c_code.repeat(1, 1, 4, 4)
            h_c_code=torch.cat((h_code, c_code), 1)
        else:
            h_c_code=h_code

        output=self.outlogits(h_c_code)
        return output.view(-1)
    

In [5]:

class STAGE1_G(nn.Module):
    def __init__(self):
        super(STAGE1_G, self).__init__()
        self.gf_dim=args.GF_DIM*8
        self.ef_dim=args.CONDITION_DIM
        self.z_dim=args.Z_DIM
        self.define_module()

    def define_module(self):
        ninput=self.z_dim+self.ef_dim
        ngf=self.gf_dim
        self.ca_net=CA_NET()

        self.fc=nn.Sequential(
            nn.Linear(ninput, ngf*4*4, bias=False),
            nn.BatchNorm1d(ngf*4*4),
            nn.ReLU(True))


        self.upsample1=upBlock(ngf, ngf//2)
        self.upsample2=upBlock(ngf//2, ngf//4)
        self.upsample3=upBlock(ngf//4, ngf//8)
        self.upsample4=upBlock(ngf//8, ngf//16)
        self.img=nn.Sequential(conv3x3(ngf//16, 3), nn.Tanh())

    def forward(self, text_embedding, noise):
        c_code, mu, logvar=self.ca_net(text_embedding)
        z_c_code=torch.cat((noise, c_code), 1)
        h_code=self.fc(z_c_code)

        h_code=h_code.view(-1, self.gf_dim, 4, 4)
        h_code=self.upsample1(h_code)
        h_code=self.upsample2(h_code)
        h_code=self.upsample3(h_code)
        h_code=self.upsample4(h_code)
        fakeImg=self.img(h_code)
        return None, fakeImg, mu, logvar


class STAGE1_D(nn.Module):
    def __init__(self):
        super(STAGE1_D, self).__init__()
        self.df_dim=args.DF_DIM
        self.ef_dim=args.CONDITION_DIM
        self.define_module()

    def define_module(self):
        ndf, nef=self.df_dim, self.ef_dim
        self.encode_img=nn.Sequential(
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.get_cond_logits=D_GET_LOGITS(ndf, nef)
        self.get_uncond_logits=None

    def forward(self, image):
        img_embedding=self.encode_img(image)

        return img_embedding

In [6]:
class GANTrainer(object):
    def __init__(self, output_dir):
        if args.FLAG:
            self.model_dir=os.path.join(output_dir, 'Model')
            self.image_dir=os.path.join(output_dir, 'Image')
            self.log_dir=os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)

        self.max_epoch=args.MAX_EPOCH
        self.snapshot_interval=args.SNAPSHOT_INTERVAL

        s_gpus=args.GPU_ID.split(',')
        self.gpus=[int(ix) for ix in s_gpus]
        self.num_gpus=len(self.gpus)
        self.batch_size=args.BATCH_SIZE*self.num_gpus
        self.output_dir=output_dir
        cudnn.benchmark=True
    def stage_I_network(self):
        genNet=STAGE1_G()
        genNet.apply(weights_init)
        print(genNet)
        disNet=STAGE1_D()
        disNet.apply(weights_init)
        print(disNet)
        print('***********************************************************')

        if args.NET_G != '':
            print('generator 1')
            print('Load from: ', args.NET_G)
        if args.NET_D != '':
            print('discriminator 1')
            print('Load from: ', args.NET_D)
        if args.CUDA:
            genNet.cuda()
            disNet.cuda()
        return genNet, disNet
        
    def stage_II_network(self):

        Stage1_G=STAGE1_G()
        genNet=STAGE2_G(Stage1_G)
        genNet.apply(weights_init)
        if args.NET_G != '':
            print('Load from: ', args.NET_G)
            state_dict = torch.load(args.NET_G,map_location=lambda storage, loc: storage)
            genNet.load_state_dict(state_dict)
        if args.STAGE1_G != '':
            print('Load from: ', args.STAGE1_G)
            state_dict = torch.load(args.STAGE1_G,map_location=lambda storage, loc: storage)
            genNet.STAGE1_G.load_state_dict(state_dict)
        else:
            print("Please give the Stage1_G path")
            return

        disNet=STAGE2_D()
        disNet.apply(weights_init)
        if args.NET_D != '':
            print('Load from: ', args.NET_D)
            state_dict = torch.load(args.NET_D,map_location=lambda storage, loc: storage)
            disNet.load_state_dict(state_dict)

        if args.CUDA:
            genNet.cuda()
            disNet.cuda()
        return genNet, disNet
    
    def train(self, dataset, stage=1):
        if stage == 1:
            genNet, disNet=self.stage_I_network()
        else:
            genNet, disNet=self.stage_II_network()

        nz=args.Z_DIM
        batch_size=self.batch_size
        noise=Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise=Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),volatile=True)
        reallabels=Variable(torch.FloatTensor(batch_size).fill_(1))
        fakelabels=Variable(torch.FloatTensor(batch_size).fill_(0))
        if args.CUDA:
            noise, fixed_noise=noise.cuda(), fixed_noise.cuda()
            reallabels, fakelabels=reallabels.cuda(), fakelabels.cuda()

        generator_lr=args.GENERATOR_LR
        discriminator_lr=args.DISCRIMINATOR_LR
        lr_decay_step=args.LR_DECAY_EPOCH
        disOptimizer=optim.Adam(disNet.parameters(), lr=args.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        genNet_para=[]
        for p in genNet.parameters():
            if p.requires_grad:
                genNet_para.append(p)
        genOptimizer=optim.Adam(genNet_para,lr=args.GENERATOR_LR,betas=(0.5, 0.999))
        count=0
        c=0
        fake=None
        for epoch in range(args.START_EPOCHS, self.max_epoch):
            torch.cuda.empty_cache()
            print('Inside Epoch:', (epoch+1))
            print('---------------------')
            start_t=time.time()
            if epoch%lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in genOptimizer.param_groups:
                    param_group['lr']=generator_lr
                discriminator_lr *= 0.5
                for param_group in disOptimizer.param_groups:
                    param_group['lr']=discriminator_lr
            br=0
            c2=0
            for i in range(len(dataset[0])//batch_size):
                if(i%30==0):
                    print(">>", end=" ")
                realImg_cpu=torch.tensor(dataset[0][(i*batch_size):((i+1)*batch_size)])
                txt_embedding=torch.tensor(dataset[2][(i*batch_size):((i+1)*batch_size)])
                realImgs=Variable(realImg_cpu)
                txt_embedding=Variable(txt_embedding)
                if args.CUDA:
                    realImgs=realImgs.cuda()
                    txt_embedding=txt_embedding.cuda()
                noise.data.normal_(0, 1)
                inputs=(txt_embedding, noise)
                _, fakeImgs, mu, logvar=nn.parallel.data_parallel(genNet, inputs, self.gpus)
                
                disNet.zero_grad()
                disErr, disErr_real, disErr_wrong, disErr_fake=compute_discriminator_loss(disNet, realImgs, fakeImgs,reallabels, fakelabels, mu, self.gpus)
                disErr.backward()
                disOptimizer.step()
                
                
                genNet.zero_grad()
                genErr=compute_generator_loss(disNet, fakeImgs,
                                              reallabels, mu, self.gpus)
                kl_loss=KL_loss(mu, logvar)
                genErr_total=genErr+kl_loss*args.KL
                genErr_total.backward()
                genOptimizer.step()
                
                br=br+1
                
                count=count+1
                if(i==len(dataset[0])//batch_size-1):
                    lr_fake, fake, _, _=nn.parallel.data_parallel(genNet, inputs, self.gpus) 
                    end_t=time.time()
                    c=1
                    print('''[%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                             Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                             Total Time: %.2fsec
                          '''
                         %(epoch, self.max_epoch,
                             disErr.data, genErr.data, kl_loss.data,
                             disErr_real, disErr_wrong, disErr_fake, (end_t-start_t)))
                    print("Learning Rate:", args.DISCRIMINATOR_LR)
            if epoch%self.snapshot_interval == 0:
                save_model(genNet, disNet, epoch, self.model_dir)
                save_img_results(realImg_cpu, fake, epoch, self.image_dir)
            grid=vutils.make_grid(fake[0:args.VIS_COUNT], normalize=True, scale_each=True)
            grid_pil=ToPILImage()(grid)
            plt.imshow(grid_pil)
            plt.axis('off')
            plt.show()
            plt.close()
        
        save_model(genNet, disNet, self.max_epoch, self.model_dir)
        save_img_results(None, fake, self.max_epoch, self.image_dir)

In [7]:
class STAGE2_G(nn.Module):
    def __init__(self, STAGE1_G):
        super(STAGE2_G, self).__init__()
        self.gf_dim=args.GF_DIM
        self.ef_dim=args.CONDITION_DIM
        self.z_dim=args.Z_DIM
        self.STAGE1_G=STAGE1_G
            
        for param in self.STAGE1_G.parameters():
            param.requires_grad=False
        self.define_module()

    def _make_layer(self, block, channel_num):
        layers=[]
        for i in range(args.R_NUM):
            layers.append(block(channel_num))
        return nn.Sequential(*layers)

    def define_module(self):
        ngf=self.gf_dim
        self.ca_net=CA_NET()
        self.encoder=nn.Sequential(
            conv3x3(3, ngf),
            nn.ReLU(True),
            nn.Conv2d(ngf, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(True),
            nn.Conv2d(ngf*2, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True))
        self.hr_joint=nn.Sequential(
            conv3x3(self.ef_dim+ngf*4, ngf*4),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True))
        self.residual=self._make_layer(ResNet, ngf*4)
        self.upsample1=upBlock(ngf*4, ngf*2)
        self.upsample2=upBlock(ngf*2, ngf)
        self.upsample3=upBlock(ngf, ngf//2)
        self.upsample4=upBlock(ngf//2, ngf//4)
        self.img=nn.Sequential(
            conv3x3(ngf//4, 3),
            nn.Tanh())

    def forward(self, text_embedding, noise):
        _, stage1_img, _, _=self.STAGE1_G(text_embedding, noise)
        stage1_img=stage1_img.detach()
        encoded_img=self.encoder(stage1_img)

        c_code, mu, logvar=self.ca_net(text_embedding)
        c_code=c_code.view(-1, self.ef_dim, 1, 1)
        c_code=c_code.repeat(1, 1, 16, 16)
        i_c_code=torch.cat([encoded_img, c_code], 1)
        h_code=self.hr_joint(i_c_code)
        h_code=self.residual(h_code)

        h_code=self.upsample1(h_code)
        h_code=self.upsample2(h_code)
        h_code=self.upsample3(h_code)
        h_code=self.upsample4(h_code)

        fakeImg=self.img(h_code)
        return stage1_img, fakeImg, mu, logvar


class STAGE2_D(nn.Module):
    def __init__(self):
        super(STAGE2_D, self).__init__()
        self.df_dim=args.DF_DIM
        self.ef_dim=args.CONDITION_DIM
        self.define_module()

    def define_module(self):
        ndf, nef=self.df_dim, self.ef_dim
        self.encode_img=nn.Sequential(
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True), 
            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2, inplace=True), 
            nn.Conv2d(ndf*8, ndf*16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*16),
            nn.LeakyReLU(0.2, inplace=True), 
            nn.Conv2d(ndf*16, ndf*32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*32),
            nn.LeakyReLU(0.2, inplace=True),
            conv3x3(ndf*32, ndf*16),
            nn.BatchNorm2d(ndf*16),
            nn.LeakyReLU(0.2, inplace=True), 
            conv3x3(ndf*16, ndf*8),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2, inplace=True)  
        )

        self.get_cond_logits=D_GET_LOGITS(ndf, nef, bcondition=True)
        self.get_uncond_logits=D_GET_LOGITS(ndf, nef, bcondition=False)

    def forward(self, image):
        img_embedding=self.encode_img(image)

        return img_embedding

In [8]:
class GANEval(object):
    def __init__(self, output_dir):
        if args.FLAG:
            self.model_dir=os.path.join(output_dir, 'Model')
            self.image_dir=os.path.join(output_dir, 'Image')
            self.log_dir=os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            
        s_gpus=args.GPU_ID.split(',')
        self.gpus=[int(ix) for ix in s_gpus]
        self.num_gpus=len(self.gpus)
        self.batch_size=args.BATCH_SIZE*self.num_gpus
        self.output_dir=output_dir
        cudnn.benchmark=True
    
    def stage_I_network(self):
        genNet=STAGE1_G()
        genNet.apply(weights_init)
        print(genNet)
        print('***********************************************************')

        if args.STAGE1_G != '':
            print('Generator 1')
            print('Load from: ', args.STAGE1_G)
            state_dict = torch.load(args.STAGE1_G,map_location=lambda storage, loc: storage)
            genNet.load_state_dict(state_dict)
        else:
            print('Please provide a generator model path!')
            return
        if args.CUDA:
            genNet.cuda()
        return genNet
    
    def stage_II_network(self):

        Stage1_G=STAGE1_G()
        genNet=STAGE2_G(Stage1_G)
        genNet.apply(weights_init)
        if args.NET_G != '':
            print('Load from: ', args.NET_G)
            state_dict = torch.load(args.NET_G,map_location=lambda storage, loc: storage)
            genNet.load_state_dict(state_dict)
        else:
            print("Please give NET_G path")
            return
        if args.STAGE1_G != '':
            print('Load from: ', args.STAGE1_G)
            state_dict = torch.load(args.STAGE1_G,map_location=lambda storage, loc: storage)
            genNet.STAGE1_G.load_state_dict(state_dict)
        else:
            print("Please give the Stage1_G path")
            return

        if args.CUDA:
            genNet.cuda()
        return genNet
    
    def evaluate(self, dataset, stage):
        if stage == 1:
            genNet=self.stage_I_network()
        else:
            genNet=self.stage_II_network()

        nz=args.Z_DIM
        batch_size=self.batch_size
        noise=Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise=Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),volatile=True)
        if args.CUDA:
            noise=noise.cuda()
        for i in range(len(dataset[0])//batch_size):
            if(i%30==0):
                print(">>", end=" ")
            realImg_cpu=torch.tensor(dataset[0][(i*batch_size):((i+1)*batch_size)])
            txt_embedding=torch.tensor(dataset[2][(i*batch_size):((i+1)*batch_size)])
            realImgs=Variable(realImg_cpu)
            txt_embedding=Variable(txt_embedding)
            if args.CUDA:
                realImgs=realImgs.cuda()
                txt_embedding=txt_embedding.cuda()
            noise.data.normal_(0, 1)
            inputs=(txt_embedding, noise)
            lr_fakeImgs, fakeImgs, mu, logvar=nn.parallel.data_parallel(genNet, inputs, self.gpus)
            save_img_results(lr_fakeImgs, fakeImgs, i, self.image_dir)

In [9]:
class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)

if __name__ == '__main__':
    params=dict()
    # Parameters
    parser = argparse.ArgumentParser()
    params = dict()
    params['CONFIG_NAME']='stageI'
    params['DATASET_NAME']='cub'
    params['EMBEDDING_TYPE']='cnn-rnn'
    params['GPU_ID']='0'
    params['CUDA']='TRUE'
    params['WORKERS']=4
    params['STAGE1_G']='/kaggle/input/stage1generator/netG_epoch_600.pth'
    params['NET_G']='/kaggle/input/stage1generator/S2_genNet_epoch_250.pth'
    params['NET_D']='/kaggle/input/stage1generator/S2_disNet_epoch_250.pth'
    params['DATA_DIR']='../input/text-to-image-cub-200-2011/CUB-200-2011/'
    params['IMG_DIR'] = '../input/text-to-image-cub-200-2011/CUB-200-2011/images'
    params['VIS_COUNT']=64
    params['Z_DIM']=100
    params['IMSIZE']=256
    params['STAGE']=1
    params['FLAG']='TRUE'
    params['BATCH_SIZE']=30
    params['START_EPOCHS']=250
    params['MAX_EPOCH']=370
    params['SNAPSHOT_INTERVAL']=10
    params['LR_DECAY_EPOCH']=20
    params['DISCRIMINATOR_LR']=6.25e-05
    params['GENERATOR_LR']=6.25e-05
    params['KL']=2.0
    
    params['CONDITION_DIM']=128
    params['DF_DIM']=96
    params['GF_DIM']=192
    params['R_NUM']=2
    params['DIMENSION']=1024
    params['TEST_OR_TRAIN']='test'
    args=Struct(**params)
    
    manualSeed=random.randint(1, 10000)
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.cuda.manual_seed_all(manualSeed)
    
    output_dir='/kaggle/working/'
    print(output_dir)
    num_gpu=len(args.GPU_ID.split(','))
    

/kaggle/working/


In [10]:
if args.FLAG:
    image_transform=transforms.Compose([
        transforms.RandomCrop(args.IMSIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    dataset=load_data((args.IMSIZE, args.IMSIZE), transform=image_transform)
    assert dataset
    dataloader=torch.utils.data.DataLoader(
        dataset, batch_size=args.BATCH_SIZE*num_gpu,
        drop_last=True, shuffle=True, num_workers=int(args.WORKERS))

In [11]:
print(len(dataset[2]))
if args.FLAG:
    if(args.TEST_OR_TRAIN=='train'):
        algo=GANTrainer(output_dir)
        algo.train(dataset, args.STAGE)
    else:
        algo=GANEval(output_dir)
        algo.evaluate(dataset, args.STAGE)

2933
STAGE1_G(
  (ca_net): CA_NET(
    (fc): Linear(in_features=1024, out_features=256, bias=True)
    (relu): ReLU()
  )
  (fc): Sequential(
    (0): Linear(in_features=228, out_features=24576, bias=False)
    (1): BatchNorm1d(24576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (upsample1): Sequential(
    (0): Upsample(scale_factor=2.0, mode='nearest')
    (1): Conv2d(1536, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (2): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU(inplace=True)
  )
  (upsample2): Sequential(
    (0): Upsample(scale_factor=2.0, mode='nearest')
    (1): Conv2d(768, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (2): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU(inplace=True)
  )
  (upsample3): Sequential(
    (0): Upsample(scale_factor=2.0, mode='nearest')
    (1): 

  fixed_noise=Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),volatile=True)


>> >> >> >> 

In [12]:
torch.cuda.empty_cache()