In [22]:
import math
import torch
from torch import nn
import torch.nn.functional as F
#from .module import Module


class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


class GatedMultimodalLayer(nn.Module):
    """ Gated Multimodal Layer based on 'Gated multimodal networks, Arevalo1 et al.' (https://arxiv.org/abs/1702.01992) """
    def __init__(self, size_in1, size_in2, size_out):
        super(GatedMultimodalLayer, self).__init__()
        self.size_in1, self.size_in2, self.size_out = size_in1, size_in2, size_out

        # Weights hidden state modality 1
        weights_hidden1 = torch.Tensor(size_out, size_in1)
        self.weights_hidden1 = nn.Parameter(weights_hidden1)
        #self.weights_hidden1.requires_grad = True

        # Weights hidden state modality 2
        weights_hidden2 = torch.Tensor(size_out, size_in2)
        self.weights_hidden2 = nn.Parameter(weights_hidden2)
        #self.weights_hidden2.requires_grad = True

        # Weight for sigmoid
        weight_sigmoid = torch.Tensor(size_out*2)
        self.weight_sigmoid = nn.Parameter(weight_sigmoid)
        #self.weight_sigmoid.requires_grad = True

        # initialize weights
        nn.init.kaiming_uniform_(self.weights_hidden1, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.weights_hidden2, a=math.sqrt(5))

        # Activation functions
        self.tanh_f = nn.Tanh()
        self.sigmoid_f = nn.Sigmoid()

    def forward(self, x1, x2):
        h1 = self.tanh_f(torch.mm(x1, self.weights_hidden1.t()))
        h2 = self.tanh_f(torch.mm(x2, self.weights_hidden2.t()))
        x = torch.cat((h1, h2), dim=1)
        z = self.sigmoid_f(torch.matmul(x, self.weight_sigmoid.t()))

        return z.view(z.size()[0],1)*h1 + (1-z).view(z.size()[0],1)*h2
    
    
class MaxOut(nn.Module):
    def __init__(self, input_dim, output_dim, num_units=2):
        super(MaxOut, self).__init__()
        self.fc1_list = nn.ModuleList([nn.Linear(input_dim, output_dim) for i in range(num_units)])

    def forward(self, x): 

        return self.maxout(x, self.fc1_list)

    def maxout(self, x, layer_list):
        max_output = layer_list[0](x)
        for _, layer in enumerate(layer_list, start=1):
            max_output = torch.max(max_output, layer(x))
        return max_output


class MLPGenreClassifierModel(nn.Module):

    def __init__(self, hyp_params):

        super(MLPGenreClassifierModel, self).__init__()
        if hyp_params.text_embedding_size == hyp_params.image_feature_size:
            self.bn1 = nn.BatchNorm1d(hyp_params.hidden_size)
            self.linear1 = MaxOut(hyp_params.hidden_size, hyp_params.hidden_size)
        else:
            self.bn1 = nn.BatchNorm1d(hyp_params.text_embedding_size+hyp_params.image_feature_size)
            self.linear1 = MaxOut(hyp_params.text_embedding_size+hyp_params.image_feature_size, hyp_params.hidden_size)
        self.drop1 = nn.Dropout(p=hyp_params.mlp_dropout)
        
        self.bn2 = nn.BatchNorm1d(hyp_params.hidden_size)
        self.linear2 = MaxOut(hyp_params.hidden_size, hyp_params.hidden_size)
        self.drop2 = nn.Dropout(p=hyp_params.mlp_dropout)
        
        self.bn3 = nn.BatchNorm1d(hyp_params.hidden_size)
        self.linear3 = nn.Linear(hyp_params.hidden_size, hyp_params.output_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids, feature_images=None):
        if feature_images is None:
            x = input_ids
        else:
            x = torch.cat((input_ids, feature_images), dim=1)
        x = self.bn1(x)
        x = self.linear1(x)
        x = self.drop1(x)
        x = self.bn2(x)
        x = self.linear2(x)
        x = self.drop2(x)
        x = self.bn3(x)
        x = self.linear3(x)

        return self.sigmoid(x)
    

class ConcatenateModel(nn.Module):

    def __init__(self, hyp_params):

        super(ConcatenateModel, self).__init__()
        self.linear1 = MaxOut(hyp_params.text_embedding_size+hyp_params.image_feature_size, hyp_params.hidden_size)
        self.bn1 = nn.BatchNorm1d(hyp_params.hidden_size)
        self.drop1 = nn.Dropout(p=hyp_params.mlp_dropout)

        self.linear2 = MaxOut(hyp_params.hidden_size, hyp_params.hidden_size)
        self.bn2 = nn.BatchNorm1d(hyp_params.hidden_size)
        self.drop2 = nn.Dropout(p=hyp_params.mlp_dropout)
        
        self.linear3 = nn.Linear(hyp_params.hidden_size, hyp_params.output_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids, feature_images):
        
        x = torch.cat((input_ids, feature_images), dim=1)
        x = self.drop1(x)
        x = self.linear1(x)
        x = self.bn1(x)
        x = self.drop2(x)
        x = self.linear2(x)
        x = self.bn2(x)
        x = self.linear3(x)

        return self.sigmoid(x)
    
    
class GMUModel(nn.Module):

    def __init__(self, hyp_params):

        super(GMUModel, self).__init__()
        self.hyp_params = {}
        self.hyp_params['text_embedding_size'] = hyp_params.hidden_size
        self.hyp_params['image_feature_size'] = hyp_params.hidden_size
        self.hyp_params['hidden_size'] = hyp_params.hidden_size
        self.hyp_params['output_dim'] = hyp_params.output_dim
        self.hyp_params['mlp_dropout'] = hyp_params.mlp_dropout
        self.hyp_params = dotdict(self.hyp_params)
        
        self.visual_mlp = torch.nn.Sequential(
            nn.BatchNorm1d(hyp_params.image_feature_size),
            nn.Linear(hyp_params.image_feature_size, hyp_params.hidden_size)
        )
        self.textual_mlp = torch.nn.Sequential(
            nn.BatchNorm1d(hyp_params.text_embedding_size),
            nn.Linear(hyp_params.text_embedding_size, hyp_params.hidden_size)
        )
        
        self.gmu = GatedMultimodalLayer(hyp_params.hidden_size, hyp_params.hidden_size, hyp_params.hidden_size)
        
        self.logistic_mlp = MLPGenreClassifierModel(self.hyp_params)

    def forward(self, input_ids, feature_images):
        
        x_v = self.visual_mlp(feature_images)
        x_t = self.textual_mlp(input_ids)
        x = self.gmu(x_v, x_t)
        

        return self.logistic_mlp(x)


In [18]:
linear = nn.Linear(10, 10)

In [19]:
print(linear)

Linear(in_features=10, out_features=10, bias=True)


In [5]:
gmu = GatedMultimodalLayer(10, 10, 10)

In [6]:
gmu

GatedMultimodalLayer(
  (tanh_f): Tanh()
  (sigmoid_f): Sigmoid()
)

In [11]:
from torchviz import make_dot

In [23]:
y = gmu(torch.rand(1,10), torch.rand(1,10))

In [25]:
y = linear(torch.rand(1,10))

In [26]:
make_dot(y).render("attached", format="png")

'attachedLinear.png'