In [1]:
!source ../../../scripts/mvp/bin/activate
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
!export CUDA_VISIBLE_DEVICES=1
%set_env CUDA_VISIBLE_DEVICES=1

env: CUDA_VISIBLE_DEVICES=1


In [142]:
import argparse
import os
import sys
import gc
import logging
import pickle
import random
import json
import time
import re
from itertools import permutations
from functools import partial
from collections import Counter
import pandas as pd

utils = os.path.abspath('../../utils/') # Relative path to utils scripts
sys.path.append(utils)

from evaluation import createResults, convertLabels

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, TQDMProgressBar, LearningRateMonitor

from transformers import AdamW, T5Tokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup
from transformers.file_utils import ModelOutput
from transformers.models.t5.modeling_t5 import *

from const import *

from tqdm import tqdm

In [143]:
LABEL_SPACE = ['ambience general:POSITIVE', 'ambience general:NEUTRAL', 'ambience general:NEGATIVE', 'drinks prices:POSITIVE', 'drinks prices:NEUTRAL', 'drinks prices:NEGATIVE', 'drinks quality:POSITIVE', 'drinks quality:NEUTRAL', 'drinks quality:NEGATIVE', 'drinks style_options:POSITIVE', 'drinks style_options:NEUTRAL', 'drinks style_options:NEGATIVE', 'food prices:POSITIVE', 'food prices:NEUTRAL', 'food prices:NEGATIVE', 'food quality:POSITIVE', 'food quality:NEUTRAL', 'food quality:NEGATIVE', 'food style_options:POSITIVE', 'food style_options:NEUTRAL', 'food style_options:NEGATIVE', 'location general:POSITIVE', 'location general:NEUTRAL', 'location general:NEGATIVE', 'restaurant general:POSITIVE', 'restaurant general:NEUTRAL', 'restaurant general:NEGATIVE', 'restaurant miscellaneous:POSITIVE', 'restaurant miscellaneous:NEUTRAL', 'restaurant miscellaneous:NEGATIVE', 'restaurant prices:POSITIVE', 'restaurant prices:NEUTRAL', 'restaurant prices:NEGATIVE', 'service general:POSITIVE', 'service general:NEUTRAL', 'service general:NEGATIVE']

LANGUAGES = ['nl', 'de', 'en', 'cs', 'ru', 'fr', 'es']


def extract_spans_para(seq, seq_type):
    quads = []
    sents = [s.strip() for s in seq.split('[SSEP]')]
    for s in sents:
        try:
            tok_list = ["[C]", "[S]", "[A]", "[O]"]

            for tok in tok_list:
                if tok not in s:
                    s += " {} null".format(tok)
            index_ac = s.index("[C]")
            index_sp = s.index("[S]")
            index_at = s.index("[A]")
            index_ot = s.index("[O]")

            combined_list = [index_ac, index_sp, index_at, index_ot]
            arg_index_list = list(np.argsort(combined_list))

            result = []
            for i in range(len(combined_list)):
                start = combined_list[i] + 4
                sort_index = arg_index_list.index(i)
                if sort_index < 3:
                    next_ = arg_index_list[sort_index + 1]
                    re = s[start:combined_list[next_]]
                else:
                    re = s[start:]
                result.append(re.strip())

            ac, sp, at, ot = result

            # if the aspect term is implicit
            if at.lower() == 'it' or at.lower() == 'es':
                at = 'null'
        except ValueError:
            try:
                print(f'In {seq_type} seq, cannot decode: {s}')
                pass
            except UnicodeEncodeError:
                print(f'In {seq_type} seq, a string cannot be decoded')
                pass
            ac, at, sp, ot = '', '', '', ''

        quads.append((ac, at, sp, ot))

    return quads


def compute_f1_scores(pred_pt, gold_pt, verbose=False):
    """
    Function to compute F1 scores with pred and gold quads
    The input needs to be already processed
    """
    # number of true postive, gold standard, predictions
    n_tp, n_gold, n_pred = 0, 0, 0

    for i in range(len(pred_pt)):
        n_gold += len(gold_pt[i])
        n_pred += len(pred_pt[i])

        for t in pred_pt[i]:
            if t in gold_pt[i]:
                n_tp += 1

    if verbose:
        print(
            f"number of gold spans: {n_gold}, predicted spans: {n_pred}, hit: {n_tp}"
        )

    precision = float(n_tp) / float(n_pred) if n_pred != 0 else 0
    recall = float(n_tp) / float(n_gold) if n_gold != 0 else 0
    f1 = 2 * precision * recall / (
        precision + recall) if precision != 0 or recall != 0 else 0
    scores = {
        'precision': precision * 100,
        'recall': recall * 100,
        'f1': f1 * 100
    }

    return scores


def compute_scores(pred_seqs, gold_seqs, verbose=True, task="asqp"):
    """
    Compute model performance
    """
    assert len(pred_seqs) == len(gold_seqs), (len(pred_seqs), len(gold_seqs))
    num_samples = len(gold_seqs)

    all_labels, all_preds = [], []

    for i in range(num_samples):
        gold_list = extract_spans_para(gold_seqs[i], 'gold')
        pred_list = extract_spans_para(pred_seqs[i], 'pred')
        if (task == "tasd"):
          gold_list = [tup[:-1] for tup in gold_list]
          pred_list = [tup[:-1] for tup in pred_list]

        all_labels.append(gold_list)
        all_preds.append(pred_list)

    try:
        preds = [[f'{labels[0]}:{opinion2sentword[labels[2]].upper() if labels[2] in opinion2sentword else ""}:{labels[1]}' for labels in pred] for pred in all_preds]
        golds = [[f'{labels[0]}:{opinion2sentword[labels[2]].upper() if labels[2] in opinion2sentword else ""}:{labels[1]}' for labels in gold] for gold in all_labels]
    except KeyError:
        print('KeyError!')
        print(all_labels)
    except:
        print(all_labels)

    scores_dfs = createResults(preds, golds, LABEL_SPACE, task)
    
    scores = compute_f1_scores(all_preds, all_labels)
    scores["all_preds"] = all_preds
    scores["all_labels"] = all_labels
    print('MVP F1-Micro: ', scores['f1'])
    # return scores, all_labels, all_preds
    return scores_dfs, all_labels, pred_seqs


def get_element_tokens(task):
    dic = {
        "aste":
            ["[A]", "[O]", "[S]"],
        "tasd":
            ["[A]", "[C]", "[S]"],
        "aocs":
        ["[A]", "[O]", "[C]", "[S]"],
        "asqp":
            ["[A]", "[O]", "[C]", "[S]"],
    }
    return dic[task]

optim_orders_all = None

def get_orders(task, data, data_type, args, sents, labels):
    # Empfehlung: Optimale Reihenfolge der Elemente vorab 1x berechnen. 
    # uncomment to calculate orders from scratch
    
    if torch.cuda.is_available():
          device = torch.device('cuda')
    else:
          device = torch.device("cpu")
    tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path)
    model = MyT5ForConditionalGenerationScore.from_pretrained(
            args.model_name_or_path).to(device)

    # if args.dataset in optim_orders_all_predefined[task].keys():
    #     if args.
    
    global optim_orders_all

    if optim_orders_all == None:
        optim_orders_all = choose_best_order_global(sents, labels, model,
                                             tokenizer, device,
                                             args.task)

    #     print(optim_orders_all)

    if args.single_view_type == 'rank':
           orders = optim_orders_all#[task]["rest16"] # delete [task][data] falls selber berechnet werden
    elif args.single_view_type == 'rand':
           orders = [random.Random(args.seed).choice(
               optim_orders_all[task][data])]
    elif args.single_view_type == "heuristic":
           orders = heuristic_orders[task]

    del model
    return orders




def cal_entropy(inputs, preds, model_path, tokenizer, device=torch.device('cuda')):
    all_entropy = []
    model = MyT5ForConditionalGenerationScore.from_pretrained(model_path).to(
        device)
    batch_size = 8
    _inputs = [' '.join(s) for s in inputs]
    _preds = [' '.join(s) for s in preds]
    for id in range(0, len(inputs), batch_size):
        in_batch = _inputs[id: min(id + batch_size, len(inputs))]
        pred_batch = _preds[id: min(id + batch_size, len(inputs))]
        assert len(in_batch) == len(pred_batch)
        tokenized_input = tokenizer.batch_encode_plus(in_batch,
                                                      max_length=200,
                                                      padding="max_length",
                                                      truncation=True,
                                                      return_tensors="pt")
        tokenized_target = tokenizer.batch_encode_plus(pred_batch,
                                                       max_length=200,
                                                       padding="max_length",
                                                       truncation=True,
                                                       return_tensors="pt")

        target_ids = tokenized_target["input_ids"].to(device)

        target_ids[target_ids[:, :] == tokenizer.pad_token_id] = -100
        outputs = model(
            input_ids=tokenized_input["input_ids"].to(device),
            attention_mask=tokenized_input["attention_mask"].to(device),
            labels=target_ids,
            decoder_attention_mask=tokenized_target["attention_mask"].to(device))

        loss, entropy = outputs[0]
        all_entropy.extend(entropy)
    return all_entropy


def order_scores_function(quad_list, cur_sent, model, tokenizer, device, task):
    q = get_element_tokens(task)

    all_orders = permutations(q)
    all_orders_list = []

    all_targets = []
    all_inputs = []
    cur_sent = " ".join(cur_sent)

    for each_order in all_orders:
        cur_order = "  ".join(each_order) + " "
        all_orders_list.append(cur_order)
        cur_target = []
        for each_q in quad_list:
            cur_target.append(each_q[cur_order][0])

        all_inputs.append(cur_sent)
        all_targets.append(" ".join(cur_target))

    tokenized_input = tokenizer.batch_encode_plus(all_inputs,
                                                  max_length=200,
                                                  padding="max_length",
                                                  truncation=True,
                                                  return_tensors="pt")
    tokenized_target = tokenizer.batch_encode_plus(all_targets,
                                                   max_length=200,
                                                   padding="max_length",
                                                   truncation=True,
                                                   return_tensors="pt")

    target_ids = tokenized_target["input_ids"].to(device)

    target_ids[target_ids[:, :] == tokenizer.pad_token_id] = -100
    outputs = model(
        input_ids=tokenized_input["input_ids"].to(device),
        attention_mask=tokenized_input["attention_mask"].to(device),
        labels=target_ids,
        decoder_attention_mask=tokenized_target["attention_mask"].to(device))

    loss, entropy = outputs[0]
    results = {}
    for i, _ in enumerate(all_orders_list):
        cur_order = all_orders_list[i]
        results[cur_order] = {"loss": loss[i], "entropy": entropy[i]}

    return results


def choose_best_order_global(sents, labels, model, tokenizer, device, task):
    q = get_element_tokens(task)
    all_orders = permutations(q)
    all_orders_list = []
    scores = []

    for each_order in all_orders:
        cur_order = " ".join(each_order)
        all_orders_list.append(cur_order)
        scores.append(0)

    for i in range(len(sents)):
        label = labels[i]
        sent = sents[i]

        quad_list = []
        for _tuple in label:
            # parse ASTE tuple
            if task == "aste":
                _tuple = parse_aste_tuple(_tuple, sent)

            at, ac, sp, ot = get_task_tuple(_tuple, task, args.lang)

            element_dict = {"[A]": at, "[O]": ot, "[C]": ac, "[S]": sp}
            element_list = []
            for key in q:
                element_list.append("{} {}".format(key, element_dict[key]))

            x = permutations(element_list)

            permute_object = {}
            for each in x:
                order = []
                content = []
                for e in each:
                    order.append(e[0:4])
                    content.append(e[4:])
                order_name = " ".join(order)
                content = " ".join(content)
                permute_object[order_name] = [content, " ".join(each)]

            quad_list.append(permute_object)

        order_scores = order_scores_function(quad_list, sent, model, tokenizer,
                                             device, task)

        for e in order_scores:
            e = e[:-1].replace("  ", " ")
            index = all_orders_list.index(e)
            scores[index] += order_scores[e.replace(" ", "  ")+" "]['entropy']

    indexes = np.argsort(np.array(scores))  # [::-1]
    returned_orders = []
    for i in indexes:
        returned_orders.append(all_orders_list[i])

    return returned_orders


def parse_aste_tuple(_tuple, sent):
    if isinstance(_tuple[0], str):
        res = _tuple
    elif isinstance(_tuple[0], list):
        # parse at
        start_idx = _tuple[0][0]
        end_idx = _tuple[0][-1] if len(_tuple[0]) > 1 else start_idx
        at = ' '.join(sent[start_idx:end_idx + 1])

        # parse ot
        start_idx = _tuple[1][0]
        end_idx = _tuple[1][-1] if len(_tuple[1]) > 1 else start_idx
        ot = ' '.join(sent[start_idx:end_idx + 1])
        res = [at, ot, _tuple[2]]
    else:
        print(_tuple)
        raise NotImplementedError
    return res


def get_task_tuple(_tuple, task, lang):
    if task == "aste":
        at, ot, sp = _tuple
        ac = None
    elif task == "tasd":
        at, ac, sp = _tuple
        ot = None
    elif task in ["asqp", "acos"]:
        at, ac, sp, ot = _tuple
    else:
        raise NotImplementedError

    if sp:
        if True:
        # if lang == 'en':
            sp = sentword2opinion[sp.lower()] if sp in sentword2opinion \
                else senttag2opinion[sp.lower()]  # 'POS' -> 'good'
        # else:
        #     sp = sentword2opinion_GER[sp.lower()] if sp in sentword2opinion \
        #         else senttag2opinion_GER[sp.lower()]  # 'POS' -> 'gut'
    if at and at.lower() == 'null':  # for implicit aspect term
        if True:
        # if lang == 'en':
            at = 'it'
        # else:
        #     at = 'es'

    return at, ac, sp, ot


def add_prompt(sent, orders, task, data_name, args):

    # add ctrl_token
    if args.ctrl_token == "none":
        pass
    elif args.ctrl_token == "post":
        sent = sent + orders
    elif args.ctrl_token == "pre":
        sent = orders + sent
    else:
        raise NotImplementedError
    return sent


def get_para_targets(sents, labels, data_name, data_type, top_k, task, args):
    """
    Obtain the target sentence under the paraphrase paradigm
    """
    targets = []
    new_sents = []
    if task in ['aste', 'tasd']:
        # at most 5 orders for triple tasks
        top_k = min(5, top_k)

    optim_orders = get_orders(task, data_name, data_type, args, sents, labels)[:top_k]

    for i in range(len(sents)):
        label = labels[i]
        cur_sent = sents[i]
        cur_sent_str = " ".join(cur_sent)

        # ASTE: parse at & ot
        if task == 'aste':
            assert len(label[0]) == 3
            parsed_label = []
            for _tuple in label:
                parsed_tuple = parse_aste_tuple(_tuple, sents[i])
                parsed_label.append(parsed_tuple)
            label = parsed_label

        # sort label by order of appearance
        # at, ac, sp, ot
        if args.sort_label and len(label) > 1:
            label_pos = {}
            for _tuple in label:
                at, ac, sp, ot = get_task_tuple(_tuple, task, args.lang)

                # get last at / ot position
                at_pos = cur_sent_str.find(at) if at else -1
                ot_pos = cur_sent_str.find(ot) if ot else -1
                last_pos = max(at_pos, ot_pos)
                last_pos = 1e4 if last_pos < 0 else last_pos
                label_pos[tuple(_tuple)] = last_pos
            new_label = [
                list(k)
                for k, _ in sorted(label_pos.items(), key=lambda x: x[1])
            ]
            label = new_label

        quad_list = []
        for _tuple in label:
            at, ac, sp, ot = get_task_tuple(_tuple, task, args.lang)
            element_dict = {"[A]": at, "[O]": ot, "[C]": ac, "[S]": sp}
            token_end = 3

            element_list = []
            for key in optim_orders[0].split(" "):
                element_list.append("{} {}".format(key, element_dict[key]))

            x = permutations(element_list)
            permute_object = {}
            for each in x:
                order = []
                content = []
                for e in each:
                    order.append(e[0:token_end])
                    content.append(e[token_end:])
                order_name = " ".join(order)
                content = " ".join(content)
                permute_object[order_name] = [content, " ".join(each)]

            quad_list.append(permute_object)

        for o in optim_orders:
            tar = []
            for each_q in quad_list:
                tar.append(each_q[o][1])

            targets.append(" [SSEP] ".join(tar))
            # add prompt
            new_sent = add_prompt(cur_sent, o.split(), task, data_name, args)
            new_sents.append(new_sent)

    return new_sents, targets


def get_para_targets_dev(sents, labels, data_name, task, args):
    """
    Obtain the target sentence under the paraphrase paradigm
    """
    new_sents = []
    targets = []
    optim_orders = get_orders(task, data_name, "test", args, sents=None, labels=None)
    top_order = optim_orders[0].split(" ")
    for sent, label in zip(sents, labels):
        all_quad_sentences = []
        for _tuple in label:
            # parse ASTE tuple
            if task == "aste":
                _tuple = parse_aste_tuple(_tuple, sent)

            at, ac, sp, ot = get_task_tuple(_tuple, task, args.lang)

            element_dict = {"[A]": at, "[O]": ot, "[C]": ac, "[S]": sp}
            element_list = []
            for key in top_order:
                element_list.append("{} {}".format(key, element_dict[key]))

            one_quad_sentence = " ".join(element_list)
            all_quad_sentences.append(one_quad_sentence)

        target = ' [SSEP] '.join(all_quad_sentences)
        targets.append(target)

        # add prompt
        sent = add_prompt(sent, top_order, task, data_name, args)

        new_sents.append(sent)
    return new_sents, targets
    
def formatText(text):
    text = re.sub(r'([(".,!?;:/)])', r" \1", text)
    text = re.sub(r'(["„“…])', r'', text)
    text = re.sub(r'([\'])', r' \1', text)
    # text = re.sub(r'([-])', r' \1 ', text)
    text = re.sub(r'([\s\s])', r' ', text)
    text = re.sub(r"\b(I|You|We|They|He|She|It|Don|Didn|Doesn|Can|Couldn|Wouldn|Shouldn|Won|Would|Wasn|Aren|Ain|Isn|Hasn|Haven|Weren|Mightn|Mustn)('|’)(m|t|ll|ve|re|s|d)\b", r"\1 \2\3", text)
    return re.sub(r"\s+", " ", text).strip()
    
def read_line_examples_from_file(data_path,
                                 task_name,
                                 data_name,
                                 lowercase,
                                 silence=True):
    """
    Read data from file, each line is: sent####labels
    Return List[List[word]], List[Tuple]
    """
    tasks, datas = [], []
    sents, labels = [], []
    with open(data_path, "r", encoding="utf-8") as f:
        df = pd.read_json(f, orient="records", lines=True).set_index('id')
        for index, row in df.iterrows():
            tasks.append(task_name)
            datas.append(data_name)
            if lowercase:
                sents.append(formatText(row['text'].lower()).split())
                labels.append([(label[2].lower(), label[0], label[1]) for label in row['labels']])
            else:
                sents.append(formatText(row['text']).split())
                labels.append([(label[2], label[0], label[1]) for label in row['labels']])

    if silence:
        print(f"Total examples = {len(sents)}")
    return tasks, datas, sents, labels

def get_transformed_io(data_path, data_name, data_type, top_k, args):
    """
    The main function to transform input & target according to the task
    """
    tasks, datas, sents, labels = read_line_examples_from_file(
        data_path, args.task, args.dataset, args.lowercase)

    # the input is just the raw sentence
    inputs = [s.copy() for s in sents]

    # low resource
    if data_type == 'train' and args.data_ratio != 1.0:
        num_sample = int(len(inputs) * args.data_ratio)
        sample_indices = random.sample(list(range(0, len(inputs))), num_sample)
        sample_inputs = [inputs[i] for i in sample_indices]
        sample_labels = [labels[i] for i in sample_indices]
        inputs, labels = sample_inputs, sample_labels
        print(
            f"Low resource: {args.data_ratio}, total train examples = {num_sample}")
        if num_sample <= 20:
            print("Labels:", sample_labels)

    if data_type == "train" or args.eval_type == "dev" or data_type == "test":
        new_inputs, targets = get_para_targets(inputs, labels, data_name,
                                               data_type, top_k, args.task,
                                               args)
    else:
        new_inputs, targets = get_para_targets_dev(inputs, labels, data_name,
                                                   args.task, args)

    print(len(inputs), len(new_inputs), len(targets))
    return new_inputs, targets


class ABSADataset(Dataset):

    def __init__(self,
                 tokenizer,
                 task_name,
                 data_setting,
                 data_name,
                 language,
                 data_type,
                 top_k,
                 args,
                 max_len=128):

        
        if data_type == 'train':
            print('Loading Train Dataset')
            if 'multi' in data_setting:
                self.data_path = []
                for lang in [lang for lang in LANGUAGES if lang != language]:
                    self.data_path.append(f'{args.data_path}/{lang}/train_b.json')
                if data_setting == 'multi_id':
                    self.data_path.append(f'{args.data_path}/{language}/train_b.json')
            else:
                self.data_path = f'{args.data_path}/{language}/train{"_b" if data_setting == "balanced" else ""}.json'
        elif data_type == 'val':
            print('Loading Validation Dataset')
            self.data_path = f'{args.data_path}/{language}/test{"_b" if data_setting == "balanced" else ""}.json'
        else:
            print('Loading Test Dataset')
            self.data_path = f'{args.data_path}/{language}/test{"_b" if data_setting == "balanced" else ""}.json'
        self.max_len = max_len
        self.tokenizer = tokenizer
        self.task_name = task_name
        self.language = language
        self.data_setting = data_setting
        self.data_name = data_name
        self.data_type = data_type
        self.args = args

        self.top_k = top_k

        self.inputs = []
        self.targets = []

        self._build_examples()

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, index):
        source_ids = self.inputs[index]["input_ids"].squeeze()
        target_ids = self.targets[index]["input_ids"].squeeze()

        src_mask = self.inputs[index]["attention_mask"].squeeze(
        )  # might need to squeeze
        target_mask = self.targets[index]["attention_mask"].squeeze(
        )  # might need to squeeze
        return {
            "source_ids": source_ids,
            "source_mask": src_mask,
            "target_ids": target_ids,
            "target_mask": target_mask
        }

    def _build_examples(self):


        if type(self.data_path) == list:
            for data_path in self.data_path:
                if self.args.multi_task:
                    inputs, targets = get_transformed_io_unified(
                        data_path, self.task_name, f'{self.data_name}-{self.language}', self.data_type,
                        self.top_k, self.args)
                else:
                    inputs, targets = get_transformed_io(data_path,
                                                         f'{self.data_name}-{self.language}',
                                                         self.data_type, self.top_k,
                                                         self.args)
        
                for i in range(len(inputs)):
                    # change input and target to two strings
                    input = ' '.join(inputs[i])
                    target = targets[i]
        
                    tokenized_input = self.tokenizer.batch_encode_plus(
                        [input],
                        max_length=self.max_len,
                        padding="max_length",
                        truncation=True,
                        return_tensors="pt")
                    
                    # for ACOS Restaurant and Laptop dataset
                    # the max target length is much longer than 200
                    # we need to set a larger max length for inference
                    target_max_length = 1024 if self.data_type == "test" else self.max_len
        
                    tokenized_target = self.tokenizer.batch_encode_plus(
                        [target],
                        max_length=target_max_length,
                        padding="max_length",
                        truncation=True,
                        return_tensors="pt")
        
                    self.inputs.append(tokenized_input)
                    self.targets.append(tokenized_target)
        else:
            if self.args.multi_task:
                inputs, targets = get_transformed_io_unified(
                    self.data_path, self.task_name, f'{self.data_name}-{self.language}', self.data_type,
                    self.top_k, self.args)
            else:
                inputs, targets = get_transformed_io(self.data_path,
                                                     f'{self.data_name}-{self.language}',
                                                     self.data_type, self.top_k,
                                                     self.args)
    
            for i in range(len(inputs)):
                # change input and target to two strings
                input = ' '.join(inputs[i])
                target = targets[i]
    
                tokenized_input = self.tokenizer.batch_encode_plus(
                    [input],
                    max_length=self.max_len,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt")
                
                # for ACOS Restaurant and Laptop dataset
                # the max target length is much longer than 200
                # we need to set a larger max length for inference
                target_max_length = 1024 if self.data_type == "test" else self.max_len
    
                tokenized_target = self.tokenizer.batch_encode_plus(
                    [target],
                    max_length=target_max_length,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt")
    
                self.inputs.append(tokenized_input)
                self.targets.append(tokenized_target)

_CONFIG_FOR_DOC = "T5Config"

def calc_entropy(input_tensor):
    lsm = nn.LogSoftmax()
    log_probs = lsm(input_tensor)
    probs = torch.exp(log_probs)
    p_log_p = log_probs * probs
    entropy = -p_log_p.sum()
    return entropy

# add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
class MyT5ForConditionalGenerationScore(T5PreTrainedModel):
    authorized_missing_keys = [r"encoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight", r"lm_head\.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.model_dim = config.d_model

        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False
        self.encoder = T5Stack(encoder_config, self.shared)

        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
        decoder_config.is_encoder_decoder = False
        decoder_config.num_layers = config.num_decoder_layers
        self.decoder = T5Stack(decoder_config, self.shared)

        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        self.init_weights()

    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)

    def get_output_embeddings(self):
        return self.lm_head

    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        head_mask=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for
            labels in ``[0, ..., config.vocab_size]``
        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
            Used to hide legacy arguments that have been deprecated.

        Returns:

        Examples::

            >>> from transformers import T5Tokenizer, T5ForConditionalGeneration

            >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
            >>> model = T5ForConditionalGeneration.from_pretrained('t5-small', return_dict=True)

            >>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
            >>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2> </s>', return_tensors='pt').input_ids
            >>> outputs = model(input_ids=input_ids, labels=labels)
            >>> loss = outputs.loss
            >>> logits = outputs.logits

            >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids  # Batch size 1
            >>> outputs = model.generate(input_ids)
        """

        if "lm_labels" in kwargs:
            warnings.warn(
                "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
                FutureWarning,
            )
            labels = kwargs.pop("lm_labels")
        if "decoder_past_key_value_states" in kwargs:
            warnings.warn(
                "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            past_key_values = kwargs.pop("decoder_past_key_value_states")
        if "decoder_past_key_values" in kwargs:
            warnings.warn(
                "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            past_key_values = kwargs.pop("decoder_past_key_values")
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."

        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        hidden_states = encoder_outputs[0]

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # If decoding with past key value states, only the last tokens
        # should be given as an input
        if past_key_values is not None:
            assert labels is None, "Decoder should not use cached key value states when training."
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids[:, -1:]
            if decoder_inputs_embeds is not None:
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]
        # Rescale output before projecting on vocab
        # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
        sequence_output = sequence_output * (self.model_dim ** -0.5)
        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-100, reduction="sum")
            loss = []
            entropy = []
            for i in range(lm_logits.size()[0]):
                loss_i = loss_fct(lm_logits[i], labels[i])
                ent = calc_entropy(lm_logits[i, 0: decoder_attention_mask[i].sum().item()])
                loss.append(loss_i.item())
                entropy.append(ent.item())
            loss = [loss, entropy]
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
        if not return_dict:
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
    ):

        # cut decoder_input_ids if past is used
        if past is not None:
            input_ids = input_ids[:, -1:]

        return {
            "decoder_input_ids": input_ids,
            "past_key_values": past,
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
            "use_cache": use_cache,
        }

    def _reorder_cache(self, past, beam_idx):
        # if decoder past is not included in output
        # speedy decoding is disabled and no need to reorder
        if past is None:
            logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
            return past

        reordered_decoder_past = ()
        for layer_past_states in past:
            # get the correct batch idx from layer past batch dim
            # batch dim of `past` is at 2nd position
            reordered_layer_past_states = ()
            for layer_past_state in layer_past_states:
                # need to set correct `past` for each of the four key / value states
                reordered_layer_past_states = reordered_layer_past_states + (
                    layer_past_state.index_select(0, beam_idx),
                )

            assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
            assert len(reordered_layer_past_states) == len(layer_past_states)

            reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
        return reordered_decoder_past

# add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
class MyT5ForConditionalGeneration(T5PreTrainedModel):
    authorized_missing_keys = [r"encoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight", r"lm_head\.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.model_dim = config.d_model

        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False
        self.encoder = T5Stack(encoder_config, self.shared)

        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
        decoder_config.is_encoder_decoder = False
        decoder_config.num_layers = config.num_decoder_layers
        self.decoder = T5Stack(decoder_config, self.shared)

        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        self.init_weights()

    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)

    def get_output_embeddings(self):
        return self.lm_head

    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        head_mask=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for
            labels in ``[0, ..., config.vocab_size]``
        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
            Used to hide legacy arguments that have been deprecated.

        Returns:

        Examples::

            >>> from transformers import T5Tokenizer, T5ForConditionalGeneration

            >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
            >>> model = T5ForConditionalGeneration.from_pretrained('t5-small', return_dict=True)

            >>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
            >>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2> </s>', return_tensors='pt').input_ids
            >>> outputs = model(input_ids=input_ids, labels=labels)
            >>> loss = outputs.loss
            >>> logits = outputs.logits

            >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids  # Batch size 1
            >>> outputs = model.generate(input_ids)
        """

        if "lm_labels" in kwargs:
            warnings.warn(
                "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
                FutureWarning,
            )
            labels = kwargs.pop("lm_labels")
        if "decoder_past_key_value_states" in kwargs:
            warnings.warn(
                "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            past_key_values = kwargs.pop("decoder_past_key_value_states")
        if "decoder_past_key_values" in kwargs:
            warnings.warn(
                "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            past_key_values = kwargs.pop("decoder_past_key_values")
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."

        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        hidden_states = encoder_outputs[0]

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # If decoding with past key value states, only the last tokens
        # should be given as an input
        if past_key_values is not None:
            assert labels is None, "Decoder should not use cached key value states when training."
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids[:, -1:]
            if decoder_inputs_embeds is not None:
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]
        # Rescale output before projecting on vocab
        # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
        sequence_output = sequence_output * (self.model_dim ** -0.5)
        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
            #lm_logits_max, lm_logits_max_index = torch.max(lm_logits, dim=-1)

        if not return_dict:
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
    ):

        # cut decoder_input_ids if past is used
        if past is not None:
            input_ids = input_ids[:, -1:]

        return {
            "decoder_input_ids": input_ids,
            "past_key_values": past,
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
            "use_cache": use_cache,
        }

    def _reorder_cache(self, past, beam_idx):
        # if decoder past is not included in output
        # speedy decoding is disabled and no need to reorder
        if past is None:
            logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
            return past

        reordered_decoder_past = ()
        for layer_past_states in past:
            # get the correct batch idx from layer past batch dim
            # batch dim of `past` is at 2nd position
            reordered_layer_past_states = ()
            for layer_past_state in layer_past_states:
                # need to set correct `past` for each of the four key / value states
                reordered_layer_past_states = reordered_layer_past_states + (
                    layer_past_state.index_select(0, beam_idx),
                )

            assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
            assert len(reordered_layer_past_states) == len(layer_past_states)

            reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
        return reordered_decoder_past

def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    # torch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")


class T5FineTuner(pl.LightningModule):
    """
    Fine tune a pre-trained T5 model
    """

    def __init__(self, config, tfm_model, tokenizer, args):
        super().__init__()
        self.save_hyperparameters(ignore=['tfm_model'])
        self.config = config
        self.model = tfm_model
        self.tokenizer = tokenizer
        self.args = args
        
        self.precompute_tokens()

    def forward(self,
                input_ids,
                attention_mask=None,
                decoder_input_ids=None,
                decoder_attention_mask=None,
                labels=None):
        return self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            labels=labels,
        )

    def _step(self, batch):
        lm_labels = batch["target_ids"]
        lm_labels[lm_labels[:, :] == self.tokenizer.pad_token_id] = -100

        outputs = self(input_ids=batch["source_ids"],
                       attention_mask=batch["source_mask"],
                       labels=lm_labels,
                       decoder_attention_mask=batch['target_mask'])

        loss = outputs[0]
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._step(batch)
        self.log("train_loss", loss)
        return loss

    def evaluate(self, batch, stage=None):
        # get f1
        outs = self.model.generate(input_ids=batch['source_ids'],
                                   attention_mask=batch['source_mask'],
                                   max_length=self.config.max_seq_length,
                                   return_dict_in_generate=True,
                                   output_scores=True,
                                   num_beams=1)

        dec = [
            self.tokenizer.decode(ids, skip_special_tokens=True)
            for ids in outs.sequences
        ]
        target = [
            self.tokenizer.decode(ids, skip_special_tokens=True)
            for ids in batch["target_ids"]
        ]
        scores, _, _ = compute_scores(dec, target, verbose=False, task=self.args.task)
        # f1 = torch.tensor(scores['f1'], dtype=torch.float64)
        f1 = torch.tensor(scores[4]['Micro-AVG']['f1'], dtype=torch.float64)
        
        # get loss
        loss = self._step(batch)

        if stage:
            self.log(f"{stage}_loss",
                     loss,
                     prog_bar=True,
                     on_step=False,
                     on_epoch=True)
            self.log(f"{stage}_f1",
                     f1,
                     prog_bar=True,
                     on_step=False,
                     on_epoch=True)

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        """ Prepare optimizer and schedule (linear warmup and decay) """
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                self.config.weight_decay,
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=self.config.learning_rate,
                          eps=self.config.adam_epsilon)
        scheduler = {
            "scheduler":
            get_linear_schedule_with_warmup(optimizer,
                                            **self.config.lr_scheduler_init),
            "interval":
            "step",
        }
        return [optimizer], [scheduler]

    def train_dataloader(self):
        print("load training data.")
        train_dataset = ABSADataset(tokenizer=self.tokenizer,
                                    task_name=args.task,
                                    data_setting=args.data_setting,
                                    data_name=args.dataset,
                                    language=args.lang,
                                    data_type="train",
                                    top_k=self.config.top_k,
                                    args=self.config,
                                    max_len=self.config.max_seq_length)

        dataloader = DataLoader(
            train_dataset,
            batch_size=self.config.train_batch_size,
            drop_last=True
            if args.data_ratio > 0.3 else False, # don't drop on few-shot
            shuffle=True,
            num_workers=4)

        return dataloader

    def val_dataloader(self):
        val_dataset = ABSADataset(tokenizer=self.tokenizer,
                                  task_name=args.task,
                                  data_setting=args.data_setting,
                                    data_name=args.dataset,
                                    language=args.lang,
                                    data_type="val",
                                  top_k=self.config.num_path,
                                  args=self.config,
                                  max_len=self.config.max_seq_length)
        return DataLoader(val_dataset,
                          batch_size=self.config.eval_batch_size,
                          num_workers=4)


    @staticmethod
    def rindex(_list, _value):
        return len(_list) - _list[::-1].index(_value) - 1

    def precompute_tokens(self):
        dic = {"cate_tokens":{}, "all_tokens":{}, "sentiment_tokens":{}, 'special_tokens':[]}
        for task in force_words.keys():
            dic["all_tokens"][task] = {}
            for dataset in force_words[task].keys():
                cur_list = force_words[task][dataset]
                tokenize_res = []
                for w in cur_list:
                    tokenize_res.extend(self.tokenizer(w, return_tensors='pt')['input_ids'].tolist()[0])
                dic["all_tokens"][task][dataset] = tokenize_res
        for k,v in cate_list.items():
            tokenize_res = []
            for w in v:
                tokenize_res.extend(self.tokenizer(w, return_tensors='pt')['input_ids'].tolist()[0]) 
            dic["cate_tokens"][str(k)] = tokenize_res
        sp_tokenize_res = []
        for sp in ['great', 'ok', 'bad', 'gut', 'ok', 'schlecht']:
            sp_tokenize_res.extend(self.tokenizer(sp, return_tensors='pt')['input_ids'].tolist()[0])
        for task in force_words.keys():
            dic['sentiment_tokens'][str(task)] = sp_tokenize_res
        #dic['sentiment_tokens'] = sp_tokenize_res
        special_tokens_tokenize_res = []
        for w in ['[O','[A','[S','[C','[SS']:
            special_tokens_tokenize_res.extend(self.tokenizer(w, return_tensors='pt')['input_ids'].tolist()[0]) 
        if self.args.model_name_or_path == 'google/mt5-small':
            special_tokens_tokenize_res = [r for r in special_tokens_tokenize_res if r != 491]
        else:
            special_tokens_tokenize_res = [r for r in special_tokens_tokenize_res if r != 784]
        dic['special_tokens'] = special_tokens_tokenize_res

        self.force_tokens = dic 
    
    def prefix_allowed_tokens_fn(self, task, data_name, source_ids, batch_id,
                                 input_ids):
        """
        Constrained Decoding
        # ids = self.tokenizer("text", return_tensors='pt')['input_ids'].tolist()[0]
        """
            
        force_tokens = self.force_tokens
        
        # google/mt5-small
        to_id = {
            'OT': [646],
            'AT': [357],
            'SP': [399],
            'AC': [424],
            'SS': [14826],
            'EP': [22818],
            '[': [491],
            ']': [259, 439],
            'it': [609],
            # 'es': [3, 15, 7],
            'null': [259, 1181]
        }

        left_brace_index = (input_ids == to_id['['][0]).nonzero()
        right_brace_index = (input_ids == to_id[']'][0]).nonzero()
        num_left_brace = len(left_brace_index)
        num_right_brace = len(right_brace_index)
        last_right_brace_pos = right_brace_index[-1][
            0] if right_brace_index.nelement() > 0 else -1
        last_left_brace_pos = left_brace_index[-1][
            0] if left_brace_index.nelement() > 0 else -1
        cur_id = input_ids[-1]

        if cur_id in to_id['[']:
            return force_tokens['special_tokens']
        elif cur_id in to_id['AT'] + to_id['OT'] + to_id['EP'] + to_id['SP'] + to_id['AC']:  
            return to_id[']']  
        elif cur_id in to_id['SS']:  
            return to_id['EP'] 
       
        # get cur_term
        if last_left_brace_pos == -1:
            return to_id['['] + [1]   # start of sentence: [
        elif (last_left_brace_pos != -1 and last_right_brace_pos == -1) \
            or last_left_brace_pos > last_right_brace_pos:
            return to_id[']']  # ]
        else:
            cur_term = input_ids[last_left_brace_pos + 1]

        ret = []
        if cur_term in to_id['SP']:  # SP
            ret = force_tokens['sentiment_tokens'][str(task)]
        elif cur_term in to_id['AT']:  # AT
            force_list = source_ids[batch_id].tolist()
            if task != 'aste': 
                if True:
                # if data_name == 'rest-16':
                    force_list.extend(to_id['it'] + [1])  
                # elif data_name == 'GERestaurant':
                #     force_list.extend(to_id['es'] + [1])  
            ret = force_list  
        elif cur_term in to_id['SS']:
            ret = [3] + to_id[']'] + [1]
        elif cur_term in to_id['AC']:  # AC
            ret = force_tokens['cate_tokens'][str(data_name)]
        elif cur_term in to_id['OT']:  # OT
            force_list = source_ids[batch_id].tolist()
            if task == "acos":
                force_list.extend(to_id['null'])  # null
            ret = force_list
        else:
            raise ValueError(cur_term)    

        if num_left_brace == num_right_brace:
            ret = set(ret)
            ret.discard(to_id[']'][0]) # remove ]
            for w in force_tokens['special_tokens']:
                ret.discard(w)
            ret = list(ret)
        elif num_left_brace > num_right_brace:
            ret += to_id[']'] 
        else:
            raise ValueError
        ret.extend(to_id['['] + [1]) # add [
        return ret


def evaluate(model, task, lang, data_setting, dataset, args, data_type):
    """
    Compute scores given the predictions and gold labels
    """

    outputs, targets, probs = [], [], []
    num_path = args.num_path
    
    if task in ['aste', 'tasd']:
        num_path = min(5, num_path)
        
    dataset = ABSADataset(tokenizer=model.tokenizer,
                                  task_name=task,
                                  data_setting=data_setting,
                                    data_name=dataset,
                                    language=lang,
                                    data_type=data_type,
                                  top_k=num_path,
                                  args=args,
                                  max_len=args.max_seq_length)
    
    data_loader = DataLoader(dataset,
                             batch_size=args.eval_batch_size,
                             num_workers=2)
    
    device = torch.device('cuda:0')
    model.model.to(device)
    model.model.eval()

    for batch in tqdm(data_loader):
        # beam search

        outs = model.model.generate(
            input_ids=batch['source_ids'].to(device),
            attention_mask=batch['source_mask'].to(device),
            max_length=args.max_seq_length,
            num_beams=args.beam_size,
            early_stopping=True,
            return_dict_in_generate=True,
            output_scores=True,
            prefix_allowed_tokens_fn=partial(
               model.prefix_allowed_tokens_fn, task, args.dataset,
               batch['source_ids']) if args.constrained_decode else None,
        )
        dec = [
            model.tokenizer.decode(ids, skip_special_tokens=True)
            for ids in outs.sequences
        ]
        target = [
            model.tokenizer.decode(ids, skip_special_tokens=True)
            for ids in batch["target_ids"]
        ]
        outputs.extend(dec)
        targets.extend(target)

    if args.multi_path:
        targets = targets[::num_path]

        # get outputs
        _outputs = outputs # backup
        outputs = [] # new outputs
        if args.agg_strategy == 'post_rank':
            inputs = [ele for ele in sents for _ in range(num_path)]
            assert len(_outputs) == len(inputs), (len(_outputs), len(inputs))
            preds = [[o] for o in _outputs] 
            model_path = os.path.join(args.output_dir, "final")
            scores = cal_entropy(inputs, preds, model_path, model.tokenizer)

        for i in range(0, len(targets)):
            o_idx = i * num_path
            multi_outputs = _outputs[o_idx:o_idx + num_path]

            if args.agg_strategy == 'post_rank':
                multi_probs = scores[o_idx:o_idx + args.num_path]
                assert len(multi_outputs) == len(multi_probs)

                sorted_outputs = [i for _,i in sorted(zip(multi_probs,multi_outputs))]
                outputs.append(sorted_outputs[0])
                continue
            elif args.agg_strategy == "pre_rank":
                outputs.append(multi_outputs[0])
                continue
            elif args.agg_strategy == 'rand':
                outputs.append(random.choice(multi_outputs))
                continue
            elif args.agg_strategy == 'vote':
                all_quads = []
                for s in multi_outputs:
                    all_quads.extend(
                        extract_spans_para(seq=s, seq_type='pred'))

                output_quads = []
                counter = dict(Counter(all_quads))
                for quad, count in counter.items():
                    # keep freq >= num_path / 2
                    if count >= len(multi_outputs) / 2:
                        output_quads.append(quad)

                # recover output
                output = []
                for q in output_quads:
                    ac, at, sp, ot = q
                    if task == "aste":
                        if 'null' not in [at, ot, sp]:  # aste has no 'null', for zero-shot only
                            output.append(f'[A] {at} [O] {ot} [S] {sp}')

                    elif task  == "tasd":
                        output.append(f"[A] {at} [S] {sp} [C] {ac}")

                    elif task in ["asqp", "acos"]:
                        output.append(f"[A] {at} [O] {ot} [S] {sp} [C] {ac}")

                    else:
                        raise NotImplementedError

                target_quads = extract_spans_para(seq=targets[i],
                                                seq_type='gold')

                # if no output, use the first path
                output_str = " [SSEP] ".join(
                    output) if output else multi_outputs[0]

                outputs.append(output_str)

    # stats
    labels_counts = Counter([len(l.split('[SSEP]')) for l in outputs])

    print('After Prediction')
    print('Preds')
    print(outputs[:5])
    print('Golds')
    print(targets[:5])
    
    scores, all_labels, preds = compute_scores(outputs,
                                                   targets,
                                                   verbose=True, task=args.task)
    return scores, preds


import signal
import sys

def signal_handler(signal, frame):
    print("Training interrupted by user!")
    sys.exit(0)

signal.signal(signal.SIGINT, signal_handler)


set_seed(5)

Random seed set as 5


In [144]:
class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

args = Namespace()

args.model_name_or_path = 'google/mt5-small'
args.lang = 'en'
args.data_path = '../../../data/restaurant/'
args.lang_setting = 'orig'
args.eval_type = 'test'
args.data_setting = 'orig'
args.output_dir = 'results/'
args.task = 'tasd'
args.num_train_epochs = 30
args.save_top_k = 0
args.top_k = 5
args.ctrl_token = 'post' # Ändern?
args.multi_path = True
args.num_path = 5
args.lowercase = True
args.train_batch_size = 8
args.gradient_accumulation_steps = 2
args.learning_rate = 1e-4
args.sort_label = True
args.agg_strategy = 'vote'
args.data_ratio = 1.0
args.eval_batch_size = 16
args.constrained_decode = True
args.do_train = True
args.max_seq_length = 200
args.multi_task = False
args.single_view_type = 'rank'
args.check_val_every_n_epoch = 20
args.load_path_cache = False
args.beam_size = 1
args.warmup_steps = 0
args.adam_epsilon = 1e-8
args.weight_decay = 0
args.n_gpu = 1
args.load_ckpt_name = False

In [145]:
import torch
import sys
import os
from subprocess import call
print('_____Python, Pytorch, Cuda info____')
print('__Python VERSION:', sys.version)
print('__pyTorch VERSION:', torch.__version__)
print('__CUDA RUNTIME API VERSION')
#os.system('nvcc --version')
print('__CUDNN VERSION:', torch.backends.cudnn.version())
print('_____nvidia-smi GPU details____')
call(["nvidia-smi", "--format=csv", "--query-gpu=index,name,driver_version,memory.total,memory.used,memory.free"])
print('_____Device assignments____')
print('Number CUDA Devices:', torch.cuda.device_count())
print ('Current cuda device: ', torch.cuda.current_device(), ' **May not correspond to nvidia-smi ID above, check visibility parameter')
print("Device name: ", torch.cuda.get_device_name(torch.cuda.current_device()))

_____Python, Pytorch, Cuda info____
__Python VERSION: 3.10.12 (main, Nov  6 2024, 20:22:13) [GCC 11.4.0]
__pyTorch VERSION: 2.1.2+cu121
__CUDA RUNTIME API VERSION
__CUDNN VERSION: 8902
_____nvidia-smi GPU details____
index, name, driver_version, memory.total [MiB], memory.used [MiB], memory.free [MiB]
0, NVIDIA RTX A5000, 525.147.05, 24564 MiB, 3 MiB, 24243 MiB
1, NVIDIA RTX A5000, 525.147.05, 24564 MiB, 10901 MiB, 13337 MiB
_____Device assignments____
Number CUDA Devices: 1
Current cuda device:  0  **May not correspond to nvidia-smi ID above, check visibility parameter
Device name:  NVIDIA RTX A5000


In [None]:
tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path)
print("\n****** Conduct Training ******")

# for token in ['OT','AT',
#     'SP',
#     'AC',
#     'SS',
#     'EP',
#     '[',
#     ']',
#     'it',
#     'null']:
#     print(token, tokenizer(token, return_tensors='pt')['input_ids'].tolist())
args.dataset = f'rest-{args.lang}'

# initialize the T5 model
tfm_model = MyT5ForConditionalGeneration.from_pretrained(args.model_name_or_path)
model = T5FineTuner(args, tfm_model, tokenizer, args)

# load data
train_loader = model.train_dataloader()

# config optimizer
t_total = ((len(train_loader.dataset) //
            (args.train_batch_size * max(1, args.n_gpu))) //
           args.gradient_accumulation_steps *
           float(args.num_train_epochs))

args.lr_scheduler_init = {
    "num_warmup_steps": args.warmup_steps,
    "num_training_steps": t_total
}

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath=args.output_dir,
    filename='{epoch}-{val_f1:.2f}-{val_loss:.2f}',
    monitor='val_f1',
    mode='max',
    save_top_k=args.save_top_k,
    save_last=False)

early_stop_callback = EarlyStopping(monitor="val_f1",
                                    min_delta=0.00,
                                    patience=20,
                                    verbose=True,
                                    mode="max")
lr_monitor = LearningRateMonitor(logging_interval='step')

# prepare for trainer
train_params = dict(
    accelerator="gpu",
    devices=1,
    default_root_dir=args.output_dir,
    accumulate_grad_batches=args.gradient_accumulation_steps,
    gradient_clip_val=1.0,
    max_epochs=args.num_train_epochs,
    check_val_every_n_epoch=1,
    callbacks=[
        checkpoint_callback, early_stop_callback,
        TQDMProgressBar(refresh_rate=10), lr_monitor
    ],
)

trainer = pl.Trainer(**train_params)

try:
    trainer.fit(model)
except KeyboardInterrupt:
    print("Training has been stopped manually.")
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.to(device)

# save the final model
#model.model.save_pretrained(os.path.join(args.output_dir, "final"))
#tokenizer.save_pretrained(os.path.join(args.output_dir, "final"))
print("Finish training and saving the model!")


****** Conduct Training ******


You are using a model of type mt5 to instantiate a model of type t5. This is not supported for all configurations of models and can yield errors.


load training data.
Loading Train Dataset
Total examples = 1708


You are using a model of type mt5 to instantiate a model of type t5. This is not supported for all configurations of models and can yield errors.
  return self._call_impl(*args, **kwargs)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A5000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


1708 8540 8540


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]

  | Name  | Type                         | Params
-------------------------------------------------------
0 | model | MyT5ForConditionalGeneration | 300 M 
-------------------------------------------------------
300 M     Trainable params
0         Non-trainable params
300 M     Total params
1,200.707 Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Loading Validation Dataset
Total examples = 587


You are using a model of type mt5 to instantiate a model of type t5. This is not supported for all configurations of models and can yield errors.


587 587 587
MVP F1-Micro:  0
MVP F1-Micro:  0
load training data.
Loading Train Dataset
Total examples = 1708


You are using a model of type mt5 to instantiate a model of type t5. This is not supported for all configurations of models and can yield errors.


1708 8540 8540


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0


Metric val_f1 improved. New best score: 0.001


MVP F1-Micro:  0


Validation: 0it [00:00, ?it/s]

MVP F1-Micro:  5.555555555555555
MVP F1-Micro:  5.714285714285714
MVP F1-Micro:  3.92156862745098
MVP F1-Micro:  0
MVP F1-Micro:  4.705882352941177
MVP F1-Micro:  6.896551724137931
MVP F1-Micro:  10.714285714285714
MVP F1-Micro:  0
MVP F1-Micro:  4.545454545454546
MVP F1-Micro:  13.333333333333334
MVP F1-Micro:  5.194805194805195
MVP F1-Micro:  5.2631578947368425
MVP F1-Micro:  0
MVP F1-Micro:  5.555555555555556
MVP F1-Micro:  19.04761904761905
MVP F1-Micro:  15.789473684210526
MVP F1-Micro:  3.3898305084745757
MVP F1-Micro:  4.651162790697675
MVP F1-Micro:  4.347826086956522
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  3.3333333333333335
MVP F1-Micro:  20.0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  0
MVP F1-Micro:  3.3333333333333335
MVP F1-Micro:  1.9230769230769231
MVP F1-Micro:  3.4782608695652173
MVP F1-Micro:  1.8691588785046727
MVP F1-Micro:  13.114754098360654
MVP F1-Micro:  2.4691358024691357
MVP F1-Micro:  6.349206349206349
MVP F1-

Metric val_f1 improved by 0.047 >= min_delta = 0.0. New best score: 0.048


MVP F1-Micro:  0


Validation: 0it [00:00, ?it/s]

MVP F1-Micro:  22.22222222222222
MVP F1-Micro:  10.526315789473683
MVP F1-Micro:  20.634920634920633
MVP F1-Micro:  4.705882352941176
MVP F1-Micro:  29.78723404255319
MVP F1-Micro:  18.6046511627907
MVP F1-Micro:  21.276595744680854
MVP F1-Micro:  6.521739130434782
MVP F1-Micro:  16.3265306122449
MVP F1-Micro:  21.21212121212121
MVP F1-Micro:  9.876543209876543
MVP F1-Micro:  9.836065573770494
MVP F1-Micro:  8.333333333333334
MVP F1-Micro:  31.034482758620697
MVP F1-Micro:  22.78481012658228
MVP F1-Micro:  13.043478260869565
MVP F1-Micro:  24.390243902439025
MVP F1-Micro:  10.526315789473685
MVP F1-Micro:  25.641025641025646
MVP F1-Micro:  9.6
MVP F1-Micro:  0
MVP F1-Micro:  6.451612903225808
MVP F1-Micro:  8.51063829787234
MVP F1-Micro:  8.333333333333332
MVP F1-Micro:  4.166666666666666
MVP F1-Micro:  0
MVP F1-Micro:  6.153846153846154
MVP F1-Micro:  9.00900900900901
MVP F1-Micro:  16.129032258064516
MVP F1-Micro:  9.523809523809522
MVP F1-Micro:  15.384615384615385
MVP F1-Micro:  9.

Metric val_f1 improved by 0.079 >= min_delta = 0.0. New best score: 0.126


MVP F1-Micro:  0


Validation: 0it [00:00, ?it/s]

MVP F1-Micro:  35.294117647058826
MVP F1-Micro:  16.666666666666668
MVP F1-Micro:  50.70422535211269
MVP F1-Micro:  4.25531914893617
MVP F1-Micro:  34.146341463414636
MVP F1-Micro:  25.641025641025646
MVP F1-Micro:  34.14634146341463
MVP F1-Micro:  17.02127659574468
MVP F1-Micro:  13.636363636363638
MVP F1-Micro:  20.833333333333336
MVP F1-Micro:  40.909090909090914
MVP F1-Micro:  15.0
MVP F1-Micro:  10.0
MVP F1-Micro:  48.78048780487806
MVP F1-Micro:  55.319148936170215
MVP F1-Micro:  25.641025641025646
MVP F1-Micro:  27.77777777777778
MVP F1-Micro:  14.545454545454547
MVP F1-Micro:  27.58620689655172
MVP F1-Micro:  43.99999999999999
MVP F1-Micro:  15.0
MVP F1-Micro:  10.256410256410255
MVP F1-Micro:  15.384615384615383
MVP F1-Micro:  18.18181818181818
MVP F1-Micro:  5.0
MVP F1-Micro:  4.761904761904763
MVP F1-Micro:  11.538461538461538
MVP F1-Micro:  59.70149253731344
MVP F1-Micro:  23.076923076923077
MVP F1-Micro:  23.809523809523807
MVP F1-Micro:  31.57894736842105
MVP F1-Micro:  3

Metric val_f1 improved by 0.077 >= min_delta = 0.0. New best score: 0.203


MVP F1-Micro:  0


Validation: 0it [00:00, ?it/s]

MVP F1-Micro:  52.94117647058824
MVP F1-Micro:  21.62162162162162
MVP F1-Micro:  65.3061224489796
MVP F1-Micro:  30.434782608695656
MVP F1-Micro:  46.51162790697674
MVP F1-Micro:  41.02564102564102
MVP F1-Micro:  32.558139534883715
MVP F1-Micro:  25.531914893617024
MVP F1-Micro:  13.636363636363638
MVP F1-Micro:  11.76470588235294
MVP F1-Micro:  40.0
MVP F1-Micro:  14.814814814814813
MVP F1-Micro:  23.25581395348837
MVP F1-Micro:  58.536585365853654
MVP F1-Micro:  48.275862068965516
MVP F1-Micro:  39.02439024390244
MVP F1-Micro:  22.222222222222225
MVP F1-Micro:  18.18181818181818
MVP F1-Micro:  28.070175438596497
MVP F1-Micro:  34.04255319148936
MVP F1-Micro:  12.76595744680851
MVP F1-Micro:  29.268292682926827
MVP F1-Micro:  15.0
MVP F1-Micro:  48.387096774193544
MVP F1-Micro:  5.128205128205128
MVP F1-Micro:  4.545454545454545
MVP F1-Micro:  23.076923076923077
MVP F1-Micro:  48.888888888888886
MVP F1-Micro:  35.8974358974359
MVP F1-Micro:  28.571428571428566
MVP F1-Micro:  30.769230

Metric val_f1 improved by 0.032 >= min_delta = 0.0. New best score: 0.235


MVP F1-Micro:  8.51063829787234


Validation: 0it [00:00, ?it/s]

MVP F1-Micro:  34.285714285714285
MVP F1-Micro:  59.45945945945946
MVP F1-Micro:  57.14285714285714
MVP F1-Micro:  68.0
MVP F1-Micro:  59.57446808510638
MVP F1-Micro:  53.658536585365844
MVP F1-Micro:  31.818181818181817
MVP F1-Micro:  40.0
MVP F1-Micro:  32.0
MVP F1-Micro:  23.52941176470588
MVP F1-Micro:  34.04255319148936
MVP F1-Micro:  22.727272727272727
MVP F1-Micro:  44.44444444444445
MVP F1-Micro:  63.63636363636365
MVP F1-Micro:  65.3061224489796
MVP F1-Micro:  34.14634146341463
MVP F1-Micro:  55.00000000000001
MVP F1-Micro:  21.428571428571427
MVP F1-Micro:  33.898305084745765
MVP F1-Micro:  38.297872340425535
MVP F1-Micro:  36.84210526315789
MVP F1-Micro:  42.857142857142854
MVP F1-Micro:  28.57142857142857
MVP F1-Micro:  25.0
MVP F1-Micro:  19.047619047619047
MVP F1-Micro:  13.33333333333333
MVP F1-Micro:  31.818181818181817
MVP F1-Micro:  49.999999999999986
MVP F1-Micro:  36.36363636363636
MVP F1-Micro:  31.111111111111118
MVP F1-Micro:  31.57894736842105
MVP F1-Micro:  60.

Metric val_f1 improved by 0.114 >= min_delta = 0.0. New best score: 0.349


MVP F1-Micro:  9.75609756097561


Validation: 0it [00:00, ?it/s]

MVP F1-Micro:  40.0
MVP F1-Micro:  64.86486486486486
MVP F1-Micro:  65.21739130434783
MVP F1-Micro:  62.22222222222222
MVP F1-Micro:  63.63636363636365
MVP F1-Micro:  56.41025641025642
MVP F1-Micro:  32.558139534883715
MVP F1-Micro:  48.0
MVP F1-Micro:  45.83333333333333
MVP F1-Micro:  28.57142857142857
MVP F1-Micro:  50.0
MVP F1-Micro:  23.809523809523807
MVP F1-Micro:  66.66666666666666
MVP F1-Micro:  65.11627906976744
MVP F1-Micro:  51.06382978723405
MVP F1-Micro:  30.0
MVP F1-Micro:  56.41025641025641
MVP F1-Micro:  21.428571428571427
MVP F1-Micro:  39.28571428571428
MVP F1-Micro:  47.82608695652174
MVP F1-Micro:  42.10526315789474
MVP F1-Micro:  60.0
MVP F1-Micro:  39.02439024390244
MVP F1-Micro:  22.222222222222225
MVP F1-Micro:  24.390243902439025
MVP F1-Micro:  18.6046511627907
MVP F1-Micro:  34.14634146341463
MVP F1-Micro:  48.888888888888886
MVP F1-Micro:  38.095238095238095
MVP F1-Micro:  51.162790697674424
MVP F1-Micro:  36.84210526315789
MVP F1-Micro:  60.465116279069775
M

Metric val_f1 improved by 0.064 >= min_delta = 0.0. New best score: 0.413


MVP F1-Micro:  10.526315789473685


Validation: 0it [00:00, ?it/s]

MVP F1-Micro:  51.42857142857144
MVP F1-Micro:  59.45945945945946
MVP F1-Micro:  73.91304347826087
MVP F1-Micro:  66.66666666666666
MVP F1-Micro:  72.72727272727272
MVP F1-Micro:  56.41025641025642
MVP F1-Micro:  32.558139534883715
MVP F1-Micro:  48.97959183673469
MVP F1-Micro:  45.83333333333333
MVP F1-Micro:  28.57142857142857
MVP F1-Micro:  50.0
MVP F1-Micro:  41.86046511627907
MVP F1-Micro:  61.904761904761905
MVP F1-Micro:  55.81395348837209
MVP F1-Micro:  63.829787234042556
MVP F1-Micro:  50.0
MVP F1-Micro:  51.28205128205129
MVP F1-Micro:  25.53191489361702
MVP F1-Micro:  35.71428571428571
MVP F1-Micro:  52.17391304347826
MVP F1-Micro:  54.05405405405405
MVP F1-Micro:  58.536585365853654
MVP F1-Micro:  38.095238095238095
MVP F1-Micro:  22.727272727272727
MVP F1-Micro:  29.268292682926827
MVP F1-Micro:  23.25581395348837
MVP F1-Micro:  35.0
MVP F1-Micro:  45.454545454545446
MVP F1-Micro:  48.78048780487805
MVP F1-Micro:  54.54545454545454
MVP F1-Micro:  36.84210526315789
MVP F1-M

Metric val_f1 improved by 0.021 >= min_delta = 0.0. New best score: 0.434


MVP F1-Micro:  15.384615384615383


Validation: 0it [00:00, ?it/s]

MVP F1-Micro:  51.42857142857144
MVP F1-Micro:  48.64864864864865
MVP F1-Micro:  89.36170212765957
MVP F1-Micro:  74.99999999999999
MVP F1-Micro:  73.91304347826086
MVP F1-Micro:  56.41025641025642
MVP F1-Micro:  37.2093023255814
MVP F1-Micro:  48.97959183673469
MVP F1-Micro:  35.99999999999999
MVP F1-Micro:  24.000000000000004
MVP F1-Micro:  54.54545454545454
MVP F1-Micro:  38.0952380952381
MVP F1-Micro:  61.904761904761905
MVP F1-Micro:  62.22222222222222
MVP F1-Micro:  76.0
MVP F1-Micro:  58.536585365853654
MVP F1-Micro:  56.41025641025641
MVP F1-Micro:  32.6530612244898
MVP F1-Micro:  35.08771929824561
MVP F1-Micro:  60.86956521739131
MVP F1-Micro:  57.89473684210527
MVP F1-Micro:  63.41463414634146
MVP F1-Micro:  42.857142857142854
MVP F1-Micro:  31.11111111111111
MVP F1-Micro:  33.33333333333333
MVP F1-Micro:  31.818181818181824
MVP F1-Micro:  33.33333333333333
MVP F1-Micro:  53.33333333333332
MVP F1-Micro:  47.61904761904761
MVP F1-Micro:  53.33333333333332
MVP F1-Micro:  36.842

Metric val_f1 improved by 0.020 >= min_delta = 0.0. New best score: 0.454


MVP F1-Micro:  9.75609756097561


Validation: 0it [00:00, ?it/s]

MVP F1-Micro:  51.42857142857144
MVP F1-Micro:  64.86486486486486
MVP F1-Micro:  82.6086956521739
MVP F1-Micro:  66.66666666666666
MVP F1-Micro:  73.91304347826086
MVP F1-Micro:  55.000000000000014
MVP F1-Micro:  31.818181818181817
MVP F1-Micro:  48.0
MVP F1-Micro:  36.73469387755102
MVP F1-Micro:  28.000000000000004
MVP F1-Micro:  56.52173913043478
MVP F1-Micro:  37.2093023255814
MVP F1-Micro:  43.90243902439025
MVP F1-Micro:  59.09090909090908
MVP F1-Micro:  62.499999999999986
MVP F1-Micro:  43.90243902439025
MVP F1-Micro:  51.28205128205129
MVP F1-Micro:  28.57142857142857
MVP F1-Micro:  38.59649122807017
MVP F1-Micro:  57.777777777777786
MVP F1-Micro:  57.89473684210527
MVP F1-Micro:  63.41463414634146
MVP F1-Micro:  42.857142857142854
MVP F1-Micro:  22.727272727272727
MVP F1-Micro:  42.857142857142854
MVP F1-Micro:  36.36363636363636
MVP F1-Micro:  33.33333333333333
MVP F1-Micro:  44.44444444444445
MVP F1-Micro:  48.78048780487805
MVP F1-Micro:  48.888888888888886
MVP F1-Micro:  3

Validation: 0it [00:00, ?it/s]

MVP F1-Micro:  51.42857142857144
MVP F1-Micro:  64.86486486486486
MVP F1-Micro:  89.79591836734694
MVP F1-Micro:  73.46938775510205
MVP F1-Micro:  76.59574468085107
MVP F1-Micro:  55.000000000000014
MVP F1-Micro:  35.55555555555555
MVP F1-Micro:  48.0
MVP F1-Micro:  44.89795918367348
MVP F1-Micro:  26.229508196721312
MVP F1-Micro:  56.52173913043478
MVP F1-Micro:  45.45454545454545
MVP F1-Micro:  55.81395348837208
MVP F1-Micro:  62.22222222222222
MVP F1-Micro:  73.46938775510203
MVP F1-Micro:  43.90243902439025
MVP F1-Micro:  51.28205128205129
MVP F1-Micro:  35.08771929824562
MVP F1-Micro:  38.59649122807017
MVP F1-Micro:  65.21739130434783
MVP F1-Micro:  63.1578947368421
MVP F1-Micro:  66.66666666666666
MVP F1-Micro:  47.61904761904761
MVP F1-Micro:  21.73913043478261
MVP F1-Micro:  38.095238095238095
MVP F1-Micro:  22.727272727272723
MVP F1-Micro:  37.2093023255814
MVP F1-Micro:  46.80851063829787
MVP F1-Micro:  52.38095238095239
MVP F1-Micro:  47.82608695652174
MVP F1-Micro:  46.153

Metric val_f1 improved by 0.009 >= min_delta = 0.0. New best score: 0.464


MVP F1-Micro:  15.0


Validation: 0it [00:00, ?it/s]

MVP F1-Micro:  62.85714285714287
MVP F1-Micro:  64.86486486486486
MVP F1-Micro:  89.79591836734694
MVP F1-Micro:  70.83333333333334
MVP F1-Micro:  70.83333333333334
MVP F1-Micro:  58.536585365853654
MVP F1-Micro:  40.0
MVP F1-Micro:  39.21568627450981
MVP F1-Micro:  43.99999999999999
MVP F1-Micro:  29.629629629629633
MVP F1-Micro:  57.777777777777786
MVP F1-Micro:  45.45454545454545
MVP F1-Micro:  48.888888888888886
MVP F1-Micro:  66.66666666666666
MVP F1-Micro:  69.3877551020408
MVP F1-Micro:  53.65853658536586
MVP F1-Micro:  61.53846153846154
MVP F1-Micro:  42.30769230769231
MVP F1-Micro:  38.59649122807017
MVP F1-Micro:  59.57446808510638
MVP F1-Micro:  68.42105263157895
MVP F1-Micro:  74.4186046511628
MVP F1-Micro:  47.61904761904761
MVP F1-Micro:  29.166666666666668
MVP F1-Micro:  37.2093023255814
MVP F1-Micro:  34.78260869565217
MVP F1-Micro:  32.558139534883715
MVP F1-Micro:  51.06382978723404
MVP F1-Micro:  52.38095238095239
MVP F1-Micro:  51.06382978723404
MVP F1-Micro:  42.10

Metric val_f1 improved by 0.015 >= min_delta = 0.0. New best score: 0.478


MVP F1-Micro:  16.666666666666664


Validation: 0it [00:00, ?it/s]

MVP F1-Micro:  51.42857142857144
MVP F1-Micro:  64.86486486486486
MVP F1-Micro:  83.33333333333334
MVP F1-Micro:  70.83333333333334
MVP F1-Micro:  79.16666666666667
MVP F1-Micro:  53.658536585365844
MVP F1-Micro:  31.818181818181817
MVP F1-Micro:  47.05882352941176
MVP F1-Micro:  40.816326530612244
MVP F1-Micro:  33.9622641509434
MVP F1-Micro:  62.222222222222236
MVP F1-Micro:  45.45454545454545
MVP F1-Micro:  51.162790697674424
MVP F1-Micro:  71.11111111111111
MVP F1-Micro:  69.3877551020408
MVP F1-Micro:  48.78048780487805
MVP F1-Micro:  61.53846153846154
MVP F1-Micro:  43.99999999999999
MVP F1-Micro:  39.28571428571428
MVP F1-Micro:  68.08510638297872
MVP F1-Micro:  61.53846153846153
MVP F1-Micro:  61.904761904761905
MVP F1-Micro:  42.857142857142854
MVP F1-Micro:  28.57142857142857
MVP F1-Micro:  36.36363636363636
MVP F1-Micro:  38.297872340425535
MVP F1-Micro:  32.558139534883715
MVP F1-Micro:  55.319148936170215
MVP F1-Micro:  47.61904761904761
MVP F1-Micro:  52.17391304347826
MV

Metric val_f1 improved by 0.003 >= min_delta = 0.0. New best score: 0.482


MVP F1-Micro:  19.51219512195122


Validation: 0it [00:00, ?it/s]

MVP F1-Micro:  51.42857142857144
MVP F1-Micro:  64.86486486486486
MVP F1-Micro:  85.71428571428572
MVP F1-Micro:  65.3061224489796
MVP F1-Micro:  79.16666666666667
MVP F1-Micro:  63.63636363636365
MVP F1-Micro:  40.0
MVP F1-Micro:  57.6923076923077
MVP F1-Micro:  42.30769230769231
MVP F1-Micro:  33.9622641509434
MVP F1-Micro:  59.09090909090909
MVP F1-Micro:  45.45454545454545
MVP F1-Micro:  52.38095238095239
MVP F1-Micro:  66.66666666666666
MVP F1-Micro:  65.3061224489796
MVP F1-Micro:  48.78048780487805
MVP F1-Micro:  61.53846153846154
MVP F1-Micro:  43.13725490196078
MVP F1-Micro:  42.10526315789474
MVP F1-Micro:  70.83333333333334
MVP F1-Micro:  70.27027027027027
MVP F1-Micro:  69.76744186046511
MVP F1-Micro:  42.857142857142854
MVP F1-Micro:  24.489795918367346
MVP F1-Micro:  41.86046511627907
MVP F1-Micro:  43.47826086956522
MVP F1-Micro:  36.36363636363637
MVP F1-Micro:  55.319148936170215
MVP F1-Micro:  37.2093023255814
MVP F1-Micro:  48.888888888888886
MVP F1-Micro:  42.105263

Metric val_f1 improved by 0.018 >= min_delta = 0.0. New best score: 0.499


MVP F1-Micro:  20.0


Validation: 0it [00:00, ?it/s]

MVP F1-Micro:  51.42857142857144
MVP F1-Micro:  75.67567567567566
MVP F1-Micro:  89.79591836734694
MVP F1-Micro:  65.3061224489796
MVP F1-Micro:  80.85106382978724
MVP F1-Micro:  63.63636363636365
MVP F1-Micro:  36.36363636363637
MVP F1-Micro:  53.84615384615385
MVP F1-Micro:  47.05882352941176
MVP F1-Micro:  38.46153846153847
MVP F1-Micro:  55.81395348837209
MVP F1-Micro:  50.0
MVP F1-Micro:  51.162790697674424
MVP F1-Micro:  66.66666666666666
MVP F1-Micro:  69.3877551020408
MVP F1-Micro:  48.78048780487805
MVP F1-Micro:  61.53846153846154
MVP F1-Micro:  43.99999999999999
MVP F1-Micro:  42.10526315789474
MVP F1-Micro:  66.66666666666666
MVP F1-Micro:  70.27027027027027
MVP F1-Micro:  65.11627906976743
MVP F1-Micro:  42.857142857142854
MVP F1-Micro:  28.57142857142857
MVP F1-Micro:  46.51162790697674
MVP F1-Micro:  40.909090909090914
MVP F1-Micro:  32.558139534883715
MVP F1-Micro:  55.319148936170215
MVP F1-Micro:  52.38095238095239
MVP F1-Micro:  47.82608695652174
MVP F1-Micro:  36.84

Metric val_f1 improved by 0.010 >= min_delta = 0.0. New best score: 0.509


MVP F1-Micro:  20.0


Validation: 0it [00:00, ?it/s]

MVP F1-Micro:  57.14285714285715
MVP F1-Micro:  75.67567567567566
MVP F1-Micro:  93.87755102040816
MVP F1-Micro:  65.3061224489796
MVP F1-Micro:  76.59574468085107
MVP F1-Micro:  63.63636363636365
MVP F1-Micro:  40.0
MVP F1-Micro:  56.60377358490566
MVP F1-Micro:  43.99999999999999
MVP F1-Micro:  45.28301886792453
MVP F1-Micro:  65.11627906976743
MVP F1-Micro:  50.0
MVP F1-Micro:  53.33333333333332
MVP F1-Micro:  66.66666666666666
MVP F1-Micro:  71.99999999999999
MVP F1-Micro:  43.90243902439025
MVP F1-Micro:  61.53846153846154
MVP F1-Micro:  43.99999999999999
MVP F1-Micro:  42.10526315789474
MVP F1-Micro:  66.66666666666666
MVP F1-Micro:  73.68421052631578
MVP F1-Micro:  65.11627906976743
MVP F1-Micro:  47.61904761904761
MVP F1-Micro:  24.489795918367346
MVP F1-Micro:  42.857142857142854
MVP F1-Micro:  40.909090909090914
MVP F1-Micro:  31.818181818181817
MVP F1-Micro:  51.06382978723404
MVP F1-Micro:  51.162790697674424
MVP F1-Micro:  52.17391304347826
MVP F1-Micro:  41.02564102564102

Metric val_f1 improved by 0.004 >= min_delta = 0.0. New best score: 0.513


MVP F1-Micro:  23.809523809523807


Validation: 0it [00:00, ?it/s]

MVP F1-Micro:  51.42857142857144
MVP F1-Micro:  75.67567567567566
MVP F1-Micro:  83.33333333333334
MVP F1-Micro:  66.66666666666666
MVP F1-Micro:  73.91304347826086
MVP F1-Micro:  63.41463414634146
MVP F1-Micro:  53.33333333333332
MVP F1-Micro:  50.98039215686274
MVP F1-Micro:  47.05882352941176
MVP F1-Micro:  41.50943396226415
MVP F1-Micro:  60.46511627906976
MVP F1-Micro:  45.45454545454545
MVP F1-Micro:  61.904761904761905
MVP F1-Micro:  71.11111111111111
MVP F1-Micro:  68.08510638297872
MVP F1-Micro:  43.90243902439025
MVP F1-Micro:  61.53846153846154
MVP F1-Micro:  44.89795918367347
MVP F1-Micro:  42.10526315789474
MVP F1-Micro:  68.08510638297872
MVP F1-Micro:  78.94736842105262
MVP F1-Micro:  69.76744186046511
MVP F1-Micro:  47.61904761904761
MVP F1-Micro:  24.489795918367346
MVP F1-Micro:  42.857142857142854
MVP F1-Micro:  50.0
MVP F1-Micro:  37.2093023255814
MVP F1-Micro:  51.06382978723404
MVP F1-Micro:  53.658536585365844
MVP F1-Micro:  47.82608695652174
MVP F1-Micro:  47.36

Metric val_f1 improved by 0.017 >= min_delta = 0.0. New best score: 0.530


MVP F1-Micro:  20.0


Validation: 0it [00:00, ?it/s]

MVP F1-Micro:  51.42857142857144
MVP F1-Micro:  75.67567567567566
MVP F1-Micro:  85.71428571428572
MVP F1-Micro:  62.500000000000014
MVP F1-Micro:  69.56521739130434
MVP F1-Micro:  53.658536585365844
MVP F1-Micro:  45.45454545454545
MVP F1-Micro:  48.0
MVP F1-Micro:  46.15384615384615
MVP F1-Micro:  46.15384615384615
MVP F1-Micro:  60.46511627906976
MVP F1-Micro:  46.51162790697674
MVP F1-Micro:  57.14285714285713
MVP F1-Micro:  71.11111111111111
MVP F1-Micro:  66.66666666666667
MVP F1-Micro:  45.0
MVP F1-Micro:  66.66666666666667
MVP F1-Micro:  50.0
MVP F1-Micro:  42.10526315789474
MVP F1-Micro:  72.3404255319149
MVP F1-Micro:  76.92307692307692
MVP F1-Micro:  58.536585365853654
MVP F1-Micro:  52.38095238095239
MVP F1-Micro:  36.734693877551024
MVP F1-Micro:  48.78048780487805
MVP F1-Micro:  50.0
MVP F1-Micro:  36.36363636363637
MVP F1-Micro:  51.06382978723404
MVP F1-Micro:  48.78048780487805
MVP F1-Micro:  56.52173913043478
MVP F1-Micro:  36.84210526315789
MVP F1-Micro:  65.11627906

Validation: 0it [00:00, ?it/s]

MVP F1-Micro:  57.14285714285715
MVP F1-Micro:  75.67567567567566
MVP F1-Micro:  85.71428571428572
MVP F1-Micro:  64.0
MVP F1-Micro:  66.66666666666666
MVP F1-Micro:  63.63636363636365
MVP F1-Micro:  48.888888888888886
MVP F1-Micro:  51.85185185185185
MVP F1-Micro:  47.27272727272727
MVP F1-Micro:  46.42857142857143
MVP F1-Micro:  55.81395348837209
MVP F1-Micro:  47.82608695652174
MVP F1-Micro:  55.81395348837208
MVP F1-Micro:  71.11111111111111
MVP F1-Micro:  74.50980392156863
MVP F1-Micro:  47.61904761904761
MVP F1-Micro:  66.66666666666667
MVP F1-Micro:  50.98039215686274
MVP F1-Micro:  44.827586206896555
MVP F1-Micro:  66.66666666666666
MVP F1-Micro:  76.92307692307692
MVP F1-Micro:  74.4186046511628
MVP F1-Micro:  52.38095238095239
MVP F1-Micro:  23.52941176470588
MVP F1-Micro:  48.888888888888886
MVP F1-Micro:  47.82608695652174
MVP F1-Micro:  37.2093023255814
MVP F1-Micro:  51.06382978723404
MVP F1-Micro:  41.86046511627907
MVP F1-Micro:  60.86956521739131
MVP F1-Micro:  41.0256

In [157]:
"""
Compute scores given the predictions and gold labels
"""

outputs, targets, probs = [], [], []
num_path = args.num_path

if args.task in ['aste', 'tasd']:
    num_path = min(5, num_path)
    
dataset = ABSADataset(tokenizer=model.tokenizer,
                              task_name=args.task,
                              data_setting=args.data_setting,
                                data_name=args.dataset,
                                language=args.lang,
                                data_type='test',
                              top_k=args.num_path,
                              args=args,
                              max_len=args.max_seq_length)

data_loader = DataLoader(dataset,
                         batch_size=args.eval_batch_size,
                         num_workers=2)

device = torch.device('cuda:0')
model.model.to(device)
model.model.eval()
"a"

Loading Test Dataset
Total examples = 587


You are using a model of type mt5 to instantiate a model of type t5. This is not supported for all configurations of models and can yield errors.


587 2935 2935


'a'

In [158]:
args.constrained_decode = True

for batch in tqdm(data_loader):
    # beam search

    outs = model.model.generate(
        input_ids=batch['source_ids'].to(device),
        attention_mask=batch['source_mask'].to(device),
        max_length=args.max_seq_length,
        num_beams=args.beam_size,
        early_stopping=True,
        return_dict_in_generate=True,
        output_scores=True,
        prefix_allowed_tokens_fn=partial(
           prefix_allowed_tokens_fn, model, args.task, args.dataset,
           batch['source_ids']) if args.constrained_decode else None,
    )
    dec = [
        model.tokenizer.decode(ids, skip_special_tokens=True)
        for ids in outs.sequences
    ]
    target = [
        model.tokenizer.decode(ids, skip_special_tokens=True)
        for ids in batch["target_ids"]
    ]
    outputs.extend(dec)
    targets.extend(target)



100%|██████████| 184/184 [02:33<00:00,  1.20it/s]


In [156]:
def compute_scores(pred_seqs, gold_seqs, verbose=True, task="asqp"):
    global all_labels
    global all_preds
    """
    Compute model performance
    """
    assert len(pred_seqs) == len(gold_seqs), (len(pred_seqs), len(gold_seqs))
    num_samples = len(gold_seqs)

    all_labels, all_preds = [], []

    for i in range(num_samples):
        gold_list = extract_spans_para(gold_seqs[i], 'gold')
        pred_list = extract_spans_para(pred_seqs[i], 'pred')
        if (task == "tasd"):
          gold_list = [tup[:-1] for tup in gold_list]
          pred_list = [tup[:-1] for tup in pred_list]

        all_labels.append(gold_list)
        all_preds.append(pred_list)

    try:
        preds = [[f'{labels[0]}:{opinion2sentword[labels[2]].upper() if labels[2] in opinion2sentword else ""}:{labels[1]}' for labels in pred] for pred in all_preds]
        golds = [[f'{labels[0]}:{opinion2sentword[labels[2]].upper() if labels[2] in opinion2sentword else ""}:{labels[1]}' for labels in gold] for gold in all_labels]
    except KeyError:
        print('KeyError!')
        print(all_labels)
    except:
        print(all_labels)

    all_label = all_labels
    
    scores_dfs = createResults(preds, golds, LABEL_SPACE, task)
    
    scores = compute_f1_scores(all_preds, all_labels)
    scores["all_preds"] = all_preds
    scores["all_labels"] = all_labels
    print('MVP F1-Micro: ', scores['f1'])
    # return scores, all_labels, all_preds
    return scores_dfs, all_labels, pred_seqs

a = compute_scores(outputs, targets, verbose=True, task=args.task)

MVP F1-Micro:  17.14865550481989


In [159]:
targets = targets[::num_path]

# get outputs
_outputs = outputs # backup
outputs = [] # new outputs
if args.agg_strategy == 'post_rank':
    inputs = [ele for ele in sents for _ in range(num_path)]
    assert len(_outputs) == len(inputs), (len(_outputs), len(inputs))
    preds = [[o] for o in _outputs] 
    model_path = os.path.join(args.output_dir, "final")
    scores = cal_entropy(inputs, preds, model_path, model.tokenizer)

for i in range(0, len(targets)):
    o_idx = i * num_path
    multi_outputs = _outputs[o_idx:o_idx + num_path]

    if args.agg_strategy == 'post_rank':
        multi_probs = scores[o_idx:o_idx + args.num_path]
        assert len(multi_outputs) == len(multi_probs)

        sorted_outputs = [i for _,i in sorted(zip(multi_probs,multi_outputs))]
        outputs.append(sorted_outputs[0])
        continue
    elif args.agg_strategy == "pre_rank":
        outputs.append(multi_outputs[0])
        continue
    elif args.agg_strategy == 'rand':
        outputs.append(random.choice(multi_outputs))
        continue
    elif args.agg_strategy == 'vote':
        all_quads = []
        for s in multi_outputs:
            all_quads.extend(
                extract_spans_para(seq=s, seq_type='pred'))

        output_quads = []
        counter = dict(Counter(all_quads))
        for quad, count in counter.items():
            # keep freq >= num_path / 2
            if count >= len(multi_outputs) / 2:
                output_quads.append(quad)

        # recover output
        output = []
        for q in output_quads:
            ac, at, sp, ot = q
            if args.task == "aste":
                if 'null' not in [at, ot, sp]:  # aste has no 'null', for zero-shot only
                    output.append(f'[A] {at} [O] {ot} [S] {sp}')

            elif args.task  == "tasd":
                output.append(f"[A] {at} [S] {sp} [C] {ac}")

            elif args.task in ["asqp", "acos"]:
                output.append(f"[A] {at} [O] {ot} [S] {sp} [C] {ac}")

            else:
                raise NotImplementedError

        target_quads = extract_spans_para(seq=targets[i],
                                        seq_type='gold')

        # if no output, use the first path
        output_str = " [SSEP] ".join(
            output) if output else multi_outputs[0]

        outputs.append(output_str)

# stats
labels_counts = Counter([len(l.split('[SSEP]')) for l in outputs])

print('After Prediction')
print('Preds')
print(outputs[:5])
print('Golds')
print(targets[:5])

scores, all_labels, preds = compute_scores(outputs,
                                               targets,
                                               verbose=True, task=args.task)

After Prediction
Preds
['[A] null [S] great [C] restaurant general', '[A] sushi [S] great [C] food quality', '[A] portions [S] ok [C] food style_options', '[A] green tea creme brulee [S] great [C] food quality', '[A] restaurant [S] bad [C] restaurant general']
Golds
['[S] great [A] it [C] food quality', '[S] great [A] sushi [C] food quality', '[S] ok [A] portions [C] food style_options', '[S] great [A] green tea creme brulee [C] food quality', '[S] great [A] it [C] food quality']
MVP F1-Micro:  56.39781685870224


In [160]:
scores[4]

{'ambience general': {'precision': 0.6538,
  'recall': 0.7727,
  'f1': 0.7083,
  'accuracy': 0.5484,
  'support': 93},
 'drinks prices': {'precision': 0.3333,
  'recall': 0.25,
  'f1': 0.2857,
  'accuracy': 0.1667,
  'support': 6},
 'drinks quality': {'precision': 0.8125,
  'recall': 0.5909,
  'f1': 0.6842,
  'accuracy': 0.52,
  'support': 25},
 'drinks style_options': {'precision': 0.4615,
  'recall': 0.5,
  'f1': 0.48,
  'accuracy': 0.3158,
  'support': 19},
 'food prices': {'precision': 0.4118,
  'recall': 0.6087,
  'f1': 0.4912,
  'accuracy': 0.3256,
  'support': 43},
 'food quality': {'precision': 0.5651,
  'recall': 0.4856,
  'f1': 0.5223,
  'accuracy': 0.3535,
  'support': 430},
 'food style_options': {'precision': 0.4419,
  'recall': 0.3455,
  'f1': 0.3878,
  'accuracy': 0.2405,
  'support': 79},
 'location general': {'precision': 0.5455,
  'recall': 0.4615,
  'f1': 0.5,
  'accuracy': 0.3333,
  'support': 18},
 'restaurant general': {'precision': 0.713,
  'recall': 0.5775,
  'f

In [152]:
def prefix_allowed_tokens_fn(model, task, data_name, source_ids, batch_id,
                                 input_ids):
    """
    Constrained Decoding
    # ids = self.tokenizer("text", return_tensors='pt')['input_ids'].tolist()[0]
    """
    
    force_tokens = model.force_tokens
    
    # google/mt5-small
    to_id = {
        'OT': [646],
        'AT': [357],
        'SP': [399],
        'AC': [424],
        'SS': [399],
        'EP': [155719],
        '[': [491],
        ']': [439],
        'it': [609],
        # 'es': [3, 15, 7],
        'null': [259, 1181]
    }

    left_brace_index = (input_ids == to_id['['][0]).nonzero()
    right_brace_index = (input_ids == to_id[']'][0]).nonzero()
    num_left_brace = len(left_brace_index)
    num_right_brace = len(right_brace_index)
    last_right_brace_pos = right_brace_index[-1][
        0] if right_brace_index.nelement() > 0 else -1
    last_left_brace_pos = left_brace_index[-1][
        0] if left_brace_index.nelement() > 0 else -1
    cur_id = input_ids[-1]

    if cur_id in to_id['[']:
        return force_tokens['special_tokens']
    elif cur_id in to_id['AT'] + to_id['OT'] + to_id['EP'] + to_id['SP'] + to_id['AC']:  
        # return to_id[']']  
        ### MT5 ONLY
        return to_id[']'] + to_id['EP']  
    elif cur_id in to_id['SS']:  
        return to_id['EP'] 
   
    # get cur_term
    if last_left_brace_pos == -1:
        return to_id['['] + [1]   # start of sentence: [
    elif (last_left_brace_pos != -1 and last_right_brace_pos == -1) \
        or last_left_brace_pos > last_right_brace_pos:
        return to_id[']']  # ]
    else:
        cur_term = input_ids[last_left_brace_pos + 1]

    ret = []
    if cur_term in to_id['SP']:  # SP
        ret = force_tokens['sentiment_tokens'][str(task)]
    elif cur_term in to_id['AT']:  # AT
        force_list = source_ids[batch_id].tolist()
        if task != 'aste': 
            if True:
            # if data_name == 'rest-16':
                force_list.extend(to_id['it'] + [1])  
            # elif data_name == 'GERestaurant':
            #     force_list.extend(to_id['es'] + [1])  
        ret = force_list  
    elif cur_term in to_id['SS']:
        #ret = [3] + to_id[']'] + [1]
        ### MT5 ONLY
        ret = [259] + to_id[']'] + [1]
    elif cur_term in to_id['AC']:  # AC
        ret = force_tokens['cate_tokens'][str(data_name)]
    elif cur_term in to_id['OT']:  # OT
        force_list = source_ids[batch_id].tolist()
        if task == "acos":
            force_list.extend(to_id['null'])  # null
        ret = force_list
    else:
        raise ValueError(cur_term)    

    if num_left_brace == num_right_brace:
        ret = set(ret)
        ret.discard(to_id[']'][0]) # remove ]
        for w in force_tokens['special_tokens']:
            ret.discard(w)
        ret = list(ret)
    elif num_left_brace > num_right_brace:
        ret += to_id[']'] 
    else:
        raise ValueError
    ret.extend(to_id['['] + [1]) # add [
    return ret


In [153]:
st = 'the lemon chicken tasted like sticky sweet donuts and the honey walnut prawns, the few they actually give you.....were not good.'

tokenized_input = model.tokenizer.batch_encode_plus(
                        [st],
                        max_length=1000,
                        padding="max_length",
                        truncation=True,
                        return_tensors="pt")

outs1 = model.model.generate(
        input_ids=tokenized_input['input_ids'].to(device),
        attention_mask=tokenized_input['attention_mask'].to(device),
        max_length=1000,
        num_beams=args.beam_size,
        early_stopping=True,
        return_dict_in_generate=True,
        output_scores=True,
        prefix_allowed_tokens_fn=partial(
           prefix_allowed_tokens_fn, model, args.task, args.dataset,
           batch['source_ids']) if args.constrained_decode else None,
)

model.tokenizer.decode(outs1.sequences[0], skip_special_tokens=True)


'[S] bad [A] food  food [C] food quality [SSEP] [S] bad [A] food [C] food quality [SSEP] [C] food quality [SSEP] [C] food quality [SSEP] [C] food quality [A] food [S] bad [SSEP] [C] food quality [A] it'

In [151]:
from collections import Counter
print(Counter([a  for labels in all_labels for a,at,p in labels]))
print(Counter([a  for labels in all_preds for a,at,p in labels]))

Counter({'food quality': 313, 'service general': 155, 'restaurant general': 142, 'ambience general': 66, 'food style_options': 55, 'restaurant miscellaneous': 33, 'food prices': 23, 'drinks quality': 22, 'restaurant prices': 21, 'location general': 13, 'drinks style_options': 12, 'drinks prices': 4})
Counter({'food quality': 269, 'service general': 150, 'restaurant general': 115, 'ambience general': 78, 'food style_options': 43, 'restaurant miscellaneous': 35, 'food prices': 34, 'restaurant prices': 18, 'drinks quality': 16, 'drinks style_options': 13, 'location general': 11, 'null': 5, 'drinks prices': 3})


In [121]:
list(set([a  for labels in all_labels for a,at,p in labels]))

['service general',
 'food quality',
 'restaurant general',
 'food style_options',
 'ambience general',
 'location general',
 'drinks quality',
 'food prices',
 'drinks style_options',
 'drinks prices',
 'restaurant prices',
 'restaurant miscellaneous']