In [1]:
from __future__ import absolute_import, division, print_function

import csv
import os
import random
import pickle
import sys
import numpy as np
from typing import *

from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import accuracy_score, f1_score

import wandb
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from torch.nn import CrossEntropyLoss, L1Loss, MSELoss
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef
from transformers import BertTokenizer, XLNetTokenizer, BertForSequenceClassification, get_linear_schedule_with_warmup
from transformers.models.bert.configuration_bert import BertConfig
from transformers.optimization import AdamW
from bert import MAG_BertForSequenceClassification
from xlnet import MAG_XLNetForSequenceClassification

from argparse_utils import str2bool, seed
from global_configs import ACOUSTIC_DIM, VISUAL_DIM, DEVICE

In [2]:
import easydict

args = easydict.EasyDict({
    "dataset": "mosi",
    "max_seq_length": 50,
    "train_batch_size": 48,
    "dev_batch_size" : 128,
    "test_batch_size": 128,
    "n_epochs": 40,
    "beta_shift": 1.0,
    "dropout_prob": 0.5,
    "model": "bert-base-uncased",
    "learning_rate": 1e-5,
    "gradient_accumulation_step": 1,
    "warmup_proportion": 0.1,
    "seed": seed("random")
})

In [3]:
class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, visual, acoustic, input_mask, segment_ids, label_id):
        self.input_ids = input_ids
        self.visual = visual
        self.acoustic = acoustic
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id

In [4]:
class MultimodalConfig(object):
    def __init__(self, beta_shift, dropout_prob):
        self.beta_shift = beta_shift
        self.dropout_prob = dropout_prob

In [5]:
def convert_to_features(examples, max_seq_length, tokenizer):
    features = []

    for (ex_index, example) in enumerate(examples):

        (words, visual, acoustic), label_id, segment = example

        tokens, inversions = [], []
        for idx, word in enumerate(words):
            tokenized = tokenizer.tokenize(word)
            tokens.extend(tokenized)
            inversions.extend([idx] * len(tokenized))

        # Check inversion
        assert len(tokens) == len(inversions)

        aligned_visual = []
        aligned_audio = []

        for inv_idx in inversions:
            aligned_visual.append(visual[inv_idx, :])
            aligned_audio.append(acoustic[inv_idx, :])

        visual = np.array(aligned_visual)
        acoustic = np.array(aligned_audio)

        # Truncate input if necessary
        if len(tokens) > max_seq_length - 2:
            tokens = tokens[: max_seq_length - 2]
            acoustic = acoustic[: max_seq_length - 2]
            visual = visual[: max_seq_length - 2]

        if args.model == "bert-base-uncased":
            prepare_input = prepare_bert_input
        elif args.model == "xlnet-base-cased":
            prepare_input = prepare_xlnet_input

        input_ids, visual, acoustic, input_mask, segment_ids = prepare_input(
            tokens, visual, acoustic, tokenizer
        )

        # Check input length
        assert len(input_ids) == args.max_seq_length
        assert len(input_mask) == args.max_seq_length
        assert len(segment_ids) == args.max_seq_length
        assert acoustic.shape[0] == args.max_seq_length
        assert visual.shape[0] == args.max_seq_length

        features.append(
            InputFeatures(
                input_ids=input_ids,
                input_mask=input_mask,
                segment_ids=segment_ids,
                visual=visual,
                acoustic=acoustic,
                label_id=label_id,
            )
        )
    return features


def prepare_bert_input(tokens, visual, acoustic, tokenizer):
    CLS = tokenizer.cls_token
    SEP = tokenizer.sep_token
    tokens = [CLS] + tokens + [SEP]

    # Pad zero vectors for acoustic / visual vectors to account for [CLS] / [SEP] tokens
    acoustic_zero = np.zeros((1, ACOUSTIC_DIM))
    acoustic = np.concatenate((acoustic_zero, acoustic, acoustic_zero))
    visual_zero = np.zeros((1, VISUAL_DIM))
    visual = np.concatenate((visual_zero, visual, visual_zero))

    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    segment_ids = [0] * len(input_ids)
    input_mask = [1] * len(input_ids)

    pad_length = args.max_seq_length - len(input_ids)

    acoustic_padding = np.zeros((pad_length, ACOUSTIC_DIM))
    acoustic = np.concatenate((acoustic, acoustic_padding))

    visual_padding = np.zeros((pad_length, VISUAL_DIM))
    visual = np.concatenate((visual, visual_padding))

    padding = [0] * pad_length

    # Pad inputs
    input_ids += padding
    input_mask += padding
    segment_ids += padding

    return input_ids, visual, acoustic, input_mask, segment_ids


def prepare_xlnet_input(tokens, visual, acoustic, tokenizer):
    CLS = tokenizer.cls_token
    SEP = tokenizer.sep_token
    PAD_ID = tokenizer.pad_token_id

    # PAD special tokens
    tokens = tokens + [SEP] + [CLS]
    audio_zero = np.zeros((1, ACOUSTIC_DIM))
    acoustic = np.concatenate((acoustic, audio_zero, audio_zero))
    visual_zero = np.zeros((1, VISUAL_DIM))
    visual = np.concatenate((visual, visual_zero, visual_zero))

    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_mask = [1] * len(input_ids)
    segment_ids = [0] * (len(tokens) - 1) + [2]

    pad_length = (args.max_seq_length - len(segment_ids))

    # then zero pad the visual and acoustic
    audio_padding = np.zeros((pad_length, ACOUSTIC_DIM))
    acoustic = np.concatenate((audio_padding, acoustic))

    video_padding = np.zeros((pad_length, VISUAL_DIM))
    visual = np.concatenate((video_padding, visual))

    input_ids = [PAD_ID] * pad_length + input_ids
    input_mask = [0] * pad_length + input_mask
    segment_ids = [3] * pad_length + segment_ids

    return input_ids, visual, acoustic, input_mask, segment_ids


def get_tokenizer(model):
    if model == "bert-base-uncased":
        return BertTokenizer.from_pretrained(model)
    elif model == "xlnet-base-cased":
        return XLNetTokenizer.from_pretrained(model)
    else:
        raise ValueError(
            "Expected 'bert-base-uncased' or 'xlnet-base-cased, but received {}".format(
                model
            )
        )


def get_appropriate_dataset(data):

    tokenizer = get_tokenizer(args.model)

    features = convert_to_features(data, args.max_seq_length, tokenizer)
    all_input_ids = torch.tensor(
        [f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor(
        [f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor(
        [f.segment_ids for f in features], dtype=torch.long)
    all_visual = torch.tensor([f.visual for f in features], dtype=torch.float)
    all_acoustic = torch.tensor(
        [f.acoustic for f in features], dtype=torch.float)
    all_label_ids = torch.tensor(
        [f.label_id for f in features], dtype=torch.float)

    dataset = TensorDataset(
        all_input_ids,
        all_visual,
        all_acoustic,
        all_input_mask,
        all_segment_ids,
        all_label_ids,
    )
    return dataset, tokenizer


def set_up_data_loader():
    with open(f"../datasets/{args.dataset}.pkl", "rb") as handle:
        data = pickle.load(handle)

    train_data = data["train"]
    dev_data = data["dev"]
    test_data = data["test"]

    train_dataset, train_tokenizer = get_appropriate_dataset(train_data)
    dev_dataset, dev_tokenizer = get_appropriate_dataset(dev_data)
    test_dataset, test_tokenizer = get_appropriate_dataset(test_data)

    num_train_optimization_steps = (
        int(
            len(train_dataset) / args.train_batch_size /
            args.gradient_accumulation_step
        )
        * args.n_epochs
    )

    train_dataloader = DataLoader(
        train_dataset, batch_size=args.train_batch_size, shuffle=True
    )

    dev_dataloader = DataLoader(
        dev_dataset, batch_size=args.dev_batch_size, shuffle=True
    )

    test_dataloader = DataLoader(
        test_dataset, batch_size=args.test_batch_size, shuffle=True,
    )

    return (
        train_dataloader,
        dev_dataloader,
        test_dataloader,
        num_train_optimization_steps,
        train_tokenizer,
        dev_tokenizer,
        test_tokenizer
    )


def set_random_seed(seed: int):
    """
    Helper function to seed experiment for reproducibility.
    If -1 is provided as seed, experiment uses random seed from 0~9999

    Args:
        seed (int): integer to be used as seed, use -1 to randomly seed experiment
    """
    print("Seed: {}".format(seed))

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.deterministic = True

    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def prep_for_training(num_train_optimization_steps: int):
    multimodal_config = MultimodalConfig(
        beta_shift=args.beta_shift, dropout_prob=args.dropout_prob
    )
    bert_config = BertConfig(
        hidden_dropout_prob=args.dropout_prob
    )

    if args.model == "bert-base-uncased":
        model = MAG_BertForSequenceClassification.from_pretrained(
            args.model, multimodal_config=multimodal_config, num_labels=1,
        )
        # model = BertForSequenceClassification.from_pretrained(
        #     args.model,
        #     num_labels = 1
        # )
    elif args.model == "xlnet-base-cased":
        model = MAG_XLNetForSequenceClassification.from_pretrained(
            args.model, multimodal_config=multimodal_config, num_labels=1
        )

    model.to(DEVICE)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.01,
        },
        {
            "params": [
                p for n, p in param_optimizer if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_train_optimization_steps,
        num_training_steps=args.warmup_proportion * num_train_optimization_steps,
    )
    return model, optimizer, scheduler


def train_epoch(model: nn.Module, train_dataloader: DataLoader, optimizer, scheduler):
    model.train()
    tr_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
        batch = tuple(t.to(DEVICE) for t in batch)
        input_ids, visual, acoustic, input_mask, segment_ids, label_ids = batch
        visual = torch.squeeze(visual, 1)
        acoustic = torch.squeeze(acoustic, 1)
        model.zero_grad()
        outputs = model(
            input_ids,
            visual,
            acoustic,
            token_type_ids=segment_ids,
            attention_mask=input_mask,
            labels=None
        )

        logits = outputs[0]
        loss_fct = MSELoss()
        loss = loss_fct(logits.view(-1), label_ids.view(-1))

        if args.gradient_accumulation_step > 1:
            loss = loss / args.gradient_accumulation_step

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        tr_loss += loss.item()
        nb_tr_steps += 1

        if (step + 1) % args.gradient_accumulation_step == 0:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

    return tr_loss / nb_tr_steps


def eval_epoch(model: nn.Module, dev_dataloader: DataLoader, optimizer):
    model.eval()
    dev_loss = 0
    nb_dev_examples, nb_dev_steps = 0, 0
    with torch.no_grad():
        for step, batch in enumerate(tqdm(dev_dataloader, desc="Iteration")):
            batch = tuple(t.to(DEVICE) for t in batch)

            input_ids, visual, acoustic, input_mask, segment_ids, label_ids = batch
            visual = torch.squeeze(visual, 1)
            acoustic = torch.squeeze(acoustic, 1)
            outputs = model(
                input_ids,
                visual,
                acoustic,
                token_type_ids=segment_ids,
                attention_mask=input_mask,
                labels=None
            )
            logits = outputs[0]

            loss_fct = MSELoss()
            loss = loss_fct(logits.view(-1), label_ids.view(-1))

            if args.gradient_accumulation_step > 1:
                loss = loss / args.gradient_accumulation_step

            dev_loss += loss.item()
            nb_dev_steps += 1

    return dev_loss / nb_dev_steps


def test_epoch(model: nn.Module, test_dataloader: DataLoader, tokenizer):
    model.eval()
    preds = []
    labels = []
    
    with torch.no_grad():
        for i, batch in enumerate(tqdm(test_dataloader)):
            batch = tuple(t.to(DEVICE) for t in batch)

            input_ids, visual, acoustic, input_mask, segment_ids, label_ids = batch
            visual = torch.squeeze(visual, 1)
            acoustic = torch.squeeze(acoustic, 1)
            outputs = model(
                input_ids,
                visual,
                acoustic,
                token_type_ids=segment_ids,
                attention_mask=input_mask,
                labels=None
            )

            logits = outputs[0]

            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.detach().cpu().numpy()

            logits = np.squeeze(logits).tolist()
            label_ids = np.squeeze(label_ids).tolist()

            preds.extend(logits)
            labels.extend(label_ids)

            # print(i, " th batch")
            # for i, s in enumerate(input_ids):
            #     tokens = tokenizer.convert_ids_to_tokens(s, skip_special_tokens = True)
            #     print(tokens, logits[i], label_ids[i])

        preds = np.array(preds)
        labels = np.array(labels)

        # for s in input_ids:
        #     tokens = tokenizer.convert_ids_to_tokens(s)
        #     print(tokens)
        # print(preds)
        # print(labels)
        # print("--------------------------------------------------------------------")

    return preds, labels


def test_score_model(model: nn.Module, test_dataloader: DataLoader, tokenizer, use_zero=False):

    preds, y_test = test_epoch(model, test_dataloader, tokenizer)
    non_zeros = np.array(
        [i for i, e in enumerate(y_test) if e != 0 or use_zero])

    preds = preds[non_zeros]
    y_test = y_test[non_zeros]

    mae = np.mean(np.absolute(preds - y_test))
    corr = np.corrcoef(preds, y_test)[0][1]

    preds = preds >= 0
    y_test = y_test >= 0

    f_score = f1_score(y_test, preds, average="weighted")
    acc = accuracy_score(y_test, preds)

    return acc, mae, corr, f_score


def test_instance(model: nn.Module, test_tokenizer):
    model.eval()
    segment_list = []
    words_list = []
    preds = []
    preds_2 = []
    preds_7 = []
    labels = []
    labels_2 = []
    labels_7 = []

    with open(f"../datasets/{args.dataset}.pkl", "rb") as handle:
        data = pickle.load(handle)

    # test_data[idx] = (words, visual, acoustic), label, segment
    test_data = data["test"]
    test_dataset, test_tokenizer = get_appropriate_dataset(test_data)
    test_dataloader = DataLoader(
        test_dataset, batch_size=args.test_batch_size, shuffle=False,
    )

    video = set()
    count = 0

    for idx in range(len(test_data)):
        (words, visual, acoustic), label, segment = test_data[idx]
        if args.dataset == 'mosi':
            segment_list.append(segment)
        else:
            video_name = segment[0]
            if video_name in video:
                count += 1
            else:
                video.add(video_name)
                count = 0
            segment_list.append(video_name + '[' + str(count) + ']')

        words_list.append(words)
        labels.append(label[0][0])

        # label_2 appending
        if label > 0:
            labels_2.append('positive')
        else:
            labels_2.append('negative')
        
        # label_7 appending
        if label < -15/7:
            labels_7.append('very negative')
        elif label < -9/7:
            labels_7.append('negative')
        elif label < -3/7:
            labels_7.append('slightly negative')
        elif label < 3/7:
            labels_7.append('Neutral')
        elif label < 9/7:
            labels_7.append('slightly positive')
        elif label < 15/7:
            labels_7.append('positive')
        else:
            labels_7.append('very positive')
            
    # prediction
    with torch.no_grad():
        for i, batch in enumerate(tqdm(test_dataloader)):
            batch = tuple(t.to(DEVICE) for t in batch)

            input_ids, visual, acoustic, input_mask, segment_ids, label_ids = batch
            visual = torch.squeeze(visual, 1)
            acoustic = torch.squeeze(acoustic, 1)
            outputs = model(
                input_ids,
                visual,
                acoustic,
                token_type_ids=segment_ids,
                attention_mask=input_mask,
                labels=None
            )
            logits = outputs[0]

            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.detach().cpu().numpy()

            logits = np.squeeze(logits).tolist()
            label_ids = np.squeeze(label_ids).tolist()

            preds.extend(logits)

            for logit in logits:
                # preds_2 appending
                if logit > 0:
                    preds_2.append('positive')
                else:
                    preds_2.append('negative')

                # label_7 appending
                if logit < -15/7:
                    preds_7.append('very negative')
                elif logit < -9/7:
                    preds_7.append('negative')
                elif logit < -3/7:
                    preds_7.append('slightly negative')
                elif logit < 3/7:
                    preds_7.append('Neutral')
                elif logit < 9/7:
                    preds_7.append('slightly positive')
                elif logit < 15/7:
                    preds_7.append('positive')
                else:
                    preds_7.append('very positive')

            

    count = 0
    for i in range(len(segment_list)):
        print(i, "th data")
        print(segment_list[i])
        print(words_list[i])
        print(labels[i])
        print(labels_2[i])
        print(labels_7[i])
        print(preds[i])
        print(preds_2[i])
        print(preds_7[i])


def train(
    model,
    train_dataloader,
    validation_dataloader,
    test_data_loader,
    optimizer,
    scheduler,
    tokenizer
):
    valid_losses = []
    test_accuracies = []

    for epoch_i in range(int(args.n_epochs)):
        train_loss = train_epoch(model, train_dataloader, optimizer, scheduler)
        valid_loss = eval_epoch(model, validation_dataloader, optimizer)
        test_acc, test_mae, test_corr, test_f_score = test_score_model(
            model, test_data_loader, tokenizer
        )

        print(
            "epoch:{}, train_loss:{}, valid_loss:{}, test_acc:{}".format(
                epoch_i, train_loss, valid_loss, test_acc
            )
        )

        valid_losses.append(valid_loss)
        test_accuracies.append(test_acc)
    
    print("Total Result:")
    print("best_accuracy: ", sorted(test_accuracies)[-1])
    print("best loss: ", sorted(valid_losses)[0])
    
    return model


In [6]:
set_random_seed(args.seed)

Seed: 2949


In [7]:
(
    train_data_loader,
    dev_data_loader,
    test_data_loader,
    num_train_optimization_steps,
    train_tokenizer,
    dev_tokenizer,
    test_tokenizer
) = set_up_data_loader()

model, optimizer, scheduler = prep_for_training(
    num_train_optimization_steps)

model = train(
    model,
    train_data_loader,
    dev_data_loader,
    test_data_loader,
    optimizer,
    scheduler,
    test_tokenizer
)

Initializing MAG with beta_shift:1.0 hidden_prob:0.5


Some weights of the model checkpoint at bert-base-uncased were not used when initializing MAG_BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing MAG_BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MAG_BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of MAG_BertForSequenceClassification were not initialized from the mod

epoch:0, train_loss:2.312387563564159, valid_loss:2.760565757751465, test_acc:0.47022900763358777


Iteration: 100%|██████████| 27/27 [00:02<00:00, 12.88it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 35.38it/s]
100%|██████████| 6/6 [00:00<00:00, 35.04it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 13.08it/s]

epoch:1, train_loss:2.2905799636134394, valid_loss:2.7379735708236694, test_acc:0.46564885496183206


Iteration: 100%|██████████| 27/27 [00:02<00:00, 12.92it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 35.01it/s]
100%|██████████| 6/6 [00:00<00:00, 34.83it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.80it/s]

epoch:2, train_loss:2.2826085267243563, valid_loss:2.7523261308670044, test_acc:0.45648854961832064


Iteration: 100%|██████████| 27/27 [00:02<00:00, 12.26it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.98it/s]
100%|██████████| 6/6 [00:00<00:00, 34.59it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:02, 12.44it/s]

epoch:3, train_loss:2.297285887930128, valid_loss:2.722131133079529, test_acc:0.4549618320610687


Iteration: 100%|██████████| 27/27 [00:02<00:00, 12.72it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 33.11it/s]
100%|██████████| 6/6 [00:00<00:00, 34.74it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.86it/s]

epoch:4, train_loss:2.274548711600127, valid_loss:2.707979202270508, test_acc:0.46412213740458014


Iteration: 100%|██████████| 27/27 [00:02<00:00, 13.00it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 35.28it/s]
100%|██████████| 6/6 [00:00<00:00, 33.91it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:02, 11.84it/s]

epoch:5, train_loss:2.252543149171052, valid_loss:2.6691495180130005, test_acc:0.48854961832061067


Iteration: 100%|██████████| 27/27 [00:02<00:00, 12.41it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.95it/s]
100%|██████████| 6/6 [00:00<00:00, 34.54it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.78it/s]

epoch:6, train_loss:2.248262634983769, valid_loss:2.6556795835494995, test_acc:0.549618320610687


Iteration: 100%|██████████| 27/27 [00:02<00:00, 12.75it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 35.07it/s]
100%|██████████| 6/6 [00:00<00:00, 34.47it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:02, 12.27it/s]

epoch:7, train_loss:2.185562844629641, valid_loss:2.570175886154175, test_acc:0.5923664122137404


Iteration: 100%|██████████| 27/27 [00:02<00:00, 12.71it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.24it/s]
100%|██████████| 6/6 [00:00<00:00, 34.78it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.83it/s]

epoch:8, train_loss:2.153342498673333, valid_loss:2.5185261964797974, test_acc:0.5893129770992367


Iteration: 100%|██████████| 27/27 [00:02<00:00, 12.99it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.71it/s]
100%|██████████| 6/6 [00:00<00:00, 34.75it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.88it/s]

epoch:9, train_loss:2.0808859533733792, valid_loss:2.405665636062622, test_acc:0.5923664122137404


Iteration: 100%|██████████| 27/27 [00:02<00:00, 13.04it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.66it/s]
100%|██████████| 6/6 [00:00<00:00, 34.77it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.94it/s]

epoch:10, train_loss:1.9690306407433968, valid_loss:2.2676409482955933, test_acc:0.5984732824427481


Iteration: 100%|██████████| 27/27 [00:02<00:00, 13.04it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.95it/s]
100%|██████████| 6/6 [00:00<00:00, 34.72it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.81it/s]

epoch:11, train_loss:1.8319963658297505, valid_loss:2.054684817790985, test_acc:0.6274809160305344


Iteration: 100%|██████████| 27/27 [00:02<00:00, 13.05it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.62it/s]
100%|██████████| 6/6 [00:00<00:00, 34.57it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.97it/s]

epoch:12, train_loss:1.6378055987534699, valid_loss:1.764476716518402, test_acc:0.7251908396946565


Iteration: 100%|██████████| 27/27 [00:02<00:00, 13.05it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.77it/s]
100%|██████████| 6/6 [00:00<00:00, 34.70it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.92it/s]

epoch:13, train_loss:1.4116337255195335, valid_loss:1.523555040359497, test_acc:0.7954198473282442


Iteration: 100%|██████████| 27/27 [00:02<00:00, 13.06it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.70it/s]
100%|██████████| 6/6 [00:00<00:00, 34.68it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.97it/s]

epoch:14, train_loss:1.2695813046561346, valid_loss:1.4019355177879333, test_acc:0.815267175572519


Iteration: 100%|██████████| 27/27 [00:02<00:00, 12.96it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.73it/s]
100%|██████████| 6/6 [00:00<00:00, 34.43it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:02, 12.31it/s]

epoch:15, train_loss:1.269033913259153, valid_loss:1.3111392259597778, test_acc:0.8183206106870229


Iteration: 100%|██████████| 27/27 [00:02<00:00, 12.91it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.67it/s]
100%|██████████| 6/6 [00:00<00:00, 34.66it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.96it/s]

epoch:16, train_loss:1.1556970521255776, valid_loss:1.2672472596168518, test_acc:0.8213740458015267


Iteration: 100%|██████████| 27/27 [00:02<00:00, 13.04it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.76it/s]
100%|██████████| 6/6 [00:00<00:00, 34.64it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.88it/s]

epoch:17, train_loss:1.0562519740175318, valid_loss:1.2886595129966736, test_acc:0.815267175572519


Iteration: 100%|██████████| 27/27 [00:02<00:00, 12.77it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 35.10it/s]
100%|██████████| 6/6 [00:00<00:00, 34.62it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 13.01it/s]

epoch:18, train_loss:1.018346987388752, valid_loss:1.2306369543075562, test_acc:0.8183206106870229


Iteration: 100%|██████████| 27/27 [00:02<00:00, 13.06it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.45it/s]
100%|██████████| 6/6 [00:00<00:00, 34.53it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.66it/s]

epoch:19, train_loss:0.9539690017700195, valid_loss:1.2030708193778992, test_acc:0.8244274809160306


Iteration: 100%|██████████| 27/27 [00:02<00:00, 12.69it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 35.44it/s]
100%|██████████| 6/6 [00:00<00:00, 35.06it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 13.25it/s]

epoch:20, train_loss:0.8786712333008095, valid_loss:1.160897672176361, test_acc:0.8229007633587786


Iteration: 100%|██████████| 27/27 [00:02<00:00, 13.03it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.55it/s]
100%|██████████| 6/6 [00:00<00:00, 34.33it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.85it/s]

epoch:21, train_loss:0.8501026122658341, valid_loss:1.1630886793136597, test_acc:0.8183206106870229


Iteration: 100%|██████████| 27/27 [00:02<00:00, 13.05it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.61it/s]
100%|██████████| 6/6 [00:00<00:00, 34.51it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:02, 12.48it/s]

epoch:22, train_loss:0.808682797131715, valid_loss:1.1121562123298645, test_acc:0.8305343511450382


Iteration: 100%|██████████| 27/27 [00:02<00:00, 12.74it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 33.59it/s]
100%|██████████| 6/6 [00:00<00:00, 33.41it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.68it/s]

epoch:23, train_loss:0.7770996424886916, valid_loss:1.1341722011566162, test_acc:0.8198473282442749


Iteration: 100%|██████████| 27/27 [00:02<00:00, 12.62it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 33.90it/s]
100%|██████████| 6/6 [00:00<00:00, 34.17it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.86it/s]

epoch:24, train_loss:0.7155413439980259, valid_loss:1.1064651608467102, test_acc:0.8366412213740458


Iteration: 100%|██████████| 27/27 [00:02<00:00, 13.04it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.28it/s]
100%|██████████| 6/6 [00:00<00:00, 34.39it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.91it/s]

epoch:25, train_loss:0.6986485057406955, valid_loss:1.154703974723816, test_acc:0.8366412213740458


Iteration: 100%|██████████| 27/27 [00:02<00:00, 13.06it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.68it/s]
100%|██████████| 6/6 [00:00<00:00, 34.52it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.93it/s]

epoch:26, train_loss:0.6514985197120242, valid_loss:1.1779365539550781, test_acc:0.833587786259542


Iteration: 100%|██████████| 27/27 [00:02<00:00, 13.06it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.42it/s]
100%|██████████| 6/6 [00:00<00:00, 34.51it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.87it/s]

epoch:27, train_loss:0.6149280071258545, valid_loss:1.103727638721466, test_acc:0.8458015267175573


Iteration: 100%|██████████| 27/27 [00:02<00:00, 13.02it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.31it/s]
100%|██████████| 6/6 [00:00<00:00, 34.43it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.95it/s]

epoch:28, train_loss:0.5374401289003866, valid_loss:1.1724441349506378, test_acc:0.8351145038167939


Iteration: 100%|██████████| 27/27 [00:02<00:00, 12.77it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.10it/s]
100%|██████████| 6/6 [00:00<00:00, 34.16it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.80it/s]

epoch:29, train_loss:0.5435559661300094, valid_loss:1.15543931722641, test_acc:0.833587786259542


Iteration: 100%|██████████| 27/27 [00:02<00:00, 13.02it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.05it/s]
100%|██████████| 6/6 [00:00<00:00, 34.37it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.92it/s]

epoch:30, train_loss:0.5155844390392303, valid_loss:1.1258782148361206, test_acc:0.8351145038167939


Iteration: 100%|██████████| 27/27 [00:02<00:00, 13.04it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.28it/s]
100%|██████████| 6/6 [00:00<00:00, 34.04it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.76it/s]

epoch:31, train_loss:0.516019156685582, valid_loss:1.1507461667060852, test_acc:0.8396946564885496


Iteration: 100%|██████████| 27/27 [00:02<00:00, 13.03it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.27it/s]
100%|██████████| 6/6 [00:00<00:00, 34.18it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.89it/s]

epoch:32, train_loss:0.44699162355175726, valid_loss:1.220868706703186, test_acc:0.8305343511450382


Iteration: 100%|██████████| 27/27 [00:02<00:00, 12.94it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.86it/s]
100%|██████████| 6/6 [00:00<00:00, 34.28it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.92it/s]

epoch:33, train_loss:0.4541569837817439, valid_loss:1.1318506002426147, test_acc:0.8366412213740458


Iteration: 100%|██████████| 27/27 [00:02<00:00, 13.02it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.68it/s]
100%|██████████| 6/6 [00:00<00:00, 34.27it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.84it/s]

epoch:34, train_loss:0.4021299680074056, valid_loss:1.104996144771576, test_acc:0.8351145038167939


Iteration: 100%|██████████| 27/27 [00:02<00:00, 12.64it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.71it/s]
100%|██████████| 6/6 [00:00<00:00, 34.36it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.85it/s]

epoch:35, train_loss:0.375795422328843, valid_loss:1.1957168579101562, test_acc:0.833587786259542


Iteration: 100%|██████████| 27/27 [00:02<00:00, 12.62it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 35.16it/s]
100%|██████████| 6/6 [00:00<00:00, 34.09it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.95it/s]

epoch:36, train_loss:0.3660257994024842, valid_loss:1.1234041452407837, test_acc:0.8320610687022901


Iteration: 100%|██████████| 27/27 [00:02<00:00, 12.94it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.51it/s]
100%|██████████| 6/6 [00:00<00:00, 34.38it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 12.93it/s]

epoch:37, train_loss:0.338266560876811, valid_loss:1.1783284842967987, test_acc:0.8412213740458016


Iteration: 100%|██████████| 27/27 [00:02<00:00, 12.98it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.23it/s]
100%|██████████| 6/6 [00:00<00:00, 34.48it/s]
Iteration:   7%|▋         | 2/27 [00:00<00:01, 13.02it/s]

epoch:38, train_loss:0.32681266246018587, valid_loss:1.1483153104782104, test_acc:0.8381679389312977


Iteration: 100%|██████████| 27/27 [00:02<00:00, 12.91it/s]
Iteration: 100%|██████████| 2/2 [00:00<00:00, 34.54it/s]
100%|██████████| 6/6 [00:00<00:00, 33.54it/s]

epoch:39, train_loss:0.31838869441438605, valid_loss:1.1747119724750519, test_acc:0.8381679389312977
Total Result:
best_accuracy:  0.8458015267175573
best loss:  1.103727638721466





### Sentiment Intensity Reflection of Fustion Embedding Space

In [13]:
with open(f"../datasets/{args.dataset}.pkl", "rb") as handle:
    data = pickle.load(handle)

train_data = data["train"]
dev_data = data["dev"]
test_data = data["test"]

train_dataset, train_tokenizer = get_appropriate_dataset(train_data)
dev_dataset, dev_tokenizer = get_appropriate_dataset(dev_data)
test_dataset, test_tokenizer = get_appropriate_dataset(test_data)

test_data_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False)

model.eval()

test_embeddings = torch.zeros((0, 100), dtype=torch.float32)
preds = []
labels = []
classes = []
pred_classes = []

# Gold 7-Class
for idx in range(len(test_data)):
    (word, visual, acoustic), label, segment = test_data[idx]
    if label < -15/7:
        classes.append(-3)
    elif label < -9/7:
        classes.append(-2)
    elif label < -3/7:
        classes.append(-1)
    elif label < 3/7:
        classes.append(0)
    elif label < 9/7:
        classes.append(1)
    elif label < 15/7:
        classes.append(2)
    else:
        classes.append(3)
classes = np.array(classes)

# MAG-BERT Model output
with torch.no_grad():
    for i, batch in enumerate(tqdm(test_data_loader)):
        batch = tuple(t.to(DEVICE) for t in batch)

        input_ids, visual, acoustic, input_mask, segment_ids, label_ids = batch
        visual = torch.squeeze(visual, 1)
        acoustic = torch.squeeze(acoustic, 1)
        outputs = model(
            input_ids,
            visual,
            acoustic,
            token_type_ids=segment_ids,
            attention_mask=input_mask,
            labels=None
        )

        logits = outputs[0]
        embeddings = outputs[1:]

        test_embeddings = torch.cat((test_embeddings, embeddings.detach().cpu()), 0)

        logits = logits.detach().cpu().numpy()
        label_ids = label_ids.detach().cpu().numpy()

        preds.extend(np.squeeze(logits).tolist())
        labels.extend(np.squeeze(label_ids).tolist())

        preds = np.array(preds)
        labels = np.array(labels)

        # 7-class prediction
        for logit in logits:
            if logit < -15/7:
                pred_classes.append(-3)
            elif logit < -9/7:
                pred_classes.append(-2)
            elif logit < -3/7:
                pred_classes.append(-1)
            elif logit < 3/7:
                pred_classes.append(0)
            elif logit < 9/7:
                pred_classes.append(1)
            elif logit < 15/7:
                pred_classes.append(2)
            else:
                pred_classes.append(3)
        pred_classes = np.array(pred_classes)
    
    

  0%|          | 0/6 [00:00<?, ?it/s]


AttributeError: 'tuple' object has no attribute 'detach'

In [None]:
# Create a two dimensional t-SNE projection of the embeddings
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm
from sklearn.manifold import TSNE

tsne = TSNE(2, verbose=1)
tsne_proj = tsne.fit_transform(test_embeddings)
cmap = cm.get_cmap('tab20')
fig, ax = plt.subplot(figsize=(8,8))
# num_categories = 7
for lab in range(-3, 3):
    indices = pred_classes==lab
    