## Sub-Task 1 Inference

### Libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data

import dill
import pickle
import tqdm
import types

### Testing Pickles

In [None]:
pickle_folder_path = "../Pickles/Task 1/"


def load_erc():
    with open(pickle_folder_path + "idx2utt.pickle", "rb") as f:
        idx2utt = pickle.load(f)
    with open(pickle_folder_path + "utt2idx.pickle", "rb") as f:
        utt2idx = pickle.load(f)

    with open(pickle_folder_path + "idx2emo.pickle", "rb") as f:
        idx2emo = pickle.load(f)
    with open(pickle_folder_path + "emo2idx.pickle", "rb") as f:
        emo2idx = pickle.load(f)

    with open(pickle_folder_path + "idx2speaker.pickle", "rb") as f:
        idx2speaker = pickle.load(f)
    with open(pickle_folder_path + "speaker2idx.pickle", "rb") as f:
        speaker2idx = pickle.load(f)

    with open(pickle_folder_path + "weight_matrix.pickle", "rb") as f:
        weight_matrix = pickle.load(f)

    with open(pickle_folder_path + "train_data.pickle", "rb") as f:
        my_dataset_train = pickle.load(f)

    with open(pickle_folder_path + "test_data.pickle", "rb") as f:
        my_dataset_test = pickle.load(f)

    with open(pickle_folder_path + "final_speaker_info.pickle", "rb") as f:
        final_speaker_info = pickle.load(f)

    with open(pickle_folder_path + "final_speaker_dialogues.pickle", "rb") as f:
        final_speaker_dialogues = pickle.load(f)

    with open(pickle_folder_path + "final_speaker_emotions.pickle", "rb") as f:
        final_speaker_emotions = pickle.load(f)

    with open(pickle_folder_path + "final_speaker_indices.pickle", "rb") as f:
        final_speaker_indices = pickle.load(f)

    with open(pickle_folder_path + "final_utt_len.pickle", "rb") as f:
        final_utt_len = pickle.load(f)

    return idx2utt, utt2idx, idx2emo, emo2idx, idx2speaker, \
        speaker2idx, weight_matrix, my_dataset_train, my_dataset_test, \
        final_speaker_info, final_speaker_dialogues, final_speaker_emotions, \
        final_speaker_indices, final_utt_len

In [None]:
idx2utt, utt2idx, idx2emo, emo2idx, idx2speaker, \
    speaker2idx, weight_matrix, my_dataset_train, my_dataset_test, \
    final_speaker_info, final_speaker_dialogues, final_speaker_emotions, \
    final_speaker_indices, final_utt_len = load_erc()

### Layer

In [None]:
def create_emb_layer(weights_matrix, utt2idx):
    num_embeddings, embedding_dim = weights_matrix.size()
    emb_layer = nn.Embedding(
        num_embeddings, embedding_dim, padding_idx=utt2idx["<pad>"])
    emb_layer.load_state_dict({'weight': weights_matrix})
    emb_layer.weight.requires_grad = False
    return emb_layer, num_embeddings, embedding_dim

### Number of True Sentences

In [None]:
def get_num_sen(fsi):
    true_len = 0
    for x in fsi.items():
        true_len += len(x[1])
    return true_len

### Function to get one-hot speaker representation

In [None]:
top_speaker_names = ["maya", "indu", "sahil",
                     "monisha", "rosesh", "madhusudhan"]
NUM_SPK = len(top_speaker_names)


def get_spk_embedding(spk_original_ix):
    name = idx2speaker[spk_original_ix]
    if name in top_speaker_names:
        vec = torch.nn.functional.one_hot(torch.tensor(
            top_speaker_names.index(name), device=device), num_classes=NUM_SPK)
    else:
        vec = torch.zeros(NUM_SPK, device=device)
    return vec

### Inference

In [None]:

def inference(model):
    model.eval()

    with torch.no_grad():

        # Storing ans
        ans = []

        data_loader = data_iter_test

        for i_batch, sample_batched in tqdm.tqdm(enumerate(data_loader), total=len(data_loader)):
            dialogue_ids = sample_batched[0].tolist()
            dialogue_ids = [train_cnt+d for d in dialogue_ids]
            inputs = sample_batched[1]
            targets1 = sample_batched[2]

            _, outputs = model(dialogue_ids, final_speaker_info, final_speaker_dialogues,
                               final_speaker_emotions, final_speaker_indices, inputs, mode="valid")

            for b in range(outputs.size()[0]):

                # True length
                alpha = get_num_sen(final_speaker_indices[dialogue_ids[b]])
                beta = final_utt_len[dialogue_ids[b]]
                if alpha > beta:
                    for i in range(alpha-beta):
                        ans.append("neutral")

                for s in range(final_utt_len[dialogue_ids[b]]):
                    pred1 = torch.unsqueeze(outputs[b][s], dim=0).to(device)
                    pred_emo = torch.argmax(F.softmax(pred1, -1), -1).to(device)

                    # Upadting ans
                    if alpha > beta:
                        ans.append(idx2emo[pred_emo.item()])
                    else:
                        ans.append(idx2emo[pred_emo.item()])

        return ans

### For creating answer.txt

In [None]:
# Line numbers of each task in the txt file

IX1_BEGIN = 0
IX1_END = 1579
IX2_BEGIN = 1580
IX2_END = 9269
IX3_BEGIN = 9270
IX3_END = 17911

In [None]:
def ans_to_txt(ans, file_name):
    f = open(file_name, 'w+')

    # MaSaC - ERC
    for i in range(IX1_BEGIN, IX1_END+1):
        label = ans[i-IX1_BEGIN]
        f.write(str(label)+"\n")

    # MaSaC - EFR
    for i in range(IX2_BEGIN, IX2_END+1):
        f.write("0.0\n")

    # MELD - EFR    
    for i in range(IX3_BEGIN, IX3_END+1):
        f.write("0.0\n")

    print(len(ans))
    f.close()

### Updating Weight Matrix

In [None]:
device = 'cpu'

utt_ix_set = set()
n = len(weight_matrix)
d = len(weight_matrix[1])

weight_matrix = weight_matrix.to(device)
new_weight_matrix = torch.zeros([n, d+NUM_SPK], device=device)
for ix1, sample in enumerate(my_dataset_train):
    for ix2, utt_ix in enumerate(sample[1]):
        ix_u = int(utt_ix)
        spk_ix = final_speaker_info[ix1][ix2]
        new_weight_matrix[ix_u] = torch.cat(
            [weight_matrix[ix_u].to(device), get_spk_embedding(spk_ix)])
        utt_ix_set.add(ix_u)

for ix1, sample in enumerate(my_dataset_test):
    for ix2, utt_ix in enumerate(sample[1]):
        ix_u = int(utt_ix)
        spk_ix = final_speaker_info[ix1][ix2]
        new_weight_matrix[ix_u] = torch.cat(
            [weight_matrix[ix_u], get_spk_embedding(spk_ix)])
        utt_ix_set.add(ix_u)

weight_matrix = new_weight_matrix

### Loading the Data

In [None]:
# Loading test data
batch_size = 8
data_iter_test = data.DataLoader(my_dataset_test, batch_size=batch_size, shuffle=False)
train_cnt = len(my_dataset_train)

### Inference

In [None]:
def forward_test(self, chat_ids, speaker_info, sp_dialogues, sp_ind, inputs):
        whole_dialogue_indices = inputs
        
        bert_embs = self.embedding(whole_dialogue_indices)
               
        dialogue, h1 = self.rnnD(bert_embs)
        dialogue = self.drop1(dialogue)

        device = inputs.device
        
        fop = torch.zeros((dialogue.size()[0],dialogue.size()[1],dialogue.size()[2])).to(device)
        fop2 = torch.zeros((dialogue.size()[0],dialogue.size()[1],dialogue.size()[2]*3)).to(device)
        op = torch.zeros((dialogue.size()[0],dialogue.size()[1],dialogue.size()[2])).to(device)
        spop = torch.zeros((dialogue.size()[0],dialogue.size()[1],dialogue.size()[2]*2)).to(device)
               
        #################### Modified for testing ############################
        
        h0 = (0.5 * torch.ones(1, 1, self.hidden_size*2)).to(device)
        d_h = (0.5 * torch.ones(1, 1, self.hidden_size)).to(device)
        attn_h = (0.5 * torch.ones(1, 1, self.hidden_size)).to(device)
        
        ######################################################################
        
        for b in range(dialogue.size()[0]):
            d_id = chat_ids[b]
            speaker_hidden_states = {}
            for s in range(dialogue.size()[1]):
                fop = op.clone()
                
                current_utt = dialogue[b][s]
                
                current_speaker = speaker_info[d_id][s]
                
                if current_speaker not in speaker_hidden_states:
                    speaker_hidden_states[current_speaker] = h0
                
                h = speaker_hidden_states[current_speaker]
                current_utt_emb = torch.unsqueeze(torch.unsqueeze(current_utt,0),0)
                
                key = fop[b][:s+1].clone()
                key = torch.unsqueeze(key,0)
                
                if s == 0:
                    tmp = torch.cat([attn_h,current_utt_emb],-1).to(device)
                    spop[b][s], h_new = self.rnnS(tmp,h)
                else:
                    query = current_utt_emb
                    attn_op,_ = self.attn(key,query)
                    
                    tmp = torch.cat([attn_op,current_utt_emb],-1).to(device)
                    spop[b][s], h_new = self.rnnS(tmp,h)
                
                spop[b][s] = spop[b][s].add(tmp)      
                speaker_hidden_states[current_speaker] = h_new
                
                fop2[b][s] = torch.cat([spop[b][s],dialogue[b][s]],-1)
                tmp = torch.unsqueeze(torch.unsqueeze(fop2[b][s].clone(),0),0)
                op[b][s],d_h = self.rnnG(tmp,d_h)

        return op,spop

In [None]:
model_path = "../Models/model-task1"
file_name = "answer1.txt"

with open(model_path, "rb") as dill_file:
    model = dill.load(dill_file)

model.ia.embedding, num_embeddings, embedding_dim = create_emb_layer(weight_matrix, utt2idx)
model.ia.forward = types.MethodType(forward_test, model.ia)
try:
    model = model.to('cpu')
    ans = inference(model)
    ans_to_txt(ans, file_name)

except:
    print("Error")