In [None]:
# Import Google Drive to access training and testing data
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
#@title Installs
!pip install datasets transformers sentence-transformers langdetect torchinfo

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
#@title Imports
# Import Libraries

import os
import pickle
import random
from collections import OrderedDict

import torch
import torch.nn as nn
from torchinfo import summary
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup

from tqdm import tqdm, trange
import csv
import logging
import sys

from sklearn.metrics import f1_score, precision_score, recall_score
from nltk.corpus import wordnet
import numpy as np

from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from sentence_transformers import SentenceTransformer, util

import requests
from bs4 import BeautifulSoup
import re
import unicodedata
from collections.abc import Iterable
from decimal import Decimal

from google.colab import files

In [None]:
# @title Stemmer
# Extracts the root word by removing prefixes, suffixes, and repetitions
"""
    CONSTANTS
"""
VOWELS = "aeiouAEIOU"
CONSONANTS = "bcdfghklmnngpqrstvwyBCDFGHKLMNNGPQRSTVWY"

"""
    Affixes
"""
PREFIX_SET = [
    'nakikipag', 'pakikipag',
    'pinakama', 'pagpapa',
    'pinagka', 'panganga',
    'makapag', 'nakapag',
    'tagapag', 'makipag',
    'nakipag', 'tigapag',
    'pakiki', 'magpa',
    'napaka', 'pinaka',
    'ipinag', 'pagka',
    'pinag', 'mapag',
    'mapa', 'taga',
    'ipag', 'tiga',
    'pala', 'pina',
    'pang', 'naka',
    'nang', 'mang',
    'sing',
    'ipa', 'pam',
    'pan', 'pag',
    'tag', 'mai',
    'mag', 'nam',
    'nag', 'man',
    'may', 'ma',
    'na', 'ni',
    'pa', 'ka',
    'um', 'in',
    'i',
]

INFIX_SET = [
    'um', 'in',
]

SUFFIX_SET = [
    'syon','dor',
    'ita', 'han',
    'hin', 'ing',
    'ang', 'ng',
    'an', 'in',
    'g',
]

PERIOD_FLAG = True
PASS_FLAG = False

def stemmer(source, info_dis="1", mode="2"):
    """
        Stems the tokens in a sentence.
            mode: if from .txt or string
            source: the string or file name
        returns LIST
    """

    global PERIOD_FLAG
    global PASS_FLAG

    word_info    = {}
    stemmed      = []
    word_root    = []
    root_only    = []
    errors       = []
    pre_stem     = inf_stem = suf_stem = rep_stem = \
        du1_stem = du2_stem = cle_stem = '-'
    root_words = []

    PREFIX     = []
    INFIX      = []
    SUFFIX     = []
    DUPLICATE  = []
    REPITITION = []
    CLEANERS   = []

    if mode == "1":
        tokens = read_file(source)

    elif mode == "2":
        tokens = source.split()

    else:
        sys.exit()

    for token in tokens:
        word_info["word"] = token

        if (PERIOD_FLAG == True and token[0].isupper()) or \
            (PERIOD_FLAG == False and token[0].islower()):

            token      = token.lower()
            du1_stem = clean_duplication(token, DUPLICATE)
            pre_stem = clean_prefix(du1_stem, PREFIX)
            rep_stem = clean_repitition(pre_stem, REPITITION)
            inf_stem = clean_infix(rep_stem, INFIX)
            rep_stem = clean_repitition(inf_stem, REPITITION)
            suf_stem = clean_suffix(rep_stem, SUFFIX)
            du2_stem = clean_duplication(suf_stem, DUPLICATE)
            cle_stem = clean_stemmed(du2_stem, CLEANERS, REPITITION)
            cle_stem = clean_duplication(cle_stem, DUPLICATE)

            if '-' in cle_stem:
                cle_stem.replace('-', '')

            # if stemmed is wrong, go to 2nd pass
            if check_validation(cle_stem) == False:
                PASS_FLAG = True
                du1_stem  = clean_duplication(cle_stem, DUPLICATE)
                pre_stem  = clean_prefix(du1_stem, PREFIX)
                rep_stem  = clean_repitition(pre_stem, REPITITION)
                inf_stem  = clean_infix(rep_stem, INFIX)
                rep_stem  = clean_repitition(inf_stem, REPITITION)
                suf_stem  = clean_suffix(rep_stem, SUFFIX)
                du2_stem  = clean_duplication(suf_stem, DUPLICATE)
                cle_stem  = clean_stemmed(du2_stem, CLEANERS, REPITITION)
                cle_stem  = clean_duplication(cle_stem, DUPLICATE)

            word_info["root"]   = cle_stem
            word_info["prefix"] = PREFIX
            word_info["infix"]  = INFIX
            word_info["suffix"] = SUFFIX
            word_info["repeat"] = REPITITION
            word_info["dupli"]  = DUPLICATE
            word_info["clean"]  = CLEANERS

            PASS_FLAG  = False
            PERIOD_FLAG = False
            PREFIX     = []
            INFIX      = []
            SUFFIX     = []
            DUPLICATE  = []
            REPITITION = []
            CLEANERS   = []

        else:
            PERIOD_FLAG = False
            cle_stem = clean_stemmed(token, CLEANERS, REPITITION)
            word_info["root"]   = token
            word_info["prefix"] = '[]'
            word_info["infix"]  = '[]'
            word_info["suffix"] = '[]'
            word_info["repeat"] = '[]'
            word_info["dupli"]  = '[]'
            word_info["clean"]   = '[]'

        stemmed.append(word_info)
        root_only.append(word_info["root"])
        word_root.append(word_info["word"] + ' : ' + word_info["root"])

        if info_dis == '1':
            root_words.append(word_info["root"])

        else:
            pass

        word_info = {}
        pre_stem = inf_stem = suf_stem = rep_stem = \
        du1_stem = du2_stem = cle_stem = '-'

    return stemmed, root_only[0]


def clean_duplication(token, DUPLICATE):
    """
        Checks token for duplication. (ex. araw-araw = araw)
            token: word to be stemmed duplication
        returns STRING
    """

    if check_validation(token):
        return token

    if '-' in token and token.index('-') != 0 and \
        token.index('-') != len(token) -  1:

        split = token.split('-')

        if all(len(tok) >= 3 for tok in split):
            if split[0] == token[1] or split[0][-1] == 'u' and change_letter(split[0], -1, 'o') == split[1] or \
                split[0][-2] == 'u' and change_letter(split[0], -2, 'o')  == split[1]:
                DUPLICATE.append(split[0])
                return split[0]

            elif split[0] == split[1][0:len(split[0])]:
                DUPLICATE.append(split[1])
                return split[1]

            elif split[0][-2:] == 'ng':
                if split[0][-3] == 'u':
                    if split[0][0:-3] + 'o' == split[1]:
                        DUPLICATE.append(split[1])
                        return split[1]

                if split[0][0:-2] == split[1]:
                    DUPLICATE.append(split[1])
                    return split[1]

        else:
            return '-'.join(split)

    return token


def clean_repitition(token, REPITITION):
    """
        Checks token for repitition. (ex. nakakabaliw = nabaliw)
            token: word to be stemmed repitition
        returns STRING
    """

    if check_validation(token):
        return token

    if len(token) >= 4:
        if check_vowel(token[0]):
            if token[0] == token[1]:
                REPITITION.append(token[0])
                return token[1:]

        elif check_consonant(token[0]) and count_vowel(token) >= 2:
            if token[0: 2] == token[2: 4] and len(token) - 2 >= 4:
                REPITITION.append(token[2:4])
                return token[2:]

            elif token[0: 3] == token[3: 6] and len(token) - 3 >= 4:
                REPITITION.append(token[3:6])
                return token[3:]

    return token


def clean_prefix(token,     PREFIX):
    """
        Checks token for prefixes. (ex. naligo = ligo)
            token: word to be stemmed for prefixes
        returns STRING
    """

    if check_validation(token):
        return token

    for prefix in PREFIX_SET:
        if len(token) - len(prefix) >= 3 and \
            count_vowel(token[len(prefix):]) >= 2:

            if prefix == ('i') and check_consonant(token[2]):
                continue

            if '-' in token:
                token = token.split('-')

                if token[0] == prefix and check_vowel(token[1][0]):
                    PREFIX.append(prefix)
                    return token[1]

                token = '-'.join(token)

            if token[0: len(prefix)] == prefix:
                if count_vowel(token[len(prefix):]) >= 2:

                    if prefix == 'panganga':
                        PREFIX.append(prefix)
                        return 'ka' + token[len(prefix):]

                    PREFIX.append(prefix)
                    return token[len(prefix):]

    return token


def clean_infix(token, INFIX):
    """
        Checks token for infixes. (ex. bumalik = balik)
            token: word to be stemmed for infixes
        returns STRING
    """

    if check_validation(token):
        return token

    for infix in INFIX_SET:
        if len(token) - len(infix) >= 3 and count_vowel(token[len(infix):]) >= 2:
            if token[0] == token[4] and token[1: 4] == infix:
                INFIX.append(infix)
                return token[4:]

            elif token[2] == token[4] and token[1: 3] == infix:
                INFIX.append(infix)
                return token[0] + token[3:]

            elif token[1: 3] == infix and check_vowel(token[3]):
                INFIX.append(infix)
                return token[0] + token[3:]

    return token


def clean_suffix(token, SUFFIX):
    """
        Checks token for suffixes. (ex. bigayan = bigay)
            token: word to be stemmed for suffixes
        returns STRING
    """

    SUF_CANDIDATE = []

    if check_validation(token):
        return token

    for suffix in SUFFIX_SET:
        if len(token) - len(suffix) >= 3 and count_vowel(token[0:len(token) - len(suffix)]) >= 2:
            if token[len(token) - len(suffix): len(token)] == suffix:
                if len(suffix) == 2 and not count_consonant(token[0:len(token) - len(suffix)]) >= 1:
                    continue

                if count_vowel(token[0: len(token) - len(suffix)]) >= 2:
                    if suffix == 'ang' and check_consonant(token[-4]) \
                        and token[-4] != 'r' and token[-5] != 'u':
                        continue

                    if check_validation(token[0: len(token) - len(suffix)]):
                        SUFFIX.append(suffix)
                        return token[0: len(token) - len(suffix)] + 'a' if SUFFIX == 'ita' \
                            else  token[0: len(token) - len(suffix)]

                    elif len(SUF_CANDIDATE) == 0:
                        SUF_CANDIDATE.append(suffix)
                        SUF_CANDIDATE.append(token[0: len(token) - len(suffix)])

    if(len(SUF_CANDIDATE) == 2):
        SUFFIX = SUF_CANDIDATE[0]
        return SUF_CANDIDATE[1][0: len(token) - len(suffix)] + 'a' if SUFFIX == 'ita' \
            else  SUF_CANDIDATE[1][0: len(token) - len(suffix)]

    return token


def check_vowel(substring):
    """
        Checks if the substring is a vowel.
            letters: substring to be tested
        returns BOOLEAN
    """

    return all(letter in VOWELS for letter in substring)


def check_consonant(substring):
    """
        Checks if the letter is a consonant.
            letter: substring to be tested
        returns BOOLEAN
    """

    return all(letter in CONSONANTS for letter in substring)



def count_vowel(token):
    """
        Count vowels in a given token.
            token: string to be counted for vowels
        returns INTEGER
    """

    count = 0

    for tok in token:
        if check_vowel(tok):
            count+=1

    return count


def count_consonant(token):
    """
        Count consonants in a given token.
            token: string to be counted for consonants
        returns INTEGER
    """

    count = 0

    for tok in token:
        if check_consonant(tok):
            count+=1

    return count


def change_letter(token, index, letter):
    """
        Replaces a letter in a token.
            token: word to be used
            index: index of the letter
            letter: letter used to replace
        returns STRING
    """

    _list = list(token)
    _list[index] = letter

    return ''.join(_list)


def clean_stemmed(token, CLEANERS, REPITITION):
    """
        Checks for left-over affixes and letters.
            token: word to be cleaned for excess affixes/letters
        returns STRING
    """

    global PERIOD_FLAG
    global PASS_FLAG

    CC_EXP = ['dr', 'gl', 'gr', 'ng', 'kr', 'kl', 'kw', 'ts', 'tr', 'pr', 'pl', 'pw', 'sw', 'sy'] # Consonant + Consonant Exceptions

    if token[-1] == '.' and PASS_FLAG == False:
        PERIOD_FLAG = True

    if not check_vowel(token[-1]) and not check_consonant(token[-1]):
        CLEANERS.append(token[-1])
        token = token[0:-1]

    if not check_vowel(token[0]) and not check_consonant(token[0]):
        CLEANERS.append(token[0])
        token = token[1:]

    if check_validation(token):
        return token

    if len(token) >= 3 and count_vowel(token) >= 2:
        token = clean_repitition(token,    REPITITION)

        if check_consonant(token[-1]) and token[- 2] == 'u':
            CLEANERS.append('u')
            token = change_letter(token, -2, 'o')

        if token[len(token) - 1] == 'u':
            CLEANERS.append('u')
            token = change_letter(token, -1, 'o')

        if token[-1] == 'r':
            CLEANERS.append('r')
            token = change_letter(token, -1, 'd')

        if token[-1] == 'h' and check_vowel(token[-1]):
            CLEANERS.append('h')
            token = token[0:-1]

        if token[0] == token[1]:
            CLEANERS.append(token[0])
            token = token[1:]

        if (token[0: 2] == 'ka' or token[0: 2] == 'pa') and check_consonant(token[2]) \
            and count_vowel(token) >= 3:

            CLEANERS.append(token[0: 2])
            token = token[2:]

        if(token[-3:]) == 'han' and count_vowel(token[0:-3]) == 1:
            CLEANERS.append('han')
            token = token[0:-3] + 'i'

        if(token[-3:]) == 'han' and count_vowel(token[0:-3]) > 1:
            CLEANERS.append('han')
            token = token[0:-3]

        if len(token) >= 2 and count_vowel(token) >= 3:
            if token[-1] == 'h' and check_vowel(token[-2]):
                CLEANERS.append('h')
                token = token[0:-1]

        if len(token) >= 6 and token[0:2] == token[2:4]:
            CLEANERS.append('0:2')
            token = token[2:]

        if any(REP[0] == 'r' for REP in REPITITION):
            CLEANERS.append('r')
            token = change_letter(token, 0, 'd')

        if token[-2:] == 'ng' and token[-3] == 'u':
            CLEANERS.append('u')
            token = change_letter(token, -3, 'o')

        if token[-1] == 'h':
            CLEANERS.append('h')
            token = token[0:-1]

        if any(token[0:2] != CC for CC in CC_EXP) and check_consonant(token[0:2]):
            CLEANERS.append(token[0:2])
            token = token[1:]

    return token



def check_validation(token):
    file_path = "/content/validation.txt"
    with open(file_path, 'r') as valid:
        data = valid.read().replace('\n', ' ').split(' ')

    return True if token in data else False



def validate(stemmed, errors):
    """
        Calculates accuracy.
            stemmed: list of stemmed words
            errors: list of stemming errors
        returns FLOAT
    """

    check = 0
    file_path = "/content/validation.txt"
    with open(file_path, 'r') as valid:
        data = valid.read().replace('\n', ' ').split(' ')

    for stem in stemmed:
        if stem[0].isupper() or stem in data:
            check += 1

        else:
            errors.append(stem)

    return format((float(check) / len(stemmed) * 100), '.2f') # Python 2.7
    # return format((check / len(stemmed) * 100), '.2f') # Python 3.0



In [None]:
# @title Definitions
# Webscraber to get definitions from online dictionary

def flatten(coll):
    for i in coll:
            if isinstance(i, Iterable) and not isinstance(i, str):
                for subc in flatten(i):
                    yield subc
            else:
                yield i

def clean(element):
   element = unicodedata.normalize('NFD', element).encode('ascii', 'ignore').decode('utf-8')
   return re.sub(r'[^a-zA-Z]', '', element)

def get_definitions(word, definitions, depth):
  URL = f"https://diksiyonaryo.ph/search/{word}"
  page = requests.get(URL, verify=False)
  soup = BeautifulSoup(page.content, "html.parser")
  same_word_flag = True
  definitions = []

  elements = soup.find_all(class_=["pronunciation", "definition"])

  for element in elements:
      if depth == 2:
          break
      elif element.get('class') == ['pronunciation'] and clean(element.text) != word:
         same_word_flag = False
      elif element.get('class') == ['pronunciation'] and clean(element.text) == word:
         same_word_flag = True
      elif element.find("a") and same_word_flag == True:
         try:
          definitions.append(get_definitions(clean(element.text), [], depth + 1))
         except RecursionError as err:
            return list(flatten(definitions))
      elif same_word_flag == True:
         definitions.append(element.text)
  return list(flatten(definitions))


def get_definition_embeddings(definitions):
  definition_embeddings = []
  for definition in definitions:
    definition_embeddings.append(sentence_transformer.encode(definition))
  return definition_embeddings

def choose_definition(definitions, definition_embeddings, sentence):
  sentence_embeddings = sentence_transformer.encode(sentence)
  similarities = []
  for definition in definition_embeddings:
    similarities.append(util.pytorch_cos_sim(sentence_embeddings, definition))
  if len(similarities) != 0: return definitions[np.argmax(similarities)]
  else: return definitions[0]

In [None]:
#@title Utils
def set_random_seed(seed):
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
#@title Logger

class Logger:
    def __init__(self, log_dir):
        log_file_format = "[%(lineno)d]%(asctime)s: %(message)s"
        log_console_format = "%(message)s"

        # Main logger
        self.log_dir = log_dir

        self.logger = logging.getLogger(log_dir)
        self.logger.setLevel(logging.INFO)
        self.logger.propagate = False

        console_handler = logging.StreamHandler()
        console_handler.setLevel(logging.INFO)
        console_handler.setFormatter(logging.Formatter(log_console_format))

        file_handler = logging.FileHandler(os.path.join(log_dir, "experiments.log"))
        file_handler.setLevel(logging.DEBUG)
        file_handler.setFormatter(logging.Formatter(log_file_format))

        self.logger.addHandler(console_handler)
        self.logger.addHandler(file_handler)

    def info(self, msg):
        self.logger.info(msg)

    def close(self):
        for handle in self.logger.handlers[:]:
            self.logger.removeHandler(handle)
        logging.shutdown()


def setup_logger(log_dir):
    log_file_format = "[%(lineno)d]%(asctime)s: %(message)s"
    log_console_format = "%(message)s"

    # Main logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    logger.propagate = False

    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_handler.setFormatter(logging.Formatter(log_console_format))

    file_handler = logging.FileHandler(os.path.join(log_dir, "experiments.log"))
    file_handler.setLevel(logging.DEBUG)
    file_handler.setFormatter(logging.Formatter(log_file_format))

    logger.addHandler(console_handler)
    logger.addHandler(file_handler)

    return logger

In [None]:
#@title Run Classifier Dataset
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BERT classification fine-tuning: utilities to work with GLUE tasks """



logger = logging.getLogger(__name__)


class InputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(
        self,
        guid,
        text_a,
        text_b=None,
        label=None,
        text_a_2=None,
        text_b_2=None,
    ):
        """Constructs a InputExample.
        Args:
            guid: Unique id for the example.
            text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
            text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
            label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label
        self.text_a_2 = text_a_2
        self.text_b_2 = text_b_2


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(
        self,
        input_ids,
        input_mask,
        segment_ids,
        label_id,
        guid=None,
        input_ids_2=None,
        input_mask_2=None,
        segment_ids_2=None,
        sentence_mask=None,
        word_index=None,
    ):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id
        self.guid = guid
        self.input_ids_2 = input_ids_2
        self.input_mask_2 = input_mask_2
        self.segment_ids_2 = segment_ids_2
        self.sentence_mask = sentence_mask
        self.word_index = word_index


class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        with open(input_file, "r", encoding="utf-8") as f:
            reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
            lines = []
            for i, line in enumerate(reader):
                if sys.version_info[0] == 2:
                    line = list(unicode(cell, "utf-8") for cell in line)
                if i == 1000:
                    break
                lines.append(line)
            return lines


class TrofiProcessor(DataProcessor):
    """Processor for the TroFi and MOH-X data set."""

    def get_train_examples(self, data_dir, k=None):
        """See base class."""
        if k is not None:
            return self._create_examples(
                self._read_tsv(os.path.join(data_dir, "train" + str(k) + ".tsv")), "train"
            )
        else:
            return self._create_examples(
                self._read_tsv(os.path.join(data_dir, "train.tsv")), "train"
            )

    def get_test_examples(self, data_dir, k=None):
        """See base class."""
        if k is not None:
            return self._create_examples(
                self._read_tsv(os.path.join(data_dir, "test" + str(k) + ".tsv")), "test"
            )
        else:
            return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_dev_examples(self, data_dir, k=None):
        """See base class."""
        if k is not None:
            return self._create_examples(
                self._read_tsv(os.path.join(data_dir, "dev" + str(k) + ".tsv")), "dev"
            )
        else:
            return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            # if i == 4:
            #     break
            guid = "%s-%s" % (set_type, line[0])
            text_a = line[2]
            label = str(int(Decimal(line[1])))
            index = line[3]
            examples.append(
                InputExample(
                    guid=guid, text_a=text_a, text_b=index, label=label, text_a_2=index
                )
            )
        return examples


def convert_examples_to_two_features(
    examples, label_list, max_seq_length, tokenizer, output_mode, use_pos, use_local_context, args, k
):
    """Converts examples to features for sequence classification.

    Args:
        examples (list): List of `InputExample` instances.
        label_list (list): List of labels.
        max_seq_length (int): Maximum sequence length.
        tokenizer: Tokenizer instance.
        output_mode (str): Output mode for the task.
        use_pos (bool): Flag indicating whether to use POS tagging.
        use_local_context (bool): Flag indicating whether to use local context.
        args: Additional arguments.
        k: Index for file naming.

    Returns:
        list: List of `InputFeatures` instances.
    """

    """Loads a data file into a list of `InputBatch`s."""
    label_map = {label: i for i, label in enumerate(label_list)}
    # Define the directory to save the features
    save_dir = "/content"
    resume_index = 0  # Index to resume from (0 for starting from the beginning)

    # Check if there are existing features files
    existing_files = [f for f in os.listdir(save_dir) if f.startswith(f"{str(k)}features_") and f.endswith(".pkl")]
    existing_files.sort()

    if existing_files:
        last_file = existing_files[-1]
        resume_index = int(last_file.split("_")[1].split(".")[0]) + 1

        # Load the last saved features file
        load_path = os.path.join(save_dir, last_file)
        with open(load_path, "rb") as file:
            features = pickle.load(file)
    else:
        features = []
    for (ex_index, example) in tqdm(enumerate(examples)):
        if ex_index < resume_index:
            continue  # Skip examples until the resume index is reached

        try:
            if ex_index % 10000 == 0:
                logger.info("Writing example %d of %d" % (ex_index, len(examples)))

            tokens_a = tokenizer.tokenize(example.text_a)  # tokenize the sentence
            tokens_b = None
            text_b = None
            word_index = int(example.text_b)

            try:
                text_b = int(example.text_b)  # index of target word
                tokens_b = text_b

                # truncate the sentence to max_seq_len
                if len(tokens_a) > max_seq_length - 6:
                    tokens_a = tokens_a[: (max_seq_length - 6)]

                # Find the target word index
                for i, w in enumerate(example.text_a.split()):
                    # If w is a target word, tokenize the word and save to text_b
                    if i == text_b:
                        # consider the index due to models that use a byte-level BPE as a tokenizer (e.g., GPT2, RoBERTa)
                        text_b = tokenizer.tokenize(w) if i == 0 else tokenizer.tokenize(" " + w)
                        break
                    w_tok = tokenizer.tokenize(w) if i == 0 else tokenizer.tokenize(" " + w)
                    # Count number of tokens before the target word to get the target word index
                    if w_tok:
                        tokens_b += len(w_tok) - 1
                if tokens_b + len(text_b) > max_seq_length - 6:
                    continue

            except TypeError:
                if example.text_b:
                    tokens_b = tokenizer.tokenize(example.text_b)
                    # Account for [CLS], [SEP], [SEP] with "- 3"
                    _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
                else:
                    # Account for [CLS] and [SEP] with "- 2"
                    if len(tokens_a) > max_seq_length - 2:
                        tokens_a = tokens_a[: (max_seq_length - 2)]

            tokens = [tokenizer.cls_token] + tokens_a + [tokenizer.sep_token]
            sentence_mask = [0] * max_seq_length
            for ind in range(len(tokens)):
                if tokens[ind] == tokenizer.cls_token:
                    continue
                elif tokens[ind] == tokenizer.sep_token:
                    break
                else:
                    sentence_mask[ind] = 1

            # Local context
            if use_local_context:
                local_start = 1
                local_end = local_start + len(tokens_a)
                comma1 = tokenizer.tokenize(",")[0]
                comma2 = tokenizer.tokenize(" ,")[0]
                if not isinstance(tokens_b, int):
                    tokens_b = tokens_b[0].lstrip('▁')
                    tokens_b = int(tokens_b)
                for i, w in enumerate(tokens):
                    if i < tokens_b + 1 and (w in [comma1, comma2]):
                        local_start = i
                    if i > tokens_b + 1 and (w in [comma1, comma2]):
                        local_end = i
                        break
                segment_ids = [
                    2 if i >= local_start and i <= local_end else 0 for i in range(len(tokens))
                ]

            else:
                segment_ids = [0] * len(tokens)

            # POS tag encoding
            after_token_a = False
            for i, t in enumerate(tokens):
                if t == tokenizer.sep_token:
                    after_token_a = True
                if after_token_a and t != tokenizer.sep_token:
                    segment_ids[i] = 3

            input_ids = tokenizer.convert_tokens_to_ids(tokens)

            try:
                tokens_b += 1  # add 1 to the target word index considering [CLS]
                for i in range(len(text_b)):
                    segment_ids[tokens_b + i] = 1
            except TypeError:
                pass

            input_mask = [1] * len(input_ids)
            padding = [tokenizer.convert_tokens_to_ids(tokenizer.pad_token)] * (
                max_seq_length - len(input_ids)
            )
            input_ids += padding
            input_mask += [0] * len(padding)
            segment_ids += [0] * len(padding)

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            if output_mode == "classification":
                label_id = label_map[example.label]
            else:
                raise KeyError(output_mode)

            # Second features (Target word)
            try:
                definitions = get_definitions(stemmer(example.text_a.split()[int(example.text_a_2)])[1], [], 0)
            except (IndexError, requests.exceptions.ConnectionError) as error:
                if len(example.text_a.split()) <= int(example.text_a_2):
                  example.text_a_2 = len(example.text_a.split()) - 1
                definitions = [example.text_a.split()[int(example.text_a_2)]]
            if not isinstance(definitions, list) or len(definitions) == 0: definitions = [example.text_a.split()[int(example.text_a_2)]]
            chosen_definition = choose_definition(definitions, get_definition_embeddings(definitions), example.text_a)
            tok = tokenizer.tokenize(chosen_definition)


            if len(tok) > max_seq_length - 2:
                tok = tok[: (max_seq_length - 2)]
            tokens = [tokenizer.cls_token] + tok + [tokenizer.sep_token]
            segment_ids_2 = [0] * len(tokens)

            # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
            input_ids_2 = tokenizer.convert_tokens_to_ids(tokens)
            input_mask_2 = [1] * len(input_ids_2)

            padding = [tokenizer.convert_tokens_to_ids(tokenizer.pad_token)] * (
                max_seq_length - len(input_ids_2)
            )
            input_ids_2 += padding
            input_mask_2 += [0] * len(padding)
            segment_ids_2 += [0] * len(padding)

            assert len(input_ids_2) == max_seq_length
            assert len(input_mask_2) == max_seq_length
            assert len(segment_ids_2) == max_seq_length

            features.append(
                InputFeatures(
                    input_ids=input_ids,
                    input_mask=input_mask,
                    segment_ids=segment_ids,
                    label_id=label_id,
                    guid=example.guid + " " + str(example.text_b),
                    input_ids_2=input_ids_2,
                    input_mask_2=input_mask_2,
                    segment_ids_2=segment_ids_2,
                    sentence_mask=sentence_mask,
                    word_index=word_index,
                )
            )

            # Save features periodically
            if ex_index % 1000 == 0:
                save_path = os.path.join(save_dir, f"{str(k)}features_{ex_index}.pkl")
                with open(save_path, "wb") as file:
                    pickle.dump(features, file)
            if ex_index % 10000 == 0:
                files.download(f'{str(k)}features_{ex_index}.pkl')


        except Exception as e:
            # Handle the error
            print(f"Error occurred at example {ex_index}: {e}")
            # Save the features before the error occurred
            save_path = os.path.join(save_dir, f"{str(k)}features_{ex_index}_error.pkl")
            with open(save_path, "wb") as file:
                pickle.dump(features, file)
            # Update the resume index to start from the next example
            resume_index = ex_index + 1
            # Raise the error again to stop the execution (optional)
            raise e

    # Save the remaining features
    if features:
        save_path = os.path.join(save_dir, f"{str(k)}final_features.pkl")
        with open(save_path, "wb") as file:
            pickle.dump(features, file)

    return features


def _truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length."""

    # This is a simple heuristic which will always truncate the longer sequence
    # one token at a time. This makes more sense than truncating an equal percent
    # of tokens from each, since if one sequence is very short then each token
    # that's truncated likely contains more information than a longer sequence.
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()


def simple_accuracy(preds, labels):
    return (preds == labels).mean()


def seq_accuracy(preds, labels):
    acc = []
    for idx, pred in enumerate(preds):
        acc.append((pred == labels[idx]).mean())
    return acc.mean()


def acc_and_f1(preds, labels):
    acc = simple_accuracy(preds, labels)
    f1 = f1_score(y_true=labels, y_pred=preds)
    return {
        "acc": acc,
        "f1": f1,
        "acc_and_f1": (acc + f1) / 2,
    }


def all_metrics(preds, labels):
    acc = simple_accuracy(preds, labels)
    f1 = f1_score(y_true=labels, y_pred=preds)
    pre = precision_score(y_true=labels, y_pred=preds)
    rec = recall_score(y_true=labels, y_pred=preds)
    return {
        "acc": acc,
        "precision": pre,
        "recall": rec,
        "f1": f1,
    }


def compute_metrics(preds, labels):
    assert len(preds) == len(labels)
    return all_metrics(preds, labels)


processors = {
    "trofi": TrofiProcessor,
}

output_modes = {
    "trofi": "classification",
}

In [None]:
#@title Data Loader


def load_train_data(args, logger, processor, task_name, label_list, tokenizer, output_mode, model_type, data_dir, max_seq_length, use_pos, use_local_context, train_batch_size, log_dir, k=None):

    """Load training data and create a DataLoader.

    Args:
        args: Additional arguments.
        logger: Logger instance.
        processor: DataProcessor instance.
        task_name (str): Name of the task.
        label_list (list): List of labels.
        tokenizer: Tokenizer instance.
        output_mode (str): Output mode for the task.
        model_type (str): Type of the model.
        data_dir (str): Directory containing the data.
        max_seq_length (int): Maximum sequence length.
        use_pos (bool): Flag indicating whether to use POS tagging.
        use_local_context (bool): Flag indicating whether to use local context.
        train_batch_size (int): Batch size for training.
        log_dir (str): Directory for logging.
        k: Index for file naming.

    Returns:
        DataLoader: DataLoader for training data.
    """

    train_examples = processor.get_train_examples(data_dir, k)

    train_features = convert_examples_to_two_features(
        examples=train_examples, label_list=label_list, max_seq_length=max_seq_length, tokenizer=tokenizer, output_mode=output_mode, args="args", use_pos=use_pos, use_local_context=use_local_context, k=k
    )

    # make features into tensor
    all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
    all_sentence_masks = torch.tensor([f.sentence_mask for f in train_features], dtype=torch.long)


    all_word_ids = torch.tensor([f.word_index for f in train_features], dtype=torch.long)
    all_input_ids_2 = torch.tensor([f.input_ids_2 for f in train_features], dtype=torch.long)
    all_input_mask_2 = torch.tensor([f.input_mask_2 for f in train_features], dtype=torch.long)
    all_segment_ids_2 = torch.tensor(
        [f.segment_ids_2 for f in train_features], dtype=torch.long
    )
    train_data = TensorDataset(
        all_input_ids,
        all_input_mask,
        all_segment_ids,
        all_label_ids,
        all_input_ids_2,
        all_input_mask_2,
        all_segment_ids_2,
        all_sentence_masks,
        all_word_ids,
    )

    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(
        train_data, sampler=train_sampler, batch_size=train_batch_size
    )
    train_dataloader_file = os.path.join(log_dir, f"{str(k)}{TRAIN_DATALOADER_NAME}")

    torch.save(train_dataloader, train_dataloader_file)
    return train_dataloader


def load_test_data(args, logger, processor, task_name, label_list, tokenizer, output_mode, data_dir, model_type, max_seq_length, eval_batch_size, use_local_context, use_pos, log_dir,  k=None):
    if task_name == "vua":
        eval_examples = processor.get_test_examples(data_dir)
    elif task_name == "trofi":
        eval_examples = processor.get_test_examples(data_dir, k)
    else:
        raise ("task_name should be 'vua' or 'trofi'!")

    if model_type in ["MELBERT_MIP", "MELBERT"]:
        eval_features = convert_examples_to_two_features(
            examples=eval_examples, label_list=label_list, max_seq_length=max_seq_length, tokenizer=tokenizer, output_mode=output_mode, args="args", use_local_context=use_local_context, use_pos=use_pos, k=k
        )

    logger.info("***** Running evaluation *****")
    # if model_type in ["MELBERT_MIP", "MELBERT"]:
    all_input_ids = torch.tensor([f.input_ids for f in tqdm(eval_features)], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in tqdm(eval_features)], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in tqdm(eval_features)], dtype=torch.long)
    all_guids = [f.guid for f in tqdm(eval_features)]
    all_idx = torch.tensor([i for i in tqdm(range(len(eval_features)))], dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in tqdm(eval_features)], dtype=torch.long)
    all_input_ids_2 = torch.tensor([f.input_ids_2 for f in tqdm(eval_features)], dtype=torch.long)
    all_input_mask_2 = torch.tensor([f.input_mask_2 for f in tqdm(eval_features)], dtype=torch.long)
    all_segment_ids_2 = torch.tensor([f.segment_ids_2 for f in tqdm(eval_features)], dtype=torch.long)
    all_sentence_masks = torch.tensor([f.sentence_mask for f in tqdm(eval_features)], dtype=torch.long)
    all_word_ids = torch.tensor([f.word_index for f in tqdm(eval_features)], dtype=torch.long)
    eval_data = TensorDataset(
        all_input_ids,
        all_input_mask,
        all_segment_ids,
        all_label_ids,
        all_idx,
        all_input_ids_2,
        all_input_mask_2,
        all_segment_ids_2,
        all_sentence_masks,
        all_word_ids
    )

    # Run prediction for full data
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=eval_batch_size)


    eval_dataloader_file = os.path.join(log_dir, f"{str(k)}{EVAL_DATALOADER_NAME}")
    guids_file = os.path.join(log_dir, f"{str(k)}{ALL_GUIDS_NAME}")

    torch.save(eval_dataloader, eval_dataloader_file)
    torch.save(all_guids, guids_file)

    return all_guids, eval_dataloader

In [None]:
#@title MelBERT
sentence_transformer = SentenceTransformer("danjohnvelasco/filipino-sentence-roberta-v1")

class AutoModelForSequenceClassification_SPV_MIP(nn.Module):
    """MelBERT model for sequence classification with SPV and MIP.

    Args:
        args: Additional arguments.
        Model: Pre-trained model instance.
        config: Configuration for the model.
        drop_ratio (float): Dropout ratio.
        classifier_hidden (int): Hidden size for the classifier.
        use_average (bool): Flag indicating whether to use average pooling.
        use_cls (bool): Flag indicating whether to use CLS token.
        tokenizer: Tokenizer instance.
        num_labels (int): Number of labels for classification.

    Attributes:
        num_labels (int): Number of labels for classification.
        encoder: Pre-trained model instance.
        config: Configuration for the model.
        dropout: Dropout layer.
        cosine: Cosine similarity layer.
        args: Additional arguments.
        drop_ratio (float): Dropout ratio.
        classifier_hidden (int): Hidden size for the classifier.
        use_average (bool): Flag indicating whether to use average pooling.
        use_cls (bool): Flag indicating whether to use CLS token.
        tokenizer: Tokenizer instance.
        SPV_linear: Linear layer for SPV mechanism.
        MIP_linear: Linear layer for MIP mechanism.
        classifier: Linear layer for final classification.
        logsoftmax: Log-Softmax layer.
    """

    def __init__(self, args, Model, config, drop_ratio, classifier_hidden, use_average, use_cls, tokenizer, num_labels=2):
        """Initialize the model"""
        super(AutoModelForSequenceClassification_SPV_MIP, self).__init__()
        self.num_labels = num_labels
        self.encoder = Model
        self.config = config
        self.dropout = nn.Dropout(drop_ratio)
        self.cosine = nn.CosineSimilarity()
        self.args = args
        self.drop_ratio = drop_ratio
        self.classifier_hidden = classifier_hidden
        self.use_average = use_average
        self.use_cls = use_cls
        self.tokenizer = tokenizer


        self.SPV_linear = nn.Linear(config.hidden_size * 2+1, classifier_hidden)
        self.MIP_linear = nn.Linear(config.hidden_size * 2+1, classifier_hidden)
        self.classifier = nn.Linear(classifier_hidden * 2, num_labels)
        self._init_weights(self.SPV_linear)
        self._init_weights(self.MIP_linear)

        self.logsoftmax = nn.LogSoftmax(dim=1)
        self._init_weights(self.classifier)



    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def forward(
        self,
        input_ids,
        input_ids_2,
        target_mask,
        target_mask_2,
        attention_mask_2,
        token_type_ids=None,
        attention_mask=None,
        labels=None,
        head_mask=None,
        main_sentence_mask=None,
    ):
        """
        Inputs:
            `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the first input token indices in the vocabulary
            `input_ids_2`: a torch.LongTensor of shape [batch_size, sequence_length] with the second input token indicies
            `target_mask`: a torch.LongTensor of shape [batch_size, sequence_length] with the mask for target word in the first input. 1 for target word and 0 otherwise.
            `target_mask_2`: a torch.LongTensor of shape [batch_size, sequence_length] with the mask for target word in the second input. 1 for target word and 0 otherwise.
            `attention_mask_2`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1] for the second input.
            `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices
                selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
            `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1] for the first input.
            `labels`: optional labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length]
                with indices selected in [0, ..., num_labels].
            `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
                It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
        """

        # First encoder for full sentence
        outputs = self.encoder(
            input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
        )
        sequence_output = outputs[0]  # [batch, max_len, hidden]


        # Get target ouput with target mask
        target_output = sequence_output * target_mask.unsqueeze(2)

        # dropout
        target_output = self.dropout(target_output)


        target_output = target_output.mean(1)  # [batch, hidden]
        input_list = input_ids.cpu().detach().numpy().tolist()
        tokens = []
        for input_id in input_list:
            token = self.tokenizer.decode(input_id, skip_special_tokens=True)
            tokens.append(token)
        pooled_output = sentence_transformer.encode(tokens)
        if torch.cuda.is_available():
            pooled_output = torch.tensor(pooled_output, requires_grad=False, dtype=torch.float).cuda()
        else:
            pooled_output = torch.tensor(pooled_output, requires_grad=False, dtype=torch.float)
        cos_sim_SPV = self.cosine(target_output, pooled_output).unsqueeze(1)

        # Second encoder for only the target word

        # Get target ouput with target mask
        input_list = input_ids_2.cpu().detach().numpy().tolist()
        tokens = []
        for input_id in input_list:
            token = self.tokenizer.decode(input_id, skip_special_tokens=True)
            tokens.append(token)
        definition_output = sentence_transformer.encode(tokens)
        if torch.cuda.is_available():
            definition_output = torch.tensor(definition_output, requires_grad=False, dtype=torch.float).cuda()
        else:
            definition_output = torch.tensor(definition_output, requires_grad=False, dtype=torch.float)
        target_output_2 = definition_output
        cos_sim_MIP = self.cosine(target_output_2, target_output).unsqueeze(1)

        # Get hidden vectors each from SPV and MIP linear layers
        SPV_hidden = self.SPV_linear(torch.cat([pooled_output, target_output,cos_sim_SPV], dim=1))
        MIP_hidden = self.MIP_linear(torch.cat([target_output_2, target_output,cos_sim_MIP], dim=1))

        logits = self.classifier(self.dropout(torch.cat([SPV_hidden, MIP_hidden], dim=1)))
        logits = self.logsoftmax(logits)
        if labels is not None:
            loss_fct = nn.NLLLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

            return loss
        return logits

In [None]:
#@title Main


CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"
ARGS_NAME = "training_args.bin"
TRAIN_DATALOADER_NAME = "train_dataloader.bin"
EVAL_DATALOADER_NAME = "eval_dataloader.bin"
ALL_GUIDS_NAME = "all_guids.bin"

OUTPUT_FILE = "/content/drive/MyDrive/CMSC-190/saves/results.txt"


def main():

    # Bert pre-trained model selected in the list (default = roberta-base) (bert-base-cased / roberta-base / albert-base-v1 / albert-large-v1) (default = roberta-base)
    bert_model = "/content/drive/MyDrive/CMSC-190/saves"
    # The input data dir. Should contain the .tsv files (default = vua-20) (vua-20 / vua-18 / vua-verb / moh / trofi / genres / pos )
    data_dir = "/content/drive/MyDrive/CMSC-190/Data2"

    log_dir = "/content/drive/MyDrive/CMSC-190/saves"
    # The name of the task to train (default = vua) (vua(1-fold) / trofi(10-fold))
    task_name = "trofi"
    # The name of model type (default = MELBERT) (BERT_BASE / BERT_SEQ / MELBERT_SPV / MELBERT_MIP / MELBERT)
    model_type = "MELBERT"
    # The hidden dimension for classifier (default = 768)
    classifier_hidden = 768
    # Learning rate scheduler (default = warmup_linear) (none / warmup_linear)
    lr_schedule = "none"
    # Training epochs to perform linear learning rate warmup for. (default = 2)
    warmup_epoch = 0
    # Dropout ratio (default = 0.2)
    drop_ratio = 0.2
    # K-fold (default = 10)
    kfold = 5
    # Number of bagging (default = 0) (0 not for using bagging technique)
    num_bagging = 2
    # The index of bagging only for the case using bagging technique (default = 0)
    bagging_index = 0

    # Use additional linguistic features
    # POS tag (default = True)
    use_pos = True
    # Local context (default = True)
    use_local_context= True

    # The type of sentence embedding representation
    # Can be either spv, mip, both or none (default = both)
    use_average = "both"
    # Can be either spv, mip, both or none (default = none)
    use_cls = "none"

    # The maximum total input sequence length after WordPiece tokenization. (default = 150)
    max_seq_length = 150
    # Whether to run training (default = True)
    do_train = False
    # Whether to run eval on the test set (default = True)
    do_test = True
    # Whether to run eval on the dev set. (default = True)
    do_eval = True
    # Set this flag if you are using an uncased model. (default = False)
    do_lower_case = False
    # Weight of metaphor. (default = 3.0)
    class_weight = 3
    # Total batch size for training. (default = 32)
    train_batch_size = 32
    # Total batch size for eval. (default = 8)
    eval_batch_size = 8
    # The initial learning rate for Adam (default = 3e-5)
    learning_rate = 3e-5
    # Total number of training epochs to perform. (default = 3.0)
    num_train_epoch = 1

    # Whether not to use CUDA when available (default = False)
    no_cuda = False
    # random seed for initialization (default = 1)
    seed = 1

    requests.packages.urllib3.disable_warnings()

    # logger
    if "saves" in bert_model:
        logger = Logger(log_dir)
    else:
      if not os.path.exists("/content/drive/MyDrive/CMSC-190/saves/"):
        os.mkdir("/content/drive/MyDrive/CMSC-190/saves/")
      logger = Logger(log_dir)

    # set CUDA devices
    device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
    n_gpu = torch.cuda.device_count()
    device = device

    logger.info("device: {} n_gpu: {}".format(device, n_gpu))

    # set seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(seed)

    # get dataset and processor
    task_name = task_name.lower()
    processor = processors[task_name]()
    output_mode = output_modes[task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    # build tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(bert_model, do_lower_case=do_lower_case)
    model = load_pretrained_model("args", bert_model=bert_model, model_type=model_type, num_labels=num_labels, device=device, n_gpu=n_gpu, no_cuda=no_cuda, drop_ratio=drop_ratio, classifier_hidden=classifier_hidden, use_average=use_average, use_cls=use_cls, tokenizer=tokenizer)


    ########### Training ###########
    # TroFi / MOH-X (K-fold)
    if do_train and task_name == "trofi":
        k_result = []
        for k in tqdm(range(kfold), desc="K-fold"):
            model = load_pretrained_model("args", bert_model=bert_model, model_type=model_type, num_labels=num_labels, device=device, n_gpu=n_gpu, no_cuda=no_cuda, drop_ratio=drop_ratio, classifier_hidden=classifier_hidden, use_average=use_average, use_cls=use_cls, tokenizer=tokenizer)
            # train_dataloader = load_train_data(
            #     "args", logger, processor, task_name, label_list, tokenizer, output_mode, model_type=model_type, data_dir=data_dir, max_seq_length=max_seq_length, train_batch_size=train_batch_size, use_pos=use_pos, use_local_context=use_local_context, log_dir=log_dir, k=k
            # )
            train_dataloader = torch.load(os.path.join(log_dir, f"{str(k)}{TRAIN_DATALOADER_NAME}"))
            model, best_result = run_train(
            args="args",
            logger=logger,
            model=model,
            train_dataloader=train_dataloader,
            processor=processor,
            task_name=task_name,
            label_list=label_list,
            tokenizer=tokenizer,
            output_mode=output_mode,
            num_train_epoch=num_train_epoch,
            lr_schedule=lr_schedule,
            warmup_epoch=warmup_epoch,
            train_batch_size=train_batch_size,
            device=device,
            model_type=model_type,
            class_weight=class_weight,
            num_labels=num_labels,
            n_gpu=n_gpu,
            do_eval=do_eval,
            learning_rate=learning_rate,
            max_seq_length=max_seq_length,
            eval_batch_size=eval_batch_size,
            data_dir=data_dir,
            use_local_context=use_local_context,
            use_pos=use_pos,
            bert_model=bert_model,
            do_lower_case=do_lower_case,
            k=k,
            log_dir=log_dir
            )
            k_result.append(best_result)

        # Calculate average result
        avg_result = copy.deepcopy(k_result[0])
        for result in k_result[1:]:
            for k, v in result.items():
                avg_result[k] += v
        for k, v in avg_result.items():
            avg_result[k] /= len(k_result)

        logger.info(f"-----Averge Result-----")
        for key in sorted(avg_result.keys()):
            logger.info(f"  {key} = {str(avg_result[key])}")

    # Load trained model
    if "saves" in bert_model:
        model = load_trained_model(args="args", model=model, tokenizer=tokenizer, log_dir=log_dir)

    ########### Inference ###########
    # TroFi / MOH-X (K-fold)
    if (do_eval or do_test) and task_name == "trofi":
        logger.info(f"***** Evaluating with {data_dir}")
        k_result = []
        for k in tqdm(range(kfold), desc="K-fold"):
            all_guids, eval_dataloader = load_test_data(
                args="args", logger=logger, processor=processor, task_name=task_name, label_list=label_list, tokenizer=tokenizer, output_mode=output_mode, data_dir=data_dir, model_type=model_type, max_seq_length=max_seq_length, eval_batch_size=eval_batch_size, use_local_context=use_local_context, use_pos=use_pos, log_dir=log_dir, k=k
            )
            # all_guids = torch.load(os.path.join(log_dir, f"{str(k)}{ALL_GUIDS_NAME}"))
            # eval_dataloader = torch.load(os.path.join(log_dir, f"{str(k)}{EVAL_DATALOADER_NAME}"))
            result = run_eval(args="args", logger=logger, model=model, eval_dataloader=eval_dataloader, all_guids=all_guids, task_name=task_name, bert_model=bert_model, do_lower_case=do_lower_case, device=device, model_type=model_type, num_labels=num_labels, during_train=False, k=k)
            k_result.append(result)

        # Calculate average result
        avg_result = copy.deepcopy(k_result[0])
        for result in k_result[1:]:
            for k, v in result.items():
                avg_result[k] += v
        for k, v in avg_result.items():
            avg_result[k] /= len(k_result)

        logger.info(f"-----Averge Result-----")
        for key in sorted(avg_result.keys()):
            logger.info(f"  {key} = {str(avg_result[key])}")
    logger.info(f"Saved to {logger.log_dir}")


def run_train(
    args,
    logger,
    model,
    train_dataloader,
    processor,
    task_name,
    label_list,
    tokenizer,
    output_mode,
    num_train_epoch,
    lr_schedule,
    warmup_epoch,
    train_batch_size,
    device,
    model_type,
    class_weight,
    num_labels,
    n_gpu,
    do_eval,
    learning_rate,
    max_seq_length,
    eval_batch_size,
    data_dir,
    use_local_context,
    use_pos,
    bert_model,
    do_lower_case,
    log_dir,
    k=None,
):
    
    """Run the training loop for the model.

    Args:
        args: Additional arguments.
        logger: Logger for logging information.
        model: Pre-trained model instance.
        train_dataloader: DataLoader for training data.
        processor: Processor for data preprocessing.
        task_name: Name of the task.
        label_list: List of labels.
        tokenizer: Tokenizer instance.
        output_mode: Output mode for the model.
        num_train_epoch: Number of training epochs.
        lr_schedule: Learning rate schedule.
        warmup_epoch: Number of warmup epochs.
        train_batch_size: Batch size for training.
        device: Device for training (e.g., "cuda" or "cpu").
        model_type: Type of the model.
        class_weight: Class weights for loss computation.
        num_labels: Number of labels for classification.
        n_gpu: Number of GPUs available.
        do_eval: Flag indicating whether to perform evaluation during training.
        learning_rate: Learning rate for optimization.
        max_seq_length: Maximum sequence length.
        eval_batch_size: Batch size for evaluation.
        data_dir: Directory containing the training data.
        use_local_context: Flag indicating whether to use local context.
        use_pos: Flag indicating whether to use part-of-speech information.
        bert_model: BERT model name or path.
        do_lower_case: Flag indicating whether to convert text to lowercase.
        log_dir: Directory for saving logs.
        k: Optional parameter for additional configuration.

    Returns:
        tuple: Trained model and the best evaluation result.
    """

    tr_loss = 0
    num_train_optimization_steps = len(train_dataloader) * num_train_epoch

    # Prepare optimizer, scheduler
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.01,
        },
        {
            "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
    if lr_schedule != False or lr_schedule.lower() != "none":
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=int(warmup_epoch * len(train_dataloader)),
            num_training_steps=num_train_optimization_steps,
        )

    logger.info("***** Running training *****")
    logger.info(f"  Batch size = {train_batch_size}")
    logger.info(f"  Num steps = { num_train_optimization_steps}")

    # Run training
    model.train()
    max_val_f1 = -1
    max_result = {}
    for epoch in trange(int(num_train_epoch), desc="Epoch"):
        tr_loss = 0
        for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
            # move batch data to gpu
            batch = tuple(t.to(device) for t in batch)

            (
                input_ids,
                input_mask,
                segment_ids,
                label_ids,
                input_ids_2,
                input_mask_2,
                segment_ids_2,
                sentence_mask,
                word_indices
            ) = batch


            # compute loss values

            logits = model(
                input_ids,
                input_ids_2,
                target_mask=(segment_ids == 1),
                target_mask_2=segment_ids_2,
                attention_mask_2=input_mask_2,
                token_type_ids=segment_ids,
                attention_mask=input_mask,
                main_sentence_mask=sentence_mask,
            )
            loss_fct = nn.NLLLoss(weight=torch.Tensor([1, class_weight]).to(device))
            loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))

            # average loss if on multi-gpu.
            if n_gpu > 1:
                loss = loss.mean()

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            if lr_schedule != False or lr_schedule.lower() != "none":
                scheduler.step()

            optimizer.zero_grad()

            tr_loss += loss.item()

        cur_lr = optimizer.param_groups[0]["lr"]
        logger.info(f"[epoch {epoch+1}] ,lr: {cur_lr} ,tr_loss: {tr_loss}")

        # evaluate
        if do_eval:
            # all_guids, eval_dataloader = load_test_data(
            #     args="args", logger=logger, processor=processor, task_name=task_name, label_list=label_list, tokenizer=tokenizer, output_mode=output_mode, data_dir=data_dir, model_type=model_type, max_seq_length=max_seq_length, eval_batch_size=eval_batch_size, use_local_context=use_local_context, use_pos=use_pos, k=k
            # )
            all_guids = torch.load(os.path.join(log_dir, f"{str(k)}{ALL_GUIDS_NAME}"))
            eval_dataloader = torch.load(os.path.join(log_dir, f"{str(k)}{EVAL_DATALOADER_NAME}"))

            result = run_eval(args="args", logger=logger, model=model, eval_dataloader=eval_dataloader, all_guids=all_guids, task_name=task_name, bert_model=bert_model, do_lower_case=do_lower_case, device=device, model_type=model_type, num_labels=num_labels, during_train=True, k=k)

            # update
            if result["f1"] > max_val_f1:
                max_val_f1 = result["f1"]
                max_result = result
                if task_name == "trofi":
                    save_model(args, model, tokenizer, log_dir)
            if task_name == "vua":
                save_model(args, model, tokenizer, data_dir)

    logger.info(f"-----Best Result-----")
    for key in sorted(max_result.keys()):
        logger.info(f"  {key} = {str(max_result[key])}")

    return model, max_result


def get_metaphor_groups(lst):
    """Identify and retrieve groups of connected metaphors in a list.

    Args:
        lst (list): A list of binary values indicating metaphor presence.

    Returns:
        list: A list of lists, where each inner list contains indexes of connected metaphors.
    """
    indexes = []
    current_group = []
    for i, num in enumerate(lst):
        if num == 1:
            current_group.append(i)
        else:
            if current_group:
                indexes.append(current_group)
                current_group = []
    if current_group:
        indexes.append(current_group)
    return indexes


def get_correct_predictions(actual_labels, predicted_labels):
    """Identify correctly predicted metaphors and additional information.

    Args:
        actual_labels (list): List of indexes representing actual metaphor presence.
        predicted_labels (list): List of indexes representing predicted metaphor presence.

    Returns:
        tuple: A tuple containing the count of actual metaphors, count of correctly predicted metaphors,
        list of correctly predicted metaphors, flag indicating wrong literal prediction, flags for
        actual multiword and predicted multiword metaphors.
    """
    correct_prediction = []
    wrong_literal = False
    actual_multiword = False
    predicted_multiword = False
    for indexes in actual_labels:
        if indexes in predicted_labels:
            if len(indexes) > 1:
                actual_multiword = True
                predicted_multiword = True
            correct_prediction.append(indexes)
        elif len(indexes) > 1:
            actual_multiword = True

    if len(actual_labels) == 0 and len(predicted_labels) > len(actual_labels):
        wrong_literal = True
    return (
        len(actual_labels),
        len(correct_prediction),
        correct_prediction,
        wrong_literal,
        actual_multiword,
        predicted_multiword,
    )


def run_eval(args, logger, model, eval_dataloader, all_guids, task_name, bert_model, do_lower_case, device, model_type, num_labels, during_train, k,  return_preds=False):

    """Run evaluation on the given model and evaluation data.

    Args:
        args: Arguments for evaluation.
        logger: Logger for logging information.
        model: The model to be evaluated.
        eval_dataloader: DataLoader for evaluation data.
        all_guids: List of GUIDs for reference.
        task_name: Name of the evaluation task.
        bert_model: BERT model used in the evaluation.
        do_lower_case: Flag indicating whether to convert tokens to lowercase.
        device: Device for running the evaluation.
        model_type: Type of the model being used.
        num_labels: Number of labels in the classification task.
        during_train: Flag indicating whether evaluation is during training.
        k: A parameter for file naming.
        return_preds: Flag indicating whether to return predictions.

    Returns:
        dict: A dictionary containing evaluation metrics.
    """
    
    model.eval()
    writer = open(OUTPUT_FILE, "a")
    single_meta_file = open(f"/content/drive/MyDrive/CMSC-190/saves/single_metaphors{k}.txt", "w")
    multiple_meta_file = open(f"/content/drive/MyDrive/CMSC-190/saves/multiple_metaphors{k}.txt", "w")
    multiword_meta_file = open(f"/content/drive/MyDrive/CMSC-190/saves/multiword_metaphors{k}.txt", "w")
    literal_file = open(f"/content/drive/MyDrive/CMSC-190/saves/literals{k}.txt", "w")
    tokenizer = AutoTokenizer.from_pretrained(bert_model, do_lower_case=do_lower_case)
    eval_loss = 0
    nb_eval_steps = 0
    correct_metaphors = 0
    correct_literals = 0
    total_sentences = 0
    total_metaphors = 0
    total_literals = 0
    total_single_meta = 0
    total_multiple_meta = 0
    total_multiword_meta = 0
    correct_single_meta = 0
    correct_multiple_meta = 0
    correct_multiword_meta = 0
    preds = []
    pred_guids = []
    out_label_ids = None
    prev_sentence = None
    predicted_labels = []
    actual_labels = []

    for eval_batch in tqdm(eval_dataloader, desc="Evaluating"):
        eval_batch = tuple(t.to(device) for t in eval_batch)

        (
            input_ids,
            input_mask,
            segment_ids,
            label_ids,
            idx,
            input_ids_2,
            input_mask_2,
            segment_ids_2,
            sentence_mask,
            word_indices
        ) = eval_batch

        with torch.no_grad():
            # compute loss values
            logits = model(
                input_ids,
                input_ids_2,
                target_mask=(segment_ids == 1),
                target_mask_2=segment_ids_2,
                attention_mask_2=input_mask_2,
                token_type_ids=segment_ids,
                attention_mask=input_mask,
                main_sentence_mask=sentence_mask,
            )
            loss_fct = nn.NLLLoss()
            tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))


            input_list = input_ids.cpu().detach().numpy().tolist()
            tokens = []
            for input_id in input_list:
                token = tokenizer.decode(input_id, skip_special_tokens=True)
                tokens.append(token)
            writer.write("-------------------------------------\n")
            writer.write("Writing new batch\n")
            for i in range(len(tokens)): #this is sentence
                curr_sentence = tokens[i]
                if curr_sentence != prev_sentence:
                    total_sentences += 1
                    predictions = get_correct_predictions(get_metaphor_groups(actual_labels), get_metaphor_groups(predicted_labels))

                    if predictions[0] > 0:
                        total_metaphors += predictions[0]
                        correct_metaphors += predictions[1]
                        if during_train == False:
                            # Single metaphor
                            if predictions[0] == 1 and predictions[4] == False:
                                total_single_meta += 1
                                # Predicted Correctly
                                if predictions[1] == 1 and predictions[5] == False:
                                    correct_single_meta += 1
                                    single_meta_file.write(f"{prev_sentence}\n")
                                    single_meta_file.write(f"Metaphors detected {predictions[2]}\n")
                            # Multiple metaphor
                            elif predictions[0] > 1 and predictions[4] == False:
                                total_multiple_meta += predictions[0]
                                #Predicted Correctly
                                if predictions[1] > 1 and predictions[5] == False:
                                    correct_multiple_meta += predictions[1]
                                    multiple_meta_file.write(f"{prev_sentence}\n")
                                    multiple_meta_file.write(f"Metaphors detected {predictions[2]}\n")
                            # Multi word metaphor regardless of quantity
                            elif predictions[4] == True:
                                total_multiword_meta += predictions[0]
                                # Predicted Correctly
                                if predictions[5] == True:
                                    correct_multiword_meta += predictions[1]
                                    multiword_meta_file.write(f"{prev_sentence}\n")
                                    multiword_meta_file.write(f"Metaphors detected {predictions[2]}\n")
                    elif predictions[0] == 0 and predictions[3] == False:
                        total_literals += 1
                        correct_literals += 1
                        if during_train == False:
                            literal_file.write(f"{prev_sentence}\n")
                    elif predictions[0] == 0 and predictions[3] == True:
                        total_literals += 1


                    writer.write(f"\nMetaphors detected {predictions[2]}\n")
                    prev_sentence = curr_sentence
                    predicted_labels.clear()
                    actual_labels.clear()

                judge = ""
                writer.write("Sentence " + str(i) + ": ")
                writer.write(tokens[i])
                writer.write("\n")
                writer.write(str(word_indices.cpu().numpy()[i]))
                writer.write("\n")
                writer.write("Prediction: ")
                np.savetxt(writer, torch.nn.functional.softmax(logits, dim=1).detach().cpu().numpy()[i], fmt="%10.5f", delimiter=' ', newline=' ')
                if torch.nn.functional.softmax(logits, dim=1).detach().cpu().numpy()[i][0] > torch.nn.functional.softmax(logits, dim=1).detach().cpu().numpy()[i][1] and int(label_ids.detach().cpu().numpy()[i]) == 0:
                    # correct_literals+=1
                    predicted_labels.append(0)
                    actual_labels.append(0)
                elif torch.nn.functional.softmax(logits, dim=1).detach().cpu().numpy()[i][0] < torch.nn.functional.softmax(logits, dim=1).detach().cpu().numpy()[i][1] and int(label_ids.detach().cpu().numpy()[i]) == 1:
                    # correct_metaphors+=1
                    predicted_labels.append(1)
                    actual_labels.append(1)
                elif torch.nn.functional.softmax(logits, dim=1).detach().cpu().numpy()[i][0] > torch.nn.functional.softmax(logits, dim=1).detach().cpu().numpy()[i][1] and int(label_ids.detach().cpu().numpy()[i]) == 1:
                    judge = "\t\t!!!WRONG!!! CORRECT LABEL = META, PRED = LIT"
                    predicted_labels.append(0)
                    actual_labels.append(1)
                elif torch.nn.functional.softmax(logits, dim=1).detach().cpu().numpy()[i][0] < torch.nn.functional.softmax(logits, dim=1).detach().cpu().numpy()[i][1] and int(label_ids.detach().cpu().numpy()[i]) == 0:
                    judge = "\t\t!!!WRONG!!! CORRECT LABEL = LIT, PRED = META"
                    predicted_labels.append(1)
                    actual_labels.append(0)
                writer.write("\n")
                writer.write("Label: " + str(label_ids.detach().cpu().numpy()[i]) + str(judge) + "\n\n")

            writer.write("Loss:" + str(tmp_eval_loss.cpu().numpy()) + "\n")
            writer.write("-------------------------------------\n")
            eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1

            if len(preds) == 0:
                preds.append(logits.detach().cpu().numpy())
                pred_guids.append([all_guids[i] for i in idx])
                out_label_ids = label_ids.detach().cpu().numpy()
            else:
                preds[0] = np.append(preds[0], logits.detach().cpu().numpy(), axis=0)
                pred_guids[0].extend([all_guids[i] for i in idx])
                out_label_ids = np.append(
                    out_label_ids, label_ids.detach().cpu().numpy(), axis=0
                )
    try:
        writer.write(f"Out of {total_metaphors} Metaphors, {correct_metaphors} were guessed correctly, or {correct_metaphors/total_metaphors*100}%\n")
        writer.write(f"Out of {total_literals} Literals, {correct_literals} were guessed correctly, or {correct_literals/total_literals*100}%\n")
    except ZeroDivisionError:
        writer.write(f"Out of {total_metaphors} Metaphoricals, {correct_metaphors} were guessed correctly, or {correct_metaphors}%\n")
        writer.write(f"Out of {total_literals} Literals, {correct_literals} were guessed correctly, or {correct_literals}%\n")
    writer.write(f"There are {total_sentences} total sentences in this file\n")
    writer.write(f"There are {total_single_meta} total sentences with only 1 metaphor in this file\n")
    writer.write(f"There are {total_multiple_meta} total sentences with multiple metaphors in this file\n")
    writer.write(f"There are {total_multiword_meta} total sentences with multiple metaphors in this file\n")
    writer.write(f"There are {correct_single_meta} single metaphors guessed correctly\n")
    writer.write(f"There are {correct_multiple_meta} multiple metaphors guessed correctly\n")
    writer.write(f"There are {correct_multiword_meta} multiword metaphors guessed correctly\n")

    eval_loss = eval_loss / nb_eval_steps
    preds = preds[0]
    preds = np.argmax(preds, axis=1)
    print(preds, out_label_ids)

    # compute metrics
    result = compute_metrics(preds, out_label_ids)

    for key in sorted(result.keys()):
        logger.info(f"  {key} = {str(result[key])}")

    writer.write(f"Preds {preds}")

    if return_preds:
        return preds
    return result


def load_pretrained_model(args, bert_model, model_type, num_labels, device, n_gpu, no_cuda, drop_ratio, classifier_hidden, use_average, use_cls, tokenizer):
    """Load a pre-trained model for sequence classification with additional layers.

    Args:
        args: Arguments for model loading.
        bert_model: Pre-trained BERT model name or path.
        model_type: Type of the model.
        num_labels: Number of labels in the classification task.
        device: Device for running the model.
        n_gpu: Number of GPUs available.
        no_cuda: Flag indicating whether to use CUDA.
        drop_ratio: Dropout ratio.
        classifier_hidden: Hidden size for classifier layers.
        use_average: Flag indicating whether to use average pooling.
        use_cls: Flag indicating whether to use [CLS] token for classification.
        tokenizer: Tokenizer for processing input.

    Returns:
        nn.Module: Loaded pre-trained model.
    """
    bert = AutoModel.from_pretrained(bert_model)
    config = bert.config
    config.type_vocab_size = 4
    bert.embeddings.token_type_embeddings = nn.Embedding(
        config.type_vocab_size, config.hidden_size
    )
    bert._init_weights(bert.embeddings.token_type_embeddings)


    # Additional Layers
    model = AutoModelForSequenceClassification_SPV_MIP(
        args=args, Model=bert, config=config, num_labels=num_labels, drop_ratio=drop_ratio, classifier_hidden=classifier_hidden, use_average=use_average, use_cls=use_cls, tokenizer=tokenizer
    )
    summary(
        model,
        )
    model.to(device)
    if n_gpu > 1 and not no_cuda:
        model = torch.nn.DataParallel(model)
    return model


def save_model(args, model, tokenizer, log_dir):
    """Save the trained model, configuration, and tokenizer vocabulary.

    Args:
        args: Arguments used during training.
        model: The trained model to be saved.
        tokenizer: Tokenizer used for processing input.
        log_dir: Directory to save the model files.
    """
    model_to_save = (
        model.module if hasattr(model, "module") else model
    )  # Only save the model it-self

    # If we save using the predefined names, we can load using `from_pretrained`
    output_model_file = os.path.join(log_dir, WEIGHTS_NAME)
    output_config_file = os.path.join(log_dir, CONFIG_NAME)

    torch.save(model_to_save.state_dict(), output_model_file)
    model_to_save.config.to_json_file(output_config_file)
    tokenizer.save_vocabulary(log_dir)

    # Good practice: save your training arguments together with the trained model
    output_args_file = os.path.join(log_dir, ARGS_NAME)
    torch.save(args, output_args_file)


def load_trained_model(args, model, tokenizer, log_dir):
    """Load a trained model from the specified directory.

    Args:
        args: Arguments for loading the trained model.
        model: The model to load the weights into.
        tokenizer: Tokenizer used for processing input.
        log_dir: Directory containing the saved model files.

    Returns:
        nn.Module: Model with loaded weights.
    """
    # If we save using the predefined names, we can load using `from_pretrained`
    output_model_file = os.path.join(log_dir, WEIGHTS_NAME)

    if hasattr(model, "module"):
        model.module.load_state_dict(torch.load(output_model_file, map_location=torch.device('cpu')))
    else:
        model.load_state_dict(torch.load(output_model_file, map_location=torch.device('cpu')))

    return model


if __name__ == "__main__":
    main()

device: cuda n_gpu: 1
Some weights of the model checkpoint at /content/drive/MyDrive/CMSC-190/saves were not used when initializing XLMRobertaModel: ['encoder.encoder.layer.4.attention.self.key.weight', 'encoder.encoder.layer.5.attention.self.value.bias', 'encoder.encoder.layer.1.output.dense.weight', 'encoder.encoder.layer.0.output.LayerNorm.weight', 'encoder.encoder.layer.4.attention.output.dense.weight', 'encoder.encoder.layer.5.attention.output.dense.bias', 'encoder.encoder.layer.6.intermediate.dense.weight', 'encoder.encoder.layer.6.attention.self.key.bias', 'encoder.encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.embeddings.LayerNorm.weight', 'encoder.encoder.layer.5.intermediate.dense.bias', 'encoder.encoder.layer.5.attention.self.value.weight', 'encoder.encoder.layer.3.attention.self.value.weight', 'encoder.encoder.layer.2.attention.output.dense.bias', 'encoder.encoder.layer.0.attention.output.dense.bias', 'encoder.encoder.layer.11.attention.output.LayerNorm.weight'

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


1it [00:01,  1.44s/it][A
  result = getattr(asarray(obj), method)(*args, **kwds)

3it [00:03,  1.00s/it][A
4it [00:03,  1.27it/s][A
5it [00:04,  1.39it/s][A
6it [00:05,  1.01it/s][A
7it [00:07,  1.15s/it][A
8it [00:07,  1.11it/s][A
9it [00:07,  1.33it/s][A
10it [00:08,  1.32it/s][A
11it [00:11,  1.51s/it][A
12it [00:12,  1.23s/it][A
13it [00:12,  1.01it/s][A
14it [00:13,  1.24it/s][A
15it [00:14,  1.13it/s][A
16it [00:15,  1.06it/s][A
17it [00:16,  1.01s/it][A
18it [00:17,  1.20it/s][A
19it [00:17,  1.41it/s][A
20it [00:19,  1.13s/it][A
21it [00:20,  1.04s/it][A
22it [00:21,  1.07s/it][A
23it [00:22,  1.01s/it][A
24it [00:23,  1.10s/it][A
25it [00:25,  1.36s/it][A
26it [00:27,  1.43s/it][A
27it [00:31,  2.19s/it][A
28it [00:32,  1.92s/it][A
29it [00:32,  1.47s/it][A
30it [00:33,  1.31s/it][A
31it [00:34,  1.09s/it][A
32it [00:36,  1.32s/it][A
33it [00:37,  1.20s/it][A
34it [00:37,  1.07it/s][A
35it [00:39,  1.33s/it][A
36it [00:41,  1.53s/it][A
37it 

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 

K-fold:  20%|██        | 1/5 [15:34<1:02:16, 934.10s/it]
  result = getattr(asarray(obj), method)(*args, **kwds)
  result = getattr(asarray(obj), method)(*args, **kwds)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


1it [00:01,  1.98s/it][A
2it [00:02,  1.13s/it][A
3it [00:03,  1.16s/it][A
4it [00:04,  1.03it/s][A
5it [00:04,  1.35it/s][A
6it [00:06,  1.03s/it][A
7it [00:08,  1.27s/it][A
8it [00:08,  1.00it/s][A
9it [00:09,  1.02it/s][A
10it [00:09,  1.26it/s][A
11it [00:10,  1.12it/s][A
12it [00:12,  1.01it/s][A
13it [00:12,  1.28it/s][A
15it [00:13,  1.57it/s][A
16it [00:15,  1.03it/s][A
17it [00:17,  1.33s/it][A
18it [00:19,  1.38s/it][A
19it [00:19,  1.17s/it][A
20it [00:20,  1.16s/it][A
21it [00:21,  1.02it/s][A
22it [00:21,  1.26it/s][A
23it [00:22,  1.50it/s][A
24it [00:22,  1.64it/s][A
25it [00:24,  1.11it/s][A
26it [00:24,  1.25it/s][A
27it [00:25,  1.25it/s][A
28it [00:27,  1.00it/s][A
29it [00:28,  1.23s/it][A
30it [00:29,  1.17s/it][A
31it [00:30,  1.02it/s][A
32it [00:31,  1.06it/s][A
33it [00:31,  1.32it/s][A
34it [00:31,  1.59it/s][A
35it [00:32,  1.84it/s][A
36it [00:33,  1.46it/s][A
37it [00:38,  1.96s/it][A
38it [00:39,  1.86s/it][A
39it [00:

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 

K-fold:  40%|████      | 2/5 [30:41<45:55, 918.55s/it]  
  result = getattr(asarray(obj), method)(*args, **kwds)
  result = getattr(asarray(obj), method)(*args, **kwds)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


1it [00:01,  1.71s/it][A
2it [00:02,  1.30s/it][A
3it [00:03,  1.15it/s][A
4it [00:03,  1.17it/s][A
5it [00:04,  1.39it/s][A
6it [00:05,  1.19it/s][A
7it [00:06,  1.28it/s][A
8it [00:06,  1.26it/s][A
9it [00:08,  1.16s/it][A
10it [00:09,  1.05s/it][A
11it [00:10,  1.10it/s][A
12it [00:10,  1.24it/s][A
13it [00:11,  1.34it/s][A
14it [00:11,  1.54it/s][A
15it [00:12,  1.42it/s][A
16it [00:13,  1.69it/s][A
17it [00:14,  1.39it/s][A
18it [00:14,  1.66it/s][A
19it [00:15,  1.57it/s][A
20it [00:16,  1.35it/s][A
21it [00:16,  1.44it/s][A
22it [00:17,  1.63it/s][A
23it [00:18,  1.41it/s][A
24it [00:18,  1.60it/s][A
25it [00:18,  1.88it/s][A
26it [00:19,  1.87it/s][A
27it [00:19,  2.10it/s][A
28it [00:20,  2.32it/s][A
29it [00:20,  2.24it/s][A
30it [00:20,  2.43it/s][A
31it [00:21,  2.47it/s][A
32it [00:21,  2.17it/s][A
33it [00:22,  2.27it/s][A
34it [00:22,  2.07it/s][A
35it [00:24,  1.23it/s][A
36it [00:26,  1.14s/it][A
37it [00:27,  1.08s/it][A
38it [00:

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 

K-fold:  60%|██████    | 3/5 [45:57<30:34, 917.32s/it]
  result = getattr(asarray(obj), method)(*args, **kwds)
  result = getattr(asarray(obj), method)(*args, **kwds)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


1it [00:00,  1.21it/s][A
2it [00:01,  1.80it/s][A
3it [00:02,  1.04it/s][A
4it [00:03,  1.36it/s][A
5it [00:05,  1.37s/it][A
6it [00:05,  1.02s/it][A
7it [00:07,  1.14s/it][A
8it [00:07,  1.00s/it][A
9it [00:10,  1.48s/it][A
10it [00:11,  1.30s/it][A
11it [00:11,  1.08s/it][A
12it [00:14,  1.44s/it][A
13it [00:14,  1.11s/it][A
14it [00:16,  1.21s/it][A
15it [00:19,  1.88s/it][A
16it [00:20,  1.49s/it][A
17it [00:20,  1.17s/it][A
18it [00:20,  1.02it/s][A
19it [00:21,  1.27it/s][A
20it [00:22,  1.18it/s][A
21it [00:22,  1.28it/s][A
22it [00:25,  1.17s/it][A
23it [00:25,  1.09it/s][A
24it [00:26,  1.06s/it][A
25it [00:27,  1.16it/s][A
26it [00:27,  1.43it/s][A
27it [00:28,  1.48it/s][A
28it [00:29,  1.08it/s][A
29it [00:29,  1.34it/s][A
30it [00:30,  1.47it/s][A
31it [00:31,  1.48it/s][A
32it [00:32,  1.19it/s][A
33it [00:33,  1.28it/s][A
34it [00:34,  1.14it/s][A
35it [00:34,  1.22it/s][A
36it [00:36,  1.07it/s][A
37it [00:36,  1.30it/s][A
38it [00:

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 

K-fold:  80%|████████  | 4/5 [1:01:23<15:20, 920.89s/it]
  result = getattr(asarray(obj), method)(*args, **kwds)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


  result = getattr(asarray(obj), method)(*args, **kwds)

2it [00:00,  2.54it/s][A
3it [00:01,  2.57it/s][A
4it [00:01,  2.54it/s][A
5it [00:02,  1.31it/s][A
6it [00:04,  1.14it/s][A
7it [00:05,  1.04s/it][A
8it [00:05,  1.22it/s][A
9it [00:06,  1.45it/s][A
10it [00:06,  1.58it/s][A
11it [00:07,  1.84it/s][A
12it [00:07,  2.01it/s][A
13it [00:07,  2.22it/s][A
14it [00:08,  2.38it/s][A
15it [00:08,  2.56it/s][A
16it [00:08,  2.71it/s][A
17it [00:09,  1.89it/s][A
18it [00:10,  1.87it/s][A
19it [00:10,  1.79it/s][A
20it [00:13,  1.19s/it][A
21it [00:14,  1.08s/it][A
22it [00:15,  1.05s/it][A
23it [00:15,  1.11it/s][A
24it [00:16,  1.16it/s][A
25it [00:17,  1.39it/s][A
26it [00:17,  1.61it/s][A
27it [00:18,  1.48it/s][A
28it [00:19,  1.32it/s][A
29it [00:19,  1.52it/s][A
30it [00:20,  1.70it/s][A
31it [00:21,  1.03it/s][A
32it [00:24,  1.38s/it][A
33it [00:25,  1.48s/it][A
34it [00:28,  1.70s/it][A
35it [00:29,  1.58s/it][A
36it [00:30,  1.34s/it][A
37it 

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 

K-fold: 100%|██████████| 5/5 [1:16:55<00:00, 923.14s/it]
-----Averge Result-----
  acc = 0.993993993993994
  f1 = 0.03333333333333334
  precision = 0.2
  recall = 0.01818181818181818
Saved to /content/drive/MyDrive/CMSC-190/saves
