In [1]:
# !pip install transformers
# !pip install scikit-learn
# !pip install focal-loss-torch

# Auxility Class

## Library

In [2]:
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
# Dataset
from PIL import Image
from torchvision import transforms
from torchvision.io import read_video, read_image
from torch.utils.data import Dataset, DataLoader
# Model
from transformers import AutoModel
from transformers import AutoTokenizer, BertModel, BertTokenizer, BertGenerationDecoder, GPT2Config, GPT2Model, GPT2LMHeadModel, EncoderDecoderModel

# Training parameter
from torch.optim import Adam
# Training process
from tqdm import tqdm
# Metrics
from sklearn.metrics import classification_report, accuracy_score
from collections import OrderedDict
import copy
import sys
from focal_loss.focal_loss import FocalLoss

## Function

In [3]:
def read_json(file_path):
    with open(file_path, 'r') as json_file:
        data = json.load(json_file)
    return data
def write_json(path, data):
    with open(path, 'w') as json_file:
        json.dump(data, json_file, indent=4)
def mkdir(path):
    isExist = os.path.exists(path)
    if not isExist:
        # Create a new directory because it does not exist
        os.makedirs(path)
        print("The new directory checkpoint is created!")

## Early Stopper

In [4]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def reset(self):
        self.counter = 0

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

# Train

## Dataset

In [5]:
class ConversationDataset(Dataset):
    def __init__(self, config_path, data, option="train"):
        super(ConversationDataset, self).__init__()
        self.build(config_path)
        self.data = data
        self.option = option
    def build(self, config_path):
        config = read_json(config_path)
        self.label2id = config["label2id"]
        self.id2label = config["id2label"]
        self.tokenizer_name = config["tokenizer_name"]
        self.padding = config["padding"]
        self.paragraph_max_length = config["paragraph_max_length"]
        self.text_max_length = config["text_max_length"]
        self.left_padding = config["left_padding"]
        self.right_padding = config["right_padding"]
        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
        self.text_pool_tokenizer = AutoTokenizer.from_pretrained("gpt2")
        # self.text_pool_tokenizer.add_special_tokens({'pad_token': 'eos_token_id'})
        self.text_pool_tokenizer.pad_token = self.text_pool_tokenizer.eos_token
    def span(self, text, subtext, ss=0):
        if len(subtext) == 0:
            return []
        for i in range(ss, len(text) - len(subtext)):
            for j in range(len(subtext)):
                if text[i+j] != subtext[j]:
                    match = 0
                    break
                else:
                    match += 1
                    if match == len(subtext):
                        return [i, i+len(subtext)]
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        conversation_ID = self.data[index]["conversation_ID"]
        paragraph = self.tokenizer(self.data[index]["paragraph"], padding=self.padding, max_length=self.paragraph_max_length, return_tensors="pt")
        paragraph = torch.squeeze(paragraph['input_ids'])
        utterance_ID = self.data[index]["utterance_ID"]
        text_pool = self.data[index]["text_pool"]
        speaker_pool = self.data[index]["speaker_pool"]
        token_text_pool = []
        token_speaker_pool = []
        list_utter = copy.copy(self.data[index]["text_utter"])

        for i in range(len(text_pool)):
            text = self.tokenizer.encode(text_pool[i], padding=self.padding, max_length=self.text_max_length)
            speaker = self.tokenizer(speaker_pool[i], padding=self.padding, max_length=8)["input_ids"]
            token_text_pool.append(text)
            token_speaker_pool.append(speaker)

        if len(text_pool) < self.left_padding+self.right_padding+1:
            for i in range(len(text_pool), self.left_padding+self.right_padding+1):
                text = self.tokenizer.encode("", padding=self.padding, max_length=self.text_max_length)
                speaker = self.tokenizer("", padding=self.padding, max_length=8)["input_ids"]
                token_text_pool.insert(0, text)
                token_speaker_pool.insert(0, speaker)
                list_utter.insert(0, 0)


        if self.option == "train":
            label = self.label2id[self.data[index]["emotion"]["property"]]
            emotion_label = [0] * len(self.id2label)
            emotion_label[label] = 1
            emotion_label = torch.FloatTensor(emotion_label)
            casual_pool = self.data[index]["emotion"]["casual"]
            token_casual = []


            list = []
            pool_label = ""
         
            for i in range(len(casual_pool)):
                casual = [casual_pool[i]["start"], casual_pool[i]["end"]]
                token_casual.append(casual)
                token = self.tokenizer.encode(casual_pool[i]["casual_text"], add_special_tokens=False)
                list.extend(self.span(paragraph.tolist(), token))
                if len(casual_pool[i]["casual_text"]) != 0:
                    pool_label += " " + casual_pool[i]["casual_text"]

            for i in range(len(casual_pool), self.left_padding+self.right_padding+1):
                casual = [0, 0]
                token_casual.insert(0, casual)
            # print(casual_pool)
            # print(list)
            paragraph_label = [0] * self.paragraph_max_length
            for i in list:
                paragraph_label[i] = 1

            # print("list_utter"+str(list_utter))
            # print("token_text_pool"+str(token_text_pool))

            return {
                "conversation_ID": conversation_ID,
                "paragraph": paragraph,
                "paragraph_label": torch.FloatTensor(paragraph_label),
                "pool_label": self.tokenizer(pool_label, padding=self.padding, max_length=self.paragraph_max_length, return_tensors="pt")["input_ids"],
                "utterance_ID": utterance_ID,
                "token_text_pool": torch.FloatTensor(token_text_pool),
                "text_utter_pool": torch.FloatTensor(list_utter),
                "token_speaker_pool": torch.FloatTensor(token_speaker_pool),
                "emotion_label": emotion_label,
                "span_label": torch.FloatTensor(token_casual)
            }
        else:
            return {
                "conversation_ID": conversation_ID,
                "paragraph": paragraph,
                "utterance_ID": utterance_ID,
                "token_text_pool": torch.FloatTensor(token_text_pool),
                "text_utter_pool": torch.FloatTensor(list_utter),
                "token_speaker_pool": torch.FloatTensor(token_speaker_pool),
            }

## Trainer

In [6]:
class Trainer(nn.Module):
    def __init__(self, config_path):
        super(Trainer, self).__init__()
        self.layers_to_freeze = []
        self.history = {
            "train loss": [],
            "valid loss": [],
            "emotion_correct": [],
            "span_correct": []
        }
        self.build(config_path)
    def build(self, config_path):
        print(config_path)
        config = read_json(config_path)
        self.patience = config['patience']
        self.min_delta = config['min_delta']
        self.batch_size = config['batch_size']
        self.device = config['device']
        self.save_checkpoint_dir = config['output_dir'] + "/checkpoint"
        self.save_history_dir = config['output_dir'] + "/history"
        self.emotion_head_layers = config['emotion_head_layers']
        self.casual_head_layers = config['casual_head_layers']
        self.decoder_layers = config['decoder_layers']
        self.encoder_layers = config['encoder_layers']
        self.lstm_layers = config['lstm_layers']
        self.alpha_focal_loss = config['alpha_focal_loss']
        self.gamma_focal_loss = config['gamma_focal_loss']
        self.label2id = config['label2id']
        self.target_names = self.label2id.keys()
        self.early_stopper = EarlyStopper(self.patience, self.min_delta)

        mkdir(self.save_history_dir)

    def emotion_head_training(self):
        self.layers_to_freeze = self.casual_head_layers + self.encoder_layers + self.lstm_layers
        # self.fn_loss = FocalLoss(gamma=self.gamma_focal_loss, weights=torch.tensor(self.alpha_focal_loss).to(self.device))
    def emotion_encoder_training(self):
        self.layers_to_freeze = self.casual_head_layers + self.lstm_layers
        # self.fn_loss = FocalLoss(gamma=self.gamma_focal_loss, weights=torch.tensor(self.alpha_focal_loss).to(self.device))
    def emotion_encoder_lstm_training(self):
        self.layers_to_freeze = self.casual_head_layers + self.decoder_layers 
    def emotion_casual_head_training(self):
        self.layers_to_freeze = self.encoder_layers + self.lstm_layers
    def emotion_casual_lstm_head_training(self):
        self.layers_to_freeze = self.encoder_layers
        # self.fn_loss = FocalLoss(gamma=self.gamma_focal_loss, weights=torch.tensor(self.alpha_focal_loss).to(self.device))
    def emotion_lstm_training(self):
        self.layers_to_freeze = self.casual_head_layers + self.encoder_layers
        # self.fn_loss = FocalLoss(gamma=self.gamma_focal_loss, weights=torch.tensor(self.alpha_focal_loss).to(self.device))
        
    def casual_head_training(self):
        self.layers_to_freeze = self.emotion_head_layers + self.encoder_layers + self.lstm_layers

    def full_training(self):
        self.layers_to_freeze = []
    def freeze(self, model):
        for name, param in model.named_parameters():
            if any(layer in name for layer in self.layers_to_freeze):
                param.requires_grad = False
            else:
                param.requires_grad = True
        return model
    def train(self, model, optimizer, train_dataset, valid_dataset, epochs, option="emotion_head", batch_size=None):
        self.optimizer = optimizer
        if option == "emotion_head":
            print("TRAINING " + option + " MODEL")
            self.emotion_head_training()
        elif option == "emotion_encoder":
            print("TRAINING " + option + " MODEL")
            self.emotion_encoder_training()
        elif option == "emotion_lstm":
            print("TRAINING " + option + " MODEL")
            self.emotion_lstm_training()
        elif option == "emotion_encoder_lstm":
            print("TRAINING " + option + " MODEL")
            self.emotion_encoder_lstm_training()
        elif option == "emotion_casual_head":
            print("TRAINING " + option + " MODEL")
            self.emotion_casual_head_training()
        elif option == "emotion_casual_lstm_head":
            print("TRAINING " + option + " MODEL")
            self.emotion_casual_lstm_head_training()
        elif option == "casual_head":
            print("TRAINING " + option + " MODEL")
            self.casual_head_training()
        elif option == "full":
            print("TRAINING " + option + " MODEL")
            self.full_training()
        else:
            print("Option fail")
            return

        if batch_size == None:
            self.batch_size = self.batch_size_config
        else:
            self.batch_size = batch_size
        mkdir(self.save_checkpoint_dir + "/" + option)
        self.fn_loss = nn.CrossEntropyLoss()
        self.span_fn_loss = nn.CrossEntropyLoss()



        print("device: "+str(self.device))
        print("batch_size: "+str(self.batch_size))
        print("gamma_focal_loss: "+str(self.gamma_focal_loss))
        print("alpha_focal_loss: "+str(self.alpha_focal_loss))

        self.early_stopper.reset()
        self.model = self.freeze(model)
        # if self.device == 'cuda' and torch.cuda.device_count() > 1:
        #     print("Training Parallel!")
        #     self.model = nn.DataParallel(self.model)
        self.model = self.model.to(self.device)

        self.train_dataloader = DataLoader(dataset=train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=False)
        self.valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=self.batch_size, shuffle=False, drop_last=False)
        
        best_val_loss = np.inf
        self.epochs = epochs
        for epoch in range(epochs):
            train_loss = self._train_epoch(epoch, option)
            val_loss, emotion_correct, span_correct, emotion_label_list, emotion_pred_list = self._val_epoch(epoch, option)
            
            self.history["train loss"].append(train_loss)
            self.history["valid loss"].append(val_loss)
            self.history["emotion_correct"].append(emotion_correct)
            self.history["span_correct"].append(span_correct)
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                self.save_checkpoint(epoch, option)
                print(f"Epoch {epoch+1}: Train loss {train_loss:.4f},  Valid loss {val_loss:.4f},  emotion_correct {emotion_correct*100:.4f}%,  span_correct {span_correct*100:.4f}%")
                print(classification_report(emotion_label_list, emotion_pred_list, target_names=self.target_names))
            if self.early_stopper.early_stop(val_loss):
                print("Early stop at epoch: "+str(epoch) + " with valid loss: "+str(val_loss))
                break

        
        write_json(self.save_history_dir + "/" + option + "_history.json", self.history) 
        
    def _train_epoch(self, epoch, option):
        self.model.train()
        train_loss = 0.0
        logger_message = f'Training epoch {epoch}/{self.epochs}'

        progress_bar = tqdm(self.train_dataloader,
                            desc=logger_message, initial=0, dynamic_ncols=True)
        for batch, data in enumerate(progress_bar):
            conversation_id = data["conversation_ID"]
            paragraph = data["paragraph"].to(self.device)
            span_label = data["paragraph_label"].to(self.device)
            pool_label = data["pool_label"].to(self.device)
            utter_id = data["utterance_ID"]
            token_text_pool = data["token_text_pool"].to(self.device)
            token_speaker_pool = data["token_speaker_pool"].to(self.device)
            emotion_label = data["emotion_label"].to(self.device)
            # span_label = data["span_label"].to(self.device)
            self.optimizer.zero_grad()
            # print("train:    333 2")
            
            emotion_pred, span_pred = self.model(paragraph, token_text_pool, token_speaker_pool)
            
            emotion_label = torch.argmax(emotion_label, dim=-1)
            if option == "emotion_head":
                loss = self.fn_loss(emotion_pred, emotion_label)
            elif option == "emotion_lstm":
                loss = self.fn_loss(emotion_pred, emotion_label)
            elif option == "emotion_encoder":
                loss = self.fn_loss(emotion_pred, emotion_label)
            elif option == "emotion_encoder_lstm":
                loss = self.fn_loss(emotion_pred, emotion_label)
            elif option == "emotion_casual_head":
                emotion_loss = self.fn_loss(emotion_pred, emotion_label)
                span_loss = self.span_fn_loss(span_pred, span_label)
                loss = emotion_loss + span_loss
            elif option == "emotion_casual_lstm_head":
                emotion_loss = self.fn_loss(emotion_pred, emotion_label)
                span_loss = self.span_fn_loss(span_pred, span_label)
                loss = emotion_loss + span_loss
            elif option == "full":
                emotion_loss = self.fn_loss(emotion_pred, emotion_label)
                span_loss = self.span_fn_loss(span_pred, span_label)
                loss = emotion_loss + span_loss
            elif option == "casual_head":
                loss = self.fn_loss(span_pred, span_label)
            else:
                emotion_loss = self.fn_loss(emotion_pred, emotion_label)
                span_loss = self.span_fn_loss(span_pred, span_label)
                loss = emotion_loss + span_loss

            train_loss += loss.item()
            loss.backward()
            self.optimizer.step()
        return train_loss / len(self.train_dataloader)

    def _val_epoch(self, epoch, option):
        self.model.eval()
        valid_loss = 0.0
        emotion_correct=0
        span_correct=0
        emotion_pred_list = []
        emotion_label_list = []
        logger_message = f'Validation epoch {epoch}/{self.epochs}'
        progress_bar = tqdm(self.valid_dataloader,
                            desc=logger_message, initial=0, dynamic_ncols=True)
        with torch.no_grad():
            for _, data in enumerate(progress_bar):
                conversation_id = data["conversation_ID"]
                paragraph = data["paragraph"].to(self.device)
                span_label = data["paragraph_label"].to(self.device)
                pool_label = data["pool_label"].to(self.device)
                utter_id = data["utterance_ID"]
                token_text_pool = data["token_text_pool"].to(self.device)
                token_speaker_pool = data["token_speaker_pool"].to(self.device)
                emotion_label = data["emotion_label"].to(self.device)
                # span_label = data["span_label"].to(self.device)
                emotion_pred, span_pred = self.model(paragraph, token_text_pool, token_speaker_pool)
                
                
                
                
                emotion_label = torch.argmax(emotion_label, dim=-1)
                if option == "emotion_head":
                    loss = self.fn_loss(emotion_pred, emotion_label)
                elif option == "emotion_lstm":
                    loss = self.fn_loss(emotion_pred, emotion_label)
                elif option == "emotion_encoder":
                    loss = self.fn_loss(emotion_pred, emotion_label)
                elif option == "emotion_encoder_lstm":
                    loss = self.fn_loss(emotion_pred, emotion_label)
                elif option == "emotion_casual_head":
                    emotion_loss = self.fn_loss(emotion_pred, emotion_label)
                    span_loss = self.span_fn_loss(span_pred, span_label)
                    loss = emotion_loss + span_loss
                elif option == "emotion_casual_lstm_head":
                    emotion_loss = self.fn_loss(emotion_pred, emotion_label)
                    span_loss = self.span_fn_loss(span_pred, span_label)
                    loss = emotion_loss + span_loss
                elif option == "full":
                    emotion_loss = self.fn_loss(emotion_pred, emotion_label)
                    span_loss = self.span_fn_loss(span_pred, span_label)
                    loss = emotion_loss + span_loss
                elif option == "casual_head":
                    loss = self.fn_loss(span_pred, span_label)
                else:
                    emotion_loss = self.fn_loss(emotion_pred, emotion_label)
                    span_loss = self.span_fn_loss(span_pred, span_label)
                    loss = emotion_loss + span_loss

                # emotion_label_ = torch.argmax(emotion_label, dim=-1)
                emotion_pred_ = torch.argmax(emotion_pred, dim=-1)
                emotion_pred_list.extend(emotion_pred_.cpu().tolist())
                emotion_label_list.extend(emotion_label.cpu().tolist())
                # print("emotion_label_: "+ str(emotion_label_))
                # print("emotion_pred_: "+ str(emotion_pred_))
                valid_loss += loss.item()
                emotion_correct += (emotion_pred_ == emotion_label).sum().item()
                span_correct += (torch.round(span_pred) == span_label).sum().item() / span_label.sum().item()
        
        return valid_loss / len(self.valid_dataloader), emotion_correct / len(self.valid_dataloader) / self.batch_size, span_correct / len(self.valid_dataloader) / self.batch_size / 16, emotion_label_list, emotion_pred_list
        
    def save_checkpoint(self, epoch, option):
        checkpoint_path = os.path.join(self.save_checkpoint_dir, option + '/checkpoint_{}.pth'.format(epoch))
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'epoch': epoch
        }, checkpoint_path)

## Predictor

In [7]:
class Prediction(nn.Module):
    def __init__(self, config_path):
        super(Prediction, self).__init__()
        self.build(config_path)
    def build(self, config_path):
        config = self.read_json(config_path)
        self.id2label = config["id2label"]
        self.label2id = config["label2id"]
        self.submit_path = config["submit path"]
        self.device = config["device"]
        self.batch_size = config["batch_size"]
    def span(self, text, subtext, ss=0):
        if len(subtext) == 0:
            return []
        for i in range(ss, len(text) - len(subtext)):
            for j in range(len(subtext)):
                if text[i+j] != subtext[j]:
                    match = 0
                    break
                else:
                    match += 1
                    if match == len(subtext):
                        return [i, i+len(subtext)]
    def predict(self, model, raw_dataset, dataset):
        self.model = model.to(self.device)
        self.model.eval()
        pred_list = []
        
        self.test_dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=False)
        logger_message = f'Predict '
        progress_bar = tqdm(self.test_dataloader,
                            desc=logger_message, initial=0, dynamic_ncols=True)
        with torch.no_grad():
            for _, data in enumerate(progress_bar):
                conversation_id = data["conversation_ID"]
                paragraph = data["paragraph"].to(self.device)
                utter_id = data["utterance_ID"]
                token_text_pool = data["token_text_pool"].to(self.device)
                text_utter_pool = data["text_utter_pool"]
                token_speaker_pool = data["token_speaker_pool"].to(self.device)
                emotion_pred, span_pred = self.model(paragraph, token_text_pool, token_speaker_pool)
                
                pred = self.convertpred2list(conversation_id, utter_id, emotion_pred, span_pred, paragraph.to('cpu').tolist(), token_text_pool.to('cpu').tolist(), text_utter_pool.tolist())
                pred_list.extend(pred)
        submit = self.convertdict2submit(self.convertpred2dict(raw_dataset, pred_list))
        self.save_json(self.submit_path, submit)
    def convertpred2list(self, conversation_id, utter_id, emotion_pred, span_pred, paragraph, token_text_pool):
        pred_list = []
        for i in range(len(emotion_pred)):
            emotion = self.id2label[str(torch.argmax(emotion_pred[i]).item())]
            emotion_string = str(utter_id[i].item()) + "_" + emotion
            if emotion != "neutral":
                # 
                span = torch.round(span_pred[i])

                list_pos = []
                for j in range(len(span)):
                    if (span[j]) == 1:
                        list_pos.append(j)

                for j in range(len(token_text_pool[i])):
                    inner_index = []
                    ss, se = self.span(paragraph[i], token_text_pool[i][j][1:-1])
                    for k in range(len(list_pos)):
                        if np.logical_and(list_pos[k] >= ss, list_pos[k] <= se):
                            inner_index.append(list_pos[k])
                    
                    if len(inner_index) == 0:
                        continue
                    elif len(inner_index) == 1:
                        start_index = inner_index[0]
                        end_index = se
                    elif len(inner_index) > 1:
                        start_index = inner_index[0]
                        end_index = inner_index[-1]
                    span_string = str(text_utter_pool[i]) + "_" + str(start_index) + "_" + str(end_index)
                    pred_dict = {}
                    pred_dict["conversation_ID"] = conversation_id[i].item()
                    pred_dict["emotion-cause_pairs"] = [emotion_string, span_string]
                    # print("pred_dict: "+str(pred_dict))
                    pred_list.append(pred_dict)

        return pred_list
    def convertpred2dict(self, raw_dataset, pred_list):
        dataset = self.convertlist2dict(raw_dataset)
        # dict = 
        for i in range(len(pred_list)):
            pred = pred_list[i]
            conversation_id = pred["conversation_ID"]
            # print("conversation_id: "+str(conversation_id))
            emotion_cause_pairs = pred["emotion-cause_pairs"]
            # print("emotion_cause_pairs: "+str(emotion_cause_pairs))
            if "emotion-cause_pairs" in dataset[conversation_id]:
                dataset[conversation_id]["emotion-cause_pairs"].append(emotion_cause_pairs)
            else:
                dataset[conversation_id]["emotion-cause_pairs"] = []
        # print(dataset)
        return dataset
    def convertlist2dict(self, raw_dataset):
        dataset = {}
        for i in range(len(raw_dataset)):
            conversation = raw_dataset[i]
            conversation_ID = conversation["conversation_ID"]
            # if conversation_ID not in dataset.key
            dataset[conversation_ID] = {
                "conversation_ID": conversation_ID,
                "conversation": conversation['conversation']
            }
        return dataset
    def convertdict2submit(self, raw_dataset):
        dataset = []
        for key in raw_dataset.keys():
            # print("key: "+str(key))
            sample = raw_dataset[key]
            # print("sample: "+str(sample))
            dataset.append(sample)
        return dataset
    def read_json(self, file_path):
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data
    def save_json(self, file_path, data):
        with open(file_path, 'w') as json_file:
            json.dump(data, json_file)

# Model

In [8]:
class TGG(nn.Module):
    def __init__(self, config_path):
        super(TGG, self).__init__()
        self.build(config_path)
        self.lstm = nn.LSTM(self.text_max_length, self.hidden_size, self.num_layers, batch_first=True, dropout=0.2, bidirectional=True)
        self.lstma_projection = nn.Linear(2*self.num_layers*self.hidden_size, self.embedding_dim)
        self.cat_layer = nn.Linear(2*self.embedding_dim, self.embedding_dim)
        self.emotion_fc = nn.Linear(self.embedding_dim, self.label_num)
        self.casual_feedforward = nn.Sequential(
            nn.Linear(self.embedding_dim , self.feedforward_dim),
            nn.ReLU(),
            nn.Linear(self.feedforward_dim , self.paragraph_max_length)
        )
        # self.casual_fc = nn.Linear(self.embedding_dim, 2)
        self.softmax = nn.Softmax(dim=-1)
        self.sigmoid = nn.Sigmoid()
    def build(self, config_path):
        config = self.read_json(config_path)
        self.hidden_size = config["hidden_size"]
        self.embedding_dim = config["embedding_dim"]
        self.feedforward_dim = config["feedforward_dim"]
        self.encoder_name = config["encoder_name"]
        self.decoder_name = config["decoder_name"]
        self.label2id = config["label2id"]
        self.label_num = len(self.label2id)
        self.num_layers = config["num_layers"]
        self.paragraph_max_length = config["paragraph_max_length"]
        self.text_max_length = config["text_max_length"]
        self.left_padding = config["left_padding"]
        self.right_padding = config["right_padding"]
        self.padding = config["padding"]
        self.num_text = 1 + self.left_padding + self.right_padding
        self.device = config["device"]
        
        main_part = AutoModel.from_pretrained(self.encoder_name)
        self.encoder_embeddings = main_part.embeddings
        self.encoder_encoder = main_part.encoder
        self.encoder_pooler = main_part.pooler
        # self.encoder_decoder = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
        
        # self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        # self.encoder_decoder.config.decoder_start_token_id = self.tokenizer.cls_token_id
        # self.encoder_decoder.config.pad_token_id = self.tokenizer.pad_token_id
        # self.encoder_decoder.config.vocab_size = self.encoder_decoder.config.decoder.vocab_size
        # self.encoder_embeddings = main_part.encoder.embeddings
        # self.encoder_encoder = main_part.encoder.encoder
        # self.encoder_pooler = main_part.encoder.pooler
        # self.decoder_embeddings = main_part.decoder.bert.embeddings
        # self.decoder_encoder = main_part.decoder.bert.encoder
        # self.decoder_pooler = main_part.decoder.bert.pooler

        # main_part_2 = GPT2LMHeadModel.from_pretrained(self.decoder_name, output_hidden_states =True)
        
    def forward(self, paragraph, token_text_pool, token_speaker_pool):
        batch_size = paragraph.size(0)
        device = paragraph.get_device()

        h0 = torch.randn((2 * self.num_layers, batch_size, self.hidden_size), device=device)
        c0 = torch.randn((2 * self.num_layers, batch_size, self.hidden_size), device=device)
        output, (hn, cn) = self.lstm(token_text_pool, (h0, c0))
        hn = hn.permute(1,0,2)
        hn = self.lstma_projection(hn.reshape(batch_size, -1))
        
        _paragraph = self.encoder_embeddings(paragraph)
        _paragraph = self.encoder_encoder(_paragraph)['last_hidden_state']
        _paragraph = self.encoder_pooler(_paragraph)
        # _paragraph_de = self.encoder_embeddings(paragraph)
        # _paragraph_de = self.encoder_encoder(_paragraph_de)['last_hidden_state']
        # _paragraph_de = self.encoder_pooler(_paragraph_de)
        # print("_paragraph:"+str(_paragraph.shape))

        hidden_state = torch.cat([hn, _paragraph], dim=-1)

        _x = self.cat_layer(hidden_state)
        # print("_x:"+str(_x.shape))
        # labels = self.tokenizer("", padding=self.padding, max_length=self.paragraph_max_length, return_tensors="pt").input_ids.to(device).expand(batch_size, -1)
        # print("labels:"+str(labels.shape))
        # print("paragraph:"+str(paragraph.shape))
        # _x = self.encoder_decoder(input_ids=paragraph, labels=labels, output_hidden_states=True)
        # print("_x:"+str(_x.shape))
        _e_category = self.emotion_fc(_x)
        # print("_e_category:"+str(_e_category.shape))

        _x = self.casual_feedforward(_x)
        # print("_x:"+str(_x.shape))
        # _x = F.cosine_similarity(_x, paragraph, dim=-1)
        # print("_x:"+str(_x.shape))
        _x = self.sigmoid(_x)
        # print("_x:"+str(_x.shape))
        return self.softmax(_e_category), _x
    def read_json(self, file_path):
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data

# Test

## Load dataset

In [9]:
config_path = "config.json"
train_data = read_json("../data/ECF 2.0/train/augment_train.json")
valid_data = read_json("../data/ECF 2.0/train/augment_valid.json")
# trial_data = read_json("../data/ECF 2.0/trial/trial.json")
# raw_trial_data = read_json("../data/ECF 2.0/trial/Subtask_1_trial.json")

train_dataset = ConversationDataset(config_path, train_data)
valid_dataset = ConversationDataset(config_path, valid_data)
model = TGG(config_path)
optimizer = Adam(model.parameters(), lr=0.001)

In [10]:
# for name, param in model.named_parameters():
#     print("\"" + name + "\",")

## Predict

In [11]:
trainer = Trainer(config_path)

config.json


In [12]:
trainer.train(model, optimizer, train_dataset=train_dataset, valid_dataset=valid_dataset, epochs=100, option="emotion_head", batch_size=128)

TRAINING emotion_head MODEL
device: cuda
batch_size: 128
gamma_focal_loss: 2
alpha_focal_loss: [1, 5, 5, 1, 1, 3, 1]


Training epoch 0/100: 100%|██████████| 101/101 [02:00<00:00,  1.19s/it]
Validation epoch 0/100: 100%|██████████| 26/26 [00:29<00:00,  1.14s/it]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1: Train loss 2.0193,  Valid loss 2.0213,  emotion_correct 13.8522%,  span_correct 3.6415%
              precision    recall  f1-score   support

       anger       0.00      0.00      0.00       461
     disgust       0.00      0.00      0.00       461
        fear       0.00      0.00      0.00       461
         joy       0.00      0.00      0.00       461
     neutral       0.00      0.00      0.00       461
     sadness       0.00      0.00      0.00       461
    surprise       0.14      1.00      0.25       461

    accuracy                           0.14      3227
   macro avg       0.02      0.14      0.04      3227
weighted avg       0.02      0.14      0.04      3227



Training epoch 1/100: 100%|██████████| 101/101 [01:51<00:00,  1.10s/it]
Validation epoch 1/100: 100%|██████████| 26/26 [00:27<00:00,  1.06s/it]
Training epoch 2/100: 100%|██████████| 101/101 [01:52<00:00,  1.11s/it]
Validation epoch 2/100: 100%|██████████| 26/26 [00:27<00:00,  1.04s/it]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 3: Train loss 2.0228,  Valid loss 2.0213,  emotion_correct 13.8522%,  span_correct 3.6423%
              precision    recall  f1-score   support

       anger       0.00      0.00      0.00       461
     disgust       0.00      0.00      0.00       461
        fear       0.00      0.00      0.00       461
         joy       0.00      0.00      0.00       461
     neutral       0.00      0.00      0.00       461
     sadness       0.00      0.00      0.00       461
    surprise       0.14      1.00      0.25       461

    accuracy                           0.14      3227
   macro avg       0.02      0.14      0.04      3227
weighted avg       0.02      0.14      0.04      3227



Training epoch 3/100: 100%|██████████| 101/101 [01:48<00:00,  1.07s/it]
Validation epoch 3/100: 100%|██████████| 26/26 [00:26<00:00,  1.04s/it]
Training epoch 4/100: 100%|██████████| 101/101 [01:50<00:00,  1.09s/it]
Validation epoch 4/100: 100%|██████████| 26/26 [00:26<00:00,  1.03s/it]
Training epoch 5/100: 100%|██████████| 101/101 [01:49<00:00,  1.09s/it]
Validation epoch 5/100: 100%|██████████| 26/26 [00:26<00:00,  1.02s/it]
Training epoch 6/100: 100%|██████████| 101/101 [01:49<00:00,  1.09s/it]
Validation epoch 6/100: 100%|██████████| 26/26 [00:26<00:00,  1.03s/it]
Training epoch 7/100: 100%|██████████| 101/101 [01:52<00:00,  1.11s/it]
Validation epoch 7/100: 100%|██████████| 26/26 [00:27<00:00,  1.05s/it]

Early stop at epoch: 7 with valid loss: 2.021279752254486





In [13]:
trainer.train(model, optimizer, train_dataset=train_dataset, valid_dataset=valid_dataset, epochs=100, option="emotion_lstm", batch_size=128)

TRAINING emotion_lstm MODEL
device: cuda
batch_size: 128
gamma_focal_loss: 2
alpha_focal_loss: [1, 5, 5, 1, 1, 3, 1]


Training epoch 0/100: 100%|██████████| 101/101 [01:52<00:00,  1.11s/it]
Validation epoch 0/100: 100%|██████████| 26/26 [00:27<00:00,  1.04s/it]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1: Train loss 2.0225,  Valid loss 2.0213,  emotion_correct 13.8522%,  span_correct 3.6937%
              precision    recall  f1-score   support

       anger       0.00      0.00      0.00       461
     disgust       0.00      0.00      0.00       461
        fear       0.00      0.00      0.00       461
         joy       0.00      0.00      0.00       461
     neutral       0.00      0.00      0.00       461
     sadness       0.00      0.00      0.00       461
    surprise       0.14      1.00      0.25       461

    accuracy                           0.14      3227
   macro avg       0.02      0.14      0.04      3227
weighted avg       0.02      0.14      0.04      3227



Training epoch 1/100: 100%|██████████| 101/101 [01:47<00:00,  1.06s/it]
Validation epoch 1/100: 100%|██████████| 26/26 [00:26<00:00,  1.02s/it]
Training epoch 2/100: 100%|██████████| 101/101 [01:51<00:00,  1.10s/it]
Validation epoch 2/100: 100%|██████████| 26/26 [00:26<00:00,  1.03s/it]
Training epoch 3/100: 100%|██████████| 101/101 [01:51<00:00,  1.10s/it]
Validation epoch 3/100: 100%|██████████| 26/26 [00:26<00:00,  1.03s/it]
Training epoch 4/100: 100%|██████████| 101/101 [01:52<00:00,  1.11s/it]
Validation epoch 4/100: 100%|██████████| 26/26 [00:27<00:00,  1.04s/it]

Early stop at epoch: 4 with valid loss: 2.021279752254486





In [None]:
trainer.train(model, optimizer, train_dataset=train_dataset, valid_dataset=valid_dataset, epochs=100, option="emotion_casual_lstm_head", batch_size=128)

TRAINING emotion_casual_lstm_head MODEL
device: cuda
batch_size: 128
gamma_focal_loss: 2
alpha_focal_loss: [1, 5, 5, 1, 1, 3, 1]
Training Parallel!


Training epoch 0/100: 100%|██████████| 101/101 [01:32<00:00,  1.09it/s]
Validation epoch 0/100: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1: Train loss 11.1833,  Valid loss 11.1017,  emotion_correct 13.8522%,  span_correct 5.3666%
              precision    recall  f1-score   support

       anger       0.00      0.00      0.00       461
     disgust       0.00      0.00      0.00       461
        fear       0.00      0.00      0.00       461
         joy       0.00      0.00      0.00       461
     neutral       0.00      0.00      0.00       461
     sadness       0.14      1.00      0.25       461
    surprise       0.00      0.00      0.00       461

    accuracy                           0.14      3227
   macro avg       0.02      0.14      0.04      3227
weighted avg       0.02      0.14      0.04      3227



  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
Training epoch 1/100: 100%|██████████| 101/101 [01:32<00:00,  1.09it/s]
Validation epoch 1/100: 100%|██████████| 26/26 [00:22<00:00,  1.16it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 2: Train loss 11.1607,  Valid loss 11.0931,  emotion_correct 13.8522%,  span_correct 5.3940%
              precision    recall  f1-score   support

       anger       0.00      0.00      0.00       461
     disgust       0.00      0.00      0.00       461
        fear       0.00      0.00      0.00       461
         joy       0.00      0.00      0.00       461
     neutral       0.00      0.00      0.00       461
     sadness       0.14      1.00      0.25       461
    surprise       0.00      0.00      0.00       461

    accuracy                           0.14      3227
   macro avg       0.02      0.14      0.04      3227
weighted avg       0.02      0.14      0.04      3227



  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
Training epoch 2/100: 100%|██████████| 101/101 [01:31<00:00,  1.11it/s]
Validation epoch 2/100: 100%|██████████| 26/26 [00:21<00:00,  1.19it/s]
Training epoch 3/100: 100%|██████████| 101/101 [01:31<00:00,  1.10it/s]
Validation epoch 3/100: 100%|██████████| 26/26 [00:21<00:00,  1.20it/s]
Training epoch 4/100: 100%|██████████| 101/101 [01:31<00:00,  1.10it/s]
Validation epoch 4/100: 100%|██████████| 26/26 [00:21<00:00,  1.20it/s]

Early stop at epoch: 4 with valid loss: 11.093136714054989





In [None]:
trainer.train(model, optimizer, train_dataset=train_dataset, valid_dataset=valid_dataset, epochs=100, option="full", batch_size=32)

TRAINING full MODEL
device: cuda
batch_size: 32
gamma_focal_loss: 2
alpha_focal_loss: [1, 5, 5, 1, 1, 3, 1]
Training Parallel!


Training epoch 0/100:   0%|          | 0/403 [00:00<?, ?it/s]


RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/root/anaconda3/envs/diffusion/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/root/anaconda3/envs/diffusion/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/tmp/ipykernel_4112826/2008320485.py", line 65, in forward
    _paragraph = self.encoder_encoder(_paragraph)['last_hidden_state']
  File "/root/anaconda3/envs/diffusion/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/anaconda3/envs/diffusion/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 607, in forward
    layer_outputs = layer_module(
  File "/root/anaconda3/envs/diffusion/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/anaconda3/envs/diffusion/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 539, in forward
    layer_output = apply_chunking_to_forward(
  File "/root/anaconda3/envs/diffusion/lib/python3.8/site-packages/transformers/pytorch_utils.py", line 242, in apply_chunking_to_forward
    return forward_fn(*input_tensors)
  File "/root/anaconda3/envs/diffusion/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 552, in feed_forward_chunk
    layer_output = self.output(intermediate_output, attention_output)
  File "/root/anaconda3/envs/diffusion/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/anaconda3/envs/diffusion/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 466, in forward
    hidden_states = self.LayerNorm(hidden_states + input_tensor)
  File "/root/anaconda3/envs/diffusion/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/anaconda3/envs/diffusion/lib/python3.8/site-packages/torch/nn/modules/normalization.py", line 189, in forward
    return F.layer_norm(
  File "/root/anaconda3/envs/diffusion/lib/python3.8/site-packages/torch/nn/functional.py", line 2347, in layer_norm
    return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: CUDA out of memory. Tried to allocate 12.00 MiB (GPU 0; 10.76 GiB total capacity; 7.08 GiB already allocated; 5.56 MiB free; 7.15 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF


In [None]:
trainer.train(model, optimizer, train_dataset=train_dataset, valid_dataset=valid_dataset, epochs=100, option="emotion_casual_head", batch_size=128)

In [None]:
trainer.train(model, optimizer, train_dataset=train_dataset, valid_dataset=valid_dataset, epochs=100, option="emotion_head", batch_size=512)

In [None]:
trainer.train(model, optimizer, train_dataset=train_dataset, valid_dataset=valid_dataset, epochs=100, option="emotion_encoder", batch_size=32)

In [None]:
trainer.train(model, optimizer, train_dataset=train_dataset, valid_dataset=valid_dataset, epochs=100, option="emotion_encoder_lstm", batch_size=32)

In [None]:

trial_dataset = ConversationDataset(config_path, trial_data, option="trial")
predict = Prediction(config_path)
predict.predict(model, raw_trial_data, trial_dataset)