## Sub-Task 2 Inference

### Libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
torch.set_default_device('cuda')

import tqdm
import dill
import math
import pickle

### Testing Pickles

In [None]:
pickle_folder_path = "../Pickles/Task 2/"


def load_efr():
    with open(pickle_folder_path + "idx2utt.pickle", "rb") as f:
        idx2utt = pickle.load(f)
    with open(pickle_folder_path + "utt2idx.pickle", "rb") as f:
        utt2idx = pickle.load(f)

    with open(pickle_folder_path + "idx2emo.pickle", "rb") as f:
        idx2emo = pickle.load(f)
    with open(pickle_folder_path + "emo2idx.pickle", "rb") as f:
        emo2idx = pickle.load(f)

    with open(pickle_folder_path + "idx2speaker.pickle", "rb") as f:
        idx2speaker = pickle.load(f)
    with open(pickle_folder_path + "speaker2idx.pickle", "rb") as f:
        speaker2idx = pickle.load(f)

    with open(pickle_folder_path + "weight_matrix.pickle", "rb") as f:
        weight_matrix = pickle.load(f)

    with open(pickle_folder_path + "train_data_trig.pickle", "rb") as f:
        my_dataset_train = pickle.load(f)

    with open(pickle_folder_path + "test_data_trig.pickle", "rb") as f:
        my_dataset_test = pickle.load(f)

    with open(pickle_folder_path + "global_speaker_info_trig.pickle", "rb") as f:
        global_speaker_info = pickle.load(f)

    with open(pickle_folder_path + "speaker_dialogues_trig.pickle", "rb") as f:
        speaker_dialogues = pickle.load(f)

    with open(pickle_folder_path + "speaker_emotions_trig.pickle", "rb") as f:
        speaker_emotions = pickle.load(f)

    with open(pickle_folder_path + "speaker_indices_trig.pickle", "rb") as f:
        speaker_indices = pickle.load(f)

    with open(pickle_folder_path + "utt_len_trig.pickle", "rb") as f:
        utt_len = pickle.load(f)

    with open(pickle_folder_path + "global_speaker_info_test_trig.pickle", "rb") as f:
        global_speaker_info_test = pickle.load(f)

    with open(pickle_folder_path + "speaker_dialogues_test_trig.pickle", "rb") as f:
        speaker_dialogues_test = pickle.load(f)

    with open(pickle_folder_path + "speaker_emotions_test_trig.pickle", "rb") as f:
        speaker_emotions_test = pickle.load(f)

    with open(pickle_folder_path + "speaker_indices_test_trig.pickle", "rb") as f:
        speaker_indices_test = pickle.load(f)

    with open(pickle_folder_path + "utt_len_test_trig.pickle", "rb") as f:
        utt_len_test = pickle.load(f)

    return idx2utt, utt2idx, idx2emo, emo2idx, idx2speaker, \
        speaker2idx, weight_matrix, my_dataset_train, my_dataset_test, \
        global_speaker_info, speaker_dialogues, speaker_emotions, \
        speaker_indices, utt_len, global_speaker_info_test, speaker_dialogues_test, \
        speaker_emotions_test, speaker_indices_test, utt_len_test

In [None]:
idx2utt, utt2idx, idx2emo, emo2idx, idx2speaker,\
        speaker2idx, weight_matrix, my_dataset_train, my_dataset_test,\
        global_speaker_info, speaker_dialogues, speaker_emotions, \
        speaker_indices, utt_len_train, global_speaker_info_test, speaker_dialogues_test, \
        speaker_emotions_test, speaker_indices_test, utt_len_test = load_efr()

### Layer

In [None]:
def create_emb_layer(weights_matrix, utt2idx):
    num_embeddings, embedding_dim = weights_matrix.size()
    emb_layer = nn.Embedding(
        num_embeddings, embedding_dim, padding_idx=utt2idx["<pad>"])
    emb_layer.load_state_dict({'weight': weights_matrix})
    emb_layer.weight.requires_grad = False
    return emb_layer, num_embeddings, embedding_dim

### Number of True Sentences

In [None]:
PAD_SP_IX = speaker2idx["<pad>"]


def get_num_sen(gsp):
    x = len(gsp)
    if x > 5:
        return x
    count = 0
    for ix, val in gsp.items():
        if val != PAD_SP_IX:
            count += 1
    return count

### Function to get one-hot speaker representation

In [None]:
top_speaker_names = ["maya", "indu", "sahil", "monisha", "rosesh", "madhusudhan"]
NUM_SPK = len(top_speaker_names)


def get_spk_embedding(spk_original_ix):
    name = idx2speaker[spk_original_ix]
    if name in top_speaker_names:
        vec = torch.nn.functional.one_hot(torch.tensor(
            top_speaker_names.index(name), device=device), num_classes=NUM_SPK)
    else:
        vec = torch.zeros(NUM_SPK, device=device)
    return vec

### Inference

In [None]:
def inference(model):
    model.eval()
    with torch.no_grad():

        # Storing ans
        ans = []

        data_loader = data_iter_test
        for i_batch, sample_batched in tqdm.tqdm(enumerate(data_loader)):
            dialogue_ids = sample_batched[0].tolist()
            inputs = sample_batched[1].to(device)
            emotions = sample_batched[2].to(device)

            emo_one_hot = torch.zeros((emotions.size()[0], emotions.size()[1], 8)).to(device)
            
            for b in range(emotions.size()[0]):
                for s in range(emotions.size()[1]):
                    emo_one_hot[b][s][emotions[b][s].item()] = 1

            # Creating the speaker_ids
            speaker_ids = []
            for d_ids_list in dialogue_ids:
                sp_id_list = [0] * len(d_ids_list)
                for ix, d_id in enumerate(d_ids_list):
                    sp_id = global_speaker_info[d_id][0]
                    sp_id_list[ix] = sp_id
                speaker_ids.append(sp_id_list)

            outputs = model(inputs, emo_one_hot, dialogue_ids,speaker_ids, utt_len_test)

            for b in range(outputs.size()[0]):

                # True length
                alpha = get_num_sen(global_speaker_info_test[dialogue_ids[b][0]])
                beta = utt_len_test[dialogue_ids[b][0]]
                if alpha > beta:
                    for i in range(alpha-beta):
                        ans.append(0)

                for s in range(utt_len_test[dialogue_ids[b][0]]):
                    pred2 = outputs[b][s]
                    pred_flip = torch.argmax(F.softmax(pred2.to(device), -1), -1)

                    # Updating ans
                    if alpha > beta:
                        ans.append(pred_flip.item())
                    else:
                        ans.append(pred_flip.item())

        return ans

### For creating answer.txt

In [None]:
# Line numbers of each task in the txt file

IX1_BEGIN = 0
IX1_END = 1579
IX2_BEGIN = 1580
IX2_END = 9269
IX3_BEGIN = 9270
IX3_END = 17911

def ans_to_txt(ans, file_name):
    f = open(file_name, 'w+')

    # MaSaC - ERC
    for i in range(IX1_BEGIN, IX1_END+1):
        f.write("neutral\n")

    # MaSaC - EFR
    for i in range(IX2_BEGIN, IX2_END+1):
        label = ans[i-IX2_BEGIN]
        f.write(str(label)+".0\n")
    

    # MELD - EFR
    for i in range(IX3_BEGIN, IX3_END+1):
        f.write("0.0\n")

    f.close()

### Loading model and testing

In [None]:
batch_size = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_iter_test = data.DataLoader(my_dataset_test, batch_size=batch_size)

In [None]:
model_path = "../Models/model-task2"
file_name = "answer2.txt"

with open(model_path, "rb") as dill_file:
        model = dill.load(dill_file)

try:
        model.emoGRU.RNN.flatten_parameters()
        model.encoder, _, _ = create_emb_layer(weight_matrix,utt2idx)
        ans = inference(model)
        ans_to_txt(ans, file_name)
except:
        print("Error")