# MISCC

## init:

In [2]:

from __future__ import division
from __future__ import print_function


## config:

In [5]:
from __future__ import division
from __future__ import print_function
import os.path as osp
import numpy as np
from easydict import EasyDict as edict


__C = edict()
cfg = __C

# Dataset name: flowers, birds
__C.DATASET_NAME = 'coco'
__C.EMBEDDING_TYPE = 'cnn-rnn'
__C.CONFIG_NAME = ''
__C.GPU_ID = '0'
__C.CUDA = True
__C.WORKERS = 6

__C.NET_G = ''
__C.NET_D = ''
__C.STAGE1_G = ''
__C.DATA_DIR = ''
__C.IMG_DIR = ''
__C.VIS_COUNT = 64

__C.Z_DIM = 100
__C.IMSIZE = 64
__C.STAGE = 1

__C.USE_LOCAL_PATHWAY = True
__C.USE_BBOX_LAYOUT = True

# Training options
__C.TRAIN = edict()
__C.TRAIN.FLAG = True
__C.TRAIN.BATCH_SIZE = 64
__C.TRAIN.MAX_EPOCH = 600
__C.TRAIN.SNAPSHOT_INTERVAL = 50
__C.TRAIN.PRETRAINED_MODEL = ''
__C.TRAIN.PRETRAINED_EPOCH = 600
__C.TRAIN.LR_DECAY_EPOCH = 600
__C.TRAIN.DISCRIMINATOR_LR = 2e-4
__C.TRAIN.GENERATOR_LR = 2e-4

__C.TRAIN.COEFF = edict()
__C.TRAIN.COEFF.KL = 2.0

# Modal options
__C.GAN = edict()
__C.GAN.CONDITION_DIM = 128
__C.GAN.DF_DIM = 64
__C.GAN.GF_DIM = 128
__C.GAN.R_NUM = 4

__C.TEXT = edict()
__C.TEXT.DIMENSION = 1024


def _merge_a_into_b(a, b):
    """Merge config dictionary a into config dictionary b, clobbering the
    options in b whenever they are also specified in a.
    """
    if type(a) is not edict:
        return

    for k, v in a.iteritems():
        # a must specify keys that are in b
        if not b.has_key(k):
            raise KeyError('{} is not a valid config key'.format(k))

        # the types must match, too
        old_type = type(b[k])
        if old_type is not type(v):
            if isinstance(b[k], np.ndarray):
                v = np.array(v, dtype=b[k].dtype)
            else:
                raise ValueError(('Type mismatch ({} vs. {}) '
                                  'for config key: {}').format(type(b[k]),
                                                               type(v), k))

        # recursively merge dicts
        if type(v) is edict:
            try:
                _merge_a_into_b(a[k], b[k])
            except:
                print('Error under config key: {}'.format(k))
                raise
        else:
            b[k] = v


def cfg_from_file(filename):
    """Load a config file and merge it into the default options."""
    import yaml
    with open(filename, 'r') as f:
        yaml_cfg = edict(yaml.load(f))

    _merge_a_into_b(yaml_cfg, __C)


## datasets:

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals


import torch.utils.data as data
import PIL
import os
import os.path
import random
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import torch
import sys

if sys.version_info[0] == 2:
    import cPickle as pickle
else:
    import pickle

#from miscc.config import cfg


class TextDataset(data.Dataset):
    def __init__(self, data_dir, img_dir, imsize, split='train', embedding_type='cnn-rnn', transform=None, crop=True, stage=1):

        self.transform = transform #Transforms are common image transformations. They can be chained together using Compose
        self.imsize = imsize # used as target size for the crop image
        self.crop = crop # boolean for crop the image
        self.data = [] # not used !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        self.data_dir = data_dir # the all data dirctory : '../../../data/MS-COCO'
        self.split_dir = os.path.join(data_dir, split) # the split data directory : '../../../data/MS-COCO/train'
        self.img_dir = img_dir # the image dirctory : "../../../data/MS-COCO/train/train2014"
        self.max_objects = 3 # three objects per image
        self.stage = stage # satge 1 or 2

        self.filenames = self.load_filenames() # load filenames from split dir as np array
        self.bboxes = self.load_bboxes() # load bboxes from split dir as np array
        self.labels = self.load_labels() # load labels from split dir as np array
        self.embeddings = self.load_embedding(self.split_dir, embedding_type)# load embeddings from splitdir as np array

    def get_img(self, img_path):
        '''
        return image given an img_path in PIL RGB format and apply transforms specified in the constructor
        '''
        img = Image.open(img_path).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img

    def load_bboxes(self):
        '''
        return an array of bboxes of shape : (82783, 3, 4) for the train2014 split
        '''
        bbox_path = os.path.join(self.split_dir, 'bboxes.pickle')
        with open(bbox_path, "rb") as f:
            bboxes = pickle.load(f)  ###############encoding=latin1!!!!!!!!!!!!!!!!!!!!!
            bboxes = np.array(bboxes)
        return bboxes

    def load_labels(self):
        '''
        return an array of labels of shape : (82783, 3, 1) for the train2014 split
        '''
        label_path = os.path.join(self.split_dir, 'labels.pickle')
        with open(label_path, "rb") as f:
            labels = pickle.load(f)
            labels = np.array(labels)
        return labels
"""
    def load_all_captions(self):
        caption_dict = {}
        for key in self.filenames:
            caption_name = '%s/text/%s.txt' % (self.data_dir, key)
            captions = self.load_captions(caption_name)
            caption_dict[key] = captions
        return caption_dict

    def load_captions(self, caption_name):
        cap_path = caption_name
        with open(cap_path, "r") as f:
            captions = f.read().decode('utf8').split('\n')
        captions = [cap.replace("\ufffd\ufffd", " ") for cap in captions if len(cap) > 0]
        return captions

"""

    def load_embedding(self, data_dir, embedding_type):
        '''
        return an array of embeddings of shape : (82783, 5, 1024) for the train2014 split
        '''
        if embedding_type == 'cnn-rnn':
            embedding_filename = '/char-CNN-RNN-embeddings.pickle'
        elif embedding_type == 'cnn-gru':
            embedding_filename = '/char-CNN-GRU-embeddings.pickle'
        elif embedding_type == 'skip-thought':
            embedding_filename = '/skip-thought-embeddings.pickle'

        with open(data_dir + embedding_filename, 'rb') as f:
            embeddings = pickle.load(f)
            embeddings = np.array(embeddings)
        return embeddings

    def load_filenames(self):
        '''
        return a list of filenames of shape : (82783) for the train2014 split
        '''
        filepath = os.path.join(self.split_dir, 'filenames.pickle')
        with open(filepath, 'rb') as f:
            filenames = pickle.load(f)
        print('Load filenames from: %s (%d)' % (filepath, len(filenames)))
        return filenames

    def crop_imgs(self, image, bbox):
        '''
        crops the image to a target size around the bounding box and adjsuts the bounding box
        *check the USE_GITHUB for a sample of the resluts*
        returns transformed image : a tensor of shape torch.Size([3, self.imsize, self.imsize])
        bbox_scaled : a list of 1 or 2 bboxes ,depending on the stage , each of size [3,4]
        '''
        ori_size = image.shape[1]          #if stage=1  ori = 76  imsize=64
        imsize = self.imsize

        flip_img = random.random() < 0.5
        img_crop = ori_size - self.imsize        #12
        h1 = int(np.floor((img_crop) * np.random.random()))
        w1 = int(np.floor((img_crop) * np.random.random()))

        if self.stage == 1:
            bbox_scaled = np.zeros_like(bbox)     #3,4
            bbox_scaled[...] = -1.0

            for idx in range(self.max_objects):
                bbox_tmp = bbox[idx]
                if bbox_tmp[0] == -1:
                    break

                x_new = max(bbox_tmp[0] * float(ori_size) - h1, 0) / float(imsize)
                y_new = max(bbox_tmp[1] * float(ori_size) - w1, 0) / float(imsize)

                width_new = min((float(ori_size)/imsize) * bbox_tmp[2], 1.0)
                if x_new + width_new > 0.999:
                    width_new = 1.0 - x_new - 0.001

                height_new = min((float(ori_size)/imsize) * bbox_tmp[3], 1.0)
                if y_new + height_new > 0.999:
                    height_new = 1.0 - y_new - 0.001

                if flip_img:
                    x_new = 1.0-x_new-width_new

                bbox_scaled[idx] = [x_new, y_new, width_new, height_new]
        else:
            # need two bboxes for stage 1 G and stage 2 G
            bbox_scaled = [np.zeros_like(bbox), np.zeros_like(bbox)]
            bbox_scaled[0][...] = -1.0
            bbox_scaled[1][...] = -1.0

            for idx in range(self.max_objects):
                bbox_tmp = bbox[idx]
                if bbox_tmp[0] == -1:
                    break

                # scale bboxes for stage 1 G
                stage1_size = 64
                stage1_ori_size = 76
                x_new = max(bbox_tmp[0] * float(stage1_ori_size) - h1, 0) / float(stage1_size)
                y_new = max(bbox_tmp[1] * float(stage1_ori_size) - w1, 0) / float(stage1_size)

                width_new = min((float(stage1_ori_size) / stage1_size) * bbox_tmp[2], 1.0)
                if x_new + width_new > 0.999:
                    width_new = 1.0 - x_new - 0.001

                height_new = min((float(stage1_ori_size) / stage1_size) * bbox_tmp[3], 1.0)
                if y_new + height_new > 0.999:
                    height_new = 1.0 - y_new - 0.001

                if flip_img:
                    x_new = 1.0 - x_new - width_new

                bbox_scaled[0][idx] = [x_new, y_new, width_new, height_new]

                # scale bboxes for stage 2 G
                x_new = max(bbox_tmp[0] * float(ori_size) - h1, 0) / float(imsize)
                y_new = max(bbox_tmp[1] * float(ori_size) - w1, 0) / float(imsize)

                width_new = min((float(ori_size) / imsize) * bbox_tmp[2], 1.0)
                if x_new + width_new > 0.999:
                    width_new = 1.0 - x_new - 0.001

                height_new = min((float(ori_size) / imsize) * bbox_tmp[3], 1.0)
                if y_new + height_new > 0.999:
                    height_new = 1.0 - y_new - 0.001

                if flip_img:
                    x_new = 1.0 - x_new - width_new

                bbox_scaled[1][idx] = [x_new, y_new, width_new, height_new]


        cropped_image = image[:, w1: w1 + imsize, h1: h1 + imsize]

        if flip_img:
            idx = [i for i in reversed(range(cropped_image.shape[2]))]
            idx = torch.LongTensor(idx) ####################problem ; dont think it's anymore
            transformed_image = torch.index_select(cropped_image, 2, idx)
        else:
            transformed_image = cropped_image

        return transformed_image, bbox_scaled


    def __getitem__(self, index):
        '''
        Must implemnt for pytorch dttaset classes
        returns inputs for the model network
        '''
        # load image
        key = self.filenames[index]
        img_name = self.img_dir +"/" + key + ".jpg" # The image_id and index does not crosspond 'use index of for testing'
        img = self.get_img(img_name) #PIL image in RGB format and applied transforms

        # load bbox
        bbox = self.bboxes[index] # (82783, 3, 4) ==> (3, 4) np_array

        # load label
        label = self.labels[index] # (82783, 3, 1) ==> (3, 1) np_array

        # load caption embedding
        embeddings = self.embeddings[index, :, :]# (82783, 5, 1024) ==> ( 5, 1024)
        embedding_ix = random.randint(0, embeddings.shape[0]-1) # random intger between 0 - 4
        embedding = embeddings[embedding_ix, :] #( 5, 1024) ==> (1024,) np_array

        if self.crop: # if true
            img, bbox = self.crop_imgs(img, bbox)
            # img : a tensor of shape torch.Size([3, self.imsize, self.imsize])
            # bbox : stage1 (1, 3,4) , stage2(2 , 3,4) python lists

        return img, bbox, label, embedding

    def __len__(self):
        '''
        Must implemnt for pytorch dttaset classes
        return length of - number of items in - the dataset
        '''
        return len(self.filenames)


## Utils:

In [None]:
import os
import errno
import numpy as np
import cPickle as pickle
import glob

from copy import deepcopy
from miscc.config import cfg

from torch.nn import init
import torch
import torch.nn as nn
import torchvision.utils as vutils
from torch.autograd import grad
from torch.autograd import Variable


def compute_transformation_matrix_inverse(bbox):
    x, y = bbox[:, 0], bbox[:, 1]
    w, h = bbox[:, 2], bbox[:, 3]

    scale_x = 1.0 / w
    scale_y = 1.0 / h

    t_x = 2 * scale_x * (0.5 - (x + 0.5 * w))
    t_y = 2 * scale_y * (0.5 - (y + 0.5 * h))

    zeros = torch.cuda.FloatTensor(bbox.shape[0],1).fill_(0)

    transformation_matrix = torch.cat([scale_x.unsqueeze(-1), zeros, t_x.unsqueeze(-1),
                                       zeros, scale_y.unsqueeze(-1), t_y.unsqueeze(-1)], 1).view(-1, 2, 3)

    return transformation_matrix                  #384, 2, 3


def compute_transformation_matrix(bbox):
    x, y = bbox[:, 0], bbox[:, 1]
    w, h = bbox[:, 2], bbox[:, 3]

    scale_x = w
    scale_y = h

    t_x = 2 * ((x + 0.5 * w) - 0.5)
    t_y = 2 * ((y + 0.5 * h) - 0.5)

    zeros = torch.cuda.FloatTensor(bbox.shape[0],1).fill_(0)

    transformation_matrix = torch.cat([scale_x.unsqueeze(-1), zeros, t_x.unsqueeze(-1),
                                       zeros, scale_y.unsqueeze(-1), t_y.unsqueeze(-1)], 1).view(-1, 2, 3)

    return transformation_matrix            #384, 2, 3


def load_validation_data(datapath, ori_size=76, imsize=64):

    with open(datapath + "bboxes.pickle", "rb") as f:
        bboxes = pickle.load(f)
        bboxes = np.array(bboxes)

    with open(datapath + "labels.pickle", "rb") as f:
        labels = pickle.load(f)
        labels = np.array(labels)

    return torch.from_numpy(labels), torch.from_numpy(bboxes)


#############################
def KL_loss(mu, logvar):
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.mean(KLD_element).mul_(-0.5)
    return KLD


def compute_discriminator_loss(netD, real_imgs, fake_imgs,real_labels, fake_labels,local_label,
                               transf_matrices, transf_matrices_inv, conditions, gpus):
    criterion = nn.BCEWithLogitsLoss()
    batch_size = real_imgs.size(0)
    cond = conditions.detach()
    fake = fake_imgs.detach()
    local_label = local_label.detach()
    real_features = nn.parallel.data_parallel(netD, (real_imgs, local_label, transf_matrices, transf_matrices_inv), gpus)
    fake_features = nn.parallel.data_parallel(netD, (fake, local_label, transf_matrices, transf_matrices_inv), gpus)
    # real pairs
    inputs = (real_features, cond)
    real_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
    errD_real = criterion(real_logits, real_labels)
    # wrong pairs
    inputs = (real_features[:(batch_size-1)], cond[1:])
    wrong_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
    errD_wrong = criterion(wrong_logits, fake_labels[1:])
    # fake pairs
    inputs = (fake_features, cond)
    fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
    errD_fake = criterion(fake_logits, fake_labels)

    if netD.get_uncond_logits is not None:
        real_logits = nn.parallel.data_parallel(netD.get_uncond_logits, (real_features), gpus)
        fake_logits = nn.parallel.data_parallel(netD.get_uncond_logits, (fake_features), gpus)
        uncond_errD_real = criterion(real_logits, real_labels)
        uncond_errD_fake = criterion(fake_logits, fake_labels)
        #
        errD = ((errD_real + uncond_errD_real) / 2. +
                (errD_fake + errD_wrong + uncond_errD_fake) / 3.)
        errD_real = (errD_real + uncond_errD_real) / 2.
        errD_fake = (errD_fake + uncond_errD_fake) / 2.
    else:
        errD = errD_real + (errD_fake + errD_wrong) * 0.5
    return errD, errD_real.item(), errD_wrong.item(), errD_fake.item()


def compute_generator_loss(netD, fake_imgs, real_labels, local_label, transf_matrices, transf_matrices_inv, conditions, gpus):
    criterion = nn.BCEWithLogitsLoss()
    cond = conditions.detach()
    fake_features = nn.parallel.data_parallel(netD, (fake_imgs, local_label, transf_matrices, transf_matrices_inv), gpus)
    # fake pairs
    inputs = (fake_features, cond)
    fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
    errD_fake = criterion(fake_logits, real_labels)
    if netD.get_uncond_logits is not None:
        fake_logits = nn.parallel.data_parallel(netD.get_uncond_logits, (fake_features), gpus)
        uncond_errD_fake = criterion(fake_logits, real_labels)
        errD_fake += uncond_errD_fake
    return errD_fake


#############################
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 0.02)
        if m.bias is not None:
            m.bias.data.fill_(0.0)


#############################
def save_img_results(data_img, fake, epoch, image_dir):
    num = cfg.VIS_COUNT
    fake = fake[0:num]
    # data_img is changed to [0,1]
    if data_img is not None:
        data_img = data_img[0:num]
        vutils.save_image(
            data_img, '%s/real_samples.png' % image_dir,
            normalize=True)
        # fake.data is still [-1, 1]
        vutils.save_image(
            fake.data, '%s/fake_samples_epoch_%03d.png' %
            (image_dir, epoch), normalize=True)
    else:
        vutils.save_image(
            fake.data, '%s/lr_fake_samples_epoch_%03d.png' %
            (image_dir, epoch), normalize=True)


def save_model(netG, netD, optimG, optimD, epoch, model_dir, saveD=False, saveOptim=False, max_to_keep=5):
    checkpoint = {
        'epoch': epoch,
        'netG': netG.state_dict(),
        'optimG': optimG.state_dict() if saveOptim else {},
        'netD': netD.state_dict() if saveD else {},
        'optimD': optimD.state_dict() if saveOptim else {}}
    torch.save(checkpoint, "{}/checkpoint_{:04}.pth".format(model_dir, epoch))
    print('Save G/D models')

    if max_to_keep is not None and max_to_keep > 0:
        checkpoint_list = sorted([ckpt for ckpt in glob.glob(model_dir + "/" + '*.pth')])
        while len(checkpoint_list) > max_to_keep:
            os.remove(checkpoint_list[0])
            checkpoint_list = checkpoint_list[1:]


def mkdir_p(path):
    try:
        os.makedirs(path)
    except OSError as exc:  # Python >2.5
        if exc.errno == errno.EEXIST and os.path.isdir(path):
            pass
        else:
            raise


# model

In [1]:
import torch
import torch.nn as nn
import torch.nn.parallel
#from miscc.config import cfg
#from miscc.utils import compute_transformation_matrix, compute_transformation_matrix_inverse
from torch.autograd import Variable


def conv3x3(in_planes, out_planes, stride=1):
    '''
    3x3 convolution with padding ouptut size is the same as input size with ne w channel number equal
    to out_planes
    '''
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


# Upsale the spatial size by a factor of 2
def upBlock(in_planes, out_planes):
    '''
    Upscale output size by factor of 2 and output channels are equal to out_planes
    '''
    block = nn.Sequential(
        nn.Upsample(scale_factor=2, mode='nearest'),
        conv3x3(in_planes, out_planes),
        nn.BatchNorm2d(out_planes),
        nn.ReLU(True))
    return block


class ResBlock(nn.Module):
    '''
    Resdula block - input -> conv -> BN -> relu -> conv -> BN -> add input -> relu
    output size and channel numbers are the same as the input
    '''
    def __init__(self, channel_num):
        super(ResBlock, self).__init__()
        self.block = nn.Sequential(
            conv3x3(channel_num, channel_num),
            nn.BatchNorm2d(channel_num),
            nn.ReLU(True),
            conv3x3(channel_num, channel_num),
            nn.BatchNorm2d(channel_num))
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        out = self.block(x)
        out += residual
        out = self.relu(out)
        return out


class CA_NET(nn.Module):
    '''
    Takes caption embbeding of shape (1024,) as input
    bs is batch size
    encode        : text_embbedding(bs , 1024) -> linear (bs , self.c_dim * 2 = 256) -> relu -> output1 (bs , 256)
                    the output is then devided to : mu= output1[: , :128](bs,128) ;logvar= output1[: , 128:](bs,128) ;
    reparametrize : std = .5*e^(logvar) (bs,128) ; eps = random noraml tensor (bs,128)
                    C_code  = mu + eps*std (bs , 128)
    forward       :retuns c_code , eps ,mu
    '''
    # some code is modified from vae examples
    # (https://github.com/pytorch/examples/blob/master/vae/main.py)
    def __init__(self):
        super(CA_NET, self).__init__()
        self.t_dim = cfg.TEXT.DIMENSION #1024
        self.c_dim = cfg.GAN.CONDITION_DIM #128
        self.fc = nn.Linear(self.t_dim, self.c_dim * 2, bias=True)
        self.relu = nn.ReLU()

    def encode(self, text_embedding):
        x = self.relu(self.fc(text_embedding))
        mu = x[:, :self.c_dim]
        logvar = x[:, self.c_dim:]
        return mu, logvar

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        if cfg.CUDA:
            eps = torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)

    def forward(self, text_embedding):
        mu, logvar = self.encode(text_embedding)
        c_code = self.reparametrize(mu, logvar)
        return c_code, mu, logvar


class D_GET_LOGITS(nn.Module):
    '''
    Takes for constructre:
        number of discrimnator featurs ndf = 96 'the actcual number of cahnnels is 8*ndf =768'
        number of condition vector features nef =128
    Takes for the forward:
        the image feature map from the discrimnator h_code(bs,768x4x4)
        the conditon vector c_code (bs, 128)
    c_code(bs, 128) -> reshape(bs, 128, 1, 1) -> repeat(bs, 128, 4, 4)
    concat h_code&c_code(bs, 896, 4, 4) ->conv(bs , 768, 4, 4) -> BN -> relu
    conv(bs, 1, 1, 1) -> 
    
    return tensor vector of shape (bs,)
    
    'last 2 rows of table 2 and in figure 1 bottom row'
    '''
    def __init__(self, ndf, nef, bcondition=True):
        super(D_GET_LOGITS, self).__init__()
        self.df_dim = ndf   # 96 form the file : coco_s1_train.yml
        self.ef_dim = nef   #128 will be assigned later when called in the discriminators 
        self.bcondition = bcondition
        if bcondition: # True if the extracted  96*8x4x4 image features has been concatenated with the image caption
            self.outlogits = nn.Sequential(
                conv3x3(ndf * 8 + nef, ndf * 8),# from 896x4x4 ==> 768x4x4
                nn.BatchNorm2d(ndf * 8),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4))#form 768x4x4 ==>1x1x1
                # nn.Sigmoid())
        else:
            self.outlogits = nn.Sequential(
                nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4))#form 768x4x4 ==>1x1x1
                # nn.Sigmoid())

    def forward(self, h_code, c_code=None):
        # conditioning output
        if self.bcondition and c_code is not None: # check if conditon 'caption' is used and passed 'not none'
            c_code = c_code.view(-1, self.ef_dim, 1, 1)
            c_code = c_code.repeat(1, 1, 4, 4)
            # state size (ngf+egf) x 4 x 4
            h_c_code = torch.cat((h_code, c_code), 1)
        else:
            h_c_code = h_code

        output = self.outlogits(h_c_code)
        return output.view(-1)


def stn(image, transformation_matrix, size):
    '''
    Spatial Transformer network
    given an image and transformatin matrix and an output size it applies affine transformation and samples from new grid to
    return an image of give size
    '''
    grid = torch.nn.functional.affine_grid(transformation_matrix, torch.Size(size))#size is the size of the targeted output image 
    out_image = torch.nn.functional.grid_sample(image, grid)

    return out_image


class BBOX_NET(nn.Module):  # obtains the layout encoding
    '''
    Takes as input to forward :
        lables : bounding box labels we concatenate the image caption embedding and the one-hot encoded bounding box
        label and apply a dense layer with 128 units, batch normalization, and a ReLU activation to it,
        to obtain a label ofshape(1,128)for each bounding box ==> lables (bs,3,128)
        transf_matr_inv : for the stn
        max_bojects : 3
        
    for each of the 3 lables per image -> reshape to (bs,128,1,1)->repat(bs,128,16,16)->apply stn
    the three outputs are summed at lables_layout(bs,128,16,16)
    label_layout->conv(bs,64,8,8)+BN+relu -> conv(bs,32,4,4)+BN+relu -> conv(bs,16,2,2)+BN+relu -> 
    reshape to (bs , 64)
    
    retruns layout_enocding(bs,64)
    
    'found at the A2 stage1 Genetraor and the global pathway G in table'
    '''
    def __init__(self):
        super(BBOX_NET, self).__init__()
        self.c_dim = cfg.GAN.CONDITION_DIM  #128
        self.encode = nn.Sequential(
            # 128 * 16 x 16
            conv3x3(self.c_dim, self.c_dim // 2, stride=2),
            nn.LeakyReLU(0.2, inplace=True),
            # 64 x 8 x 8
            conv3x3(self.c_dim // 2, self.c_dim // 4, stride=2),
            nn.BatchNorm2d(self.c_dim // 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 32 x 4 x 4
            conv3x3(self.c_dim // 4, self.c_dim // 8, stride=2),
            nn.BatchNorm2d(self.c_dim // 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 16 x 2 x 2
         )

    def forward(self, labels, transf_matr_inv, max_objects):
        label_layout = torch.cuda.FloatTensor(labels.shape[0], self.c_dim, 16, 16).fill_(0)
        # label_layout [82783,128,16,16];
        # labels 'One hot lables 'is of shape (82783, 3, 128)
        for idx in range(max_objects):
            current_label = labels[:, idx] # label idx (0,1 or 2) of each image (bs,128)
            current_label = current_label.view(current_label.shape[0], current_label.shape[1], 1, 1)
            #shape becomes [82783, 128, 1, 1]
            current_label = current_label.repeat(1, 1, 16, 16)   # shape becomes [82783, 128, 16, 16]
            current_label = stn(current_label, transf_matr_inv[:, idx], current_label.shape)
            label_layout += current_label 


        layout_encoding = self.encode(label_layout).view(labels.shape[0], -1) #output is of shape [82783, 16*2*2=64]

        return layout_encoding    #this will be concat with the noise tensor 

# ############# Networks for stageI GAN #############
class STAGE1_G(nn.Module):
    '''
    Takes for forward :
        text_embbeding : (bs , 1024) , use it to get c_code(bs,128)
        noise : (bs,100)
        transf_matrices_inv: for STN
        label_one_hot : (bs,3,81)
        max_objects : 3
    Get all the inputs from above functions too feed to generator global and local pathways
    and then use ouptut of the pathways to genertae the fake image
    
    returns :
        fake image : bsx3x64x64
        mu , logvar : bsx128
        local_labels : bs,3,128 'bbox labels + captions applied to linear layer'
    '''
    def __init__(self):
        super(STAGE1_G, self).__init__()
        self.gf_dim = cfg.GAN.GF_DIM * 8                  #192*8
        self.ef_dim = cfg.GAN.CONDITION_DIM               #128
        self.z_dim = cfg.Z_DIM                            #100 
        self.define_module()

    def define_module(self):
        ninput = self.z_dim + self.ef_dim                #228 noises = noise+conditioning 
        linput = self.ef_dim + 81                        #128+81= 209 encoded features + label input
        ngf = self.gf_dim                                #192*8 = 1536
        # TEXT.DIMENSION -> GAN.CONDITION_DIM 
        self.ca_net = CA_NET()  

        if cfg.USE_BBOX_LAYOUT:
            self.bbox_net = BBOX_NET()
            ninput += 64                                #292 global layout+noise+conditioning 

        # -> ngf x 4 x 4                                # 24576
        self.fc = nn.Sequential(
            nn.Linear(ninput, ngf * 4 * 4, bias=False),#292-->24576
            nn.BatchNorm1d(ngf * 4 * 4),
            nn.ReLU(True))

        # local pathway
        self.label = nn.Sequential(
            nn.Linear(linput, self.ef_dim, bias=False), # 209 --> 128
            nn.BatchNorm1d(self.ef_dim),
            nn.ReLU(True))
        self.local1 = upBlock(self.ef_dim, ngf // 2)  #128---->768
        self.local2 = upBlock(ngf // 2, ngf // 4)     #768----->384

        # global pathway
        # ngf x 4 x 4 -> ngf/2 x 8 x 8
        self.upsample1 = upBlock(ngf, ngf // 2)
        
        # ngf/2 x 8 x 8 -> ngf/4 x 16 x 16
        self.upsample2 = upBlock(ngf // 2, ngf // 4)
        
        # ngf/2 x 16 x 16 -> ngf/8 x 32 x 32
        self.upsample3 = upBlock(ngf // 2, ngf // 8)
        
        # ngf/8 x 32 x 32-> ngf/16 x 64 x 64
        self.upsample4 = upBlock(ngf // 8, ngf // 16)
        
        # ngf/16 x 64 x 64 -> 3 x 64 x 64
        self.img = nn.Sequential(conv3x3(ngf // 16, 3), nn.Tanh())

    def forward(self, text_embedding, noise, transf_matrices_inv, label_one_hot, max_objects=3):
        
        c_code, mu, logvar = self.ca_net(text_embedding)
        local_labels = torch.cuda.FloatTensor(noise.shape[0], max_objects, self.ef_dim).fill_(0)
        #local_labels(bs,3,128)

        # local(object) pathway
        # h_code_locals is the empty canvas on which the features are added at the locations given by the bbox
        h_code_locals = torch.cuda.FloatTensor(noise.shape[0], self.gf_dim // 4, 16, 16).fill_(0)
        #h_code_locals(bs,384,16,16)
        for idx in range(max_objects):
            # generate individual label for each bounding box, based on bbox label and caption
            current_label = self.label(torch.cat((c_code, label_one_hot[:, idx]), 1))
            #c_code(bs,128) ; label_one_hot(bs,3,81)->idx(bs,81) -> concat(bs,209) -> label(bs,128) from above
            local_labels[:, idx] = current_label #insert ar row idx of local_labels
            # replicate label spatially
            current_label = current_label.view(current_label.shape[0], self.ef_dim, 1, 1)#(bs, 128, 1,1)
            current_label = current_label.repeat(1, 1, 4, 4)#(bs, 128, 4, 4)
            # apply object pathway to the label to generate object features
            h_code_local = self.local1(current_label)#bsx128x4x4---->bsx768x8x8
            h_code_local = self.local2(h_code_local)#bsx768x8x8----->bsx384x16x16
            # transform features to the shape of the bounding box and add to empty canvas
            h_code_local = stn(h_code_local, transf_matrices_inv[:, idx], h_code_local.shape)
            h_code_locals += h_code_local

        # global pathway
        if cfg.USE_BBOX_LAYOUT:
            bbox_code = self.bbox_net(local_labels, transf_matrices_inv, max_objects) #layout_enocding(bs,64)
            z_c_code = torch.cat((noise, c_code, bbox_code), 1)
            #noise(bs,100) ; c_code(bs,128) ; bbox_code(bs,64) --> concat z_c_code(bs,292)
        else:
            z_c_code = torch.cat((noise, c_code), 1)
        # start global pathway
        h_code = self.fc(z_c_code)# (bs , 24576)
        h_code = h_code.view(-1, self.gf_dim, 4, 4) #(bs , 1536 , 4 , 4)
        h_code = self.upsample1(h_code) #(bs , 768 , 8 , 8)
        h_code = self.upsample2(h_code) #(bs , 384 , 16, 16)

        # combine local and global
        h_code = torch.cat((h_code, h_code_locals), 1)
        #h_code(bs , 384 , 16, 16); h_code_local (bsx384x16x16) --> cat(bsx768x16x16)

        h_code = self.upsample3(h_code) #(bs , 192 , 32, 32)
        h_code = self.upsample4(h_code) #(bs , 96  , 64, 64)

        # state size 3 x 64 x 64
        fake_img = self.img(h_code)
        return None, fake_img, mu, logvar, local_labels


class STAGE1_D(nn.Module):
    def __init__(self):
        super(STAGE1_D, self).__init__()
        self.df_dim = cfg.GAN.DF_DIM     # 96
        self.ef_dim = cfg.GAN.CONDITION_DIM    #128
        self.define_module()

    def define_module(self):
        ndf, nef = self.df_dim, self.ef_dim

        # local pathway
        self.local = nn.Sequential(
            nn.Conv2d(3 + 81, ndf * 2, 4, 1, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.act = nn.LeakyReLU(0.2, inplace=True)

        self.conv1 = nn.Conv2d(3, ndf, 4, 2, 1, bias=False)         #input image(3,64,64)--> output (64,32,32)
        self.conv2 = nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)   #---> output(128,16,16)
        self.bn2 = nn.BatchNorm2d(ndf * 2)                          #---> output(128,16,16)
        self.conv3 = nn.Conv2d(ndf*4, ndf * 4, 4, 2, 1, bias=False)  
        self.bn3 = nn.BatchNorm2d(ndf * 4)
        self.conv4 = nn.Conv2d(ndf*4, ndf * 8, 4, 2, 1, bias=False)
        self.bn4 = nn.BatchNorm2d(ndf * 8)

        self.get_cond_logits = D_GET_LOGITS(ndf, nef)
        self.get_uncond_logits = None

    def _encode_img(self, image, label, transf_matrices, transf_matrices_inv, max_objects):
        # local pathway
        # h_code_locals is the empty canvas on which the features are added at the locations given by the bbox
        h_code_locals = torch.cuda.FloatTensor(image.shape[0], self.df_dim * 2, 16, 16).fill_(0)  #(3,128,16,16)
        for idx in range(max_objects):
            # get bbox label and replicate spatially
            current_label = label[:, idx].view(label.shape[0], 81, 1, 1)
            current_label = current_label.repeat(1, 1, 16, 16)
            # extract features from bounding box and concatenate with the bbox label
            h_code_local = stn(image, transf_matrices[:, idx], (image.shape[0], image.shape[1], 16, 16))
            h_code_local = torch.cat((h_code_local, current_label), 1)
            # apply local pathway
            h_code_local = self.local(h_code_local)
            # reshape extracted features to bbox layout and add to empty canvas
            h_code_local = stn(h_code_local, transf_matrices_inv[:, idx], (h_code_local.shape[0], h_code_local.shape[1], 16, 16))
            h_code_locals += h_code_local
            
            
            

        # start global pathway
        h_code = self.conv1(image)
        h_code = self.act(h_code)
        h_code = self.conv2(h_code)
        h_code = self.bn2(h_code)
        h_code = self.act(h_code)
        # output global pathway(128,16,16)




        # combine global and local pathway
        h_code = torch.cat((h_code, h_code_locals), 1)

        h_code = self.conv3(h_code)
        h_code = self.bn3(h_code)
        h_code = self.act(h_code)

        h_code = self.conv4(h_code)
        h_code = self.bn4(h_code)
        h_code = self.act(h_code)
        return h_code

    def forward(self, image, label, transf_matrices, transf_matrices_inv, max_objects=3):  # , label_one_hot):
        img_embedding = self._encode_img(image, label, transf_matrices, transf_matrices_inv, max_objects)

        return img_embedding


# ############# Networks for stageII GAN #############
class STAGE2_G(nn.Module):
    def __init__(self, STAGE1_G):
        super(STAGE2_G, self).__init__()
        self.gf_dim = cfg.GAN.GF_DIM             #128 //192
        self.ef_dim = cfg.GAN.CONDITION_DIM      #128
        self.z_dim = cfg.Z_DIM                   #100
        self.STAGE1_G = STAGE1_G
        # fix parameters of stageI GAN
        for param in self.STAGE1_G.parameters():
            param.requires_grad = False
        self.define_module()

    def _make_layer(self, block, channel_num):     
        layers = []
        for i in range(cfg.GAN.R_NUM):      #residual number
            layers.append(block(channel_num))
        return nn.Sequential(*layers)

    def define_module(self):
        ngf = self.gf_dim    #128
        # TEXT.DIMENSION -> GAN.CONDITION_DIM
        self.ca_net = CA_NET()

        # local pathway
        linput = self.ef_dim + 81          #128+81 =209       
        self.label = nn.Sequential(
            nn.Linear(linput, self.ef_dim, bias=False),
            nn.BatchNorm1d(self.ef_dim),
            nn.ReLU(True))
        self.local1 = upBlock(self.ef_dim+768, ngf * 2)
        self.local2 = upBlock(ngf * 2, ngf)

        # --> 4ngf x 16 x 16
        self.encoder = nn.Sequential(
            conv3x3(3, ngf),             #(192,64,64)
            nn.ReLU(True),
            nn.Conv2d(ngf, ngf * 2, 4, 2, 1, bias=False),  #output image (384,32,32)
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.Conv2d(ngf * 2, ngf * 4, 4, 2, 1, bias=False),  #output image (768,16,16)
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True))
        if cfg.USE_BBOX_LAYOUT:
            self.hr_joint = nn.Sequential(
                conv3x3(self.ef_dim * 2 + ngf * 4, ngf * 4),   #[1, 1024, 16, 16]-->  [1, 768, 16, 16]
                nn.BatchNorm2d(ngf * 4),
                nn.ReLU(True))
        else:
            self.hr_joint = nn.Sequential(
                conv3x3(self.ef_dim + ngf * 4, ngf * 4), 
                nn.BatchNorm2d(ngf * 4),
                nn.ReLU(True))
        self.residual = self._make_layer(ResBlock, ngf * 4)
        # --> 2ngf x 32 x 32
        self.upsample1 = upBlock(ngf * 4, ngf * 2)
        # --> ngf x 64 x 64
        self.upsample2 = upBlock(ngf * 2, ngf)
        # --> ngf // 2 x 128 x 128
        self.upsample3 = upBlock(ngf * 2, ngf // 2)
        # --> ngf // 4 x 256 x 256
        self.upsample4 = upBlock(ngf // 2, ngf // 4)
        # --> 3 x 256 x 256
        self.img = nn.Sequential(
            conv3x3(ngf // 4, 3),
            nn.Tanh())

    def forward(self, text_embedding, noise, transf_matrices_inv,
                transf_matrices_s2, transf_matrices_inv_s2, label_one_hot, max_objects=3):
        _, stage1_img, _, _, _ = self.STAGE1_G(text_embedding, noise, transf_matrices_inv, label_one_hot)
        stage1_img = stage1_img.detach()
        encoded_img = self.encoder(stage1_img)  #(768,16,16)

        # contains the generated labels of the individual bboxes
        local_labels = torch.cuda.FloatTensor(noise.shape[0], max_objects, self.ef_dim).fill_(0) #1,3,128

        c_code, mu, logvar = self.ca_net(text_embedding)
        c_code_ = c_code.view(-1, self.ef_dim, 1, 1)
        c_code_ = c_code_.repeat(1, 1, 16, 16) #(1,128,16,16)

        if cfg.USE_BBOX_LAYOUT:
            labels_layout = torch.cuda.FloatTensor(noise.shape[0], self.ef_dim, 16, 16).fill_(0)  #(1,128,16,16)
            # create bbox layout by adding the bbox labels at the locations of the bbox, zeros everywhere else
            for idx in range(max_objects):
                # first, generate labels for each bbox, using the one-hot bbox labels and image caption
                current_label = self.label(torch.cat((c_code, label_one_hot[:, idx]), 1)) #c_code= (1,128) ; label_one_hot (1,81)
                #current_label=(1,128)
                local_labels[:, idx] = current_label  # (1,3,128 ) <---- (1,128)
                
                
                
                #*************************************#
                # replicate label spatially
                current_label = current_label.view(current_label.shape[0], current_label.shape[1], 1, 1)
                current_label = current_label.repeat(1, 1, 16, 16)       #(1,128,16,16)
                # transfer label to bbox location and add to empty canvas
                label_local = stn(current_label, transf_matrices_inv[:, idx],
                                  (labels_layout.shape[0], labels_layout.shape[1], 16, 16))
                labels_layout += label_local
            # concatenate with the other information
            i_c_code = torch.cat([encoded_img, c_code_, labels_layout], 1)  # [1, 1024 , 16, 16]
        else:
            i_c_code = torch.cat([encoded_img, c_code_], 1)
        h_code = self.hr_joint(i_c_code) #[1,768 , 16, 16]
        h_code = self.residual(h_code)

        # local pathway
        h_code_locals = torch.cuda.FloatTensor(h_code.shape[0], self.gf_dim, 64, 64).fill_(0)  #(1, 192, 64, 64)
        for idx in range(max_objects):
            if not cfg.USE_BBOX_LAYOUT:
                # generate local labels if not already done
                current_label = self.label(torch.cat((c_code, label_one_hot[:, idx]), 1)) #current_label=(1,128)
                local_labels[:, idx] = current_label
                
                
                 #*************************************#
            # replicate local labels spatially
            current_label = local_labels[:, idx].view(h_code.shape[0], 128, 1, 1)
            current_label = current_label.repeat(1, 1, 16, 16)   #(1,128,16,16)
            # extract features from image at the location of the bbox and concat with label
            current_patch = stn(h_code, transf_matrices_s2[:, idx], (h_code.shape[0], h_code.shape[1], 16, 16)) # (1, 786, 16, 16)
            current_input = torch.cat((current_patch, current_label), 1)   #(1,896,16,16)  
            
            # apply local pathway
            h_code_local = self.local1(current_input) # (1,384,32,32) 
            h_code_local = self.local2(h_code_local)     #(1,192,64,64) 
            # transfer features to bbox location and add to empty canvas
            h_code_local = stn(h_code_local, transf_matrices_inv_s2[:, idx], h_code_locals.shape)
            h_code_locals += h_code_local

        # start upsampling with global pathway
        h_code = self.upsample1(h_code)     #[1,768 , 16, 16]---> 384 x 32 x 32
        h_code = self.upsample2(h_code)     #384 x 32 x 32 ----> 192 x 64 x 64

        # combine global and local
        h_code = torch.cat((h_code, h_code_locals), 1)   #384 x 64 x 64

        h_code = self.upsample3(h_code)   #------> 96 x 128 x 128
        h_code = self.upsample4(h_code)   #------> 48 x 256 x 256

        fake_img = self.img(h_code)        #------> 3 x 256 x 256

        return stage1_img, fake_img, mu, logvar, local_labels


class STAGE2_D(nn.Module):
    def __init__(self):
        super(STAGE2_D, self).__init__()
        self.df_dim = cfg.GAN.DF_DIM           #64
        self.ef_dim = cfg.GAN.CONDITION_DIM    #128
        self.define_module()

    def define_module(self):
        ndf, nef = self.df_dim, self.ef_dim
##################kernal size should be 3 #########################################################
        self.local = nn.Sequential(
            nn.Conv2d(3 + 81, ndf * 2, 4, 1, 1, bias=False),              #(1,84,32,32)---> (1,192,31,31)
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 2, 4, 1, 1, bias=False),            #------> (1,192,30,30)
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
        )
#########################################################################################################    

        self.act = nn.LeakyReLU(0.2, inplace=True)

        self.conv1 = nn.Conv2d(3, ndf, 4, 2, 1, bias=False)
        self.conv2 = nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(ndf * 2)
        self.conv3 = nn.Conv2d(ndf*2, ndf * 4, 4, 2, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(ndf * 4)
        self.conv4 = nn.Conv2d(ndf * 6, ndf * 8, 4, 2, 1, bias=False)
        self.bn4 = nn.BatchNorm2d(ndf * 8)
        self.conv5 = nn.Conv2d(ndf * 8, ndf * 16, 4, 2, 1, bias=False)
        self.bn5 = nn.BatchNorm2d(ndf * 16)
        self.conv6 = nn.Conv2d(ndf * 16, ndf * 32, 4, 2, 1, bias=False)
        self.bn6 = nn.BatchNorm2d(ndf * 32)
        self.conv7 = conv3x3(ndf * 32, ndf * 16)
        self.bn7 = nn.BatchNorm2d(ndf * 16)
        self.conv8 = conv3x3(ndf * 16, ndf * 8)
        self.bn8 = nn.BatchNorm2d(ndf * 8)


        self.get_cond_logits = D_GET_LOGITS(ndf, nef, bcondition=True)
        self.get_uncond_logits = D_GET_LOGITS(ndf, nef, bcondition=False)

    def _encode_img(self, image, label, transf_matrices, transf_matrices_inv, max_objects):
        # local pathway
        h_code_locals = torch.cuda.FloatTensor(image.shape[0], self.df_dim * 2, 32, 32).fill_(0)  #(1,128,32,32)
        for idx in range(max_objects):
            # get current bbox label and replicate spatially
            current_label = label[:, idx]
            current_label = current_label.view(label.shape[0], 81, 1, 1)
            current_label = current_label.repeat(1, 1, 32, 32)                 #(1,81,32,32)
            # extract features from bbox and concat with label
            h_code_local = stn(image, transf_matrices[:, idx], (image.shape[0], image.shape[1], 32, 32))   #(1,3,32,32)
            h_code_local = torch.cat((h_code_local, current_label), 1)                                   #(1,84,32,32)
            
            
            # apply local pathway
            h_code_local = self.local(h_code_local)
            # transfer features to location of bbox and add to empty canvas
            h_code_local = stn(h_code_local, transf_matrices_inv[:, idx], h_code_locals.shape)
            h_code_locals += h_code_local        #(1, 192, 32, 32)     

        # start downsampling with global pathway
        h_code = self.conv1(image)      #(,3,256,256 )---------> (,96,128,128)
        h_code = self.act(h_code)       #---------> (,96,128,128)
        h_code = self.conv2(h_code)     #---------> (,192,64,64)
        h_code = self.bn2(h_code)       #---------> (,192,64,64)
        h_code = self.act(h_code)       #---------> (,192,64,64)
        h_code = self.conv3(h_code)     #---------> (,384,32,32)
        h_code = self.bn3(h_code)       #---------> (,384,32,32)      
        h_code = self.act(h_code)       #---------> (,384,32,32)

        # combine global and local
        h_code = torch.cat((h_code, h_code_locals), 1)         #(,576,32,32)

        h_code = self.conv4(h_code)       #--------> (,768,16,16)
        h_code = self.bn4(h_code)
        h_code = self.act(h_code)
        h_code = self.conv5(h_code)       #--------> (,1536,8,8)
        h_code = self.bn5(h_code)
        h_code = self.act(h_code)
        h_code = self.conv6(h_code)       #--------> (,3072,4,4)
        h_code = self.bn6(h_code)
        h_code = self.act(h_code)
        h_code = self.conv7(h_code)       #--------> (,1536,4,4) 
        h_code = self.bn7(h_code)
        h_code = self.act(h_code)
        h_code = self.conv8(h_code)       #--------> (,768,4,4) 
        h_code = self.bn8(h_code)
        h_code = self.act(h_code)

        return h_code

    def forward(self, image, label, transf_matrices, transf_matrices_inv, max_objects=3):
        img_embedding = self._encode_img(image, label, transf_matrices, transf_matrices_inv, max_objects)

        return img_embedding


# trainer

In [None]:
from __future__ import print_function
from six.moves import range
from PIL import Image

import torch.backends.cudnn as cudnn
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import os
import time

import numpy as np
import torchfile

from miscc.config import cfg
from miscc.utils import mkdir_p
from miscc.utils import weights_init
from miscc.utils import save_img_results, save_model
from miscc.utils import KL_loss
from miscc.utils import compute_discriminator_loss, compute_generator_loss
from miscc.utils import compute_transformation_matrix, compute_transformation_matrix_inverse
from miscc.utils import load_validation_data

from tensorboard import summary
from tensorboard import FileWriter


class GANTrainer(object):
    def __init__(self, output_dir):       # output_dir='../../..//output/%s_%s_%s' %\(cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)   #create an event file in a given directory and add summaries and events to it

        self.max_epoch = cfg.TRAIN.MAX_EPOCH                                 #120
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL                 #10

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        self.batch_size = cfg.TRAIN.BATCH_SIZE                            # depends on the stage we are training 
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True                                        # enables benchmark mode in cudnn

    # ############# For training stageI GAN #############
    def load_network_stageI(self):
        from model import STAGE1_G, STAGE1_D
        netG = STAGE1_G()
        netG.apply(weights_init)
        print(netG)
        netD = STAGE1_D()
        netD.apply(weights_init)
        print(netD)

        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G, map_location=lambda storage,loc   : storage)
            netG.load_state_dict(state_dict["netG"])
            print('Load from: ', cfg.NET_G)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,  map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD

    # ############# For training stageII GAN  #############
    def load_network_stageII(self):
        from model import STAGE1_G, STAGE2_G, STAGE2_D

        Stage1_G = STAGE1_G()
        netG = STAGE2_G(Stage1_G)
        netG.apply(weights_init)
        print(netG)
        if cfg.NET_G != '':
            state_dict = torch.load(cfg.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict["netG"])
            print('Load from: ', cfg.NET_G)
        elif cfg.STAGE1_G != '':
            state_dict = torch.load(cfg.STAGE1_G, map_location=lambda storage, loc: storage)
            netG.STAGE1_G.load_state_dict(state_dict["netG"])
            print('Load from: ', cfg.STAGE1_G)
        else:
            print("Please give the Stage1_G path")
            return

        netD = STAGE2_D()
        netD.apply(weights_init)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        print(netD)

        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD

    def train(self, data_loader, stage=1, max_objects=3):
        if stage == 1:       # we are checking for training of stage 1 
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM                                                             #100
        batch_size = self.batch_size                                               #128
        noise = Variable(torch.FloatTensor(batch_size, nz))                        #tensor of shape (128,100)
        # with torch.no_grad():
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), requires_grad=False)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR                                     #0.0002
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR                             #0.0002
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH                                  #20

        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerD = optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999))

        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, bbox, label, txt_embedding = data
                #real_img_cpu.shape =128, 3, 64, 64
                #label.shape=  128, 3, 1
                #bbox.shape=128, 3, 4
               #txt_embedding.shape =128, 1024 

                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    if cfg.STAGE == 1:
                        bbox = bbox.cuda()
                    elif cfg.STAGE == 2:
                        bbox = [bbox[0].cuda(), bbox[1].cuda()]
                    label = label.cuda()
                    txt_embedding = txt_embedding.cuda()

                if cfg.STAGE == 1:
                    bbox = bbox.view(-1, 4)         # 384, 4                                                                   
                    transf_matrices_inv = compute_transformation_matrix_inverse(bbox)                          #--->384, 2, 3
                    transf_matrices_inv = transf_matrices_inv.view(real_imgs.shape[0], max_objects, 2, 3)      #128,3, 2, 3    
                    transf_matrices = compute_transformation_matrix(bbox)                                      #--->384, 2, 3
                    transf_matrices = transf_matrices.view(real_imgs.shape[0], max_objects, 2, 3)              #128,3, 2, 3 
                elif cfg.STAGE == 2:
                    print(bbox.shape)
                    print(bbox[0].shape)
                    _bbox = bbox[0].view(-1, 4) 
                    print(_bbox.shape)
                    transf_matrices_inv = compute_transformation_matrix_inverse(_bbox)   #768,2,3
                    transf_matrices_inv = transf_matrices_inv.view(real_imgs.shape[0], max_objects, 2, 3)

                    _bbox = bbox[1].view(-1, 4)
                    transf_matrices_inv_s2 = compute_transformation_matrix_inverse(_bbox)
                    transf_matrices_inv_s2 = transf_matrices_inv_s2.view(real_imgs.shape[0], max_objects, 2, 3)
                    transf_matrices_s2 = compute_transformation_matrix(_bbox)
                    transf_matrices_s2 = transf_matrices_s2.view(real_imgs.shape[0], max_objects, 2, 3)

                # produce one-hot encodings of the labels
                _labels = label.long()
                # remove -1 to enable one-hot converting
                _labels[_labels < 0] = 80
                label_one_hot = torch.cuda.FloatTensor(noise.shape[0], max_objects, 81).fill_(0)
                label_one_hot = label_one_hot.scatter_(2, _labels, 1).float()

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                if cfg.STAGE == 1:
                    inputs = (txt_embedding, noise, transf_matrices_inv, label_one_hot)
                elif cfg.STAGE == 2:
                    inputs = (txt_embedding, noise, transf_matrices_inv, transf_matrices_s2, transf_matrices_inv_s2, label_one_hot)
                _, fake_imgs, mu, logvar, _ = nn.parallel.data_parallel(netG, inputs, self.gpus)
                # _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise, transf_matrices_inv, label_one_hot)

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()

                if cfg.STAGE == 1:
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                   real_labels, fake_labels,
                                                   label_one_hot, transf_matrices, transf_matrices_inv,
                                                   mu, self.gpus)
                elif cfg.STAGE == 2:
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                   real_labels, fake_labels,
                                                   label_one_hot, transf_matrices_s2, transf_matrices_inv_s2,
                                                   mu, self.gpus)
                errD.backward(retain_graph=True)
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                if cfg.STAGE == 1:
                    errG = compute_generator_loss(netD, fake_imgs,
                                                  real_labels, label_one_hot, transf_matrices, transf_matrices_inv,
                                                  mu, self.gpus)
                elif cfg.STAGE == 2:
                    errG = compute_generator_loss(netD, fake_imgs,
                                                  real_labels, label_one_hot, transf_matrices_s2, transf_matrices_inv_s2,
                                                  mu, self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward()
                optimizerG.step()

                count += 1
                if i % 500 == 0:
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())
                    summary_KL = summary.scalar('KL_loss', kl_loss.item())

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    with torch.no_grad():
                        if cfg.STAGE == 1:
                            inputs = (txt_embedding, noise, transf_matrices_inv, label_one_hot)
                        elif cfg.STAGE == 2:
                            inputs = (txt_embedding, noise, transf_matrices_inv, transf_matrices_s2, transf_matrices_inv_s2, label_one_hot)
                        lr_fake, fake, _, _, _ = nn.parallel.data_parallel(netG, inputs, self.gpus)
                        save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                        if lr_fake is not None:
                            save_img_results(None, lr_fake, epoch, self.image_dir)
            with torch.no_grad():
                if cfg.STAGE == 1:
                    inputs = (txt_embedding, noise, transf_matrices_inv, label_one_hot)
                elif cfg.STAGE == 2:
                    inputs = (txt_embedding, noise, transf_matrices_inv, transf_matrices_s2, transf_matrices_inv_s2, label_one_hot)
                lr_fake, fake, _, _, _ = nn.parallel.data_parallel(netG, inputs, self.gpus)
                save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                if lr_fake is not None:
                    save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  '''
                  % (epoch, self.max_epoch, i, len(data_loader),
                     errD.item(), errG.item(), kl_loss.item(),
                     errD_real, errD_wrong, errD_fake, (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir)
        #
        save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir)
        #
        self.summary_writer.close()

    def sample(self, datapath, num_samples=25, stage=1, draw_bbox=True, max_objects=3):
        from PIL import Image, ImageDraw, ImageFont
        import cPickle as pickle
        import torchvision
        import torchvision.utils as vutils
        img_dir = cfg.IMG_DIR
        if stage == 1:
            netG, _ = self.load_network_stageI()
        else:
            netG, _ = self.load_network_stageII()
        netG.eval()

        # Load text embeddings generated from the encoder
        t_file = torchfile.load(datapath + "val_captions.t7")
        captions_list = t_file.raw_txt
        embeddings = np.concatenate(t_file.fea_txt, axis=0)
        num_embeddings = len(captions_list)
        label, bbox = load_validation_data(datapath)

        filepath = os.path.join(datapath, 'filenames.pickle')
        with open(filepath, 'rb') as f:
            filenames = pickle.load(f)
        print('Successfully load sentences from: ', datapath)
        print('Total number of sentences:', num_embeddings)
        # path to save generated samples
        save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')] + "_visualize_bbox"
        print("saving to:", save_dir)
        mkdir_p(save_dir)

        if cfg.CUDA:
            if cfg.STAGE == 1:
                bbox = bbox.cuda()
            elif cfg.STAGE == 2:
                bbox = [bbox.clone().cuda(), bbox.cuda()]
            label = label.cuda()

        #######################################
        if cfg.STAGE == 1:
            bbox_ = bbox.clone()
        elif cfg.STAGE == 2:
            bbox_ = bbox[0].clone()

        if cfg.STAGE == 1:
            bbox = bbox.view(-1, 4)
            transf_matrices_inv = compute_transformation_matrix_inverse(bbox)
            transf_matrices_inv = transf_matrices_inv.view(num_embeddings, max_objects, 2, 3)
        elif cfg.STAGE == 2:
            _bbox = bbox[0].view(-1, 4)
            transf_matrices_inv = compute_transformation_matrix_inverse(_bbox)
            transf_matrices_inv = transf_matrices_inv.view(num_embeddings, max_objects, 2, 3)

            _bbox = bbox[1].view(-1, 4)
            transf_matrices_inv_s2 = compute_transformation_matrix_inverse(_bbox)
            transf_matrices_inv_s2 = transf_matrices_inv_s2.view(num_embeddings, max_objects, 2, 3)
            transf_matrices_s2 = compute_transformation_matrix(_bbox)
            transf_matrices_s2 = transf_matrices_s2.view(num_embeddings, max_objects, 2, 3)

        # produce one-hot encodings of the labels
        _labels = label.long()
        # remove -1 to enable one-hot converting
        _labels[_labels < 0] = 80
        label_one_hot = torch.cuda.FloatTensor(num_embeddings, max_objects, 81).fill_(0)
        label_one_hot = label_one_hot.scatter_(2, _labels, 1).float()
        #######################################

        nz = cfg.Z_DIM
        noise = Variable(torch.FloatTensor(9, nz))
        if cfg.CUDA:
            noise = noise.cuda()

        imsize = 64 if stage == 1 else 256

        for count in range(num_samples):
            index = int(np.random.randint(0, num_embeddings, 1))
            key = filenames[index]
            img_name = img_dir + "/" + key + ".jpg"
            img = Image.open(img_name).convert('RGB').resize((imsize, imsize), Image.ANTIALIAS)
            val_image = torchvision.transforms.functional.to_tensor(img)
            val_image = val_image.view(1, 3, imsize, imsize)
            val_image = (val_image - 0.5) * 2

            embeddings_batch = embeddings[index]
            transf_matrices_inv_batch = transf_matrices_inv[index]
            label_one_hot_batch = label_one_hot[index]

            embeddings_batch = np.reshape(embeddings_batch, (1, 1024)).repeat(9,0)
            transf_matrices_inv_batch = transf_matrices_inv_batch.view(1, 3, 2, 3).repeat(9, 1, 1, 1)
            label_one_hot_batch = label_one_hot_batch.view(1, 3, 81).repeat(9, 1, 1)

            if cfg.STAGE == 2:
                transf_matrices_s2_batch = transf_matrices_s2[index]
                transf_matrices_s2_batch = transf_matrices_s2_batch.view(1, 3, 2, 3).repeat(9, 1, 1, 1)
                transf_matrices_inv_s2_batch = transf_matrices_inv_s2[index]
                transf_matrices_inv_s2_batch = transf_matrices_inv_s2_batch.view(1, 3, 2, 3).repeat(9, 1, 1, 1)

            txt_embedding = Variable(torch.FloatTensor(embeddings_batch))
            if cfg.CUDA:
                label_one_hot_batch = label_one_hot_batch.cuda()
                txt_embedding = txt_embedding.cuda()

            #######################################################
            # (2) Generate fake images
            ######################################################
            noise.data.normal_(0, 1)
            # inputs = (txt_embedding, noise, transf_matrices_inv_batch, label_one_hot_batch)
            if cfg.STAGE == 1:
                inputs = (txt_embedding, noise, transf_matrices_inv_batch, label_one_hot_batch)
            elif cfg.STAGE == 2:
                inputs = (txt_embedding, noise, transf_matrices_inv_batch,
                          transf_matrices_s2_batch, transf_matrices_inv_s2_batch, label_one_hot_batch)
            with torch.no_grad():
                _, fake_imgs, mu, logvar, _ = nn.parallel.data_parallel(netG, inputs, self.gpus)

            data_img = torch.FloatTensor(10, 3, imsize, imsize).fill_(0)
            data_img[0] = val_image
            data_img[1:10] = fake_imgs

            if draw_bbox:
                for idx in range(3):
                    x, y, w, h = tuple([int(imsize*x) for x in bbox_[index, idx]])
                    w = imsize-1 if w > imsize-1 else w
                    h = imsize-1 if h > imsize-1 else h
                    if x <= -1:
                        break
                    data_img[:10, :, y, x:x + w] = 1
                    data_img[:10, :, y:y + h, x] = 1
                    data_img[:10, :, y+h, x:x + w] = 1
                    data_img[:10, :, y:y + h, x + w] = 1

            vutils.save_image(data_img, '{}/{}.png'.format(save_dir, captions_list[index]), normalize=True, nrow=10)

        print("Saved {} files to {}".format(count+1, save_dir))



# main

In [None]:
from __future__ import print_function
import torch.backends.cudnn as cudnn
import torch
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn

import argparse
import os
import random
import sys
import pprint
import datetime
import dateutil
import dateutil.tz
from shutil import copyfile


dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.')))  # you give it a relative path and it returns the absolute path 
sys.path.append(dir_path)

from miscc.datasets import TextDataset
from miscc.config import cfg, cfg_from_file
from miscc.utils import mkdir_p
from trainer import GANTrainer


def parse_args():
    parser = argparse.ArgumentParser(description='Train a GAN network')
    parser.add_argument('--cfg', dest='cfg_file',
                        help='optional config file',
                        default='birds_stage1.yml', type=str)
    parser.add_argument('--gpu',  dest='gpu_id', type=str, default='0')
    parser.add_argument('--data_dir', dest='data_dir', type=str, default='')
    parser.add_argument('--manualSeed', type=int, help='manual seed')
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_args()
    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)
    if args.gpu_id != -1:
        cfg.GPU_ID = args.gpu_id
    if args.data_dir != '':
        cfg.DATA_DIR = args.data_dir
    print('Using config:')
    pprint.pprint(cfg)
    if args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)
    if cfg.CUDA:
        torch.cuda.manual_seed_all(args.manualSeed)
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    output_dir = '../../..//output/%s_%s_%s' % \
                 (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

    cudnn.benchmark = True

    num_gpu = len(cfg.GPU_ID.split(','))
    if cfg.TRAIN.FLAG:
        try:
            os.makedirs(output_dir)
        except OSError as exc:  # Python >2.5
            if exc.errno == errno.EEXIST and os.path.isdir(path):
                pass
            else:
                raise

        copyfile(sys.argv[0], output_dir + "/" + sys.argv[0])
        copyfile("trainer.py", output_dir + "/" + "trainer.py")
        copyfile("model.py", output_dir + "/" + "model.py")
        copyfile("miscc/utils.py", output_dir + "/" + "utils.py")
        copyfile("miscc/datasets.py", output_dir + "/" + "datasets.py")
        copyfile(args.cfg_file, output_dir + "/" + "cfg_file.yml")

        if cfg.STAGE == 1:
            resize = 76
            imsize=64
        elif cfg.STAGE == 2:
            resize = 268
            imsize = 256

        img_transform = transforms.Compose([
            transforms.Resize((resize, resize)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset = TextDataset(cfg.DATA_DIR, cfg.IMG_DIR, split="train", imsize=imsize, transform=img_transform,
                              crop=True, stage=cfg.STAGE)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=cfg.TRAIN.BATCH_SIZE,
            drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader, cfg.STAGE)
    else:
        datapath= '%s/test/' % (cfg.DATA_DIR)
        algo = GANTrainer(output_dir)
        algo.sample(datapath, num_samples=25, stage=cfg.STAGE, draw_bbox=True)
