In [1]:
!pip install transformers
!pip install scikit-learn
!pip install focal-loss-torch

[0m

In [2]:
# from google.colab import drive
# drive.mount('/content/drive')

In [3]:
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
# Dataset
from PIL import Image
from torchvision import transforms
from torchvision.io import read_video, read_image
from torch.utils.data import Dataset, DataLoader
# Model
from transformers import AutoModel
from transformers import AutoTokenizer, BertModel, BertTokenizer, BertGenerationDecoder, GPT2Config, GPT2Model, GPT2LMHeadModel

# Training parameter
from torch.optim import Adam
# Training process
from tqdm import tqdm
# Metrics
from sklearn.metrics import classification_report, accuracy_score
from collections import OrderedDict
import copy
import sys
from focal_loss.focal_loss import FocalLoss

In [4]:
def read_json(file_path):
    with open(file_path, 'r') as json_file:
        data = json.load(json_file)
    return data
def write_json(path, data):
    with open(path, 'w') as json_file:
        json.dump(data, json_file, indent=4)
def mkdir(path):
    isExist = os.path.exists(path)
    if not isExist:
        # Create a new directory because it does not exist
        os.makedirs(path)
        print("The new directory checkpoint is created!")
def count(data):
    dictionary = {}

    for item in data:
        # print( dictionary.keys())
        emotion = item['emotion']['property']
        if emotion in dictionary.keys():
            # print(emotion)
            dictionary[emotion] +=1
        else:
            # print(emotion)
            dictionary[emotion] =1
    return dictionary

In [5]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def reset(self):
        self.counter = 0

    def early_stop(self, validation_loss):
        if validation_loss < (self.min_validation_loss + self.min_delta):
            self.min_validation_loss = validation_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [6]:
class ConversationDataset(Dataset):
    def __init__(self, config_path, data, option="train"):
        super(ConversationDataset, self).__init__()
        self.build(config_path)
        self.data = data
        self.option = option
    def build(self, config_path):
        config = read_json(config_path)
        self.label2id = config["label2id"]
        self.id2label = config["id2label"]
        self.tokenizer_name = config["tokenizer_name"]
        self.padding = config["padding"]
        self.paragraph_max_length = config["paragraph_max_length"]
        self.text_max_length = config["text_max_length"]
        self.left_padding = config["left_padding"]
        self.right_padding = config["right_padding"]
        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
        self.text_pool_tokenizer = AutoTokenizer.from_pretrained("gpt2")
        # self.text_pool_tokenizer.add_special_tokens({'pad_token': 'eos_token_id'})
        self.text_pool_tokenizer.pad_token = self.text_pool_tokenizer.eos_token
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        conversation_ID = self.data[index]["conversation_ID"]
        paragraph = self.tokenizer(self.data[index]["paragraph"], padding=self.padding, max_length=self.paragraph_max_length, return_tensors="pt")
        paragraph = torch.squeeze(paragraph['input_ids'])
        utterance_ID = self.data[index]["utterance_ID"]
        text_pool = self.data[index]["text_pool"]
        speaker_pool = self.data[index]["speaker_pool"]
        token_text_pool = []
        token_speaker_pool = []

        for i in range(len(text_pool)):
            text = self.tokenizer.encode(text_pool[i], padding=self.padding, max_length=self.text_max_length)
            speaker = self.tokenizer(speaker_pool[i], padding=self.padding, max_length=8)["input_ids"]
            token_text_pool.append(text)
            token_speaker_pool.append(speaker)

        if len(text_pool) < self.left_padding+self.right_padding+1:
            for i in range(len(text_pool), self.left_padding+self.right_padding+1):
                text = self.tokenizer.encode("", padding=self.padding, max_length=self.text_max_length)
                speaker = self.tokenizer("", padding=self.padding, max_length=8)["input_ids"]
                token_text_pool.insert(0, text)
                token_speaker_pool.insert(0, speaker)
        if self.option == "train":
            label = self.label2id[self.data[index]["emotion"]["property"]]
            emotion_label = [0] * len(self.id2label)
            emotion_label[label] = 1
            emotion_label = torch.FloatTensor(emotion_label)
            casual_pool = self.data[index]["emotion"]["casual"]
            token_casual = []

            for i in range(len(casual_pool)):
                casual = [casual_pool[i]["start"], casual_pool[i]["end"]]
                token_casual.append(casual)
            # print(" token_casual: "+str(len(token_casual)))

            for i in range(len(casual_pool), self.left_padding+self.right_padding+1):
                casual = [0, 0]
                token_casual.insert(0, casual)
            return {
                "conversation_ID": conversation_ID,
                "paragraph": paragraph,
                "utterance_ID": utterance_ID,
                "token_text_pool": torch.FloatTensor(token_text_pool),
                "token_speaker_pool": torch.FloatTensor(token_speaker_pool),
                "emotion_label": emotion_label,
                "span_label": torch.FloatTensor(token_casual)
            }
        else:
            return {
                "conversation_ID": conversation_ID,
                "paragraph": paragraph,
                "utterance_ID": utterance_ID,
                "token_text_pool": torch.FloatTensor(token_text_pool),
                "token_speaker_pool": torch.FloatTensor(token_speaker_pool),
            }

In [7]:
class Trainer(nn.Module):
    def __init__(self, config_path):
        super(Trainer, self).__init__()
        self.layers_to_freeze = []
        self.history = {
            "train loss": [],
            "valid loss": [],
            "emotion_correct": [],
            "span_correct": []
        }
        self.build(config_path)
    def build(self, config_path):
        print(config_path)
        config = read_json(config_path)
        self.patience = config['patience']
        self.min_delta = config['min_delta']
        self.batch_size = config['batch_size']
        self.device = config['device']
        self.save_checkpoint_dir = config['output_dir'] + "/checkpoint"
        self.save_history_dir = config['output_dir'] + "/history"
        self.emotion_head_layers = config['emotion_head_layers']
        self.casual_head_layers = config['casual_head_layers']
        self.decoder_layers = config['decoder_layers']
        self.encoder_layers = config['encoder_layers']
        self.lstm_layers = config['lstm_layers']
        self.alpha_focal_loss = config['alpha_focal_loss']
        self.gamma_focal_loss = config['gamma_focal_loss']
        self.label2id = config['label2id']
        self.target_names = self.label2id.keys()
        self.early_stopper = EarlyStopper(self.patience, self.min_delta)

        mkdir(self.save_history_dir)

    def emotion_head_training(self):
        self.layers_to_freeze = self.casual_head_layers + self.encoder_layers + self.decoder_layers + self.lstm_layers
        # self.fn_loss = FocalLoss(gamma=self.gamma_focal_loss, weights=torch.tensor(self.alpha_focal_loss).to(self.device))
    def emotion_encoder_training(self):
        self.layers_to_freeze = self.casual_head_layers + self.lstm_layers + self.decoder_layers
        # self.fn_loss = FocalLoss(gamma=self.gamma_focal_loss, weights=torch.tensor(self.alpha_focal_loss).to(self.device))
    def emotion_encoder_lstm_training(self):
        self.layers_to_freeze = self.casual_head_layers + self.decoder_layers
        # self.fn_loss = FocalLoss(gamma=self.gamma_focal_loss, weights=torch.tensor(self.alpha_focal_loss).to(self.device))
    def emotion_lstm_training(self):
        self.layers_to_freeze = self.casual_head_layers + self.encoder_layers + self.decoder_layers
        # self.fn_loss = FocalLoss(gamma=self.gamma_focal_loss, weights=torch.tensor(self.alpha_focal_loss).to(self.device))

    def casual_head_training(self):
        self.layers_to_freeze = self.emotion_head_layers + self.encoder_layers + self.lstm_layers + self.decoder_layers

    def full_training(self):
        self.layers_to_freeze = []
    def freeze(self, model):
        for name, param in model.named_parameters():
            if any(layer in name for layer in self.layers_to_freeze):
                param.requires_grad = False
            else:
                param.requires_grad = True
        return model
    def train(self, model, optimizer, train_dataset, valid_dataset, epochs, option="emotion_head", batch_size=None):
        self.optimizer = optimizer
        best_weight = model.state_dict()
        if option == "emotion_head":
            print("TRAINING " + option + " MODEL")
            self.emotion_head_training()
        elif option == "emotion_encoder":
            print("TRAINING " + option + " MODEL")
            self.emotion_encoder_training()
        elif option == "emotion_lstm":
            print("TRAINING " + option + " MODEL")
            self.emotion_lstm_training()
        elif option == "emotion_encoder_lstm":
            print("TRAINING " + option + " MODEL")
            self.emotion_encoder_lstm_training()
        self.fn_loss = nn.CrossEntropyLoss()
        # self.loss_fn = FocalLoss(gamma=2, weights=torch.tensor([1,5,5,1,1,1,1]).to(self.device))

        if batch_size == None:
            self.batch_size = self.batch_size_config
        else:
            self.batch_size = batch_size
        mkdir(self.save_checkpoint_dir + "/" + option)

        self.early_stopper.reset()
        self.model = self.freeze(model)
        self.model = nn.DataParallel(self.model)
        self.model = self.model.to(self.device)

        self.train_dataloader = DataLoader(dataset=train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=False)
        self.valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=self.batch_size, shuffle=False, drop_last=False)

        best_val_loss = np.inf
        self.epochs = epochs
        for epoch in range(epochs):
            train_loss = self._train_epoch(epoch, option)
            val_loss, emotion_correct, emotion_label_list, emotion_pred_list = self._val_epoch(epoch, option)

            self.history["train loss"].append(train_loss)
            self.history["valid loss"].append(val_loss)
            self.history["emotion_correct"].append(emotion_correct)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_weight = model.state_dict()
                self.save_checkpoint(epoch, option)
                print(f"Epoch {epoch+1}: Train loss {train_loss:.4f},  Valid loss {val_loss:.4f},  emotion_correct {emotion_correct*100:.4f}%")
                print(classification_report(emotion_label_list, emotion_pred_list, target_names=self.target_names))
            if self.early_stopper.early_stop(val_loss):
                print("Early stop at epoch: "+str(epoch) + " with valid loss: "+str(val_loss))
                break

        model.load_state_dict(best_weight)
        write_json(self.save_history_dir + "/" + option + "_history.json", self.history)

    def _train_epoch(self, epoch, option):
        self.model.train()
        train_loss = 0.0
        logger_message = f'Training epoch {epoch}/{self.epochs}'

        progress_bar = tqdm(self.train_dataloader,
                            desc=logger_message, initial=0, dynamic_ncols=True)
        for batch, data in enumerate(progress_bar):
            conversation_id = data["conversation_ID"]
            paragraph = data["paragraph"].to(self.device)
            utter_id = data["utterance_ID"]
            token_text_pool = data["token_text_pool"].to(self.device)
            token_speaker_pool = data["token_speaker_pool"].to(self.device)
            emotion_label = data["emotion_label"].to(self.device)
            self.optimizer.zero_grad()
            emotion_pred = self.model(paragraph, token_text_pool, token_speaker_pool)
            # print("label: "+str(emotion_label.shape))
            # print("label: "+str(emotion_label))
            # emotion_label = torch.argmax(emotion_label, dim=-1)
            # emotion_pred = torch.argmax(emotion_pred, dim=-1).unsqueeze(-1)
            # print("emotion_pred: "+str(emotion_pred.shape))
            # print("emotion_pred: "+str(emotion_pred))
            loss = self.loss_fn(emotion_pred, emotion_label)
            train_loss += loss.item()
            loss.backward()
            self.optimizer.step()
        return train_loss / len(self.train_dataloader)

    def _val_epoch(self, epoch, option):
        self.model.eval()
        valid_loss = 0.0
        emotion_correct=0
        span_correct=0
        emotion_pred_list = []
        emotion_label_list = []
        logger_message = f'Validation epoch {epoch}/{self.epochs}'
        progress_bar = tqdm(self.valid_dataloader,
                            desc=logger_message, initial=0, dynamic_ncols=True)
        with torch.no_grad():
            for _, data in enumerate(progress_bar):
                conversation_id = data["conversation_ID"]
                paragraph = data["paragraph"].to(self.device)
                utter_id = data["utterance_ID"]
                token_text_pool = data["token_text_pool"].to(self.device)
                token_speaker_pool = data["token_speaker_pool"].to(self.device)
                emotion_label = data["emotion_label"].to(self.device)
                emotion_pred = self.model(paragraph, token_text_pool, token_speaker_pool)
                # emotion_label = torch.argmax(emotion_label, dim=-1).unsqueeze(-1)
                loss = self.loss_fn(emotion_pred, emotion_label)

                emotion_label = torch.argmax(emotion_label, dim=-1).unsqueeze(-1)
                emotion_pred = torch.argmax(emotion_pred, dim=-1)
                emotion_pred_list.extend(emotion_pred.cpu().tolist())
                emotion_label_list.extend(emotion_label.cpu().tolist())
                emotion_correct += (emotion_pred==emotion_label).sum().item()
                valid_loss += loss.item()
        return valid_loss / len(self.valid_dataloader), emotion_correct / len(self.valid_dataloader) / self.batch_size, emotion_label_list, emotion_pred_list

    def save_checkpoint(self, epoch, option):
        checkpoint_path = os.path.join(self.save_checkpoint_dir, option + '/checkpoint_{}.pth'.format(epoch))
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'epoch': epoch
        }, checkpoint_path)

In [8]:
class TGG(nn.Module):
    def __init__(self, config_path):
        super(TGG, self).__init__()
        self.build(config_path)
        # self.lstm = nn.LSTM(self.text_max_length, self.hidden_size, self.num_layers, batch_first=True, dropout=0.2, bidirectional=True)
        # self.lstma_projection = nn.Linear(2*self.num_layers*self.hidden_size, self.embedding_dim)
        # self.cat_layer = nn.Linear(2*self.embedding_dim, self.embedding_dim)
        self.emotion_fc = nn.Linear(self.embedding_dim, self.label_num)
        # self.emotion_feedforward = nn.Sequential(
        #     nn.Linear(self.embedding_dim, 2*self.embedding_dim),
        #     nn.ReLU(),
        #     nn.Linear(2*self.embedding_dim, self.embedding_dim),
        #     nn.ReLU(),
        #     nn.Dropout(0.2)
        # )
        self.softmax = nn.Softmax(dim=-1)
    def build(self, config_path):
        config = self.read_json(config_path)
        self.hidden_size = config["hidden_size"]
        self.embedding_dim = config["embedding_dim"]
        self.encoder_name = config["encoder_name"]
        self.decoder_name = config["decoder_name"]
        self.label2id = config["label2id"]
        self.label_num = len(self.label2id)
        self.num_layers = config["num_layers"]
        self.paragraph_max_length = config["paragraph_max_length"]
        self.text_max_length = config["text_max_length"]
        self.device = config["device"]

        # main_part = AutoModel.from_pretrained(self.encoder_name)
        # self.embeddings = main_part.embeddings
        # self.encoder = main_part.encoder
        # self.pooler = main_part.pooler
        main_part = AutoModel.from_pretrained(self.encoder_name)
        self.embeddings = main_part.embed_tokens
        self.encoder = main_part.layers
        self.pooler = main_part.norm


    def forward(self, paragraph, token_text_pool, token_speaker_pool):
        batch_size = paragraph.size(0)
        # candidate_size = token_text_pool.size(1)
        # t = token_text_pool.size(2)
        device = paragraph.get_device()

        # h0 = torch.randn((2 * self.num_layers, batch_size, self.hidden_size), device=device)
        # c0 = torch.randn((2 * self.num_layers, batch_size, self.hidden_size), device=device)
        # output, (hn, cn) = self.lstm(token_text_pool, (h0, c0))
        # hn = hn.permute(1,0,2)
        # hn = self.lstma_projection(hn.reshape(batch_size, -1))

        _paragraph = self.embeddings(paragraph)
        _paragraph = self.encoder(_paragraph)['last_hidden_state']
        _paragraph = self.pooler(_paragraph)

        # hidden_state = torch.cat([hn, _paragraph], dim=-1)
        # _x = self.cat_layer(hidden_state)
        # _x = self.emotion_feedforward(_paragraph)
        _e_category = self.emotion_fc(_x)

        return self.softmax(_e_category)
    def read_json(self, file_path):
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data

In [9]:
config_path = "/root/NAACL2024/new model/config.json"
train_data = read_json("/root/NAACL2024/data/ECF 2.0/train/augment_train.json")
valid_data = read_json("/root/NAACL2024/data/ECF 2.0/train/augment_valid.json")
# train_data = read_json("/root/NAACL2024/data/ECF 2.0/train/train.json")
# valid_data = read_json("/root/NAACL2024/data/ECF 2.0/train/valid.json")
# trial_data = read_json("../data/ECF 2.0/trial/trial.json")
# raw_trial_data = read_json("../data/ECF 2.0/trial/Subtask_1_trial.json")

print(count(train_data))
print(count(valid_data))

{'neutral': 1845, 'surprise': 1845, 'anger': 1845, 'sadness': 1845, 'joy': 1845, 'disgust': 1845, 'fear': 1845}
{'neutral': 1188, 'anger': 346, 'surprise': 329, 'fear': 61, 'sadness': 263, 'joy': 456, 'disgust': 81}


In [10]:
train_dataset = ConversationDataset(config_path, train_data)
valid_dataset = ConversationDataset(config_path, valid_data)
model = TGG(config_path)
optimizer = Adam(model.parameters(), lr=0.001)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [11]:
trainer = Trainer(config_path)

/root/NAACL2024/new model/config.json


In [12]:
# for name, param in model.named_parameters():
#     print("\"" + name + "\",")

In [13]:
trainer.train(model, optimizer, train_dataset=train_dataset, valid_dataset=valid_dataset, epochs=100, option="emotion_head", batch_size=2)

TRAINING emotion_head MODEL


RuntimeError: CUDA out of memory. Tried to allocate 172.00 MiB (GPU 0; 10.76 GiB total capacity; 9.25 GiB already allocated; 60.56 MiB free; 9.26 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF