# **LIBRARY**

In [1]:
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
# Dataset
from PIL import Image
from torchvision import transforms
from torchvision.io import read_video, read_image
from torch.utils.data import Dataset, DataLoader
# Model
from transformers import AutoModel
from transformers import AutoTokenizer, BertModel, BertTokenizer, BertGenerationDecoder

# Training parameter
from torch.optim import Adam
# Training process
from tqdm import tqdm
# Metrics
from sklearn.metrics import classification_report, accuracy_score
from collections import OrderedDict
import copy
import sys

  from .autonotebook import tqdm as notebook_tqdm


# Auxility class

In [2]:
def read_json(file_path):
    with open(file_path, 'r') as json_file:
        data = json.load(json_file)
    return data

## Early Stopper

In [3]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def reset(self):
        self.counter = 0

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

## Dataset

In [4]:
class ConversationDataset(Dataset):
    def __init__(self, config_path, data, option="train"):
        super(ConversationDataset, self).__init__()
        self.build(config_path)
        self.data = data
        self.option = option
    def build(self, config_path):
        config = self.read_json(config_path)
        self.label2id = config["label2id"]
        self.id2label = config["id2label"]
        self.tokenizer_name = config["tokenizer_name"]
        self.padding = config["padding"]
        self.max_length = config["max_length"]
        self.candidate_num = config["candidate num"]
        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        conversation_ID = self.data[index]["conversation_ID"]
        paragraph = self.tokenizer(self.data[index]["paragraph"], padding=self.padding, max_length=self.max_length, return_tensors="pt")
        paragraph = torch.squeeze(paragraph['input_ids'])
        utterance_ID = self.data[index]["utterance_ID"]

        casual_pool = []
        for i in range(len(self.data[index]["casual_pool"])):
            token = self.tokenizer(self.data[index]["casual_pool"][i], padding=self.padding, max_length=self.max_length)
            casual_pool.append(token['input_ids'])
        
        if self.option == "train":
            span_label = self.data[index]["span_label"].copy()
            emotion_label = self.label2id[self.data[index]["emotion"]]

        casual_span_pool = self.data[index]["casual_span_pool"].copy()
        for i in range(self.candidate_num-len(casual_pool)):
            token = self.tokenizer("", padding=self.padding, max_length=self.max_length)
            casual_pool.append(token['input_ids'])
            casual_span_pool.append([0,0,0])
            if self.option == "train":
                span_label.append(0)
        casual_pool = torch.LongTensor(casual_pool)  
        casual_span_pool = torch.FloatTensor(casual_span_pool)


        if self.option == "train":
            span_label = torch.FloatTensor(span_label)
            return {
                "conversation_ID": conversation_ID,
                "paragraph": paragraph,
                "utterance_ID": utterance_ID,
                "casual_pool": casual_pool,
                "casual_span_pool": casual_span_pool,
                "emotion_label": emotion_label,
                "span_label": span_label
            }
        else:
            return {
                "conversation_ID": conversation_ID,
                "paragraph": paragraph,
                "utterance_ID": utterance_ID,
                "casual_pool": casual_pool,
                "casual_span_pool": casual_span_pool
            }
    def read_json(self, file_path):
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data

## Evaluation

In [5]:
class Evaluation(nn.Module):
    def __init__(self, config_path):
        super(Evaluation, self).__init__()
        self.build(config_path)
    def build(self, config_path):
        config = self.read_json(config_path)
        self.label2id = config["label2id"]
        self.id2label = config["id2label"]
        self.batch_size = config["batch_size"]
        self.candidate_num = config["candidate_num"]
    def forward(self, conversation_id, utter_id, emotion_pred, span_pred, emotion_label, span_label):
        score_emo = []
        score_list = []
        score_list_1 = []
        emotion_pred = torch.argmax(emotion_pred, dim=1)
        conversation_id = conversation_id.unsqueeze(1)
        conversation_id = conversation_id.expand(self.batch_size, self.candidate_num)
        utter_id = utter_id.unsqueeze(1)
        utter_id = utter_id.expand(self.batch_size, self.candidate_num)
        emotion_pred = emotion_pred.unsqueeze(1)
        emotion_pred = emotion_pred.expand(self.batch_size, self.candidate_num)
        for i in range(len(conversation_id)):
            conv_id, emo_utt_id, cau_utt_id, span_start_id, span_end_id, emotion_category = conversation_id[i], utter_id[i], span_pred[i][0], span_pred[i][1], span_pred[i][2], self.id2label(emotion_pred)
            conv_id_label, emo_utt_id_label, cau_utt_id_label, span_start_id_label, span_end_id_label, emotion_category_label = conversation_id[i], utter_id[i], span_label[i][0], span_label[i][1], span_label[i][2], self.id2label(emotion_label)
            
            span_pair_dict = [conv_id_label, emo_utt_id_label, cau_utt_id_label, span_start_id_label, span_end_id_label, emotion_category_label]
            pred_pairs = [conv_id, emo_utt_id, cau_utt_id, span_start_id, span_end_id, emotion_category]
            
            emotion = self.cal_emotion(span_pair_dict, pred_pairs)
            emocate = self.cal_prf_span_pair_emocate(span_pair_dict, pred_pairs)
            emocate_proportional = self.cal_prf_span_pair_emocate_proportional(span_pair_dict, pred_pairs)
            score_emo.append(emotion)
            score_list.append(emocate)
            score_list_1.append(emocate_proportional)
        
        score_emo = np.array(score_emo)
        score_list = np.array(score_list)
        score_list_1 = np.array(score_list_1)
        return {
            "emotion_acc": score_emo.mean(),
            "weighted_strict_precision": score_list[:, 0].mean(),
            "weighted_strict_recall": score_list[:, 1].mean(),
            "weighted_strict_f1": score_list[:, 2].mean(),
            "strict_precision": score_list[:, 3].mean(),
            "strict_recall": score_list[:, 4].mean(),
            "strict_f1": score_list[:, 5].mean(),
            "weighted_Proportional_precision": score_list_1[:, 0].mean(),
            "weighted_Proportional_recall": score_list_1[:, 1].mean(),
            "weighted_Proportional_f1": score_list_1[:, 2].mean(),
            "Proportional_precision": score_list_1[:, 3].mean(),
            "Proportional_recall": score_list_1[:, 4].mean(),
            "Proportional_f1": score_list_1[:, 5].mean()
        }

            
    '''
    Strict Match: emotion_utt and cause_utt are the same, and the cause spans completely match.
    Fuzzy Match: emotion_utt and cause_utt are the same, and the cause spans overlap
    '''
    def judge_cause_span_pair_emocate(self, pred_span_pair, true_spans_pos_dict, span_mode='fuzzy'): # strict/fuzzy
        d_id, emo_id, cau_id, start_cur, end_cur, emo = pred_span_pair
        cur_key = 'dia{}_emoutt{}_causeutt{}'.format(d_id, emo_id, cau_id)
        if cur_key in true_spans_pos_dict:
            if span_mode == 'strict':
                if [start_cur, end_cur, emo] in true_spans_pos_dict[cur_key]:
                    true_spans_pos_dict[cur_key].remove([start_cur, end_cur, emo])
                    return True
            else:
                for t_start, t_end, emo_y in true_spans_pos_dict[cur_key]:
                    if emo == emo_y and not(end_cur<=t_start or start_cur>=t_end):
                        true_spans_pos_dict[cur_key].remove([t_start, t_end, emo_y]) 
                        return True
        return False


    def cal_prf_span_pair_emocate(self, span_pair_dict, pred_pairs, span_mode='strict'): 
        conf_mat = np.zeros([7,7])
        for p in pred_pairs: # [conv_id, emo_utt_id, cau_utt_id, span_start_id, span_end_id, emotion_category]
            if self.judge_cause_span_pair_emocate(p, span_pair_dict, span_mode=span_mode):
                conf_mat[p[5]][p[5]] += 1
            else:
                conf_mat[0][p[5]] += 1
        for k, v in span_pair_dict.items():
            for p in v:
                conf_mat[p[2]][0] += 1
        p = np.diagonal(conf_mat / np.reshape(np.sum(conf_mat, axis = 0)+(1e-8), [1,7]))
        r = np.diagonal(conf_mat / np.reshape(np.sum(conf_mat, axis = 1)+(1e-8), [7,1]))
        f = 2*p*r/(p+r+(1e-8))
        weight0 = np.sum(conf_mat, axis = 1)
        weight = weight0[1:] / np.sum(weight0[1:])
        w_avg_p = np.sum(p[1:] * weight)
        w_avg_r = np.sum(r[1:] * weight)
        w_avg_f1 = np.sum(f[1:] * weight)
        
        micro_acc = np.sum(np.diagonal(conf_mat)[1:])
        micro_p = micro_acc / (sum(np.sum(conf_mat, axis = 0)[1:])+(1e-8))
        micro_r = micro_acc / (sum(np.sum(conf_mat, axis = 1)[1:])+(1e-8))
        micro_f1 = 2*micro_p*micro_r/(micro_p+micro_r+1e-8)
        
        return [w_avg_p, w_avg_r, w_avg_f1, micro_p, micro_r, micro_f1]


    def get_match_scores(self, pred_span, true_spans):
        match_id, match_gold_length, match_length, match_score = 0, 0, 0, 0
        p_start, p_end, p_emo = pred_span
        for ii, (t_start, t_end, t_emo) in enumerate(true_spans):
            if p_emo == t_emo and not (p_end<=t_start or p_start>=t_end):
                cur_match_length = min(p_end, t_end) - max(p_start, t_start)
                cur_gold_length = t_end - t_start
                cur_match_score = cur_match_length / float(cur_gold_length)
                if cur_match_score > match_score:
                    match_id = ii
                    match_gold_length = cur_gold_length
                    match_length = cur_match_length
                    match_score = cur_match_score
                if (cur_match_score == match_score) and (cur_match_score > 0):
                    if cur_match_length > match_length:
                        match_id = ii
                        match_gold_length = cur_gold_length
                        match_length = cur_match_length
                        match_score = cur_match_score
        return match_id, match_gold_length, match_length, match_score


    '''
    Proportional Match (span): Each predicted span is compared with all golden spans, and determine which golden span it matches based on the overlap ratio (match score). Then the Precision, Recall, and F1 are calculated based on the overlapping tokens.
    '''
    def cal_prf_span_pair_emocate_proportional(self, true_span_pair_dict, pred_span_pair_dict): # 'dia{}_emoutt{}_causeutt{}': [[span_start_id, span_end_id, emotion_category], ...]
        prf_mat = np.zeros([7,5]) # row: emotion category; col: correct_num, true_num, pred_num, matched_true_span_num, true_span_num
        true_span_pair_dict_copy = copy.deepcopy(true_span_pair_dict)
        for k, v in pred_span_pair_dict.items():
            for pred_span in v:
                prf_mat[pred_span[2]][2] += pred_span[1] - pred_span[0]
                if k in true_span_pair_dict:
                    true_spans = true_span_pair_dict[k]
                    match_id, match_gold_length, match_length, match_score = self.get_match_scores(pred_span, true_spans)
                    if match_length > 0:
                        prf_mat[pred_span[2]][0] += match_length
                        prf_mat[pred_span[2]][1] += match_gold_length # Multiple predicted spans may match the same golden span.
                        prf_mat[pred_span[2]][3] += 1
                        if true_spans[match_id] in true_span_pair_dict_copy[k]:
                            true_span_pair_dict_copy[k].remove(true_spans[match_id])
        
        for k, v in true_span_pair_dict_copy.items():
            for true_span in v:
                prf_mat[true_span[2]][1] += true_span[1] - true_span[0]
                prf_mat[true_span[2]][3] += 1
        for k, v in true_span_pair_dict.items():
            for true_span in v:
                prf_mat[true_span[2]][4] += 1
        
        p_scores = prf_mat[1:,0] / (prf_mat[1:,2]+(1e-8))
        r_scores = prf_mat[1:,0] / (prf_mat[1:,1]+(1e-8))
        f1_scores = 2*p_scores*r_scores/(p_scores+r_scores+(1e-8))
        weight = prf_mat[1:,4] / sum(prf_mat[1:,4]) # Calculate the weight based on the actual number of golden spans.
        w_avg_p = sum(p_scores*weight)
        w_avg_r = sum(r_scores*weight)
        w_avg_f1 = sum(f1_scores*weight)

        total_correct = sum(prf_mat[1:,0])
        micro_p = total_correct / (sum(prf_mat[1:,2])+(1e-8))
        micro_r = total_correct / (sum(prf_mat[1:,1])+(1e-8))
        micro_f1 = 2*micro_p*micro_r/(micro_p+micro_r+1e-8)

        return [w_avg_p, w_avg_r, w_avg_f1, micro_p, micro_r, micro_f1]
    def cal_emotion(self, true_span_pair_dict, pred_span_pair_dict):
        return true_span_pair_dict[0] == pred_span_pair_dict[0]
    def read_json(self, file_path):
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data

# **Model**

## two encoder

In [6]:
class CasualSpanHead(nn.Module):
    def __init__(self, config_path):
        super(CasualSpanHead, self).__init__()
        self.build(config_path)
        
        self.hidden_projection = nn.Linear(self.token_size, 1)
        self.embedding_projection = nn.Linear(self.embedding_dim, 1)
        
    def build(self, config_path):
        config = self.read_json(config_path)
        self.candidate_num = config["candidate_num"]
        self.token_size = config["token_size"]
        self.hidden_size = config["hidden_size"]
        self.embedding_dim = config["embedding_dim"]
        self.feedforward_size = config["feedforward_size"]
        self.batch_size = config["batch_size"]
        self.casual_span_encoder_name = config["casual_span_encoder_name"]
        main_part = AutoModel.from_pretrained(self.casual_span_encoder_name)
        self.casual_span_encoder = main_part.encoder
        self.casual_span_pooler = main_part.pooler
    def forward(self, x, candidate):
        _x = x.unsqueeze(1)
        _x = _x.expand(self.batch_size, self.candidate_num, self.token_size, self.embedding_dim)
        
        _candidate = candidate.reshape(self.candidate_num, self.batch_size, self.token_size, self.embedding_dim)
        for i in range(self.candidate_num):
            _candidate[i] = self.casual_span_encoder(_candidate[i])
            _candidate[i] = self.casual_span_pooler(_candidate[i])
        _candidate = _candidate.reshape(self.batch_size, self.candidate_num, self.token_size, self.embedding_dim)
        
        _similarity = F.cosine_similarity(_x, _candidate, dim=1)
        _similarity = self.embedding_projection(_similarity).squeeze(-1)
        _similarity = self.hidden_projection(_similarity).squeeze(-1)
        return _similarity
    def read_json(self, file_path):
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data

class EmotionHead(nn.Module):
    def __init__(self, config_path):
        super(EmotionHead, self).__init__()
        self.build(config_path)
        
        self.emotion_fc = nn.Linear(self.embedding_dim, self.label_num)
    def build(self, config_path):
        config = self.read_json(config_path)
        self.candidate_num = config["candidate_num"]
        self.token_size = config["token_size"]
        self.embedding_dim = config["embedding_dim"]
        self.label2id = config["label2id"]
        self.label_num = len(self.label2id)
    def forward(self, x):
        return self.emotion_fc(x)
    def read_json(self, file_path):
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data
    
        
class TGG(nn.Module):
    def __init__(self, config_path):
        super(TGG, self).__init__()
        self.emotion_head = EmotionHead(config_path)
        self.casual_span_head = CasualSpanHead(config_path)
        self.build(config_path)
    def build(self, config_path):
        config = self.read_json(config_path)
        self.candidate_num = config["candidate_num"]
        self.hidden_size = config["hidden_size"]
        self.embedding_dim = config["embedding_dim"]
        self.token_size = config["token_size"]
        self.main_part_name = config["main_part_name"]
        
        main_part = AutoModel.from_pretrained(self.main_part_name)
        self.embedding = main_part.embedding
        self.encoder = main_part.encoder
        self.pooler = main_part.pooler
        
    def forward(self, x, candidate):
        _x = self.embedding(x)
        
        _candidate = candidate.unsqueeze(-1)
        _candidate = _candidate.expand(self.batch_size, self.candidate_num, self.token_size, self.embedding_dim)
        _candidate = _candidate.reshape(self.candidate_num, self.batch_size, self.token_size, self.embedding_dim)
        for i in range(self.candidate_num):
            _candidate[i] = self.embedding(_candidate[i])
        _candidate = _candidate.reshape(self.batch_size, self.candidate_num, self.token_size, self.embedding_dim)
        
        _x = self.encoder(_x)
        _x = self.pooler(_x)
        _e_category = self.emotion_head(_x)
        _casual_span = self.casual_span_head(_x, _candidate)
        return _e_category, _casual_span
    def read_json(self, file_path):
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data

## one encoder

In [7]:
class EmotionHead(nn.Module):
    def __init__(self, config_path):
        super(EmotionHead, self).__init__()
        self.build(config_path)
        
        self.emotion_fc = nn.Linear(self.embedding_dim, self.label_num)
    def build(self, config_path):
        config = self.read_json(config_path)
        self.candidate_num = config["candidate_num"]
        self.token_size = config["token_size"]
        self.embedding_dim = config["embedding_dim"]
        self.label2id = config["label2id"]
        self.label_num = len(self.label2id)
    def forward(self, x):
        return self.emotion_fc(x)
    def read_json(self, file_path):
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data
    
        
class TGG(nn.Module):
    def __init__(self, config_path):
        super(TGG, self).__init__()
        self.build(config_path)
        self.emotion_head = EmotionHead(config_path)
        self.casual_feedforward = nn.Sequential(
            nn.Linear(self.embedding_dim, self.hidden_size),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(self.hidden_size, self.embedding_dim),
            nn.ReLU(),
            nn.Dropout()
        )
    def build(self, config_path):
        config = self.read_json(config_path)
        self.candidate_num = config["candidate_num"]
        self.hidden_size = config["hidden_size"]
        self.embedding_dim = config["embedding_dim"]
        self.token_size = config["token_size"]
        self.main_part_name = config["main_part_name"]
        # self.batch_size = config["batch_size"]
        self.device = config["device"]
        
        main_part = AutoModel.from_pretrained(self.main_part_name)
        self.embeddings = main_part.embeddings
        self.encoder = main_part.encoder
        self.pooler = main_part.pooler
        
    def forward(self, x, candidate):
        self.batch_size = x.size(0)
        # print("x: "+str(x.shape))
        _x = self.embeddings(x)
        _x = self.encoder(_x)['last_hidden_state']
        _x = self.pooler(_x)
        _e_category = self.emotion_head(_x)


        _candidate = torch.zeros((self.candidate_num, self.batch_size, self.embedding_dim), device=self.device)
        cand = candidate.reshape(self.candidate_num, self.batch_size, self.token_size)
        for i in range(self.candidate_num):
            _candidate[i] = self.pooler(self.encoder(self.embeddings(cand[i]))['last_hidden_state'])
        _candidate = _candidate.reshape(self.batch_size, self.candidate_num, self.embedding_dim)
        
        # print(_candidate)
        _candidate = self.casual_feedforward(_candidate)
        _x = self.casual_feedforward(_x)
        _x = _x.unsqueeze(1)
        _x = _x.expand(self.batch_size, self.candidate_num, self.embedding_dim)
        _similarity = F.cosine_similarity(_x, _candidate, dim=2)
        
        return _e_category, _similarity
    def read_json(self, file_path):
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data

# Trainer

In [8]:
class Trainer(nn.Module):
    def __init__(self, config_path):
        super(Trainer, self).__init__()
        self.layers_to_freeze = []
        self.history = {
            "train loss": [],
            "valid loss": [],
            "valid metrics": []
        }
        self.evaluator = Evaluation(config_path)
        self.build(config_path)
    def build(self, config_path):
        print(config_path)
        config = self.read_json(config_path)
        self.patience = config['patience']
        self.min_delta = config['min_delta']
        self.fn_loss_name = config['loss_fn_name']
        self.batch_size_config = config['batch_size']
        self.device = config['device']
        self.save_checkpoint_dir = config['save_checkpoint_dir']
        self.emotion_head_layers = config['emotion_head_layers']
        self.casual_emotion_layers = config['casual_emotion_layers']
        self.main_layers = config['main_layers']
        # if "CrossEntropyLoss" == self.fn_loss_name:
        self.fn_loss = nn.CrossEntropyLoss()
        self.early_stopper = EarlyStopper(self.patience, self.min_delta)
    def emotion_head_training(self):
        self.layers_to_freeze = self.casual_emotion_layers + self.main_layers
        
    def casual_emotion_training(self):
        self.layers_to_freeze = self.emotion_head_layers + self.main_layers

    def full_training(self):
        self.layers_to_freeze = []
    def freeze(self, model):
        for name, param in model.named_parameters():
            if any(layer in name for layer in self.layers_to_freeze):
                param.requires_grad = False
            else:
                param.requires_grad = True
        return model
    def train(self, model, optimizer, train_dataset, valid_dataset, epochs, option="full", batch_size=None):
        self.optimizer = optimizer
        if option == "emotion":
            print("TRAINING " + option + " MODEL")
            self.emotion_head_training()
        elif option == "casual_emotion":
            print("TRAINING " + option + " MODEL")
            self.casual_emotion_training()
        else:
            print("TRAINING " + option + " MODEL")
            self.full_training()

        if batch_size != None:
            self.batch_size = batch_size
        else:
            self.batch_size = self.batch_size_config


        self.early_stopper.reset()
        self.model = self.freeze(model)
        self.model = self.model.to(self.device)

        self.train_dataloader = DataLoader(dataset=train_dataset, batch_size=self.batch_size, shuffle=True)
        self.valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=self.batch_size, shuffle=True)
        
        best_val_loss = np.inf
        self.epochs = epochs
        for epoch in range(epochs):
            train_loss = self._train_epoch(epoch)
            val_loss = self._val_epoch(epoch)
            
            self.history["train loss"].append(train_loss)
            self.history["valid loss"].append(val_loss)
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                self.save_checkpoint(epoch)
                print(f"Epoch {epoch+1}: Train loss {train_loss:.4f},  Valid loss {val_loss:.4f}")
            if self.early_stopper.early_stop(val_loss):
                print("Early stop at epoch: "+str(epoch) + " with valid loss: "+str(val_loss))
                break
            
    def _train_epoch(self, epoch):
        self.model.train()
        train_loss = 0.0
        logger_message = f'Training epoch {epoch}/{self.epochs}'

        progress_bar = tqdm(self.train_dataloader,
                            desc=logger_message, initial=0, dynamic_ncols=True)
        for batch, data in enumerate(progress_bar):
            conversation_id = data["conversation_ID"]
            paragraph = data["paragraph"].to(self.device)
            utter_id = data["utterance_ID"]
            casual_pool = data["casual_pool"].to(self.device)
            casual_span_pool = data["casual_span_pool"]
            emotion_label = data["emotion_label"].to(self.device)
            span_label = data["span_label"].to(self.device)
            # print("conversation_id:"+str(conversation_id.shape))
            # print("paragraph:"+str(paragraph.shape))
            # print("casual_pool:"+str(casual_pool.shape))
            # print("emotion_label:"+str(emotion_label.shape))
            # print("span_label:"+str(span_label.shape))
            self.optimizer.zero_grad()
            
            emotion_pred, span_pred = self.model(paragraph, casual_pool)
            
            emotion_loss = self.fn_loss(emotion_pred, emotion_label)
            span_loss = self.fn_loss(span_pred, span_label)
            loss = emotion_loss + span_loss
            train_loss += loss.item()
            
            loss.backward()
            self.optimizer.step()
        return train_loss / len(self.train_dataloader)

    def _val_epoch(self, epoch):
        self.model.eval()
        valid_loss = 0.0
        valid_metrics = {}
        logger_message = f'Validation epoch {epoch}/{self.epochs}'
        progress_bar = tqdm(self.valid_dataloader,
                            desc=logger_message, initial=0, dynamic_ncols=True)
        with torch.no_grad():
            for _, data in enumerate(progress_bar):
                conversation_id = data["conversation_ID"]
                paragraph = data["paragraph"].to(self.device)
                utter_id = data["utterance_ID"]
                casual_pool = data["casual_pool"].to(self.device)
                casual_span_pool = data["casual_span_pool"]
                emotion_label = data["emotion_label"].to(self.device)
                span_label = data["span_label"].to(self.device)
                emotion_pred, span_pred = self.model(paragraph, casual_pool)
                
                emotion_loss = self.fn_loss(emotion_pred, emotion_label)
                span_loss = self.fn_loss(span_pred, span_label)
                loss = emotion_loss + span_loss
                valid_loss += loss.item()
                
        return valid_loss / len(self.valid_dataloader)

    def recalculation(self, evaluate_score, valid_metrics, option="sum"):
        if option=="sum":
            for key in evaluate_score.keys():
                if key not in valid_metrics.keys():
                    valid_metrics[key]=[]
                valid_metrics[key].append(evaluate_score[key])
            return valid_metrics
        else:
            for key in evaluate_score.keys():
                valid_metrics[key] = sum(valid_metrics[key]) / len(valid_metrics[key])
            return valid_metrics

    def save_checkpoint(self, epoch):
        checkpoint_path = os.path.join(self.save_checkpoint_dir, 'checkpoint_{}.pth'.format(epoch))
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'epoch': epoch
        }, checkpoint_path)
    def read_json(self, file_path):
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data

# Predict

In [9]:
class Prediction(nn.Module):
    def __init__(self, config_path):
        super(Prediction, self).__init__()
        self.build(config_path)
    def build(self, config_path):
        config = self.read_json(config_path)
        self.id2label = config["id2label"]
        self.label2id = config["label2id"]
        self.submit_path = config["submit path"]
        self.device = config["device"]
        self.batch_size = config["batch_size"]
    def predict(self, model, raw_dataset, dataset):
        self.model = model.to(self.device)
        self.model.eval()
        pred_list = []
        
        self.test_dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=False)
        logger_message = f'Predict '
        progress_bar = tqdm(self.test_dataloader,
                            desc=logger_message, initial=0, dynamic_ncols=True)
        with torch.no_grad():
            for _, data in enumerate(progress_bar):
                conversation_id = data["conversation_ID"]
                paragraph = data["paragraph"].to(self.device)
                utter_id = data["utterance_ID"]
                casual_pool = data["casual_pool"].to(self.device)
                casual_span_pool = data["casual_span_pool"]
                # emotion_label = data["emotion_label"].to(self.device)
                # span_label = data["span_label"].to(self.device)
                emotion_pred, span_pred = self.model(paragraph, casual_pool)
                
                pred = self.convertpred2list(conversation_id, utter_id, emotion_pred, span_pred, casual_span_pool)
                pred_list.extend(pred)
        submit = self.convertdict2submit(self.convertpred2dict(raw_dataset, pred_list))
        self.save_json(self.submit_path, submit)
    def convertpred2list(self, conversation_id, utter_id, emotion_pred, span_pred, casual_span_pool):
        pred_list = []
        for i in range(len(emotion_pred)):
            emotion = self.id2label[str(torch.argmax(emotion_pred[i]).item())]
            emotion_string = str(utter_id[i].item()) + "_" + emotion
            if emotion != "neutral":
                # 
                span = torch.round(span_pred[i])
                # print("span: "+str(span))
                # print("casual_span_pool: "+str(casual_span_pool[i]))
                for j in range(len(span)):
                    if span[j] == 1:
                        if casual_span_pool[i][j][0].item() == 0:
                            break
                        if casual_span_pool[i][j][1].item() == casual_span_pool[i][j][2].item():
                            break
                        span_string = str(int(casual_span_pool[i][j][0].item())) + "_" + str(int(casual_span_pool[i][j][1].item())) + "_" + str(int(casual_span_pool[i][j][2].item()))
                        pred_dict = {}
                        pred_dict["conversation_ID"] = conversation_id[i].item()
                        pred_dict["emotion-cause_pairs"] = [emotion_string, span_string]
                        # print("pred_dict: "+str(pred_dict))
                        pred_list.append(pred_dict)
        return pred_list
    def convertpred2dict(self, raw_dataset, pred_list):
        dataset = self.convertlist2dict(raw_dataset)
        # dict = 
        for i in range(len(pred_list)):
            pred = pred_list[i]
            conversation_id = pred["conversation_ID"]
            # print("conversation_id: "+str(conversation_id))
            emotion_cause_pairs = pred["emotion-cause_pairs"]
            # print("emotion_cause_pairs: "+str(emotion_cause_pairs))
            if "emotion-cause_pairs" in dataset[conversation_id]:
                dataset[conversation_id]["emotion-cause_pairs"].append(emotion_cause_pairs)
            else:
                dataset[conversation_id]["emotion-cause_pairs"] = []
        # print(dataset)
        return dataset
    def convertlist2dict(self, raw_dataset):
        dataset = {}
        for i in range(len(raw_dataset)):
            conversation = raw_dataset[i]
            conversation_ID = conversation["conversation_ID"]
            # if conversation_ID not in dataset.key
            dataset[conversation_ID] = {
                "conversation_ID": conversation_ID,
                "conversation": conversation['conversation']
            }
        return dataset
    def convertdict2submit(self, raw_dataset):
        dataset = []
        for key in raw_dataset.keys():
            # print("key: "+str(key))
            sample = raw_dataset[key]
            # print("sample: "+str(sample))
            dataset.append(sample)
        return dataset
    def read_json(self, file_path):
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data
    def save_json(self, file_path, data):
        with open(file_path, 'w') as json_file:
            json.dump(data, json_file)

# TEST

## Training

In [10]:
config_path = "config.json"
dataset = read_json("data/ECF 2.0/train/train.json")
trial_data = read_json("data/ECF 2.0/trial/trial.json")
raw_trial_data = read_json("data/ECF 2.0/trial/Subtask_1_trial.json")

n_train = int(0.8 * len(dataset))
train_data, valid_data = dataset[:n_train], dataset[n_train:]
train_dataset = ConversationDataset(config_path, train_data)
valid_dataset = ConversationDataset(config_path, valid_data)
model = TGG(config_path)
optimizer = Adam(model.parameters(), lr=0.001)
# for name, param in model.named_parameters():
#     print("\"" + name + "\",")

## predicting

In [11]:
trainer = Trainer(config_path)
trainer.train(model, optimizer, train_dataset=train_dataset, valid_dataset=valid_dataset, epochs=100, option="emotion", batch_size=128)
trial_dataset = ConversationDataset(config_path, trial_data, option="trial")
predict = Prediction(config_path)
predict.predict(model, raw_trial_data, trial_dataset)

config.json
TRAINING emotion MODEL


Training epoch 0/100: 100%|██████████| 85/85 [07:49<00:00,  5.52s/it]
Validation epoch 0/100: 100%|██████████| 22/22 [02:01<00:00,  5.52s/it]


Epoch 1: Train loss 2.6259,  Valid loss 9.4572


Training epoch 1/100: 100%|██████████| 85/85 [08:47<00:00,  6.20s/it]
Validation epoch 1/100: 100%|██████████| 22/22 [02:14<00:00,  6.09s/it]


Epoch 2: Train loss 2.5920,  Valid loss 9.0473


Training epoch 2/100: 100%|██████████| 85/85 [09:11<00:00,  6.49s/it]
Validation epoch 2/100: 100%|██████████| 22/22 [02:10<00:00,  5.91s/it]
Training epoch 3/100: 100%|██████████| 85/85 [08:56<00:00,  6.31s/it]
Validation epoch 3/100: 100%|██████████| 22/22 [02:08<00:00,  5.84s/it]
Training epoch 4/100: 100%|██████████| 85/85 [09:13<00:00,  6.51s/it]
Validation epoch 4/100: 100%|██████████| 22/22 [02:16<00:00,  6.23s/it]
Training epoch 5/100: 100%|██████████| 85/85 [09:17<00:00,  6.56s/it]
Validation epoch 5/100: 100%|██████████| 22/22 [02:13<00:00,  6.05s/it]
Training epoch 6/100: 100%|██████████| 85/85 [09:09<00:00,  6.47s/it]
Validation epoch 6/100: 100%|██████████| 22/22 [02:12<00:00,  6.04s/it]


Early stop at epoch: 6 with valid loss: 9.535531130704014


Predict : 100%|██████████| 45/45 [00:18<00:00,  2.49it/s]


In [12]:
trainer.train(model, optimizer, train_dataset=train_dataset, valid_dataset=valid_dataset, epochs=100, option="casual_emotion", batch_size=128)
trial_dataset = ConversationDataset(config_path, trial_data, option="trial")
predict = Prediction(config_path)
predict.predict(model, raw_trial_data, trial_dataset)

TRAINING casual_emotion MODEL


Training epoch 0/100: 100%|██████████| 85/85 [08:59<00:00,  6.35s/it]
Validation epoch 0/100: 100%|██████████| 22/22 [02:11<00:00,  5.97s/it]


Epoch 1: Train loss 2.5378,  Valid loss 9.4816


Training epoch 1/100: 100%|██████████| 85/85 [09:00<00:00,  6.36s/it]
Validation epoch 1/100: 100%|██████████| 22/22 [02:11<00:00,  5.99s/it]


Epoch 2: Train loss 2.5345,  Valid loss 9.3875


Training epoch 2/100: 100%|██████████| 85/85 [09:01<00:00,  6.38s/it]
Validation epoch 2/100: 100%|██████████| 22/22 [02:12<00:00,  6.01s/it]
Training epoch 3/100: 100%|██████████| 85/85 [09:07<00:00,  6.44s/it]
Validation epoch 3/100: 100%|██████████| 22/22 [02:08<00:00,  5.86s/it]
Training epoch 4/100: 100%|██████████| 85/85 [09:02<00:00,  6.38s/it]
Validation epoch 4/100: 100%|██████████| 22/22 [02:11<00:00,  5.97s/it]


Epoch 5: Train loss 2.5296,  Valid loss 9.3270
Early stop at epoch: 4 with valid loss: 9.327030485326594


Predict : 100%|██████████| 45/45 [00:17<00:00,  2.55it/s]


In [13]:
trainer.train(model, optimizer, train_dataset=train_dataset, valid_dataset=valid_dataset, epochs=100, option="full")
trial_dataset = ConversationDataset(config_path, trial_data, option="trial")
predict = Prediction(config_path)
predict.predict(model, raw_trial_data, trial_dataset)

TRAINING full MODEL


Training epoch 0/100: 100%|██████████| 1347/1347 [23:17<00:00,  1.04s/it]
Validation epoch 0/100: 100%|██████████| 337/337 [02:03<00:00,  2.73it/s]


Epoch 1: Train loss 2.7653,  Valid loss 11.4463


Training epoch 1/100: 100%|██████████| 1347/1347 [22:27<00:00,  1.00s/it]
Validation epoch 1/100: 100%|██████████| 337/337 [02:02<00:00,  2.74it/s]
Training epoch 2/100: 100%|██████████| 1347/1347 [22:30<00:00,  1.00s/it]
Validation epoch 2/100: 100%|██████████| 337/337 [02:02<00:00,  2.74it/s]


Epoch 3: Train loss 2.6326,  Valid loss 7.4694


Training epoch 3/100: 100%|██████████| 1347/1347 [22:30<00:00,  1.00s/it]
Validation epoch 3/100: 100%|██████████| 337/337 [02:02<00:00,  2.75it/s]
Training epoch 4/100: 100%|██████████| 1347/1347 [22:28<00:00,  1.00s/it]
Validation epoch 4/100: 100%|██████████| 337/337 [02:01<00:00,  2.77it/s]
Training epoch 5/100: 100%|██████████| 1347/1347 [22:20<00:00,  1.00it/s]
Validation epoch 5/100: 100%|██████████| 337/337 [02:00<00:00,  2.79it/s]
Training epoch 6/100: 100%|██████████| 1347/1347 [22:14<00:00,  1.01it/s]
Validation epoch 6/100: 100%|██████████| 337/337 [02:00<00:00,  2.79it/s]
Training epoch 7/100: 100%|██████████| 1347/1347 [22:12<00:00,  1.01it/s]
Validation epoch 7/100: 100%|██████████| 337/337 [02:00<00:00,  2.79it/s]


Early stop at epoch: 7 with valid loss: 11.001746106925987


Predict : 100%|██████████| 45/45 [00:15<00:00,  2.89it/s]
