In [1]:
import numpy as np
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from collections import OrderedDict

import importlib

import dataloader
from audio_conv2d import AudioConv2d
from Densenet import densenet169
from simple_densenet import simple_densenet169

In [2]:
class MultimodalAutoencoder(nn.Module):
    """
    Autoencoder for multi-modal data fusion. The purpose of the autoencoder
    is to generate a hidden representation that correlates audio, video features 
    with the binary classification.
    """
    ENCODER_HIDDEN = [2048, 1024]
    DECODER_A_HIDDEN = [1024]
    DECODER_V_HIDDEN = [1024]
    CLASSIFIER = [512, 64]
    P_DROPOUT = 0.2
    ZETA = 0.8
    
    def __init__(self, a_in_shape, v_in_shape):
        super().__init__()
        
        self.a_in = a_in_shape
        self.v_in = v_in_shape
        self.activation = nn.ReLU6()
        self.p_dropout = self.P_DROPOUT
        self.zeta = self.ZETA
        
        encoder_layer_sizes = [self.a_in + self.v_in] + self.ENCODER_HIDDEN
        decoder_a_layer_sizes = [self.ENCODER_HIDDEN[-1]] + self.DECODER_A_HIDDEN + [self.a_in]
        decoder_v_layer_sizes = [self.ENCODER_HIDDEN[-1]] + self.DECODER_V_HIDDEN + [self.v_in]
        classifier_layer_sizes = [self.ENCODER_HIDDEN[-1]] + self.CLASSIFIER

        self.encoder = nn.Sequential(*self.get_layers(encoder_layer_sizes, dropout=True))
        self.decoder_a = nn.Sequential(*self.get_layers(decoder_a_layer_sizes))
        self.decoder_v = nn.Sequential(*self.get_layers(decoder_v_layer_sizes))
        self.classifier = nn.Sequential(*self.get_layers(classifier_layer_sizes))
        self.classifier_scoring = nn.Linear(classifier_layer_sizes[-1], 2)
    
    def get_layers(self, layer_sizes, dropout=False):
        layers = []
        if dropout: layers.append(nn.Dropout(self.p_dropout))
        for i in range(len(layer_sizes)-1):
            layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))
            layers.append(self.activation)
        return layers
    
    def forward(self, a_in, v_in):
        batch_size = v_in.shape[0]
        assert(v_in.size() == (batch_size, 1664) and a_in.size() == (batch_size, 2300))
        x = torch.cat((a_in, v_in), dim=1)
        x = self.encoder(x)
        a_out = self.decoder_a(x)
        v_out = self.decoder_v(x)
        classification_embedding = self.classifier(x)
        binary_out = self.classifier_scoring(classification_embedding)
        return a_out, v_out, binary_out, classification_embedding
    
    def loss(self, model_output, target):
        model_v, model_a, model_binary = model_output
        target_v, target_a, target_binary = target
        
        batch_size = model_v.size(0)
        assert(model_v.shape == target_v.shape)
        assert(model_a.shape == target_a.shape)
        assert(model_binary.size() == (batch_size, 2) and target_binary.size() == (batch_size, ))

        v_loss = F.mse_loss(model_v, target_v)
        a_loss = F.mse_loss(model_a, target_a)
        binary_loss = F.cross_entropy(model_binary, target_binary)

        return (a_loss + v_loss)*(1-self.zeta) + binary_loss*self.zeta

In [3]:
class MultiModalAESimpleDensenetModel(nn.Module):
    AUDIO_DIM = 2300
    VIDEO_IN_DIM = 640
    VIDEO_DIM = 1664
    NUM_SUBTRACT_LAYERS = 8
    def __init__(self):
        super(MultiModalAESimpleDensenetModel, self).__init__()
        self.densenet = simple_densenet169(pretrained=False, progress=True, memory_efficient=False)
        self.densenet_linear = nn.Linear(self.VIDEO_IN_DIM, self.VIDEO_DIM)
        self.audionet = AudioConv2d()
        self.ae = MultimodalAutoencoder(self.AUDIO_DIM, self.VIDEO_DIM)
    
    def forward(self, videos, audios):
        batch_size = videos.shape[0]
        assert(videos.size() == (batch_size, 3, 224, 224) and audios.size() == (batch_size, 5, 50))
        video_embed = self.densenet(videos) # (N, C, H, W) -> (N, num_features:640)
        video_embed = self.densenet_linear(video_embed) # -> (N, num_features:1664)
        audio_embed = self.audionet(audios) # -> (N, num_features:2300)
        audio_out, video_out, binary_out, classification_embedding = self.ae(audio_embed, video_embed)
        
        assert(video_embed.shape == video_out.shape and audio_embed.shape == audio_out.shape)
        assert(binary_out.size() == (batch_size, 2) and classification_embedding.size() == (batch_size, self.ae.CLASSIFIER[-1]))
        return video_out, audio_out, binary_out, classification_embedding, video_embed, audio_embed
    
    def loss(self, *args):
        return self.ae.loss(*args)
        

In [4]:
def load_data():
    BATCH_SIZE = 2

    train_video_dataset = dataloader.get_dataset(dataloader.TRAIN_JSON_PATH, dataloader.SINGLE_FRAME)
    train_audio_dataset = dataloader.AudioDataset()
    train_loader = dataloader.AVDataLoader(train_video_dataset, train_audio_dataset, batch_size=BATCH_SIZE, shuffle=False, single_frame=True)
    return train_loader

In [5]:
def verify_model(train_loader):
    model = MultiModalAESimpleDensenetModel()
    model.densenet.load_state_dict(torch.load("densenet169.pth")["model_state_dict"])
    print(model)

    for v, a, _, _, l in train_loader:
        print('videos shape:', v.shape) # batch_size*3(channel)*224*224
        print('audios shape:', a.shape) # batch_size*5*50(channel)
        print('labels shape:', l.shape) # batch_size

        video_out, audio_out, binary_out, classification_embedding, video_embed, audio_embed = model(v, a)
        print("out")
        print(video_out.shape)
        print(audio_out.shape)
        print(binary_out.shape)
        print(classification_embedding.shape)
        print(binary_out)
        loss = model.loss((video_out, audio_out, binary_out), (video_embed, audio_embed, l))
        print(loss)
        break

In [6]:
if __name__ == "__main__":
    train_loader = load_data()
    verify_model(train_loader)
#     verify_model(None)

loaded 733589 images 
MultiModalAESimpleDensenetModel(
  (densenet): SimpleDenseNet(
    (features): Sequential(
      (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu0): ReLU(inplace=True)
      (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (denseblock1): _DenseBlock(
        (layers): ModuleDict(
          (denselayer1): _DenseLayer(
            (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu1): ReLU(inplace=True)
            (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu2): ReLU(inplace=True)
            (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       

torch.Size([2, 640])
out
torch.Size([2, 1664])
torch.Size([2, 2300])
torch.Size([2, 2])
torch.Size([2, 64])
tensor([[ 0.0167,  0.0117],
        [ 0.0269, -0.0039]], grad_fn=<AddmmBackward>)
tensor(0.9538, grad_fn=<AddBackward0>)
