In [None]:
import torch
import torch.nn as nn
from torchvision import models
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
from typing import Any

from .conv_utils import conv2d

# Misc Functions

In [None]:
def conv1x1(in_planes, out_planes, bias=False):
    "1x1 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
                     padding=0, bias=bias)

def conv3x3(in_planes, out_planes):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1,
                     padding=1, bias=False)


def conv2d(
    in_channels: int,
    out_channels: int,
    kernel_size: int = 3,
    stride: int = 1,
    padding: int = 1,
) -> nn.Conv2d:
    """
    Template convolution which is typically used throughout the project

    :param int in_channels: Number of input channels
    :param int out_channels: Number of output channels
    :param int kernel_size: Size of sliding kernel
    :param int stride: How many steps kernel does when sliding
    :param int padding: How many dimensions to pad
    :return: Convolution layer with parameters
    :rtype: nn.Conv2d
    """
    return nn.Conv2d(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
    )

def conv1d(
    in_channels: int,
    out_channels: int,
    kernel_size: int = 1,
    stride: int = 1,
    padding: int = 0,
) -> nn.Conv1d:
    """
    Template 1d convolution which is typically used throughout the project

    :param int in_channels: Number of input channels
    :param int out_channels: Number of output channels
    :param int kernel_size: Size of sliding kernel
    :param int stride: How many steps kernel does when sliding
    :param int padding: How many dimensions to pad
    :return: Convolution layer with parameters
    :rtype: nn.Conv2d
    """
    return nn.Conv1d(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
    )


class GLU(nn.Module):
    def __init__(self):
        super(GLU, self).__init__()

    def forward(self, x):
        nc = x.size(1)
        assert nc % 2 == 0, 'channels dont divide 2!'
        nc = int(nc/2)
        return x[:, :nc] * torch.sigmoid(x[:, nc:])

## ACM

In [None]:
class ACM(nn.Module):
    """Affine Combination Module from ManiGAN"""

    def __init__(self, text_chans: int, img_chans: int, inner_dim: int = 64) -> None:
        """
        Initialize the convolutional layers

        :param int text_chans: Channels of textual input
        :param int img_chans: Channels in visual input
        :param int inner_dim: Hyperparameters for inner dimensionality of features
        """
        super().__init__()
        self.conv = conv2d(in_channels=img_chans, out_channels=inner_dim)
        self.weights = conv2d(in_channels=inner_dim, out_channels=text_chans)
        self.biases = conv2d(in_channels=inner_dim, out_channels=text_chans)

    def forward(self, text: torch.Tensor, img: torch.Tensor) -> Any:
        """
        Propagate the textual and visual input through the ACM module

        :param torch.Tensor text: Textual input
        :param torch.Tensor img: Image input
        :return: Affine combination of text and image
        :rtype: torch.Tensor
        """
        img_features = self.conv(img)
        return text * self.weights(img_features) + self.biases(img_features)


# Spatial Attention

In [None]:
class SpatialAttention(nn.Module):
    """Spatial attention module for attending textual context to visual features"""

    def __init__(self, d: int, d_hat: int) -> None:
        """
        Set up softmax and conv layers

        :param int d: Initial embedding size for textual features. D from paper
        :param int d_hat: Height of image feature map. D_hat from paper
        """
        super().__init__()
        self.softmax = nn.Softmax(2)
        self.conv = conv1d(d, d_hat)

    def forward(self, text_context: torch.Tensor, image: torch.Tensor) -> Any:
        """
        Project image features into the latent space
        of textual features and apply attention

        :param text_context: D x T tensor of hidden textual features
        :param image: D_hat x N visual features
        :return: Word features attended by visual features
        :rtype: Any
        """
        text_context = self.conv(text_context)
        image = torch.transpose(image, 1, 2)
        s_i_j = image @ text_context
        b_i_j = self.softmax(s_i_j)
        c_i_j = b_i_j @ torch.transpose(text_context, 1, 2)
        return torch.transpose(c_i_j, 1, 2)


# Channel Attention

In [None]:
class ChannelWiseAttention(nn.Module):
    """ChannelWise attention adapted from ControlGAN"""

    def __init__(self, fm_size: int, text_d: int) -> None:
        """
        Initialize the Channel-Wise attention module

        :param int fm_size:
            Height and width of feature map on k-th iteration of forward-pass.
            In paper, it's H_k * W_k
        :param int text_d: Dimensionality of sentence. From paper, it's D
        """
        super().__init__()
        # perception layer
        self.text_conv = conv1d(text_d, fm_size)
        # attention across channel dimension
        self.softmax = nn.Softmax(2)

    def forward(self, v_k: torch.Tensor, w_text: torch.Tensor) -> Any:
        """
        Apply attention to visual features taking into account features of words

        :param torch.Tensor v_k: Visual context
        :param torch.Tensor w_text: Textual features
        :return: Fused hidden visual features and word features
        :rtype: Any
        """
        w_hat = self.text_conv(w_text)
        m_k = v_k @ w_hat
        a_k = self.softmax(m_k)
        w_hat = torch.transpose(w_hat, 1, 2)
        return a_k @ w_hat

# Image Encoder

In [None]:
class CNN_ENCODER(nn.Module):
    def __init__(self, nef, train):
        super(CNN_ENCODER, self).__init__()
        if train:
            self.nef = nef
        else:
            self.nef = 256  # define a uniform ranker, this is TEXT.embedding_dimension

        model = models.inception_v3(init_weights = True)
        url = 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth'
        model.load_state_dict(model_zoo.load_url(url))
        for param in model.parameters():
            param.requires_grad = False
        print('Load pretrained model from ', url)

        self.define_module(model)
        self.init_trainable_weights()

    def define_module(self, model):
        self.Conv2d_1a_3x3 = model.Conv2d_1a_3x3
        self.Conv2d_2a_3x3 = model.Conv2d_2a_3x3
        self.Conv2d_2b_3x3 = model.Conv2d_2b_3x3
        self.Conv2d_3b_1x1 = model.Conv2d_3b_1x1
        self.Conv2d_4a_3x3 = model.Conv2d_4a_3x3
        self.Mixed_5b = model.Mixed_5b
        self.Mixed_5c = model.Mixed_5c
        self.Mixed_5d = model.Mixed_5d
        self.Mixed_6a = model.Mixed_6a
        self.Mixed_6b = model.Mixed_6b
        self.Mixed_6c = model.Mixed_6c
        self.Mixed_6d = model.Mixed_6d
        self.Mixed_6e = model.Mixed_6e
        self.Mixed_7a = model.Mixed_7a
        self.Mixed_7b = model.Mixed_7b
        self.Mixed_7c = model.Mixed_7c

        self.emb_features = conv1x1(768, self.nef)
        self.emb_cnn_code = nn.Linear(2048, self.nef)

    def init_trainable_weights(self):
        initrange = 0.1
        self.emb_features.weight.data.uniform_(-initrange, initrange)
        self.emb_cnn_code.weight.data.uniform_(-initrange, initrange)

    def forward(self, x):

        # this is the image size
        # x.shape: 10 3 256 256

        features = None
        # --> fixed-size input: batch x 3 x 299 x 299
        x = nn.Upsample(size=(299, 299), mode='bilinear')(x)
        # 299 x 299 x 3
        x = self.Conv2d_1a_3x3(x)
        # 149 x 149 x 32
        x = self.Conv2d_2a_3x3(x)
        # 147 x 147 x 32
        x = self.Conv2d_2b_3x3(x)
        # 147 x 147 x 64
        x = F.max_pool2d(x, kernel_size=3, stride=2)
        # 73 x 73 x 64
        x = self.Conv2d_3b_1x1(x)
        # 73 x 73 x 80
        x = self.Conv2d_4a_3x3(x)
        # 71 x 71 x 192

        x = F.max_pool2d(x, kernel_size=3, stride=2)
        # 35 x 35 x 192
        x = self.Mixed_5b(x)
        # 35 x 35 x 256
        x = self.Mixed_5c(x)
        # 35 x 35 x 288
        x = self.Mixed_5d(x)
        # 35 x 35 x 288

        x = self.Mixed_6a(x)
        # 17 x 17 x 768
        x = self.Mixed_6b(x)
        # 17 x 17 x 768
        x = self.Mixed_6c(x)
        # 17 x 17 x 768
        x = self.Mixed_6d(x)
        # 17 x 17 x 768
        x = self.Mixed_6e(x)
        # 17 x 17 x 768

        # image region features
        features = x
        # 17 x 17 x 768

        x = self.Mixed_7a(x)
        # 8 x 8 x 1280
        x = self.Mixed_7b(x)
        # 8 x 8 x 2048
        x = self.Mixed_7c(x)
        # 8 x 8 x 2048
        x = F.avg_pool2d(x, kernel_size=8)
        # 1 x 1 x 2048
        # x = F.dropout(x, training=self.training)
        # 1 x 1 x 2048
        x = x.view(x.size(0), -1)
        # 2048

        # global image features
        cnn_code = self.emb_cnn_code(x)
        # 512
        if features is not None:
            features = self.emb_features(features)

        # feature.shape: 10 256 17 17
        # cnn_code.shape: 10 256
        return features, cnn_code


In [None]:
img_enoder = CNN_ENCODER(256, True)
x = torch.randn(10, 3, 256, 256)
features, cnn_code = img_enoder(x)
print(features.shape)
print(cnn_code.shape)

# Upsample / Downsample

In [None]:
# Upsale the spatial size by a factor of 2
def upBlock(in_planes, out_planes):
    block = nn.Sequential(
        nn.Upsample(scale_factor=2, mode='nearest'),
        conv3x3(in_planes, out_planes * 2),
        nn.InstanceNorm2d(out_planes * 2),
        GLU())
    return block


def imgUpBlock(in_planes, out_planes):
    block = nn.Sequential(
        nn.Upsample(scale_factor=1.9, mode='nearest'),
        conv3x3(in_planes, out_planes * 2),
        nn.InstanceNorm2d(out_planes * 2),
        GLU())
    return block

def downBlock(in_planes, out_planes):
    block = nn.Sequential(
        nn.Conv2d(in_planes, out_planes, 4, 2, 1, bias=False),
        nn.BatchNorm2d(out_planes),
        nn.LeakyReLU(0.2, inplace=True)
    )
    return block


# Residual

In [None]:
class ResBlock(nn.Module):
    def __init__(self, channel_num):
        super(ResBlock, self).__init__()
        self.block = nn.Sequential(
            conv3x3(channel_num, channel_num * 2),
            nn.InstanceNorm2d(channel_num * 2),
            GLU(),
            conv3x3(channel_num, channel_num),
            nn.InstanceNorm2d(channel_num))

    def forward(self, x):
        residual = x
        out = self.block(x)
        out += residual
        return out

In [None]:
residual_block = ResBlock(3)
x = torch.randn(10, 3, 256, 256)
out = residual_block(x)
print(out.shape)

# Generator

In [None]:
class CA_NET(nn.Module):
    # some code is modified from vae examples
    # (https://github.com/pytorch/examples/blob/master/vae/main.py)
    def __init__(self):
        super(CA_NET, self).__init__()
        self.t_dim = cfg.TEXT.EMBEDDING_DIM # 256
        self.c_dim = cfg.GAN.CONDITION_DIM # 100
        self.fc = nn.Linear(self.t_dim, self.c_dim * 4, bias=True)
        self.relu = GLU()

    def encode(self, text_embedding):
        x = self.relu(self.fc(text_embedding))
        mu = x[:, :self.c_dim]
        logvar = x[:, self.c_dim:]
        return mu, logvar

    def reparametrize(self, mu, logvar, device):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std, device = device)
        return mu + eps*std

    def forward(self, text_embedding):
        mu, logvar = self.encode(text_embedding)
        c_code = self.reparametrize(mu, logvar)
        return c_code, mu, logvar


class INIT_STAGE_G(nn.Module):
    def __init__(self, ngf, ncf, nef):
        super(INIT_STAGE_G, self).__init__()
        self.gf_dim = ngf #512, as we pass ngf * 16 to this class from G_NET.
        self.in_dim = cfg.GAN.Z_DIM + ncf + cfg.TEXT.EMBEDDING_DIM #GAN.Z_DIM = 100, TEXT.EMBEDDING_DIM = 256, ncf = 100
        self.ef_dim = nef #256, i.e. text.embedding_dim

        self.define_module()

    def define_module(self):
        nz, ngf = self.in_dim, self.gf_dim
        self.fc = nn.Sequential(
            nn.Linear(nz, ngf * 4 * 4 * 2, bias=False),
            nn.BatchNorm1d(ngf * 4 * 4 * 2),
            GLU())

        self.upsample1 = upBlock(ngf, ngf // 2)
        self.upsample2 = upBlock(ngf // 2, ngf // 4)
        self.upsample3 = upBlock(ngf // 4, ngf // 8) # inputs are: 512 // 4 = 128, 512 // 8 = 64
        self.upsample4 = upBlock(ngf // 8 * 3, ngf // 16)

        self.residual = self._make_layer(ResBlock, ngf // 8 * 3)
        self.ACM = ACM(ngf // 8 * 3, img_chans = cfg.GAN.GF_DIM) 

        self.att = SpatialAttention(self.ef_dim, ngf // 8) #passing 256, 64 to SpatialAttention.
        self.channel_att = ChannelWiseAttention(32*32, self.ef_dim) #passing 32*32, 256 to ChannelWiseAttention

    def _make_layer(self, block, channel_num):
        layers = []
        for i in range(cfg.GAN.R_NUM): #R_NUM = 2
            layers.append(block(channel_num))
        return nn.Sequential(*layers)

    def forward(self, z_code, c_code, cnn_code, imgs, mask, word_embs):
        
        c_z_code = torch.cat((c_code, z_code), 1)

        # for testing
        if not cfg.TRAIN.FLAG and not cfg.B_VALIDATION:
            cnn_code = cnn_code.repeat(c_z_code.size(0), 1)

        c_z_cnn_code = torch.cat((c_z_code, cnn_code), 1)
        
        out_code = self.fc(c_z_cnn_code)
        out_code = out_code.view(-1, self.gf_dim, 4, 4)
        out_code = self.upsample1(out_code)
        out_code = self.upsample2(out_code)
        out_code32 = self.upsample3(out_code) #out_code32.shape = torch.Size([batch, channel, H, W])
        out_code32_comb = out_code32.view(out_code32.shape[0], -1, out_code32.shape[2] * out_code32.shape[3])

        # self.att.applyMask(mask)
        c_code = self.att(word_embs, out_code32_comb) #c_code shape: D^ x N, words_embs shape: D x T, out_code32_comb shape: D^ x N
        c_code_channel = self.channel_att(c_code, word_embs) #c_code_channel shape: D^ x N or C * (Hk * Wk)
        c_code = c_code.view(word_embs.size(0), -1, out_code32.size(2), out_code32.size(3)) #shape is (batch, channel, H, W)
        h_c_code = torch.cat((out_code32, c_code), 1)
        c_code_channel = c_code_channel.view(word_embs.size(0), -1, out_code32.size(2), out_code32.size(3)) #shape is (batch, channel, H, W)
        h_c_c_code = torch.cat((h_c_code, c_code_channel), 1)


        out_imgs_code32 = self.ACM(h_c_c_code, imgs)
        out_imgs_code32 = self.residual(out_imgs_code32)
        out_code64 = self.upsample4(out_imgs_code32)
        return out_code64

class NEXT_STAGE_G(nn.Module):
    def __init__(self, ngf, nef, ncf, size):
        super(NEXT_STAGE_G, self).__init__()
        self.gf_dim = ngf #32
        self.ef_dim = nef
        self.cf_dim = ncf
        self.num_residual = cfg.GAN.R_NUM #R_NUM = 2
        self.size = size
        self.define_module()

    def _make_layer(self, block, channel_num):
        layers = []
        for i in range(cfg.GAN.R_NUM): #R_NUM = 2
            layers.append(block(channel_num))
        return nn.Sequential(*layers)

    def define_module(self):
        ngf = self.gf_dim
        self.att = SpatialAttention(ngf, self.ef_dim) # passing 32, 256 to SPATIAL_NET.
        self.channel_att = ChannelWiseAttention(ngf, self.ef_dim, self.size) #passing 32, 256, 64 to CHANNEL_NET
        self.residual = self._make_layer(ResBlock, ngf * 3)
        self.upsample = upBlock(ngf * 3, ngf)
        self.ACM = ACM(ngf * 3) #passing 96 here
        self.upsample2 = upBlock(ngf, ngf)

    def forward(self, h_code, c_code, word_embs, mask, seg_img):
        """
            h_code1(query):  batch x idf x ih x iw (queryL=ihxiw)
            word_embs(context): batch x cdf x sourceL (sourceL=seq_len)
            c_code1: batch x idf x queryL
            att1: batch x sourceL x queryL
        """
        self.att.applyMask(mask)
        c_code, att = self.att(h_code, word_embs)
        c_code_channel, att_channel = self.channel_att(c_code, word_embs, h_code.size(2), h_code.size(3))
        c_code = c_code.view(word_embs.size(0), -1, h_code.size(2), h_code.size(3))

        h_c_code = torch.cat((h_code, c_code), 1)
        h_c_c_code = torch.cat((h_c_code, c_code_channel), 1)
        h_c_c_seg_code = self.ACM(h_c_c_code, seg_img)

        out_code = self.residual(h_c_c_seg_code)

        out_code = self.upsample(out_code)
        out_code = self.upsample2(out_code)

        return out_code, att

class GET_IMAGE_G(nn.Module):
    def __init__(self, ngf):
        super(GET_IMAGE_G, self).__init__()
        self.gf_dim = ngf
        self.img = nn.Sequential(
            conv3x3(ngf, 3),
            nn.Tanh()
        )

    def forward(self, h_code):
        out_img = self.img(h_code)
        return out_img



In [None]:
class G_NET(nn.Module):
    def __init__(self):
        super(G_NET, self).__init__()
        ngf = cfg.GAN.GF_DIM #32
        nef = cfg.TEXT.EMBEDDING_DIM #256
        ncf = cfg.GAN.CONDITION_DIM #100
        self.ca_net = CA_NET()

        if cfg.TREE.BRANCH_NUM > 0: #BRANCH_NUM for train_bird is set at 3
            self.h_net1 = INIT_STAGE_G(ngf * 16, ncf, nef)
            self.imgUpSample1 = imgUpBlock(nef, ngf)
            
        if cfg.TREE.BRANCH_NUM > 2:
            self.h_net3 = NEXT_STAGE_G(ngf, nef, ncf, 64)
            self.img_net = GET_IMAGE_G(ngf)
            self.ACM = ACM(ngf)
            self.imgUpSample2 = downBlock(nef//2, ngf)
            self.imgUpSample3 = upBlock(ngf, ngf)
            self.imgUpSample4 = upBlock(ngf, ngf)
    def forward(self, z_code, sent_emb, word_embs, mask, cnn_code, region_features, vgg_features):
        """
            :param z_code: batch x cfg.GAN.Z_DIM
            :param sent_emb: batch x cfg.TEXT.EMBEDDING_DIM
            :param word_embs: batch x cdf x seq_len
            :param mask: batch x seq_len
            :return:
        """
        fake_imgs = []
        att_maps = []
        c_code, mu, logvar = self.ca_net(sent_emb)
        if cfg.TREE.BRANCH_NUM > 0:
            img_code32 = self.imgUpSample1(region_features)
            h_code1 = self.h_net1(z_code, c_code, cnn_code, img_code32, mask, word_embs)

        if cfg.TREE.BRANCH_NUM > 2:
            img_code64 = self.imgUpSample2(vgg_features)
            h_code2, att2 = \
                self.h_net3(h_code1, c_code, word_embs, mask, img_code64)
            img_code128 = self.imgUpSample3(img_code64)            
            img_code256 = self.imgUpSample4(img_code128)            
            h_code3 = self.ACM(h_code2, img_code256)
            fake_img = self.img_net(h_code3)
            fake_imgs.append(fake_img)
            if att2 is not None:
                att_maps.append(att2)
        
        return fake_imgs, att_maps, mu, logvar
