In [22]:

training_data_path='../data/train_t.txt'
validation_data_path='../data/val_t.txt'
testing_data_path='../data/testy_t.txt'
answer_path='../data/answer.txt'

In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os, sys
import math
import pandas as pd
import pdb
import string
import re
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import get_linear_schedule_with_warmup
import argparse, logging
from sklearn.metrics import precision_recall_fscore_support
from transformers import RobertaTokenizer, RobertaModel
from transformers import BertTokenizer, BertModel

MODEL CLASS

In [5]:
class ERC_model(nn.Module):
    def __init__(self, model_type, cls_num):
        """

              model_type  : specifies which model and tokenier to used

              clsNum      : Represents the number of classes for classification.



        """

        super(ERC_model, self).__init__()
        self.gpu = True

        self.model = BertModel.from_pretrained("bert-base-multilingual-uncased")
        tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-uncased")


        self.hiddenDim = self.model.config.hidden_size

        self.W = nn.Linear(self.hiddenDim, cls_num)

    def forward(self, batch_input_tokens):

        """
              batch_input_tokens  : (batch, len) which is a batch of tokenized sequences.

              context_logit       : the unnormalized predictions for each class

        """

        batch_context_output = self.model(batch_input_tokens).last_hidden_state[:,0,:] # (batch, 1024)

        context_logit = self.W(batch_context_output) # (batch, cls_num)

        return context_logit

UTILS



> **Cleaning function**  -  will be used for preprocessing of utterances text in a conversation



In [24]:
def cleaning(text):

  def cleaning_punctuations(text):
      translator = str.maketrans('', '', string.punctuation)
      return text.translate(translator)

  text=cleaning_punctuations(text)

  def cleaning_stopwords(text):
      STOPWORDS = set(['a', 'aadi', 'aaj', 'aap', 'aapne', 'aata', 'aati', 'aaya', 'aaye', 'ab', 'abbe', 'abbey', 'abe', 'abhi', 'able', 'about', 'above', 'accha', 'according', 'accordingly', 'acha', 'achcha', 'across', 'actually', 'after', 'afterwards', 'again', 'against', 'agar', 'ain', 'aint', "ain't", 'aisa', 'aise', 'aisi', 'alag', 'all', 'allow', 'allows', 'almost', 'alone', 'along', 'already', 'also', 'although', 'always', 'am', 'among', 'amongst', 'an', 'and', 'andar', 'another', 'any', 'anybody', 'anyhow', 'anyone', 'anything', 'anyway', 'anyways', 'anywhere', 'ap', 'apan', 'apart', 'apna', 'apnaa', 'apne', 'apni', 'appear', 'are', 'aren', 'arent', "aren't", 'around', 'arre', 'as', 'aside', 'ask', 'asking', 'at', 'aur', 'avum', 'aya', 'aye', 'baad', 'baar', 'bad', 'bahut', 'bana', 'banae', 'banai', 'banao', 'banaya', 'banaye', 'banayi', 'banda', 'bande', 'bandi', 'bane', 'bani', 'bas', 'bata', 'batao', 'bc', 'be', 'became', 'because', 'become', 'becomes', 'becoming', 'been', 'before', 'beforehand', 'behind', 'being', 'below', 'beside', 'besides', 'best', 'better', 'between', 'beyond', 'bhai', 'bheetar', 'bhi', 'bhitar', 'bht', 'bilkul', 'bohot', 'bol', 'bola', 'bole', 'boli', 'bolo', 'bolta', 'bolte', 'bolti', 'both', 'brief', 'bro', 'btw', 'but', 'by', 'came', 'can', 'cannot', 'cant', "can't", 'cause', 'causes', 'certain', 'certainly', 'chahiye', 'chaiye', 'chal', 'chalega', 'chhaiye', 'clearly', "c'mon", 'com', 'come', 'comes', 'could', 'couldn', 'couldnt', "couldn't", 'd', 'de', 'dede', 'dega', 'degi', 'dekh', 'dekha', 'dekhe', 'dekhi', 'dekho', 'denge', 'dhang', 'di', 'did', 'didn', 'didnt', "didn't", 'dijiye', 'diya', 'diyaa', 'diye', 'diyo', 'do', 'does', 'doesn', 'doesnt', "doesn't", 'doing', 'done', 'dono', 'dont', "don't", 'doosra', 'doosre', 'down', 'downwards', 'dude', 'dunga', 'dungi', 'during', 'dusra', 'dusre', 'dusri', 'dvaara', 'dvara', 'dwaara', 'dwara', 'each', 'edu', 'eg', 'eight', 'either', 'ek', 'else', 'elsewhere', 'enough', 'etc', 'even', 'ever', 'every', 'everybody', 'everyone', 'everything', 'everywhere', 'ex', 'exactly', 'example', 'except', 'far', 'few', 'fifth', 'fir', 'first', 'five', 'followed', 'following', 'follows', 'for', 'forth', 'four', 'from', 'further', 'furthermore', 'gaya', 'gaye', 'gayi', 'get', 'gets', 'getting', 'ghar', 'given', 'gives', 'go', 'goes', 'going', 'gone', 'good', 'got', 'gotten', 'greetings', 'haan', 'had', 'hadd', 'hadn', 'hadnt', "hadn't", 'hai', 'hain', 'hamara', 'hamare', 'hamari', 'hamne', 'han', 'happens', 'har', 'hardly', 'has', 'hasn', 'hasnt', "hasn't", 'have', 'haven', 'havent', "haven't", 'having', 'he', 'hello', 'help', 'hence', 'her', 'here', 'hereafter', 'hereby', 'herein', "here's", 'hereupon', 'hers', 'herself', "he's", 'hi', 'him', 'himself', 'his', 'hither', 'hm', 'hmm', 'ho', 'hoga', 'hoge', 'hogi', 'hona', 'honaa', 'hone', 'honge', 'hongi', 'honi', 'hopefully', 'hota', 'hotaa', 'hote', 'hoti', 'how', 'howbeit', 'however', 'hoyenge', 'hoyengi', 'hu', 'hua', 'hue', 'huh', 'hui', 'hum', 'humein', 'humne', 'hun', 'huye', 'huyi', 'i', "i'd", 'idk', 'ie', 'if', "i'll", "i'm", 'imo', 'in', 'inasmuch', 'inc', 'inhe', 'inhi', 'inho', 'inka', 'inkaa', 'inke', 'inki', 'inn', 'inner', 'inse', 'insofar', 'into', 'inward', 'is', 'ise', 'isi', 'iska', 'iskaa', 'iske', 'iski', 'isme', 'isn', 'isne', 'isnt', "isn't", 'iss', 'isse', 'issi', 'isski', 'it', "it'd", "it'll", 'itna', 'itne', 'itni', 'itno', 'its', "it's", 'itself', 'ityaadi', 'ityadi', "i've", 'ja', 'jaa', 'jab', 'jabh', 'jaha', 'jahaan', 'jahan', 'jaisa', 'jaise', 'jaisi', 'jata', 'jayega', 'jidhar', 'jin', 'jinhe', 'jinhi', 'jinho', 'jinhone', 'jinka', 'jinke', 'jinki', 'jinn', 'jis', 'jise', 'jiska', 'jiske', 'jiski', 'jisme', 'jiss', 'jisse', 'jitna', 'jitne', 'jitni', 'jo', 'just', 'jyaada', 'jyada', 'k', 'ka', 'kaafi', 'kab', 'kabhi', 'kafi', 'kaha', 'kahaa', 'kahaan', 'kahan', 'kahi', 'kahin', 'kahte', 'kaisa', 'kaise', 'kaisi', 'kal', 'kam', 'kar', 'kara', 'kare', 'karega', 'karegi', 'karen', 'karenge', 'kari', 'karke', 'karna', 'karne', 'karni', 'karo', 'karta', 'karte', 'karti', 'karu', 'karun', 'karunga', 'karungi', 'kaun', 'kaunsa', 'kayi', 'kch', 'ke', 'keep', 'keeps', 'keh', 'kehte', 'kept', 'khud', 'ki', 'kin', 'kine', 'kinhe', 'kinho', 'kinka', 'kinke', 'kinki', 'kinko', 'kinn', 'kino', 'kis', 'kise', 'kisi', 'kiska', 'kiske', 'kiski', 'kisko', 'kisliye', 'kisne', 'kitna', 'kitne', 'kitni', 'kitno', 'kiya', 'kiye', 'know', 'known', 'knows', 'ko', 'koi', 'kon', 'konsa', 'koyi', 'krna', 'krne', 'kuch', 'kuchch', 'kuchh', 'kul', 'kull', 'kya', 'kyaa', 'kyu', 'kyuki', 'kyun', 'kyunki', 'lagta', 'lagte', 'lagti', 'last', 'lately', 'later', 'le', 'least', 'lekar', 'lekin', 'less', 'lest', 'let', "let's", 'li', 'like', 'liked', 'likely', 'little', 'liya', 'liye', 'll', 'lo', 'log', 'logon', 'lol', 'look', 'looking', 'looks', 'ltd', 'lunga', 'm', 'maan', 'maana', 'maane', 'maani', 'maano', 'magar', 'mai', 'main', 'maine', 'mainly', 'mana', 'mane', 'mani', 'mano', 'many', 'mat', 'may', 'maybe', 'me', 'mean', 'meanwhile', 'mein', 'mera', 'mere', 'merely', 'meri', 'might', 'mightn', 'mightnt', "mightn't", 'mil', 'mjhe', 'more', 'moreover', 'most', 'mostly', 'much', 'mujhe', 'must', 'mustn', 'mustnt', "mustn't", 'my', 'myself', 'na', 'naa', 'naah', 'nahi', 'nahin', 'nai', 'name', 'namely', 'nd', 'ne', 'near', 'nearly', 'necessary', 'neeche', 'need', 'needn', 'neednt', "needn't", 'needs', 'neither', 'never', 'nevertheless', 'new', 'next', 'nhi', 'nine', 'no', 'nobody', 'non', 'none', 'noone', 'nope', 'nor', 'normally', 'not', 'nothing', 'novel', 'now', 'nowhere', 'o', 'obviously', 'of', 'off', 'often', 'oh', 'ok', 'okay', 'old', 'on', 'once', 'one', 'ones', 'only', 'onto', 'or', 'other', 'others', 'otherwise', 'ought', 'our', 'ours', 'ourselves', 'out', 'outside', 'over', 'overall', 'own', 'par', 'pata', 'pe', 'pehla', 'pehle', 'pehli', 'people', 'per', 'perhaps', 'phla', 'phle', 'phli', 'placed', 'please', 'plus', 'poora', 'poori', 'provides', 'pura', 'puri', 'q', 'que', 'quite', 'raha', 'rahaa', 'rahe', 'rahi', 'rakh', 'rakha', 'rakhe', 'rakhen', 'rakhi', 'rakho', 'rather', 're', 'really', 'reasonably', 'regarding', 'regardless', 'regards', 'rehte', 'rha', 'rhaa', 'rhe', 'rhi', 'ri', 'right', 's', 'sa', 'saara', 'saare', 'saath', 'sab', 'sabhi', 'sabse', 'sahi', 'said', 'sakta', 'saktaa', 'sakte', 'sakti', 'same', 'sang', 'sara', 'sath', 'saw', 'say', 'saying', 'says', 'se', 'second', 'secondly', 'see', 'seeing', 'seem', 'seemed', 'seeming', 'seems', 'seen', 'self', 'selves', 'sensible', 'sent', 'serious', 'seriously', 'seven', 'several', 'shall', 'shan', 'shant', "shan't", 'she', "she's", 'should', 'shouldn', 'shouldnt', "shouldn't", "should've", 'si', 'since', 'six', 'so', 'soch', 'some', 'somebody', 'somehow', 'someone', 'something', 'sometime', 'sometimes', 'somewhat', 'somewhere', 'soon', 'still', 'sub', 'such', 'sup', 'sure', 't', 'tab', 'tabh', 'tak', 'take', 'taken', 'tarah', 'teen', 'teeno', 'teesra', 'teesre', 'teesri', 'tell', 'tends', 'tera', 'tere', 'teri', 'th', 'tha', 'than', 'thank', 'thanks', 'thanx', 'that', "that'll", 'thats', "that's", 'the', 'theek', 'their', 'theirs', 'them', 'themselves', 'then', 'thence', 'there', 'thereafter', 'thereby', 'therefore', 'therein', 'theres', "there's", 'thereupon', 'these', 'they', "they'd", "they'll", "they're", "they've", 'thi', 'thik', 'thing', 'think', 'thinking', 'third', 'this', 'tho', 'thoda', 'thodi', 'thorough', 'thoroughly', 'those', 'though', 'thought', 'three', 'through', 'throughout', 'thru', 'thus', 'tjhe', 'to', 'together', 'toh', 'too', 'took', 'toward', 'towards', 'tried', 'tries', 'true', 'truly', 'try', 'trying', 'tu', 'tujhe', 'tum', 'tumhara', 'tumhare', 'tumhari', 'tune', 'twice', 'two', 'um', 'umm', 'un', 'under', 'unhe', 'unhi', 'unho', 'unhone', 'unka', 'unkaa', 'unke', 'unki', 'unko', 'unless', 'unlikely', 'unn', 'unse', 'until', 'unto', 'up', 'upar', 'upon', 'us', 'use', 'used', 'useful', 'uses', 'usi', 'using', 'uska', 'uske', 'usne', 'uss', 'usse', 'ussi', 'usually', 'vaala', 'vaale', 'vaali', 'vahaan', 'vahan', 'vahi', 'vahin', 'vaisa', 'vaise', 'vaisi', 'vala', 'vale', 'vali', 'various', 've', 'very', 'via', 'viz', 'vo', 'waala', 'waale', 'waali', 'wagaira', 'wagairah', 'wagerah', 'waha', 'wahaan', 'wahan', 'wahi', 'wahin', 'waisa', 'waise', 'waisi', 'wala', 'wale', 'wali', 'want', 'wants', 'was', 'wasn', 'wasnt', "wasn't", 'way', 'we', "we'd", 'well', "we'll", 'went', 'were', "we're", 'weren', 'werent', "weren't", "we've", 'what', 'whatever', "what's", 'when', 'whence', 'whenever', 'where', 'whereafter', 'whereas', 'whereby', 'wherein', "where's", 'whereupon', 'wherever', 'whether', 'which', 'while', 'who', 'whoever', 'whole', 'whom', "who's", 'whose', 'why', 'will', 'willing', 'with', 'within', 'without', 'wo', 'woh', 'wohi', 'won', 'wont', "won't", 'would', 'wouldn', 'wouldnt', "wouldn't", 'y', 'ya', 'yadi', 'yah', 'yaha', 'yahaan', 'yahan', 'yahi', 'yahin', 'ye', 'yeah', 'yeh', 'yehi', 'yes', 'yet', 'you', "you'd", "you'll", 'your', "you're", 'yours', 'yourself', 'yourselves', "you've", 'yup'])

      return " ".join([word for word in str(text).split() if word not in
                       STOPWORDS])

  text=cleaning_stopwords(text)


  def cleaning_numbers(text):
      return re.sub('[0-9]+', '', text)

  text=cleaning_numbers(text)

  return text


>  **encode_right_truncated function** :

*   Tokenizes the input text using the provided tokenizer.
*   Truncates the tokenized sequence from the right if its length exceeds '*max_length*'.
*   Converts the tokens to their corresponding IDs.
*   Returns a list of token IDs starting with the ID of the [CLS] token followed by the truncated token IDs.

>  **padding function** :

*   Calculates the maximum length among the token ID lists in ids_list.
*   Pads each token ID list with the padding token ID from the tokenizer to match the maximum length.
*   Returns a PyTorch tensor of padded token ID sequences.



In [7]:

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-uncased')


def encode_right_truncated(text, tokenizer, max_length=511):
    tokenized = tokenizer.tokenize(text)
    truncated = tokenized[-max_length:]
    ids = tokenizer.convert_tokens_to_ids(truncated)
    return [tokenizer.cls_token_id] + ids



def padding(ids_list, tokenizer):

    max_len = max(len(ids) for ids in ids_list)

    pad_ids = []
    for ids in ids_list:
        pad_len = max_len-len(ids)
        add_ids = [tokenizer.pad_token_id for _ in range(pad_len)]

        pad_ids.append(ids+add_ids)

    return torch.tensor(pad_ids)




>  **make_batch_roberta and make_batch_bert collate function** :

*   Prepares a batch of inputs and labels suitable for RoBERTa and mBERT model processing.
*   Processes each session, tokenizes the context using the RoBERTa tokenizer, and pads the token IDs.
*   Assigns labels based on the provided data.
*   Returns the batch of input tokens and corresponding labels as PyTorch tensors.




In [11]:


def make_batch_bert(sessions):
    batch_input, batch_labels = [], []
    for session in sessions:
        data = session[0]
        label_list = session[1]

        context_speaker, context, emotion, sentiment = data


        now_speaker = context_speaker[-1]
        speaker_utt_list = []

        inputString = ""
        for turn, (speaker, utt) in enumerate(zip(context_speaker, context)):
            inputString += '<s' + str(speaker+1) + '> ' # s1, s2, s3...
            inputString += utt + " "

            if turn<len(context_speaker)-1 and speaker == now_speaker:
                speaker_utt_list.append(encode_right_truncated(utt, bert_tokenizer))

        concat_string = inputString.strip()
        batch_input.append(encode_right_truncated(concat_string, bert_tokenizer))

        if len(label_list) > 3:
            label_ind = label_list.index(emotion)
        else:
            label_ind = label_list.index(sentiment)
        batch_labels.append(label_ind)

    batch_input_tokens = padding(batch_input, bert_tokenizer)
    batch_labels = torch.tensor(batch_labels)

    return batch_input_tokens, batch_labels



DATALOADER

In [12]:
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_sequence
import random

class DATA_loader(Dataset):
    def __init__(self, txt_file, dataclass):
        self.dialogs = []

        f = open(txt_file, 'r')
        dataset = f.readlines()
        f.close()

        temp_speakerList = []
        context = []
        context_speaker = []
        self.speakerNum = []

        # 'anger', 'disgust', 'fear', 'joy', 'neutral', 'sadness', 'surprise'
        emodict = {'anger': "anger", 'disgust': "disgust", 'fear': "fear", 'joy': "joy", 'neutral': "neutral", 'sadness': "sad", 'surprise': 'surprise','contempt':"contempt"}
        self.sentidict = {'positive': ["joy"], 'negative': ["anger", "disgust", "fear", "sadness","contempt"], 'neutral': ["neutral", "surprise"]}
        self.emoSet = set()
        self.sentiSet = set()

        for i, data in enumerate(dataset):
            if i < 1:
                continue

            if (data == '\n' and len(self.dialogs) > 0) or data.strip().split('\t')==['']:

                self.speakerNum.append(len(temp_speakerList))
                temp_speakerList = []
                context = []
                context_speaker = []
                continue

            if data.strip().split('\t')!=['']:


              speaker, utt, emo, senti = data.strip().split('\t')

            utt=cleaning(utt)
            context.append(utt)
            if speaker not in temp_speakerList:
                temp_speakerList.append(speaker)
            speakerCLS = temp_speakerList.index(speaker)
            context_speaker.append(speakerCLS)

            self.dialogs.append([context_speaker[:], context[:], emodict[emo], senti])
            self.emoSet.add(emodict[emo])
            self.sentiSet.add(senti)

        self.emoList = sorted(self.emoSet)

        self.labelList = self.emoList

        self.speakerNum.append(len(temp_speakerList))

    def __len__(self):

        return len(self.dialogs)

    def __getitem__(self, idx):
        return self.dialogs[idx], self.labelList, self.sentidict

ACCURACY AND LOSS CALCULATE FUNCTION

In [13]:
def _CalACC(model, dataloader):
    model.eval()
    correct = 0
    label_list = []
    pred_list = []

    p1num, p2num, p3num = 0, 0, 0
    # label arragne
    with torch.no_grad():
        for i_batch, data in enumerate(dataloader):
            """Prediction"""
            batch_input_tokens, batch_labels = data
            batch_input_tokens, batch_labels = batch_input_tokens.cuda(), batch_labels.cuda()

            pred_logits = model(batch_input_tokens) # (1, clsNum)

            """Calculation"""
            pred_logits_sort = pred_logits.sort(descending=True)
            indices = pred_logits_sort.indices.tolist()[0]

            pred_label = indices[0] # pred_logits.argmax(1).item()
            true_label = batch_labels.item()

            pred_list.append(pred_label)
            label_list.append(true_label)
            if pred_label == true_label:
                correct += 1

            """Calculation precision"""
            if true_label in indices[:1]:
                p1num += 1
            if true_label in indices[:2]:
                p2num += 1/2
            if true_label in indices[:3]:
                p3num += 1/3

        p1 = round(p1num/len(dataloader)*100, 2)
        p2 = round(p2num/len(dataloader)*100, 2)
        p3 = round(p3num/len(dataloader)*100, 2)
    return [p1, p2, p3], pred_list, label_list

def _CalACCT(model, dataloader):
    model.eval()
    pred_list = []
    with torch.no_grad():
        for i_batch, data in enumerate(dataloader):
            """Prediction"""
            batch_input_tokens, batch_labels = data
            batch_input_tokens, batch_labels = batch_input_tokens.cuda(), batch_labels.cuda()

            pred_logits = model(batch_input_tokens) # (1, clsNum)

            """Calculation"""
            pred_logits_sort = pred_logits.sort(descending=True)
            indices = pred_logits_sort.indices.tolist()[0]

            pred_label = indices[0] # pred_logits.argmax(1).item()
            true_label = batch_labels.item()

            pred_list.append(pred_label)

    return pred_list

def CELoss(pred_outs, labels):
    """
        pred_outs: [batch, clsNum]
        labels: [batch]
    """
    loss = nn.CrossEntropyLoss()
    loss_val = loss(pred_outs, labels)
    return loss_val

In [14]:
def _SaveModel(model, path):
    if not os.path.exists(path):
        os.makedirs(path)
    torch.save(model.state_dict(),"./model_mBERT_no_transliteratec.pth" )


MAIN-TRAINING

(Platform used : Google Colab T4 GPU)

In [25]:
torch.cuda.empty_cache()
"""Dataset Loading"""
batch_size = 1
epoch=10
norm=10
lr=1e-6
dataclass = 'emotion'
sample = 1.0
model_type = "bert-base-multilingual-uncased"
dataset="MELD"
make_batch = make_batch_bert



train_dataset = DATA_loader(training_data_path,dataclass)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=make_batch)
train_sample_num = int(len(train_dataset)*sample)

dev_dataset = DATA_loader(validation_data_path, dataclass)
dev_dataloader = DataLoader(dev_dataset, batch_size=1, shuffle=False, num_workers=2, collate_fn=make_batch)

test_dataset = DATA_loader(testing_data_path, dataclass)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2, collate_fn=make_batch)

In [26]:
train_dataset.__getitem__(3)[0]

[[0, 1, 2, 1],
 ['क्य क्य भर् के रख है इन्द्रवदन् ने इस् घर् मेइन् इन्द्रवदन् प्लेअसे तुम् सरि बेकर् कि चिजेन् बहर् क्योन् नहिन् फ़ेक्ते',
  'ओक् चलो रोसेश् चलो बहर्',
  'मोम्म हथ् छोदिये दद्',
  'देखो मय य न बोल् रह है फ़िर् तुम् हि कहोगि फ़ल्तु चिजेन् जम कर्ते हो फ़ेक्ते नहिन् हो चलो'],
 'neutral',
 'neutral']

In [29]:


print('DataClass: ', dataclass, '!!!') # emotion
no_of_emotion_classes = len(train_dataset.labelList)


# Use the converted tensor in ERC_model initialization
model = ERC_model(model_type, no_of_emotion_classes)

#model.load_state_dict(torch.load("./model_mBERT_no_transliteratec.pth"))


DataClass:  emotion !!!


Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [30]:

"""Training Setting"""

model = model.cuda()
model.train()

training_epochs = epoch
save_term = int(training_epochs/5)
max_grad_norm = norm
num_training_steps = len(train_dataset)*training_epochs
num_warmup_steps = len(train_dataset)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr) # , eps=1e-06, weight_decay=0.01
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
save_path = os.path.join(dataset+'_models', model_type, dataclass)

In [None]:
test_preds_listes=[]

"""Input & Label Setting"""

best_dev_fscore, best_test_fscore = 0, 0
best_dev_fscore_macro, best_dev_fscore_micro, best_test_fscore_macro, best_test_fscore_micro = 0, 0, 0, 0
best_epoch = 0


for epoch in tqdm(range(training_epochs)):
    model.train()

    for i_batch, data in (enumerate(train_dataloader)):

        if i_batch%1000==0 or i_batch==10:
          print('iterater crossed    ---   ',i_batch)
        if i_batch > train_sample_num:
            print(i_batch, train_sample_num)
            break

        """Prediction"""
        batch_input_tokens, batch_labels = data
        batch_input_tokens, batch_labels = batch_input_tokens.cuda(), batch_labels.cuda()
        try:
            pred_logits = model(batch_input_tokens)
        except:
            pdb.set_trace()

        """Loss calculation & training"""
        loss_val = CELoss(pred_logits, batch_labels)

        loss_val.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)  # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

    """Dev & Test evaluation"""
    model.eval()

    dev_prek, dev_pred_list, dev_label_list = _CalACC(model, dev_dataloader)
    dev_pre, dev_rec, dev_fbeta, _ = precision_recall_fscore_support(dev_label_list, dev_pred_list, average='weighted')

    """Best Score & Model Save"""
    if dev_fbeta > best_dev_fscore:
        best_dev_fscore = dev_fbeta
        test_pred_list = _CalACCT(model, test_dataloader)
        test_preds_listes.append(test_pred_list)

        # Open the file in write mode
        with open(answer_path, 'w') as file:
            # Write each item in the list to a new line in the file
            id2emo={0:'anger',1: 'contempt', 2:'disgust', 3:'fear', 4:'joy',5: 'neutral', 6:'sadness', 7:'surprise'}
            for item in test_preds_listes:
              for i in item:
                i = id2emo[i]

            for item in test_preds_listes:
                file.write(f'{item}\n')

        best_epoch = epoch
        _SaveModel(model, save_path)

    print('Epoch: {}'.format(epoch))


    print('Devleopment ## precision: {}, precision: {}, recall: {}, fscore: {}'.format(dev_prek, dev_pre, dev_rec, dev_fbeta))
    print('')