# Auxility Class

## Library

In [None]:
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
# Dataset
from PIL import Image
from torchvision import transforms
from torchvision.io import read_video, read_image
from torch.utils.data import Dataset, DataLoader
# Model
from transformers import AutoModel
from transformers import AutoTokenizer, BertModel, BertTokenizer, BertGenerationDecoder, RobertaTokenizer, RobertaModel

# Training parameter
from torch.optim import Adam
# Training process
from tqdm import tqdm
# Metrics
from sklearn.metrics import classification_report, accuracy_score
from collections import OrderedDict
import copy
import sys
from sklearn.metrics import classification_report
from focal_loss.focal_loss import FocalLoss

## Function

In [None]:
def read_json(file_path):
    with open(file_path, 'r') as json_file:
        data = json.load(json_file)
    return data
def write_json(path, data):
    with open(path, 'w') as json_file:
        json.dump(data, json_file, indent=4)

## Early Stopper

In [None]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def reset(self):
        self.counter = 0

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

# Train

## Dataset

In [None]:
class ConversationDataset(Dataset):
    def __init__(self, config_path, data, option="train"):
        super(ConversationDataset, self).__init__()
        self.build(config_path)
        self.data = data
        self.option = option
    def build(self, config_path):
        config = self.read_json(config_path)
        self.label2id = config["label2id"]
        self.id2label = config["id2label"]
        self.tokenizer_name = config["tokenizer_name"]
        self.padding = config["padding"]
        self.max_length = config["max_length"]
        self.candidate_num = config["candidate num"]
        if self.tokenizer_name == 'bert-base-uncased':
            self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
        elif self.tokenizer_name == 'roberta-base':
            self.tokenizer = RobertaTokenizer.from_pretrained(self.tokenizer_name)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        conversation_ID = self.data[index]["conversation_ID"]
        # print(self.data[index])
        paragraph = self.tokenizer(self.data[index]["paragraph"], padding=self.padding, max_length=self.max_length, return_tensors="pt")
        paragraph = torch.squeeze(paragraph['input_ids'])
        main_text = self.tokenizer(self.data[index]["main_text"], padding=self.padding, max_length=self.max_length, return_tensors="pt")
        main_text = torch.squeeze(main_text['input_ids'])
        utterance_ID = self.data[index]["utterance_ID"]

        casual_pool = torch.squeeze(self.tokenizer(self.data[index]["casual_pool"], padding=self.padding, max_length=self.max_length, return_tensors="pt")['input_ids'])
        # print("casual_pool: "+str(casual_pool))
        casual_span_pool = torch.FloatTensor(self.data[index]["casual_span_pool"])

        if self.option == "train":
            label = self.label2id[self.data[index]["emotion"]]
            emotion_label = [0] * len(self.id2label)
            emotion_label[label] = 1
            # print(emotion_label)
            emotion_label = torch.FloatTensor(emotion_label)
            span_label = torch.FloatTensor([self.data[index]["span_label"]])
            return {
                "conversation_ID": conversation_ID,
                "paragraph": paragraph,
                "main_text": main_text,
                "utterance_ID": utterance_ID,
                "casual_pool": casual_pool,
                "casual_span_pool": casual_span_pool,
                "emotion_label": emotion_label,
                "span_label": span_label
            }
        else:
            return {
                "conversation_ID": conversation_ID,
                "paragraph": paragraph,
                "main_text": main_text,
                "utterance_ID": utterance_ID,
                "casual_pool": casual_pool,
                "casual_span_pool": casual_span_pool
            }
    
    def read_json(self, file_path):
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data

## Trainer

In [None]:
class Trainer(nn.Module):
    def __init__(self, config_path):
        super(Trainer, self).__init__()
        self.layers_to_freeze = []
        self.history = {
            "train loss": [],
            "valid loss": [],
            "emotion_correct": [],
            "span_correct": []
        }
        self.build(config_path)
    def build(self, config_path):
        print(config_path)
        config = read_json(config_path)
        self.patience = config['patience']
        self.min_delta = config['min_delta']
        self.fn_loss_name = config['loss_fn_name']
        self.batch_size_config = config['batch_size']
        self.device = config['device']
        self.save_checkpoint_dir = config['save_checkpoint_dir']
        self.save_history_dir = config['save_history_dir']
        self.emotion_head_layers = config['emotion_head_layers']
        self.casual_emotion_layers = config['casual_emotion_layers']
        self.rnn_layers = config['rnn_layers']
        self.main_layers = config['main_layers']
        self.label2id = config['label2id']
        self.alpha_focal_loss = config['alpha_focal_loss']
        self.gamma_focal_loss = config['gamma_focal_loss']
        self.target_names = self.label2id.keys()
        self.early_stopper = EarlyStopper(self.patience, self.min_delta)
    def emotion_head_training(self):
        self.layers_to_freeze = self.casual_emotion_layers + self.main_layers
        # self.fn_loss = nn.CrossEntropyLoss()
        self.fn_loss = FocalLoss(gamma=self.gamma_focal_loss, weights=torch.tensor(self.alpha_focal_loss).to(self.device))
    def emotion_main_training(self):
        self.layers_to_freeze = self.casual_emotion_layers
        # self.fn_loss = nn.CrossEntropyLoss()
        self.fn_loss = FocalLoss(gamma=self.gamma_focal_loss, weights=torch.tensor(self.alpha_focal_loss))
        
    def casual_emotion_training(self):
        self.layers_to_freeze = self.emotion_head_layers + self.main_layers
        self.fn_loss = nn.MSELoss()

    def full_training(self):
        self.layers_to_freeze = []
    def freeze(self, model):
        for name, param in model.named_parameters():
            if any(layer in name for layer in self.layers_to_freeze):
                param.requires_grad = False
            else:
                param.requires_grad = True
        return model
    def train(self, model, optimizer, train_dataset, valid_dataset, epochs, option="emotion_head", batch_size=None):
        self.optimizer = optimizer
        if option == "emotion_head":
            print("TRAINING " + option + " MODEL")
            self.emotion_head_training()
        elif option == "emotion_main":
            print("TRAINING " + option + " MODEL")
            self.emotion_main_training()
        elif option == "casual_head":
            print("TRAINING " + option + " MODEL")
            self.casual_emotion_training()

        if batch_size == None:
            self.batch_size = self.batch_size_config
        else:
            self.batch_size = batch_size
        print("device: "+str(self.device))
        print("batch_size: "+str(self.batch_size))
        print("gamma_focal_loss: "+str(self.gamma_focal_loss))
        print("alpha_focal_loss: "+str(self.alpha_focal_loss))


        self.early_stopper.reset()
        self.model = self.freeze(model)
        self.model = self.model.to(self.device)

        self.train_dataloader = DataLoader(dataset=train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=False)
        self.valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=self.batch_size, shuffle=True, drop_last=False)

        best_val_loss = np.inf
        self.epochs = epochs
        for epoch in range(epochs):
            train_loss = self._train_epoch(epoch, option)
            val_loss, emotion_correct, span_correct, emotion_label_list, emotion_pred_list = self._val_epoch(epoch, option)
            
            self.history["train loss"].append(train_loss)
            self.history["valid loss"].append(val_loss)
            self.history["emotion_correct"].append(emotion_correct)
            self.history["span_correct"].append(span_correct)
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                self.save_checkpoint(epoch, option)
                print(f"Epoch {epoch+1}: Train loss {train_loss:.4f},  Valid loss {val_loss:.4f},  emotion_correct {emotion_correct*100:.4f}%,  span_correct {span_correct*100:.4f}%")
            
                print(classification_report(emotion_label_list, emotion_pred_list, target_names=self.target_names))
            if self.early_stopper.early_stop(val_loss):
                print("Early stop at epoch: "+str(epoch) + " with valid loss: "+str(val_loss))
                break
        write_json(self.save_history_dir + "/" + option + "_history.json", self.history) 
    def _train_epoch(self, epoch, option):
        self.model.train()
        train_loss = 0.0
        logger_message = f'Training epoch {epoch}/{self.epochs}'

        progress_bar = tqdm(self.train_dataloader,
                            desc=logger_message, initial=0, dynamic_ncols=True)
        for batch, data in enumerate(progress_bar):
            # conversation_id = data["conversation_ID"]
            paragraph = data["paragraph"].to(self.device)
            main_text = data["main_text"].to(self.device)
            # utter_id = data["utterance_ID"]
            casual_pool = data["casual_pool"].to(self.device)
            # casual_span_pool = data["casual_span_pool"]
            emotion_label = data["emotion_label"].to(self.device)
            span_label = data["span_label"].to(self.device)
            # print("conversation_id:"+str(conversation_id.shape))
            # print("paragraph:"+str(paragraph.shape))
            # print("casual_pool:"+str(casual_pool.shape))
            # print("emotion_label:"+str(emotion_label.shape))
            # print("span_label:"+str(span_label.shape))
            self.optimizer.zero_grad()
            
            emotion_pred, span_pred = self.model(main_text, paragraph, casual_pool)
            
            emotion_label = torch.argmax(emotion_label, dim=-1)
            if option == "emotion_head":
                # print(emotion_pred)
                # print(emotion_pred.shape)
                loss = self.fn_loss(emotion_pred, emotion_label)

            elif option == "emotion_main":
                loss = self.fn_loss(emotion_pred, emotion_label)
            elif option == "casual_head":
                loss = self.fn_loss(span_pred, span_label)
            else:
                emotion_loss = self.fn_loss(emotion_pred, emotion_label)
                span_loss = self.fn_loss(span_pred, span_label)
                loss = emotion_loss + span_loss

            train_loss += loss.item()
            loss.backward()
            self.optimizer.step()
        return train_loss / len(self.train_dataloader)

    def _val_epoch(self, epoch, option):
        self.model.eval()
        valid_loss = 0.0
        emotion_correct=0
        span_correct=0
        emotion_pred_list = []
        emotion_label_list = []
        logger_message = f'Validation epoch {epoch}/{self.epochs}'
        progress_bar = tqdm(self.valid_dataloader,
                            desc=logger_message, initial=0, dynamic_ncols=True)
        with torch.no_grad():
            for _, data in enumerate(progress_bar):
                # conversation_id = data["conversation_ID"]
                paragraph = data["paragraph"].to(self.device)
                main_text = data["main_text"].to(self.device)
                # utter_id = data["utterance_ID"]
                casual_pool = data["casual_pool"].to(self.device)
                # casual_span_pool = data["casual_span_pool"]
                emotion_label = data["emotion_label"].to(self.device)
                span_label = data["span_label"].to(self.device)
                emotion_pred, span_pred = self.model(main_text, paragraph, casual_pool)
                
                
                emotion_label = torch.argmax(emotion_label, dim=-1)
                if option == "emotion_head":
                    loss = self.fn_loss(emotion_pred, emotion_label)
                elif option == "emotion_main":
                    loss = self.fn_loss(emotion_pred, emotion_label)
                elif option == "casual_head":
                    loss = self.fn_loss(span_pred, span_label)
                else:
                    emotion_loss = self.fn_loss(emotion_pred, emotion_label)
                    span_loss = self.fn_loss(span_pred, span_label)
                    loss = emotion_loss + span_loss

                # emotion_label_ = torch.argmax(emotion_label, dim=-1)
                emotion_pred_ = torch.argmax(emotion_pred, dim=-1)
                emotion_pred_list.extend(emotion_pred_.cpu().tolist())
                emotion_label_list.extend(emotion_label.cpu().tolist())
                # print("emotion_label_: "+ str(emotion_label_))
                # print("emotion_pred_: "+ str(emotion_pred_))
                valid_loss += loss.item()
                emotion_correct += (emotion_pred_ == emotion_label).sum().item()
                span_correct += (span_pred == span_label).sum().item()
        
        return valid_loss / len(self.valid_dataloader), emotion_correct / len(self.valid_dataloader) / self.batch_size, span_correct / len(self.valid_dataloader) / self.batch_size, emotion_label_list, emotion_pred_list
        
    def save_checkpoint(self, epoch, option):
        checkpoint_path = os.path.join(self.save_checkpoint_dir, option + '/checkpoint_{}.pth'.format(epoch))
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'epoch': epoch
        }, checkpoint_path)

## Predictor

In [None]:
class Prediction(nn.Module):
    def __init__(self, config_path):
        super(Prediction, self).__init__()
        self.build(config_path)
    def build(self, config_path):
        config = self.read_json(config_path)
        self.id2label = config["id2label"]
        self.label2id = config["label2id"]
        self.submit_path = config["submit path"]
        self.device = config["device"]
        self.batch_size = config["batch_size"]
        self.top_k = config["top_k"]
        
    def predict(self, model, raw_dataset, dataset):
        self.model = model.to(self.device)
        self.model.eval()
        pred_list = []
        
        self.test_dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=False)
        logger_message = f'Predict '
        progress_bar = tqdm(self.test_dataloader,
                            desc=logger_message, initial=0, dynamic_ncols=True)
        with torch.no_grad():
            for _, data in enumerate(progress_bar):
                # print("conversation_ID "+str(data["conversation_ID"].shape))
                # print("paragraph "+str(data["paragraph"].shape))
                # print("utterance_ID "+str(data["utterance_ID"].shape))
                # print("casual_pool "+str(data["casual_pool"].shape))
                # print("casual_span_pool "+str(data["casual_span_pool"].shape))
                conversation_id = data["conversation_ID"]
                paragraph = data["paragraph"].to(self.device)
                main_text = data["main_text"].to(self.device)
                utter_id = data["utterance_ID"]
                casual_pool = data["casual_pool"].to(self.device)
                casual_span_pool = data["casual_span_pool"]
                # emotion_label = data["emotion_label"].to(self.device)
                # span_label = data["span_label"].to(self.device)
                emotion_pred, span_pred = self.model(main_text, paragraph, casual_pool)
                
                pred = self.convertpred2list(conversation_id, utter_id, emotion_pred, span_pred, casual_span_pool)
                pred_list.extend(pred)
        
        submit = self.convertpred2dict(raw_dataset, pred_list)
        # self.save_json("data/predict/convertpred2dict.json", submit)
        submit = self.get_top(submit, self.top_k)
        # self.save_json("data/predict/get_top.json", submit)
        submit = self.convertdict2submit(submit)
        # print(submit)
        self.save_json(self.submit_path, submit)
    def convertpred2list(self, conversation_id, utter_id, emotion_pred, span_pred, casual_span_pool):
        pred_list = []
        for i in range(len(emotion_pred)):
            emotion = self.id2label[str(torch.argmax(emotion_pred[i]).item())]
            emotion_string = str(utter_id[i].item()) + "_" + emotion
            if emotion != "neutral":
                # 
                span = span_pred[i].item()
                # print("casual_span_pool: "+str(casual_span_pool[i]))
                # if span >= self.predict_threshold:
                if casual_span_pool[i][0].item() == 0:
                    break
                if casual_span_pool[i][1].item() == casual_span_pool[i][2].item():
                    break
                span_string = str(int(casual_span_pool[i][0].item())) + "_" + str(int(casual_span_pool[i][1].item())) + "_" + str(int(casual_span_pool[i][2].item()))
                pred_dict = {}
                pred_dict["conversation_ID"] = conversation_id[i].item()
                pred_dict["emotion-cause_pairs"] = [emotion_string, span_string]
                pred_dict["logits"] = span
                # print("pred_dict: "+str(pred_dict))
                pred_list.append(pred_dict)
        # print("list: "+str(pred_list))
        return pred_list
    def convertpred2dict(self, raw_dataset, pred_list):
        
        dataset = self.convertlist2dict(raw_dataset)
        # print(dataset)
        # dict = 
        for i in range(len(pred_list)):
            pred = pred_list[i]
            # print("pred: "+str(pred))
            conversation_id = pred["conversation_ID"]
            # print("conversation_id: "+str(conversation_id))
            emotion_cause_pairs = pred["emotion-cause_pairs"]
            # print("emotion_cause_pairs: "+str(emotion_cause_pairs))
            if "emotion-cause_pairs" in dataset[conversation_id]:
                dataset[conversation_id]["emotion-cause_pairs"].append(emotion_cause_pairs)
                dataset[conversation_id]["logits"].append(pred["logits"])
                # print(pred["logits"])
            else:
                dataset[conversation_id]["emotion-cause_pairs"] = []
        # print(dataset)
        
        return dataset
    def get_top(self, raw_dataset, top_k):
        for conversation_ID in raw_dataset.keys():
            # print("i : "+str(conversation_ID))
            # print("raw_dataset : "+str(raw_dataset[conversation_ID]))
            # print("raw_dataset[i]['logits'] : "+str(raw_dataset[conversation_ID]['logits']))
            logits_list = raw_dataset[conversation_ID]['logits']
            cause_pairs_list = raw_dataset[conversation_ID]['emotion-cause_pairs']
            indices_of_topk = sorted(range(len(logits_list)), key=lambda i: logits_list[i], reverse=True)[:top_k]
            new_cause_pairs_list = []
            for index in indices_of_topk:
                new_cause_pairs_list.append(cause_pairs_list[index])
            raw_dataset[conversation_ID]['emotion-cause_pairs'] = new_cause_pairs_list
        return raw_dataset
    
    def convertlist2dict(self, raw_dataset):
        dataset = {}
        for i in range(len(raw_dataset)):
            conversation = raw_dataset[i]
            conversation_ID = conversation["conversation_ID"]
            # if conversation_ID not in dataset.key
            dataset[conversation_ID] = {
                "conversation_ID": conversation_ID,
                "conversation": conversation['conversation'],
                "emotion-cause_pairs": [],
                "logits": []
            }
        return dataset
    def convertdict2submit(self, raw_dataset):
        dataset = []
        for key in raw_dataset.keys():
            # print("key: "+str(key))
            sample = raw_dataset[key]
            sample = {
                "conversation_ID": raw_dataset[key]['conversation_ID'],
                "conversation": raw_dataset[key]['conversation'],
                "emotion-cause_pairs": raw_dataset[key]['emotion-cause_pairs']
            }
            # print("sample: "+str(sample))
            dataset.append(sample)
        return dataset
    def read_json(self, file_path):
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data
    def save_json(self, file_path, data):
        with open(file_path, 'w') as json_file:
            json.dump(data, json_file)

# Model

In [None]:
class TGG(nn.Module):
    def __init__(self, config_path):
        super(TGG, self).__init__()
        self.build(config_path)
        self.emotion_fc = nn.Linear(self.embedding_dim, self.label_num)
        self.casual_feedforward = nn.Sequential(
            nn.Linear(self.embedding_dim, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.embedding_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        # self.emotion_feedforward = nn.Sequential(
        #     nn.Linear(self.embedding_dim, self.hidden_size),
        #     nn.ReLU(),
        #     nn.Linear(self.hidden_size, self.embedding_dim),
        #     nn.ReLU(),
        #     nn.Dropout(0.2)
        # )
        self.cat_layer = nn.Linear(2*self.embedding_dim, self.embedding_dim)
        self.softmax = nn.Softmax(dim=-1)
    def build(self, config_path):
        config = self.read_json(config_path)
        self.hidden_size = config["hidden_size"]
        self.embedding_dim = config["embedding_dim"]
        self.token_size = config["token_size"]
        self.main_part_name = config["main_part_name"]
        self.label2id = config["label2id"]
        self.label_num = len(self.label2id)
        self.device = config["device"]
        if self.main_part_name == 'bert-base-uncased':
        
            main_part = AutoModel.from_pretrained(self.main_part_name)
            self.embeddings = main_part.embeddings
            self.encoder = main_part.encoder
            self.pooler = main_part.pooler
        elif self.main_part_name == 'roberta-base':    
            main_part = RobertaModel.from_pretrained(self.main_part_name)
            self.embeddings = main_part.embeddings
            self.encoder = main_part.encoder
            self.pooler = main_part.pooler
        
    def forward(self, main_text, paragraph, candidate):
        # batch_size = paragraph.size(0)
        # print("x: "+str(x.shape))
        # print("candidate: "+str(candidate.shape))
        _paragraph = self.embeddings(paragraph)
        _paragraph = self.encoder(_paragraph)['last_hidden_state']
        _paragraph = self.pooler(_paragraph)
        # _main_text= self.embeddings(main_text)
        # _main_text = self.encoder(_main_text)['last_hidden_state']
        # _main_text = self.pooler(_main_text)

        # _x = self.cat_layer(torch.cat([_paragraph, _main_text], dim =-1))
        # _e_category = self.emotion_fc(_x)
        _e_category = self.emotion_fc(_paragraph)

        # print("_e_category: "+str(_e_category.shape))
        
        _candidate = self.embeddings(candidate)
        _candidate = self.encoder(_candidate)['last_hidden_state']
        _candidate = self.pooler(_candidate)
        # print("_candidate: "+str(_candidate.shape))


        _candidate = self.casual_feedforward(_candidate)
        # print("_candidate: "+str(_candidate.shape))

        _similarity = torch.einsum('ij,ij->i', _paragraph, _candidate)
        
        # print("_similarity: "+str(_similarity.shape))
        return self.softmax(_e_category), torch.sigmoid(_similarity)
    def read_json(self, file_path):
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data

In [None]:
# class TGG(nn.Module):
#     def __init__(self, config_path):
#         super(TGG, self).__init__()
#         self.build(config_path)
#         self.emotion_fc = nn.Linear(self.embedding_dim, self.label_num)
#         self.casual_feedforward = nn.Sequential(
#             nn.Linear(self.embedding_dim, self.hidden_size),
#             nn.ReLU(),
#             nn.Linear(self.hidden_size, self.embedding_dim),
#             nn.ReLU(),
#             nn.Dropout(0.2)
#         )
#         self.cat_layer = nn.Linear(3*self.embedding_dim, self.embedding_dim)
#         self.rnn = nn.LSTM(input_size=self.embedding_dim, hidden_size =self.embedding_dim, num_layers = 2, batch_first=True, bidirectional=True)
#         self.softmax = nn.Softmax(dim=-1)
#     def build(self, config_path):
#         config = self.read_json(config_path)
#         self.hidden_size = config["hidden_size"]
#         self.embedding_dim = config["embedding_dim"]
#         self.token_size = config["token_size"]
#         self.main_part_name = config["main_part_name"]
#         self.label2id = config["label2id"]
#         self.label_num = len(self.label2id)
#         self.device = config["device"]
#         if self.main_part_name == 'bert-base-uncased':
        
#             main_part = AutoModel.from_pretrained(self.main_part_name)
#             self.embeddings = main_part.embeddings
#             self.encoder = main_part.encoder
#             self.pooler = main_part.pooler
#         elif self.main_part_name == 'roberta-base':    
#             main_part = RobertaModel.from_pretrained(self.main_part_name)
#             self.embeddings = main_part.embeddings
#             self.encoder = main_part.encoder
#             self.pooler = main_part.pooler
        
#     def forward(self, main_text, paragraph, candidate):
#         batch_size = paragraph.size(0)
#         # print("x: "+str(x.shape))
#         # print("candidate: "+str(candidate.shape))
#         _paragraph = self.embeddings(paragraph)
#         _paragraph = self.encoder(_paragraph)['last_hidden_state']
#         _paragraph = self.pooler(_paragraph)

#         h0 = torch.randn(4, batch_size, self.embedding_dim)
#         c0 = torch.randn(4, batch_size, self.embedding_dim)
#         print("paragraph: "+str(paragraph.shape))
#         output, (hn, cn) = self.rnn(paragraph, (h0, c0))
#         print("output: "+str(output.shape))

#         _x = torch.cat([_paragraph, output], dim=-1)
#         _x = self.cat_layer(_x)
#         # print("_x: "+str(_x.shape))
#         _e_category = self.emotion_fc(_x)
#         # _e_category = self.emotion_fc(_paragraph)

#         # print("_e_category: "+str(_e_category.shape))
        
#         _candidate = self.embeddings(candidate)
#         _candidate = self.encoder(_candidate)['last_hidden_state']
#         _candidate = self.pooler(_candidate)
#         # print("_candidate: "+str(_candidate.shape))


#         _candidate = self.casual_feedforward(_candidate)
#         # print("_candidate: "+str(_candidate.shape))

#         _similarity = torch.einsum('ij,ij->i', _paragraph, _candidate)
        
#         # print("_similarity: "+str(_similarity.shape))
#         return self.softmax(_e_category), torch.sigmoid(_similarity)
#     def read_json(self, file_path):
#         with open(file_path, 'r') as json_file:
#             data = json.load(json_file)
#         return data

# Test

In [None]:
config_path = "config.json"
model = TGG(config_path)
optimizer = Adam(model.parameters(), lr=0.001)
trainer = Trainer(config_path)

In [None]:
for name, param in model.named_parameters():
    print("\"" + name + "\",")

## Load dataset

In [None]:
dataset = read_json("data/ECF 2.0/train/original_emotion_train.json")
n_train = int(0.8 * len(dataset))
train_data, valid_data = dataset[:n_train], dataset[n_train:]
train_dataset = ConversationDataset(config_path, train_data)
valid_dataset = ConversationDataset(config_path, valid_data)
trainer.train(model, optimizer, train_dataset=train_dataset, valid_dataset=valid_dataset, epochs=100, option="emotion_head", batch_size=64)

In [None]:
dataset = read_json("data/ECF 2.0/train/train.json")
n_train = int(0.8 * len(dataset))
train_data, valid_data = dataset[:n_train], dataset[n_train:]
train_dataset = ConversationDataset(config_path, train_data)
valid_dataset = ConversationDataset(config_path, valid_data)
trainer.train(model, optimizer, train_dataset=train_dataset, valid_dataset=valid_dataset, epochs=100, option="casual_head", batch_size=64)

## Predict

In [None]:
trial_data = read_json("data/ECF 2.0/trial/trial.json")
raw_trial_data = read_json("data/ECF 2.0/trial/Subtask_1_trial.json")

In [None]:
trial_dataset = ConversationDataset(config_path, trial_data, option="trial")
predict = Prediction(config_path)
predict.predict(model, raw_trial_data, trial_dataset)

In [None]:
# dataset = read_json("data/ECF 2.0/train/train.json")
# n_train = int(0.8 * len(dataset))
# train_data, valid_data = dataset[:n_train], dataset[n_train:]
# train_dataset = ConversationDataset(config_path, train_data)
# valid_dataset = ConversationDataset(config_path, valid_data)
# trainer = Trainer(config_path)
# trainer.train(model, optimizer, train_dataset=train_dataset, valid_dataset=valid_dataset, epochs=100, option="casual_emotion", batch_size=32)

In [None]:
# trial_data = read_json("data/ECF 2.0/trial/trial.json")
# raw_trial_data = read_json("data/ECF 2.0/trial/Subtask_1_trial.json")

# trial_dataset = ConversationDataset(config_path, trial_data, option="trial")
# predict = Prediction(config_path)
# predict.predict(model, raw_trial_data, trial_dataset)