## Multimodal Autoencoder

In [1]:
import time
import torch
import torch.nn as nn
import scipy.sparse as sp
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [None]:
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")

In [None]:
class MultimodalAutoencoder(nn.Module):
    """
    Autoencoder for multi-modal data fusion. The purpose of the autoencoder
    is to generate a hidden representation that correlates audio, video features 
    with the binary classification.
    """
    ENCODER_HIDDEN = [2048, 1024]
    DECODER_A_HIDDEN = [1024]
    DECODER_V_HIDDEN = [1024]
    CLASSIFIER = [512, 64, 2]
    
    def __init__(self, a_in_shape, v_in_shape, p_dropout)
        super().__init__()
        
        self.a_in = a_in_shape[1]
        self.v_in = v_in_shape[1]
        self.activation = nn.ReLU6()
        self.p_dropout = p_dropout
        
        encoder_layer_sizes = [self.a_in + self.v_in] + ENCODER_HIDDEN
        decoder_a_layer_sizes = DECODER_A_HIDDEN + [self.a_in]
        decoder_v_layer_sizes = DECODER_V_HIDDEN + [self.v_in]
        classifier_layer_sizes = CLASSIFIER

        self.encoder = nn.Sequential(*self.get_layers(encoder_layer_sizes, dropout=True))
        self.decoder_a = nn.Sequential(*self.get_layers(decoder_a_layer_sizes))
        self.decoder_v = nn.Sequential(*self.get_layers(decoder_v_layer_sizes))
        self.classifier = nn.Sequential(*self.get_layers(classifier_layer_sizes))
    
    def get_layers(self, layer_sizes, dropout=False):
        layers = []
        if dropout: layers.append(nn.Dropout(self.p_dropout))
        for i in range(len(layer_sizes)-1):
            layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))
            layers.append(self.activation)
        return layers
    
    def forward(self, a_in, v_in):
        x = torch.cat((a_in, v_in), 1)
        x = self.encoder(x)
        a_out = self.decoder_a(x)
        v_out = self.decoder_v(x)
        binary_out = self.classifier(x)
        return a_out, v_out, binary_out

In [None]:
def loss(model_output, target, zeta):
    a_in, v_in, binary_in = model_output
    a_out, v_out, binary_out = target
    
    assert(a_in.shape == a_out.shape)
    assert(v_in.shape == v_out.shape)
    assert(binary_in.shape == binary_out.shape)
    
    a_loss = nn.MSELoss(a_in, a_out)
    v_loss = nn.MSELoss(v_in, v_out)
    binary_loss = nn.BCELoss(binary_in, binary_out)
    
    return (a_loss + v_loss)*(1-zeta) + binary_loss*zeta