In [1]:
import pickle as pkl
import numpy as np
from tqdm import tqdm
import ipywidgets as widgets
from ipywidgets import interact, interact_manual
import itertools

import plotly.express as px
import plotly.subplots as sp
import plotly.figure_factory as ff
import pandas as pd

In [26]:
# To observe data samples

def get_samples(data):
    @interact
    def get_sample(idx = range(len(data['text']))):
        print("TEXT", data['text'][idx].shape)
        print("VISION", data['vision'][idx].shape)
        print("AUDIO", data['audio'][idx].shape)
        print("LABEL", data['labels'][idx])
        print("ID", data['id'][idx])

In [27]:
# To observe total count distribution by class

def barplot_label(data):
    label_list = []
    labels = data['labels']
    for idx, label in enumerate(labels):
        label = label[0][0]
        
        # min, max setting
        if label > 3.:
            label = 3
        elif label < -3.:
            label = -3

        # make class aligned list
        if label == 0.:
            label_list.append([label, 'neutral', round(label), data['id'][idx]])
        else:
            label_list.append([label, 'pos', round(label), data['id'][idx]]) if label > 0. else label_list.append([label, 'neg', round(label), data['id'][idx]])
    
    label_df = pd.DataFrame(label_list, columns=['values', 'binary', '7class', 'segment'])

    order = [-3, -2, -1, 0, 1, 2, 3]

    df_binary = label_df.groupby('binary').count().reset_index()
    df_7class = label_df.groupby('7class').count().reset_index()

    fig1 = px.bar(df_binary, x='binary', y='segment')
    fig2 = px.bar(df_7class, x='7class', y='segment')

    fig1_traces = []
    fig2_traces = []

    for trace in range(len(fig1["data"])):
        fig1_traces.append(fig1["data"][trace])
    for trace in range(len(fig2["data"])):
        fig2_traces.append(fig2["data"][trace])

    this_figure = sp.make_subplots(rows=1, cols=2, subplot_titles=("Binary", "7-class"))
    for traces in fig1_traces:
        this_figure.append_trace(traces, row=1, col=1)
    for traces in fig2_traces:
        this_figure.append_trace(traces, row=1, col=2)

    this_figure.update_xaxes(categoryorder='array', categoryarray= order)
    this_figure.update_layout(height=500, width=1200, title_text="Ground-truth distribution")
    # this_figure.update_yaxes(range=[0,420])
    this_figure.show()

In [17]:
data_path = "/data1/multimodal/processed_data/"

## MOSI aligned

In [37]:
with open(data_path + "mosi_data.pkl", 'rb') as f:
    mosi = pkl.load(f)

In [19]:
mosi_train = mosi['train']
mosi_valid = mosi['valid']
mosi_test = mosi['test']

In [21]:
get_samples(mosi_test)

interactive(children=(Dropdown(description='idx', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 1…

In [23]:
barplot_label(mosi_train)

In [25]:
barplot_label(mosi_test)

## MOSEI aligned

In [29]:
with open(data_path + "mosei_senti_data.pkl", 'rb') as f:
    mosei = pkl.load(f)

In [30]:
mosei_train = mosei['train']
mosei_valid = mosei['valid']
mosei_test = mosei['test']

In [32]:
get_samples(mosei_test)

interactive(children=(Dropdown(description='idx', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 1…

In [33]:
barplot_label(mosei_train)

In [34]:
barplot_label(mosei_test)

## MOSI unaligned

In [42]:
with open(data_path + "mosi_data_noalign.pkl", 'rb') as f:
    mosi_noalign = pkl.load(f)

In [38]:
mosi_noalign_train = mosi_noalign['train']
mosi_noalign_valid = mosi_noalign['valid']
mosi_noalign_test = mosi_noalign['test']

In [39]:
get_samples(mosi_noalign_test)

interactive(children=(Dropdown(description='idx', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 1…

## MOSEI unaligned

In [43]:
with open(data_path + "mosei_senti_data_noalign.pkl", 'rb') as f:
    mosei_noalign = pkl.load(f)

In [44]:
mosei_noalign_train = mosei_noalign['train']
mosei_noalign_valid = mosei_noalign['valid']
mosei_noalign_test = mosei_noalign['test']

In [45]:
get_samples(mosei_noalign_test)

interactive(children=(Dropdown(description='idx', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 1…