In [None]:
import torch
import argparse
import numpy as np

from utils import *
from torch.utils.data import DataLoader
from solver import Solver
from config import get_args, get_config, output_dim_dict, criterion_dict
from data_loader import get_loader
from test_instance import TestMOSI, TestMOSEI

In [None]:
def set_seed(seed):
    # torch.set_default_tensor_type('torch.FloatTensor')
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        # torch.set_default_tensor_type('torch.cuda.FloatTensor')

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        use_cuda = True

In [None]:
# path to a pretrained word embedding file
word_emb_path = '/mnt/soyeon/workspace/glove.840B.300d.txt'
assert(word_emb_path is not None)

In [None]:
from datetime import datetime
from pathlib import Path
import pprint
from torch import optim
import torch.nn as nn

# username = Path.home().name
# project_dir = Path(__file__).resolve().parent.parent
# sdk_dir = project_dir.joinpath('CMU-MultimodalSDK')
# data_dir = project_dir.joinpath('datasets')
sdk_dir = Path('/mnt/soyeon/workspace/multimodal/CMU-MultimodalSDK')
data_dir = Path('/mnt/soyeon/workspace/multimodal/datasets')
data_dict = {'mosi': data_dir.joinpath('MOSI'), 'mosei': data_dir.joinpath(
    'MOSEI'), 'ur_funny': data_dir.joinpath('UR_FUNNY')}
optimizer_dict = {'RMSprop': optim.RMSprop, 'Adam': optim.Adam}
activation_dict = {'elu': nn.ELU, "hardshrink": nn.Hardshrink, "hardtanh": nn.Hardtanh,
                   "leakyrelu": nn.LeakyReLU, "prelu": nn.PReLU, "relu": nn.ReLU, "rrelu": nn.RReLU,
                   "tanh": nn.Tanh}

output_dim_dict = {
    'mosi': 1,
    'mosei_senti': 1,
}

criterion_dict = {
    'mosi': 'L1Loss',
    'iemocap': 'CrossEntropyLoss',
    'ur_funny': 'CrossEntropyLoss'
}

In [None]:
import easydict

args = easydict.EasyDict({
    # Tasks
    "dataset": "mosi",
    "data_path": "datasets",

    # Dropouts
    "dropout_a": 0.1,
    "dropout_v": 0.1,
    "dropout_prj": 0.1,

    # Architecture
    "multiseed": True,
    "contrast": True,
    "add_va": True,
    "n_layer": 1,
    "cpc_layers": 1,
    "d_vh": 16,
    "d_ah": 16,
    "d_vout": 16,
    "d_aout": 16,
    "bidirectional": True,
    "d_prjh": 128,
    "pretrain_emb": 768,

    # Activations
    "mmilb_mid_activation": "ReLU",
    "mmilb_last_activation": "Tanh",
    "cpc_activation": "Tanh",

    # Training Setting
    "batch_size": 32,
    "clip": 1.0,
    "lr_main": 1e-3,
    "lr_bert": 5e-5,
    "lr_mmilb": 1e-3,
    "alpha": 0.1,
    "beta": 0.1,
    "weight_decay_main": 1e-4,
    "weight_decay_bert": 1e-4,
    "weight_decay_club": 1e-4,
    "optim": "Adam",
    "num_epochs": 40,
    "when": 20,
    "patience": 10,
    "update_batch": 1,

    # Logistics
    "log_interval": 100,
    "seed": 1111
})

In [None]:
def str2bool(v):
    """string to boolean"""
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

In [None]:
class Config(object):
    def __init__(self, data, mode='train'):
        """Configuration Class: set kwargs as class attributes with setattr"""
        self.dataset_dir = data_dict[data.lower()]
        self.sdk_dir = sdk_dir
        self.mode = mode
        # Glove path
        self.word_emb_path = word_emb_path

        # Data Split ex) 'train', 'valid', 'test'
        self.data_dir = self.dataset_dir

    def __str__(self):
        """Pretty-print configurations in alphabetical order"""
        config_str = 'Configurations\n'
        config_str += pprint.pformat(self.__dict__)
        return config_str


def get_config(dataset='mosi', mode='train', batch_size=32):
    config = Config(data=dataset, mode=mode)
    
    config.dataset = dataset
    config.batch_size = batch_size

    return config

In [None]:
dataset = str.lower(args.dataset.strip())

set_seed(args.seed)
print("Start loading the data....")
train_config = get_config(dataset, mode='train', batch_size=args.batch_size)
valid_config = get_config(dataset, mode='valid', batch_size=args.batch_size)
test_config = get_config(dataset, mode='test',  batch_size=args.batch_size)

# pretrained_emb saved in train_config here
train_loader = get_loader(args, train_config, shuffle=True)
print('Training data loaded!')
valid_loader = get_loader(args, valid_config, shuffle=False)
print('Validation data loaded!')
test_loader = get_loader(args, test_config, shuffle=False)
print('Test data loaded!')
print('Finish loading the data....')

torch.autograd.set_detect_anomaly(True)

# addintional appending
args.word2id = train_config.word2id

# architecture parameters
args.d_tin, args.d_vin, args.d_ain = train_config.tva_dim
args.dataset = args.data = dataset
args.when = args.when
args.n_class = output_dim_dict.get(dataset, 1)
args.criterion = criterion_dict.get(dataset, 'MSELoss')

In [None]:
solver = Solver(args, train_loader=train_loader, dev_loader=valid_loader, test_loader=test_loader, is_train=True)

In [None]:
model = solver.model

In [None]:
model = solver.train_and_eval()

In [None]:
torch.save(model.state_dict(), "./saved_models_MMIM_mosi.pt")

### Model Load

In [None]:
model.load_state_dict(torch.load("./saved_models_MMIM_mosi.pt"))
model.eval()

In [None]:
segment_list = []
tester = TestMOSI
tester = tester(model)
segment_list, preds, preds_2, preds_7 = tester.start()

In [None]:
import pickle
# Gold-truth
labels = []
labels_2 = []
labels_7 = []
with open(f"../datasets/{args.dataset}.pkl", "rb") as handle:
    data = pickle.load(handle)

test_data = data["test"]

video = set()
count = 0

for idx in range(len(test_data)):
    (words, visual, acoustic), label, segment = test_data[idx]
    if args.dataset == 'mosi':
        assert segment_list[idx] == segment
    else:
        video_name = segment[0]
        if video_name in video:
            count += 1
        else:
            video.add(video_name)
            count = 0
        assert segment_list[idx] == segment

    labels.append(label[0][0])

    # label_2 appending
    if label > 0:
        labels_2.append('positive')
    else:
        labels_2.append('negative')
    
    # label_7 appending
    if label < -15/7:
        labels_7.append('very negative')
    elif label < -9/7:
        labels_7.append('negative')
    elif label < -3/7:
        labels_7.append('slightly negative')
    elif label < 3/7:
        labels_7.append('Neutral')
    elif label < 9/7:
        labels_7.append('slightly positive')
    elif label < 15/7:
        labels_7.append('positive')
    else:
        labels_7.append('very positive')
count = 0

In [None]:
from ipywidgets import interact

@interact
def get_predict_result(idx = range(len(segment_list))):
    print("SEGMENT:", segment_list[idx])
    print("GOLD_VALUE:", labels[idx])
    print("GOLD_BINARY:", labels_2[idx])
    print("GOLD_7_CLASS:", labels_7[idx])
    print("PREDICTED_VALUE:", preds[idx])
    print("PREDICTED_BINARY:", preds_2[idx])
    print("PREDICTED _7_CLASS:", preds_7[idx])

In [None]:
print(len(segment_list))
print(len(labels))
print(len(preds))

In [None]:
import plotly.express as px
import plotly.subplots as sp
import pandas as pd

d = {'segmentID': segment_list, 'labels': labels, 'labels_2': labels_2, 'labels_7': labels_7, 'preds': preds, 'preds_2': preds_2, 'preds_7': preds_7}
df = pd.DataFrame(data=d)
order = ['very negative', 'negative', 'slightly negative', 'Neutral', 'slightly positive', 'positive', 'very positive']

fig1 = px.bar(df, x="labels_7")
fig2 = px.bar(df, x="preds_7")

fig1_traces = []
fig2_traces = []

for trace in range(len(fig1["data"])):
    fig1_traces.append(fig1["data"][trace])
for trace in range(len(fig2["data"])):
    fig2_traces.append(fig2["data"][trace])

this_figure = sp.make_subplots(rows=1, cols=2, subplot_titles=("Gold", "MIM"))
for traces in fig1_traces:
    this_figure.append_trace(traces, row=1, col=1)
for traces in fig2_traces:
    this_figure.append_trace(traces, row=1, col=2)

this_figure.update_layout(height=600, width=1500, title_text="CMU-MOSI 7 Class Sentiment Intensity")
this_figure.update_xaxes(categoryorder='array', categoryarray= order)
this_figure.update_yaxes(range=[0,250])
this_figure.show()