# Contributors: 
* Shivangana Rawat(cs20mtech12001) </br>
* Pranoy Panda(cs20mtech12002) 

## Importing Libraries

In [1]:
# deep learning framework
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import torch.nn.functional as F

# for loading and manipulating data
import numpy as np
import pandas as pd
from collections import defaultdict
from textwrap import wrap

# for visualization
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc

# for splitting data and evaluation metrics
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report

# miscellaneous
import os
import random

#define global vars
RANDOM_SEED = 42
EPOCHS = 50
LR = 5e-07
BATCH_SIZE = 4
SEQLEN = 5
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
modality_and_repr_type = {'text':['sentence-BERT',512], 'video':['resnet101',2048], 'audio':['opensmile',65]}
text_size = modality_and_repr_type['text'][1]
visual_size = modality_and_repr_type['video'][1]
acoustic_size = modality_and_repr_type['audio'][1]

# setting seeds for reproducibility
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

## Download datasets

In [None]:
# declaring train and test dataframes
df_train = pd.DataFrame()
df_test = pd.DataFrame()

if 'text' in modality_and_repr_type: 
  !gdown --id 1Z-Mt2kMtA6ZJ5YA704SRqFtCZsNbPMMR -q # train raw
  !gdown --id 1TWO58qHYYEVcUeaEltVv2xVH1HLVsUnK -q # test raw
  df_train = pd.read_csv('train.tsv',sep='\t')
  df_test = pd.read_csv('test.tsv',sep='\t')
  print("Downloaded the raw text M2H2 data !")

  # download feature representations
  if modality_and_repr_type['text'][0]=='sentence-BERT':
    !gdown --id 1decZ9lPDjxlKLJZfpvKW8vUynz6UA01n -q # train
    !gdown --id 1--0jt4tgOGRfajVMYBHuoKKDw-aZjgUW -q # test
    df_train['text'] = pd.DataFrame({'text':np.loadtxt('train_utterance_embeddings_sentenceBERT.txt').tolist()})
    df_test['text'] = pd.DataFrame({'text':np.loadtxt('test_utterance_embeddings_sentenceBERT.txt').tolist()})
    print("Downloaded the sentence-BERT embeddings !")
  elif modality_and_repr_type['text'][0]=='FastText':
    !gdown --id 1CP9Q83PQ1eD6D3QpTxQQWhdmpZ4Bb70r -q # train
    !gdown --id 11-89-yI6uwslPACgqsTNMxjTOV_LzNK1 -q # test
    df_train['text'] = pd.DataFrame({'text':np.loadtxt('train_utterance_embeddings_FastText.txt').tolist()})
    df_test['text'] = pd.DataFrame({'text':np.loadtxt('test_utterance_embeddings_FastText.txt').tolist()})
    print("Downloaded the FastText embeddings !")

if 'video' in modality_and_repr_type:
  !gdown --id 1Z-Mt2kMtA6ZJ5YA704SRqFtCZsNbPMMR -q # train raw
  !gdown --id 1TWO58qHYYEVcUeaEltVv2xVH1HLVsUnK -q # test raw
  df_train['Label'] = np.array(pd.read_csv('train.tsv',sep='\t')['Label'])
  df_test['Label'] = np.array(pd.read_csv('test.tsv',sep='\t')['Label'])
  print("Downloaded the raw text M2H2 data !")

  !gdown --id 1J0cc2mf2n03zAGwbLZ9TO1SEHtWAs9Rb -q # train
  !gdown --id 191WO9nVckQnjbiAy3NROZXQOId5AX_w9 -q # test
  df_train['video'] = pd.DataFrame({'video':np.loadtxt('train_utterance_features_resnext101.txt').tolist()})
  df_test['video'] = pd.DataFrame({'video':np.loadtxt('test_utterance_features_resnext101.txt').tolist()})
  print("Downloaded the resnect101 features !")

if 'audio' in modality_and_repr_type:
  !gdown --id 1Z-Mt2kMtA6ZJ5YA704SRqFtCZsNbPMMR -q # train raw
  !gdown --id 1TWO58qHYYEVcUeaEltVv2xVH1HLVsUnK -q # test raw
  df_train['Label'] = np.array(pd.read_csv('train.tsv',sep='\t')['Label'])
  df_test['Label'] = np.array(pd.read_csv('test.tsv',sep='\t')['Label'])
  print("Downloaded the raw text M2H2 data !")

  !gdown --id 1-2isFu4OFEpg4ftrcdpOeHNrRCJ9OLPo -q # train
  !gdown --id 1-GlUVqGL4oLtYzfz7Ik1HuzMiGLAGUL3 -q # test
  df_train['audio'] = pd.DataFrame({'audio':np.loadtxt('train_features_opensmile_avg.txt').tolist()})
  df_test['audio'] = pd.DataFrame({'audio':np.loadtxt('test_features_opensmile_avg.txt').tolist()})
  print("Downloaded the opensmile averaged features !")

''' Vanilla data split'''
# train-val split
df_train, df_val = train_test_split(
  df_train,
  test_size=0.1,
  random_state=RANDOM_SEED
)
print("Number of utterances in train : ", len(df_train))
print("Number of utterances in val : ", len(df_val))

## Create scene+episode wise split

In [None]:
# creating scenewise train val split
df_train['Episode Scene'] = df_train['episode'] + df_train['Scenes']
df_test['Episode Scene'] = df_test['episode'] + df_test['Scenes']
ep_scene = list(set(df_train['Episode Scene']))
random.shuffle(ep_scene)
trainlist = ep_scene[:int(0.9*len(ep_scene))]
testlist = ep_scene[int(0.9*len(ep_scene)):]

traindf = pd.DataFrame()
for epsc in trainlist:
  tempdf = df_train.loc[df_train['Episode Scene'] == epsc]
  traindf = pd.concat([traindf, tempdf])
valdf = pd.DataFrame()
for epsc in testlist:
  tempdf = df_train.loc[df_train['Episode Scene'] == epsc]
  valdf = pd.concat([valdf, tempdf])
df_train = traindf
df_val = valdf

# sorting rows based on episodes and scenes
df_train.sort_values(by=['episode', 'Scenes'])
df_val.sort_values(by=['episode', 'Scenes'])
df_test.sort_values(by=['episode', 'Scenes'])

print("Number of utterances in train : ", len(df_train))
print("Number of utterances in val : ", len(df_val))

## Dataset class

In [6]:
# creating dataset
class M2H2_Dataset(Dataset):
  def __init__(self, df, modality_and_repr_type, seqlen):
    self.df = df
    self.text_features = torch.from_numpy(np.array(list(df['text'])))
    self.video_features = torch.from_numpy(np.array(list(df['video'])))
    self.audio_features = torch.from_numpy(np.array(list(df['audio']))) # now using audio as text

    index = 0
    indexlist = []
    for i in range(len(self.df['Episode Scene'])):
      # value_counts() Return a Series containing counts of unique rows in the DataFrame.
      if len(self.df['Episode Scene'][i: i+seqlen].value_counts()) == 1 and self.df['Episode Scene'][i: i+seqlen].value_counts()[0] == seqlen:
        indexlist.append(index)
        index= index+1
      else:
        indexlist.append(None)
    self.df['Index'] = indexlist
    self.labels = torch.from_numpy(np.array(list(df['Label'])))
    self.df = self.df.reset_index()
    self.len = self.df['Index'].max()
    self.seqlen = seqlen

  def __len__(self):
    # data length
    return int(self.len)

  def __getitem__(self, index):
     # return one item based on the index value
     idx = self.df.index[self.df['Index'] == index]
     indexes = list(range(idx[0], idx[0]+self.seqlen))
     samples = {'textf':self.text_features[indexes],
                'videof':self.video_features[indexes],
                'audiof':self.audio_features[indexes],
                'labels':self.labels[indexes]}
     return samples

In [8]:
# creating data loader
def create_data_loader(df, batch_size, seqlen,dataset_class):
  ds = dataset_class(df,modality_and_repr_type, seqlen)
  return DataLoader(ds,batch_size=batch_size,num_workers=2)


dataset_class = M2H2_Dataset
train_data_loader = create_data_loader(df_train, BATCH_SIZE, SEQLEN,dataset_class)
val_data_loader = create_data_loader(df_val, BATCH_SIZE, SEQLEN,dataset_class)
test_data_loader = create_data_loader(df_test, BATCH_SIZE, SEQLEN,dataset_class)

## MISA Code

In [None]:
import numpy as np
import random

import torch
import torch.nn as nn
from torch.autograd import Function
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

class ReverseLayerF(Function):

    @staticmethod
    def forward(ctx, x, p):
        ctx.p = p

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        print("dekha hai pahli baar")
        output = grad_output.neg() * ctx.p

        return output, None


def to_gpu(x, on_cpu=False, gpu_id=None):
    """Tensor => Variable"""
    if torch.cuda.is_available() and not on_cpu:
        x = x.cuda(gpu_id)
    return x


def masked_mean(tensor, mask, dim):
    """Finding the mean along dim"""
    masked = torch.mul(tensor, mask)
    return masked.sum(dim=dim) / mask.sum(dim=dim)

def masked_max(tensor, mask, dim):
    """Finding the max along dim"""
    masked = torch.mul(tensor, mask)
    neg_inf = torch.zeros_like(tensor)
    neg_inf[~mask] = -math.inf
    return (masked + neg_inf).max(dim=dim)

In [None]:
# let's define a simple model that can deal with multimodal variable length sequence
class MISA(nn.Module):
    def __init__(self, text_tensor,visual_size,acoustic_size):
        super(MISA, self).__init__()

        self.text_size = text_tensor
        self.visual_size = visual_size
        self.acoustic_size = acoustic_size


        self.input_sizes = input_sizes = [self.text_size, self.visual_size, self.acoustic_size]
        self.hidden_sizes = hidden_sizes = [int(self.text_size), int(self.visual_size), int(self.acoustic_size)]
        self.output_size = output_size = 7
        self.dropout_rate = dropout_rate = 0.03
        self.activation = nn.ELU()
        self.tanh = nn.Tanh()
        
        
        rnn =  nn.GRU
        # defining modules - two layer bidirectional LSTM with layer norm in between

        
        #self.embed = nn.Embedding(len(config.word2id), input_sizes[0])
        self.trnn1 = rnn(input_sizes[0], hidden_sizes[0], bidirectional=True)
        self.trnn2 = rnn(2*hidden_sizes[0], hidden_sizes[0], bidirectional=True)
        
        self.vrnn1 = rnn(input_sizes[1], hidden_sizes[1], bidirectional=True)
        self.vrnn2 = rnn(2*hidden_sizes[1], hidden_sizes[1], bidirectional=True)
        
        self.arnn1 = rnn(input_sizes[2], hidden_sizes[2], bidirectional=True)
        self.arnn2 = rnn(2*hidden_sizes[2], hidden_sizes[2], bidirectional=True)



        ##########################################
        # mapping modalities to same sized space
        ##########################################
       
        self.project_t = nn.Sequential()
        self.project_t.add_module('project_t', nn.Linear(in_features=hidden_sizes[0], out_features=512))
        self.project_t.add_module('project_t_activation', self.activation)
        self.project_t.add_module('project_t_layer_norm', nn.LayerNorm(512))
        

        self.project_v = nn.Sequential()
        self.project_v.add_module('project_v', nn.Linear(in_features=hidden_sizes[1], out_features=512))
        self.project_v.add_module('project_v_activation', self.activation)
        self.project_v.add_module('project_v_layer_norm', nn.LayerNorm(512))

        self.project_a = nn.Sequential()
        self.project_a.add_module('project_a', nn.Linear(in_features=hidden_sizes[2], out_features=512))
        self.project_a.add_module('project_a_activation', self.activation)
        self.project_a.add_module('project_a_layer_norm', nn.LayerNorm(512))


        ##########################################
        # private encoders
        ##########################################
        self.private_t = nn.Sequential()
        self.private_t.add_module('private_t_1', nn.Linear(in_features=512, out_features=512))
        self.private_t.add_module('private_t_activation_1', nn.Sigmoid())
        
        self.private_v = nn.Sequential()
        self.private_v.add_module('private_v_1', nn.Linear(in_features=512, out_features=512))
        self.private_v.add_module('private_v_activation_1', nn.Sigmoid())
        
        self.private_a = nn.Sequential()
        self.private_a.add_module('private_a_3', nn.Linear(in_features=512, out_features=512))
        self.private_a.add_module('private_a_activation_3', nn.Sigmoid())
        

        ##########################################
        # shared encoder
        ##########################################
        self.shared = nn.Sequential()
        self.shared.add_module('shared_1', nn.Linear(in_features=512, out_features=512))
        self.shared.add_module('shared_activation_1', nn.Sigmoid())


        ##########################################
        # reconstruct
        ##########################################
        self.recon_t = nn.Sequential()
        self.recon_t.add_module('recon_t_1', nn.Linear(in_features=512, out_features=512))
        self.recon_v = nn.Sequential()
        self.recon_v.add_module('recon_v_1', nn.Linear(in_features=512, out_features=512))
        self.recon_a = nn.Sequential()
        self.recon_a.add_module('recon_a_1', nn.Linear(in_features=512, out_features=512))



        ##########################################
        # shared space adversarial discriminator
        ##########################################
        self.discriminator = nn.Sequential()
        self.discriminator.add_module('discriminator_layer_1', nn.Linear(in_features=512, out_features=512))
        self.discriminator.add_module('discriminator_layer_1_activation', self.activation)
        self.discriminator.add_module('discriminator_layer_1_dropout', nn.Dropout(dropout_rate))
        self.discriminator.add_module('discriminator_layer_2', nn.Linear(in_features=512, out_features=len(hidden_sizes)))

        ##########################################
        # shared-private collaborative discriminator
        ##########################################

        self.sp_discriminator = nn.Sequential()
        self.sp_discriminator.add_module('sp_discriminator_layer_1', nn.Linear(in_features=512, out_features=4))

        self.tlayer_norm = nn.LayerNorm((hidden_sizes[0]*2,))
        self.vlayer_norm = nn.LayerNorm((hidden_sizes[1]*2,))
        self.alayer_norm = nn.LayerNorm((hidden_sizes[2]*2,))

        encoder_layer = torch.nn.TransformerEncoderLayer(d_model=3072, nhead=2) # 128 is the default hidden size used by MISA
        self.transformer_encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=1)

    #original MISA
    def extract_features(self, sequence, lengths, rnn1, rnn2, layer_norm):
        packed_h1, final_h1 = rnn1(sequence)
        normed_h1 = layer_norm(packed_h1)
        _, final_h2 = rnn2(normed_h1)

        return final_h1, final_h2

    def alignment(self, sentences, visual, acoustic, lengths):
        
        self.shared_private(sentences, visual, acoustic)
        
        #misa original
        #self.shared_private(utterance_text, utterance_video, utterance_audio)

        
        reversed_shared_code_t = ReverseLayerF.apply(self.utt_shared_t, 1.0)
        reversed_shared_code_v = ReverseLayerF.apply(self.utt_shared_v, 1.0)
        reversed_shared_code_a = ReverseLayerF.apply(self.utt_shared_a, 1.0)

        self.domain_label_t = self.discriminator(reversed_shared_code_t)
        self.domain_label_v = self.discriminator(reversed_shared_code_v)
        self.domain_label_a = self.discriminator(reversed_shared_code_a)
        


        self.shared_or_private_p_t = self.sp_discriminator(self.utt_private_t)
        self.shared_or_private_p_v = self.sp_discriminator(self.utt_private_v)
        self.shared_or_private_p_a = self.sp_discriminator(self.utt_private_a)
        self.shared_or_private_s = self.sp_discriminator( (self.utt_shared_t + self.utt_shared_v + self.utt_shared_a)/3.0 )
        
        # For reconstruction
        self.reconstruct()
        
        # 1-LAYER TRANSFORMER FUSION
        h = torch.stack((self.utt_private_t, self.utt_private_v, self.utt_private_a, self.utt_shared_t, self.utt_shared_v,  self.utt_shared_a), dim=0) # dims: (6,bs,seq_len,512)
        h = h.transpose(0,1) # dims: (bs,6,seq_len,512)
        h = h.transpose(1,2) # dims: (bs,seq_len,6,512)
        # print('after stacking: ',h.shape)
        h = torch.reshape(h,(h.size(0),h.size(1),h.size(2)*h.size(3)))
        # print('after flatenning: ', h.shape)
        h = self.transformer_encoder(h)
        # print('h after transformer encoder: ',h.shape)
        return h
    
    def reconstruct(self,):

        self.utt_t = (self.utt_private_t + self.utt_shared_t)
        self.utt_v = (self.utt_private_v + self.utt_shared_v)
        self.utt_a = (self.utt_private_a + self.utt_shared_a)
        

        self.utt_t_recon = self.recon_t(self.utt_t)
        self.utt_v_recon = self.recon_v(self.utt_v)
        self.utt_a_recon = self.recon_a(self.utt_a)


    def shared_private(self, utterance_t, utterance_v, utterance_a):
        
        # Projecting to same sized space
        self.utt_t_orig = utterance_t = self.project_t(utterance_t)
        self.utt_v_orig = utterance_v = self.project_v(utterance_v)
        self.utt_a_orig = utterance_a = self.project_a(utterance_a)
        
        # Private-shared components
        self.utt_private_t = self.private_t(utterance_t)
        self.utt_private_v = self.private_v(utterance_v)
        self.utt_private_a = self.private_a(utterance_a)

        self.utt_shared_t = self.shared(utterance_t)
        self.utt_shared_v = self.shared(utterance_v)
        self.utt_shared_a = self.shared(utterance_a)


    def forward(self, sentences, video, acoustic, lengths):
        o = self.alignment(sentences, video, acoustic, lengths)
        return o

## Training and Testing function

In [None]:
def train_epoch(model,
                misa_model,
                data_loader,
                loss_fn,
                loss_recon,
                loss_cmd,
                loss_diff,
                optimizer,
                device,
                n_examples):
  
  model = model.train()
  misa_model = misa_model.train()
  losses = []
  correct_predictions = 0

  
  for i, data in enumerate(data_loader): 
    textf, acouf,visuf,targets = data['textf'],data['audiof'],data['videof'],data['labels']
    textf,acouf,visuf,targets = textf.to(device),acouf.to(device),visuf.to(device),targets.to(device) 
    # getting the fused representation 
    fused_repr = misa_model(textf.float(),visuf.float(),acouf.float(),len(textf)) # shape: (bs,seq_len, Z)
    outputs = model(fused_repr)
    targets = targets[:, -1].long()
    _, preds = torch.max(outputs, dim=1)
    # compute task loss
    loss = loss_fn(outputs, targets)
    
    # compute different losses for training MISA
    recon_loss = get_recon_loss(misa_model,loss_recon)
    diff_loss = get_diff_loss(misa_model,loss_diff)
    similarity_loss = get_cmd_loss(misa_model,loss_cmd) # default option in MISA code (https://github.com/declare-lab/MISA/blob/ec42faddde0d210cf7368aebf2118fe9570e7102/src/config.py#L81)

    # final loss to backprop
    total_loss = loss + recon_loss + diff_loss + similarity_loss 
    
    #book keeping
    correct_predictions += torch.sum(preds == targets)
    losses.append(total_loss.item())
    
    # backprop
    total_loss.backward() # compute gradients for all params via chain rule
    optimizer.step() # update parameters via Adam
    optimizer.zero_grad() # clear accumulated gradients
    
  return correct_predictions.double() / n_examples, np.mean(losses)

In [None]:
def eval_model(model, misa_model, data_loader, loss_fn, loss_recon, loss_cmd, loss_diff, device, n_examples):
  model = model.eval()
  misa_model = misa_model.eval()
  losses = []
  correct_predictions = 0
  with torch.no_grad():
    for i,data in enumerate(data_loader,0):
      textf, acouf,visuf,targets = data['textf'],data['audiof'],data['videof'],data['labels']
      textf,acouf,visuf,targets = textf.to(device),acouf.to(device),visuf.to(device),targets.to(device)  
      fused_repr = misa_model(textf.float(),visuf.float(),acouf.float(),len(textf)) # shape: (bs,seq_len, Z)
      outputs = model(fused_repr)
      targets = targets[:, -1].long()
      _, preds = torch.max(outputs, dim=1)
      # task loss
      loss = loss_fn(outputs, targets)
      # different losses for training MISA
      recon_loss = get_recon_loss(misa_model,loss_recon)
      diff_loss = get_diff_loss(misa_model,loss_diff)
      similarity_loss = get_cmd_loss(misa_model,loss_cmd) # default option in MISA code (https://github.com/declare-lab/MISA/blob/ec42faddde0d210cf7368aebf2118fe9570e7102/src/config.py#L81)

      total_loss = loss + recon_loss + diff_loss + similarity_loss 
      correct_predictions += torch.sum(preds == targets)
      losses.append(total_loss.item())
  return correct_predictions.double() / n_examples, np.mean(losses)


## Declaring the SequenceClassifier

In [None]:
class SequenceHumorClassifier(nn.Module):
  def __init__(self, n_classes, feature_len):
    super(SequenceHumorClassifier, self).__init__()
    self.lstm = nn.LSTM(feature_len, 512, proj_size=128, batch_first=True) # input shape: [batch_size, time_steps,feature_size]
    self.classifier = nn.Sequential(
                                nn.Linear(128,512),
                                nn.ReLU(),
                                nn.Linear(512,1024),
                                nn.ReLU(),
                                nn.Linear(1024,512),
                                nn.ReLU(),
                                nn.Linear(512,256),
                                nn.ReLU(),
                                nn.Linear(256,n_classes)
                                )
  def forward(self, x):
    out, (hn, cn) = self.lstm(x)
    out = self.classifier(out[:, -1, :])
    return out

## Defining the sequence classifier and MISA models

In [None]:
model = SequenceHumorClassifier(2,feature_len=3072) 
model = model.to(device)


# instantiate the MISA object
misa_model = MISA(text_size,visual_size,acoustic_size).to(device)

## Function for different losses of MISA

In [None]:
class MSE(nn.Module):
  def __init__(self):
    super(MSE, self).__init__()

  def forward(self, pred, real):
    diffs = torch.add(real, -pred)
    n = torch.numel(diffs.data)
    mse = torch.sum(diffs.pow(2)) / n
    return mse


class DiffLoss(nn.Module):
  def __init__(self):
    super(DiffLoss, self).__init__()

  def forward(self, input1, input2):

    batch_size = input1.size(0)
    input1 = input1.view(batch_size, -1)
    input2 = input2.view(batch_size, -1)

    # Zero mean
    input1_mean = torch.mean(input1, dim=0, keepdims=True)
    input2_mean = torch.mean(input2, dim=0, keepdims=True)
    input1 = input1 - input1_mean
    input2 = input2 - input2_mean

    input1_l2_norm = torch.norm(input1, p=2, dim=1, keepdim=True).detach()
    input1_l2 = input1.div(input1_l2_norm.expand_as(input1) + 1e-6)
    

    input2_l2_norm = torch.norm(input2, p=2, dim=1, keepdim=True).detach()
    input2_l2 = input2.div(input2_l2_norm.expand_as(input2) + 1e-6)

    diff_loss = torch.mean((input1_l2.t().mm(input2_l2)).pow(2))

    return diff_loss


class CMD(nn.Module):
  """
  Adapted from https://github.com/wzell/cmd/blob/master/models/domain_regularizer.py
  """

  def __init__(self):
    super(CMD, self).__init__()

  def forward(self, x1, x2, n_moments):
    mx1 = torch.mean(x1, 0)
    mx2 = torch.mean(x2, 0)
    sx1 = x1-mx1
    sx2 = x2-mx2
    dm = self.matchnorm(mx1, mx2)
    scms = dm
    for i in range(n_moments - 1):
        scms += self.scm(sx1, sx2, i + 2)
    return scms

  def matchnorm(self, x1, x2):
    power = torch.pow(x1-x2,2)
    summed = torch.sum(power)
    sqrt = summed**(0.5)
    return sqrt
    # return ((x1-x2)**2).sum().sqrt()

  def scm(self, sx1, sx2, k):
    ss1 = torch.mean(torch.pow(sx1, k), 0)
    ss2 = torch.mean(torch.pow(sx2, k), 0)
    return self.matchnorm(ss1, ss2)

# reconstruction loss
def get_recon_loss(misa_model,loss_recon):
  loss =  loss_recon(misa_model.utt_t_recon, misa_model.utt_t_orig)
  loss += loss_recon(misa_model.utt_v_recon, misa_model.utt_v_orig)
  loss += loss_recon(misa_model.utt_a_recon, misa_model.utt_a_orig)
  loss = loss/3.0
  return loss


def get_cmd_loss(misa_model,loss_cmd):
  # losses between shared states
  loss = loss_cmd(misa_model.utt_shared_t, misa_model.utt_shared_v, 5)
  loss += loss_cmd(misa_model.utt_shared_t, misa_model.utt_shared_a, 5)
  loss += loss_cmd(misa_model.utt_shared_a, misa_model.utt_shared_v, 5)
  loss = loss/3.0

  return loss

    
def get_diff_loss(misa_model, loss_diff):
  shared_t = misa_model.utt_shared_t
  shared_v = misa_model.utt_shared_v
  shared_a = misa_model.utt_shared_a
  private_t = misa_model.utt_private_t
  private_v = misa_model.utt_private_v
  private_a = misa_model.utt_private_a

  # Between private and shared
  loss = loss_diff(private_t, shared_t)
  loss += loss_diff(private_v, shared_v)
  loss += loss_diff(private_a, shared_a)

  # Across privates
  loss += loss_diff(private_a, private_t)
  loss += loss_diff(private_a, private_v)
  loss += loss_diff(private_t, private_v)

  return loss

## Setting up optimizer and losses

In [None]:
optimizer = torch.optim.Adam(list(model.parameters())+list(misa_model.parameters()), lr=LR)#, weight_decay = 0.1) # both model parameters need to be trained
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
total_steps = len(train_data_loader) * EPOCHS
# MISA losses
loss_recon = MSE()
loss_cmd = CMD()
loss_diff = DiffLoss()
# task loss
loss_fn = nn.CrossEntropyLoss().to(device) 

## Training and Validation

In [None]:
history = defaultdict(list)
best_accuracy = 0
for epoch in range(EPOCHS):
  print(f'Epoch {epoch + 1}/{EPOCHS}')
  print('-' * 10)
  train_acc, train_loss = train_epoch(
    model,
    misa_model,
    train_data_loader,
    loss_fn,
    loss_recon,
    loss_cmd,
    loss_diff,
    optimizer,
    device,
    len(df_train))
  print(f'Train loss {train_loss} accuracy {train_acc}')
  val_acc, val_loss = eval_model(
    model,
    misa_model,
    val_data_loader,
    loss_fn,
    loss_recon,
    loss_cmd,
    loss_diff,
    device,
    len(df_val)
  )

  scheduler.step(val_loss)

  print(f'Val   loss {val_loss} accuracy {val_acc}')
  print()
  history['train_acc'].append(train_acc)
  history['train_loss'].append(train_loss)
  history['val_acc'].append(val_acc)
  history['val_loss'].append(val_loss)
  if val_acc > best_accuracy:
    torch.save(model.state_dict(), 'best_model_state.bin')
    best_accuracy = val_acc

## Results(F1, P and R values along with confusion matrix)

In [None]:
def get_predictions(model, data_loader):
  model = model.eval()
  review_texts = []
  predictions = []
  prediction_probs = []
  real_values = []
  with torch.no_grad():
    for i,data in enumerate(data_loader,0):
      textf, acouf,visuf,targets = data['textf'],data['audiof'],data['videof'],data['labels']
      textf,acouf,visuf = textf.to(device),acouf.to(device),visuf.to(device)  
      fused_repr = misa_model(textf.float(),visuf.float(),acouf.float(),len(textf)) # shape: (bs,seq_len, Z)
      outputs = model(fused_repr)
      targets = targets[:, -1].long()
      _, preds = torch.max(outputs, dim=1)
      predictions.extend(preds)
      prediction_probs.extend(outputs)
      real_values.extend(targets)
  predictions = torch.stack(predictions).cpu()
  prediction_probs = torch.stack(prediction_probs).cpu()
  real_values = torch.stack(real_values).cpu()
  return review_texts, predictions, prediction_probs, real_values


model.load_state_dict(torch.load('best_model_state.bin'))
y_review_texts, y_pred, y_pred_probs, y_test = get_predictions(
  model,
  test_data_loader
)

In [None]:
print('Stats for ',modality_and_repr_type,'\n')
class_names = ['not humorous', 'humorous']
print(classification_report(y_test, y_pred, target_names=class_names))

In [None]:
print(confusion_matrix(y_test, y_pred))