In [18]:
import os
import re
from typing import Optional
import numpy as np
from subprocess import check_call, CalledProcessError

# path to the folder where you want to store data
os.chdir('./src/datasets/CMU-MultimodalSDK')

# path to the SDK folder
SDK_PATH: Optional[str] = None
# path to the folder where you want to store data
DATA_PATH: Optional[str] = '.src/datasets/CMU-MultimodalSDK/data/'
# path to a pretrained word embedding file
WORD_EMB_PATH: Optional[str] = None
# path to loaded word embedding matrix and corresponding word2id mapping
CACHE_PATH: Optional[str] = '.src/datasets/CMU-MultimodalSDK/data/embedding_and_mapping.pt'

In [19]:
import sys
import mmsdk
from mmsdk import mmdatasdk as md
from subprocess import check_call, CalledProcessError

data_files = os.listdir(DATA_PATH)
print("Downloaded data: ",'\n'.join(data_files))

Downloaded data:  CMU_MOSI_ModifiedTimestampedWords.csd
CMU_MOSI_OpenSmile_EB10.csd
CMU_MOSI_openSMILE_IS09.csd
CMU_MOSI_Opinion_Labels.csd
CMU_MOSI_TimestampedWords.csd
CMU_MOSI_Visual_OpenFace_2.csd
CMU_MOSI_TimestampedWordVectors.csd
CMU_MOSI_Visual_OpenFace_1.csd
CMU_MOSI_Visual_Facet_41.csd
CMU_MOSI_TimestampedPhones.csd
CMU_MOSI_Visual_Facet_42.csd
CMU_MOSI_COVAREP.csd


In [51]:
# define your different modalities - refer to the filenames of the CSD files
visual_field = 'CMU_MOSI_Visual_Facet_41'
acoustic_field = 'CMU_MOSI_COVAREP'
text_field = 'CMU_MOSI_ModifiedTimestampedWords'


features = [
    text_field, 
    visual_field, 
    acoustic_field
]

recipe = {feat: os.path.join(DATA_PATH, feat) + '.csd' for feat in features}
dataset = md.mmdataset(recipe)

[92m[1m[2023-09-25 15:59:05.831] | Success | [0mComputational sequence read from file .src/datasets/CMU-MultimodalSDK/data/CMU_MOSI_ModifiedTimestampedWords.csd ...
[94m[1m[2023-09-25 15:59:05.840] | Status  | [0mChecking the integrity of the <b'CMU_MOSI_ModifiedTimestampedWords'> computational sequence ...
[94m[1m[2023-09-25 15:59:05.840] | Status  | [0mChecking the format of the data in <b'CMU_MOSI_ModifiedTimestampedWords'> computational sequence ...


                                                                   

[92m[1m[2023-09-25 15:59:05.861] | Success | [0m<b'CMU_MOSI_ModifiedTimestampedWords'> computational sequence data in correct format.
[94m[1m[2023-09-25 15:59:05.861] | Status  | [0mChecking the format of the metadata in <b'CMU_MOSI_ModifiedTimestampedWords'> computational sequence ...
[92m[1m[2023-09-25 15:59:05.862] | Success | [0mComputational sequence read from file .src/datasets/CMU-MultimodalSDK/data/CMU_MOSI_Visual_Facet_41.csd ...
[94m[1m[2023-09-25 15:59:05.876] | Status  | [0mChecking the integrity of the <FACET_4.1> computational sequence ...
[94m[1m[2023-09-25 15:59:05.876] | Status  | [0mChecking the format of the data in <FACET_4.1> computational sequence ...


                                                                   

[92m[1m[2023-09-25 15:59:05.907] | Success | [0m<FACET_4.1> computational sequence data in correct format.
[94m[1m[2023-09-25 15:59:05.907] | Status  | [0mChecking the format of the metadata in <FACET_4.1> computational sequence ...
[92m[1m[2023-09-25 15:59:05.908] | Success | [0mComputational sequence read from file .src/datasets/CMU-MultimodalSDK/data/CMU_MOSI_COVAREP.csd ...
[94m[1m[2023-09-25 15:59:05.917] | Status  | [0mChecking the integrity of the <COVAREP> computational sequence ...
[94m[1m[2023-09-25 15:59:05.917] | Status  | [0mChecking the format of the data in <COVAREP> computational sequence ...


                                                                   

[92m[1m[2023-09-25 15:59:05.942] | Success | [0m<COVAREP> computational sequence data in correct format.
[94m[1m[2023-09-25 15:59:05.943] | Status  | [0mChecking the format of the metadata in <COVAREP> computational sequence ...
[92m[1m[2023-09-25 15:59:05.943] | Success | [0mDataset initialized successfully ... 




In [52]:
print(list(dataset.keys()))
print("=" * 80)

print(list(dataset[visual_field].keys())[:10])
print("=" * 80)

some_id = list(dataset[visual_field].keys())[15]
print(list(dataset[visual_field][some_id].keys()))
print("=" * 80)

print(list(dataset[visual_field][some_id]['intervals'].shape))
print("=" * 80)

print(list(dataset[visual_field][some_id]['features'].shape))
print(list(dataset[text_field][some_id]['features'].shape))
print(list(dataset[acoustic_field][some_id]['features'].shape))
print("Different modalities have different number of time steps!")

['CMU_MOSI_ModifiedTimestampedWords', 'CMU_MOSI_Visual_Facet_41', 'CMU_MOSI_COVAREP']
['03bSnISJMiM', '0h-zjBukYpk', '1DmNV9C1hbY', '1iG0909rllw', '2WGyTLYerpo', '2iD-tVS8NPw', '5W7Z1C_fDaE', '6Egk_28TtTM', '6_0THN4chvY', '73jzhE8R1TQ']
['features', 'intervals']
[5404, 2]
[5404, 47]
[658, 1]
[18009, 74]
Different modalities have different number of time steps!


In [53]:
# we define a simple averaging function that does not depend on intervals
def avg(intervals: np.array, features: np.array) -> np.array:
    try:
        return np.average(features, axis=0)
    except:
        return features

# first we align to words with averaging, collapse_function receives a list of functions
dataset.align(text_field, collapse_functions=[avg])

[94m[1m[2023-09-25 15:59:10.644] | Status  | [0mUnify was called ...
[92m[1m[2023-09-25 15:59:10.644] | Success | [0mUnify completed ...
[94m[1m[2023-09-25 15:59:10.644] | Status  | [0mPre-alignment based on <CMU_MOSI_ModifiedTimestampedWords> computational sequence started ...
[94m[1m[2023-09-25 15:59:15.362] | Status  | [0mPre-alignment done for <CMU_MOSI_COVAREP> ...
[94m[1m[2023-09-25 15:59:16.349] | Status  | [0mPre-alignment done for <CMU_MOSI_Visual_Facet_41> ...
[94m[1m[2023-09-25 15:59:16.369] | Status  | [0mAlignment starting ...


                                                                                              

[92m[1m[2023-09-25 15:59:42.765] | Success | [0mAlignment to <CMU_MOSI_ModifiedTimestampedWords> complete.
[94m[1m[2023-09-25 15:59:42.765] | Status  | [0mReplacing dataset content with aligned computational sequences
[92m[1m[2023-09-25 15:59:42.782] | Success | [0mInitialized empty <CMU_MOSI_ModifiedTimestampedWords> computational sequence.
[94m[1m[2023-09-25 15:59:42.782] | Status  | [0mChecking the format of the data in <CMU_MOSI_ModifiedTimestampedWords> computational sequence ...


                                                                      

[92m[1m[2023-09-25 15:59:42.843] | Success | [0m<CMU_MOSI_ModifiedTimestampedWords> computational sequence data in correct format.
[94m[1m[2023-09-25 15:59:42.843] | Status  | [0mChecking the format of the metadata in <CMU_MOSI_ModifiedTimestampedWords> computational sequence ...
[92m[1m[2023-09-25 15:59:42.843] | Success | [0mInitialized empty <CMU_MOSI_Visual_Facet_41> computational sequence.
[94m[1m[2023-09-25 15:59:42.843] | Status  | [0mChecking the format of the data in <CMU_MOSI_Visual_Facet_41> computational sequence ...


                                                                      

[92m[1m[2023-09-25 15:59:42.879] | Success | [0m<CMU_MOSI_Visual_Facet_41> computational sequence data in correct format.
[94m[1m[2023-09-25 15:59:42.879] | Status  | [0mChecking the format of the metadata in <CMU_MOSI_Visual_Facet_41> computational sequence ...
[92m[1m[2023-09-25 15:59:42.879] | Success | [0mInitialized empty <CMU_MOSI_COVAREP> computational sequence.
[94m[1m[2023-09-25 15:59:42.879] | Status  | [0mChecking the format of the data in <CMU_MOSI_COVAREP> computational sequence ...


                                                                      

[92m[1m[2023-09-25 15:59:42.911] | Success | [0m<CMU_MOSI_COVAREP> computational sequence data in correct format.
[94m[1m[2023-09-25 15:59:42.911] | Status  | [0mChecking the format of the metadata in <CMU_MOSI_COVAREP> computational sequence ...


In [54]:
label_field = 'CMU_MOSI_Opinion_Labels'

# we add and align to lables to obtain labeled segments
# this time we don't apply collapse functions so that the temporal sequences are preserved
label_recipe = {label_field: os.path.join(DATA_PATH, label_field + '.csd')}
dataset.add_computational_sequences(label_recipe, destination=None)
dataset.align(label_field)

[92m[1m[2023-09-25 15:59:42.988] | Success | [0mComputational sequence read from file .src/datasets/CMU-MultimodalSDK/data/CMU_MOSI_Opinion_Labels.csd ...
[94m[1m[2023-09-25 15:59:42.994] | Status  | [0mChecking the integrity of the <Opinion Segment Labels> computational sequence ...
[94m[1m[2023-09-25 15:59:42.994] | Status  | [0mChecking the format of the data in <Opinion Segment Labels> computational sequence ...


                                                                   

[92m[1m[2023-09-25 15:59:43.008] | Success | [0m<Opinion Segment Labels> computational sequence data in correct format.
[94m[1m[2023-09-25 15:59:43.008] | Status  | [0mChecking the format of the metadata in <Opinion Segment Labels> computational sequence ...
[94m[1m[2023-09-25 15:59:43.008] | Status  | [0mUnify was called ...
[92m[1m[2023-09-25 15:59:43.045] | Success | [0mUnify completed ...
[94m[1m[2023-09-25 15:59:43.046] | Status  | [0mPre-alignment based on <CMU_MOSI_Opinion_Labels> computational sequence started ...
[94m[1m[2023-09-25 15:59:43.116] | Status  | [0mPre-alignment done for <CMU_MOSI_COVAREP> ...
[94m[1m[2023-09-25 15:59:43.179] | Status  | [0mPre-alignment done for <CMU_MOSI_ModifiedTimestampedWords> ...
[94m[1m[2023-09-25 15:59:43.244] | Status  | [0mPre-alignment done for <CMU_MOSI_Visual_Facet_41> ...
[94m[1m[2023-09-25 15:59:43.246] | Status  | [0mAlignment starting ...


                                                                                              

[92m[1m[2023-09-25 15:59:44.407] | Success | [0mAlignment to <CMU_MOSI_Opinion_Labels> complete.
[94m[1m[2023-09-25 15:59:44.407] | Status  | [0mReplacing dataset content with aligned computational sequences
[92m[1m[2023-09-25 15:59:44.474] | Success | [0mInitialized empty <CMU_MOSI_ModifiedTimestampedWords> computational sequence.
[94m[1m[2023-09-25 15:59:44.474] | Status  | [0mChecking the format of the data in <CMU_MOSI_ModifiedTimestampedWords> computational sequence ...


                                                                     

[92m[1m[2023-09-25 15:59:44.476] | Success | [0m<CMU_MOSI_ModifiedTimestampedWords> computational sequence data in correct format.
[94m[1m[2023-09-25 15:59:44.476] | Status  | [0mChecking the format of the metadata in <CMU_MOSI_ModifiedTimestampedWords> computational sequence ...
[92m[1m[2023-09-25 15:59:44.476] | Success | [0mInitialized empty <CMU_MOSI_Visual_Facet_41> computational sequence.
[94m[1m[2023-09-25 15:59:44.476] | Status  | [0mChecking the format of the data in <CMU_MOSI_Visual_Facet_41> computational sequence ...


                                                                     

[92m[1m[2023-09-25 15:59:44.479] | Success | [0m<CMU_MOSI_Visual_Facet_41> computational sequence data in correct format.
[94m[1m[2023-09-25 15:59:44.479] | Status  | [0mChecking the format of the metadata in <CMU_MOSI_Visual_Facet_41> computational sequence ...
[92m[1m[2023-09-25 15:59:44.479] | Success | [0mInitialized empty <CMU_MOSI_COVAREP> computational sequence.
[94m[1m[2023-09-25 15:59:44.479] | Status  | [0mChecking the format of the data in <CMU_MOSI_COVAREP> computational sequence ...


                                                                     

[92m[1m[2023-09-25 15:59:44.481] | Success | [0m<CMU_MOSI_COVAREP> computational sequence data in correct format.
[94m[1m[2023-09-25 15:59:44.481] | Status  | [0mChecking the format of the metadata in <CMU_MOSI_COVAREP> computational sequence ...
[92m[1m[2023-09-25 15:59:44.481] | Success | [0mInitialized empty <CMU_MOSI_Opinion_Labels> computational sequence.
[94m[1m[2023-09-25 15:59:44.481] | Status  | [0mChecking the format of the data in <CMU_MOSI_Opinion_Labels> computational sequence ...


                                                                     

[92m[1m[2023-09-25 15:59:44.483] | Success | [0m<CMU_MOSI_Opinion_Labels> computational sequence data in correct format.
[94m[1m[2023-09-25 15:59:44.483] | Status  | [0mChecking the format of the metadata in <CMU_MOSI_Opinion_Labels> computational sequence ...




In [55]:
# check out what the keys look like now
print(list(dataset[text_field].keys())[55])

1iG0909rllw[3]


In [56]:
DATASET = md.cmu_mosi
# obtain the train/dev/test splits - these splits are based on video IDs
train_split = DATASET.standard_folds.standard_train_fold
dev_split = DATASET.standard_folds.standard_valid_fold
test_split = DATASET.standard_folds.standard_test_fold

# inspect the splits: they only contain video IDs
print(test_split)

['tmZoasNr4rU', 'zhpQhgha_KU', 'lXPQBPVc5Cw', 'iiK8YX8oH1E', 'tStelxIAHjw', 'nzpVDcQ0ywM', 'etzxEpPuc6I', 'cW1FSBF59ik', 'd6hH302o4v8', 'k5Y_838nuGo', 'pLTX3ipuDJI', 'jUzDDGyPkXU', 'f_pcplsH_V0', 'yvsjCA6Y5Fc', 'nbWiPyCm4g0', 'rnaNMUZpvvg', 'wMbj6ajWbic', 'cM3Yna7AavY', 'yDtzw_Y-7RU', 'vyB00TXsimI', 'dq3Nf_lMPnE', 'phBUpBr1hSo', 'd3_k5Xpfmik', 'v0zCBqDeKcE', 'tIrG4oNLFzE', 'fvVhgmXxadc', 'ob23OKe5a9Q', 'cXypl4FnoZo', 'vvZ4IcEtiZc', 'f9O3YtZ2VfI', 'c7UH_rxdZv4']


In [57]:
# we can see they are in the format of 'video_id[segment_no]', but the splits was specified with video_id only
# we need to use regex or something to match the video IDs...
import torch
import torch.nn as nn

from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm_notebook
from collections import defaultdict

# a sentinel epsilon for safe division, without it we will replace illegal values with a constant
EPS = 0

# construct a word2id mapping that automatically takes increment when new words are encountered
word2id = defaultdict(lambda: len(word2id))
UNK = word2id['<unk>']
PAD = word2id['<pad>']

# place holders for the final train/dev/test dataset
train = []
dev = []
test = []

# define a regular expression to extract the video ID out of the keys
pattern = re.compile('(.*)\[.*\]')
num_drop = 0 # a counter to count how many data points went into some processing issues

for segment in dataset[label_field].keys():
    
    # get the video ID and the features out of the aligned dataset
    vid = re.search(pattern, segment).group(1)
    label = dataset[label_field][segment]['features']
    _words = dataset[text_field][segment]['features']
    _visual = dataset[visual_field][segment]['features']
    _acoustic = dataset[acoustic_field][segment]['features']

    # if the sequences are not same length after alignment, there must be some problem with some modalities
    # we should drop it or inspect the data again
    if not _words.shape[0] == _visual.shape[0] == _acoustic.shape[0]:
        print(f"Encountered datapoint {vid} with text shape {_words.shape}, visual shape {_visual.shape}, acoustic shape {_acoustic.shape}")
        num_drop += 1
        continue

    # remove nan values
    label = np.nan_to_num(label)
    _visual = np.nan_to_num(_visual)
    _acoustic = np.nan_to_num(_acoustic)

    # remove speech pause tokens - this is in general helpful
    # we should remove speech pauses and corresponding visual/acoustic features together
    # otherwise modalities would no longer be aligned
    words = []
    visual = []
    acoustic = []
    for i, word in enumerate(_words):
        if word[0] != b'sp':
            words.append(word2id[word[0].decode('utf-8')]) # SDK stores strings as bytes, decode into strings here
            visual.append(_visual[i, :])
            acoustic.append(_acoustic[i, :])

    words = np.asarray(words)
    visual = np.asarray(visual)
    acoustic = np.asarray(acoustic)

    # z-normalization per instance and remove nan/infs
    visual = np.nan_to_num((visual - visual.mean(0, keepdims=True)) / (EPS + np.std(visual, axis=0, keepdims=True)))
    acoustic = np.nan_to_num((acoustic - acoustic.mean(0, keepdims=True)) / (EPS + np.std(acoustic, axis=0, keepdims=True)))

    if vid in train_split:
        train.append(((words, visual, acoustic), label, segment))
    elif vid in dev_split:
        dev.append(((words, visual, acoustic), label, segment))
    elif vid in test_split:
        test.append(((words, visual, acoustic), label, segment))
    else:
        print(f"Found video that doesn't belong to any splits: {vid}")

print(f"Total number of {num_drop} datapoints have been dropped.")

# turn off the word2id - define a named function here to allow for pickling
def return_unk():
    return UNK
word2id.default_factory = return_unk

  acoustic = np.nan_to_num((acoustic - acoustic.mean(0, keepdims=True)) / (EPS + np.std(acoustic, axis=0, keepdims=True)))
  visual = np.nan_to_num((visual - visual.mean(0, keepdims=True)) / (EPS + np.std(visual, axis=0, keepdims=True)))


Total number of 0 datapoints have been dropped.


In [58]:
# let's see the size of each set and shape of data
print(len(train))
print(len(dev))
print(len(test))

print(train[0][0][1].shape)
print(train[0][1].shape)
print(train[0][1])

print(f"Total vocab size: {len(word2id)}")

1283
229
686
(5, 47)
(1, 1)
[[2.4]]
Total vocab size: 3143


In [59]:
def multi_collate(batch):
    '''
    Collate functions assume batch = [Dataset[i] for i in index_set]
    '''
    # for later use we sort the batch in descending order of length
    batch = sorted(batch, key=lambda x: x[0][0].shape[0], reverse=True)
    
    # get the data out of the batch - use pad sequence util functions from PyTorch to pad things
    labels = torch.cat([torch.from_numpy(sample[1]) for sample in batch], dim=0)
    sentences = pad_sequence([torch.LongTensor(sample[0][0]) for sample in batch], padding_value=PAD)
    visual = pad_sequence([torch.FloatTensor(sample[0][1]) for sample in batch])
    acoustic = pad_sequence([torch.FloatTensor(sample[0][2]) for sample in batch])
    
    # lengths are useful later in using RNNs
    lengths = torch.LongTensor([sample[0][0].shape[0] for sample in batch])
    return sentences, visual, acoustic, labels, lengths

# construct dataloaders, dev and test could use around ~X3 times batch size since no_grad is used during eval
batch_sz = 56
train_loader = DataLoader(train, shuffle=True, batch_size=batch_sz, collate_fn=multi_collate)
dev_loader = DataLoader(dev, shuffle=False, batch_size=batch_sz*3, collate_fn=multi_collate)
test_loader = DataLoader(test, shuffle=False, batch_size=batch_sz*3, collate_fn=multi_collate)

# let's create a temporary dataloader just to see how the batch looks like
temp_loader = iter(DataLoader(test, shuffle=True, batch_size=8, collate_fn=multi_collate))
batch = next(temp_loader)

print(batch[0].shape) # word vectors, padded to maxlen
print(batch[1].shape) # visual features
print(batch[2].shape) # acoustic features
print(batch[3]) # labels
print(batch[4]) # lengths

torch.Size([45, 8])
torch.Size([45, 8, 47])
torch.Size([45, 8, 74])
tensor([[-2.4000],
        [ 1.0000],
        [-0.6000],
        [-1.4000],
        [ 0.6000],
        [ 2.8000],
        [-0.6000],
        [-2.4000]])
tensor([45, 24, 15, 15, 14, 10,  7,  5])


In [60]:
# Let's actually inspect the transcripts to ensure it's correct
id2word = {v:k for k, v in word2id.items()}
examine_target = train
idx = np.random.randint(0, len(examine_target))
print(' '.join(list(map(lambda x: id2word[x], examine_target[idx][0][0].tolist()))))
# print(' '.join(examine_target[idx][0]))
print(examine_target[idx][1])
print(examine_target[idx][2])

like one thing i like with the aliens is that a they are not just weak links that hide behind the
[[1.4]]
2WGyTLYerpo[18]


In [72]:
!ls
from ....src.multimodal.architecture.late_fusion import LFLSTM

LICENSE       __init_.py    librerias.txt next_steps.md [34mwandb[m[m
LICENSE.txt   [31mclean.sh[m[m      [34mmmsdk[m[m         optim.std
README.md     [34mexamples[m[m      model.std     [34mrelated_repos[m[m


ImportError: attempted relative import with no known parent package

In [62]:
# define a function that loads data from GloVe-like embedding files
# we will add tutorials for loading contextualized embeddings later
# 2196017 is the vocab size of GloVe here.

def load_emb(w2i, path_to_embedding, embedding_size=300, embedding_vocab=2196017, init_emb=None):
    if init_emb is None:
        emb_mat = np.random.randn(len(w2i), embedding_size)
    else:
        emb_mat = init_emb
    f = open(path_to_embedding, 'r')
    found = 0
    for line in tqdm_notebook(f, total=embedding_vocab):
        content = line.strip().split()
        vector = np.asarray(list(map(lambda x: float(x), content[-300:])))
        word = ' '.join(content[:-300])
        if word in w2i:
            idx = w2i[word]
            emb_mat[idx, :] = vector
            found += 1
    print(f"Found {found} words in the embedding file.")
    return torch.tensor(emb_mat).float()

In [63]:
from tqdm import tqdm
from torch.optim import Adam, SGD
from sklearn.metrics import accuracy_score

torch.manual_seed(123)
torch.cuda.manual_seed_all(123)

CUDA = torch.cuda.is_available()
MAX_EPOCH = 1000

text_size = 300
visual_size = 47
acoustic_size = 74

# define some model settings and hyper-parameters
input_sizes = [text_size, visual_size, acoustic_size]
hidden_sizes = [int(text_size * 1.5), int(visual_size * 1.5), int(acoustic_size * 1.5)]
fc1_size = sum(hidden_sizes) // 2
dropout = 0.25
output_size = 1
curr_patience = patience = 8
num_trials = 3
grad_clip_value = 1.0
weight_decay = 0.1

if os.path.exists(CACHE_PATH):
    pretrained_emb, word2id = torch.load(CACHE_PATH)
elif WORD_EMB_PATH is not None:
    pretrained_emb = load_emb(word2id, WORD_EMB_PATH)
    torch.save((pretrained_emb, word2id), CACHE_PATH)
else:
    pretrained_emb = None

model = LFLSTM(input_sizes, hidden_sizes, fc1_size, output_size, dropout)
if pretrained_emb is not None:
    model.embed.weight.data = pretrained_emb
model.embed.requires_grad = False
optimizer = Adam([param for param in model.parameters() if param.requires_grad], weight_decay=weight_decay)

if CUDA:
    model.cuda()
criterion = nn.L1Loss(reduction='sum')
criterion_test = nn.L1Loss(reduction='sum')
best_valid_loss = float('inf')
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
lr_scheduler.step() # for some reason it seems the StepLR needs to be stepped once first
train_losses = []
valid_losses = []
for e in range(MAX_EPOCH):
    model.train()
    train_iter = tqdm(train_loader)
    train_loss = 0.0
    for batch in train_iter:
        model.zero_grad()
        t, v, a, y, l = batch
        batch_size = t.size(0)
        if CUDA:
            t = t.cuda()
            v = v.cuda()
            a = a.cuda()
            y = y.cuda()
            l = l.cuda()
        y_tilde = model(t, v, a, l)
        loss = criterion(y_tilde, y)
        loss.backward()
        torch.nn.utils.clip_grad_value_([param for param in model.parameters() if param.requires_grad], grad_clip_value)
        optimizer.step()
        train_iter.set_description(f"Epoch {e}/{MAX_EPOCH}, current batch loss: {round(loss.item()/batch_size, 4)}")
        train_loss += loss.item()
    train_loss = train_loss / len(train)
    train_losses.append(train_loss)
    print(f"Training loss: {round(train_loss, 4)}")

    model.eval()
    with torch.no_grad():
        valid_loss = 0.0
        for batch in dev_loader:
            model.zero_grad()
            t, v, a, y, l = batch
            if CUDA:
                t = t.cuda()
                v = v.cuda()
                a = a.cuda()
                y = y.cuda()
                l = l.cuda()
            y_tilde = model(t, v, a, l)
            loss = criterion(y_tilde, y)
            valid_loss += loss.item()
    
    valid_loss = valid_loss/len(dev)
    valid_losses.append(valid_loss)
    print(f"Validation loss: {round(valid_loss, 4)}")
    print(f"Current patience: {curr_patience}, current trial: {num_trials}.")
    if valid_loss <= best_valid_loss:
        best_valid_loss = valid_loss
        print("Found new best model on dev set!")
        torch.save(model.state_dict(), 'model.std')
        torch.save(optimizer.state_dict(), 'optim.std')
        curr_patience = patience
    else:
        curr_patience -= 1
        if curr_patience <= -1:
            print("Running out of patience, loading previous best model.")
            num_trials -= 1
            curr_patience = patience
            model.load_state_dict(torch.load('model.std'))
            optimizer.load_state_dict(torch.load('optim.std'))
            lr_scheduler.step()
            print(f"Current learning rate: {optimizer.state_dict()['param_groups'][0]['lr']}")
    
    if num_trials <= 0:
        print("Running out of patience, early stopping.")
        break

model.load_state_dict(torch.load('model.std'))
y_true = []
y_pred = []
model.eval()
with torch.no_grad():
    test_loss = 0.0
    for batch in test_loader:
        model.zero_grad()
        t, v, a, y, l = batch
        if CUDA:
            t = t.cuda()
            v = v.cuda()
            a = a.cuda()
            y = y.cuda()
            l = l.cuda()
        y_tilde = model(t, v, a, l)
        loss = criterion_test(y_tilde, y)
        y_true.append(y_tilde.detach().cpu().numpy())
        y_pred.append(y.detach().cpu().numpy())
        test_loss += loss.item()
print(f"Test set performance: {test_loss/len(test)}")
y_true = np.concatenate(y_true, axis=0)
y_pred = np.concatenate(y_pred, axis=0)
                  
y_true_bin = y_true >= 0
y_pred_bin = y_pred >= 0
bin_acc = accuracy_score(y_true_bin, y_pred_bin)
print(f"Test set accuracy is {bin_acc}")

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  train_iter = tqdm_notebook(train_loader)


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 1.306
Validation loss: 1.3898
Current patience: 8, current trial: 3.
Found new best model on dev set!


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.9546
Validation loss: 1.3244
Current patience: 8, current trial: 3.
Found new best model on dev set!


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.7075
Validation loss: 1.3359
Current patience: 8, current trial: 3.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.558
Validation loss: 1.3321
Current patience: 7, current trial: 3.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.4775
Validation loss: 1.323
Current patience: 6, current trial: 3.
Found new best model on dev set!


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.4025
Validation loss: 1.3075
Current patience: 8, current trial: 3.
Found new best model on dev set!


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.3695
Validation loss: 1.3205
Current patience: 8, current trial: 3.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.3445
Validation loss: 1.315
Current patience: 7, current trial: 3.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.3184
Validation loss: 1.3173
Current patience: 6, current trial: 3.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.3056
Validation loss: 1.3121
Current patience: 5, current trial: 3.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.3035
Validation loss: 1.3274
Current patience: 4, current trial: 3.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.3056
Validation loss: 1.3169
Current patience: 3, current trial: 3.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.3151
Validation loss: 1.3042
Current patience: 2, current trial: 3.
Found new best model on dev set!


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.3015
Validation loss: 1.3204
Current patience: 8, current trial: 3.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.3106
Validation loss: 1.3502
Current patience: 7, current trial: 3.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.298
Validation loss: 1.3147
Current patience: 6, current trial: 3.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2717
Validation loss: 1.3174
Current patience: 5, current trial: 3.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2749
Validation loss: 1.3216
Current patience: 4, current trial: 3.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2783
Validation loss: 1.3391
Current patience: 3, current trial: 3.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2588
Validation loss: 1.3139
Current patience: 2, current trial: 3.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.257
Validation loss: 1.3337
Current patience: 1, current trial: 3.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2724
Validation loss: 1.3171
Current patience: 0, current trial: 3.
Running out of patience, loading previous best model.
Current learning rate: 1e-05


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2827
Validation loss: 1.3091
Current patience: 8, current trial: 2.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2536
Validation loss: 1.3105
Current patience: 7, current trial: 2.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2479
Validation loss: 1.3137
Current patience: 6, current trial: 2.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2357
Validation loss: 1.3135
Current patience: 5, current trial: 2.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.233
Validation loss: 1.3087
Current patience: 4, current trial: 2.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2135
Validation loss: 1.3105
Current patience: 3, current trial: 2.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2062
Validation loss: 1.3092
Current patience: 2, current trial: 2.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2092
Validation loss: 1.3157
Current patience: 1, current trial: 2.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2005
Validation loss: 1.3154
Current patience: 0, current trial: 2.
Running out of patience, loading previous best model.
Current learning rate: 1e-05


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2858
Validation loss: 1.3083
Current patience: 8, current trial: 1.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2692
Validation loss: 1.3057
Current patience: 7, current trial: 1.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2398
Validation loss: 1.3087
Current patience: 6, current trial: 1.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2463
Validation loss: 1.3085
Current patience: 5, current trial: 1.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2198
Validation loss: 1.3198
Current patience: 4, current trial: 1.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2157
Validation loss: 1.3167
Current patience: 3, current trial: 1.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2058
Validation loss: 1.315
Current patience: 2, current trial: 1.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2067
Validation loss: 1.3198
Current patience: 1, current trial: 1.


  0%|          | 0/23 [00:00<?, ?it/s]

Training loss: 0.2027
Validation loss: 1.3108
Current patience: 0, current trial: 1.
Running out of patience, loading previous best model.
Current learning rate: 1e-05
Running out of patience, early stopping.
Test set performance: 1.3241789945708071
Test set accuracy is 0.5947521865889213
