In [54]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import os
import pickle
import numpy as np
import argparse
from random import random

from torch import optim
from sklearn.metrics import f1_score

import torch
import torch.nn as nn
from torch.nn import functional as F

from sklearn.utils import shuffle

import numpy as np
import random

import torch
import torch.nn as nn
from torch.autograd import Function
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

In [None]:
def preprocess(au_mfcc_path):
    data = []
    labels = []
    with open(au_mfcc_path, 'rb') as f:
        au_mfcc = pickle.load(f)

    print(len(au_mfcc))

    for key in au_mfcc:
        emotion = key.split('-')[2]
        emotion = int(emotion)-1
        labels.append(emotion)
        data.append(au_mfcc[key])

    data=np.array(data)
    labels = np.array(labels)
    labels = labels.reshape(labels.shape+(1,))

    data = np.hstack((data, labels))
    fdata = shuffle(data)

    data = fdata[:, :-1]
    labels = fdata[:, -1].astype(int)

    return data, labels

In [59]:
class MMF_Model(nn.Module):
    def __init__(self):
        super(MMF_Model, self).__init__()

        rnn = nn.LSTM

        self.au_rnn1 = rnn(35, 16, bidirectional=True)
        self.au_rnn2 = rnn(2*16, 16, bidirectional=True)

        self.mfccs_rnn1 = rnn(259, 16, bidirectional=True)
        self.mfccs_rnn2 = rnn(2*16, 16, bidirectional=True)

        self.fusion_layer = nn.Linear(in_features=128, out_features=8)

    def extract_au(self, au, lengths):
        packed_sequence = pack_padded_sequence(au, lengths)
        packed_h1, (final_h1, _) = self.au_rnn1(packed_sequence)
        padded_h1, _ = pad_packed_sequence(packed_h1)
        packed_normed_h1 = pack_padded_sequence(padded_h1, lengths)
        _, (final_h2, _) = self.au_rnn2(packed_normed_h1)
        extracted_au = torch.cat((final_h1, final_h2), dim=2).permute(1,0,2).contiguous().view(batch_size,-1)

        return extracted_au

    def extract_mfccs(self, mfccs, lengths):

        packed_sequence = pack_padded_sequence(mfccs, lengths)
        packed_h1, (final_h1, _) = self.mfccs_rnn1(packed_sequence)
        padded_h1, _ = pad_packed_sequence(packed_h1)
        packed_normed_h1 = pack_padded_sequence(padded_h1, lengths)
        _, (final_h2, _) = self.mfccs_rnn2(packed_normed_h1)
        extracted_mfccs = torch.cat((final_h1, final_h2), dim=2).permute(1,0,2).contiguous().view(batch_size,-1)

        return extracted_mfccs

    def forward(self, au, mfccs, lengths):
        batch_size = 60

        extracted_au = self.extract_au(au, lengths)
        extracted_mfccs = self.extract_mfccs(mfccs, lengths)

        au_mfccs_fusion = torch.cat((extracted_au, extracted_mfccs), dim=1)

        final_output = self.fusion_layer(au_mfccs_fusion)
        return final_output

def eval(data, labels, mode=None, to_print=False):
    assert(mode is not None)

    model.eval()

    y_true, y_pred = [], []
    eval_loss, eval_loss_diff = [], []

    if mode == "test":
        if to_print:
            model.load_state_dict(torch.load(
                f'/content/drive/MyDrive/multimodal-fusion/model.ckpt'))

    corr=0
    with torch.no_grad():
        for i in range(0, len(data), 60):
            model.zero_grad()
            # v, a, y, l = batch
            d=data[i:i+60]
            l=labels[i:i+60]
            d=np.expand_dims(d,axis=0)
            au=torch.from_numpy(d[:, :, :35]).float()
            mfccs=torch.from_numpy(d[:, :, 35:]).float()
            y=torch.from_numpy(l).float()

            lengths = torch.LongTensor([au.shape[0]]*au.size(1))

            au = au.cuda()
            mfccs = mfccs.cuda()
            y = y.cuda()

            output = model(au, mfccs, lengths)

            loss =  criterion(output, y)

            eval_loss.append(loss.item())
            preds=output.detach().cpu().numpy()
            y_trues=y.detach().cpu().numpy()

            for j in range(len(preds)):
                pred=np.argmax(preds[j])
                y_true=np.argmax(y_trues[j])
                if pred==y_true:
                    corr+=1

    eval_loss = np.mean(eval_loss)

    accuracy = corr/(1.0*len(labels))

    return eval_loss, accuracy


In [63]:
if __name__ == '__main__':

    device = torch.cuda.is_available()

    data_path = r'/content/drive/MyDrive/multimodal-fusion/au_mfcc.pkl'

    data, labels=preprocess(data_path)
    print('u:', np.unique(labels.astype(int)).size)
    new_labels= np.zeros((labels.shape[0], np.unique(labels.astype(int)).size))

    for i in range(len(labels)):
        new_labels[i, labels[i]]=1

    labels=new_labels

    test_data=data[-181:-1]
    test_labels=labels[-181:-1]
    data=data[:-180]
    labels=labels[:-180]

    train_data=data[:1020]
    train_labels=labels[:1020]

    dev_data=data[1020:]
    dev_labels=labels[1020:]

    model = MMF_Model()

    model.cuda()

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

    criterion = nn.CrossEntropyLoss()

    batch_size=60
    n_total=len(train_data)
    best_loss=float('inf')
    for e in range(50):
        model.train()
        total_loss=0
        cnt=0
        print(f"=====Epoch{e+1}======")
        for i in range(0, len(train_data), batch_size):
            data=train_data[i:i+60]
            label=train_labels[i:i+60]

            model.zero_grad()
            # v, a, y, l = batch
            data=np.expand_dims(data,axis=0)
            au=torch.from_numpy(data[:, :, :35]).float()
            mfccs=torch.from_numpy(data[:, :, 35:]).float()
            y=torch.from_numpy(label).float()

            au = au.cuda()
            mfccs = mfccs.cuda()

            y = y.cuda()

            lengths = torch.LongTensor([au.shape[0]]*au.size(1))
            fused_features = model(au, mfccs, lengths)

            loss = criterion(fused_features, y)

            loss.backward()

            optimizer.step()

        train_loss, train_acc = eval(train_data, train_labels, mode="train")
        print('train_loss: {:.3f}, train_acc: {:.2f}%'.format(train_loss, 100*train_acc))

        valid_loss, valid_acc = eval(dev_data, dev_labels, mode="dev")
        print('valid_loss: {:.3f}, valid_acc: {:.2f}%'.format(valid_loss, 100*valid_acc))

        if valid_loss < best_loss:
            best_loss = valid_loss
            torch.save(model.state_dict(), '/content/drive/MyDrive/multimodal-fusion/model.ckpt')
            torch.save(optimizer.state_dict(), '/content/drive/MyDrive/multimodal-fusion/optim_best.std')
        else:
            model.load_state_dict(torch.load('/content/drive/MyDrive/multimodal-fusion/model.ckpt'))
            optimizer.load_state_dict(torch.load('/content/drive/MyDrive/multimodal-fusion/optim_best.std'))

    test_loss, test_acc=eval(test_data, test_labels, mode="test", to_print=True)
    print('test_loss: {:.3f} test_acc: {:.2f}%'.format(test_loss, 100*test_acc))

1440
u: 8
train_loss: 2.030, train_acc: 26.47%
valid_loss: 2.044, valid_acc: 21.25%
train_loss: 1.969, train_acc: 39.51%
valid_loss: 1.995, valid_acc: 28.33%
train_loss: 1.877, train_acc: 37.65%
valid_loss: 1.911, valid_acc: 29.17%
train_loss: 1.734, train_acc: 45.00%
valid_loss: 1.775, valid_acc: 35.83%
train_loss: 1.580, train_acc: 44.51%
valid_loss: 1.628, valid_acc: 39.17%
train_loss: 1.450, train_acc: 50.49%
valid_loss: 1.502, valid_acc: 45.00%
train_loss: 1.354, train_acc: 53.33%
valid_loss: 1.416, valid_acc: 48.75%
train_loss: 1.276, train_acc: 56.27%
valid_loss: 1.338, valid_acc: 52.92%
train_loss: 1.216, train_acc: 58.04%
valid_loss: 1.279, valid_acc: 55.42%
train_loss: 1.157, train_acc: 59.71%
valid_loss: 1.225, valid_acc: 54.17%
train_loss: 1.105, train_acc: 61.96%
valid_loss: 1.175, valid_acc: 58.33%
train_loss: 1.057, train_acc: 63.43%
valid_loss: 1.124, valid_acc: 59.58%
train_loss: 1.023, train_acc: 64.80%
valid_loss: 1.092, valid_acc: 60.00%
train_loss: 0.987, train_acc