In [1]:
import torch
import argparse
import numpy as np

from utils import *
from torch.utils.data import DataLoader
from solver import Solver
from config import get_args, get_config, output_dim_dict, criterion_dict
from data_loader import get_loader
from test_instance import TestMOSI, TestMOSEI

  '"sox" backend is being deprecated. '
loading file https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt from cache at /home/ubuntu/.cache/huggingface/transformers/45c3f7a79a80e1cf0a489e5c62b43f173c15db47864303a55d623bb3c96f72a5.d789d64ebfe299b0e416afc4a169632f903f693095b4629a7ea271d5a0cf2c99
loading file https://huggingface.co/bert-base-uncased/resolve/main/added_tokens.json from cache at None
loading file https://huggingface.co/bert-base-uncased/resolve/main/special_tokens_map.json from cache at None
loading file https://huggingface.co/bert-base-uncased/resolve/main/tokenizer_config.json from cache at /home/ubuntu/.cache/huggingface/transformers/c1d7f0a763fb63861cc08553866f1fc3e5a6f4f07621be277452d26d71303b7e.20430bd8e10ef77a7d2977accefe796051e01bc2fc4aa146bc862997a1a15e79
loading configuration file https://huggingface.co/bert-base-uncased/resolve/main/config.json from cache at /home/ubuntu/.cache/huggingface/transformers/3c61d016573b14f7f008c02c4e51a366c67ab274726fe29106

In [2]:
def set_seed(seed):
    # torch.set_default_tensor_type('torch.FloatTensor')
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        # torch.set_default_tensor_type('torch.cuda.FloatTensor')

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        use_cuda = True

In [3]:
# path to a pretrained word embedding file
# word_emb_path = '/mnt/soyeon/workspace/glove.840B.300d.txt'
word_emb_path = '/home/ubuntu/soyeon/glove.840B.300d.txt'
assert(word_emb_path is not None)

In [4]:
from datetime import datetime
from pathlib import Path
import pprint
from torch import optim
import torch.nn as nn

# username = Path.home().name
# project_dir = Path(__file__).resolve().parent.parent
# sdk_dir = project_dir.joinpath('CMU-MultimodalSDK')
# data_dir = project_dir.joinpath('datasets')

# sdk_dir = Path('/mnt/soyeon/workspace/multimodal/CMU-MultimodalSDK')
# data_dir = Path('/mnt/soyeon/workspace/multimodal/datasets')
sdk_dir = Path('/home/ubuntu/soyeon/CMU-MultimodalSDK')
data_dir = Path('/home/ubuntu/soyeon/MSIR/datasets')

data_dict = {'mosi': data_dir.joinpath('MOSI'), 'mosei': data_dir.joinpath(
    'MOSEI'), 'ur_funny': data_dir.joinpath('UR_FUNNY')}
optimizer_dict = {'RMSprop': optim.RMSprop, 'Adam': optim.Adam}
activation_dict = {'elu': nn.ELU, "hardshrink": nn.Hardshrink, "hardtanh": nn.Hardtanh,
                   "leakyrelu": nn.LeakyReLU, "prelu": nn.PReLU, "relu": nn.ReLU, "rrelu": nn.RReLU,
                   "tanh": nn.Tanh}

output_dim_dict = {
    'mosi': 1,
    'mosei_senti': 1,
}

criterion_dict = {
    'mosi': 'L1Loss',
    'iemocap': 'CrossEntropyLoss',
    'ur_funny': 'CrossEntropyLoss'
}

In [5]:
import easydict

args = easydict.EasyDict({
    # Tasks
    "dataset": "mosi",
    "data_path": "datasets",

    # Dropouts
    "dropout_a": 0.1,
    "dropout_v": 0.1,
    "dropout_prj": 0.1,

    # Architecture
    "multiseed": True,
    "contrast": True,
    "add_va": True,
    "n_layer": 1,
    "cpc_layers": 1,
    "d_vh": 16,
    "d_ah": 16,
    "d_vout": 16,
    "d_aout": 16,
    "bidirectional": True,
    "d_prjh": 128,
    "pretrain_emb": 768,

    # Activations
    "mmilb_mid_activation": "ReLU",
    "mmilb_last_activation": "Tanh",
    "cpc_activation": "Tanh",

    # Training Setting
    "batch_size": 32,
    "clip": 1.0,
    "lr_main": 1e-3,
    "lr_bert": 5e-5,
    "lr_mmilb": 1e-3,
    "alpha": 0.1,
    "beta": 0.1,
    "weight_decay_main": 1e-4,
    "weight_decay_bert": 1e-4,
    "weight_decay_club": 1e-4,
    "optim": "Adam",
    "num_epochs": 40,
    "when": 20,
    "patience": 10,
    "update_batch": 1,

    # Logistics
    "log_interval": 100,
    "seed": 1111
})

In [6]:
def str2bool(v):
    """string to boolean"""
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

In [7]:
class Config(object):
    def __init__(self, data, mode='train'):
        """Configuration Class: set kwargs as class attributes with setattr"""
        self.dataset_dir = data_dict[data.lower()]
        self.sdk_dir = sdk_dir
        self.mode = mode
        # Glove path
        self.word_emb_path = word_emb_path

        # Data Split ex) 'train', 'valid', 'test'
        self.data_dir = self.dataset_dir

    def __str__(self):
        """Pretty-print configurations in alphabetical order"""
        config_str = 'Configurations\n'
        config_str += pprint.pformat(self.__dict__)
        return config_str


def get_config(dataset='mosi', mode='train', batch_size=32):
    config = Config(data=dataset, mode=mode)
    
    config.dataset = dataset
    config.batch_size = batch_size

    return config

In [8]:
dataset = str.lower(args.dataset.strip())

set_seed(args.seed)
print("Start loading the data....")
train_config = get_config(dataset, mode='train', batch_size=args.batch_size)
valid_config = get_config(dataset, mode='valid', batch_size=args.batch_size)
test_config = get_config(dataset, mode='test',  batch_size=args.batch_size)

# pretrained_emb saved in train_config here
train_loader = get_loader(args, train_config, shuffle=True)
print('Training data loaded!')
valid_loader = get_loader(args, valid_config, shuffle=False)
print('Validation data loaded!')
test_loader = get_loader(args, test_config, shuffle=False)
print('Test data loaded!')
print('Finish loading the data....')

torch.autograd.set_detect_anomaly(True)

# addintional appending
args.word2id = train_config.word2id

# architecture parameters
args.d_tin, args.d_vin, args.d_ain = train_config.tva_dim
args.dataset = args.data = dataset
args.when = args.when
args.n_class = output_dim_dict.get(dataset, 1)
args.criterion = criterion_dict.get(dataset, 'MSELoss')

Start loading the data....
train
Training data loaded!
valid
Validation data loaded!
test
Test data loaded!
Finish loading the data....


In [9]:
solver = Solver(args, train_loader=train_loader, dev_loader=valid_loader, test_loader=test_loader, is_train=True)

loading configuration file https://huggingface.co/bert-base-uncased/resolve/main/config.json from cache at /home/ubuntu/.cache/huggingface/transformers/3c61d016573b14f7f008c02c4e51a366c67ab274726fe2910691e2a761acf43e.37395cee442ab11005bcd270f3c34464dc1704b715b5d7d52b1a461abe3b9e4e
Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "output_hidden_states": true,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.18.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

loading weights file https://huggingface.co/bert-base-uncased/res

In [10]:
model = solver.model

In [11]:
model = solver.train_and_eval()

100%|██████████| 41/41 [00:21<00:00,  1.89it/s]
100%|██████████| 41/41 [00:25<00:00,  1.61it/s]


--------------------------------------------------
Epoch  1 | Time 48.8909 sec | Valid Loss 0.8566 | Test Loss 0.9392
--------------------------------------------------
MAE:  0.9391528
Correlation Coefficient:  0.755333239532547
mult_acc_7:  0.3629737609329446
mult_acc_5:  0.41836734693877553
F1 score all/non0: 0.7898/0.7918 over 686/656
Accuracy all/non0: 0.7901/0.7912
--------------------------------------------------
Saved model at pre_trained_models/MM.pt!


100%|██████████| 41/41 [00:21<00:00,  1.89it/s]
100%|██████████| 41/41 [00:25<00:00,  1.61it/s]
  0%|          | 0/41 [00:00<?, ?it/s]

--------------------------------------------------
Epoch  2 | Time 49.0569 sec | Valid Loss 0.8426 | Test Loss 0.9577
--------------------------------------------------


100%|██████████| 41/41 [00:21<00:00,  1.90it/s]
100%|██████████| 41/41 [00:25<00:00,  1.59it/s]


--------------------------------------------------
Epoch  3 | Time 49.3721 sec | Valid Loss 0.7948 | Test Loss 0.7741
--------------------------------------------------
MAE:  0.77414805
Correlation Coefficient:  0.7570627137874315
mult_acc_7:  0.4446064139941691
mult_acc_5:  0.5160349854227405
F1 score all/non0: 0.8107/0.8254 over 686/656
Accuracy all/non0: 0.8105/0.8247
--------------------------------------------------
Saved model at pre_trained_models/MM.pt!


100%|██████████| 41/41 [00:21<00:00,  1.89it/s]
100%|██████████| 41/41 [00:25<00:00,  1.61it/s]
  0%|          | 0/41 [00:00<?, ?it/s]

--------------------------------------------------
Epoch  4 | Time 49.0612 sec | Valid Loss 0.8010 | Test Loss 0.8798
--------------------------------------------------


100%|██████████| 41/41 [00:21<00:00,  1.90it/s]
100%|██████████| 41/41 [00:25<00:00,  1.61it/s]


--------------------------------------------------
Epoch  5 | Time 48.8552 sec | Valid Loss 0.7263 | Test Loss 0.7359
--------------------------------------------------
MAE:  0.7358665
Correlation Coefficient:  0.7750612027685916
mult_acc_7:  0.47230320699708456
mult_acc_5:  0.5524781341107872
F1 score all/non0: 0.8181/0.8364 over 686/656
Accuracy all/non0: 0.8192/0.8369
--------------------------------------------------
Saved model at pre_trained_models/MM.pt!


100%|██████████| 41/41 [00:21<00:00,  1.89it/s]
100%|██████████| 41/41 [00:25<00:00,  1.60it/s]
  0%|          | 0/41 [00:00<?, ?it/s]

--------------------------------------------------
Epoch  6 | Time 49.2005 sec | Valid Loss 0.7660 | Test Loss 0.7368
--------------------------------------------------


100%|██████████| 41/41 [00:21<00:00,  1.89it/s]
100%|██████████| 41/41 [00:25<00:00,  1.61it/s]
  0%|          | 0/41 [00:00<?, ?it/s]

--------------------------------------------------
Epoch  7 | Time 48.8719 sec | Valid Loss 0.8066 | Test Loss 0.7451
--------------------------------------------------


100%|██████████| 41/41 [00:21<00:00,  1.91it/s]
100%|██████████| 41/41 [00:25<00:00,  1.61it/s]
  0%|          | 0/41 [00:00<?, ?it/s]

--------------------------------------------------
Epoch  8 | Time 48.8006 sec | Valid Loss 0.7626 | Test Loss 0.8214
--------------------------------------------------


100%|██████████| 41/41 [00:21<00:00,  1.89it/s]
100%|██████████| 41/41 [00:25<00:00,  1.60it/s]
  0%|          | 0/41 [00:00<?, ?it/s]

--------------------------------------------------
Epoch  9 | Time 49.2763 sec | Valid Loss 0.7738 | Test Loss 0.7432
--------------------------------------------------


100%|██████████| 41/41 [00:21<00:00,  1.90it/s]
100%|██████████| 41/41 [00:25<00:00,  1.60it/s]
  0%|          | 0/41 [00:00<?, ?it/s]

--------------------------------------------------
Epoch 10 | Time 49.0904 sec | Valid Loss 0.8040 | Test Loss 0.7631
--------------------------------------------------


100%|██████████| 41/41 [00:21<00:00,  1.88it/s]
100%|██████████| 41/41 [00:25<00:00,  1.58it/s]
  0%|          | 0/41 [00:00<?, ?it/s]

--------------------------------------------------
Epoch 11 | Time 49.5970 sec | Valid Loss 0.7530 | Test Loss 0.7523
--------------------------------------------------


100%|██████████| 41/41 [00:21<00:00,  1.88it/s]
100%|██████████| 41/41 [00:25<00:00,  1.60it/s]
  0%|          | 0/41 [00:00<?, ?it/s]

--------------------------------------------------
Epoch 12 | Time 49.2314 sec | Valid Loss 0.7693 | Test Loss 0.7591
--------------------------------------------------


100%|██████████| 41/41 [00:21<00:00,  1.90it/s]
100%|██████████| 41/41 [00:25<00:00,  1.61it/s]
  0%|          | 0/41 [00:00<?, ?it/s]

--------------------------------------------------
Epoch 13 | Time 48.9105 sec | Valid Loss 0.7308 | Test Loss 0.7778
--------------------------------------------------


100%|██████████| 41/41 [00:21<00:00,  1.91it/s]
100%|██████████| 41/41 [00:25<00:00,  1.60it/s]
  0%|          | 0/41 [00:00<?, ?it/s]

--------------------------------------------------
Epoch 14 | Time 48.9292 sec | Valid Loss 0.7822 | Test Loss 0.7318
--------------------------------------------------


100%|██████████| 41/41 [00:21<00:00,  1.90it/s]
100%|██████████| 41/41 [00:25<00:00,  1.60it/s]


--------------------------------------------------
Epoch 15 | Time 49.0583 sec | Valid Loss 0.7369 | Test Loss 0.7569
--------------------------------------------------
Best epoch: 5
MAE:  0.7358665
Correlation Coefficient:  0.7750612027685916
mult_acc_7:  0.47230320699708456
mult_acc_5:  0.5524781341107872
F1 score all/non0: 0.8181/0.8364 over 686/656
Accuracy all/non0: 0.8192/0.8369
--------------------------------------------------


In [12]:
torch.save(model.state_dict(), "./saved_models_MMIM_mosi.pt")

### Model Load

In [None]:
model.load_state_dict(torch.load("./saved_models_MMIM_mosi.pt"))
model.eval()

In [13]:
segment_list = []
tester = TestMOSI
tester = tester(model)
segment_list, preds, preds_2, preds_7 = tester.start()

  0%|          | 0/69 [00:00<?, ?it/s]


RuntimeError: input.size(-1) must be equal to input_size. Expected 5, got 74

In [None]:
import pickle
# Gold-truth
labels = []
labels_2 = []
labels_7 = []
with open(f"../datasets/{args.dataset}.pkl", "rb") as handle:
    data = pickle.load(handle)

test_data = data["test"]

video = set()
count = 0

for idx in range(len(test_data)):
    (words, visual, acoustic), label, segment = test_data[idx]
    if args.dataset == 'mosi':
        assert segment_list[idx] == segment
    else:
        video_name = segment[0]
        if video_name in video:
            count += 1
        else:
            video.add(video_name)
            count = 0
        assert segment_list[idx] == segment

    labels.append(label[0][0])

    # label_2 appending
    if label > 0:
        labels_2.append('positive')
    else:
        labels_2.append('negative')
    
    # label_7 appending
    if label < -15/7:
        labels_7.append('very negative')
    elif label < -9/7:
        labels_7.append('negative')
    elif label < -3/7:
        labels_7.append('slightly negative')
    elif label < 3/7:
        labels_7.append('Neutral')
    elif label < 9/7:
        labels_7.append('slightly positive')
    elif label < 15/7:
        labels_7.append('positive')
    else:
        labels_7.append('very positive')
count = 0

In [None]:
from ipywidgets import interact

@interact
def get_predict_result(idx = range(len(segment_list))):
    print("SEGMENT:", segment_list[idx])
    print("GOLD_VALUE:", labels[idx])
    print("GOLD_BINARY:", labels_2[idx])
    print("GOLD_7_CLASS:", labels_7[idx])
    print("PREDICTED_VALUE:", preds[idx])
    print("PREDICTED_BINARY:", preds_2[idx])
    print("PREDICTED _7_CLASS:", preds_7[idx])

In [None]:
print(len(segment_list))
print(len(labels))
print(len(preds))

In [None]:
import plotly.express as px
import plotly.subplots as sp
import pandas as pd

d = {'segmentID': segment_list, 'labels': labels, 'labels_2': labels_2, 'labels_7': labels_7, 'preds': preds, 'preds_2': preds_2, 'preds_7': preds_7}
df = pd.DataFrame(data=d)
order = ['very negative', 'negative', 'slightly negative', 'Neutral', 'slightly positive', 'positive', 'very positive']

fig1 = px.bar(df, x="labels_7")
fig2 = px.bar(df, x="preds_7")

fig1_traces = []
fig2_traces = []

for trace in range(len(fig1["data"])):
    fig1_traces.append(fig1["data"][trace])
for trace in range(len(fig2["data"])):
    fig2_traces.append(fig2["data"][trace])

this_figure = sp.make_subplots(rows=1, cols=2, subplot_titles=("Gold", "MIM"))
for traces in fig1_traces:
    this_figure.append_trace(traces, row=1, col=1)
for traces in fig2_traces:
    this_figure.append_trace(traces, row=1, col=2)

this_figure.update_layout(height=600, width=1500, title_text="CMU-MOSI 7 Class Sentiment Intensity")
this_figure.update_xaxes(categoryorder='array', categoryarray= order)
this_figure.update_yaxes(range=[0,250])
this_figure.show()