# **INSTALL LIBRARIES**

In [1]:
# !pip install transformers

# **LIBRARIES**

In [2]:
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
# Dataset
from PIL import Image
from torchvision import transforms
from torchvision.io import read_video, read_image
from torch.utils.data import Dataset, DataLoader
# Model
from transformers import AutoModel
from transformers import AutoTokenizer, BertModel, BertTokenizer, BertGenerationDecoder

# Training parameter
from torch.optim import Adam
# Training process
from tqdm import tqdm
# Metrics
from sklearn.metrics import classification_report
from collections import OrderedDict

In [3]:
# from google.colab import drive
# drive.mount('/content/drive')

# **SETUP PARAMETERS**

In [4]:
FOLDER = '/data/ECF 2.0/'
save_checkpoint_dir = FOLDER + "checkpoints"
save_model_dir = FOLDER + "model"
save_report_dir = FOLDER + "report"
model_name = "BERT"
tokenizer_name = 'michellejieli/emotion_text_classifier'
encoder_name = 'michellejieli/emotion_text_classifier'
decoder_name = 'roberta-base'
optimizer_name = "BERT"
loss_fn_name = "BCE_Loss"
metrics = {}
output_json = FOLDER + "predict"
file_train = {"subtask_1_text_file": FOLDER + "train/Subtask_1_train.json",
              "subtask_2_text_file": FOLDER + "train/Subtask_2_train.json",
              "video_dir": FOLDER + "train/video_with_audio",
              "max_conversation_length": 32,
              "max_emotion_cause_pairs_length": 128}
file_trial = {"subtask_1_text_file": FOLDER + "trial/Subtask_1_trial.json",
              "subtask_2_text_file": FOLDER + "trial/Subtask_2_trial.json",
              "video_dir": FOLDER + "trial/video_with_audio",
              "max_conversation_length": 32,
              "max_emotion_cause_pairs_length": 128}
subtask = "subtask_1"
format = "ECF 2.0"

# Define token to find subtensor
bos_token_id = 101
eos_token_id = 102
max_length = 512

# **UTIL FUNCTIONS**

In [5]:
def read_json_file(file_path):
  with open(file_path, 'r') as f:
    data = json.load(f)
  return data
def find_sub_list_indices(main_list, sub_list):
  start_index = None
  end_index = None
  sub_list = sub_list[0]
  end_sublist = sub_list.index(eof_token)
  sub_list = sub_list[1:end_sublist]
  start_index = main_list.index(sub_list[0])
  end_index = start_index + len(sub_list)
  while main_list[start_index:end_index] != sub_list:
    start_index = main_list.index(sub_list[0], start_index + 1)
    end_index = start_index + len(sub_list)

  return start_index, end_index

# **LOAD MODEL**

In [6]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
encoder = AutoModel.from_pretrained(encoder_name)

Some weights of RobertaModel were not initialized from the model checkpoint at michellejieli/emotion_text_classifier and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# **Evaluation**

In [7]:
# class Evaluation(nn.Module):
#     def __init__(self):
#         self.super(Evaluation, self).__init__()
#     def forward(self, emotion, sentence_index, sentence_pos, emotion_pred, sentence_index_pred, sentence_pos_pred):
#         emotion_acc = 0

# **TRAIN**

In [8]:

# prompt: create trainer class
class Trainer(object):
  def __init__(self, model, optimizer, metrics=None, save_checkpoint_dir='./', save_model_dir='./', save_report_dir='./'):
    self.model = model
    self.optimizer = optimizer
    self.metrics = metrics

    self.save_checkpoint_dir = save_checkpoint_dir
    self.save_model_dir = save_model_dir
    self.save_report_dir = save_report_dir

  def train(self, train_loader, val_loader, epochs):
    best_val_loss = np.inf
    self.epochs=epochs
    self.train_dataloader = train_loader
    self.valid_dataloader = val_loader
    for epoch in range(epochs):
      train_loss, train_metrics = self._train_epoch(epoch)
      val_loss, emotion_acc_avg, sentence_index_acc_avg, sentence_pos_acc_avg = self._val_epoch(epoch)
      if val_loss < best_val_loss:
        best_val_loss = val_loss
        self.save_checkpoint(epoch)
      print(f"Epoch {epoch+1}: Train loss {train_loss:.4f}, Train {train_metrics}, Valid loss {val_loss:.4f}, Emotion {emotion_acc_avg * 100:.4f}%, Sentence Index {sentence_index_acc_avg * 100:.4f}%, Sentence Pos {sentence_pos_acc_avg * 100:.4f}%")

  def _train_epoch(self, epoch):
    self.model.train()
    train_loss = 0.0
    emotion_acc_avg = 0
    train_metrics = {}
    logger_message = f'Training epoch {epoch}/{self.epochs}'

    progress_bar = tqdm(self.train_dataloader,
                        desc=logger_message, initial=0, dynamic_ncols=True)
    for batch, data in enumerate(progress_bar):
      self.optimizer.zero_grad()
      loss, emotion_pred, sentence_index_pred, sentence_pos_pred = self.model(data)
      
      loss.backward()
      self.optimizer.step()
        
      train_loss += loss.item()
      
      
      # Emotion
      _, emotion_predicted = torch.max(emotion_pred, 1)
      _, emotion_target = torch.max(data['emotion'], 1)
      emotion_acc_avg += (emotion_predicted.to('cpu') == emotion_target).sum().item()
      # for metric, fn in self.metrics.items():
      #   train_metrics[metric] = fn(outputs, batch["target"])
    return train_loss / len(self.train_dataloader), emotion_acc_avg / len(self.train_dataloader)

  def _val_epoch(self, epoch):
    self.model.eval()
    test_loss = 0.0
    test_metrics = {}
    emotion_acc_avg = 0
    sentence_index_acc_avg = 0
    sentence_pos_acc_avg = 0
    logger_message = f'Validation epoch {epoch}/{self.epochs}'
    progress_bar = tqdm(self.valid_dataloader,
                        desc=logger_message, initial=0, dynamic_ncols=True)
    with torch.no_grad():
      for _, data in enumerate(progress_bar):
        loss, emotion_pred, sentence_index_pred, sentence_pos_pred = self.model(data)
        test_loss += loss.item()
        
        # Emotion
        _, emotion_predicted = torch.max(emotion_pred, 1)
        _, emotion_target = torch.max(data['emotion'], 1)
        emotion_acc_avg += (emotion_predicted.to('cpu') == emotion_target).sum().item() / emotion_target.size(0)
        sentence_index_acc_avg += self.sentence_index_func(sentence_index_pred.to('cpu'), data['casual_sentence_index_label'])
        sentence_pos_acc_avg += self.sentence_pos_func(sentence_pos_pred.to('cpu'), data['casual_sentence_pos_label'])
        # for metric, fn in self.metrics.items():
        #   test_metrics[metric] = fn(outputs, batch["target"])
    return test_loss / len(self.valid_dataloader), emotion_acc_avg / len(self.valid_dataloader), sentence_index_acc_avg / len(self.valid_dataloader), sentence_pos_acc_avg / len(self.valid_dataloader)
  def save_checkpoint(self, epoch):
    checkpoint_path = os.path.join(self.save_checkpoint_dir, 'checkpoint_{}.pth'.format(epoch))
    torch.save({
        'model_state_dict': self.model.state_dict(),
        'optimizer_state_dict': self.optimizer.state_dict(),
        'epoch': epoch
    }, checkpoint_path)
    
  def sentence_index_func(self, pred, target):
    acc = torch.sum((pred == 1) & (target == 1)).item()
    if target.sum() == 0:
      return 0
    return acc / target.sum()
  def sentence_pos_func(self, pred, target):
    acc = (pred[:, :, 1] == target[:, :, 1])*(pred[:, :, 0] == target[:, :, 0])
    return acc.sum().item() / acc.size(0) / acc.size(1)

# **PREDICTION**

In [9]:
class Predictor(object):
  def __init__(self, model, output_json, subtask=None, batch_size=32, save_checkpoint_dir=None, save_model_dir=None):
    self.model = model
    self.output_json = output_json
    self.subtask = subtask
    self.batch_size = batch_size
    self.text_to_number_mapping =  {"anger": 0, 
                                    "disgust": 1, 
                                    "joy": 2, 
                                    "neutral": 3, 
                                    "fear": 4, 
                                    "sadness": 5, 
                                    "surprise": 6}
    self.number_to_text_mapping =  {0: "anger", 
                                    1: "disgust", 
                                    2: "joy", 
                                    3: "neutral", 
                                    4: "fear", 
                                    5: "sadness", 
                                    6: "surprise"}
    self.predicted_dataset = []
    
    if save_checkpoint_dir!=None:
      self.load_checkpoint(save_checkpoint_dir)

  def emotion_cause_pairs(self, casual_text, predicted_emotion, casual = True):
    text = str(predicted_emotion[0] + 1) + "_" + str(self.number_to_text_mapping[predicted_emotion[1]])
    if casual == True:
      cause = str(predicted_emotion[2] + 1) + "_" + str(casual_text[predicted_emotion[3]: predicted_emotion[4]])
    else:
      cause = str(predicted_emotion[2] + 1)
    return [text, cause]

  def predict(self, dataset, dataloader):
    self.predicted_dataset = []
    self.model.eval()
    logger_message = f'Prediction'
    progress_bar = tqdm(dataloader,
                        desc=logger_message, initial=0, dynamic_ncols=True)
    with torch.no_grad():
      for _, data in enumerate(progress_bar):
        emotion_pred, sentence_index_pred, sentence_pos_pred = self.model.predict(data) # (batch_size, max_utterance, 5)
        if self.subtask == "subtask_1":
          d, dataset = self.predict_subtask_1(data, emotion_pred, sentence_index_pred, sentence_pos_pred, dataset)
          self.predicted_dataset.extend(d)
    # self.recontruct_dataset()
    # self.save_prediction(self.predicted_dataset)
    self.dataset = dataset
    
  def predict_subtask_1(self, data, emotion_pred, sentence_index_pred, sentence_pos_pred, dataset):
    data['emotion'] = emotion_pred
    data['casual_sentence_index_label'] = sentence_index_pred
    data['casual_sentence_pos_label'] = sentence_pos_pred
    
    # print(dataset[conversation_ID])
    new_data = []
    for batch in range(len(data['emotion'])):
      conversation_ID = data['conversation_ID'][batch]
      utterance_ID = data['utterance_ID'][batch]
      casual_sentence_index_list = data['casual_sentence_index'][batch]
      emotion = data['emotion'][batch]
      casual_sentence_index_label_list = data['casual_sentence_index_label'][batch]
      casual_sentence_pos_label_list = data['casual_sentence_pos_label'][batch]
      # print(conversation_ID)
      # print(dataset[conversation_ID-1].keys())
      if 'emotion-cause_pairs' not in dataset[conversation_ID-1].keys():
        dataset[conversation_ID-1]['emotion-cause_pairs'] = []
      print(emotion)
      print(casual_sentence_index_label_list)
      for j in range(len(casual_sentence_index_label_list)):
        
        if casual_sentence_index_label_list[j] != 0:
          casual_sentence_index = casual_sentence_index_list[j]
          emotion=torch.argmax(emotion).item()
          emotion=self.number_to_text_mapping(emotion)
          emotion_cause_pairs = [
            str(utterance_ID)+"_"+emotion,
            str(casual_sentence_index)+"_"+str(casual_sentence_pos_label_list[0])+"_"+str(casual_sentence_pos_label_list[1])
          ]
          dataset[conversation_ID-1]['emotion-cause_pairs'].append(emotion_cause_pairs)
      print(dataset[conversation_ID-1])
    return data, dataset
  def recontruct_dataset(self):
    dictation = []
    for i in range(len(self.predicted_dataset)):
      conversation_ID = self.predicted_dataset[i]['conversation_ID']
      if conversation_ID in dictation:
        utter = {"utterance_ID": self.predicted_dataset[i]['utterance_ID'],
                "text": self.predicted_dataset[i]['text'],
                "speaker": self.predicted_dataset[i]['speaker']}
        dictation[conversation_ID]['conversation'].append(utter)
        emotion_cause_pairs = [
          str(self.predicted_dataset[i]['utterance_ID'])+"_"+self.predicted_dataset[i]['emotion'],
          str(self.predicted_dataset[i]['casual_sentence_index_label'])+"_"+str(self.predicted_dataset[i]['casual_sentence_pos_label'][0])+"_"+str(self.predicted_dataset[i]['casual_sentence_pos_label'][1])
        ]
        dictation[conversation_ID]['conversation'].append(utter)
        dictation[conversation_ID]['emotion-cause_pairs'].append(emotion_cause_pairs)
    self.predicted_dataset = []
    for i in range(len(dictation)):
      conversation_ID = i
      conversation = dictation[conversation_ID]['conversation']
      emotion_cause_pairs = dictation[conversation_ID]['emotion-cause_pairs']
      self.predicted_dataset.append({"conversation_ID": conversation_ID,
                                    "conversation": conversation,
                                    "emotion-cause_pairs": emotion_cause_pairs})
  def save_prediction(self, dataset):
    with open(self.output_json, 'w') as f:
      json.dump(dataset, f)
  def load_checkpoint(self, PATH):    
    checkpoint = torch.load(PATH)
    self.model.load_state_dict(checkpoint['model_state_dict'])
    print("LOAD CHECHPOINT SUCCESFUL")

# **DATASET**

In [10]:
class ProcessDataset(nn.Module):
  def __init__(self, file, format="ECF 2.0", subtask="subtask_1", tokenizer=None, max_length=512):
    super(ProcessDataset, self).__init__()
    self.format = format
    self.subtask = subtask
    self.left_padding_side = 6
    self.right_padding_side = 2
    if format == 'ECF 2.0':
      self.tokenizer = tokenizer
      self.max_conversation_length = file['max_conversation_length']
      self.max_emotion_cause_pairs_length = file['max_emotion_cause_pairs_length']
      if self.subtask == "subtask_1":
        self.text_file = file['subtask_1_text_file']
        self.text_data = json.load(open(self.text_file))
        self.max_length = max_length
        self.text_to_number_mapping = {"anger": 0, 
                                       "disgust": 1, 
                                       "joy": 2, 
                                       "neutral": 3, 
                                       "fear": 4, 
                                       "sadness": 5, 
                                       "surprise": 6}
        self.build()
  def split_train_test(self, test_size=0.2):
    if test_size == 0:
      return self.dataset
    n = len(self.dataset)
    n_train = int(n / 32 * (1 - test_size)) * 32
    print("N_TRAIN: "+str(n_train) + ", N_TEST: "+str(n-n_train))
    return self.dataset[:n_train], self.dataset[n_train:]
  def get_data(self):
    return self.text_data
  def __len__(self):
    return len(self.dataset)
  def build(self):
    self.padding_utterance()
    self.dataset = []
    for conversation_index in range(len(self.text_data)):
      for i in range(len(self.text_data[conversation_index]['conversation']) - self.left_padding_side - self.right_padding_side):
        conversation_ID = self.text_data[conversation_index]['conversation_ID']
        utterance_ID = self.text_data[conversation_index]['conversation'][i + self.left_padding_side]['utterance_ID']
        text = ""
        casual_sentence_index_list = []
        sentence_list = []
        casual_sentence_index_list_label = [0] * 8
        casual_sentence_pos_list_label = [[0,0]] * 8
        
        for j in range(self.left_padding_side):
          text += self.text_data[conversation_index]['conversation'][i + j + 1]['text']
        for j in range(self.left_padding_side+self.right_padding_side):
          casual_sentence_index_list.append(self.text_data[conversation_index]['conversation'][i + j + 1]['utterance_ID'])
          sentence_list.append(self.text_data[conversation_index]['conversation'][i + j + 1]['text'])
        speaker = self.text_data[conversation_index]['conversation'][i + self.left_padding_side]['speaker']
        if 'emotion' in self.text_data[conversation_index]['conversation'][i + self.left_padding_side].keys():
          emotion = self.text_data[conversation_index]['conversation'][i + self.left_padding_side]['emotion']
        else:
          emotion = 'neutral'
        

        if 'emotion-cause_pairs' in self.text_data[conversation_index].keys():
          for j in range(len(self.text_data[conversation_index]['emotion-cause_pairs'])):
            split_sentence = self.text_data[conversation_index]['emotion-cause_pairs'][j][0].split("_")
            sentence_index, _ = int(split_sentence[0]), split_sentence[1]
            if sentence_index == utterance_ID:
              split_sentence = self.text_data[conversation_index]['emotion-cause_pairs'][j][1].split("_")
              casual_sentence_index_label, casual_sentence_label = int(split_sentence[0]), split_sentence[1]
              sentence_label = self.text_data[conversation_index]['conversation'][casual_sentence_index_label + self.left_padding_side - 1]['text']
              casual_sentence_pos_label = self.find_sub_list(casual_sentence_label, sentence_label)
              
              if (casual_sentence_index_label - sentence_index) <= 2 and (casual_sentence_index_label - sentence_index) >= -5:
                casual_sentence_index_list_label[casual_sentence_index_list.index(casual_sentence_index_label)] = 1
                casual_sentence_pos_list_label[casual_sentence_index_list.index(casual_sentence_index_label)] = casual_sentence_pos_label
        
        utter = {'conversation_ID': conversation_ID,
                 'utterance_ID': utterance_ID,
                 'text': text,
                 'speaker': speaker,
                 'emotion': emotion,
                 'casual_sentence_index': casual_sentence_index_list,
                 'sentence': sentence_list,
                 'casual_sentence_index_label': casual_sentence_index_list_label,
                 'casual_sentence_pos_label': casual_sentence_pos_list_label,
                 }
        self.dataset.append(utter)
  def find_sub_list(self, sub_sentence,sentence):
    sentence = sentence.split()
    sub_sentence = sub_sentence.split()
    sll=len(sub_sentence)
    for ind in (i for i,e in enumerate(sentence) if e==sub_sentence[0]):
        if sentence[ind:ind+sll]==sub_sentence:
            return [ind,ind+sll]
  def padding_utterance(self):
    for i in range(len(self.text_data)):
      l = len(self.text_data[i]['conversation'])
      for j in range(self.left_padding_side):
        utter = {
                    "utterance_ID": -j,
                    "text": "",
                    "speaker": "",
                    "emotion": "neutral"
                }
        self.text_data[i]['conversation'].insert(0, utter)
      for j in range(self.right_padding_side):
        utter = {
                    "utterance_ID": l+j,
                    "text": "",
                    "speaker": "",
                    "emotion": "neutral"
                }
        self.text_data[i]['conversation'].append(utter)      
  def padding_casual_sentence(self, sentence_list, casual_sentence_pos_list, casual_sentence_list):
    padding_side = 8 - len(sentence_list)
    for i in range(padding_side):
      sentence_list.append('')
      casual_sentence_pos_list.append([0,0])
      casual_sentence_list.append('')
    return sentence_list, casual_sentence_pos_list, casual_sentence_list
  def emotion_EDA(self):
    self.emotion_counts = {"neutral": 0, 'surprise': 0, 'anger': 0, "disgust": 0, "fear": 0, "joy": 0, "sadness": 0}
    print("DISTANCE")
    dictation = {}
    for index in range(len(self.text_data)):
      emotion_cause_pairs = self.text_data[index]['emotion-cause_pairs']
      # print("emotion_cause_pairs:" +str(emotion_cause_pairs))
      for i in range(len(emotion_cause_pairs)):
        emotion = int(emotion_cause_pairs[i][0].split("_")[0])
        casual_emotion = int(emotion_cause_pairs[i][1].split("_")[0])
        string = str(emotion-casual_emotion)
        if string in dictation.keys():
          dictation[string] += 1
        else:
          dictation[string] = 1
    dictation = sorted(dictation.items(), key=lambda x:x[1])
    print(dictation)
    print("COUNTING")
    for index in range(len(self.text_data)):
      conversation = self.text_data[index]['conversation']
      for i in range(len(conversation)):
        emotion = conversation[i]['emotion']
        self.emotion_counts[emotion] += 1
    return self.emotion_counts  
  
class Dataset(Dataset):
  def __init__(self, dataset, tokenizer=None, max_length=512):
    self.dataset = dataset
    self.tokenizer = tokenizer
    self.max_length = max_length
    self.text_to_number_mapping = {"anger": 0, 
                                    "disgust": 1, 
                                    "joy": 2, 
                                    "neutral": 3, 
                                    "fear": 4, 
                                    "sadness": 5, 
                                    "surprise": 6}
  def __len__(self):
    return len(self.dataset)
  def __getitem__(self, index):
    item = self.dataset[index]
    emotion = self.text_to_number_mapping[item['emotion']]
    emotion_binary_tensor = torch.zeros(len(self.text_to_number_mapping))
    emotion_binary_tensor[emotion] = 1
    return {'conversation_ID': item['conversation_ID'],
              'utterance_ID': item['utterance_ID'],
              'text': self.tokenizer(item['text'], padding="max_length", max_length=self.max_length, truncation=True, return_tensors="pt"),
              'speaker': item['speaker'],
              'emotion': emotion_binary_tensor,
              'casual_sentence_index': torch.FloatTensor(item['casual_sentence_index']),
              'sentence': self.tokenizer(item['sentence'], padding="max_length", truncation=True, return_tensors="pt"),
              'casual_sentence_index_label': torch.FloatTensor(item['casual_sentence_index_label']),
              # 'sentence_label': self.tokenizer(item['sentence_label'], padding="max_length", truncation=True, return_tensors="pt"),
              'casual_sentence_pos_label': torch.FloatTensor(item['casual_sentence_pos_label']),
              # 'casual_sentence_label': self.tokenizer(item['casual_sentence_label'], padding="max_length", truncation=True, return_tensors="pt")
              }

# **MODEL**

In [11]:
class VH(nn.Module):
  def __init__(self, encoder, hidden_size, label_dim, padding_side, loss_fn, device):
    super(VH, self).__init__()
    self.embedding_dim = 768
    self.hidden_size = hidden_size
    self.label_dim = label_dim
    self.padding_side = padding_side
    self.loss_fn = loss_fn
    self.device = device
    self.build(encoder)
    
  def build(self, encoder):
    self.embeddings = encoder.embeddings
    self.encoder = encoder.encoder
    self.pooler = encoder.pooler
    self.emotion_fc = nn.Linear(self.embedding_dim, self.label_dim)
    self.casual_matching_embed = nn.Linear(self.embedding_dim, self.embedding_dim)
    self.relu = nn.ReLU()
    self.tanh = nn.Tanh()
    self.casual_pos_retrieval = nn.Linear(self.embedding_dim*self.hidden_size, 2)
    self.emotion_softmax = nn.Softmax(dim=1)
  
    self.norm_768 = nn.LayerNorm(768)
    
  def pos_acc_func(self, pred, target):
    s = 0
    for i in range(8):
      s += torch.sum((pred[:,:,0] == target[:,:,0]) and (pred[:,:,1] == target[:,:,1]))
    return s
  def forward(self, x):
    batch_size = x['text']['input_ids'].size(0)
    text = x['text']['input_ids'].squeeze(1).to(self.device)
    emotion = x['emotion'].to(self.device)
    sentence = x['sentence']['input_ids'].to(self.device)
    casual_sentence_index = x['casual_sentence_index'].to(self.device)
    casual_sentence_index_label = x['casual_sentence_index_label'].to(self.device)
    casual_sentence_pos_label = x['casual_sentence_pos_label'].to(self.device)
    # casual_sentence = x['casual_sentence']['input_ids'].to(self.device)
    # print("text:"+str(text.shape))
    # print("emotion:"+str(emotion.shape))
    # print("sentence:"+str(sentence.shape))
    # print("casual_sentence_index:"+str(casual_sentence_index.shape))
    # print("casual_sentence_index_label:"+str(casual_sentence_index_label.shape))
    # print("casual_sentence_pos_label:"+str(casual_sentence_pos_label.shape))
    # print("casual_sentence_index:"+str(casual_sentence_index))
    
    # Embeddings
    # print(text.shape)
    text = self.embeddings(text)
    # print(text.shape)
    hidden_state = self.encoder(text)
    
    # Emotion predict
    emotion_pred = self.emotion_predict(hidden_state)
    emotion_loss = self.loss_fn(emotion_pred, emotion)
      
    # casual emotion
    sentence_index_pred, sentence_pos_pred = self.casual_emotion(hidden_state, sentence, casual_sentence_index, batch_size)
    
    index_loss = self.loss_fn(sentence_index_pred, casual_sentence_index_label)
    pos_loss = self.loss_fn(sentence_pos_pred, casual_sentence_pos_label)
    
    loss = emotion_loss + index_loss + pos_loss
    # print("emotion_pred:"+str(emotion_pred.shape))
    # print("sentence_index_pred:"+str(sentence_index_pred.shape))
    # print("sentence_pos_pred:"+str(sentence_pos_pred.shape))
    # print(sentence_index_pred)
    # print("label:"+str(casual_sentence_index_label))
    return loss, emotion_pred, sentence_index_pred, sentence_pos_pred
  
  def emotion_predict(self, hidden_state):
    # Emotion
    emotion = self.emotion_fc(self.pooler(hidden_state['last_hidden_state']))
    return self.emotion_softmax(emotion)
  
  def casual_emotion(self, hidden_state, sentence, casual_sentence_index, batch_size):
    sentence_pool = torch.zeros((self.padding_side, batch_size, self.hidden_size, self.embedding_dim), device=self.device)

    for i in range(self.padding_side):
      sentence_pool[i] = self.embeddings(sentence[:,i,:])

    sentence_pool = sentence_pool.permute(1, 0, 2, 3)
    hidden_state['last_hidden_state'] = hidden_state['last_hidden_state'].unsqueeze(1)
    expanded_tensor = hidden_state['last_hidden_state'].expand(batch_size, self.padding_side, self.hidden_size, self.embedding_dim)

    sentence_index_pred = self.index_predict(expanded_tensor, sentence_pool, casual_sentence_index, batch_size)
    sentence_pos_pred = self.pos_predict(expanded_tensor, sentence_pool, batch_size)
    return sentence_index_pred, sentence_pos_pred
  def pos_predict(self, expanded_tensor, sentence_pool, batch_size):
    minus = sentence_pool.reshape(batch_size, self.padding_side, -1) - expanded_tensor.reshape(batch_size, self.padding_side, -1)
    effect_score = self.casual_pos_retrieval(minus)
    return effect_score
    # sentence_pos = self.casual_pos_retrieval(casual_sentence_index.reshape(batch_size, self.padding_side, -1))
  def index_predict(self, expanded_tensor, sentence_pool, casual_sentence_index, batch_size):
    # Độ tương đồng giữa câu đầu vào và các câu nguyên nhân
    sentence_pool = self.norm_768(self.casual_matching_embed(sentence_pool))
    cosine = F.cosine_similarity(sentence_pool.reshape(batch_size, self.padding_side, -1), expanded_tensor.reshape(batch_size, self.padding_side, -1), dim=2)

    # Loại bỏ các câu có độ tương tự âm và xem đó như là trọng số cho độ phù hợp của các cặp câu.
    effect_score = self.tanh(cosine)
    casual_sentence_index = effect_score.ceil() * casual_sentence_index
    
    return casual_sentence_index
  def predict(self, x):
    batch_size = x['text']['input_ids'].size(0)
    text = x['text']['input_ids'].squeeze(1).to(self.device)
    sentence = x['sentence']['input_ids'].to(self.device)
    casual_sentence_index = x['casual_sentence_index'].to(self.device)

    # Embeddings
    text = self.embeddings(text)
    hidden_state = self.encoder(text)
    
    # Emotion predict
    emotion_pred = self.emotion_predict(hidden_state)
      
    # casual emotion
    sentence_index_pred, sentence_pos_pred = self.casual_emotion(hidden_state, sentence, casual_sentence_index, batch_size)

    return emotion_pred, sentence_index_pred, sentence_pos_pred

# jlkjl

# **TRAINING**

In [12]:
train_dataset = ProcessDataset(file=file_train, format="ECF 2.0", subtask="subtask_1", tokenizer=tokenizer, max_length=max_length)
train_dataset, valid_dataset = train_dataset.split_train_test(test_size=0.2)
train_dataset, valid_dataset = Dataset(train_dataset, tokenizer=tokenizer, max_length=max_length), Dataset(valid_dataset, tokenizer=tokenizer, max_length=max_length)
train_loader = DataLoader(train_dataset, batch_size=16)
val_loader = DataLoader(valid_dataset, batch_size=16)

N_TRAIN: 10880, N_TEST: 2739


In [13]:
device = "cuda:0"
loss_fn = nn.CrossEntropyLoss()
model = VH(encoder, 512, 7, 8, loss_fn, device)
optimizer = Adam(model.parameters(), lr=0.001)
model = model.to(device)

In [14]:
trainer = Trainer(model=model, optimizer=optimizer, save_checkpoint_dir="/root/NAACL/data/checkpoint", save_model_dir="/root/NAACL/data/model", save_report_dir="/root/NAACL/data/report")
trainer.train(train_loader=train_loader, val_loader=val_loader, epochs=100)

Training epoch 0/100:   0%|          | 1/680 [00:00<04:55,  2.30it/s]


RuntimeError: CUDA out of memory. Tried to allocate 192.00 MiB (GPU 0; 10.76 GiB total capacity; 9.02 GiB already allocated; 138.56 MiB free; 9.17 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

# **PREDICTING**

In [None]:
trial_dataset = ProcessDataset(file=file_trial, format="ECF 2.0", subtask="subtask_1", tokenizer=tokenizer, max_length=max_length)
raw_trial_dataset = trial_dataset.get_data()
trial_dataset = trial_dataset.split_train_test(test_size=0)
trial_dataset = Dataset(trial_dataset, tokenizer=tokenizer, max_length=max_length)
trial_loader = DataLoader(trial_dataset, batch_size=16)

In [None]:
device = "cuda:0"
loss_fn = nn.CrossEntropyLoss()
model = VH(encoder, 512, 7, 8, loss_fn, device)
model = model.to(device)

In [None]:
predictor = Predictor(model, output_json, subtask='subtask_1', batch_size=16, save_checkpoint_dir="/root/NAACL/data/checkpoint/checkpoint_2.pth", save_model_dir=None)
predictor.predict(raw_trial_dataset, trial_loader)

In [None]:
predictor.dataset