In [5]:
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.optim import lr_scheduler
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.autograd import Variable
from external_function import SpectralNorm
import os, ntpath
import torch
from collections import OrderedDict
from util import util
import pickle
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
import copy
import torch.nn.functional as F
from util import task
from util import task, util
import itertools
from options.global_config import TextConfig
import pickle
import importlib
import torch
import random
import json
import pickle
import os
import numpy as np
import imageio
from PIL import Image, ImageFile
import torchvision.transforms as transforms
import torch.utils.data as data
from image_folder import make_dataset
from util import task, util
from global_config import TextConfig

In [6]:
def init_weights(net, init_type='normal', gain=0.02):
    """Get different initial method for the network weights"""
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv')!=-1 or classname.find('Linear')!=-1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, 0.02)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)


def get_norm_layer(norm_type='batch'):
    """Get the normalization layer for the networks"""
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, momentum=0.1, affine=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=True)
    elif norm_type == 'none':
        norm_layer = None
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer


def get_nonlinearity_layer(activation_type='PReLU'):
    """Get the activation layer for the networks"""
    if activation_type == 'ReLU':
        nonlinearity_layer = nn.ReLU()
    elif activation_type == 'SELU':
        nonlinearity_layer = nn.SELU()
    elif activation_type == 'LeakyReLU':
        nonlinearity_layer = nn.LeakyReLU(0.1)
    elif activation_type == 'PReLU':
        nonlinearity_layer = nn.PReLU()
    else:
        raise NotImplementedError('activation layer [%s] is not found' % activation_type)
    return nonlinearity_layer


def get_scheduler(optimizer, opt):
    """Get the training learning rate for different epoch"""
    if opt.lr_policy == 'lambda':
        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch+1+1+opt.iter_count-opt.niter) / float(opt.niter_decay+1)
            return lr_l
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    elif opt.lr_policy == 'step':
        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
    elif opt.lr_policy == 'exponent':
        scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    else:
        raise NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
    return scheduler


def print_network(net):
    """print the network"""
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print(net)
    print('total number of parameters: %.3f M' % (num_params/1e6))


def init_net(net, init_type='normal', activation='relu', gpu_ids=[]):
    """print the network structure and initial the network"""
    print_network(net)

    if len(gpu_ids) > 0:
        assert(torch.cuda.is_available())
        net.cuda()
        net = torch.nn.DataParallel(net, gpu_ids)
    init_weights(net, init_type)
    return net


def _freeze(*args):
    """freeze the network for forward process"""
    for module in args:
        if module:
            for p in module.parameters():
                p.requires_grad = False


def _unfreeze(*args):
    """ unfreeze the network for parameter update"""
    for module in args:
        if module:
            for p in module.parameters():
                p.requires_grad = True


def spectral_norm(module, use_spect=True):
    """use spectral normal layer to stable the training process"""
    if use_spect:
        return SpectralNorm(module)
    else:
        return module


def coord_conv(input_nc, output_nc, use_spect=False, use_coord=False, with_r=False, **kwargs):
    """use coord convolution layer to add position information"""
    if use_coord:
        return CoordConv(input_nc, output_nc, with_r, use_spect, **kwargs)
    else:
        return spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)

def conv1x1(in_planes, out_planes):
    "1x1 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
                     padding=0, bias=False)

def func_attention(query, context, gamma1):
    """
    query: batch x ndf x queryL
    context: batch x ndf x ih x iw (sourceL=ihxiw)
    mask: batch_size x sourceL
    """
    batch_size, queryL = query.size(0), query.size(2)
    ih, iw = context.size(2), context.size(3)
    sourceL = ih * iw

    # --> batch x sourceL x ndf
    context = context.view(batch_size, -1, sourceL)
    contextT = torch.transpose(context, 1, 2).contiguous()

    # Get attention
    # (batch x sourceL x ndf)(batch x ndf x queryL)
    # -->batch x sourceL x queryL
    attn = torch.bmm(contextT, query)  # Eq. (7) in AttnGAN paper
    # --> batch*sourceL x queryL
    attn = attn.view(batch_size * sourceL, queryL)
    attn = nn.Softmax()(attn)  # Eq. (8)

    # --> batch x sourceL x queryL
    attn = attn.view(batch_size, sourceL, queryL)
    # --> batch*queryL x sourceL
    attn = torch.transpose(attn, 1, 2).contiguous()
    attn = attn.view(batch_size * queryL, sourceL)
    #  Eq. (9)
    attn = attn * gamma1
    attn = nn.Softmax()(attn)
    attn = attn.view(batch_size, queryL, sourceL)
    # --> batch x sourceL x queryL
    attnT = torch.transpose(attn, 1, 2).contiguous()

    # (batch x ndf x sourceL)(batch x sourceL x queryL)
    # --> batch x ndf x queryL
    weightedContext = torch.bmm(context, attnT)

    return weightedContext, attn.view(batch_size, -1, ih, iw)

class AddCoords(nn.Module):
    """
    Add Coords to a tensor
    """
    def __init__(self, with_r=False):
        super(AddCoords, self).__init__()
        self.with_r = with_r

    def forward(self, x):
        """
        :param x: shape (batch, channel, x_dim, y_dim)
        :return: shape (batch, channel+2, x_dim, y_dim)
        """
        B, _, x_dim, y_dim = x.size()

        # coord calculate
        xx_channel = torch.arange(x_dim).repeat(B, 1, y_dim, 1).type_as(x)
        yy_cahnnel = torch.arange(y_dim).repeat(B, 1, x_dim, 1).permute(0, 1, 3, 2).type_as(x)
        # normalization
        xx_channel = xx_channel.float() / (x_dim-1)
        yy_cahnnel = yy_cahnnel.float() / (y_dim-1)
        xx_channel = xx_channel * 2 - 1
        yy_cahnnel = yy_cahnnel * 2 - 1

        ret = torch.cat([x, xx_channel, yy_cahnnel], dim=1)

        if self.with_r:
            rr = torch.sqrt(xx_channel ** 2 + yy_cahnnel ** 2)
            ret = torch.cat([ret, rr], dim=1)

        return ret


class CoordConv(nn.Module):
    """
    CoordConv operation
    """
    def __init__(self, input_nc, output_nc, with_r=False, use_spect=False, **kwargs):
        super(CoordConv, self).__init__()
        self.addcoords = AddCoords(with_r=with_r)
        input_nc = input_nc + 2
        if with_r:
            input_nc = input_nc + 1
        self.conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)

    def forward(self, x):
        ret = self.addcoords(x)
        ret = self.conv(ret)

        return ret


class ResBlock(nn.Module):
    """
    Define an Residual block for different types
    """
    def __init__(self, input_nc, output_nc, hidden_nc=None, norm_layer=nn.BatchNorm2d, nonlinearity= nn.LeakyReLU(),
                 sample_type='none', use_spect=False, use_coord=False):
        super(ResBlock, self).__init__()

        hidden_nc = output_nc if hidden_nc is None else hidden_nc
        self.sample = True
        if sample_type == 'none':
            self.sample = False
        elif sample_type == 'up':
            output_nc = output_nc * 4
            self.pool = nn.PixelShuffle(upscale_factor=2)
        elif sample_type == 'down':
            self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
        else:
            raise NotImplementedError('sample type [%s] is not found' % sample_type)

        kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
        kwargs_short = {'kernel_size': 1, 'stride': 1, 'padding': 0}

        self.conv1 = coord_conv(input_nc, hidden_nc, use_spect, use_coord, **kwargs)
        self.conv2 = coord_conv(hidden_nc, output_nc, use_spect, use_coord, **kwargs)
        self.bypass = coord_conv(input_nc, output_nc, use_spect, use_coord, **kwargs_short)

        if type(norm_layer) == type(None):
            self.model = nn.Sequential(nonlinearity, self.conv1, nonlinearity, self.conv2,)
        else:
            self.model = nn.Sequential(norm_layer(input_nc), nonlinearity, self.conv1, norm_layer(hidden_nc), nonlinearity, self.conv2,)

        self.shortcut = nn.Sequential(self.bypass,)

    def forward(self, x):
        if self.sample:
            out = self.pool(self.model(x)) + self.pool(self.shortcut(x))
        else:
            out = self.model(x) + self.shortcut(x)

        return out


class ResBlockEncoderOptimized(nn.Module):
    """
    Define an Encoder block for the first layer of the discriminator and representation network
    """
    def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity= nn.LeakyReLU(), use_spect=False, use_coord=False):
        super(ResBlockEncoderOptimized, self).__init__()

        kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
        kwargs_short = {'kernel_size': 1, 'stride': 1, 'padding': 0}

        self.conv1 = coord_conv(input_nc, output_nc, use_spect, use_coord, **kwargs)
        self.conv2 = coord_conv(output_nc, output_nc, use_spect, use_coord, **kwargs)
        self.bypass = coord_conv(input_nc, output_nc, use_spect, use_coord, **kwargs_short)

        if type(norm_layer) == type(None):
            self.model = nn.Sequential(self.conv1, nonlinearity, self.conv2, nn.AvgPool2d(kernel_size=2, stride=2))
        else:
            self.model = nn.Sequential(self.conv1, norm_layer(output_nc), nonlinearity, self.conv2, nn.AvgPool2d(kernel_size=2, stride=2))

        self.shortcut = nn.Sequential(nn.AvgPool2d(kernel_size=2, stride=2), self.bypass)

    def forward(self, x):
        out = self.model(x) + self.shortcut(x)

        return out


class ResBlockDecoder(nn.Module):
    """
    Define a decoder block
    """
    def __init__(self, input_nc, output_nc, hidden_nc=None, norm_layer=nn.BatchNorm2d, nonlinearity= nn.LeakyReLU(),
                 use_spect=False, use_coord=False):
        super(ResBlockDecoder, self).__init__()

        hidden_nc = output_nc if hidden_nc is None else hidden_nc

        self.conv1 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, kernel_size=3, stride=1, padding=1), use_spect)
        self.conv2 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, kernel_size=3, stride=2, padding=1, output_padding=1), use_spect)
        self.bypass = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, kernel_size=3, stride=2, padding=1, output_padding=1), use_spect)

        if type(norm_layer) == type(None):
            self.model = nn.Sequential(nonlinearity, self.conv1, nonlinearity, self.conv2,)
        else:
            self.model = nn.Sequential(norm_layer(input_nc), nonlinearity, self.conv1, norm_layer(hidden_nc), nonlinearity, self.conv2,)

        self.shortcut = nn.Sequential(self.bypass)

    def forward(self, x):
        out = self.model(x) + self.shortcut(x)

        return out


class Output(nn.Module):
    """
    Define the output layer
    """
    def __init__(self, input_nc, output_nc, kernel_size = 3, norm_layer=nn.BatchNorm2d, nonlinearity= nn.LeakyReLU(),
                 use_spect=False, use_coord=False):
        super(Output, self).__init__()

        kwargs = {'kernel_size': kernel_size, 'padding':0, 'bias': True}

        self.conv1 = coord_conv(input_nc, output_nc, use_spect, use_coord, **kwargs)

        if type(norm_layer) == type(None):
            self.model = nn.Sequential(nonlinearity, nn.ReflectionPad2d(int(kernel_size/2)), self.conv1, nn.Tanh())
        else:
            self.model = nn.Sequential(norm_layer(input_nc), nonlinearity, nn.ReflectionPad2d(int(kernel_size / 2)), self.conv1, nn.Tanh())

    def forward(self, x):
        out = self.model(x)

        return out


class Auto_Attn(nn.Module):
    """ Short+Long attention Layer"""

    def __init__(self, input_nc, norm_layer=nn.BatchNorm2d):
        super(Auto_Attn, self).__init__()
        self.input_nc = input_nc

        self.query_conv = nn.Conv2d(input_nc, input_nc // 4, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.alpha = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)

        self.model = ResBlock(int(input_nc*2), input_nc, input_nc, norm_layer=norm_layer, use_spect=True)

    def forward(self, x, pre=None, mask=None):
        """
        inputs :
            x : input feature maps( B X C X W X H)
        returns :
            out : self attention value + input feature
            attention: B X N X N (N is Width*Height)
        """
        B, C, W, H = x.size()
        proj_query = self.query_conv(x).view(B, -1, W * H)  # B X (N)X C
        proj_key = proj_query  # B X C x (N)

        energy = torch.bmm(proj_query.permute(0, 2, 1), proj_key)  # transpose check
        attention = self.softmax(energy)  # BX (N) X (N)
        proj_value = x.view(B, -1, W * H)  # B X C X N

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(B, C, W, H)

        out = self.gamma * out + x

        if type(pre) != type(None):
            # using long distance attention layer to copy information from valid regions
            context_flow = torch.bmm(pre.view(B, -1, W*H), attention.permute(0, 2, 1)).view(B, -1, W, H)
            context_flow = self.alpha * (1-mask) * context_flow + (mask) * pre
            out = self.model(torch.cat([out, context_flow], dim=1))

        return out, attention


class ImageTextAttention(nn.Module):
    """
    Global attention takes a matrix and a query metrix.
    Based on each query vector q, it computes a parameterized convex combination of the matrix
    based.
    H_1 H_2 H_3 ... H_n
      q   q   q       q
        |  |   |       |
          \ |   |      /
                  .....
              \   |  /
                      a
    Constructs a unit mapping.
    $$(H_1 + H_n, q) => (a)$$
    Where H is of `batch x n x dim` and q is of `batch x dim`.
    References:
    https://github.com/OpenNMT/OpenNMT-py/tree/fc23dfef1ba2f258858b2765d24565266526dc76/onmt/modules
    http://www.aclweb.org/anthology/D15-1166

    :param idf: image dimension
    :param cdf: text dimension
    :param multi_peak: use sigmoid when computing word attention
    :param pooling: pooling layer type on weightedConext
    """
    def __init__(self, idf, cdf, multi_peak=False, pooling='max'):
        super(ImageTextAttention, self).__init__()
        self.conv_image = conv1x1(idf, cdf)
        self.sm = nn.Softmax()
        self.multi_peak = multi_peak
        self.sigmoid = nn.Sigmoid()
        self.pooling = pooling
        if self.pooling == 'max':
            self.pooling_layer = nn.AdaptiveMaxPool2d(1)
        elif self.pooling == 'avg':
            self.pooling_layer = nn.AdaptiveAvgPool2d(1)
        else:
            self.pooling = False

    def forward_softmax(self, image, text, mask=None, image_mask=None, inverse_attention=False):
        """
            input: batch x idf x ih x iw (image_L=ihxiw)
            context: batch x cdf x text_L
        """
        ih, iw = image.size(2), image.size(3)
        image_L = ih * iw
        batch_size, text_L = text.size(0), text.size(2)

        # --> batch x image_L x idf
        image = self.conv_image(image)
        image_flat = image.view(batch_size, -1, image_L)
        image_flat_T = torch.transpose(image_flat, 1, 2).contiguous()

        # Get attention
        # (batch x image_L x idf)(batch x idf x text_L)
        # -->batch x image_L x text_L
        attn = torch.bmm(image_flat_T, text)
        if inverse_attention:
            attn *= -1

        if image_mask is not None:
            # in img_mask, 0 is masked, so here we inverse the mask value
            image_mask = (1-image_mask).bool()#torch.logical_not(image_mask.bool())
            image_mask = image_mask.view(-1, image_L, 1).repeat(1, 1, text_L)

            attn.data.masked_fill_(image_mask.data, -float('inf'))

        # --> batch*image_L x text_L
        attn = attn.view(batch_size*image_L, text_L)
        if mask is not None:
            # batch_size x text_L --> batch_size*image_L x text_L
            mask = mask.repeat(image_L, 1)
            attn.data.masked_fill_(mask.data, -float('inf'))

        attn = self.sm(attn)  # Eq. (2)
        attn.data.masked_fill_(attn != attn, 0)
        # --> batch x image_L x text_L
        attn = attn.view(batch_size, image_L, text_L)
        # --> batch x text_L x image_L
        attn = torch.transpose(attn, 1, 2).contiguous()

        # (batch x idf x text_L)(batch x text_L x image_L)
        # --> batch x idf x image_L
        weightedContext = torch.bmm(text, attn)
        weightedContext = weightedContext.view(batch_size, -1, ih, iw)
        attn = attn.view(batch_size, -1, ih, iw)

        # torch.save(attn.detach(), 'attention_map.pt')

        return weightedContext

    def forward_sigmoid(self, image, text, mask=None, image_mask=None, inverse_attention=False):
        """
            input: batch x idf x ih x iw (image_L=ihxiw)
            context: batch x cdf x text_L
        """
        ih, iw = image.size(2), image.size(3)
        image_L = ih * iw
        batch_size, text_L = text.size(0), text.size(2)

        # --> batch x image_L x idf
        image = self.conv_image(image)
        image_flat = image.view(batch_size, -1, image_L)
        image_flat_T = torch.transpose(image_flat, 1, 2).contiguous()

        # Get attention
        # (batch x image_L x idf)(batch x idf x text_L)
        # -->batch x image_L x text_L
        attn = torch.bmm(image_flat_T, text)

        # Apply mask
        if image_mask is not None:
            # in img_mask, 0 is masked, so here we inverse the mask value
            image_mask = (1-image_mask).bool()#torch.logical_not(image_mask.bool())
            image_mask = image_mask.view(-1, image_L, 1).repeat(1, 1, text_L)

            attn.data.masked_fill_(image_mask.data, -float('inf'))

        # --> batch*image_L x text_L
        attn = attn.view(batch_size*image_L, text_L)
        if mask is not None:
            # batch_size x text_L --> batch_size*image_L x text_L
            mask = mask.repeat(image_L, 1)
            attn.data.masked_fill_(mask.data, -float('inf'))

        attn = self.sigmoid(attn)  # Eq. (2)
        if inverse_attention:
            attn = 1 - attn

        attn.data.masked_fill_(attn != attn, 0)
        # --> batch x image_L x text_L
        attn = attn.view(batch_size, image_L, text_L)
        # --> batch x text_L x image_L
        attn = torch.transpose(attn, 1, 2).contiguous()

        # (batch x idf x text_L)(batch x text_L x image_L)
        # --> batch x idf x image_L
        weightedContext = torch.bmm(text, attn)
        weightedContext = weightedContext.view(batch_size, -1, ih, iw)
        attn = attn.view(batch_size, -1, ih, iw)

        return weightedContext

    def forward(self, image, text, mask=None, image_mask=None, inverse_attention=False):
        if self.multi_peak:
            weightedContext = self.forward_sigmoid(image, text, mask, image_mask, inverse_attention)
        else:
            weightedContext = self.forward_softmax(image, text, mask, image_mask, inverse_attention)
        if self.pooling is not False:
            ih, iw = weightedContext.size(2), weightedContext.size(3)
            weightedContext = self.pooling_layer(weightedContext)
            weightedContext = weightedContext.repeat(1, 1, ih, iw)

        return weightedContext

class GlobalAttentionGeneral(nn.Module):
    """
    Global attention takes a matrix and a query metrix.
    Based on each query vector q, it computes a parameterized convex combination of the matrix
    based.
    H_1 H_2 H_3 ... H_n
      q   q   q       q
        |  |   |       |
          \ |   |      /
                  .....
              \   |  /
                      a
    Constructs a unit mapping.
    $$(H_1 + H_n, q) => (a)$$
    Where H is of `batch x n x dim` and q is of `batch x dim`.
    References:
    https://github.com/OpenNMT/OpenNMT-py/tree/fc23dfef1ba2f258858b2765d24565266526dc76/onmt/modules
    http://www.aclweb.org/anthology/D15-1166
    """

    def __init__(self, idf, cdf):
        super(GlobalAttentionGeneral, self).__init__()
        self.conv_context = conv1x1(cdf, idf)
        self.sm = nn.Softmax()

    def forward(self, input, context, mask=None, inverse_attention=False):
        """
            input: batch x idf x ih x iw (queryL=ihxiw)
            context: batch x cdf x sourceL
        """
        ih, iw = input.size(2), input.size(3)
        queryL = ih * iw
        batch_size, sourceL = context.size(0), context.size(2)

        # --> batch x queryL x idf
        target = input.view(batch_size, -1, queryL)
        targetT = torch.transpose(target, 1, 2).contiguous()
        # batch x cdf x sourceL --> batch x cdf x sourceL x 1
        sourceT = context.unsqueeze(3)
        # --> batch x idf x sourceL
        sourceT = self.conv_context(sourceT).squeeze(3)

        # Get attention
        # (batch x queryL x idf)(batch x idf x sourceL)
        # -->batch x queryL x sourceL
        attn = torch.bmm(targetT, sourceT)
        if inverse_attention:
            attn *= -1
        attn = attn.view(batch_size * queryL, sourceL)
        if mask is not None:
            # batch_size x sourceL --> batch_size*queryL x sourceL
            mask = mask.repeat(queryL, 1)
            attn.data.masked_fill_(mask.data, -float('inf'))

        attn = self.sm(attn)  # Eq. (2)
        # --> batch x queryL x sourceL
        attn = attn.view(batch_size, queryL, sourceL)
        # --> batch x sourceL x queryL
        attn = torch.transpose(attn, 1, 2).contiguous()

        # (batch x idf x sourceL)(batch x sourceL x queryL)
        # --> batch x idf x queryL
        weightedContext = torch.bmm(sourceT, attn)
        weightedContext = weightedContext.view(batch_size, -1, ih, iw)
        attn = attn.view(batch_size, -1, ih, iw)

        return weightedContext

# ##################Loss for matching text-image###################
def cosine_similarity(x1, x2, dim=1, eps=1e-8):
    """Returns cosine similarity between x1 and x2, computed along dim.
    """
    w12 = torch.sum(x1 * x2, dim)
    w1 = torch.norm(x1, 2, dim)
    w2 = torch.norm(x2, 2, dim)
    return (w12 / (w1 * w2).clamp(min=eps))


def sent_loss(cnn_code, rnn_code, labels, eps=1e-8, smooth_gama3=10.0):

    # --> seq_len x batch_size x nef
    if cnn_code.dim() == 2:
        cnn_code = cnn_code.unsqueeze(0)
        rnn_code = rnn_code.unsqueeze(0)

    # cnn_code_norm / rnn_code_norm: seq_len x batch_size x 1
    cnn_code_norm = torch.norm(cnn_code, 2, dim=2, keepdim=True)
    rnn_code_norm = torch.norm(rnn_code, 2, dim=2, keepdim=True)
    # scores* / norm*: seq_len x batch_size x batch_size
    scores0 = torch.bmm(cnn_code, rnn_code.transpose(1, 2))
    norm0 = torch.bmm(cnn_code_norm, rnn_code_norm.transpose(1, 2))
    scores0 = scores0 / norm0.clamp(min=eps) * smooth_gama3

    # --> batch_size x batch_size
    scores0 = scores0.squeeze(0)

    scores1 = scores0.transpose(0, 1)
    if labels is not None:
        loss0 = nn.CrossEntropyLoss()(scores0, labels)
        loss1 = nn.CrossEntropyLoss()(scores1, labels)
    else:
        loss0, loss1 = None, None
    return loss0 + loss1


def words_loss(img_features, words_emb, labels, cap_lens, batch_size,
                            smooth_gamma1=5.0, smooth_gamma2=5.0, smooth_gamma3=10.0):
    """
        words_emb(query): batch x nef x seq_len
        img_features(context): batch x nef x 17 x 17
    """

    att_maps = []
    similarities = []
    cap_lens = cap_lens.data.tolist()
    for i in range(batch_size):

        # Get the i-th text description
        words_num = cap_lens[i]
        # -> 1 x nef x words_num
        word = words_emb[i, :, :words_num].unsqueeze(0).contiguous()
        # -> batch_size x nef x words_num
        word = word.repeat(batch_size, 1, 1)
        # batch x nef x 17*17
        context = img_features
        """
            word(query): batch x nef x words_num
            context: batch x nef x 17 x 17
            weiContext: batch x nef x words_num
            attn: batch x words_num x 17 x 17
        """
        weiContext, attn = func_attention(word, context, smooth_gamma1)
        att_maps.append(attn[i].unsqueeze(0).contiguous())
        # --> batch_size x words_num x nef
        word = word.transpose(1, 2).contiguous()
        weiContext = weiContext.transpose(1, 2).contiguous()
        # --> batch_size*words_num x nef
        word = word.view(batch_size * words_num, -1)
        weiContext = weiContext.view(batch_size * words_num, -1)
        #
        # -->batch_size*words_num
        row_sim = cosine_similarity(word, weiContext)
        # --> batch_size x words_num
        row_sim = row_sim.view(batch_size, words_num)

        # Eq. (10)
        row_sim.mul_(smooth_gamma2).exp_()
        row_sim = row_sim.sum(dim=1, keepdim=True)
        row_sim = torch.log(row_sim)

        # --> 1 x batch_size
        # similarities(i, j): the similarity between the i-th image and the j-th text description
        similarities.append(row_sim)

    # batch_size x batch_size
    similarities = torch.cat(similarities, 1)

    similarities = similarities * smooth_gamma3

    similarities1 = similarities.transpose(0, 1)
    if labels is not None:
        loss0 = nn.CrossEntropyLoss()(similarities, labels)
        loss1 = nn.CrossEntropyLoss()(similarities1, labels)
    else:
        loss0, loss1 = None, None
    return loss0 + loss1, att_maps

class GANHingeLoss(nn.Module):
    def __init__(self):
        super(GANHingeLoss, self).__init__()
        self.activation = nn.ReLU(inplace=True)

    def __call__(self, pos, neg):
        hinge_pos = torch.mean(self.activation(1-pos))
        hinge_neg = torch.mean(self.activation(1+neg))
        d_loss = .5 * hinge_pos + .5 * hinge_neg
        g_loss = -torch.mean(neg)

        return d_loss, g_loss

In [7]:
class BaseModel():
    def __init__(self, opt):
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.isTrain = opt.isTrain
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
        self.loss_names = []
        self.log_names = []
        self.model_names = []
        self.visual_names = []
        self.text_names = []
        self.value_names = []
        self.image_paths = []
        self.optimizers = []
        self.schedulers = []

    def name(self):
        return 'BaseModel'

    @staticmethod
    def modify_options(parser, is_train):
        """Add new options and rewrite default values for existing options"""
        return parser

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps"""
        pass

    def setup(self, opt):
        """Load networks, create schedulers"""
        if self.isTrain:
            self.schedulers = [base_function.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
        if not self.isTrain or opt.continue_train:
            self.load_networks(opt.which_iter, opt.gpu_ids)

    def eval(self):
        """Make models eval mode during test time"""
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net_' + name)
                net.eval()

    def get_image_paths(self):
        """Return image paths that are used to load current data"""
        return self.image_paths

    def update_learning_rate(self):
        """Update learning rate"""
        for scheduler in self.schedulers:
            scheduler.step()
        lr = self.optimizers[0].param_groups[0]['lr']
        print('learning rate=%.7f' % lr)

    def get_current_errors(self):
        """Return training loss"""
        errors_ret = OrderedDict()
        for name in self.loss_names:
            if isinstance(name, str):
                errors_ret[name] = getattr(self, 'loss_' + name).item()

        if 'img_truth' in self.visual_names:
            truth = getattr(self, 'img_truth')
            outputs_names = ['img_out', 'img_g', 'img_rec']
            for name in outputs_names:
                if name in self.visual_names:
                    out = getattr(self, name)
                    psnr = util.PSNR(util.tensor2im(out[-1].data), util.tensor2im(truth[-1].data))
                    errors_ret['psnr_'+name] = psnr

        return errors_ret

    def get_current_visuals(self):
        """Return visualization images"""
        visual_ret = OrderedDict()
        for name in self.visual_names:
            if isinstance(name, str):
                value = getattr(self, name)
                if isinstance(value, list):
                    visual_ret[name] = util.tensor2im(value[-1].data)
                else:
                    visual_ret[name] = util.tensor2im(value.data)
        return visual_ret

    def get_current_text(self):
        """Return the last image's caption of current batch"""
        text_ret = OrderedDict()
        for name in self.text_names:
            if isinstance(name, str):
                text = getattr(self, name)
                if isinstance(text, list):
                    text_ret[name] = text[-1] + '\n'+ self.image_paths[0]
                else:
                    text_ret[name] = text + '\n' + self.image_paths[0]
        return text_ret

    def get_current_dis(self):
        """Return the distribution of encoder features"""
        dis_ret = OrderedDict()
        value = getattr(self, 'distribution')
        for i in range(1):
            for j, name in enumerate(self.value_names):
                if isinstance(name, str):
                    dis_ret[name+str(i)] =util.tensor2array(value[i][j].data)

        return dis_ret

    # save model
    def save_networks(self, which_epoch):
        """Save all the networks to the disk"""
        for name in self.model_names:
            if isinstance(name, str):
                save_filename = '%s_net_%s.pth' % (which_epoch, name)
                save_path = os.path.join(self.save_dir, save_filename)
                net = getattr(self, 'net_' + name)
                torch.save(net.cpu().state_dict(), save_path)
                if len(self.gpu_ids) > 0 and torch.cuda.is_available():
                    net.cuda()

    # load models
    def load_networks(self, which_epoch, gpu_ids):
        """Load all the networks from the disk"""
        for name in self.model_names:
            if isinstance(name, str):
                filename = '%s_net_%s.pth' % (which_epoch, name)
                path = os.path.join(self.save_dir, filename)
                net = getattr(self, 'net_' + name)
                pretrained_dict = torch.load(path)
                try:
                    if len(gpu_ids) != 0:
                        net.load_state_dict(pretrained_dict)
                    else:
                        pretrained_dict_cpu = {key[7:]:value for key, value in pretrained_dict.items()}
                        net.load_state_dict(pretrained_dict_cpu)
                except:
                    model_dict = net.state_dict()
                    try:
                        pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict}
                        net.load_state_dict(pretrained_dict)
                        print('Pretrained network %s has excessive layers; Only loading layers that are used' % name)
                    except:
                        print('Pretrained network %s has fewer layers; The following are not initialized:' % name)
                        not_initialized = set()
                        for k, v in pretrained_dict.items():
                            if v.size() == model_dict[k].size():
                                model_dict[k] = v

                        for k, v in model_dict.items():
                            if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
                                not_initialized.add(k.split('.')[0])
                        print(sorted(not_initialized))
                        net.load_state_dict(model_dict)
                if len(self.gpu_ids) > 0 and torch.cuda.is_available():
                    net.cuda()
                if not self.isTrain:
                    net.eval()

    def save_results(self, save_data, score=None, data_name='none', mark=None):
        """Save the training or testing results to disk"""
        img_paths = self.get_image_paths()

        for i in range(save_data.size(0)):
            print('process image ...... %s' % img_paths[i])
            short_path = ntpath.basename(img_paths[i])  # get image path
            name = os.path.splitext(short_path)[0]
            if type(score) == type(None):
                img_name = '%s_%s.png' % (name, data_name)
            else:
                # d_score = score[i].mean()
                # img_name = '%s_%s_%s.png' % (name, data_name, str(round(d_score.item(), 3)))
                if type(mark) == type(None):
                    img_name = '%s_%s_%s.png' % (name, data_name, str(score))
                else:
                    img_name = '%s_%s_%s_%s.png' % (name, data_name, str(score), str(mark))
            # save predicted image with discriminator score
            util.mkdir(self.opt.results_dir)
            img_path = os.path.join(self.opt.results_dir, img_name)
            img_numpy = util.tensor2im(save_data[i].data)
            util.save_image(img_numpy, img_path)

In [8]:
####################################################################################################
# spectral normalization layer to decouple the magnitude of a weight tensor
####################################################################################################

def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)


class SpectralNorm(nn.Module):
    """
    spectral normalization
    code and idea originally from Takeru Miyato's work 'Spectral Normalization for GAN'
    https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
    """
    def __init__(self, module, name='weight', power_iterations=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))

        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            v = getattr(self.module, self.name + "_v")
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False

    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)

    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)


####################################################################################################
# adversarial loss for different gan mode
####################################################################################################


class GANLoss(nn.Module):
    """Define different GAN objectives.
    The GANLoss class abstracts away the need to create the target label tensor
    that has the same size as the input.
    """

    def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
        """ Initialize the GANLoss class.
        Parameters:
            gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
            target_real_label (bool) - - label for a real image
            target_fake_label (bool) - - label of a fake image
        Note: Do not use sigmoid as the last layer of Discriminator.
        LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
        """
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode == 'hinge':
            self.loss = nn.ReLU()
        elif gan_mode == 'wgangp':
            self.loss = None
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)

    def __call__(self, prediction, target_is_real, is_disc=False):
        """ Calculate loss given Discriminator's output and grount truth labels.
        Parameters:
            prediction (tensor) - - tpyically the prediction output from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images
        Returns:
            the calculated loss.
        """
        if self.gan_mode in ['lsgan', 'vanilla']:
            labels = (self.real_label if target_is_real else self.fake_label).expand_as(prediction).type_as(prediction)
            loss = self.loss(prediction, labels)
        elif self.gan_mode in ['hinge', 'wgangp']:
            if is_disc:
                if target_is_real:
                    prediction = -prediction
                if self.gan_mode == 'hinge':
                    loss = self.loss(1 + prediction).mean()
                elif self.gan_mode == 'wgangp':
                    loss = prediction.mean()
            else:
                loss = -prediction.mean()
        return loss


def cal_gradient_penalty(netD, real_data, fake_data, type='mixed', constant=1.0, lambda_gp=10.0):
    """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
    Arguments:
        netD (network)              -- discriminator network
        real_data (tensor array)    -- real images
        fake_data (tensor array)    -- generated images from the generator
        type (str)                  -- if we mix real and fake data or not [real | fake | mixed].
        constant (float)            -- the constant used in formula ( | |gradient||_2 - constant)^2
        lambda_gp (float)           -- weight for this loss
    Returns the gradient penalty loss
    """
    if lambda_gp > 0.0:
        if type == 'real':   # either use real images, fake images, or a linear interpolation of two.
            interpolatesv = real_data
        elif type == 'fake':
            interpolatesv = fake_data
        elif type == 'mixed':
            alpha = torch.rand(real_data.shape[0], 1)
            alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
            alpha = alpha.type_as(real_data)
            interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
        else:
            raise NotImplementedError('{} not implemented'.format(type))
        interpolatesv.requires_grad_(True)
        disc_interpolates = netD(interpolatesv)
        gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
                                        grad_outputs=torch.ones(disc_interpolates.size()).type_as(real_data),
                                        create_graph=True, retain_graph=True, only_inputs=True)
        gradients = gradients[0].view(real_data.size(0), -1)  # flat the data
        gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp        # added eps
        return gradient_penalty, gradients
    else:
        return 0.0, None


####################################################################################################
# neural style transform loss from neural_style_tutorial of pytorch
####################################################################################################


def ContentLoss(input, target):
    target = target.detach()
    loss = F.l1_loss(input, target)
    return loss


def GramMatrix(input):
    s = input.size()
    features = input.view(s[0], s[1], s[2]*s[3])
    features_t = torch.transpose(features, 1, 2)
    G = torch.bmm(features, features_t).div(s[1]*s[2]*s[3])
    return G


def StyleLoss(input, target):
    target = GramMatrix(target).detach()
    input = GramMatrix(input)
    loss = F.l1_loss(input, target)
    return loss


def img_crop(input, size=224):
    input_cropped = F.upsample(input, size=(size, size), mode='bilinear', align_corners=True)
    return input_cropped


class Normalization(nn.Module):
    def __init__(self, mean, std):
        super(Normalization, self).__init__()
        self.mean = mean.view(-1, 1, 1)
        self.std = std.view(-1, 1, 1)

    def forward(self, input):
        return (input-self.mean) / self.std


class get_features(nn.Module):
    def __init__(self, cnn):
        super(get_features, self).__init__()

        vgg = copy.deepcopy(cnn)

        self.conv1 = nn.Sequential(vgg[0], vgg[1], vgg[2], vgg[3], vgg[4])
        self.conv2 = nn.Sequential(vgg[5], vgg[6], vgg[7], vgg[8], vgg[9])
        self.conv3 = nn.Sequential(vgg[10], vgg[11], vgg[12], vgg[13], vgg[14], vgg[15], vgg[16])
        self.conv4 = nn.Sequential(vgg[17], vgg[18], vgg[19], vgg[20], vgg[21], vgg[22], vgg[23])
        self.conv5 = nn.Sequential(vgg[24], vgg[25], vgg[26], vgg[27], vgg[28], vgg[29], vgg[30])

    def forward(self, input, layers):
        input = img_crop(input)
        output = []
        for i in range(1, layers):
            layer = getattr(self, 'conv'+str(i))
            input = layer(input)
            output.append(input)
        return output

In [9]:
##############################################################################################################
# Network function
##############################################################################################################
def define_e(input_nc=3, ngf=64, z_nc=512, img_f=512, L=6, layers=5, norm='none', activation='ReLU', use_spect=True,
             use_coord=False, init_type='orthogonal', gpu_ids=[]):

    net = ResEncoder(input_nc, ngf, z_nc, img_f, L, layers, norm, activation, use_spect, use_coord)

    return init_net(net, init_type, activation, gpu_ids)

def define_textual_e(input_nc=3, ngf=64, z_nc=512, img_f=512, L=6, layers=5, norm='none', activation='ReLU', use_spect=True,
             use_coord=False, init_type='orthogonal', gpu_ids=[], image_dim=256, text_dim=256, multi_peak=True, pool_attention='max'):
    net = TextualResEncoder(input_nc, ngf, z_nc, img_f, L, layers, norm, activation, use_spect, use_coord, image_dim, text_dim, multi_peak, pool_attention)
    return init_net(net, init_type, activation, gpu_ids)

def define_att_textual_e(input_nc=3, ngf=64, z_nc=512, img_f=512, L=6, layers=5, norm='none', activation='ReLU', use_spect=True,
             use_coord=False, init_type='orthogonal', gpu_ids=[], image_dim=256, text_dim=256, multi_peak=True, pool_attention='max'):
    net = AttTextualResEncoder(input_nc, ngf, z_nc, img_f, L, layers, norm, activation, use_spect, use_coord, image_dim, text_dim, multi_peak, pool_attention)
    return init_net(net, init_type, activation, gpu_ids)

def define_contract_e(input_nc=3, ngf=64, z_nc=512, img_f=512, L=6, layers=5, norm='none', activation='ReLU', use_spect=True,
             use_coord=False, init_type='orthogonal', gpu_ids=[], image_dim=256, text_dim=256, multi_peak=True, pool_attention='max'):
    net = ContrastResEncoder(input_nc, ngf, z_nc, img_f, L, layers, norm, activation, use_spect, use_coord, image_dim, text_dim, multi_peak, pool_attention)
    return init_net(net, init_type, activation, gpu_ids)

def define_pos_textual_e(input_nc=3, ngf=64, z_nc=512, img_f=512, L=6, layers=5, norm='none', activation='ReLU', use_spect=True,
             use_coord=False, init_type='orthogonal', gpu_ids=[], image_dim=256, text_dim=256, multi_peak=True, pool_attention='max'):
    net = PosAttTextualResEncoder(input_nc, ngf, z_nc, img_f, L, layers, norm, activation, use_spect, use_coord, image_dim, text_dim, multi_peak, pool_attention)
    return init_net(net, init_type, activation, gpu_ids)

def define_word_attn_e(input_nc=3, ngf=64, z_nc=512, img_f=512, L=6, layers=5, norm='none', activation='ReLU', use_spect=True,
             use_coord=False, init_type='orthogonal', gpu_ids=[], image_dim=256, text_dim=256, multi_peak=True, pool_attention='max'):
    net = WordAttnEncoder(input_nc, ngf, z_nc, img_f, L, layers, norm, activation, use_spect, use_coord, image_dim, text_dim, multi_peak, pool_attention)
    return init_net(net, init_type, activation, gpu_ids)

def define_constraint_e(input_nc=3, ngf=64, z_nc=512, img_f=512, L=6, layers=5, norm='none', activation='ReLU', use_spect=True,
             use_coord=False, init_type='orthogonal', gpu_ids=[], image_dim=256, text_dim=256, multi_peak=True, pool_attention='max'):
    net = ConstraintResEncoder(input_nc, ngf, z_nc, img_f, L, layers, norm, activation, use_spect, use_coord, image_dim, text_dim, multi_peak, pool_attention)
    return init_net(net, init_type, activation, gpu_ids)

def define_g(output_nc=3, ngf=64, z_nc=512, img_f=512, L=1, layers=5, norm='instance', activation='ReLU', output_scale=1,
             use_spect=True, use_coord=False, use_attn=True, init_type='orthogonal', gpu_ids=[]):

    net = ResGenerator(output_nc, ngf, z_nc, img_f, L, layers, norm, activation, output_scale, use_spect, use_coord, use_attn)

    return init_net(net, init_type, activation, gpu_ids)

def define_constrast_g(output_nc=3, ngf=64, z_nc=512, img_f=512, L=1, layers=5, norm='instance', activation='ReLU', output_scale=1,
             use_spect=True, use_coord=False, use_attn=True, init_type='orthogonal', gpu_ids=[]):

    net = ContrastResGenerator(output_nc, ngf, z_nc, img_f, L, layers, norm, activation, output_scale, use_spect, use_coord, use_attn)

    return init_net(net, init_type, activation, gpu_ids)


def define_textual_g(output_nc=3, f_text_dim=384, ngf=64, z_nc=512, img_f=512, L=1, layers=5, norm='instance', activation='ReLU', output_scale=1,
             use_spect=True, use_coord=False, use_attn=True, init_type='orthogonal', gpu_ids=[]):

    net = TextualResGenerator(output_nc, f_text_dim, ngf, z_nc, img_f, L, layers, norm, activation, output_scale, use_spect, use_coord, use_attn)

    return init_net(net, init_type, activation, gpu_ids)

def define_hidden_textual_g(output_nc=3, f_text_dim=384, ngf=64, z_nc=512, img_f=512, L=1, layers=5, norm='instance', activation='ReLU', output_scale=1,
             use_spect=True, use_coord=False, use_attn=True, init_type='orthogonal', gpu_ids=[]):

    net = HiddenResGenerator(output_nc, f_text_dim, ngf, z_nc, img_f, L, layers, norm, activation, output_scale, use_spect, use_coord, use_attn)

    return init_net(net, init_type, activation, gpu_ids)

def define_d(input_nc=3, ndf=64, img_f=512, layers=6, norm='none', activation='LeakyReLU', use_spect=True, use_coord=False,
             use_attn=True,  model_type='ResDis', init_type='orthogonal', gpu_ids=[]):

    if model_type == 'ResDis':
        net = ResDiscriminator(input_nc, ndf, img_f, layers, norm, activation, use_spect, use_coord, use_attn)
    elif model_type == 'PatchDis':
        net = SNPatchDiscriminator(input_nc, ndf, img_f, layers, norm, activation, use_spect, use_coord, use_attn)

    return init_net(net, init_type, activation, gpu_ids)

def define_textual_attention(image_dim, text_dim, multi_peak=True, init_type='orthogonal',  gpu_ids=[]):
    net = ImageTextAttention(idf=image_dim, cdf=text_dim, multi_peak=multi_peak)
    return init_net(net, init_type, gpu_ids=gpu_ids)


#############################################################################################################
# Network structure
#############################################################################################################
class ResEncoder(nn.Module):
    """
    ResNet Encoder Network
    :param input_nc: number of channels in input
    :param ngf: base filter channel
    :param z_nc: latent channels
    :param img_f: the largest feature channels
    :param L: Number of refinements of density
    :param layers: down and up sample layers
    :param norm: normalization function 'instance, batch, group'
    :param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
    """
    def __init__(self, input_nc=3, ngf=64, z_nc=128, img_f=1024, L=6, layers=6, norm='none', activation='ReLU',
                 use_spect=True, use_coord=False):
        super(ResEncoder, self).__init__()

        self.layers = layers
        self.z_nc = z_nc
        self.L = L

        norm_layer = get_norm_layer(norm_type=norm)
        nonlinearity = get_nonlinearity_layer(activation_type=activation)
        # encoder part
        self.block0 = ResBlockEncoderOptimized(input_nc, ngf, norm_layer, nonlinearity, use_spect, use_coord)

        mult = 1
        for i in range(layers-1):
            mult_prev = mult
            mult = min(2 ** (i + 1), img_f // ngf)
            block = ResBlock(ngf * mult_prev, ngf * mult, ngf * mult_prev, norm_layer, nonlinearity, 'down', use_spect, use_coord)
            setattr(self, 'encoder' + str(i), block)

        # inference part
        for i in range(self.L):
            block = ResBlock(ngf * mult, ngf * mult, ngf *mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)
            setattr(self, 'infer_prior' + str(i), block)

        self.posterior = ResBlock(ngf * mult, 2*z_nc, ngf * mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)
        self.prior = ResBlock(ngf * mult, 2*z_nc, ngf * mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)

    def forward(self, img_m, img_c=None):
        """
        :param img_m: image with mask regions I_m
        :param img_c: complement of I_m, the mask regions
        :return distribution: distribution of mask regions, for training we have two paths, testing one path
        :return feature: the conditional feature f_m, and the previous f_pre for auto context attention
        """

        if type(img_c) != type(None):
            img = torch.cat([img_m, img_c], dim=0)
        else:
            img = img_m

        # encoder part
        out = self.block0(img)
        feature = [out]
        for i in range(self.layers-1):
            model = getattr(self, 'encoder' + str(i))
            out = model(out)
            feature.append(out)

        # infer part
        # during the training, we have two paths, during the testing, we only have one paths
        if type(img_c) != type(None):
            distribution = self.two_paths(out)
            return distribution, feature
        else:
            distribution = self.one_path(out)
            return distribution, feature

    def one_path(self, f_in):
        """one path for baseline training or testing"""
        f_m = f_in
        distribution = []

        # infer state
        for i in range(self.L):
            infer_prior = getattr(self, 'infer_prior' + str(i))
            f_m = infer_prior(f_m)

        # get distribution
        o = self.prior(f_m)
        q_mu, q_std = torch.split(o, self.z_nc, dim=1)
        distribution.append([q_mu, F.softplus(q_std)])

        return distribution

    def two_paths(self, f_in):
        """two paths for the training"""
        f_m, f_c = f_in.chunk(2)
        distributions = []

        # get distribution
        o = self.posterior(f_c)
        p_mu, p_std = torch.split(o, self.z_nc, dim=1)
        distribution = self.one_path(f_m)
        distributions.append([p_mu, F.softplus(p_std), distribution[0][0], distribution[0][1]])

        return distributions

class TextualResEncoder(nn.Module):
    """
    ResNet Encoder Network
    :param input_nc: number of channels in input
    :param ngf: base filter channel
    :param z_nc: latent channels
    :param img_f: the largest feature channels
    :param L: Number of refinements of density
    :param layers: down and up sample layers
    :param norm: normalization function 'instance, batch, group'
    :param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
    """
    def __init__(self, input_nc=3, ngf=32, z_nc=256, img_f=256, L=6, layers=5, norm='none', activation='ReLU',
                 use_spect=True, use_coord=False, image_dim=256, text_dim=256, multi_peak=True, pool_attention='max'):
        super(TextualResEncoder, self).__init__()

        self.layers = layers
        self.z_nc = z_nc
        self.L = L

        norm_layer = get_norm_layer(norm_type=norm)
        nonlinearity = get_nonlinearity_layer(activation_type=activation)
        # encoder part
        self.block0 = ResBlockEncoderOptimized(input_nc, ngf, norm_layer, nonlinearity, use_spect, use_coord)
        self.word_attention = ImageTextAttention(idf=image_dim, cdf=text_dim, multi_peak=multi_peak, pooling=pool_attention)

        mult = 1
        for i in range(layers-1):
            mult_prev = mult
            mult = min(2 ** (i + 2), img_f // ngf)
            block = ResBlock(ngf * mult_prev, ngf * mult, ngf * mult_prev, norm_layer, nonlinearity, 'down', use_spect, use_coord)
            setattr(self, 'encoder' + str(i), block)

        # inference part
        for i in range(self.L):
            block = ResBlock(ngf * mult, ngf * mult, ngf *mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)
            setattr(self, 'infer_prior' + str(i), block)

        # For textual, only change input and hidden dimension, z_nc is set when called.
        self.posterior = ResBlock(2*text_dim, 2*z_nc, ngf * mult * 2, norm_layer, nonlinearity, 'none', use_spect, use_coord)
        self.prior =     ResBlock(2*text_dim, 2*z_nc, ngf * mult * 2, norm_layer, nonlinearity, 'none', use_spect, use_coord)

    def forward(self, img_m, sentence_embedding, word_embeddings, text_mask, image_mask, img_c=None):
        """
        :param img_m: image with mask regions I_m
        :param sentence_embedding: the sentence embedding of I
        :param word_embeddings: word embedding of I
        :param text_mask: mask of word sequence of word_embeddings
        :param image_mask: mask of Im and Ic, need to scale if apply to fm
        :param img_c: complement of I_m, the mask regions
        :return distribution: distribution of mask regions, for training we have two paths, testing one path
        :return feature: the conditional feature f_m, and the previous f_pre for auto context attention
        :return text_feature: word and sentence features
        """

        if type(img_c) != type(None):
            img = torch.cat([img_m, img_c], dim=0)
        else:
            img = img_m

        # encoder part
        out = self.block0(img)
        feature = [out]
        for i in range(self.layers-1):
            model = getattr(self, 'encoder' + str(i))
            out = model(out)
            feature.append(out)

        # infer part
        # during the training, we have two paths, during the testing, we only have one paths
        image_mask = task.scale_img(image_mask, size=[feature[-1].size(2), feature[-1].size(3)])
        if image_mask.size(1) == 3:
            image_mask = image_mask.chunk(3, dim=1)[0]

        if type(img_c) != type(None):
            # adapt to word embedding, compute weighted word embedding with fm separately
            f_m_g, f_m_rec = feature[-1].chunk(2)
            img_mask_g = image_mask
            img_mask_rec = 1 - img_mask_g
            weighted_word_embedding_rec = self.word_attention(
                        f_m_rec, word_embeddings, mask=text_mask, image_mask=img_mask_rec, inverse_attention=False)
            weighted_word_embedding_g = self.word_attention(
                        f_m_g, word_embeddings, mask=text_mask, image_mask=img_mask_g, inverse_attention=True)

            weighted_word_embedding =  torch.cat([weighted_word_embedding_g, weighted_word_embedding_rec])
            distribution, f_text = self.two_paths(out, sentence_embedding, weighted_word_embedding)

            return distribution, feature, f_text
        else:
            # adapt to word embedding, compute weighted word embedding with fm of one path
            f_m = feature[-1]
            weighted_word_embedding = self.word_attention(
                f_m, word_embeddings, mask=text_mask, image_mask=image_mask, inverse_attention=True)

            distribution, f_text = self.one_path(out, sentence_embedding, weighted_word_embedding)

            return distribution, feature, f_text

    def one_path(self, f_in, sentence_embedding, weighted_word_embedding):
        """one path for baseline training or testing"""
        # TOTEST: adapt to word embedding, compute distribution with word embedding.
        f_m = f_in
        distribution = []

        # infer state
        for i in range(self.L):
            infer_prior = getattr(self, 'infer_prior' + str(i))
            f_m = infer_prior(f_m)

        # get distribution
        # use sentence embedding here
        ix, iw = f_m.size(2), f_m.size(3)
        sentence_dim = sentence_embedding.size(1)
        sentence_embedding_replication = sentence_embedding.view(-1, sentence_dim, 1, 1).repeat(1, 1, ix, iw)
        f_text = torch.cat([sentence_embedding_replication, weighted_word_embedding], dim=1)

        o = self.prior(f_text)
        q_mu, q_std = torch.split(o, self.z_nc, dim=1)
        distribution.append([q_mu, F.softplus(q_std)])

        return distribution, f_text

    def two_paths(self, f_in, sentence_embedding, weighted_word_embedding):
        """two paths for the training"""
        f_m, f_c = f_in.chunk(2)
        weighted_word_embedding_m, weighted_word_embedding_c = weighted_word_embedding.chunk(2)
        distributions = []

        # get distribution
        # use text embedding here
        ix, iw = f_c.size(2), f_c.size(3)
        sentence_dim = sentence_embedding.size(1)
        sentence_embedding_replication = sentence_embedding.view(-1, sentence_dim, 1, 1).repeat(1, 1, ix, iw)

        f_text_c = torch.cat([sentence_embedding_replication, weighted_word_embedding_c], dim=1)
        o = self.posterior(f_text_c)
        p_mu, p_std = torch.split(o, self.z_nc, dim=1)

        distribution, f_text_m = self.one_path(f_m, sentence_embedding, weighted_word_embedding_m)
        distributions.append([p_mu, F.softplus(p_std), distribution[0][0], distribution[0][1]])

        return distributions, torch.cat([f_text_m, f_text_c], dim=0)

class AttTextualResEncoder(nn.Module):
    """
    Attentive Encoder Network
    :param input_nc: number of channels in input
    :param ngf: base filter channel
    :param z_nc: latent channels
    :param img_f: the largest feature channels
    :param L: Number of refinements of density
    :param layers: down and up sample layers
    :param norm: normalization function 'instance, batch, group'
    :param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
    :param image_dim: num of image feature maps
    :param text_dim: num of text embedding dimension
    :param multi_peak: use sigmoid in text attention if set to True
    """
    def __init__(self, input_nc=3, ngf=32, z_nc=256, img_f=256, L=6, layers=5, norm='none', activation='ReLU',
                 use_spect=True, use_coord=False, image_dim=256, text_dim=256, multi_peak=True, pool_attention='max'):
        super(AttTextualResEncoder, self).__init__()

        self.layers = layers
        self.z_nc = z_nc
        self.L = L

        norm_layer = get_norm_layer(norm_type=norm)
        nonlinearity = get_nonlinearity_layer(activation_type=activation)
        # encoder part
        self.block0 = ResBlockEncoderOptimized(input_nc, ngf, norm_layer, nonlinearity, use_spect, use_coord)
        self.word_attention = ImageTextAttention(idf=image_dim, cdf=text_dim, multi_peak=multi_peak, pooling=pool_attention)

        mult = 1
        for i in range(layers-1):
            mult_prev = mult
            mult = min(2 ** (i + 2), img_f // ngf)
            block = ResBlock(ngf * mult_prev, ngf * mult, ngf * mult_prev, norm_layer, nonlinearity, 'down', use_spect, use_coord)
            setattr(self, 'encoder' + str(i), block)

        # inference part
        for i in range(self.L):
            block = ResBlock(ngf * mult, ngf * mult, ngf *mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)
            setattr(self, 'infer_prior' + str(i), block)

        # For textual, only change input and hidden dimension, z_nc is set when called.
        self.posterior = ResBlock(ngf * mult + 2*text_dim, 2*z_nc, ngf * mult * 2, norm_layer, nonlinearity, 'none', use_spect, use_coord)
        self.prior =     ResBlock(ngf * mult + 2*text_dim, 2*z_nc, ngf * mult * 2, norm_layer, nonlinearity, 'none', use_spect, use_coord)

    def forward(self, img_m, sentence_embedding, word_embeddings, text_mask, image_mask, img_c=None):
        """
        :param img_m: image with mask regions I_m
        :param sentence_embedding: the sentence embedding of I
        :param word_embeddings: word embedding of I
        :param text_mask: mask of word sequence of word_embeddings
        :param image_mask: mask of Im and Ic, need to scale if apply to fm
        :param img_c: complement of I_m, the mask regions
        :return distribution: distribution of mask regions, for training we have two paths, testing one path
        :return feature: the conditional feature f_m, and the previous f_pre for auto context attention
        :return text_feature: word and sentence features
        """

        if type(img_c) != type(None):
            img = torch.cat([img_m, img_c], dim=0)
        else:
            img = img_m

        # encoder part
        out = self.block0(img)
        feature = [out]
        for i in range(self.layers-1):
            model = getattr(self, 'encoder' + str(i))
            out = model(out)
            feature.append(out)

        # infer part
        # during the training, we have two paths, during the testing, we only have one paths
        image_mask = task.scale_img(image_mask, size=[feature[-1].size(2), feature[-1].size(3)])
        if image_mask.size(1) == 3:
            image_mask = image_mask.chunk(3, dim=1)[0]

        if type(img_c) != type(None):
            # adapt to word embedding, compute weighted word embedding with fm separately
            f_m_g, f_m_rec = feature[-1].chunk(2)
            img_mask_g = image_mask
            img_mask_rec = 1 - img_mask_g
            weighted_word_embedding_rec = self.word_attention(
                        f_m_rec, word_embeddings, mask=text_mask, image_mask=img_mask_rec, inverse_attention=False)
            weighted_word_embedding_g = self.word_attention(
                        f_m_g, word_embeddings, mask=text_mask, image_mask=img_mask_g, inverse_attention=True)

            weighted_word_embedding =  torch.cat([weighted_word_embedding_g, weighted_word_embedding_rec])
            distribution, f_text = self.two_paths(out, sentence_embedding, weighted_word_embedding)

            return distribution, feature, f_text
        else:
            # adapt to word embedding, compute weighted word embedding with fm of one path
            f_m = feature[-1]
            weighted_word_embedding = self.word_attention(
                f_m, word_embeddings, mask=text_mask, image_mask=image_mask, inverse_attention=True)

            distribution, f_m_text = self.one_path(out, sentence_embedding, weighted_word_embedding)
            f_text = torch.cat([f_m_text, weighted_word_embedding], dim=1)
            return distribution, feature, f_text

    def one_path(self, f_in, sentence_embedding, weighted_word_embedding):
        """one path for baseline training or testing"""
        f_m = f_in
        distribution = []

        # infer state
        for i in range(self.L):
            infer_prior = getattr(self, 'infer_prior' + str(i))
            f_m = infer_prior(f_m)

        # get distribution
        # use sentence embedding here
        ix, iw = f_m.size(2), f_m.size(3)
        sentence_dim = sentence_embedding.size(1)
        sentence_embedding_replication = sentence_embedding.view(-1, sentence_dim, 1, 1).repeat(1, 1, ix, iw)
        f_m_sent = torch.cat([f_m, sentence_embedding_replication], dim=1)
        f_m_text = torch.cat([f_m_sent, weighted_word_embedding], dim=1)

        o = self.prior(f_m_text)
        q_mu, q_std = torch.split(o, self.z_nc, dim=1)
        distribution.append([q_mu, F.softplus(q_std)])

        return distribution, f_m_sent

    def two_paths(self, f_in, sentence_embedding, weighted_word_embedding):
        """two paths for the training"""
        f_m, f_c = f_in.chunk(2)
        weighted_word_embedding_m, weighted_word_embedding_c = weighted_word_embedding.chunk(2)
        distributions = []

        # get distribution
        # use text embedding here
        ix, iw = f_c.size(2), f_c.size(3)
        sentence_dim = sentence_embedding.size(1)
        sentence_embedding_replication = sentence_embedding.view(-1, sentence_dim, 1, 1).repeat(1, 1, ix, iw)
        f_c_sent = torch.cat([f_c, sentence_embedding_replication], dim=1)
        f_c_text = torch.cat([f_c_sent, weighted_word_embedding_c], dim=1)
        o = self.posterior(f_c_text)
        p_mu, p_std = torch.split(o, self.z_nc, dim=1)

        distribution, f_m_sent = self.one_path(f_m, sentence_embedding, weighted_word_embedding_m)
        distributions.append([p_mu, F.softplus(p_std), distribution[0][0], distribution[0][1]])

        f_m_text = torch.cat([f_m_sent, weighted_word_embedding_m], dim=1)
        # TODO: rm weighted_word_emb_c for consis generation
        f_c_text = torch.cat([f_m_sent, weighted_word_embedding_c], dim=1)
        return distributions, torch.cat([f_m_text, f_c_text], dim=0)

class ContrastResEncoder(nn.Module):
    """
    Contrastive Encoder Network
    :param input_nc: number of channels in input
    :param ngf: base filter channel
    :param z_nc: latent channels
    :param img_f: the largest feature channels
    :param L: Number of refinements of density
    :param layers: down and up sample layers
    :param norm: normalization function 'instance, batch, group'
    :param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
    :param image_dim: num of image feature maps
    :param text_dim: num of text embedding dimension
    :param multi_peak: use sigmoid in text attention if set to True
    """
    def __init__(self, input_nc=3, ngf=32, z_nc=256, img_f=256, L=6, layers=5, norm='none', activation='ReLU',
                 use_spect=True, use_coord=False, image_dim=256, text_dim=256, multi_peak=True, pool_attention='max'):
        super(ContrastResEncoder, self).__init__()

        self.layers = layers
        self.z_nc = z_nc
        self.L = L

        norm_layer = get_norm_layer(norm_type=norm)
        nonlinearity = get_nonlinearity_layer(activation_type=activation)
        # encoder part
        self.block0 = ResBlockEncoderOptimized(input_nc, ngf, norm_layer, nonlinearity, use_spect, use_coord)
        self.word_attention = ImageTextAttention(idf=image_dim, cdf=text_dim, multi_peak=multi_peak, pooling=pool_attention)

        mult = 1
        for i in range(layers-1):
            mult_prev = mult
            mult = min(2 ** (i + 2), img_f // ngf)
            block = ResBlock(ngf * mult_prev, ngf * mult, ngf * mult_prev, norm_layer, nonlinearity, 'down', use_spect, use_coord)
            setattr(self, 'encoder' + str(i), block)

        # inference part
        for i in range(self.L):
            block = ResBlock(ngf * mult, ngf * mult, ngf *mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)
            setattr(self, 'infer_prior' + str(i), block)

        # For textual, only change input and hidden dimension, z_nc is set when called.
        self.posterior = ResBlock(ngf * mult, 2*z_nc, ngf * mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)
        self.prior     = ResBlock(text_dim + ngf * mult, 2*z_nc, 2*ngf * mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)

    def forward(self, img_m, sentence_embedding, word_embeddings, text_mask, image_mask, img_c=None):
        """
        :param img_m: image with mask regions I_m
        :param sentence_embedding: the sentence embedding of I
        :param word_embeddings: word embedding of I
        :param text_mask: mask of word sequence of word_embeddings
        :param image_mask: mask of Im and Ic, need to scale if apply to fm
        :param img_c: complement of I_m, the mask regions
        :return distribution: distribution of mask regions, for training we have two paths, testing one path
        :return feature: the conditional feature f_m, and the previous f_pre for auto context attention
        :return text_feature: word and sentence features
        """

        if type(img_c) != type(None):
            img = torch.cat([img_m, img_c], dim=0)
        else:
            img = img_m

        # encoder part
        out = self.block0(img)
        feature = [out]
        for i in range(self.layers-1):
            model = getattr(self, 'encoder' + str(i))
            out = model(out)
            feature.append(out)

        # infer part
        # during the training, we have two paths, during the testing, we only have one paths
        image_mask = task.scale_img(image_mask, size=[feature[-1].size(2), feature[-1].size(3)])
        if image_mask.size(1) == 3:
            image_mask = image_mask.chunk(3, dim=1)[0]

        if type(img_c) != type(None):
            # adapt to word embedding, compute weighted word embedding with fm separately
            f_m_g, f_m_rec = feature[-1].chunk(2)
            img_mask_g = image_mask
            img_mask_rec = 1 - img_mask_g
            weighted_word_embedding_rec = self.word_attention(
                        f_m_rec, word_embeddings, mask=text_mask, image_mask=img_mask_rec, inverse_attention=False)
            weighted_word_embedding_g = self.word_attention(
                        f_m_g, word_embeddings, mask=text_mask, image_mask=img_mask_g, inverse_attention=True)

            weighted_word_embedding =  torch.cat([weighted_word_embedding_g, weighted_word_embedding_rec])
            distribution, h_word = self.two_paths(out, weighted_word_embedding)

            return distribution, feature, h_word
        else:
            # adapt to word embedding, compute weighted word embedding with fm of one path
            f_m = feature[-1]
            weighted_word_embedding = self.word_attention(
                f_m, word_embeddings, mask=text_mask, image_mask=image_mask, inverse_attention=True)

            distribution, h_word = self.one_path(weighted_word_embedding, f_m)
            return distribution, feature, h_word

    def one_path(self, weighted_word_embedding, v_h):
        """one path for baseline training or testing"""
        h_word = weighted_word_embedding
        distribution = []

        # infer state
        for i in range(self.L):
            infer_prior = getattr(self, 'infer_prior' + str(i))
            h_word = infer_prior(h_word)

        # get distribution
        o = self.prior(torch.cat([h_word,v_h], dim=1))
        q_mu, q_std = torch.split(o, self.z_nc, dim=1)
        distribution.append([q_mu, F.softplus(q_std)])

        return distribution, h_word

    def two_paths(self, f_in, weighted_word_embedding):
        """two paths for the training"""
        # use text embedding here
        f_m, f_c = f_in.chunk(2)
        weighted_word_embedding_m, weighted_word_embedding_c = weighted_word_embedding.chunk(2)

        h_word_c = weighted_word_embedding_c
        # infer state
        for i in range(self.L):
            infer_prior = getattr(self, 'infer_prior' + str(i))
            h_word_c = infer_prior(h_word_c)

        # get distribution
        distributions = []
        o = self.posterior(f_c)
        p_mu, p_std = torch.split(o, self.z_nc, dim=1)
        distribution, h_word = self.one_path(weighted_word_embedding_m, f_m)
        distributions.append([p_mu, F.softplus(p_std), distribution[0][0], distribution[0][1]])

        return distributions, torch.cat([h_word, h_word_c], dim=0)


class PosAttTextualResEncoder(nn.Module):
    """
    Positive Attentive Encoder Network
    :param input_nc: number of channels in input
    :param ngf: base filter channel
    :param z_nc: latent channels
    :param img_f: the largest feature channels
    :param L: Number of refinements of density
    :param layers: down and up sample layers
    :param norm: normalization function 'instance, batch, group'
    :param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
    :param image_dim: num of image feature maps
    :param text_dim: num of text embedding dimension
    :param multi_peak: use sigmoid in text attention if set to True
    """
    def __init__(self, input_nc=3, ngf=32, z_nc=256, img_f=256, L=6, layers=5, norm='none', activation='ReLU',
                 use_spect=True, use_coord=False, image_dim=256, text_dim=256, multi_peak=True, pool_attention='max'):
        super(PosAttTextualResEncoder, self).__init__()

        self.layers = layers
        self.z_nc = z_nc
        self.L = L

        norm_layer = get_norm_layer(norm_type=norm)
        nonlinearity = get_nonlinearity_layer(activation_type=activation)
        # encoder part
        self.block0 = ResBlockEncoderOptimized(input_nc, ngf, norm_layer, nonlinearity, use_spect, use_coord)
        self.word_attention = ImageTextAttention(idf=image_dim, cdf=text_dim, multi_peak=multi_peak, pooling=pool_attention)

        mult = 1
        for i in range(layers-1):
            mult_prev = mult
            mult = min(2 ** (i + 2), img_f // ngf)
            block = ResBlock(ngf * mult_prev, ngf * mult, ngf * mult_prev, norm_layer, nonlinearity, 'down', use_spect, use_coord)
            setattr(self, 'encoder' + str(i), block)

        # inference part
        for i in range(self.L):
            block = ResBlock(ngf * mult, ngf * mult, ngf *mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)
            setattr(self, 'infer_prior' + str(i), block)

        # For textual, only change input and hidden dimension, z_nc is set when called.
        self.posterior = ResBlock(ngf * mult + 2*text_dim, 2*z_nc, ngf * mult * 2, norm_layer, nonlinearity, 'none', use_spect, use_coord)
        self.prior =     ResBlock(ngf * mult + 2*text_dim, 2*z_nc, ngf * mult * 2, norm_layer, nonlinearity, 'none', use_spect, use_coord)

    def forward(self, img_m, sentence_embedding, word_embeddings, text_mask, image_mask, img_c=None):
        """
        :param img_m: image with mask regions I_m
        :param sentence_embedding: the sentence embedding of I
        :param word_embeddings: word embedding of I
        :param text_mask: mask of word sequence of word_embeddings
        :param image_mask: mask of Im and Ic, need to scale if apply to fm
        :param img_c: complement of I_m, the mask regions
        :return distribution: distribution of mask regions, for training we have two paths, testing one path
        :return feature: the conditional feature f_m, and the previous f_pre for auto context attention
        :return text_feature: word and sentence features
        """

        if type(img_c) != type(None):
            img = torch.cat([img_m, img_c], dim=0)
        else:
            img = img_m

        # encoder part
        out = self.block0(img)
        feature = [out]
        for i in range(self.layers-1):
            model = getattr(self, 'encoder' + str(i))
            out = model(out)
            feature.append(out)

        # infer part
        # during the training, we have two paths, during the testing, we only have one paths
        image_mask = task.scale_img(image_mask, size=[feature[-1].size(2), feature[-1].size(3)])
        if image_mask.size(1) == 3:
            image_mask = image_mask.chunk(3, dim=1)[0]

        if type(img_c) != type(None):
            # adapt to word embedding, compute weighted word embedding with fm separately
            f_m_g, f_m_rec = feature[-1].chunk(2)
            img_mask_g = image_mask
            img_mask_rec = 1 - img_mask_g
            weighted_word_embedding_rec = self.word_attention(
                        f_m_rec, word_embeddings, mask=text_mask, image_mask=img_mask_rec, inverse_attention=True)
            weighted_word_embedding_g = self.word_attention(
                        f_m_g, word_embeddings, mask=text_mask, image_mask=img_mask_g, inverse_attention=False)

            weighted_word_embedding =  torch.cat([weighted_word_embedding_g, weighted_word_embedding_rec])
            distribution, f_text = self.two_paths(out, sentence_embedding, weighted_word_embedding)

            return distribution, feature, f_text
        else:
            # adapt to word embedding, compute weighted word embedding with fm of one path
            f_m = feature[-1]
            weighted_word_embedding = self.word_attention(
                f_m, word_embeddings, mask=text_mask, image_mask=image_mask, inverse_attention=False)

            distribution, f_m_text = self.one_path(out, sentence_embedding, weighted_word_embedding)
            f_text = torch.cat([f_m_text, weighted_word_embedding], dim=1)
            return distribution, feature, f_text

    def one_path(self, f_in, sentence_embedding, weighted_word_embedding):
        """one path for baseline training or testing"""
        # TOTEST: adapt to word embedding, compute distribution with word embedding.
        f_m = f_in
        distribution = []

        # infer state
        for i in range(self.L):
            infer_prior = getattr(self, 'infer_prior' + str(i))
            f_m = infer_prior(f_m)

        # get distribution
        # use sentence embedding here
        ix, iw = f_m.size(2), f_m.size(3)
        sentence_dim = sentence_embedding.size(1)
        sentence_embedding_replication = sentence_embedding.view(-1, sentence_dim, 1, 1).repeat(1, 1, ix, iw)
        f_m_sent = torch.cat([f_m, sentence_embedding_replication], dim=1)
        f_m_text = torch.cat([f_m_sent, weighted_word_embedding], dim=1)

        o = self.prior(f_m_text)
        q_mu, q_std = torch.split(o, self.z_nc, dim=1)
        distribution.append([q_mu, F.softplus(q_std)])

        return distribution, f_m_sent

    def two_paths(self, f_in, sentence_embedding, weighted_word_embedding):
        """two paths for the training"""
        f_m, f_c = f_in.chunk(2)
        weighted_word_embedding_m, weighted_word_embedding_c = weighted_word_embedding.chunk(2)
        distributions = []

        # get distribution
        # use text embedding here
        ix, iw = f_c.size(2), f_c.size(3)
        sentence_dim = sentence_embedding.size(1)
        sentence_embedding_replication = sentence_embedding.view(-1, sentence_dim, 1, 1).repeat(1, 1, ix, iw)
        f_c_sent = torch.cat([f_c, sentence_embedding_replication], dim=1)
        f_c_text = torch.cat([f_c_sent, weighted_word_embedding_c], dim=1)
        o = self.posterior(f_c_text)
        p_mu, p_std = torch.split(o, self.z_nc, dim=1)

        distribution, f_m_sent = self.one_path(f_m, sentence_embedding, weighted_word_embedding_m)
        distributions.append([p_mu, F.softplus(p_std), distribution[0][0], distribution[0][1]])

        f_m_text = torch.cat([f_m_sent, weighted_word_embedding_m], dim=1)
        # TODO: rm weighted_word_emb_c for consis generation
        f_c_text = torch.cat([f_m_sent, weighted_word_embedding_c], dim=1)
        return distributions, torch.cat([f_m_text, f_c_text], dim=0)


class WordAttnEncoder(nn.Module):
    """
    WordAttn Encoder Network
    :param input_nc: number of channels in input
    :param ngf: base filter channel
    :param z_nc: latent channels
    :param img_f: the largest feature channels
    :param L: Number of refinements of density
    :param layers: down and up sample layers
    :param norm: normalization function 'instance, batch, group'
    :param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
    :param image_dim: num of image feature maps
    :param text_dim: num of text embedding dimension
    :param multi_peak: use sigmoid in text attention if set to True
    """
    def __init__(self, input_nc=3, ngf=32, z_nc=256, img_f=256, L=6, layers=5, norm='none', activation='ReLU',
                 use_spect=True, use_coord=False, image_dim=256, text_dim=256, multi_peak=True, pool_attention='max'):
        super(WordAttnEncoder, self).__init__()

        self.layers = layers
        self.z_nc = z_nc
        self.L = L

        norm_layer = get_norm_layer(norm_type=norm)
        nonlinearity = get_nonlinearity_layer(activation_type=activation)
        # encoder part
        self.block0 = ResBlockEncoderOptimized(input_nc, ngf, norm_layer, nonlinearity, use_spect, use_coord)
        self.word_attention = ImageTextAttention(idf=image_dim, cdf=text_dim, multi_peak=multi_peak, pooling=pool_attention)

        mult = 1
        for i in range(layers-1):
            mult_prev = mult
            mult = min(2 ** (i + 2), img_f // ngf)
            block = ResBlock(ngf * mult_prev, ngf * mult, ngf * mult_prev, norm_layer, nonlinearity, 'down', use_spect, use_coord)
            setattr(self, 'encoder' + str(i), block)

        # inference part
        for i in range(self.L):
            block = ResBlock(ngf * mult, ngf * mult, ngf *mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)
            setattr(self, 'infer_prior' + str(i), block)

        # For textual, only change input and hidden dimension, z_nc is set when called.
        self.posterior = ResBlock(ngf * mult + text_dim, 2*z_nc, ngf * mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)
        self.prior =     ResBlock(ngf * mult + text_dim, 2*z_nc, ngf * mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)

    def forward(self, img_m, word_embeddings, text_mask, image_mask, img_c=None):
        """
        :param img_m: image with mask regions I_m
        :param word_embeddings: word embedding of I
        :param text_mask: mask of word sequence of word_embeddings
        :param image_mask: mask of Im and Ic, need to scale if apply to fm
        :param img_c: complement of I_m, the mask regions
        :return distribution: distribution of mask regions, for training we have two paths, testing one path
        :return feature: the conditional feature f_m, and the previous f_pre for auto context attention
        :return text_feature: word and sentence features
        """

        if type(img_c) != type(None):
            img = torch.cat([img_m, img_c], dim=0)
        else:
            img = img_m

        # encoder part
        out = self.block0(img)
        feature = [out]
        for i in range(self.layers-1):
            model = getattr(self, 'encoder' + str(i))
            out = model(out)
            feature.append(out)

        # infer part
        # during the training, we have two paths, during the testing, we only have one paths
        image_mask = task.scale_img(image_mask, size=[feature[-1].size(2), feature[-1].size(3)])
        if image_mask.size(1) == 3:
            image_mask = image_mask.chunk(3, dim=1)[0]

        if type(img_c) != type(None):
            # adapt to word embedding, compute weighted word embedding with fm separately
            f_m_g, f_m_rec = feature[-1].chunk(2)
            img_mask_g = image_mask
            img_mask_rec = 1 - img_mask_g
            weighted_word_embedding_rec = self.word_attention(
                        f_m_rec, word_embeddings, mask=text_mask, image_mask=img_mask_rec, inverse_attention=False)
            weighted_word_embedding_g = self.word_attention(
                        f_m_g, word_embeddings, mask=text_mask, image_mask=img_mask_g, inverse_attention=True)

            weighted_word_embedding =  torch.cat([weighted_word_embedding_g, weighted_word_embedding_rec])
            distribution, f_text = self.two_paths(out, weighted_word_embedding)

            return distribution, feature, f_text
        else:
            # adapt to word embedding, compute weighted word embedding with fm of one path
            f_m = feature[-1]
            weighted_word_embedding = self.word_attention(
                f_m, word_embeddings, mask=text_mask, image_mask=image_mask, inverse_attention=True)

            distribution = self.one_path(out, weighted_word_embedding)
            f_text = weighted_word_embedding
            return distribution, feature, f_text

    def one_path(self, f_in, weighted_word_embedding):
        """one path for baseline training or testing"""
        # TOTEST: adapt to word embedding, compute distribution with word embedding.
        f_m = f_in
        distribution = []

        # infer state
        for i in range(self.L):
            infer_prior = getattr(self, 'infer_prior' + str(i))
            f_m = infer_prior(f_m)

        # get distribution
        # use sentence embedding here
        f_m_text = torch.cat([f_m, weighted_word_embedding], dim=1)

        o = self.prior(f_m_text)
        q_mu, q_std = torch.split(o, self.z_nc, dim=1)
        distribution.append([q_mu, F.softplus(q_std)])

        return distribution

    def two_paths(self, f_in, weighted_word_embedding):
        """two paths for the training"""
        f_m, f_c = f_in.chunk(2)
        weighted_word_embedding_m, weighted_word_embedding_c = weighted_word_embedding.chunk(2)
        distributions = []

        # get distribution
        # use text embedding here
        f_c_text = torch.cat([f_c, weighted_word_embedding_c], dim=1)
        o = self.posterior(f_c_text)
        p_mu, p_std = torch.split(o, self.z_nc, dim=1)

        distribution = self.one_path(f_m, weighted_word_embedding_m)
        distributions.append([p_mu, F.softplus(p_std), distribution[0][0], distribution[0][1]])

        f_m_text = weighted_word_embedding_m
        f_c_text = weighted_word_embedding_c
        return distributions, torch.cat([f_m_text, f_c_text], dim=0)


class ConstraintResEncoder(nn.Module):
    """
    Constraint Encoder Network
    :param input_nc: number of channels in input
    :param ngf: base filter channel
    :param z_nc: latent channels
    :param img_f: the largest feature channels
    :param L: Number of refinements of density
    :param layers: down and up sample layers
    :param norm: normalization function 'instance, batch, group'
    :param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
    :param image_dim: num of image feature maps
    :param text_dim: num of text embedding dimension
    :param multi_peak: use sigmoid in text attention if set to True
    """
    def __init__(self, input_nc=3, ngf=32, z_nc=256, img_f=256, L=6, layers=5, norm='none', activation='ReLU',
                 use_spect=True, use_coord=False, image_dim=256, text_dim=256, multi_peak=True, pool_attention='max'):
        super(ConstraintResEncoder, self).__init__()

        self.layers = layers
        self.z_nc = z_nc
        self.L = L

        norm_layer = get_norm_layer(norm_type=norm)
        nonlinearity = get_nonlinearity_layer(activation_type=activation)
        # encoder part
        self.block0 = ResBlockEncoderOptimized(input_nc, ngf, norm_layer, nonlinearity, use_spect, use_coord)
        self.word_attention = ImageTextAttention(idf=image_dim, cdf=text_dim, multi_peak=multi_peak, pooling=pool_attention)

        mult = 1
        for i in range(layers-1):
            mult_prev = mult
            mult = min(2 ** (i + 2), img_f // ngf)
            block = ResBlock(ngf * mult_prev, ngf * mult, ngf * mult_prev, norm_layer, nonlinearity, 'down', use_spect, use_coord)
            setattr(self, 'encoder' + str(i), block)

        # inference part
        for i in range(self.L):
            block = ResBlock(ngf * mult, ngf * mult, ngf *mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)
            setattr(self, 'infer_prior' + str(i), block)

        for i in range(self.L):
            block = ResBlock(ngf * mult, ngf * mult, ngf *mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)
            setattr(self, 'infer_prior_word' + str(i), block)

        # For textual, only change input and hidden dimension, z_nc is set when called.
        self.posterior = ResBlock(ngf * mult + 2*text_dim, 2*z_nc, ngf * mult * 2, norm_layer, nonlinearity, 'none', use_spect, use_coord)
        self.prior =     ResBlock(ngf * mult + 2*text_dim, 2*z_nc, ngf * mult * 2, norm_layer, nonlinearity, 'none', use_spect, use_coord)

    def forward(self, img_m, sentence_embedding, word_embeddings, text_mask, image_mask, img_c=None):
        """
        :param img_m: image with mask regions I_m
        :param sentence_embedding: the sentence embedding of I
        :param word_embeddings: word embedding of I
        :param text_mask: mask of word sequence of word_embeddings
        :param image_mask: mask of Im and Ic, need to scale if apply to fm
        :param img_c: complement of I_m, the mask regions
        :return distribution: distribution of mask regions, for training we have two paths, testing one path
        :return feature: the conditional feature f_m, and the previous f_pre for auto context attention
        :return text_feature: word and sentence features
        """

        if type(img_c) != type(None):
            img = torch.cat([img_m, img_c], dim=0)
        else:
            img = img_m

        # encoder part
        out = self.block0(img)
        feature = [out]
        for i in range(self.layers-1):
            model = getattr(self, 'encoder' + str(i))
            out = model(out)
            feature.append(out)

        # infer part
        # during the training, we have two paths, during the testing, we only have one paths
        image_mask = task.scale_img(image_mask, size=[feature[-1].size(2), feature[-1].size(3)])
        if image_mask.size(1) == 3:
            image_mask = image_mask.chunk(3, dim=1)[0]

        if type(img_c) != type(None):
            # During training
            f_m_g, f_m_rec = feature[-1].chunk(2)
            img_mask_g = image_mask
            img_mask_rec = 1 - img_mask_g
            weighted_word_embedding_rec = self.word_attention(
                        f_m_rec, word_embeddings, mask=text_mask, image_mask=img_mask_rec, inverse_attention=False)
            weighted_word_embedding_g = self.word_attention(
                        f_m_g, word_embeddings, mask=text_mask, image_mask=img_mask_g, inverse_attention=True)

            weighted_word_embedding =  torch.cat([weighted_word_embedding_g, weighted_word_embedding_rec])
            distribution, f_text, dual_word_embedding = self.two_paths(out, sentence_embedding, weighted_word_embedding)

            return distribution, feature, f_text, dual_word_embedding
        else:
            # During test
            f_m = feature[-1]
            weighted_word_embedding = self.word_attention(
                f_m, word_embeddings, mask=text_mask, image_mask=image_mask, inverse_attention=True)

            distribution, f_m_text, infered_word_embedding = self.one_path(out, sentence_embedding, weighted_word_embedding)
            f_text = torch.cat([f_m_text, weighted_word_embedding], dim=1)
            return distribution, feature, f_text

    def one_path(self, f_in, sentence_embedding, weighted_word_embedding):
        """one path for baseline training or testing"""
        # TOTEST: adapt to word embedding, compute distribution with word embedding.
        f_m = f_in
        distribution = []

        # infer state
        for i in range(self.L):
            infer_prior = getattr(self, 'infer_prior' + str(i))
            f_m = infer_prior(f_m)

        # infer state
        for i in range(self.L):
            infer_prior_word = getattr(self, 'infer_prior_word' + str(i))
            infered_word_embedding = infer_prior_word(weighted_word_embedding)

        # get distribution
        # use sentence embedding here
        ix, iw = f_m.size(2), f_m.size(3)
        sentence_dim = sentence_embedding.size(1)
        sentence_embedding_replication = sentence_embedding.view(-1, sentence_dim, 1, 1).repeat(1, 1, ix, iw)
        f_m_sent = torch.cat([f_m, sentence_embedding_replication], dim=1)
        f_m_text = torch.cat([f_m_sent, infered_word_embedding], dim=1)

        o = self.prior(f_m_text)
        q_mu, q_std = torch.split(o, self.z_nc, dim=1)
        distribution.append([q_mu, F.softplus(q_std)])

        return distribution, f_m_sent, infered_word_embedding

    def two_paths(self, f_in, sentence_embedding, weighted_word_embedding):
        """two paths for the training"""
        f_m, f_c = f_in.chunk(2)
        weighted_word_embedding_m, weighted_word_embedding_c = weighted_word_embedding.chunk(2)
        distributions = []

        # get distribution
        # use text embedding here
        ix, iw = f_c.size(2), f_c.size(3)
        sentence_dim = sentence_embedding.size(1)
        sentence_embedding_replication = sentence_embedding.view(-1, sentence_dim, 1, 1).repeat(1, 1, ix, iw)
        f_c_sent = torch.cat([f_c, sentence_embedding_replication], dim=1)
        f_c_text = torch.cat([f_c_sent, weighted_word_embedding_c], dim=1)
        o = self.posterior(f_c_text)
        p_mu, p_std = torch.split(o, self.z_nc, dim=1)

        distribution, f_m_sent, infered_word_embedding = self.one_path(f_m, sentence_embedding, weighted_word_embedding_m)
        distributions.append([p_mu, F.softplus(p_std), distribution[0][0], distribution[0][1]])
        dual_word_embedding = torch.cat([infered_word_embedding, weighted_word_embedding_c], dim=0)

        f_m_text = torch.cat([f_m_sent, infered_word_embedding], dim=1)
        # TODO: evaluate wether to replace infered to weighted_c
        f_c_text = torch.cat([f_m_sent, infered_word_embedding], dim=1)

        return distributions, torch.cat([f_m_text, f_c_text], dim=0), dual_word_embedding

class ResGenerator(nn.Module):
    """
    ResNet Generator Network
    :param output_nc: number of channels in output
    :param ngf: base filter channel
    :param z_nc: latent channels
    :param img_f: the largest feature channels
    :param L: Number of refinements of density
    :param layers: down and up sample layers
    :param norm: normalization function 'instance, batch, group'
    :param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
    :param output_scale: Different output scales
    """
    def __init__(self, output_nc=3, ngf=64, z_nc=128, img_f=1024, L=1, layers=6, norm='batch', activation='ReLU',
                 output_scale=1, use_spect=True, use_coord=False, use_attn=True):
        super(ResGenerator, self).__init__()

        self.layers = layers
        self.L = L
        self.output_scale = output_scale
        self.use_attn = use_attn

        norm_layer = get_norm_layer(norm_type=norm)
        nonlinearity = get_nonlinearity_layer(activation_type=activation)
        # latent z to feature
        mult = min(2 ** (layers-1), img_f // ngf)
        # input -> hidden
        self.generator = ResBlock(z_nc, ngf * mult, ngf * mult, None, nonlinearity, 'none', use_spect, use_coord)

        # transform
        for i in range(self.L):
            block = ResBlock(ngf * mult, ngf * mult, ngf * mult, None, nonlinearity, 'none', use_spect, use_coord)
            setattr(self, 'generator' + str(i), block)

        # decoder part
        for i in range(layers):
            mult_prev = mult
            mult = min(2 ** (layers - i - 1), img_f // ngf)
            if i > layers - output_scale:
                # upconv = ResBlock(ngf * mult_prev + output_nc, ngf * mult, ngf * mult, norm_layer, nonlinearity, 'up', True)
                upconv = ResBlockDecoder(ngf * mult_prev + output_nc, ngf * mult, ngf * mult, norm_layer, nonlinearity, use_spect, use_coord)
            else:
                # upconv = ResBlock(ngf * mult_prev, ngf * mult, ngf * mult, norm_layer, nonlinearity, 'up', True)
                upconv = ResBlockDecoder(ngf * mult_prev , ngf * mult, ngf * mult, norm_layer, nonlinearity, use_spect, use_coord)
            setattr(self, 'decoder' + str(i), upconv)
            # output part
            if i > layers - output_scale - 1:
                outconv = Output(ngf * mult, output_nc, 3, None, nonlinearity, use_spect, use_coord)
                setattr(self, 'out' + str(i), outconv)
            # short+long term attention part
            if i == 1 and use_attn:
                attn = Auto_Attn(ngf*mult, None)
                setattr(self, 'attn' + str(i), attn)

    def forward(self, z, f_m=None, f_e=None, mask=None):
        """
        ResNet Generator Network
        :param z: latent vector
        :param f_m: feature of valid regions for conditional VAG-GAN
        :param f_e: previous encoder feature for short+long term attention layer
        :return results: different scale generation outputs
        """

        f = self.generator(z)
        for i in range(self.L):
             generator = getattr(self, 'generator' + str(i)) # dimension not change
             f = generator(f)

        # the features come from mask regions and valid regions, we directly add them together
        out = f_m + f
        results= []
        attn = 0
        for i in range(self.layers):
            model = getattr(self, 'decoder' + str(i))
            out = model(out)
            if i == 1 and self.use_attn:
                # auto attention
                model = getattr(self, 'attn' + str(i))
                out, attn = model(out, f_e, mask)
            if i > self.layers - self.output_scale - 1:
                model = getattr(self, 'out' + str(i))
                output = model(out)
                results.append(output)
                out = torch.cat([out, output], dim=1)

        return results, attn

class ContrastResGenerator(nn.Module):
    """
    Contrast ResNet Generator Network
    :param output_nc: number of channels in output
    :param ngf: base filter channel
    :param z_nc: latent channels
    :param img_f: the largest feature channels
    :param L: Number of refinements of density
    :param layers: down and up sample layers
    :param norm: normalization function 'instance, batch, group'
    :param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
    :param output_scale: Different output scales
    """
    def __init__(self, output_nc=3, ngf=64, z_nc=128, img_f=1024, L=1, layers=6, norm='batch', activation='ReLU',
                 output_scale=1, use_spect=True, use_coord=False, use_attn=True):
        super(ContrastResGenerator, self).__init__()

        self.layers = layers
        self.L = L
        self.output_scale = output_scale
        self.use_attn = use_attn

        norm_layer = get_norm_layer(norm_type=norm)
        nonlinearity = get_nonlinearity_layer(activation_type=activation)
        # latent z to feature
        mult = min(2 ** (layers-1), img_f // ngf)
        # input -> hidden
        self.generator = ResBlock(z_nc, ngf * mult, ngf * mult, None, nonlinearity, 'none', use_spect, use_coord)

        # transform
        for i in range(self.L):
            block = ResBlock(ngf * mult, ngf * mult, ngf * mult, None, nonlinearity, 'none', use_spect, use_coord)
            setattr(self, 'generator' + str(i), block)

        # decoder part
        for i in range(layers):
            mult_prev = mult
            mult = min(2 ** (layers - i - 1), img_f // ngf)
            if i > layers - output_scale:
                # upconv = ResBlock(ngf * mult_prev + output_nc, ngf * mult, ngf * mult, norm_layer, nonlinearity, 'up', True)
                upconv = ResBlockDecoder(ngf * mult_prev + output_nc, ngf * mult, ngf * mult, norm_layer, nonlinearity, use_spect, use_coord)
            else:
                # upconv = ResBlock(ngf * mult_prev, ngf * mult, ngf * mult, norm_layer, nonlinearity, 'up', True)
                upconv = ResBlockDecoder(ngf * mult_prev , ngf * mult, ngf * mult, norm_layer, nonlinearity, use_spect, use_coord)
            setattr(self, 'decoder' + str(i), upconv)
            # output part
            if i > layers - output_scale - 1:
                outconv = Output(ngf * mult, output_nc, 3, None, nonlinearity, use_spect, use_coord)
                setattr(self, 'out' + str(i), outconv)
            # short+long term attention part
            if i == 1 and use_attn:
                attn = Auto_Attn(ngf*mult, None)
                setattr(self, 'attn' + str(i), attn)

    def forward(self, z, f_m=None, f_e=None, mask=None):
        """
        ResNet Generator Network
        :param z: latent vector
        :param f_m: feature of valid regions for conditional VAG-GAN
        :param f_e: previous encoder feature for short+long term attention layer
        :return results: different scale generation outputs
        """

        f = self.generator(z)
        for i in range(self.L):
             generator = getattr(self, 'generator' + str(i)) # dimension not change
             f = generator(f)

        # the features come from mask regions and valid regions, we directly add them together
        # out = f_m + f
        out = f
        results= []
        attn = 0
        for i in range(self.layers):
            model = getattr(self, 'decoder' + str(i))
            out = model(out)
            if i == 1 and self.use_attn:
                # auto attention
                model = getattr(self, 'attn' + str(i))
                out, attn = model(out, f_e, mask)
            if i > self.layers - self.output_scale - 1:
                model = getattr(self, 'out' + str(i))
                output = model(out)
                results.append(output)
                out = torch.cat([out, output], dim=1)

        return results, attn

class HiddenResGenerator(nn.Module):
    """
    ResNet Generator Network
    :param output_nc: number of channels in output
    :param ngf: base filter channel
    :param z_nc: latent channels
    :param img_f: the largest feature channels
    :param L: Number of refinements of density
    :param layers: down and up sample layers
    :param norm: normalization function 'instance, batch, group'
    :param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
    :param output_scale: Different output scales
    """
    def __init__(self, output_nc=3, f_text_dim=384, ngf=64, z_nc=128, img_f=256, L=1, layers=6, norm='batch', activation='ReLU',
                 output_scale=1, use_spect=True, use_coord=False, use_attn=True):
        super(HiddenResGenerator, self).__init__()

        self.layers = layers
        self.L = L
        self.output_scale = output_scale
        self.use_attn = use_attn

        norm_layer = get_norm_layer(norm_type=norm)
        nonlinearity = get_nonlinearity_layer(activation_type=activation)
        # latent z to feature
        mult = min(2 ** (layers-1), img_f // ngf)
        # input -> hidden
        self.generator = ResBlock(z_nc, ngf * mult, ngf * mult, None, nonlinearity, 'none', use_spect, use_coord)
        self.f_transformer = ResBlock(f_text_dim, ngf * mult, ngf * mult, None, nonlinearity, 'none', use_spect, use_coord)

        # transform
        for i in range(self.L):
            block = ResBlock(ngf * mult, ngf * mult, ngf * mult, None, nonlinearity, 'none', use_spect, use_coord)
            setattr(self, 'generator' + str(i), block)

        # decoder part
        for i in range(layers):
            mult_prev = mult
            mult = min(2 ** (layers - i - 1), img_f // ngf)
            if i > layers - output_scale:
                # upconv = ResBlock(ngf * mult_prev + output_nc, ngf * mult, ngf * mult, norm_layer, nonlinearity, 'up', True)
                upconv = ResBlockDecoder(ngf * mult_prev + output_nc, ngf * mult, ngf * mult, norm_layer, nonlinearity, use_spect, use_coord)
            else:
                # upconv = ResBlock(ngf * mult_prev, ngf * mult, ngf * mult, norm_layer, nonlinearity, 'up', True)
                upconv = ResBlockDecoder(ngf * mult_prev , ngf * mult, ngf * mult, norm_layer, nonlinearity, use_spect, use_coord)
            setattr(self, 'decoder' + str(i), upconv)
            # output part
            if i > layers - output_scale - 1:
                outconv = Output(ngf * mult, output_nc, 3, None, nonlinearity, use_spect, use_coord)
                setattr(self, 'out' + str(i), outconv)
            # short+long term attention part
            if i == 1 and use_attn:
                attn = Auto_Attn(ngf*mult, None)
                setattr(self, 'attn' + str(i), attn)

    def forward(self, z, f_text=None, f_e=None, mask=None):
        """
        ResNet Generator Network
        :param z: latent vector
        :param f_m: feature of valid regions for conditional VAG-GAN
        :param f_e: previous encoder feature for short+long term attention layer
        :return results: different scale generation outputs
        """
        f = self.generator(z)
        for i in range(self.L):
             generator = getattr(self, 'generator' + str(i)) # dimension not change
             f = generator(f)
        f_text_trans = self.f_transformer(f_text)
        # the features come from mask regions and valid regions, we directly add them together
        out = f_text_trans + f
        results= []
        attn = 0
        for i in range(self.layers):
            model = getattr(self, 'decoder' + str(i))
            out = model(out)
            if i == 1 and self.use_attn:
                # auto attention
                model = getattr(self, 'attn' + str(i))
                out, attn = model(out, f_e, mask)
            if i > self.layers - self.output_scale - 1:
                model = getattr(self, 'out' + str(i))
                output = model(out)
                results.append(output)
                out = torch.cat([out, output], dim=1)

        return results, attn

class TextualResGenerator(nn.Module):
    """
    Textual ResNet Generator Network.
    This fucking code is hard to maintenance, just list a trip of shit.
    :param output_nc: number of channels in output
    :param ngf: base filter channel
    :param z_nc: latent channels
    :param img_f: the largest feature channels
    :param L: Number of refinements of density
    :param layers: down and up sample layers
    :param norm: normalization function 'instance, batch, group'
    :param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
    :param output_scale: Different output scales
    """
    def __init__(self, output_nc=3, f_text_dim=384, ngf=64, z_nc=128, img_f=256, L=1, layers=6, norm='batch', activation='ReLU',
                 output_scale=1, use_spect=True, use_coord=False, use_attn=True):
        super(TextualResGenerator, self).__init__()

        self.layers = layers
        self.L = L
        self.output_scale = output_scale
        self.use_attn = use_attn

        norm_layer = get_norm_layer(norm_type=norm)
        nonlinearity = get_nonlinearity_layer(activation_type=activation)
        # latent z to feature
        mult = min(2 ** (layers-1), img_f // ngf)
        # input -> hidden
        self.generator = ResBlock(z_nc, ngf * mult, ngf * mult, None, nonlinearity, 'none', use_spect, use_coord)
        self.f_transformer = ResBlock(f_text_dim, ngf * mult, ngf * mult, None, nonlinearity, 'none', use_spect, use_coord)

        # transform
        for i in range(self.L):
            block = ResBlock(ngf * mult, ngf * mult, ngf * mult, None, nonlinearity, 'none', use_spect, use_coord)
            setattr(self, 'generator' + str(i), block)

        # decoder part
        for i in range(layers):
            mult_prev = mult
            mult = min(2 ** (layers - i - 1), img_f // ngf)
            if i > layers - output_scale:
                # upconv = ResBlock(ngf * mult_prev + output_nc, ngf * mult, ngf * mult, norm_layer, nonlinearity, 'up', True)
                upconv = ResBlockDecoder(ngf * mult_prev + output_nc, ngf * mult, ngf * mult, norm_layer, nonlinearity, use_spect, use_coord)
            else:
                # upconv = ResBlock(ngf * mult_prev, ngf * mult, ngf * mult, norm_layer, nonlinearity, 'up', True)
                upconv = ResBlockDecoder(ngf * mult_prev , ngf * mult, ngf * mult, norm_layer, nonlinearity, use_spect, use_coord)
            setattr(self, 'decoder' + str(i), upconv)
            # output part
            if i > layers - output_scale - 1:
                outconv = Output(ngf * mult, output_nc, 3, None, nonlinearity, use_spect, use_coord)
                setattr(self, 'out' + str(i), outconv)
            # short+long term attention part
            if i == 1 and use_attn:
                attn = Auto_Attn(ngf*mult, None)
                setattr(self, 'attn' + str(i), attn)

    def forward(self, z, f_text=None, f_e=None, mask=None):
        """
        ResNet Generator Network
        :param z: latent vector
        :param f_m: feature of valid regions for conditional VAG-GAN
        :param f_e: previous encoder feature for short+long term attention layer
        :return results: different scale generation outputs
        """
        f = self.generator(z)
        for i in range(self.L):
             generator = getattr(self, 'generator' + str(i)) # dimension not change
             f = generator(f)
        f_text_trans = self.f_transformer(f_text)
        # the features come from mask regions and valid regions, we directly add them together
        out = f_text_trans + f
        results= []
        attn = 0
        for i in range(self.layers):
            model = getattr(self, 'decoder' + str(i))
            out = model(out)
            if i == 1 and self.use_attn:
                # auto attention
                model = getattr(self, 'attn' + str(i))
                out, attn = model(out, f_e, mask)
            if i > self.layers - self.output_scale - 1:
                model = getattr(self, 'out' + str(i))
                output = model(out)
                results.append(output)
                out = torch.cat([out, output], dim=1)

        return results, attn

class ResDiscriminator(nn.Module):
    """
    ResNet Discriminator Network
    :param input_nc: number of channels in input
    :param ndf: base filter channel
    :param layers: down and up sample layers
    :param img_f: the largest feature channels
    :param norm: normalization function 'instance, batch, group'
    :param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
    """
    def __init__(self, input_nc=3, ndf=64, img_f=1024, layers=6, norm='none', activation='LeakyReLU', use_spect=True,
                 use_coord=False, use_attn=True):
        super(ResDiscriminator, self).__init__()

        self.layers = layers
        self.use_attn = use_attn

        norm_layer = get_norm_layer(norm_type=norm)
        nonlinearity = get_nonlinearity_layer(activation_type=activation)
        self.nonlinearity = nonlinearity

        # encoder part
        self.block0 = ResBlockEncoderOptimized(input_nc, ndf,norm_layer, nonlinearity, use_spect, use_coord)

        mult = 1
        for i in range(layers - 1):
            mult_prev = mult
            mult = min(2 ** (i + 1), img_f // ndf)
            # self-attention
            if i == 2 and use_attn:
                attn = Auto_Attn(ndf * mult_prev, norm_layer)
                setattr(self, 'attn' + str(i), attn)
            block = ResBlock(ndf * mult_prev, ndf * mult, ndf * mult_prev, norm_layer, nonlinearity, 'down', use_spect, use_coord)
            setattr(self, 'encoder' + str(i), block)

        self.block1 = ResBlock(ndf * mult, ndf * mult, ndf * mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)
        self.conv = SpectralNorm(nn.Conv2d(ndf * mult, 1, 3))

    def forward(self, x):
        out = self.block0(x)
        for i in range(self.layers - 1):
            if i == 2 and self.use_attn:
                attn = getattr(self, 'attn' + str(i))
                out, attention = attn(out)
            model = getattr(self, 'encoder' + str(i))
            out = model(out)
        out = self.block1(out)
        out = self.conv(self.nonlinearity(out))
        return out



class SNPatchDiscriminator(nn.Module):
    """
    SN Patch Discriminator Network for Local 70*70 fake/real
    :param input_nc: number of channels in input
    :param ndf: base filter channel
    :param img_f: the largest channel for the model
    :param layers: down sample layers
    :param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
    :param use_spect: use spectral normalization or not
    :param use_coord: use CoordConv or nor
    """
    def __init__(self, input_nc=4, ndf=64, img_f=256, layers=6, activation='LeakyReLU',
                 use_spect=True, use_coord=False):
        super(SNPatchDiscriminator, self).__init__()

        nonlinearity = get_nonlinearity_layer(activation_type=activation)

        kwargs = {'kernel_size': 4, 'stride': 2, 'padding': 1, 'bias': False}
        sequence = [
            coord_conv(input_nc, ndf, use_spect, use_coord, **kwargs),
            nonlinearity,
        ]

        mult = 1
        for i in range(1, layers):
            mult_prev = mult
            mult = min(2 ** i, img_f // ndf)
            sequence +=[
                    coord_conv(ndf * mult_prev, ndf * mult, use_spect, use_coord, **kwargs),
                    nonlinearity,
                ]

        self.model = nn.Sequential(*sequence)

    def forward(self, x):
        out = self.model(x)
        return out


# ############## Text2Image Encoder-Decoder #######
class RNN_ENCODER(nn.Module):
    def __init__(self, ntoken, ninput=300, drop_prob=0.5,
                 nhidden=128, nlayers=1, bidirectional=True):
        super(RNN_ENCODER, self).__init__()

        self.n_steps = 25
        self.rnn_type = 'LSTM'

        self.ntoken = ntoken  # size of the dictionary
        self.ninput = ninput  # size of each embedding vector
        self.drop_prob = drop_prob  # probability of an element to be zeroed
        self.nlayers = nlayers  # Number of recurrent layers
        self.bidirectional = bidirectional

        if bidirectional:
            self.num_directions = 2
        else:
            self.num_directions = 1
        # number of features in the hidden state
        self.nhidden = nhidden // self.num_directions

        self.define_module()
        self.init_weights()

    def define_module(self):
        self.encoder = nn.Embedding(self.ntoken, self.ninput)
        self.drop = nn.Dropout(self.drop_prob)
        if self.rnn_type == 'LSTM':
            # dropout: If non-zero, introduces a dropout layer on
            # the outputs of each RNN layer except the last layer
            self.rnn = nn.LSTM(self.ninput, self.nhidden,
                               self.nlayers, batch_first=True,
                               dropout=self.drop_prob,
                               bidirectional=self.bidirectional)
        elif self.rnn_type == 'GRU':
            self.rnn = nn.GRU(self.ninput, self.nhidden,
                              self.nlayers, batch_first=True,
                              dropout=self.drop_prob,
                              bidirectional=self.bidirectional)
        else:
            raise NotImplementedError

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        # Do not need to initialize RNN parameters, which have been initialized
        # http://pytorch.org/docs/master/_modules/torch/nn/modules/rnn.html#LSTM
        # self.decoder.weight.data.uniform_(-initrange, initrange)
        # self.decoder.bias.data.fill_(0)

    def init_hidden(self, bsz):
        weight = next(self.parameters()).data
        if self.rnn_type == 'LSTM':
            return (Variable(weight.new(self.nlayers * self.num_directions,
                                        bsz, self.nhidden).zero_()),
                    Variable(weight.new(self.nlayers * self.num_directions,
                                        bsz, self.nhidden).zero_()))
        else:
            return Variable(weight.new(self.nlayers * self.num_directions,
                                       bsz, self.nhidden).zero_())

    def forward(self, captions, cap_lens, hidden, mask=None):
        # input: torch.LongTensor of size batch x n_steps
        # --> emb: batch x n_steps x ninput
        emb = self.drop(self.encoder(captions))
        #
        # Returns: a PackedSequence object
        cap_lens = cap_lens.data.tolist()
        emb = pack_padded_sequence(emb, cap_lens, batch_first=True, enforce_sorted=False)
        # #hidden and memory (num_layers * num_directions, batch, hidden_size):
        # tensor containing the initial hidden state for each element in batch.
        # #output (batch, seq_len, hidden_size * num_directions)
        # #or a PackedSequence object:
        # tensor containing output features (h_t) from the last layer of RNN
        output, hidden = self.rnn(emb, hidden)
        # PackedSequence object
        # --> (batch, seq_len, hidden_size * num_directions)
        output = pad_packed_sequence(output, batch_first=True)[0]
        # output = self.drop(output)
        # --> batch x hidden_size*num_directions x seq_len
        words_emb = output.transpose(1, 2)
        # --> batch x num_directions*hidden_size
        if self.rnn_type == 'LSTM':
            sent_emb = hidden[0].transpose(0, 1).contiguous()
        else:
            sent_emb = hidden.transpose(0, 1).contiguous()
        sent_emb = sent_emb.view(-1, self.nhidden * self.num_directions)
        return words_emb, sent_emb

from torchvision import models
import torch.utils.model_zoo as model_zoo

class CNN_ENCODER(nn.Module):
    def __init__(self, nef, pre_train=False):
        super(CNN_ENCODER, self).__init__()
        self.nef = nef  # define a uniform ranker

        model = models.inception_v3()
        if pre_train:
            url = 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth'
            model.load_state_dict(model_zoo.load_url(url))
            for param in model.parameters():
                param.requires_grad = False
            print('Load pretrained model from ', url)

        self.define_module(model)
        self.init_trainable_weights()

    def define_module(self, model):
        self.Conv2d_1a_3x3 = model.Conv2d_1a_3x3
        self.Conv2d_2a_3x3 = model.Conv2d_2a_3x3
        self.Conv2d_2b_3x3 = model.Conv2d_2b_3x3
        self.Conv2d_3b_1x1 = model.Conv2d_3b_1x1
        self.Conv2d_4a_3x3 = model.Conv2d_4a_3x3
        self.Mixed_5b = model.Mixed_5b
        self.Mixed_5c = model.Mixed_5c
        self.Mixed_5d = model.Mixed_5d
        self.Mixed_6a = model.Mixed_6a
        self.Mixed_6b = model.Mixed_6b
        self.Mixed_6c = model.Mixed_6c
        self.Mixed_6d = model.Mixed_6d
        self.Mixed_6e = model.Mixed_6e
        self.Mixed_7a = model.Mixed_7a
        self.Mixed_7b = model.Mixed_7b
        self.Mixed_7c = model.Mixed_7c

        self.emb_features = conv1x1(768, self.nef)
        self.emb_cnn_code = nn.Linear(2048, self.nef)

    def init_trainable_weights(self):
        initrange = 0.1
        self.emb_features.weight.data.uniform_(-initrange, initrange)
        self.emb_cnn_code.weight.data.uniform_(-initrange, initrange)

    def forward(self, x):
        features = None
        # --> fixed-size input: batch x 3 x 299 x 299
        x = nn.Upsample(size=(299, 299), mode='bilinear')(x)
        # 299 x 299 x 3
        x = self.Conv2d_1a_3x3(x)
        # 149 x 149 x 32
        x = self.Conv2d_2a_3x3(x)
        # 147 x 147 x 32
        x = self.Conv2d_2b_3x3(x)
        # 147 x 147 x 64
        x = F.max_pool2d(x, kernel_size=3, stride=2)
        # 73 x 73 x 64
        x = self.Conv2d_3b_1x1(x)
        # 73 x 73 x 80
        x = self.Conv2d_4a_3x3(x)
        # 71 x 71 x 192

        x = F.max_pool2d(x, kernel_size=3, stride=2)
        # 35 x 35 x 192
        x = self.Mixed_5b(x)
        # 35 x 35 x 256
        x = self.Mixed_5c(x)
        # 35 x 35 x 288
        x = self.Mixed_5d(x)
        # 35 x 35 x 288

        x = self.Mixed_6a(x)
        # 17 x 17 x 768
        x = self.Mixed_6b(x)
        # 17 x 17 x 768
        x = self.Mixed_6c(x)
        # 17 x 17 x 768
        x = self.Mixed_6d(x)
        # 17 x 17 x 768
        x = self.Mixed_6e(x)
        # 17 x 17 x 768

        # image region features
        features = x
        # 17 x 17 x 768

        x = self.Mixed_7a(x)
        # 8 x 8 x 1280
        x = self.Mixed_7b(x)
        # 8 x 8 x 2048
        x = self.Mixed_7c(x)
        # 8 x 8 x 2048
        x = F.avg_pool2d(x, kernel_size=8)
        # 1 x 1 x 2048
        # x = F.dropout(x, training=self.training)
        # 1 x 1 x 2048
        x = x.view(x.size(0), -1)
        # 2048

        # global image features
        cnn_code = self.emb_cnn_code(x)
        # 512
        if features is not None:
            features = self.emb_features(features)
        return features, cnn_code


In [10]:
class TDAnet(BaseModel):
    """This class implements the text-guided image completion, for 256*256 resolution"""
    def name(self):
        return "TDAnet Image Completion"

    @staticmethod
    def modify_options(parser, is_train=True):
        """Add new options and rewrite default values for existing options"""
        parser.add_argument('--prior_alpha', type=float, default=0.8,
                            help='factor to contorl prior variation: 1/(1+e^((x-0.8)*8))')
        parser.add_argument('--prior_beta', type=float, default=8,
                            help='factor to contorl prior variation: 1/(1+e^((x-0.8)*8))')
        parser.add_argument('--no_maxpooling', action='store_true', help='rm maxpooling in DMA for ablation')
        parser.add_argument('--update_language', action='store_true', help='update language encoder while training')
        parser.add_argument('--detach_embedding', action='store_true',
                            help='do not pass grad to embedding in DAMSM-text end')

        if is_train:
            parser.add_argument('--train_paths', type=str, default='two', help='training strategies with one path or two paths')
            parser.add_argument('--dynamic_sigma', action='store_true', help='change sigma base on mask area')
            parser.add_argument('--lambda_rec_l1', type=float, default=20.0, help='weight for image reconstruction loss')
            parser.add_argument('--lambda_gen_l1', type=float, default=20.0, help='weight for image reconstruction loss')
            parser.add_argument('--lambda_kl', type=float, default=20.0, help='weight for kl divergence loss')
            parser.add_argument('--lambda_gan', type=float, default=1.0, help='weight for generation loss')
            parser.add_argument('--lambda_match', type=float, default=0.1, help='weight for image-text match loss')

        return parser

    def __init__(self, opt):
        """Initial the pluralistic model"""
        BaseModel.__init__(self, opt)

        self.loss_names = ['kl_rec', 'kl_g', 'l1_rec', 'l1_g', 'gan_g', 'word_g', 'sentence_g', 'ad_l2_g',
                           'gan_rec', 'ad_l2_rec', 'word_rec', 'sentence_rec',  'dis_img', 'dis_img_rec']
        self.log_names = []
        self.visual_names = ['img_m', 'img_truth', 'img_c', 'img_out', 'img_g', 'img_rec']
        self.text_names = ['text_positive']
        self.value_names = ['u_m', 'sigma_m', 'u_post', 'sigma_post', 'u_prior', 'sigma_prior']
        self.model_names = ['E', 'G', 'D', 'D_rec']
        self.distribution = []
        self.prior_alpha = opt.prior_alpha
        self.prior_beta = opt.prior_beta
        self.max_pool = None if opt.no_maxpooling else 'max'

        # define the inpainting model
        self.net_E = network.define_att_textual_e(ngf=32, z_nc=256, img_f=256, layers=5, norm='none', activation='LeakyReLU',
                          init_type='orthogonal', gpu_ids=opt.gpu_ids, image_dim=256, text_dim=256, multi_peak=False, pool_attention=self.max_pool)
        self.net_G = network.define_hidden_textual_g(f_text_dim=768, ngf=32, z_nc=256, img_f=256, L=0, layers=5, output_scale=opt.output_scale,
                                      norm='instance', activation='LeakyReLU', init_type='orthogonal', gpu_ids=opt.gpu_ids)
        # define the discriminator model
        self.net_D = network.define_d(ndf=32, img_f=128, layers=5, model_type='ResDis', init_type='orthogonal', gpu_ids=opt.gpu_ids)
        self.net_D_rec = network.define_d(ndf=32, img_f=128, layers=5, model_type='ResDis', init_type='orthogonal', gpu_ids=opt.gpu_ids)

        text_config = TextConfig(opt.text_config)
        self._init_language_model(text_config)

        if self.isTrain:
            # define the loss functions
            self.GANloss = external_function.GANLoss(opt.gan_mode)
            self.L1loss = torch.nn.L1Loss()
            self.L2loss = torch.nn.MSELoss()

            self.image_encoder = network.CNN_ENCODER(text_config.EMBEDDING_DIM)
            state_dict = torch.load(
                text_config.IMAGE_ENCODER, map_location=lambda storage, loc: storage)
            self.image_encoder.load_state_dict(state_dict)
            self.image_encoder.eval()
            if len(self.gpu_ids) > 0 and torch.cuda.is_available():
                self.image_encoder.cuda()
            base_function._freeze(self.image_encoder)

            # define the optimizer
            self.optimizer_G = torch.optim.Adam(itertools.chain(filter(lambda p: p.requires_grad, self.net_G.parameters()),
                        filter(lambda p: p.requires_grad, self.net_E.parameters())), lr=opt.lr, betas=(0.0, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(filter(lambda p: p.requires_grad, self.net_D.parameters()),
                                                filter(lambda p: p.requires_grad, self.net_D_rec.parameters())),
                                                lr=opt.lr, betas=(0.0, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

        self.setup(opt)

    def _init_language_model(self, text_config):
        x = pickle.load(open(text_config.VOCAB, 'rb'))
        self.ixtoword = x[2]
        self.wordtoix = x[3]

        word_len = len(self.wordtoix)
        self.text_encoder = network.RNN_ENCODER(word_len, nhidden=256)

        state_dict = torch.load(text_config.LANGUAGE_ENCODER, map_location=lambda storage, loc: storage)
        self.text_encoder.load_state_dict(state_dict)
        self.text_encoder.eval()
        if not self.opt.update_language:
            self.text_encoder.requires_grad_(False)
        if len(self.gpu_ids) > 0 and torch.cuda.is_available():
            self.text_encoder.cuda()

    def set_input(self, input, epoch=0):
        """Unpack input data from the data loader and perform necessary pre-process steps"""
        self.input = input
        self.image_paths = self.input['img_path']
        self.img = input['img']
        self.mask = input['mask']
        self.caption_idx = input['caption_idx']
        self.caption_length = input['caption_len']

        if len(self.gpu_ids) > 0:
            self.img = self.img.cuda(self.gpu_ids[0], True)
            self.mask = self.mask.cuda(self.gpu_ids[0], True)

        # get I_m and I_c for image with mask and complement regions for training
        self.img_truth = self.img * 2 - 1
        self.img_m = self.mask * self.img_truth
        self.img_c =  (1 - self.mask) * self.img_truth

        # get multiple scales image ground truth and mask for training
        self.scale_img = task.scale_pyramid(self.img_truth, self.opt.output_scale)
        self.scale_mask = task.scale_pyramid(self.mask, self.opt.output_scale)

        # About text stuff
        self.text_positive = util.idx_to_caption(
                                    self.ixtoword, self.caption_idx[0].tolist(), self.caption_length[0].item())
        self.word_embeddings, self.sentence_embedding = util.vectorize_captions_idx_batch(
                                                    self.caption_idx, self.caption_length, self.text_encoder)
        self.text_mask = util.lengths_to_mask(self.caption_length, max_length=self.word_embeddings.size(-1))
        self.match_labels = torch.LongTensor(range(len(self.img_m)))
        if len(self.gpu_ids) > 0:
            self.word_embeddings = self.word_embeddings.cuda(self.gpu_ids[0], True)
            self.sentence_embedding = self.sentence_embedding.cuda(self.gpu_ids[0], True)
            self.text_mask = self.text_mask.cuda(self.gpu_ids[0], True)
            self.match_labels = self.match_labels.cuda(self.gpu_ids[0], True)

    def test(self, mark=None):
        """Forward function used in test time"""
        # save the groundtruth and masked image
        self.save_results(self.img_truth, data_name='truth')
        self.save_results(self.img_m, data_name='mask')

        # encoder process
        distribution, f, f_text = self.net_E(
            self.img_m, self.sentence_embedding, self.word_embeddings, self.text_mask, self.mask)
        variation_factor = 0. if self.opt.no_variance else 1.
        q_distribution = torch.distributions.Normal(distribution[-1][0], distribution[-1][1] * variation_factor)
        scale_mask = task.scale_img(self.mask, size=[f[2].size(2), f[2].size(3)])

        # decoder process
        for i in range(self.opt.nsampling):
            z = q_distribution.sample()

            self.img_g, attn = self.net_G(z, f_text, f_e=f[2], mask=scale_mask.chunk(3, dim=1)[0])
            self.img_out = (1 - self.mask) * self.img_g[-1].detach() + self.mask * self.img_m
            self.score = self.net_D(self.img_out)
            self.save_results(self.img_out, i, data_name='out', mark=mark)

    def get_distribution(self, distribution_factors):
        """Calculate encoder distribution for img_m, img_c only in train, all about distribution layer of VAE model"""
        # get distribution
        sum_valid = (torch.mean(self.mask.view(self.mask.size(0), -1), dim=1) - 1e-5).view(-1, 1, 1, 1)
        m_sigma = 1 if not self.opt.dynamic_sigma else (1 / (1 + ((sum_valid - self.prior_alpha) * self.prior_beta).exp_()))
        p_distribution, q_distribution, kl_rec, kl_g = 0, 0, 0, 0
        self.distribution = []
        for distribution in distribution_factors:
            p_mu, p_sigma, q_mu, q_sigma = distribution
            # the assumption distribution for different mask regions
            std_distribution = torch.distributions.Normal(torch.zeros_like(p_mu), m_sigma * torch.ones_like(p_sigma))
            # m_distribution = torch.distributions.Normal(torch.zeros_like(p_mu), torch.ones_like(p_sigma))
            # the post distribution from mask regions
            p_distribution = torch.distributions.Normal(p_mu, p_sigma)
            p_distribution_fix = torch.distributions.Normal(p_mu.detach(), p_sigma.detach())
            # the prior distribution from valid region
            q_distribution = torch.distributions.Normal(q_mu, q_sigma)

            # kl divergence
            kl_rec += torch.distributions.kl_divergence(std_distribution, p_distribution)
            if self.opt.train_paths == "one":
                kl_g += torch.distributions.kl_divergence(std_distribution, q_distribution)
            elif self.opt.train_paths == "two":
                kl_g += torch.distributions.kl_divergence(p_distribution_fix, q_distribution)
            self.distribution.append([torch.zeros_like(p_mu), m_sigma * torch.ones_like(p_sigma), p_mu, p_sigma, q_mu, q_sigma])

        return p_distribution, q_distribution, kl_rec, kl_g

    def get_G_inputs(self, p_distribution, q_distribution, f):
        """Process the encoder feature and distributions for generation network, combine two dataflow when implement."""
        f_m = torch.cat([f[-1].chunk(2)[0], f[-1].chunk(2)[0]], dim=0)
        f_e = torch.cat([f[2].chunk(2)[0], f[2].chunk(2)[0]], dim=0)
        scale_mask = task.scale_img(self.mask, size=[f_e.size(2), f_e.size(3)])
        mask = torch.cat([scale_mask.chunk(3, dim=1)[0], scale_mask.chunk(3, dim=1)[0]], dim=0)
        z_p = p_distribution.rsample()
        z_q = q_distribution.rsample()
        z = torch.cat([z_p, z_q], dim=0)
        return z, f_m, f_e, mask

    def forward(self):
        """Run forward processing to get the inputs"""
        # encoder process
        distribution_factors, f, f_text = self.net_E(
            self.img_m, self.sentence_embedding, self.word_embeddings, self.text_mask, self.mask, self.img_c)

        p_distribution, q_distribution, self.kl_rec, self.kl_g = self.get_distribution(distribution_factors)

        # decoder process
        z, f_m, f_e, mask = self.get_G_inputs(p_distribution, q_distribution, f) # prepare inputs: img, mask, distribute

        results, attn = self.net_G(z, f_text, f_e, mask)
        self.img_rec = []
        self.img_g = []
        for result in results:
            img_rec, img_g = result.chunk(2)
            self.img_rec.append(img_rec)
            self.img_g.append(img_g)
        self.img_out = (1-self.mask) * self.img_g[-1].detach() + self.mask * self.img_truth

        self.region_features_rec, self.cnn_code_rec = self.image_encoder(self.img_rec[-1])
        self.region_features_g, self.cnn_code_g = self.image_encoder(self.img_g[-1])


    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator"""
        # Real
        D_real = netD(real)
        D_real_loss = self.GANloss(D_real, True, True)
        # fake
        D_fake = netD(fake.detach())
        D_fake_loss = self.GANloss(D_fake, False, True)
        # loss for discriminator
        D_loss = (D_real_loss + D_fake_loss) * 0.5
        # gradient penalty for wgan-gp
        if self.opt.gan_mode == 'wgangp':
            gradient_penalty, gradients = external_function.cal_gradient_penalty(netD, real, fake.detach())
            D_loss +=gradient_penalty

        D_loss.backward()

        return D_loss

    def backward_D(self):
        """Calculate the GAN loss for the discriminators"""
        base_function._unfreeze(self.net_D, self.net_D_rec)
        ## Note: changed gen path gan loss to rec path
        # self.loss_dis_img = self.backward_D_basic(self.net_D, self.img_truth, self.img_g[-1])
        self.loss_dis_img = self.backward_D_basic(self.net_D, self.img_truth, self.img_rec[-1])
        self.loss_dis_img_rec = self.backward_D_basic(self.net_D_rec, self.img_truth, self.img_rec[-1])

    def backward_G(self):
        """Calculate training loss for the generator"""

        # encoder kl loss
        self.loss_kl_rec = self.kl_rec.mean() * self.opt.lambda_kl * self.opt.output_scale
        self.loss_kl_g = self.kl_g.mean() * self.opt.lambda_kl * self.opt.output_scale

        # Adversarial loss
        base_function._freeze(self.net_D, self.net_D_rec)

        # D loss fake
        D_fake_g = self.net_D(self.img_g[-1])
        self.loss_gan_g = self.GANloss(D_fake_g, True, False) * self.opt.lambda_gan
        D_fake_rec = self.net_D(self.img_rec[-1])
        self.loss_gan_rec = self.GANloss(D_fake_rec, True, False) * self.opt.lambda_gan

        # LSGAN loss
        D_fake = self.net_D_rec(self.img_rec[-1])
        D_real = self.net_D_rec(self.img_truth)
        D_fake_g = self.net_D_rec(self.img_g[-1])
        self.loss_ad_l2_rec = self.L2loss(D_fake, D_real) * self.opt.lambda_gan
        self.loss_ad_l2_g = self.L2loss(D_fake_g, D_real) * self.opt.lambda_gan

        # Text-image consistent loss
        if not self.opt.detach_embedding:
            sentence_embedding = self.sentence_embedding
            word_embeddings = self.word_embeddings
        else:
            sentence_embedding = self.sentence_embedding.detach()
            word_embeddings = self.word_embeddings.detach()


        loss_sentence = base_function.sent_loss(self.cnn_code_rec, sentence_embedding, self.match_labels)
        loss_word, _ = base_function.words_loss(self.region_features_rec, word_embeddings, self.match_labels, \
                                 self.caption_length, len(word_embeddings))
        self.loss_word_rec = loss_word * self.opt.lambda_match
        self.loss_sentence_rec = loss_sentence * self.opt.lambda_match

        loss_sentence = base_function.sent_loss(self.cnn_code_g, sentence_embedding, self.match_labels)
        loss_word, _ = base_function.words_loss(self.region_features_g, word_embeddings, self.match_labels, \
                                 self.caption_length, len(word_embeddings))
        self.loss_word_g = loss_word * self.opt.lambda_match
        self.loss_sentence_g = loss_sentence * self.opt.lambda_match


        # calculate l1 loss ofr multi-scale, multi-depth-level outputs
        loss_l1_rec, loss_l1_g, log_PSNR_rec, log_PSNR_out = 0, 0, 0, 0
        for i, (img_rec_i, img_fake_i, img_out_i, img_real_i, mask_i) in enumerate(zip(self.img_rec, self.img_g, self.img_out, self.scale_img, self.scale_mask)):
            loss_l1_rec += self.L1loss(img_rec_i, img_real_i)
            if self.opt.train_paths == "one":
                loss_l1_g += self.L1loss(img_fake_i, img_real_i)
            elif self.opt.train_paths == "two":
                loss_l1_g += self.L1loss(img_fake_i, img_real_i)

        self.loss_l1_rec = loss_l1_rec * self.opt.lambda_rec_l1
        self.loss_l1_g = loss_l1_g * self.opt.lambda_gen_l1

        # if one path during the training, just calculate the loss for generation path
        if self.opt.train_paths == "one":
            self.loss_l1_rec = self.loss_l1_rec * 0
            self.loss_ad_l2_rec = self.loss_ad_l2_rec * 0
            self.loss_kl_rec = self.loss_kl_rec * 0

        total_loss = 0

        for name in self.loss_names:
            if name != 'dis_img' and name != 'dis_img_rec':
                total_loss += getattr(self, "loss_" + name)

        total_loss.backward()

    def optimize_parameters(self):
        """update network weights"""
        # compute the image completion results
        self.forward()
        # optimize the discrinimator network parameters
        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()
        # optimize the completion network parameters
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

In [11]:
"""This package contains modules related to function, network architectures, and models"""

def find_model_using_name(model_name):
    """Import the module "model/[model_name]_model.py"."""
    model_file_name = "model." + model_name + "_model"
    modellib = importlib.import_module(model_file_name)
    model = None
    for name, cls in modellib.__dict__.items():
        if name.lower() == model_name.lower() and issubclass(cls, BaseModel):
            model = cls

    if model is None:
        print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_file_name, model_name))
        exit(0)

    return model


def get_option_setter(model_name):
    """Return the static method <modify_commandline_options> of the model class."""
    model = find_model_using_name(model_name)
    return model.modify_options


def create_model(opt):
    """Create a model given the option."""
    model = find_model_using_name(opt.model)
    instance = model(opt)
    print("model [%s] was created" % type(instance).__name__)
    return instance

In [12]:
class CreateDataset(data.Dataset):
    def __init__(self, opt, debug=False):
        self.opt = opt
        self.debug = debug
        self.img_paths, self.img_size = make_dataset(opt.img_file)
        # provides random file for training and testing
        if opt.mask_file != 'none':
            if not opt.mask_file.endswith('.json'):
                self.mask_paths, self.mask_size = make_dataset(opt.mask_file)
            else:
                with open(opt.mask_file, 'r') as f:
                    self.image_bbox = json.load(f)

        self.transform = get_transform(opt)

        ## ========Abnout text stuff===============
        text_config = TextConfig(opt.text_config)
        self.max_length = text_config.MAX_TEXT_LENGTH
        if 'coco' in text_config.CAPTION.lower():
            self.num_captions = 5
        elif 'place' in text_config.CAPTION.lower():
            self.num_captions = 1
        else:
            self.num_captions = 10

        # load caption file
        with open(text_config.CAPTION, 'r') as f:
            self.captions = json.load(f)

        x = pickle.load(open(text_config.VOCAB, 'rb'))
        self.ixtoword = x[2]
        self.wordtoix = x[3]

        self.epoch = 0 # Used for iter on captions.

    def __getitem__(self, index):
        # load image
        index = self.epoch*self.img_size+index

        img, img_path = self.load_img(index)
        # load mask
        mask = self.load_mask(img, index, img_path)
        assert sum(img.shape) == sum(mask.shape), (img.shape, mask.shape)
        caption_idx, caption_len, caption, img_name= self._load_text_idx(index)
        return {'img': img, 'img_path': img_path, 'mask': mask, \
                'caption_idx' : torch.Tensor(caption_idx).long(), 'caption_len':caption_len,\
                'caption_text': caption, 'image_path': img_name}

    def __len__(self):
        return self.img_size

    def name(self):
        return "inpainting dataset"

    def load_img(self, index):
        ImageFile.LOAD_TRUNCATED_IMAGES = True
        img_path = self.img_paths[index % self.img_size]
        img_pil = Image.open(img_path).convert('RGB')
        img = self.transform(img_pil)
        img_pil.close()
        return img, img_path

    def _load_text_idx(self, image_index):
        img_name = self.img_paths[image_index % self.img_size]
        caption_index_of_image = image_index // self.img_size  % self.num_captions
        img_name = os.path.basename(img_name)
        captions = self.captions[img_name]
        caption = captions[caption_index_of_image] if type(captions) == list else captions
        caption_idx, caption_len = util._caption_to_idx(self.wordtoix, caption, self.max_length)

        return caption_idx, caption_len, caption, img_name

    def load_mask(self, img, index, img_path):
        """Load different mask types for training and testing"""
        mask_type_index = random.randint(0, len(self.opt.mask_type) - 1)
        mask_type = self.opt.mask_type[mask_type_index]

        # center mask
        if mask_type == 0:
            return task.center_mask(img)

        # random regular mask
        if mask_type == 1:
            return task.random_regular_mask(img)

        # random irregular mask
        if mask_type == 2:
            return task.random_irregular_mask(img)

        if mask_type == 3:
            # file masks, e.g. CUB object mask
            mask_index = index
            mask_pil = Image.open(self.mask_paths[mask_index]).convert('RGB')

            mask_transform = get_transform_mask(self.opt)

            mask = (mask_transform(mask_pil) == 0).float()
            mask_pil.close()
            return mask

        if mask_type == 4:
            # coco json file object mask
            if os.path.basename(img_path) not in self.image_bbox:
                return task.random_regular_mask(img)

            img_original = np.asarray(Image.open(img_path).convert('RGB'))

            # create a mask matrix same as img_original
            mask = np.zeros_like(img_original)
            bboxes = self.image_bbox[os.path.basename(img_path)]

            # choose max area box
            choosen_box = 0,0,0,0
            max_area = 0
            for x1,x2,y1,y2 in bboxes:
                area = (x2-x1) * (y2-y1)
                if area > max_area:
                    max_area = area
                    choosen_box = x1,x2,y1,y2
            x1, x2, y1, y2 = choosen_box
            mask[x1:x2, y1:y2] = 1

            # apply same transform as img to the mask
            mask_pil = Image.fromarray(mask)

            mask_transform = get_transform_mask(self.opt)

            mask = (mask_transform(mask_pil) == 0).float()

            mask_pil.close()

            return mask


def dataloader(opt):
    datasets = CreateDataset(opt)
    dataset = data.DataLoader(datasets, batch_size=opt.batchSize, shuffle=not opt.no_shuffle, num_workers=int(opt.nThreads), pin_memory=True)

    return dataset

def get_transform_mask(opt):
    """Basic process to transform PIL image to torch tensor"""
    transform_list = []
    osize = [opt.loadSize[0], opt.loadSize[1]]
    fsize = [opt.fineSize[0], opt.fineSize[1]]
    if opt.isTrain:
        if opt.resize_or_crop == 'resize_and_crop':
            transform_list.append(transforms.Resize(osize))
            transform_list.append(transforms.RandomCrop(fsize))
        elif opt.resize_or_crop == 'crop':
            transform_list.append(transforms.RandomCrop(fsize))
        if not opt.no_flip:
            transform_list.append(transforms.RandomHorizontalFlip())
        if not opt.no_rotation:
            transform_list.append(transforms.RandomRotation(3))
    else:
        transform_list.append(transforms.Resize(fsize))

    transform_list += [transforms.ToTensor()]

    return transforms.Compose(transform_list)

def get_transform(opt):
    """Basic process to transform PIL image to torch tensor"""
    transform_list = []
    osize = [opt.loadSize[0], opt.loadSize[1]]
    fsize = [opt.fineSize[0], opt.fineSize[1]]
    if opt.isTrain:
        if opt.resize_or_crop == 'resize_and_crop':
            transform_list.append(transforms.Resize(osize))
            transform_list.append(transforms.RandomCrop(fsize))
        elif opt.resize_or_crop == 'crop':
            transform_list.append(transforms.RandomCrop(fsize))
        if not opt.no_augment:
            transform_list.append(transforms.ColorJitter(0.0, 0.0, 0.0, 0.0))
        if not opt.no_flip:
            transform_list.append(transforms.RandomHorizontalFlip())
        if not opt.no_rotation:
            transform_list.append(transforms.RandomRotation(3))
    else:
        transform_list.append(transforms.Resize(fsize))

    transform_list += [transforms.ToTensor()]

    return transforms.Compose(transform_list)

In [13]:
import os
import os.path

IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(path_files):
    if os.path.isfile(path_files):
        paths, size = make_dataset_txt(path_files)
    else:
        paths, size = make_dataset_dir(path_files)

    return paths, size


def make_dataset_txt(files):
    """
    :param path_files: the path of txt file that store the image paths
    :return: image paths and sizes
    """
    img_paths = []

    with open(files) as f:
        paths = f.readlines()

    for path in paths:
        path = path.strip()
        img_paths.append(os.path.join(os.path.dirname(files), path))

    return img_paths, len(img_paths)


def make_dataset_dir(dir):
    """
    :param dir: directory paths that store the image
    :return: image paths and sizes
    """
    img_paths = []

    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    for root, _, fnames in os.walk(dir):
        for fname in sorted(fnames):
            if is_image_file(fname):
                path = os.path.join(root, fname)
                img_paths.append(path)

    return img_paths, len(img_paths)

In [14]:
import argparse
import os
import torch
import model
from util import util


class BaseOptions():
    def __init__(self):
        self.parser = argparse.ArgumentParser()
        self.initialized = False

    def initialize(self, parser):
        # base define
        parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment.')
        parser.add_argument('--model', type=str, default='tdanet', help='name of the model type. [pluralistic]')
        parser.add_argument('--mask_type', type=int, default=[1, 2, 3], nargs='+',
                            help='mask type, 0: center mask, 1:random regular mask, '
                            '2: random irregular mask. 3: external irregular mask. 4: external json bbox mask'
                            ' [0],[1,2],[1,2,3]')
        parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are save here')
        parser.add_argument('--which_iter', type=str, default='latest', help='which iterations to load')
        parser.add_argument('--gpu_ids', type=str, default='-1', help='gpu ids: e.g. 0, 1, 2 use -1 for CPU')
        parser.add_argument('--text_config', type=str, default='config.bird.yml', help='path to text config')
        parser.add_argument('--output_scale', type=int, default=4, help='# of number of the output scale')

        # data pattern define
        parser.add_argument('--img_file', type=str, default='/home/hemanthgaddey/Documents/tdanet_/tdanet/dataset/CUB_200_2011/images/002.Laysan_Albatross', help='training and testing dataset')
        parser.add_argument('--mask_file', type=str, default='none', help='load test mask')
        parser.add_argument('--loadSize', type=int, default=[266, 266], help='scale images to this size')
        parser.add_argument('--fineSize', type=int, default=[256, 256], help='then crop to this size')
        parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|]')
        parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the image for data augmentation')
        parser.add_argument('--no_rotation', action='store_true', help='if specified, do not rotation for data augmentation')
        parser.add_argument('--no_augment', action='store_true', help='if specified, do not augment the image for data augmentation')
        parser.add_argument('--batchSize', type=int, default=10, help='input batch size')
        parser.add_argument('--nThreads', type=int, default=8, help='# threads for loading data')
        parser.add_argument('--no_shuffle', action='store_true',help='if true, takes images serial')

        # display parameter define
        parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
        parser.add_argument('--display_id', type=int, default=1, help='display id of the web')
        parser.add_argument('--display_port', type=int, default=8097, help='visidom port of the web display')
        parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visidom web panel')

        return parser

    def gather_options(self):
        """Add additional model-specific options"""

        if not self.initialized:
            parser = self.initialize(self.parser)

        # get basic options
        opt, _ = parser.parse_known_args()

        # modify the options for different models
        model_option_set = model.get_option_setter(opt.model)
        parser = model_option_set(parser, self.isTrain)
        opt = parser.parse_args('')

        return opt

    def parse(self):
        """Parse the options"""

        opt = self.gather_options()
        opt.isTrain = self.isTrain

        self.print_options(opt)

        # set gpu ids
        str_ids = opt.gpu_ids.split(',')
        opt.gpu_ids = []
        for str_id in str_ids:
            id = int(str_id)
            if id >= 0:
                opt.gpu_ids.append(id)
        if len(opt.gpu_ids):
            torch.cuda.set_device(opt.gpu_ids[0])

        self.opt = opt

        return self.opt

    @staticmethod
    def print_options(opt):
        """print and save options"""

        print('--------------Options--------------')
        for k, v in sorted(vars(opt).items()):
            print('%s: %s' % (str(k), str(v)))
        print('----------------End----------------')

        # save to the disk
        expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
        util.mkdirs(expr_dir)
        if opt.isTrain:
            file_name = os.path.join(expr_dir, 'train_opt.txt')
        else:
            file_name = os.path.join(expr_dir, 'test_opt.txt')
        with open(file_name, 'wt') as opt_file:
            opt_file.write('--------------Options--------------\n')
            for k, v in sorted(vars(opt).items()):
                opt_file.write('%s: %s\n' % (str(k), str(v)))
            opt_file.write('----------------End----------------\n')

In [15]:
class TrainOptions(BaseOptions):
    def initialize(self, parser):
        parser = BaseOptions.initialize(self, parser)

        # training epoch
        parser.add_argument('--iter_count', type=int, default=1, help='the starting epoch count')
        parser.add_argument('--niter', type=int, default=5000000, help='# of iter with initial learning rate')
        parser.add_argument('--niter_decay', type=int, default=0, help='# of iter to decay learning rate to zero')
        parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
        parser.add_argument('--valid_file', type=str, default='/data/dataset/valid', help='valid dataset')

        # learning rate and loss weight
        parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy[lambda|step|plateau]')
        parser.add_argument('--lr', type=float, default=1e-4, help='initial learning rate for adam')
        parser.add_argument('--gan_mode', type=str, default='lsgan', choices=['wgan-gp', 'hinge', 'lsgan'])

        # display the results
        parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
        parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
        parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results at each batch')
        parser.add_argument('--save_iters_freq', type=int, default=10000, help='frequency of saving checkpoints at the end of batch')
        parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results')

        self.isTrain = True

        return parser

In [16]:
import numpy as np
import os
import imageio
import torch

from nltk.tokenize import word_tokenize

def idx_to_caption(ixtoword, caption, length):
    """ Turn idx to word"""
    return ' '.join([ixtoword[caption[i]] for i in range(length)])

def _caption_to_idx(wordtoix, caption, max_length):
    '''Transform single text caption to idx and length tensor'''
    caption_token = word_tokenize(caption.lower())

    caption_idx = []
    for token in caption_token:
        t = token.encode('ascii', 'ignore').decode('ascii')
        if len(t) > 0 and t in wordtoix:
            caption_idx.append(wordtoix[t])

    length = len(caption_idx)
    if length <= max_length:
        caption_idx = caption_idx + [0] * (max_length - len(caption_idx))
    else:
        caption_idx = caption_idx[:max_length]

    return caption_idx, length

def vectorize_captions_idx_batch(batch_padded_captions_idx, batch_length, language_encoder):
    '''Transform batch_padded_captions_idx to sentence embedding'''
    batch_size = len(batch_length)

    with torch.no_grad():
        hidden = language_encoder.init_hidden(batch_size)
        device = hidden[0].device
        word_embs, sent_emb = language_encoder(batch_padded_captions_idx.to(device), \
                                       batch_length.to(device), hidden)
    return word_embs, sent_emb

def lengths_to_mask(lengths, max_length, device=None):
    '''transform digital lengths to tensor mask.'''
    masks = torch.ones(len(lengths), max_length)
    for i, length in enumerate(lengths):
        masks[i,:length] = 0
    masks = masks.bool()
    return masks if device is None else masks.to(device)


def PSNR(a, b):
    '''compute PSNR for a and b image'''
    mse = np.mean((a - b) ** 2) + 1e-8

    return 20 * np.log10(255.0 / np.sqrt(mse))

def tensor_image_scale(tensor):
    '''scale the value in tensor as image'''
    return (tensor + 1) / 2.0 * 255.0

# convert a tensor into a numpy array
def tensor2im(image_tensor, bytes=255.0, imtype=np.uint8):
    if image_tensor.dim() == 3:
        image_numpy = image_tensor.cpu().float().numpy()
    else:
        image_numpy = image_tensor[0].cpu().float().numpy()
    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * bytes

    return image_numpy.astype(imtype)


# conver a tensor into a numpy array
def tensor2array(value_tensor):
    if value_tensor.dim() == 3:
        numpy = value_tensor.view(-1).cpu().float().numpy()
    else:
        numpy = value_tensor[0].view(-1).cpu().float().numpy()
    return numpy


def save_image(image_numpy, image_path):
    if image_numpy.shape[2] == 1:
        image_numpy = image_numpy.reshape(image_numpy.shape[0], image_numpy.shape[1])

    imageio.imwrite(image_path, image_numpy)


def mkdirs(paths):
    if isinstance(paths, list) and not isinstance(paths, str):
        for path in paths:
            mkdir(path)
    else:
        mkdir(paths)


def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [17]:
import dominate
from dominate.tags import *
import os


class HTML:
    def __init__(self, web_dir, title, reflesh=0):
        self.title = title
        self.web_dir = web_dir
        self.img_dir = os.path.join(self.web_dir, 'images')
        if not os.path.exists(self.web_dir):
            os.makedirs(self.web_dir)
        if not os.path.exists(self.img_dir):
            os.makedirs(self.img_dir)
        # print(self.img_dir)

        self.doc = dominate.document(title=title)
        if reflesh > 0:
            with self.doc.head:
                meta(http_equiv="reflesh", content=str(reflesh))

    def get_image_dir(self):
        return self.img_dir

    def add_header(self, str):
        with self.doc:
            h3(str)

    def add_table(self, border=1):
        self.t = table(border=border, style="table-layout: fixed;")
        self.doc.add(self.t)

    def add_images(self, ims, txts, links, width=400):
        self.add_table()
        with self.t:
            with tr():
                for im, txt, link in zip(ims, txts, links):
                    with td(style="word-wrap: break-word;", halign="center", valign="top"):
                        with p():
                            with a(href=os.path.join('images', link)):
                                img(style="width:%dpx" % width, src=os.path.join('images', im))
                            br()
                            p(txt)

    def save(self):
        html_file = '%s/index.html' % self.web_dir
        f = open(html_file, 'wt')
        f.write(self.doc.render())
        f.close()


if __name__ == '__main__':
    html = HTML('web/', 'test_html')
    html.add_header('hello world')

    ims = []
    txts = []
    links = []
    for n in range(4):
        ims.append('image_%d.png' % n)
        txts.append('text_%d' % n)
        links.append('image_%d.png' % n)
    html.add_images(ims, txts, links)
    html.save()

In [18]:
import numpy as np
import os
import ntpath
import time

class Visualizer():
    def __init__(self, opt):
        # self.opt = opt
        self.display_id = opt.display_id
        self.use_html = opt.isTrain and not opt.no_html
        self.win_size = opt.display_winsize
        self.name = opt.name
        if self.display_id > 0:
            import visdom
            self.vis = visdom.Visdom(port = opt.display_port)
            self.display_single_pane_ncols = opt.display_single_pane_ncols

        if self.use_html:
            self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
            self.img_dir = os.path.join(self.web_dir, 'images')
            print('create web directory %s...' % self.web_dir)
            util.mkdirs([self.web_dir, self.img_dir])
        self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
        with open(self.log_name, "a") as log_file:
            now = time.strftime("%c")
            log_file.write('================ Training Loss (%s) ================\n' % now)

    # |visuals|: dictionary of images to display or save
    def display_current_results(self, visuals, text, epoch):
        if self.display_id > 0: # show images in the browser
            if self.display_single_pane_ncols > 0:
                h, w = next(iter(visuals.values())).shape[:2]
                table_css = """<style>
    table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center}
    table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black}
</style>""" % (w, h)
                ncols = self.display_single_pane_ncols
                title = self.name
                label_html = ''
                label_html_row = ''
                nrows = int(np.ceil(len(visuals.items()) / ncols))
                images = []
                idx = 0
                for label, image_numpy in visuals.items():
                    label_html_row += '<td>%s</td>' % label
                    images.append(image_numpy.transpose([2, 0, 1]))
                    idx += 1
                    if idx % ncols == 0:
                        label_html += '<tr>%s</tr>' % label_html_row
                        label_html_row = ''
                white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255
                while idx % ncols != 0:
                    images.append(white_image)
                    label_html_row += '<td></td>'
                    idx += 1
                if label_html_row != '':
                    label_html += '<tr>%s</tr>' % label_html_row
                # pane col = image row
                self.vis.images(images, nrow=ncols, win=self.display_id + 1,
                                padding=2, opts=dict(title=title + ' images'))
                label_html = '<table>%s</table>' % label_html
                self.vis.text(table_css + label_html, win = self.display_id + 2,
                              opts=dict(title=title + ' labels'))
            else:
                idx = 1
                for label, image_numpy in visuals.items():
                    #image_numpy = np.flipud(image_numpy)
                    self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label),
                                       win=self.display_id + idx)
                    idx += 1

        for key, value in text.items():
            self.vis.text(value, win=key, opts=dict(title=key))

        if self.use_html: # save images to a html file
            for label, image_numpy in visuals.items():
                img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
                util.save_image(image_numpy, img_path)
            # update website
            webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
            for n in range(epoch, 0, -1):
                webpage.add_header('epoch [%d]' % n)
                ims = []
                txts = []
                links = []

                for label, image_numpy in visuals.items():
                    img_path = 'epoch%.3d_%s.png' % (n, label)
                    ims.append(img_path)
                    txts.append(label)
                    links.append(img_path)
                webpage.add_images(ims, txts, links, width=self.win_size)
            webpage.save()

    # errors: dictionary of error labels and values
    def plot_current_errors(self, iters, errors):
        if not hasattr(self, 'plot_data'):
            self.plot_data = {'X': [], 'Y': [], 'legend': list(errors.keys())}
        self.plot_data['X'].append(iters)
        self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
        self.vis.line(
            X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
            Y=np.array(self.plot_data['Y']),
            opts={'title': self.name + ' loss over time',
                  'legend': self.plot_data['legend'],
                  'xlabel': 'iterations',
                  'ylabel': 'loss'},
            win=self.display_id)

    def plot_current_score(self, epoch, counter_ratio, scores):
        if not hasattr(self, 'plot_score'):
            self.plot_score = {'X':[],'Y':[], 'legend':list(scores.keys())}
        self.plot_score['X'].append(epoch + counter_ratio)
        self.plot_score['Y'].append([scores[k] for k in self.plot_score['legend']])
        self.vis.line(
            X=np.stack([np.array(self.plot_score['X'])] * len(self.plot_score['legend']), 1),
            Y=np.array(self.plot_score['Y']),
            opts={
                'title': self.name + ' Inception Score over time',
                'legend': self.plot_score['legend'],
                'xlabel': 'epoch',
                'ylabel': 'score'},
            win=self.display_id + 29
        )

    # statistics distribution: draw data histogram
    def plot_current_distribution(self, distribution):
        name = list(distribution.keys())
        value = np.array(list(distribution.values())).swapaxes(1, 0)
        self.vis.boxplot(
            X=value,
            opts=dict(legend=name),
            win=self.display_id+30
        )

    # errors: same format as |errors| of plotCurrentErrors
    def print_current_errors(self, epoch, i, errors, t):
        message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
        for k, v in errors.items():
            message += '%s: %.3f ' % (k, v)

        print(message)
        with open(self.log_name, "a") as log_file:
            log_file.write('%s\n' % message)

    # save image to the disk
    def save_images(self, webpage, visuals, image_path):
        image_dir = webpage.get_image_dir()
        short_path = ntpath.basename(image_path[0])
        name = os.path.splitext(short_path)[0]

        webpage.add_header(name)
        ims = []
        txts = []
        links = []

        for label, image_numpy in visuals.items():
            image_name = '%s_%s.png' % (name, label)
            save_path = os.path.join(image_dir, image_name)
            util.save_image(image_numpy, save_path)

            ims.append(image_name)
            txts.append(label)
            links.append(image_name)
        webpage.add_images(ims, txts, links, width=self.win_size)

In [20]:
%tb
import time
s = '--name tda_bird  --gpu_ids 0 --model tdanet --mask_type 0 1 2 3 --img_file ./datasets/CUB_200_2011/train.flist --mask_file ./datasets/CUB_200_2011/train_mask.flist --text_config config.bird.yml'
opt = TrainOptions().parse()
# create a dataset
dataset = dataloader(opt)
dataset_size = len(dataset) * opt.batchSize
print('training images = %d' % dataset_size)
# create a model
model = create_model(opt)
# create a visualizer
visualizer = Visualizer(opt)
# training flag
keep_training = True
max_iteration = opt.niter+opt.niter_decay
epoch = 0
total_iteration = opt.iter_count

# training process
while(keep_training):
    epoch_start_time = time.time()
    epoch+=1
    print('\n Training epoch: %d' % epoch)

    for i, data in enumerate(dataset):
        dataset.epoch = epoch - 1
        iter_start_time = time.time()
        total_iteration += 1
        model.set_input(data)
        model.optimize_parameters()

        # display images on visdom and save images
        if total_iteration % opt.display_freq == 0:
            visualizer.display_current_results(model.get_current_visuals(), model.get_current_text(), epoch)
            visualizer.plot_current_distribution(model.get_current_dis())

        # print training loss and save logging information to the disk
        if total_iteration % opt.print_freq == 0:
            losses = model.get_current_errors()
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch, total_iteration, losses, t)
            if opt.display_id > 0:
                visualizer.plot_current_errors(total_iteration, losses)

        # save the latest model every <save_latest_freq> iterations to the disk
        if total_iteration % opt.save_latest_freq == 0:
            print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_iteration))
            model.save_networks('latest')

        # save the model every <save_iter_freq> iterations to the disk
        if total_iteration % opt.save_iters_freq == 0:
            print('saving the model of iterations %d' % total_iteration)
            model.save_networks(total_iteration)

        if total_iteration > max_iteration:
            keep_training = False
            break

    model.update_learning_rate()

    print('\nEnd training')

AssertionError: /data/dataset/train is not a valid directory

--------------Options--------------
batchSize: 10
checkpoints_dir: ./checkpoints
continue_train: False
detach_embedding: False
display_freq: 100
display_id: 1
display_port: 8097
display_single_pane_ncols: 0
display_winsize: 256
dynamic_sigma: False
fineSize: [256, 256]
gan_mode: lsgan
gpu_ids: -1
img_file: /data/dataset/train
isTrain: True
iter_count: 1
lambda_gan: 1.0
lambda_gen_l1: 20.0
lambda_kl: 20.0
lambda_match: 0.1
lambda_rec_l1: 20.0
loadSize: [266, 266]
lr: 0.0001
lr_policy: lambda
mask_file: none
mask_type: [1, 2, 3]
model: tdanet
nThreads: 8
name: experiment_name
niter: 5000000
niter_decay: 0
no_augment: False
no_flip: False
no_html: False
no_maxpooling: False
no_rotation: False
no_shuffle: False
output_scale: 4
print_freq: 100
prior_alpha: 0.8
prior_beta: 8
resize_or_crop: resize_and_crop
save_iters_freq: 10000
save_latest_freq: 1000
text_config: config.bird.yml
train_paths: two
update_language: False
valid_file: /data/dataset/valid
which_iter: latest
----------------End---

AssertionError: /data/dataset/train is not a valid directory

In [43]:
print('mo')

mo
