In [None]:
from dataloader.data_loader import dataloader
import torch
import pickle

In [None]:
class Options:
    def __init__(self, name, model, mask_type, img_file, mask_file, text_config, batchSize, prior_alpha, prior_beta, lr_policy, lr, gan_mode):
        self.name=name
        self.model=model
        self.mask_type=mask_type
        self.img_file=img_file # has paths of all the images
        self.mask_file=mask_file
        self.text_config=text_config
        self.batchSize=batchSize
        self.prior_alpha=prior_alpha
        self.prior_bets=prior_beta
        self.lr_policy=lr_policy
        self.lr=lr
        self.gan_mode=gan_mode

opt=Options("tdanet", 0, './datasets/CUB_200_2011/train.flist', './datasets/CUB_200_2011/train_mask.flist', 'config.bird.yml', 10, 0.8, 8, 'lambda', 1e-4,
            'lsgan')


In [None]:
DEFAULT_CONFIG = {
    'MAX_TEXT_LENGTH' : 128,

    'VOCAB' : "./datasets/captions_vocab_bird.pickle",      # The path to DAMSM vocab pickle file
    'LANGUAGE_ENCODER' : "./datasets/text_encoder_bird.pth",    # The path to DAMSM text encoder
    'IMAGE_ENCODER': "./datasets/image_encoder_bird.pth",   # The path to DAMSM image encoder
    'EMBEDDING_DIM' : 256,

    'CATE_IMAGE_TRAIN' : "./datasets/CUB_200_2011/cate_image_train.json",   # The path to category-image mapping cache file
    'IMAGE_CATE_TRAIN' : "./datasets/CUB_200_2011/image_cate.json",     # The path to image-category mapping cache file

    'CAPTION' : "./datasets/CUB_200_2011/caption.json",     # The path to image-caption cache file

}

In [None]:
# for generating masks
def scale_pyramid(img, num_scales):
    scaled_imgs = [img]

    s = img.size()

    h = s[2]
    w = s[3]

    for i in range(1, num_scales):
        ratio = 2**i
        nh = h // ratio
        nw = w // ratio
        scaled_img = scale_img(img, size=[nh, nw])
        scaled_imgs.append(scaled_img)

    scaled_imgs.reverse()
    return scaled_imgs

In [None]:
dataset= dataloader(opt)
dataset_size = len(dataset) * opt.batchSize


class TDAnet:
    def __init__(self, opt):
        
        self.loss_names = ['kl_rec', 'kl_g', 'l1_rec', 'l1_g', 'gan_g', 'word_g', 'sentence_g', 'ad_l2_g',
                           'gan_rec', 'ad_l2_rec', 'word_rec', 'sentence_rec',  'dis_img', 'dis_img_rec']
        self.log_names = []
        self.visual_names = ['img_m', 'img_truth', 'img_c', 'img_out', 'img_g', 'img_rec']
        self.text_names = ['text_positive']
        self.value_names = ['u_m', 'sigma_m', 'u_post', 'sigma_post', 'u_prior', 'sigma_prior']
        self.model_names = ['E', 'G', 'D', 'D_rec']
        self.distribution = []
        self.prior_alpha = opt.prior_alpha
        self.prior_beta = opt.prior_beta
        self.max_pool = None if opt.no_maxpooling else 'max'

        # inpainting model
        self.net_E = network.define_att_textual_e(ngf=32, z_nc=256, img_f=256, layers=5, norm='none', activation='LeakyReLU',
                          init_type='orthogonal', gpu_ids=opt.gpu_ids, image_dim=256, text_dim=256, multi_peak=False, pool_attention=self.max_pool)
        self.net_G = network.define_hidden_textual_g(f_text_dim=768, ngf=32, z_nc=256, img_f=256, L=0, layers=5, output_scale=opt.output_scale,
                                      norm='instance', activation='LeakyReLU', init_type='orthogonal', gpu_ids=opt.gpu_ids)
        
        #discriminator model
        self.net_D = network.define_d(ndf=32, img_f=128, layers=5, model_type='ResDis', init_type='orthogonal', gpu_ids=opt.gpu_ids)
        self.net_D_rec = network.define_d(ndf=32, img_f=128, layers=5, model_type='ResDis', init_type='orthogonal', gpu_ids=opt.gpu_ids)

        self._init_language_model(DEFAULT_CONFIG)

        if self.isTrain:
            # define the loss functions
            self.GANloss = external_function.GANLoss(opt.gan_mode)
            self.L1loss = torch.nn.L1Loss()
            self.L2loss = torch.nn.MSELoss()

            self.image_encoder = network.CNN_ENCODER(DEFAULT_CONFIG['EMBEDDING_DIM'])
            state_dict = torch.load(
                DEFAULT_CONFIG['IMAGE_ENCODER'], map_location=lambda storage, loc: storage)
            self.image_encoder.load_state_dict(state_dict)
            self.image_encoder.eval()
            if len(self.gpu_ids) > 0 and torch.cuda.is_available():
                self.image_encoder.cuda()
            base_function._freeze(self.image_encoder)

            # define the optimizer
            self.optimizer_G = torch.optim.Adam(itertools.chain(filter(lambda p: p.requires_grad, self.net_G.parameters()),
                        filter(lambda p: p.requires_grad, self.net_E.parameters())), lr=opt.lr, betas=(0.0, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(filter(lambda p: p.requires_grad, self.net_D.parameters()),
                                                filter(lambda p: p.requires_grad, self.net_D_rec.parameters())),
                                                lr=opt.lr, betas=(0.0, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

        self.setup(opt)

        def _init_language_model(self, text_config):
            x = pickle.load(open(text_config['VOCAB'], 'rb'))
            self.ixtoword = x[2]
            self.wordtoix = x[3]

            word_len = len(self.wordtoix)
            self.text_encoder = network.RNN_ENCODER(word_len, nhidden=256)

            state_dict = torch.load(text_config.LANGUAGE_ENCODER, map_location=lambda storage, loc: storage)
            self.text_encoder.load_state_dict(state_dict)
            self.text_encoder.eval()
            if not self.opt.update_language:
                self.text_encoder.requires_grad_(False)
            if len(self.gpu_ids) > 0 and torch.cuda.is_available():
                self.text_encoder.cuda()

        def set_input(self, input, epoch=0):
            """Get set of data from DataLoader and do necessary preprocessing"""
            self.input = input
            self.image_paths = self.input['img_path']
            self.img = input['img']
            self.mask = input['mask']
            self.caption_idx = input['caption_idx']
            self.caption_length = input['caption_len']

            if len(self.gpu_ids) > 0:
                self.img = self.img.cuda(self.gpu_ids[0], True)
                self.mask = self.mask.cuda(self.gpu_ids[0], True)

            # get I_m and I_c for image with mask and complement regions for training
            self.img_truth = self.img * 2 - 1
            self.img_m = self.mask * self.img_truth
            self.img_c =  (1 - self.mask) * self.img_truth

            # get multiple scales image ground truth and mask for training
            self.scale_img = scale_pyramid(self.img_truth, self.opt.output_scale)
            self.scale_mask = scale_pyramid(self.mask, self.opt.output_scale)

            # About text stuff
            self.text_positive = util.idx_to_caption(
                                        self.ixtoword, self.caption_idx[0].tolist(), self.caption_length[0].item())
            self.word_embeddings, self.sentence_embedding = util.vectorize_captions_idx_batch(
                                                        self.caption_idx, self.caption_length, self.text_encoder)
            self.text_mask = util.lengths_to_mask(self.caption_length, max_length=self.word_embeddings.size(-1))
            self.match_labels = torch.LongTensor(range(len(self.img_m)))
            if len(self.gpu_ids) > 0:
                self.word_embeddings = self.word_embeddings.cuda(self.gpu_ids[0], True)
                self.sentence_embedding = self.sentence_embedding.cuda(self.gpu_ids[0], True)
                self.text_mask = self.text_mask.cuda(self.gpu_ids[0], True)
                self.match_labels = self.match_labels.cuda(self.gpu_ids[0], True)
                
        def forward(self):
            """Perform forward propagation of given inputs"""
            # encoder process
            distribution_factors, f, f_text = self.net_E(self.img_m, self.sentence_embedding, self.word_embeddings, self.text_mask, self.mask, self.img_c)
            p_distribution, q_distribution, self.kl_rec, self.kl_g = self.get_distribution(distribution_factors)

            # decoder process
            z, f_m, f_e, mask = self.get_G_inputs(p_distribution, q_distribution, f) # prepare inputs: img, mask, distribute

            results, attn = self.net_G(z, f_text, f_e, mask)
            self.img_rec = []
            self.img_g = []
            for result in results:
                img_rec, img_g = result.chunk(2)
                self.img_rec.append(img_rec)
                self.img_g.append(img_g)
            self.img_out = (1-self.mask) * self.img_g[-1].detach() + self.mask * self.img_truth

            self.region_features_rec, self.cnn_code_rec = self.image_encoder(self.img_rec[-1])
            self.region_features_g, self.cnn_code_g = self.image_encoder(self.img_g[-1])

In [1]:
a={"a":1,"b":2}

In [None]:
model=TDAnet(opt)