# Auxility Class

## Library

In [None]:
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
# Dataset
from PIL import Image
from torchvision import transforms
from torchvision.io import read_video, read_image
from torch.utils.data import Dataset, DataLoader
# Model
from transformers import AutoModel
from transformers import AutoTokenizer, BertModel, BertTokenizer, BertGenerationDecoder

# Training parameter
from torch.optim import Adam
# Training process
from tqdm import tqdm
# Metrics
from sklearn.metrics import classification_report, accuracy_score
from collections import OrderedDict
import copy
import sys

## Function

In [None]:
def read_json(file_path):
    with open(file_path, 'r') as json_file:
        data = json.load(json_file)
    return data
def write_json(path, data):
    with open(path, 'w') as json_file:
        json.dump(data, json_file, indent=4)

## Early Stopper

In [None]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def reset(self):
        self.counter = 0

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

# Train

## Dataset

In [None]:
class ConversationDataset(Dataset):
    def __init__(self, config_path, data, option="train"):
        super(ConversationDataset, self).__init__()
        self.build(config_path)
        self.data = data
        self.option = option
    def build(self, config_path):
        config = self.read_json(config_path)
        self.label2id = config["label2id"]
        self.id2label = config["id2label"]
        self.tokenizer_name = config["tokenizer_name"]
        self.padding = config["padding"]
        self.max_length = config["max_length"]
        self.candidate_num = config["candidate num"]
        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        conversation_ID = self.data[index]["conversation_ID"]
        paragraph = self.tokenizer(self.data[index]["paragraph"], padding=self.padding, max_length=self.max_length, return_tensors="pt")
        paragraph = torch.squeeze(paragraph['input_ids'])
        utterance_ID = self.data[index]["utterance_ID"]

        casual_pool = self.tokenizer(self.data[index]["casual_pool"][i], padding=self.padding, max_length=self.max_length)['input_ids'] 
        casual_span_pool = torch.FloatTensor(self.data[index]["casual_span_pool"])

        if self.option == "train":
            emotion_label = self.label2id[self.data[index]["emotion"]]
            span_label = torch.FloatTensor(self.data[index]["span_label"])
            return {
                "conversation_ID": conversation_ID,
                "paragraph": paragraph,
                "utterance_ID": utterance_ID,
                "casual_pool": casual_pool,
                "casual_span_pool": casual_span_pool,
                "emotion_label": emotion_label,
                "span_label": span_label
            }
        else:
            return {
                "conversation_ID": conversation_ID,
                "paragraph": paragraph,
                "utterance_ID": utterance_ID,
                "casual_pool": casual_pool,
                "casual_span_pool": casual_span_pool
            }
    def read_json(self, file_path):
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data

## Trainer

In [None]:
class Trainer(nn.Module):
    def __init__(self, config_path):
        super(Trainer, self).__init__()
        self.layers_to_freeze = []
        self.history = {
            "train loss": [],
            "valid loss": [],
            "valid metrics": []
        }
        self.evaluator = Evaluation(config_path)
        self.build(config_path)
    def build(self, config_path):
        print(config_path)
        config = read_json(config_path)
        self.patience = config['patience']
        self.min_delta = config['min_delta']
        self.fn_loss_name = config['loss_fn_name']
        self.batch_size = config['batch_size']
        self.device = config['device']
        self.save_checkpoint_dir = config['save_checkpoint_dir']
        self.save_history_dir = config['save_history_dir']
        self.emotion_head_layers = config['emotion_head_layers']
        self.casual_emotion_layers = config['casual_emotion_layers']
        self.main_layers = config['main_layers']
        self.early_stopper = EarlyStopper(self.patience, self.min_delta)
    def emotion_head_training(self):
        self.layers_to_freeze = self.casual_emotion_layers + self.main_layers
        self.fn_loss = nn.CrossEntropyLoss()
        
    def casual_emotion_training(self):
        self.layers_to_freeze = self.emotion_head_layers + self.main_layers
        self.fn_loss = nn.MSELoss()

    def full_training(self):
        self.layers_to_freeze = []
    def freeze(self, model):
        for name, param in model.named_parameters():
            if any(layer in name for layer in self.layers_to_freeze):
                param.requires_grad = False
            else:
                param.requires_grad = True
        return model
    def train(self, model, optimizer, train_dataset, valid_dataset, epochs, option="emotion", batch_size=None):
        self.optimizer = optimizer
        if option == "emotion":
            print("TRAINING " + option + " MODEL")
            self.emotion_head_training()
        elif option == "casual_emotion":
            print("TRAINING " + option + " MODEL")
            self.casual_emotion_training()

        if batch_size != None:
            batch_size = self.batch_size
        else:
            batch_size = batch_size


        self.early_stopper.reset()
        self.model = self.freeze(model)
        self.model = self.model.to(self.device)

        self.train_dataloader = DataLoader(dataset=train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=False)
        self.valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=self.batch_size, shuffle=True, drop_last=False)
        
        best_val_loss = np.inf
        self.epochs = epochs
        for epoch in range(epochs):
            train_loss = self._train_epoch(epoch, option)
            val_loss, emotion_correct, span_correct = self._val_epoch(epoch, option)
            
            self.history["train loss"].append(train_loss)
            self.history["valid loss"].append(val_loss)
            self.history["emotion_correct"].append(emotion_correct)
            self.history["span_correct"].append(span_correct)
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                self.save_checkpoint(epoch, option)
                print(f"Epoch {epoch+1}: Train loss {train_loss:.4f},  Valid loss {val_loss:.4f},  emotion_correct {emotion_correct*100:.4f}%,  span_correct {span_correct*100:.4f}%")
            if self.early_stopper.early_stop(val_loss):
                print("Early stop at epoch: "+str(epoch) + " with valid loss: "+str(val_loss))
                break
        write_json(self.save_history_dir + "/" + option + "_history.json") 
    def _train_epoch(self, epoch, option):
        self.model.train()
        train_loss = 0.0
        logger_message = f'Training epoch {epoch}/{self.epochs}'

        progress_bar = tqdm(self.train_dataloader,
                            desc=logger_message, initial=0, dynamic_ncols=True)
        for batch, data in enumerate(progress_bar):
            conversation_id = data["conversation_ID"]
            paragraph = data["paragraph"].to(self.device)
            utter_id = data["utterance_ID"]
            casual_pool = data["casual_pool"].to(self.device)
            casual_span_pool = data["casual_span_pool"]
            emotion_label = data["emotion_label"].to(self.device)
            span_label = data["span_label"].to(self.device)
            # print("conversation_id:"+str(conversation_id.shape))
            # print("paragraph:"+str(paragraph.shape))
            # print("casual_pool:"+str(casual_pool.shape))
            # print("emotion_label:"+str(emotion_label.shape))
            # print("span_label:"+str(span_label.shape))
            self.optimizer.zero_grad()
            
            emotion_pred, span_pred = self.model(paragraph, casual_pool)
            
            if option == "emotion":
                loss = self.fn_loss(emotion_pred, emotion_label)
            elif option == "casual_emotion":
                loss = self.fn_loss(span_pred, span_label)
            else:
                emotion_loss = self.fn_loss(emotion_pred, emotion_label)
                span_loss = self.fn_loss(span_pred, span_label)
                loss = emotion_loss + span_loss

            train_loss += loss.item()
            loss.backward()
            self.optimizer.step()
        return train_loss / len(self.train_dataloader)

    def _val_epoch(self, epoch, option):
        self.model.eval()
        valid_loss = 0.0
        valid_metrics = {}
        logger_message = f'Validation epoch {epoch}/{self.epochs}'
        progress_bar = tqdm(self.valid_dataloader,
                            desc=logger_message, initial=0, dynamic_ncols=True)
        with torch.no_grad():
            for _, data in enumerate(progress_bar):
                conversation_id = data["conversation_ID"]
                paragraph = data["paragraph"].to(self.device)
                utter_id = data["utterance_ID"]
                casual_pool = data["casual_pool"].to(self.device)
                casual_span_pool = data["casual_span_pool"]
                emotion_label = data["emotion_label"].to(self.device)
                span_label = data["span_label"].to(self.device)
                emotion_pred, span_pred = self.model(paragraph, casual_pool)
                
                
                if option == "emotion":
                    loss = self.fn_loss(emotion_pred, emotion_label)
                elif option == "casual_emotion":
                    loss = self.fn_loss(span_pred, span_label)
                else:
                    emotion_loss = self.fn_loss(emotion_pred, emotion_label)
                    span_loss = self.fn_loss(span_pred, span_label)
                    loss = emotion_loss + span_loss

                valid_loss += loss.item()
                emotion_correct += (emotion_pred == emotion_label).sum().item()
                span_correct += (span_pred == span_label).sum().item()
                
        return valid_loss / len(self.valid_dataloader), emotion_correct / len(self.valid_dataloader), span_correct / len(self.valid_dataloader)
        
    def save_checkpoint(self, epoch, option):
        checkpoint_path = os.path.join(self.save_checkpoint_dir, option + '/checkpoint_{}.pth'.format(epoch))
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'epoch': epoch
        }, checkpoint_path)

## Predictor

In [None]:
class Prediction(nn.Module):
    def __init__(self, config_path):
        super(Prediction, self).__init__()
        self.build(config_path)
    def build(self, config_path):
        config = self.read_json(config_path)
        self.id2label = config["id2label"]
        self.label2id = config["label2id"]
        self.submit_path = config["submit path"]
        self.device = config["device"]
        self.batch_size = config["batch_size"]
    def predict(self, model, raw_dataset, dataset):
        self.model = model.to(self.device)
        self.model.eval()
        pred_list = []
        
        self.test_dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=False)
        logger_message = f'Predict '
        progress_bar = tqdm(self.test_dataloader,
                            desc=logger_message, initial=0, dynamic_ncols=True)
        with torch.no_grad():
            for _, data in enumerate(progress_bar):
                conversation_id = data["conversation_ID"]
                paragraph = data["paragraph"].to(self.device)
                utter_id = data["utterance_ID"]
                casual_pool = data["casual_pool"].to(self.device)
                casual_span_pool = data["casual_span_pool"]
                # emotion_label = data["emotion_label"].to(self.device)
                # span_label = data["span_label"].to(self.device)
                emotion_pred, span_pred = self.model(paragraph, casual_pool)
                
                pred = self.convertpred2list(conversation_id, utter_id, emotion_pred, span_pred, casual_span_pool)
                pred_list.extend(pred)
        submit = self.convertdict2submit(self.convertpred2dict(raw_dataset, pred_list))
        self.save_json(self.submit_path, submit)
    def convertpred2list(self, conversation_id, utter_id, emotion_pred, span_pred, casual_span_pool):
        pred_list = []
        for i in range(len(emotion_pred)):
            emotion = self.id2label[str(torch.argmax(emotion_pred[i]).item())]
            emotion_string = str(utter_id[i].item()) + "_" + emotion
            if emotion != "neutral":
                # 
                span = torch.round(span_pred[i])
                # print("span: "+str(span))
                # print("casual_span_pool: "+str(casual_span_pool[i]))
                for j in range(len(span)):
                    if span[j] == 1:
                        if casual_span_pool[i][j][0].item() == 0:
                            break
                        if casual_span_pool[i][j][1].item() == casual_span_pool[i][j][2].item():
                            break
                        span_string = str(int(casual_span_pool[i][j][0].item())) + "_" + str(int(casual_span_pool[i][j][1].item())) + "_" + str(int(casual_span_pool[i][j][2].item()))
                        pred_dict = {}
                        pred_dict["conversation_ID"] = conversation_id[i].item()
                        pred_dict["emotion-cause_pairs"] = [emotion_string, span_string]
                        # print("pred_dict: "+str(pred_dict))
                        pred_list.append(pred_dict)
        return pred_list
    def convertpred2dict(self, raw_dataset, pred_list):
        dataset = self.convertlist2dict(raw_dataset)
        # dict = 
        for i in range(len(pred_list)):
            pred = pred_list[i]
            conversation_id = pred["conversation_ID"]
            # print("conversation_id: "+str(conversation_id))
            emotion_cause_pairs = pred["emotion-cause_pairs"]
            # print("emotion_cause_pairs: "+str(emotion_cause_pairs))
            if "emotion-cause_pairs" in dataset[conversation_id]:
                dataset[conversation_id]["emotion-cause_pairs"].append(emotion_cause_pairs)
            else:
                dataset[conversation_id]["emotion-cause_pairs"] = []
        # print(dataset)
        return dataset
    def convertlist2dict(self, raw_dataset):
        dataset = {}
        for i in range(len(raw_dataset)):
            conversation = raw_dataset[i]
            conversation_ID = conversation["conversation_ID"]
            # if conversation_ID not in dataset.key
            dataset[conversation_ID] = {
                "conversation_ID": conversation_ID,
                "conversation": conversation['conversation']
            }
        return dataset
    def convertdict2submit(self, raw_dataset):
        dataset = []
        for key in raw_dataset.keys():
            # print("key: "+str(key))
            sample = raw_dataset[key]
            # print("sample: "+str(sample))
            dataset.append(sample)
        return dataset
    def read_json(self, file_path):
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data
    def save_json(self, file_path, data):
        with open(file_path, 'w') as json_file:
            json.dump(data, json_file)

# Model

In [None]:
class TGG(nn.Module):
    def __init__(self, config_path):
        super(TGG, self).__init__()
        self.build(config_path)
        self.emotion_fc = nn.Linear(self.embedding_dim, self.label_num)
        self.casual_feedforward = nn.Sequential(
            nn.Linear(self.embedding_dim, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.embedding_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
    def build(self, config_path):
        config = self.read_json(config_path)
        self.hidden_size = config["hidden_size"]
        self.embedding_dim = config["embedding_dim"]
        self.token_size = config["token_size"]
        self.main_part_name = config["main_part_name"]
        self.label2id = config["label2id"]
        self.device = config["device"]
        
        main_part = AutoModel.from_pretrained(self.main_part_name)
        self.embeddings = main_part.embeddings
        self.encoder = main_part.encoder
        self.pooler = main_part.pooler
        
    def forward(self, x, candidate):
        batch_size = x.size(0)
        # print("x: "+str(x.shape))
        _x = self.embeddings(x)
        _x = self.encoder(_x)['last_hidden_state']
        _x = self.pooler(_x)
        _e_category = self.emotion_head(_x)

        
        _candidate = self.embeddings(candidate)
        _candidate = self.encoder(_candidate)['last_hidden_state']
        _candidate = self.pooler(_candidate)


        _candidate = self.casual_feedforward(_candidate)

        _similarity = _e_category @ _candidate
        
        return _e_category, _similarity
    def read_json(self, file_path):
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data

# Test

## Load dataset

In [None]:
config_path = "config.json"
dataset = read_json("data/ECF 2.0/train/train.json")
trial_data = read_json("data/ECF 2.0/trial/trial.json")
raw_trial_data = read_json("data/ECF 2.0/trial/Subtask_1_trial.json")

n_train = int(0.8 * len(dataset))
train_data, valid_data = dataset[:n_train], dataset[n_train:]
train_dataset = ConversationDataset(config_path, train_data)
valid_dataset = ConversationDataset(config_path, valid_data)
model = TGG(config_path)
optimizer = Adam(model.parameters(), lr=0.001)

## Predict

In [None]:
trainer = Trainer(config_path)
trainer.train(model, optimizer, train_dataset=train_dataset, valid_dataset=valid_dataset, epochs=100, option="emotion", batch_size=128)
trial_dataset = ConversationDataset(config_path, trial_data, option="trial")
predict = Prediction(config_path)
predict.predict(model, raw_trial_data, trial_dataset)