# layers.py

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class myRNN(nn.Module):
    def __init__(self,input_size,hidden_size,num_layers,dp=0,bd=False):
        super(myRNN,self).__init__()
        self.hidden_dim = hidden_size
        self.n_layers = num_layers
        self.RNN = nn.GRU(input_size = input_size,hidden_size=hidden_size,num_layers=num_layers,dropout=dp,batch_first=True,bidirectional=bd)
       
    def forward(self,x,h0=None):
        out,h = self.RNN(x,h0)
        return out,h

class attention(nn.Module):
    def __init__(self,qembed_dim, kembed_dim=None, vembed_dim=None, hidden_dim=None, out_dim=None, dropout=0):
        super(attention, self).__init__()
        if kembed_dim is None:
            kembed_dim = qembed_dim
        if hidden_dim is None:
            hidden_dim = kembed_dim
        if out_dim is None:
            out_dim = kembed_dim
        if vembed_dim is None:
            vembed_dim = kembed_dim
            
        self.qembed_dim = qembed_dim
        self.kembed_dim = kembed_dim
        self.vembed_dim = vembed_dim
        
        self.hidden_dim = hidden_dim
        self.for_key = nn.Linear(kembed_dim,hidden_dim)
        self.for_query = nn.Linear(qembed_dim,hidden_dim)
        self.for_value = nn.Linear(vembed_dim,hidden_dim)
        self.normalise_factor = hidden_dim**(1/2)
    
    def mask_score(self,s,m):
        for i in range(s.size()[0]):
            for j in range(s.size()[1]):
                for k in range(s.size()[2]):
                    if m[i][j][k] == 0:
                        s[i][j][k] = float('-inf')   #So that after softmax, 0 weight is given to it
        return s
    
    def forward(self,key,query,mask=None):
        if len(query.shape) == 1:
            query = torch.unsqueeze(query, dim=0)
        if len(key.shape) == 1:
            key = torch.unsqueeze(key, dim=0)
            
        if len(query.shape) == 2:
            query = torch.unsqueeze(query, dim=1)
        if len(key.shape) == 2:
            key = torch.unsqueeze(key, dim=1)
            
        new_query = self.for_query(query)
        new_key = self.for_key(key)
        new_value = self.for_value(key)
        
        score = torch.bmm(new_query,new_key.permute(0,2,1))/self.normalise_factor
        
        if mask != None:
            score = self.mask_score(score,mask)
            
        score = F.softmax(score,-1)
        score.data[score!=score] = 0         #removing nan values
        
        output = torch.bmm(score,new_value)
        return output,score

class interact(nn.Module):
    def __init__(self,hidden_dim,weight_matrix,utt2idx):
        super(interact, self).__init__()
        self.hidden_size = hidden_dim

        self.embedding, num_embeddings, embedding_dim = create_emb_layer(weight_matrix,utt2idx)
        self.rnnD = myRNN(embedding_dim, hidden_dim,1)   #Dialogue
        self.drop1 = nn.Dropout()
        
        self.rnnG = myRNN(embedding_dim*3, hidden_dim,1)   #Global level
        self.drop2 = nn.Dropout()
        
        self.attn = attention(embedding_dim)
        
        self.rnnS = myRNN(embedding_dim*2, embedding_dim*2,1)   #Speaker representation
        self.drop3 = nn.Dropout()

    def forward(self, chat_ids, speaker_info, sp_dialogues, sp_ind, inputs):
        whole_dialogue_indices = inputs
        
        bert_embs = self.embedding(whole_dialogue_indices)
               
        dialogue, h1 = self.rnnD(bert_embs)    #Get global level representation
        dialogue = self.drop1(dialogue)

        device = inputs.device
        
        fop = torch.zeros((dialogue.size()[0],dialogue.size()[1],dialogue.size()[2])).to(device)
        fop2 = torch.zeros((dialogue.size()[0],dialogue.size()[1],dialogue.size()[2]*3)).to(device)
        op = torch.zeros((dialogue.size()[0],dialogue.size()[1],dialogue.size()[2])).to(device)
        spop = torch.zeros((dialogue.size()[0],dialogue.size()[1],dialogue.size()[2]*2)).to(device)
                    
        h0 = torch.randn(1, 1, self.hidden_size*2).to(device)
        d_h = torch.randn(1, 1, self.hidden_size).to(device)
        attn_h = torch.randn(1, 1, self.hidden_size).to(device)
        
        for b in range(dialogue.size()[0]):
            d_id = chat_ids[b]
            speaker_hidden_states = {}
            for s in range(dialogue.size()[1]):
                fop = op.clone()
                
                current_utt = dialogue[b][s]
                
                current_speaker = speaker_info[d_id][s]
                
                if current_speaker not in speaker_hidden_states:
                    speaker_hidden_states[current_speaker] = h0
                
                h = speaker_hidden_states[current_speaker]
                current_utt_emb = torch.unsqueeze(torch.unsqueeze(current_utt,0),0)
                
                key = fop[b][:s+1].clone()
                key = torch.unsqueeze(key,0)
                
                if s == 0:
                    tmp = torch.cat([attn_h,current_utt_emb],-1).to(device)
                    spop[b][s], h_new = self.rnnS(tmp,h)
                else:
                    query = current_utt_emb
                    attn_op,_ = self.attn(key,query)
                    
                    tmp = torch.cat([attn_op,current_utt_emb],-1).to(device)
                    spop[b][s], h_new = self.rnnS(tmp,h)
                
                spop[b][s] = spop[b][s].add(tmp)        # Residual Connection        
                speaker_hidden_states[current_speaker] = h_new
                
                fop2[b][s] = torch.cat([spop[b][s],dialogue[b][s]],-1)
                tmp = torch.unsqueeze(torch.unsqueeze(fop2[b][s].clone(),0),0)
                op[b][s],d_h = self.rnnG(tmp,d_h)

        return op,spop
    
class fc_e(nn.Module):
    def __init__(self,inp_dim,op_dim):
        super(fc_e,self).__init__()
        self.linear1 = nn.Linear(inp_dim,int(inp_dim/2))
        self.drop1 = nn.Dropout()
        
        self.linear2 = nn.Linear(int(inp_dim/2),int(inp_dim/4))
        self.drop2 = nn.Dropout(0.6)
        
        self.linear3 = nn.Linear(int(inp_dim/4),op_dim)
        self.drop3 = nn.Dropout(0.7)
    def forward(self,x):
        ip = x.float()
    
        op = self.linear1(ip)
        op = self.drop1(op)
        
        op = self.linear2(op)
        op = self.drop2(op)
        
        op = self.linear3(op)
        op = self.drop3(op)
        
        return op

class fc_t(nn.Module):
    def __init__(self,inp_dim,op_dim):
        super(fc_t,self).__init__()
        self.linear1 = nn.Linear(inp_dim,inp_dim)
        self.drop1 = nn.Dropout(0.7)
        
        self.linear2 = nn.Linear(inp_dim,inp_dim)
        self.drop2 = nn.Dropout(0.7)
        
        self.linear3 = nn.Linear(inp_dim,int(inp_dim/2))
        self.drop3 = nn.Dropout(0.7)
        
        self.linear4 = nn.Linear(int(inp_dim/2),int(inp_dim/4))
        self.drop4 = nn.Dropout(0.7)
        
        self.linear5 = nn.Linear(int(inp_dim/4),op_dim)
        self.drop5 = nn.Dropout(0.7)
    def forward(self,x):
        ip = x.float()
    
        op = self.linear1(ip)
        op = self.drop1(op)
        
        op = self.linear2(ip)
        op = self.drop2(op)
        
        op = self.linear3(ip)
        op = self.drop3(op)
        
        op = self.linear4(op)
        op = self.drop4(op)
        
        op = self.linear5(op)
        op = self.drop5(op)
        
        return op
    
class maskedattn(nn.Module):
    def __init__(self,batch_size, s_len, emb_size):
        super(maskedattn,self).__init__()
        self.b_len = batch_size
        self.s_len = s_len
        self.emb_size = emb_size
        self.attn = attention(emb_size*2, kembed_dim=emb_size, out_dim=emb_size)
    
    def create_mask(self,n):
        mask = torch.zeros((1, self.s_len, self.emb_size), dtype=torch.uint8)
        mask[:n+1] = torch.ones((self.emb_size), dtype=torch.uint8)
        mask = mask.repeat(self.b_len,1,1)
        return mask
        
    def forward(self,key,query):
        device = key.device

        ops = torch.zeros([key.size()[0],key.size()[1], key.size()[2]], dtype=torch.float32).to(device)
        for i in range(key.size()[1]):
          mask = self.create_mask(i)
          op,_ = self.attn(key,query,mask=mask)
          for b in range(op.size()[0]):
            ops[b][i] = op[b][i]
        return ops
    
class memnet(nn.Module):
  def __init__(self,num_hops,hidden_size,batch_size,seq_len):
    super(memnet,self).__init__()
    self.num_hops = num_hops
    self.rnn = myRNN(hidden_size, hidden_size, 1)
    self.masked_attention = maskedattn(batch_size,seq_len,hidden_size)
  
  def forward(self,globl,spl):
    X = globl
    for hop in range(self.num_hops):
      dialogue,h = self.rnn(X)
      X = self.masked_attention(dialogue,spl)
    return X

class pool(nn.Module):
    def __init__(self,mode="mean"):
        super(pool,self).__init__()
        self.mode = mode
    def forward(self,x):
        device = x.device
        op = torch.zeros((x.size()[0],x.size()[1],x.size()[2])).to(device)
        for b in range(x.size()[0]):
            this_tensor = []
            for s in range(x.size()[1]):
                this_tensor.append(x[b][s])
                if self.mode == "mean":
                    op[b][s] = torch.mean(torch.stack(this_tensor),0)
                elif self.mode == "max":
                    op[b][s],_ = torch.max(torch.stack(this_tensor),0)
                elif self.mode == "sum":
                    op[b][s] = torch.sum(torch.stack(this_tensor),0)
                else:
                    print("Error: Mode can be either mean or max only")
        return op

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

# Pickle Data Loader.py

In [7]:
pickle_folder_path = "/kaggle/input/meld-pickles/Pickles/"

import pickle

def load_erc():
    with open(pickle_folder_path + "idx2utt.pickle","rb") as f:
        idx2utt = pickle.load(f)
    with open(pickle_folder_path + "utt2idx.pickle","rb") as f:
        utt2idx = pickle.load(f)
        
    with open(pickle_folder_path + "idx2emo.pickle","rb") as f:
        idx2emo = pickle.load(f)
    with open(pickle_folder_path + "emo2idx.pickle","rb") as f:
        emo2idx = pickle.load(f)
        
    with open(pickle_folder_path + "idx2speaker.pickle","rb") as f:
        idx2speaker = pickle.load(f)
    with open(pickle_folder_path + "speaker2idx.pickle","rb") as f:
        speaker2idx = pickle.load(f)

    with open(pickle_folder_path + "weight_matrix.pickle","rb") as f:
        weight_matrix = pickle.load(f)

    with open(pickle_folder_path + "train_data.pickle","rb") as f:
        my_dataset_train = pickle.load(f)
        
    with open(pickle_folder_path + "test_data.pickle","rb") as f:
        my_dataset_test = pickle.load(f)
        
    with open(pickle_folder_path + "final_speaker_info.pickle","rb") as f:
        final_speaker_info = pickle.load(f)
        
    with open(pickle_folder_path + "final_speaker_dialogues.pickle","rb") as f:
        final_speaker_dialogues = pickle.load(f)
        
    with open(pickle_folder_path + "final_speaker_emotions.pickle","rb") as f:
        final_speaker_emotions = pickle.load(f)
        
    with open(pickle_folder_path + "final_speaker_indices.pickle","rb") as f:
        final_speaker_indices = pickle.load(f)
        
    with open(pickle_folder_path + "final_utt_len.pickle","rb") as f:
        final_utt_len = pickle.load(f)

    return idx2utt, utt2idx, idx2emo, emo2idx, idx2speaker,\
        speaker2idx, weight_matrix, my_dataset_train, my_dataset_test,\
        final_speaker_info, final_speaker_dialogues, final_speaker_emotions,\
        final_speaker_indices, final_utt_len

def load_efr():
    with open(pickle_folder_path + "idx2utt.pickle","rb") as f:
        idx2utt = pickle.load(f)
    with open(pickle_folder_path + "utt2idx.pickle","rb") as f:
        utt2idx = pickle.load(f)
        
    with open(pickle_folder_path + "idx2emo.pickle","rb") as f:
        idx2emo = pickle.load(f)
    with open(pickle_folder_path + "emo2idx.pickle","rb") as f:
        emo2idx = pickle.load(f)
        
    with open(pickle_folder_path + "idx2speaker.pickle","rb") as f:
        idx2speaker = pickle.load(f)
    with open(pickle_folder_path + "speaker2idx.pickle","rb") as f:
        speaker2idx = pickle.load(f)

    with open(pickle_folder_path + "weight_matrix.pickle","rb") as f:
        weight_matrix = pickle.load(f)

    with open(pickle_folder_path + "train_data_trig.pickle","rb") as f:
        my_dataset_train = pickle.load(f)

    with open(pickle_folder_path + "test_data_trig.pickle","rb") as f:
        my_dataset_test = pickle.load(f)
        
    with open(pickle_folder_path + "global_speaker_info_trig.pickle","rb") as f:
        global_speaker_info = pickle.load(f)
        
    with open(pickle_folder_path + "speaker_dialogues_trig.pickle","rb") as f:
        speaker_dialogues = pickle.load(f)
        
    with open(pickle_folder_path + "speaker_emotions_trig.pickle","rb") as f:
        speaker_emotions = pickle.load(f)
        
    with open(pickle_folder_path + "speaker_indices_trig.pickle","rb") as f:
        speaker_indices = pickle.load(f)
        
    with open(pickle_folder_path + "utt_len_trig.pickle","rb") as f:
        utt_len = pickle.load(f)
        
    with open(pickle_folder_path + "global_speaker_info_test_trig.pickle","rb") as f:
        global_speaker_info_test = pickle.load(f)
        
    with open(pickle_folder_path + "speaker_dialogues_test_trig.pickle","rb") as f:
        speaker_dialogues_test = pickle.load(f)
        
    with open(pickle_folder_path + "speaker_emotions_test_trig.pickle","rb") as f:
        speaker_emotions_test = pickle.load(f)
        
    with open(pickle_folder_path + "speaker_indices_test_trig.pickle","rb") as f:
        speaker_indices_test = pickle.load(f)
        
    with open(pickle_folder_path + "utt_len_test_trig.pickle","rb") as f:
        utt_len_test = pickle.load(f)

    return idx2utt, utt2idx, idx2emo, emo2idx, idx2speaker,\
        speaker2idx, weight_matrix, my_dataset_train, my_dataset_test,\
        global_speaker_info, speaker_dialogues, speaker_emotions, \
        speaker_indices, utt_len, global_speaker_info_test, speaker_dialogues_test, \
        speaker_emotions_test, speaker_indices_test, utt_len_test

# utils.py

In [8]:
import torch.nn as nn

##Source: https://medium.com/@martinpella/how-to-use-pre-trained-word-embeddings-in-pytorch-71ca59249f76
def create_emb_layer(weights_matrix, utt2idx, non_trainable=False):
    num_embeddings, embedding_dim = weights_matrix.size()
    emb_layer = nn.Embedding(num_embeddings, embedding_dim, padding_idx=utt2idx["<pad>"])
    emb_layer.load_state_dict({'weight': weights_matrix})
    if non_trainable:
        emb_layer.weight.requires_grad = False
    return emb_layer, num_embeddings, embedding_dim

# models.py

In [9]:
import torch
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class ERC_MMN(nn.Module):
    def __init__(self,hidden_size,weight_matrix,utt2idx,batch_size,seq_len):
        super(ERC_MMN,self).__init__()
        self.ia = interact(hidden_size,weight_matrix,utt2idx)
        self.mn = memnet(4,hidden_size,batch_size,seq_len)
        self.pool = pool()
        
        self.rnn_c = myRNN(hidden_size*3,hidden_size*2,1)
        
        self.rnn_e = myRNN(hidden_size*2,hidden_size*2,1)
                
        self.linear1 = fc_e(hidden_size*2,7)

    def forward(self,c_ids,speaker_info,sp_dialogues,sp_em,sp_ind,x1,mode="train"):
        glob, splvl = self.ia(c_ids,speaker_info,sp_dialogues,sp_ind,x1)

        op = self.mn(glob,splvl)
        op = self.pool(op)

        op = torch.cat([splvl,op],dim=2)

        rnn_c_op,_ = self.rnn_c(op)

        rnn_e_op,_ = self.rnn_e(rnn_c_op)
        fip = rnn_e_op.add(rnn_c_op)      # Residual Connection
        fop1 = self.linear1(fip)

        return fip,fop1

class EFR_TX(nn.Module):
    def __init__(self, weight_matrix, utt2idx, nclass, ninp, count_speakers, nsp, nhead, nhid, nlayers, device, dropout=0.5):
        super(EFR_TX, self).__init__()
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder, num_embeddings, embedding_dim = create_emb_layer(weight_matrix, utt2idx)
        self.ninp = ninp
        self.decoder = nn.Linear(2*ninp, nclass)
        self.speakers_embedding = torch.nn.Embedding(count_speakers, nsp)

        self.init_weights()
        self.device = device

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, d_ids, sp_ids, ut_len):
        device = 'cuda'
        torch.set_default_device('cuda')
        if self.src_mask is None or self.src_mask.size(0) != len(src):
            device = src.device
            mask = self._generate_square_subsequent_mask(len(src)).to(device)
            self.src_mask = mask

        # Old Code
        # src = self.encoder(src) * math.sqrt(self.ninp)
        # New
        src = self.encoder(src)
        new_src = torch.zeros(src.shape[0],src.shape[1],self.ninp)
        for ix1,mat in enumerate(src):
            for ix2,vec in enumerate(mat):
                new_src[ix1][ix2] = torch.cat([self.speakers_embedding(torch.tensor(sp_ids[ix1][ix2], device=device, dtype=torch.long)), src[ix1][ix2]],-1)
        src = new_src
        src = src * math.sqrt(self.ninp)        
        
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, self.src_mask)
        
        decoder_ip = torch.zeros(output.size()[0],output.size()[1],output.size()[2]*2).to(self.device)
        for b in range(output.size()[0]):
            d_id = d_ids[b][0]
            main_utt = output[b][ut_len[d_id]-1]
            for s in range(ut_len[d_id]):
                this_utt = output[b][s]
                decoder_ip[b][s] = torch.cat([this_utt,main_utt],-1)
        
        output = self.decoder(decoder_ip)
        
        return decoder_ip,output

class ERC_true_EFR(nn.Module):
    def __init__(self, weight_matrix, utt2idx, nclass, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(ERC_true_EFR, self).__init__()
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder, num_embeddings, embedding_dim = create_emb_layer(weight_matrix,utt2idx)
        
        self.emoGRU = myRNN(7,100,1)
        self.ninp = ninp
        self.decoder = nn.Linear(2*ninp+100, nclass)

        self.init_weights()

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, em_seq, d_ids, ut_len):
        if self.src_mask is None or self.src_mask.size(0) != len(src):
            device = src.device
            mask = self._generate_square_subsequent_mask(len(src)).to(device)
            self.src_mask = mask

        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, self.src_mask)
        
        emo_seq,_ = self.emoGRU(em_seq.float())
        
        decoder_ip = torch.zeros(output.size()[0],output.size()[1],output.size()[2]*2).cuda()
        for b in range(output.size()[0]):
            d_id = d_ids[b][0]
            main_utt = output[b][ut_len[d_id]-1]
            for s in range(ut_len[d_id]):
                this_utt = output[b][s]
                decoder_ip[b][s] = torch.cat([this_utt,main_utt],-1)
        
        decoder_ip = torch.cat([decoder_ip,emo_seq],-1)
        output = self.decoder(decoder_ip)
        
        return output

class ERC_EFR_multitask(nn.Module):
    def __init__(self,hidden_size,weight_matrix,utt2idx,batch_size,seq_len):
        super(ERC_EFR_multitask,self).__init__()
        self.ia = interact(hidden_size,weight_matrix,utt2idx)
        self.mn = memnet(4,hidden_size,batch_size,seq_len)
        self.pool = pool()
        
        self.rnn_c = myRNN(hidden_size*3,hidden_size*2,1)
        
        self.rnn_e = myRNN(hidden_size*2,hidden_size*2,1)
        self.rnn_t = myRNN(hidden_size*2,hidden_size,1)

        self.linear1 = fc_e(hidden_size*2,7)
        self.linear2 = fc_t(hidden_size*2,2)

    def forward(self,c_ids,speaker_info,sp_dialogues,sp_em,sp_ind,freeze,x1,mode="train"):
        speaker_emo = {}
        speaker_emo_distance = {}
        
        for d_id in c_ids:
            speaker_emo[d_id] = {}
            speaker_emo_distance[d_id] = {}
                    
        if freeze:
            with torch.no_grad():
                glob, splvl = self.ia(c_ids,speaker_info,sp_dialogues,sp_ind,x1)
        
                op = self.mn(glob,splvl)
                op = self.pool(op)

                op = torch.cat([splvl,op],dim=2)

                rnn_c_op,_ = self.rnn_c(op)

                rnn_e_op,_ = self.rnn_e(rnn_c_op)
                rnn_e_op = rnn_e_op.add(rnn_c_op)      # Residual Connection
                fop1 = self.linear1(rnn_e_op)
        else:
            glob, splvl = self.ia(c_ids,speaker_info,sp_dialogues,sp_ind,x1)
        
            op = self.mn(glob,splvl)
            op = self.pool(op)
            
            op = torch.cat([splvl,op],dim=2)

            rnn_c_op,_ = self.rnn_c(op)

            rnn_e_op,_ = self.rnn_e(rnn_c_op)
            rnn_e_op = rnn_e_op.add(rnn_c_op)      # Residual Connection
            fop1 = self.linear1(rnn_e_op)
        
        rnn_t_op,_ = self.rnn_t(rnn_c_op)

        fop2_final = []
        for b in range(rnn_t_op.size()[0]):
            d_id = c_ids[b]
            fop2_final_tmp = []
            for s in range(rnn_t_op.size()[1]):
                fop2_final_tmp_tmp = []
                concerned_utt = rnn_t_op[b][s]
                
                if s < 4:
                    r = s+1
                else:
                    r = 4
                
                for s2 in range(r,-1,-1):
                    this_utt = rnn_t_op[b][s-s2]
                    tmp = torch.cat((concerned_utt,this_utt),-1)
                    fop2 = self.linear2(tmp)

                    fop2_final_tmp_tmp.append(fop2)
                fop2_final_tmp.append(fop2_final_tmp_tmp)
            fop2_final.append(fop2_final_tmp)
        return fop1,fop2_final

class cascade(nn.Module):
    def __init__(self,hidden_size,nclasses):
        super(cascade,self).__init__()        
        self.linear = fc_e(hidden_size*4,nclasses)
    
    def forward(self,x1):
        op = self.linear(x1)
        return op

# Train EFR-TX.py

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from torch.utils import data
from sklearn.metrics import classification_report
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix

batch_size = 128
seq_len = 5
seq2_len = seq_len
emb_size = 768
hidden_size = 768
batch_first = True

torch.set_default_device('cuda')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

idx2utt, utt2idx, idx2emo, emo2idx, idx2speaker,\
        speaker2idx, weight_matrix, my_dataset_train, my_dataset_test,\
        global_speaker_info, speaker_dialogues, speaker_emotions, \
        speaker_indices, utt_len, global_speaker_info_test, speaker_dialogues_test, \
        speaker_emotions_test, speaker_indices_test, utt_len_test = load_efr()
    
def get_train_test_loader(bs):
    print(len(my_dataset_train))
    train_data_iter = data.DataLoader(my_dataset_train, batch_size=bs)
    test_data_iter = data.DataLoader(my_dataset_test, batch_size=bs)
    
    return train_data_iter, test_data_iter
    
def train(model, train_data_loader, epochs):
    class_weights2 = torch.FloatTensor(weights2).to(device)
    criterion2 = nn.CrossEntropyLoss(weight=class_weights2,reduction='none').to(device)
    
    optimizer = torch.optim.Adam(model.parameters(),lr=5e-8,weight_decay=1e-5)
    
    max_f1_2 = 0
   
    for epoch in tqdm(range(epochs)):
        print("\n\n-------Epoch {}-------\n\n".format(epoch+1))
        model.train()
        
        avg_loss = 0
       
        y_true2 = []
        y_pred2 = []
            
        for i_batch, sample_batched in tqdm(enumerate(train_data_loader)):
            dialogue_ids = sample_batched[0].tolist()
            inputs = sample_batched[1].to(device)
            targets2 = sample_batched[3].to(device)
            
            # Creating the speaker_ids
            speaker_ids = []
            for d_ids_list in dialogue_ids:
              sp_id_list = [0] * len(d_ids_list)
              for ix, d_id in enumerate(d_ids_list):
                sp_id = global_speaker_info[d_id][0]
                sp_id_list[ix] = sp_id
              speaker_ids.append(sp_id_list)
            
            optimizer.zero_grad()
            
            _,outputs = model(inputs,dialogue_ids,speaker_ids,utt_len)
            
            loss = 0
            for b in range(outputs.size()[0]):
              loss2 = 0
              
              for s in range(utt_len[dialogue_ids[b][0]]):
                pred2 = outputs[b][s]
                pred_flip = torch.argmax(F.softmax(pred2.to(device),-1),-1)
                
                truth2 = targets2[b][s]

                y_pred2.append(pred_flip.item())
                y_true2.append(truth2.long().to(device).item())

                pred2_ = torch.unsqueeze(pred2,0)
                truth2_ = torch.unsqueeze(truth2,0)
                
                loss2 += criterion2(pred2_,truth2_)
              loss2 /= utt_len[dialogue_ids[b][0]]
            
            loss += loss2
            loss /= outputs.size()[0]
            avg_loss += loss

            loss.backward()            
            optimizer.step()
            
        avg_loss /= len(train_data_loader)
        
        print("Average Loss = ",avg_loss)

        f1_2_cls,v_loss = validate(model, data_iter_test, epoch)
        
        # if f1_2_cls[1] > max_f1_2:
        #     print(f"Saving model at epoch {epoch}")
        #     max_f1_2 = f1_2_cls[1]
        #     torch.save(model.state_dict(), "./best_model.pth")

    return model

def validate(model, test_data_loader,epoch):
    print("\n\n***VALIDATION ({})***\n\n".format(epoch))
    
    class_weights2 = torch.FloatTensor(weights2).to(device)
    criterion2 = nn.CrossEntropyLoss(weight=class_weights2,reduction='none').to(device)

    model.eval()

    with torch.no_grad():
      avg_loss = 0
      y_true2 = []
      y_pred2 = []

      for i_batch, sample_batched in tqdm(enumerate(test_data_loader)):
            dialogue_ids = sample_batched[0].tolist()           
            inputs = sample_batched[1].to(device)
            targets2 = sample_batched[3].to(device)
            
            # Creating the speaker_ids
            speaker_ids = []
            for d_ids_list in dialogue_ids:
              sp_id_list = [0] * len(d_ids_list)
              for ix, d_id in enumerate(d_ids_list):
                sp_id = global_speaker_info[d_id][0]
                sp_id_list[ix] = sp_id
              speaker_ids.append(sp_id_list)
                       
            _,outputs = model(inputs,dialogue_ids,speaker_ids,utt_len)
            
            loss = 0
            for b in range(outputs.size()[0]):
              loss2 = 0
              
              for s in range(utt_len_test[dialogue_ids[b][0]]):
                pred2 = outputs[b][s]
                pred_flip = torch.argmax(F.softmax(pred2.to(device),-1),-1)
                
                truth2 = targets2[b][s]

                y_pred2.append(pred_flip.item())
                y_true2.append(truth2.long().to(device).item())

                pred2_ = torch.unsqueeze(pred2,0)
                truth2_ = torch.unsqueeze(truth2,0)
                
                loss2 += criterion2(pred2_,truth2_)
              loss2 /= utt_len_test[dialogue_ids[b][0]]
            
            loss += loss2
            loss /= outputs.size()[0]
            avg_loss += loss

      avg_loss /= len(test_data_loader)

      class_report = classification_report(y_true2,y_pred2)
      conf_mat2 = confusion_matrix(y_true2,y_pred2)

      print(class_report)
      print("Confusion Matrix: \n",conf_mat2)
    
      f1 = f1_score(y_true2,y_pred2)
      return f1,avg_loss

nclass = 2
utt_emsize = 768
personality_size = 100
nhid = 768
nlayers = 6
nhead = 2
dropout = 0.2
count_speakers = len(speaker2idx)
model = EFR_TX(weight_matrix, utt2idx, nclass, personality_size + utt_emsize, count_speakers, personality_size, nhead, nhid, nlayers, device, dropout).to(device)

weights2 = [1.0, 2.5]
data_iter_train, data_iter_test = get_train_test_loader(batch_size)
model = train(model, data_iter_train, epochs = 1000)

4000


  0%|          | 0/1000 [00:00<?, ?it/s]



-------Epoch 1-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0205], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (0)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.76      0.92      0.83      1497
           1       0.35      0.13      0.19       490

    accuracy                           0.72      1987
   macro avg       0.55      0.53      0.51      1987
weighted avg       0.66      0.72      0.68      1987

Confusion Matrix: 
 [[1372  125]
 [ 424   66]]


-------Epoch 2-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0197], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (1)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.76      0.91      0.83      1497
           1       0.33      0.14      0.20       490

    accuracy                           0.72      1987
   macro avg       0.55      0.53      0.51      1987
weighted avg       0.66      0.72      0.67      1987

Confusion Matrix: 
 [[1358  139]
 [ 420   70]]


-------Epoch 3-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0181], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (2)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.76      0.90      0.82      1497
           1       0.32      0.15      0.21       490

    accuracy                           0.71      1987
   macro avg       0.54      0.52      0.52      1987
weighted avg       0.65      0.71      0.67      1987

Confusion Matrix: 
 [[1342  155]
 [ 416   74]]


-------Epoch 4-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0176], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (3)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.76      0.89      0.82      1497
           1       0.31      0.16      0.21       490

    accuracy                           0.71      1987
   macro avg       0.54      0.52      0.52      1987
weighted avg       0.65      0.71      0.67      1987

Confusion Matrix: 
 [[1327  170]
 [ 412   78]]


-------Epoch 5-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0168], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (4)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.77      0.87      0.82      1497
           1       0.32      0.19      0.24       490

    accuracy                           0.70      1987
   macro avg       0.54      0.53      0.53      1987
weighted avg       0.66      0.70      0.67      1987

Confusion Matrix: 
 [[1305  192]
 [ 399   91]]


-------Epoch 6-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0160], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (5)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.77      0.86      0.81      1497
           1       0.31      0.20      0.24       490

    accuracy                           0.69      1987
   macro avg       0.54      0.53      0.53      1987
weighted avg       0.65      0.69      0.67      1987

Confusion Matrix: 
 [[1280  217]
 [ 392   98]]


-------Epoch 7-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0149], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (6)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.77      0.83      0.80      1497
           1       0.30      0.22      0.26       490

    accuracy                           0.68      1987
   macro avg       0.53      0.53      0.53      1987
weighted avg       0.65      0.68      0.66      1987

Confusion Matrix: 
 [[1246  251]
 [ 381  109]]


-------Epoch 8-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0145], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (7)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.77      0.82      0.79      1497
           1       0.30      0.24      0.27       490

    accuracy                           0.68      1987
   macro avg       0.54      0.53      0.53      1987
weighted avg       0.65      0.68      0.66      1987

Confusion Matrix: 
 [[1228  269]
 [ 372  118]]


-------Epoch 9-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0145], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (8)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.77      0.80      0.78      1497
           1       0.30      0.26      0.27       490

    accuracy                           0.67      1987
   macro avg       0.53      0.53      0.53      1987
weighted avg       0.65      0.67      0.66      1987

Confusion Matrix: 
 [[1199  298]
 [ 365  125]]


-------Epoch 10-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0143], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (9)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.77      0.79      0.78      1497
           1       0.29      0.26      0.28       490

    accuracy                           0.66      1987
   macro avg       0.53      0.53      0.53      1987
weighted avg       0.65      0.66      0.66      1987

Confusion Matrix: 
 [[1187  310]
 [ 361  129]]


-------Epoch 11-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0127], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (10)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.77      0.78      0.77      1497
           1       0.29      0.28      0.28       490

    accuracy                           0.65      1987
   macro avg       0.53      0.53      0.53      1987
weighted avg       0.65      0.65      0.65      1987

Confusion Matrix: 
 [[1162  335]
 [ 355  135]]


-------Epoch 12-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0159], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (11)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.77      0.76      0.76      1497
           1       0.28      0.30      0.29       490

    accuracy                           0.64      1987
   macro avg       0.53      0.53      0.53      1987
weighted avg       0.65      0.64      0.65      1987

Confusion Matrix: 
 [[1133  364]
 [ 345  145]]


-------Epoch 13-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0135], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (12)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.76      0.73      0.75      1497
           1       0.28      0.31      0.29       490

    accuracy                           0.63      1987
   macro avg       0.52      0.52      0.52      1987
weighted avg       0.64      0.63      0.64      1987

Confusion Matrix: 
 [[1096  401]
 [ 337  153]]


-------Epoch 14-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0134], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (13)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.77      0.72      0.74      1497
           1       0.28      0.33      0.30       490

    accuracy                           0.62      1987
   macro avg       0.52      0.53      0.52      1987
weighted avg       0.65      0.62      0.63      1987

Confusion Matrix: 
 [[1076  421]
 [ 327  163]]


-------Epoch 15-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0110], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (14)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.77      0.69      0.73      1497
           1       0.27      0.35      0.31       490

    accuracy                           0.61      1987
   macro avg       0.52      0.52      0.52      1987
weighted avg       0.64      0.61      0.62      1987

Confusion Matrix: 
 [[1040  457]
 [ 319  171]]


-------Epoch 16-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0141], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (15)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.77      0.67      0.71      1497
           1       0.27      0.38      0.32       490

    accuracy                           0.60      1987
   macro avg       0.52      0.52      0.52      1987
weighted avg       0.65      0.60      0.62      1987

Confusion Matrix: 
 [[1001  496]
 [ 304  186]]


-------Epoch 17-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0129], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (16)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.77      0.64      0.70      1497
           1       0.28      0.42      0.33       490

    accuracy                           0.58      1987
   macro avg       0.52      0.53      0.52      1987
weighted avg       0.65      0.58      0.61      1987

Confusion Matrix: 
 [[954 543]
 [283 207]]


-------Epoch 18-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0129], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (17)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.77      0.62      0.68      1497
           1       0.27      0.44      0.34       490

    accuracy                           0.57      1987
   macro avg       0.52      0.53      0.51      1987
weighted avg       0.65      0.57      0.60      1987

Confusion Matrix: 
 [[921 576]
 [272 218]]


-------Epoch 19-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0139], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (18)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.78      0.61      0.68      1497
           1       0.28      0.46      0.35       490

    accuracy                           0.57      1987
   macro avg       0.53      0.53      0.51      1987
weighted avg       0.65      0.57      0.60      1987

Confusion Matrix: 
 [[911 586]
 [264 226]]


-------Epoch 20-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0110], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (19)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.78      0.59      0.67      1497
           1       0.28      0.48      0.35       490

    accuracy                           0.56      1987
   macro avg       0.53      0.53      0.51      1987
weighted avg       0.65      0.56      0.59      1987

Confusion Matrix: 
 [[879 618]
 [255 235]]


-------Epoch 21-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0110], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (20)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.78      0.56      0.65      1497
           1       0.27      0.50      0.35       490

    accuracy                           0.55      1987
   macro avg       0.52      0.53      0.50      1987
weighted avg       0.65      0.55      0.58      1987

Confusion Matrix: 
 [[841 656]
 [244 246]]


-------Epoch 22-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0114], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (21)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.78      0.55      0.64      1497
           1       0.27      0.52      0.36       490

    accuracy                           0.54      1987
   macro avg       0.53      0.53      0.50      1987
weighted avg       0.65      0.54      0.57      1987

Confusion Matrix: 
 [[819 678]
 [235 255]]


-------Epoch 23-------




0it [00:00, ?it/s]

Average Loss =  tensor([0.0137], device='cuda:0', grad_fn=<DivBackward0>)


***VALIDATION (22)***




0it [00:00, ?it/s]

              precision    recall  f1-score   support

           0       0.78      0.52      0.62      1497
           1       0.27      0.55      0.36       490

    accuracy                           0.52      1987
   macro avg       0.52      0.53      0.49      1987
weighted avg       0.65      0.52      0.56      1987

Confusion Matrix: 
 [[773 724]
 [222 268]]


-------Epoch 24-------




0it [00:00, ?it/s]