# Auxility Class

## Library

In [1]:
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
# Dataset
from PIL import Image
from torchvision import transforms
from torchvision.io import read_video, read_image
from torch.utils.data import Dataset, DataLoader
# Model
from transformers import AutoModel
from transformers import AutoTokenizer, BertModel, BertTokenizer, BertGenerationDecoder

# Training parameter
from torch.optim import Adam
# Training process
from tqdm import tqdm
# Metrics
from sklearn.metrics import classification_report, accuracy_score
from collections import OrderedDict
import copy
import sys
from focal_loss.focal_loss import FocalLoss

## Function

In [2]:
def read_json(file_path):
    with open(file_path, 'r') as json_file:
        data = json.load(json_file)
    return data
def write_json(path, data):
    with open(path, 'w') as json_file:
        json.dump(data, json_file, indent=4)
def mkdir(path):
    isExist = os.path.exists(path)
    if not isExist:
        # Create a new directory because it does not exist
        os.makedirs(path)
        print("The new directory checkpoint is created!")

## Early Stopper

In [3]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def reset(self):
        self.counter = 0

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

# Train

## Dataset

In [4]:
class ConversationDataset(Dataset):
    def __init__(self, config_path, data, option="train"):
        super(ConversationDataset, self).__init__()
        self.build(config_path)
        self.data = data
        self.option = option
    def build(self, config_path):
        config = read_json(config_path)
        self.label2id = config["label2id"]
        self.id2label = config["id2label"]
        self.tokenizer_name = config["tokenizer_name"]
        self.padding = config["padding"]
        self.max_length = config["max_length"]
        self.left_padding = config["left_padding"]
        self.right_padding = config["right_padding"]
        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        conversation_ID = self.data[index]["conversation_ID"]
        paragraph = self.tokenizer(self.data[index]["paragraph"], padding=self.padding, max_length=self.max_length, return_tensors="pt")
        paragraph = torch.squeeze(paragraph['input_ids'])
        utterance_ID = self.data[index]["utterance_ID"]
        text_pool = self.data[index]["text_pool"]
        speaker_pool = self.data[index]["speaker_pool"]
        # print(" conversation_ID: "+str(conversation_ID))
        # print(" utterance_ID: "+str(utterance_ID))
        # print(" text_pool: "+str(len(text_pool)))
        # print(" speaker_pool: "+str(len(speaker_pool)))
        token_text_pool = []
        token_speaker_pool = []

        for i in range(len(text_pool)):
            text = self.tokenizer(text_pool[i], padding=self.padding, max_length=128)["input_ids"]
            speaker = self.tokenizer(speaker_pool[i], padding=self.padding, max_length=8)["input_ids"]
            token_text_pool.append(text)
            token_speaker_pool.append(speaker)

        if len(text_pool) < self.left_padding+self.right_padding+1:
            for i in range(len(text_pool), self.left_padding+self.right_padding+1):
                text = self.tokenizer("", padding=self.padding, max_length=128)["input_ids"]
                speaker = self.tokenizer("", padding=self.padding, max_length=8)["input_ids"]
                token_text_pool.insert(0, text)
                token_speaker_pool.insert(0, speaker)
        # print(" token_text_pool: "+str(len(token_text_pool)))
        # print(" token_speaker_pool: "+str(len(token_speaker_pool)))
        if self.option == "train":
            label = self.label2id[self.data[index]["emotion"]["property"]]
            emotion_label = [0] * len(self.id2label)
            emotion_label[label] = 1
            emotion_label = torch.FloatTensor(emotion_label)
            casual_pool = self.data[index]["emotion"]["casual"]
            token_casual = []

            for i in casual_pool.keys():
                casual = [casual_pool[i]["start"], casual_pool[i]["end"]]
                token_casual.append(casual)
            # print(" token_casual: "+str(len(token_casual)))

            for i in range(len(casual_pool), self.left_padding+self.right_padding+1):
                casual = [0, 0]
                token_casual.insert(0, casual)
            # print(" token_casual: "+str(len(token_casual)))

            # print("conversation_ID : "+ str(conversation_ID))
            # print("paragraph : "+ str(paragraph))
            # print("utterance_ID : "+ str(utterance_ID))
            # print("token_text_pool : "+ str(token_text_pool))
            # print("token_speaker_pool : "+ str(token_speaker_pool))
            # print("emotion_label : "+ str(emotion_label))
            # print("token_casual : "+ str(token_casual))
            return {
                "conversation_ID": conversation_ID,
                "paragraph": paragraph,
                "utterance_ID": utterance_ID,
                "token_text_pool": torch.FloatTensor(token_text_pool),
                "token_speaker_pool": torch.FloatTensor(token_speaker_pool),
                "emotion_label": emotion_label,
                "span_label": torch.FloatTensor(token_casual)
            }
        else:
            return {
                "conversation_ID": conversation_ID,
                "paragraph": paragraph,
                "utterance_ID": utterance_ID,
                "token_text_pool": torch.FloatTensor(token_text_pool),
                "token_speaker_pool": torch.FloatTensor(token_speaker_pool),
            }

## Trainer

In [10]:
class Trainer(nn.Module):
    def __init__(self, config_path):
        super(Trainer, self).__init__()
        self.layers_to_freeze = []
        self.history = {
            "train loss": [],
            "valid loss": [],
            "emotion_correct": [],
            "span_correct": []
        }
        self.build(config_path)
    def build(self, config_path):
        print(config_path)
        config = read_json(config_path)
        self.patience = config['patience']
        self.min_delta = config['min_delta']
        self.batch_size = config['batch_size']
        self.device = config['device']
        self.save_checkpoint_dir = config['output_dir'] + "/checkpoint"
        self.save_history_dir = config['output_dir'] + "/history"
        self.emotion_head_layers = config['emotion_head_layers']
        self.casual_head_layers = config['casual_head_layers']
        self.encoder_layers = config['encoder_layers']
        self.lstm_layers = config['lstm_layers']
        self.alpha_focal_loss = config['alpha_focal_loss']
        self.gamma_focal_loss = config['gamma_focal_loss']
        self.label2id = config['label2id']
        self.target_names = self.label2id.keys()
        self.early_stopper = EarlyStopper(self.patience, self.min_delta)

        mkdir(self.save_history_dir)

    def emotion_head_training(self):
        self.layers_to_freeze = self.casual_head_layers + self.encoder_layers + self.lstm_layers
        self.fn_loss = FocalLoss(gamma=self.gamma_focal_loss, weights=torch.tensor(self.alpha_focal_loss).to(self.device))
    def emotion_encoder_training(self):
        self.layers_to_freeze = self.casual_head_layers + self.lstm_layers
        self.fn_loss = FocalLoss(gamma=self.gamma_focal_loss, weights=torch.tensor(self.alpha_focal_loss).to(self.device))
    def emotion_encoder_lstm_training(self):
        self.layers_to_freeze = self.casual_head_layers
        self.fn_loss = FocalLoss(gamma=self.gamma_focal_loss, weights=torch.tensor(self.alpha_focal_loss).to(self.device))
    def emotion_lstm_training(self):
        self.layers_to_freeze = self.casual_head_layers + self.encoder_layers
        self.fn_loss = FocalLoss(gamma=self.gamma_focal_loss, weights=torch.tensor(self.alpha_focal_loss).to(self.device))
        
    def casual_head_training(self):
        self.layers_to_freeze = self.emotion_head_layers + self.encoder_layers + self.lstm_layers
        self.fn_loss = nn.MSELoss()

    def full_training(self):
        self.layers_to_freeze = []
    def freeze(self, model):
        for name, param in model.named_parameters():
            if any(layer in name for layer in self.layers_to_freeze):
                param.requires_grad = False
            else:
                param.requires_grad = True
        return model
    def train(self, model, optimizer, train_dataset, valid_dataset, epochs, option="emotion_head", batch_size=None):
        self.optimizer = optimizer
        if option == "emotion_head":
            print("TRAINING " + option + " MODEL")
            self.emotion_head_training()
        elif option == "emotion_encoder":
            print("TRAINING " + option + " MODEL")
            self.emotion_encoder_training()
        elif option == "emotion_lstm":
            print("TRAINING " + option + " MODEL")
            self.emotion_lstm_training()
        elif option == "emotion_encoder_lstm":
            print("TRAINING " + option + " MODEL")
            self.emotion_encoder_lstm_training()
        elif option == "casual_head":
            print("TRAINING " + option + " MODEL")
            self.casual_head_training()

        if batch_size == None:
            self.batch_size = self.batch_size_config
        else:
            self.batch_size = batch_size
        print("device: "+str(self.device))
        print("batch_size: "+str(self.batch_size))
        print("gamma_focal_loss: "+str(self.gamma_focal_loss))
        print("alpha_focal_loss: "+str(self.alpha_focal_loss))

        self.early_stopper.reset()
        self.model = self.freeze(model)
        self.model = self.model.to(self.device)

        self.train_dataloader = DataLoader(dataset=train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=False)
        self.valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=self.batch_size, shuffle=False, drop_last=False)
        
        best_val_loss = np.inf
        self.epochs = epochs
        for epoch in range(epochs):
            train_loss = self._train_epoch(epoch, option)
            val_loss, emotion_correct, span_correct, emotion_label_list, emotion_pred_list = self._val_epoch(epoch, option)
            
            self.history["train loss"].append(train_loss)
            self.history["valid loss"].append(val_loss)
            self.history["emotion_correct"].append(emotion_correct)
            self.history["span_correct"].append(span_correct)
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                self.save_checkpoint(epoch, option)
                print(f"Epoch {epoch+1}: Train loss {train_loss:.4f},  Valid loss {val_loss:.4f},  emotion_correct {emotion_correct*100:.4f}%,  span_correct {span_correct*100:.4f}%")
                print(classification_report(emotion_label_list, emotion_pred_list, target_names=self.target_names))
            if self.early_stopper.early_stop(val_loss):
                print("Early stop at epoch: "+str(epoch) + " with valid loss: "+str(val_loss))
                break

        
        mkdir(self.save_checkpoint_dir + "/" + option)
        write_json(self.save_history_dir + "/" + option + "_history.json", self.history) 
        
    def _train_epoch(self, epoch, option):
        self.model.train()
        train_loss = 0.0
        logger_message = f'Training epoch {epoch}/{self.epochs}'

        progress_bar = tqdm(self.train_dataloader,
                            desc=logger_message, initial=0, dynamic_ncols=True)
        for batch, data in enumerate(progress_bar):
            conversation_id = data["conversation_ID"]
            paragraph = data["paragraph"].to(self.device)
            utter_id = data["utterance_ID"]
            token_text_pool = data["token_text_pool"].to(self.device)
            token_speaker_pool = data["token_speaker_pool"].to(self.device)
            emotion_label = data["emotion_label"].to(self.device)
            span_label = data["span_label"].to(self.device)
            # print("conversation_id:"+str(conversation_id.shape))
            # print("paragraph:"+str(paragraph.shape))
            # print("casual_pool:"+str(casual_pool.shape))
            # print("emotion_label:"+str(emotion_label.shape))
            # print("span_label:"+str(span_label.shape))
            self.optimizer.zero_grad()
            
            emotion_pred, span_pred = self.model(paragraph, token_text_pool, token_speaker_pool)
            # print("emotion_label : "+str(emotion_label.shape))
            # print("span_label : "+str(span_label.shape))
            # print("emotion_pred : "+str(emotion_pred.shape))
            # print("span_pred : "+str(span_pred.shape))
            
            emotion_label = torch.argmax(emotion_label, dim=-1)
            if option == "emotion_head":
                loss = self.fn_loss(emotion_pred, emotion_label)
            elif option == "emotion_lstm":
                loss = self.fn_loss(emotion_pred, emotion_label)
            elif option == "emotion_encoder":
                loss = self.fn_loss(emotion_pred, emotion_label)
            elif option == "emotion_encoder_lstm":
                loss = self.fn_loss(emotion_pred, emotion_label)
            elif option == "casual_head":
                loss = self.fn_loss(span_pred, span_label)
            else:
                emotion_loss = self.fn_loss(emotion_pred, emotion_label)
                span_loss = self.fn_loss(span_pred, span_label)
                loss = emotion_loss + span_loss

            train_loss += loss.item()
            loss.backward()
            self.optimizer.step()
        return train_loss / len(self.train_dataloader)

    def _val_epoch(self, epoch, option):
        self.model.eval()
        valid_loss = 0.0
        emotion_correct=0
        span_correct=0
        emotion_pred_list = []
        emotion_label_list = []
        logger_message = f'Validation epoch {epoch}/{self.epochs}'
        progress_bar = tqdm(self.valid_dataloader,
                            desc=logger_message, initial=0, dynamic_ncols=True)
        with torch.no_grad():
            for _, data in enumerate(progress_bar):
                conversation_id = data["conversation_ID"]
                paragraph = data["paragraph"].to(self.device)
                utter_id = data["utterance_ID"]
                token_text_pool = data["token_text_pool"].to(self.device)
                token_speaker_pool = data["token_speaker_pool"].to(self.device)
                emotion_label = data["emotion_label"].to(self.device)
                span_label = data["span_label"].to(self.device)
                emotion_pred, span_pred = self.model(paragraph, token_text_pool, token_speaker_pool)
                
                
                
                
                emotion_label = torch.argmax(emotion_label, dim=-1)
                if option == "emotion_head":
                    loss = self.fn_loss(emotion_pred, emotion_label)
                elif option == "emotion_lstm":
                    loss = self.fn_loss(emotion_pred, emotion_label)
                elif option == "emotion_encoder":
                    loss = self.fn_loss(emotion_pred, emotion_label)
                elif option == "emotion_encoder_lstm":
                    loss = self.fn_loss(emotion_pred, emotion_label)
                elif option == "casual_head":
                    loss = self.fn_loss(span_pred, span_label)
                else:
                    emotion_loss = self.fn_loss(emotion_pred, emotion_label)
                    span_loss = self.fn_loss(span_pred, span_label)
                    loss = emotion_loss + span_loss

                # emotion_label_ = torch.argmax(emotion_label, dim=-1)
                emotion_pred_ = torch.argmax(emotion_pred, dim=-1)
                emotion_pred_list.extend(emotion_pred_.cpu().tolist())
                emotion_label_list.extend(emotion_label.cpu().tolist())
                # print("emotion_label_: "+ str(emotion_label_))
                # print("emotion_pred_: "+ str(emotion_pred_))
                valid_loss += loss.item()
                emotion_correct += (emotion_pred_ == emotion_label).sum().item()
                span_correct += (span_pred == span_label).sum().item()
        
        return valid_loss / len(self.valid_dataloader), emotion_correct / len(self.valid_dataloader) / self.batch_size, span_correct / len(self.valid_dataloader) / self.batch_size, emotion_label_list, emotion_pred_list
        
    def save_checkpoint(self, epoch, option):
        checkpoint_path = os.path.join(self.save_checkpoint_dir, option + '/checkpoint_{}.pth'.format(epoch))
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'epoch': epoch
        }, checkpoint_path)

## Predictor

In [6]:
class Prediction(nn.Module):
    def __init__(self, config_path):
        super(Prediction, self).__init__()
        self.build(config_path)
    def build(self, config_path):
        config = self.read_json(config_path)
        self.id2label = config["id2label"]
        self.label2id = config["label2id"]
        self.submit_path = config["submit path"]
        self.device = config["device"]
        self.batch_size = config["batch_size"]
    def predict(self, model, raw_dataset, dataset):
        self.model = model.to(self.device)
        self.model.eval()
        pred_list = []
        
        self.test_dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=False)
        logger_message = f'Predict '
        progress_bar = tqdm(self.test_dataloader,
                            desc=logger_message, initial=0, dynamic_ncols=True)
        with torch.no_grad():
            for _, data in enumerate(progress_bar):
                conversation_id = data["conversation_ID"]
                paragraph = data["paragraph"].to(self.device)
                utter_id = data["utterance_ID"]
                token_text_pool = data["token_text_pool"].to(self.device)
                token_speaker_pool = data["token_speaker_pool"].to(self.device)
                emotion_pred, span_pred = self.model(paragraph, token_text_pool, token_speaker_pool)
                
                pred = self.convertpred2list(conversation_id, utter_id, emotion_pred, span_pred, casual_span_pool)
                pred_list.extend(pred)
        submit = self.convertdict2submit(self.convertpred2dict(raw_dataset, pred_list))
        self.save_json(self.submit_path, submit)
    def convertpred2list(self, conversation_id, utter_id, emotion_pred, span_pred, casual_span_pool):
        pred_list = []
        for i in range(len(emotion_pred)):
            emotion = self.id2label[str(torch.argmax(emotion_pred[i]).item())]
            emotion_string = str(utter_id[i].item()) + "_" + emotion
            if emotion != "neutral":
                # 
                span = torch.round(span_pred[i])
                # print("span: "+str(span))
                # print("casual_span_pool: "+str(casual_span_pool[i]))
                for j in range(len(span)):
                    if span[j] == 1:
                        if casual_span_pool[i][j][0].item() == 0:
                            break
                        if casual_span_pool[i][j][1].item() == casual_span_pool[i][j][2].item():
                            break
                        span_string = str(int(casual_span_pool[i][j][0].item())) + "_" + str(int(casual_span_pool[i][j][1].item())) + "_" + str(int(casual_span_pool[i][j][2].item()))
                        pred_dict = {}
                        pred_dict["conversation_ID"] = conversation_id[i].item()
                        pred_dict["emotion-cause_pairs"] = [emotion_string, span_string]
                        # print("pred_dict: "+str(pred_dict))
                        pred_list.append(pred_dict)
        return pred_list
    def convertpred2dict(self, raw_dataset, pred_list):
        dataset = self.convertlist2dict(raw_dataset)
        # dict = 
        for i in range(len(pred_list)):
            pred = pred_list[i]
            conversation_id = pred["conversation_ID"]
            # print("conversation_id: "+str(conversation_id))
            emotion_cause_pairs = pred["emotion-cause_pairs"]
            # print("emotion_cause_pairs: "+str(emotion_cause_pairs))
            if "emotion-cause_pairs" in dataset[conversation_id]:
                dataset[conversation_id]["emotion-cause_pairs"].append(emotion_cause_pairs)
            else:
                dataset[conversation_id]["emotion-cause_pairs"] = []
        # print(dataset)
        return dataset
    def convertlist2dict(self, raw_dataset):
        dataset = {}
        for i in range(len(raw_dataset)):
            conversation = raw_dataset[i]
            conversation_ID = conversation["conversation_ID"]
            # if conversation_ID not in dataset.key
            dataset[conversation_ID] = {
                "conversation_ID": conversation_ID,
                "conversation": conversation['conversation']
            }
        return dataset
    def convertdict2submit(self, raw_dataset):
        dataset = []
        for key in raw_dataset.keys():
            # print("key: "+str(key))
            sample = raw_dataset[key]
            # print("sample: "+str(sample))
            dataset.append(sample)
        return dataset
    def read_json(self, file_path):
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data
    def save_json(self, file_path, data):
        with open(file_path, 'w') as json_file:
            json.dump(data, json_file)

# Model

In [7]:
class TGG(nn.Module):
    def __init__(self, config_path):
        super(TGG, self).__init__()
        self.build(config_path)
        self.lstm = nn.LSTM(128, self.hidden_size, self.num_layers, batch_first=True, dropout=0.2, bidirectional=True)
        self.lstma_projection = nn.Linear(2*self.num_layers*self.hidden_size, self.embedding_dim)
        self.cat_layer = nn.Linear(2*self.embedding_dim, self.embedding_dim)
        self.emotion_fc = nn.Linear(self.embedding_dim, self.label_num)
        self.casual_feedforward = nn.Sequential(
            nn.Linear(self.embedding_dim, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.embedding_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        self.softmax = nn.Softmax(dim=-1)
    def build(self, config_path):
        config = self.read_json(config_path)
        self.hidden_size = config["hidden_size"]
        self.embedding_dim = config["embedding_dim"]
        self.encoder_name = config["encoder_name"]
        self.label2id = config["label2id"]
        self.label_num = len(self.label2id)
        self.num_layers = config["num_layers"]
        self.device = config["device"]
        
        main_part = AutoModel.from_pretrained(self.encoder_name)
        self.embeddings = main_part.embeddings
        self.encoder = main_part.encoder
        self.pooler = main_part.pooler
        
    def forward(self, paragraph, token_text_pool, token_speaker_pool):
        batch_size = paragraph.size(0)
        device = paragraph.get_device()

        h0 = torch.randn((2 * self.num_layers, batch_size, self.hidden_size), device=device)
        c0 = torch.randn((2 * self.num_layers, batch_size, self.hidden_size), device=device)
        output, (hn, cn) = self.lstm(token_text_pool, (h0, c0))
        hn = hn.permute(1,0,2)
        hn = self.lstma_projection(hn.reshape(batch_size, -1))
        # print("x: "+str(x.shape))
        
        _paragraph = self.embeddings(paragraph)
        _paragraph = self.encoder(_paragraph)['last_hidden_state']
        _paragraph = self.pooler(_paragraph)
        # print("hn : "+ str(hn.shape))
        # print("cn : "+ str(cn.shape))
        # print("output : "+ str(output.shape))
        # print("_paragraph : "+ str(_paragraph.shape))

        _x = self.cat_layer(torch.cat([hn, _paragraph], dim=-1))
        _e_category = self.emotion_fc(_x)
        # print("_e_category : "+ str(_e_category.shape))

        # _candidate = self.embeddings(candidate)
        # _candidate = self.encoder(_candidate)['last_hidden_state']
        # _candidate = self.pooler(_candidate)


        # _candidate = self.casual_feedforward(_candidate)

        # _similarity = _e_category @ _candidate
        _similarity = torch.rand((batch_size, 9, 2), device=device)
        # print("_similarity : "+ str(_similarity.shape))
        return self.softmax(_e_category), _similarity
    def read_json(self, file_path):
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data

# Test

## Load dataset

In [8]:
config_path = "config.json"
dataset = read_json("../data/ECF 2.0/train/train.json")
trial_data = read_json("../data/ECF 2.0/trial/trial.json")
raw_trial_data = read_json("../data/ECF 2.0/trial/Subtask_1_trial.json")

n_train = int(0.8 * len(dataset))
train_data, valid_data = dataset[:n_train], dataset[n_train:]
train_dataset = ConversationDataset(config_path, train_data)
valid_dataset = ConversationDataset(config_path, valid_data)
model = TGG(config_path)
optimizer = Adam(model.parameters(), lr=0.001)

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

KeyboardInterrupt: 

In [None]:
# for name, param in model.named_parameters():
#     print("\"" + name + "\",")

## Predict

In [11]:
trainer = Trainer(config_path)

config.json
The new directory checkpoint is created!


In [None]:
trainer.train(model, optimizer, train_dataset=train_dataset, valid_dataset=valid_dataset, epochs=100, option="emotion_lstm", batch_size=512)
trial_dataset = ConversationDataset(config_path, trial_data, option="trial")
predict = Prediction(config_path)
predict.predict(model, raw_trial_data, trial_dataset)

In [None]:
trainer = Trainer(config_path)
trainer.train(model, optimizer, train_dataset=train_dataset, valid_dataset=valid_dataset, epochs=100, option="emotion_encoder", batch_size=32)
trial_dataset = ConversationDataset(config_path, trial_data, option="trial")
trainer = Trainer(config_path)
trainer.train(model, optimizer, train_dataset=train_dataset, valid_dataset=valid_dataset, epochs=100, option="emotion_encoder_lstm", batch_size=32)
trial_dataset = ConversationDataset(config_path, trial_data, option="trial")