In [1]:
#1. build our own dataset

In [38]:
from datasets import Dataset, DatasetDict
from transformers import DataCollatorForTokenClassification, AutoModelForTokenClassification, TrainingArguments, Trainer
import pandas as pd
import numpy as np

# map POS to POS_id
cnt = 0
POS_id = {}
POS_ls = ['NN', 'IN', 'NNP', 'DT', 'NNS', 'JJ', 'COMMA', 'CD', '.', 'VBD', 'RB','VB', 'CC', 'VBN', 'VBZ', 
          'VBG', 'TO', 'PRP', 'VBP', 'POS', 'PRP$','MD', '$', '``', "''", 'WDT', ':', 'JJR', 'RP', 'RBR', 
          'WP', 'NNPS','JJS', ')', '(', 'EX', 'RBS', 'WRB', '-', 'UH', 'WP$', 'PDT', '/', '#', 'LS', 'SYM', 'FW', 'AUX']
for pos in POS_ls:
    POS_id[pos] = cnt
    cnt += 1

# map BIO to BIO_id
cnt = 0
BIO_id = {}
BIO_ls = ['O', 'B-NP', 'I-NP', 'B-PP', 'B-ADVP', 'B-ADJP', 'B-SBAR', 'B-CONJP',
       'I-ADJP', 'I-PP', 'I-ADVP', 'I-CONJP', 'B-INTJ', 'I-SBAR', 'B-LST',
       'B-VP', 'B-PRT', 'I-INTJ', 'I-VP']
for bio in BIO_ls:
    BIO_id[bio] = cnt
    cnt += 1

In [266]:


# map label to ARG_id
Label_id = {"ARG0":0,"ARG1":1,"ARG2":2,"PRED":3,"SUPPORT":4}

def mapLabel(label):
    return Label_id[label] if label in Label_id else 5

def mapPOS(label):
    return POS_id[label] if label in POS_id else len(POS_id)+1

def mapBIO(label):
    return BIO_id[label] if label in BIO_id else len(BIO_id)+1

# build datasets
def condense_df(file):
    df = pd.DataFrame()
    with open(file, 'r') as file:
        ls = [i.split('\t') for i in file.read().split('\n')]
        df = pd.DataFrame(ls)

    df['id'] = df.index
    df[0].replace('', np.nan, inplace=True)
    df.dropna(axis=0, subset = [0], inplace = True)
    df['BIO'] = df[2].map(mapBIO)
    df['POS'] = df[1].map(mapPOS)
    df['label'] = df[5].map(mapLabel)
    df['id'] = df[4].map(int)
    df.drop(columns = [1, 2, 3, 4, 5, 6], inplace = True)
    condense = df.groupby('id').apply(lambda x: [list(x[0]),list(x['POS']), list(x['BIO']), list(x['label'])]).apply(pd.Series)
    condense.columns =['tokens','BIO', 'POS','label']
    return condense


train = Dataset.from_pandas(condense_df("Partitive-Files/%_nombank.clean.train"))
eval_ = Dataset.from_pandas(condense_df("Partitive-Files/%_nombank.clean.dev"))
test = Dataset.from_pandas(condense_df("Partitive-Files/%_nombank.clean.test"))


datasets = DatasetDict({"train": train, "validation":eval_, "test":test})
datasets

DatasetDict({
    train: Dataset({
        features: ['tokens', 'BIO', 'POS', 'label', 'id'],
        num_rows: 2174
    })
    validation: Dataset({
        features: ['tokens', 'BIO', 'POS', 'label', 'id'],
        num_rows: 83
    })
    test: Dataset({
        features: ['tokens', 'BIO', 'POS', 'label', 'id'],
        num_rows: 150
    })
})

In [40]:
# 2. tokenize

In [311]:

from transformers import AutoTokenizer

model_checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

def align_labels_with_tokens(labels, POSs, BIOs, word_ids):
    new_labels = []
    POS_labels = []
    BIO_labels = []
    current_word = None
    for word_id in word_ids:
        if not word_id:
            new_labels.append(-100)
            POS_labels.append(0)
            BIO_labels.append(0)
        else:
            if word_id != current_word:# Start of a new word!
                current_word = word_id       
            new_labels.append(labels[word_id])
            POS_labels.append(POSs[word_id])
            BIO_labels.append(BIOs[word_id])

    return new_labels, POS_labels, BIO_labels

def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)
    all_BIO = examples["BIO"]
    all_POS = examples["POS"]
    all_labels = examples["label"]
    
    new_labels = []
    POS_labels = []
    BIO_labels = []
    for i in range(len(all_labels)):
        word_ids = tokenized_inputs.word_ids(i)
        new_label, POS_label, BIO_label = align_labels_with_tokens(all_labels[i], all_POS[i], all_BIO[i], word_ids)
        new_labels.append(new_label)
        POS_labels.append(POS_label)
        BIO_labels.append(BIO_label)

    tokenized_inputs["BIOL"] = BIO_labels
    tokenized_inputs["POSL"] = POS_labels
    tokenized_inputs["labels"] = new_labels
    #print("new_labels::"+str(new_labels))
    #print("POS_labels::"+str(POS_labels))

    return tokenized_inputs

tokenized_datasets = datasets.map(
    tokenize_and_align_labels,
    batched=True,
    remove_columns = datasets["train"].column_names,
)
"""
from transformers import AutoTokenizer

model_checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

def align_labels_with_tokens(labels, word_ids):
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if not word_id:
            new_labels.append(-100)
        else:
            if word_id != current_word:# Start of a new word!
                current_word = word_id       
            label = labels[word_id]
            new_labels.append(label)

    return new_labels

def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)
    all_labels = examples["label"]
    new_labels = []
    for i, labels in enumerate(all_labels):
        word_ids = tokenized_inputs.word_ids(i)
        new_labels.append(align_labels_with_tokens(labels, word_ids))

    tokenized_inputs["labels"] = new_labels
    return tokenized_inputs

tokenized_datasets = datasets.map(
    tokenize_and_align_labels,
    batched=True,
    remove_columns = datasets["train"].column_names,
)
"""

  0%|          | 0/3 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

'\nfrom transformers import AutoTokenizer\n\nmodel_checkpoint = "bert-base-uncased"\ntokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n\ndef align_labels_with_tokens(labels, word_ids):\n    new_labels = []\n    current_word = None\n    for word_id in word_ids:\n        if not word_id:\n            new_labels.append(-100)\n        else:\n            if word_id != current_word:# Start of a new word!\n                current_word = word_id       \n            label = labels[word_id]\n            new_labels.append(label)\n\n    return new_labels\n\ndef tokenize_and_align_labels(examples):\n    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)\n    all_labels = examples["label"]\n    new_labels = []\n    for i, labels in enumerate(all_labels):\n        word_ids = tokenized_inputs.word_ids(i)\n        new_labels.append(align_labels_with_tokens(labels, word_ids))\n\n    tokenized_inputs["labels"] = new_labels\n    return tokenized_inputs\n\

In [None]:
# 3. train while evaluate

In [312]:

import evaluate
import numpy as np

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
metric = evaluate.load("seqeval")

small_train = tokenized_datasets["train"].shuffle(seed=42).select(range(64))
small_eval = tokenized_datasets["validation"].shuffle(seed=42).select(range(16))
small_test = tokenized_datasets["test"].shuffle(seed=42).select(range(8))


BatchSize = 1
from torch.utils.data import DataLoader
train_dataloader = DataLoader(
    small_train, shuffle=True, batch_size=BatchSize, collate_fn=data_collator#tokenized_datasets["train"]
)
eval_dataloader = DataLoader(
    small_eval, batch_size=BatchSize, collate_fn=data_collator#tokenized_datasets["validation"]
)
test_dataloader = DataLoader(
    small_test, batch_size=BatchSize, collate_fn=data_collator#tokenized_datasets["test"]
)

In [None]:
#len(small_train['POS'])

In [None]:
#len(small_train['POS'][2])

In [337]:
label_names = ["ARG0", "ARG1", "ARG2", "PRED", "SUPPORT", "None"]
POS_len = len(POS_ls)
BIO_len = len(BIO_ls)
feature_dim = 768

In [338]:
from transformers import AutoConfig, AutoModel
from transformers.modeling_outputs import TokenClassifierOutput

import torch.nn as nn
class CustomModel(nn.Module):
    def __init__(self,checkpoint,num_labels): 
        super(CustomModel,self).__init__() 
        self.num_labels = num_labels 

        #Load Model with given checkpoint and extract its body
        self.model = AutoModel.from_pretrained(checkpoint,config=AutoConfig.from_pretrained(checkpoint, output_attentions=True,output_hidden_states=True))
        self.dropout = nn.Dropout(0.1) 
        self.classifier = nn.Linear(feature_dim,num_labels) # load and initialize weights

    def forward(self, input_ids, token_type_ids, attention_mask, labels, POS, BIO):
        #TODO: add 2 more parameters as: forward(self, input_ids, token_type_ids, attention_mask, POS, BIO, labels)
        #Extract outputs from the body
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
#         print("input_ids:::"+str(input_ids.shape))
#         print("labels:::"+str(labels.shape))
        #Add custom layers
        sequence_output = self.dropout(outputs[0]) #outputs[0]=last hidden state
        
#         POS_f = torch.eye(POS_len)[POS]
#         BIO_f = torch.eye(POS_len)[BIO]
        
#         sequence_output = torch.cat((sequence_output,POS_f, BIO_f),2)
        #sequence_output = outputs
#         print(sequence_output.size())
        """
        TODO: concat
        After using print(sequence_output.size()) here, we have:
        torch.Size([8, 44, 768])
        torch.Size([8, 50, 768]) etc.
        第一维度是全局batch_size我设置成8
        第二维是一句话里面token的个数，是经过分割OOV过后的token数量
        第三维是Bert的输出
        
        假设POS 48维，BIO 10维，那么总的第三维在concat后的总维数是 768+48+10，concat不会改变第一维和第二维的长度
        
        logits = self.classifier(sequence_output[:,:,:].view(-1, 768+48+10)) # calculate losses
        """
        
        logits = self.classifier(sequence_output[:,:,:].view(-1, feature_dim)) # calculate losses
        self.logits = logits
        #print("logits::"+str(logits.shape))

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states,attentions=outputs.attentions)

In [339]:
import torch
from tqdm import tqdm
from transformers import AdamW,get_scheduler
from datasets import load_metric
metric = evaluate.load("seqeval")

num_epochs = 10
model_cc = CustomModel(checkpoint=model_checkpoint,num_labels=len(label_names))#.cuda()
optimizer = AdamW(model_cc.parameters(), lr=2e-5)
num_training_steps = num_epochs * len(train_dataloader)
progress_bar_train = tqdm(range(num_training_steps),miniters=2)
progress_bar_eval = tqdm(range(num_epochs * len(eval_dataloader)),miniters=2)
f1_best = 0
resume_flag = False
best_net = None

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

for epoch in range(num_epochs):
    if resume_flag:
        model_cc.load_state_dict(torch.load("sstcls_best.dat"))
    model_cc.train()
    for batch in train_dataloader:
        batch = {k: v for k, v in batch.items()}#.cuda()
        outputs = model_cc(batch['input_ids'], token_type_ids=batch['token_type_ids'], attention_mask=batch['attention_mask'],labels=batch['labels'], POS=batch['POSL'], BIO=batch['BIOL'])
#         assert()
        #TODO: add 2 more parameters as followed:
        #model_cc(batch['input_ids'], token_type_ids=batch['token_type_ids'], attention_mask=batch['attention_mask'],POS = batch['POS'], BIO = batch['BIO'], labels=batch['labels'])
        
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar_train.update(1)
    model_cc.eval()
    f1_now = []
    for batch in eval_dataloader:
        batch = {k: v for k, v in batch.items()}#.cuda()
        with torch.no_grad():
            outputs = model_cc(batch['input_ids'], token_type_ids=batch['token_type_ids'], attention_mask=batch['attention_mask'],labels=batch['labels'], POS=batch['POSL'], BIO=batch['BIOL'])
            #TODO: add 2 more parameters as followed:
            #model_cc(batch['input_ids'], token_type_ids=batch['token_type_ids'], attention_mask=batch['attention_mask'],POS = batch['POS'], BIO = batch['BIO'], labels=batch['labels'])
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)

        true_labels = [[label_names[l] for l in label if l != -100] for label in batch["labels"]]
        
        reshaped_predictions = torch.reshape(predictions, (BatchSize,-1))
        true_predictions = [
            [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(reshaped_predictions, batch["labels"])
        ]
        
        all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
        metric.add_batch(predictions=true_predictions, references=true_labels)
        progress_bar_eval.update(1)
        f1_now.append(all_metrics["overall_f1"])
        
    if np.mean(f1_now) > f1_best or not best_net:
        torch.save(model_cc.state_dict(), 'sstcls_best.dat')
        f1_best = np.mean(f1_now)
        print("the best f1 is now: "+ str(np.mean(f1_now)))
        best_net = model_cc
        

    print(metric.compute())

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

  0%|                                                  | 0/640 [00:35<?, ?it/s][A


  0%|                                                  | 0/160 [00:35

the best f1 is now: 0.029824561403508774
{'RED': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'RG1': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'UPPORT': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'one': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 3}, 'overall_precision': 0.0, 'overall_recall': 0.0, 'overall_f1': 0.0, 'overall_accuracy': 0.9142857142857143}



 10%|████▏                                    | 65/640 [00:57<38:08,  3.98s/it][A
 10%|████▏                                    | 66/640 [00:57<31:30,  3.29s/it][A
 10%|████▎                                    | 67/640 [00:58<25:40,  2.69s/it][A
 11%|████▎                                    | 68/640 [00:59<20:43,  2.17s/it][A
 11%|████▍                                    | 69/640 [00:59<16:45,  1.76s/it][A
 11%|████▍                                    | 70/640 [01:00<13:42,  1.44s/it][A
 11%|████▌                                    | 71/640 [01:00<11:25,  1.21s/it][A
 11%|████▌                                    | 72/640 [01:01<09:42,  1.03s/it][A
 11%|████▋                                    | 73/640 [01:01<08:26,  1.12it/s][A
 12%|████▋                                    | 74/640 [01:02<07:33,  1.25it/s][A
 12%|████▊                                    | 75/640 [01:03<06:56,  1.36it/s][A
 12%|████▊                                    | 76/640 [01:03<06:29,  1.45it/s][A
 12

the best f1 is now: 0.4471864385446294
{'RED': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1}, 'RG1': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'UPPORT': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'one': {'precision': 0.5, 'recall': 0.3333333333333333, 'f1': 0.4, 'number': 3}, 'overall_precision': 0.6666666666666666, 'overall_recall': 0.3333333333333333, 'overall_f1': 0.4444444444444444, 'overall_accuracy': 0.9428571428571428}



 20%|████████                                | 129/640 [01:53<54:55,  6.45s/it][A
 20%|████████▏                               | 130/640 [01:54<39:50,  4.69s/it][A
 20%|████████▏                               | 131/640 [01:55<29:16,  3.45s/it][A
 21%|████████▎                               | 132/640 [01:55<21:53,  2.59s/it][A
 21%|████████▎                               | 133/640 [01:56<16:44,  1.98s/it][A
 21%|████████▍                               | 134/640 [01:56<13:10,  1.56s/it][A
 21%|████████▍                               | 135/640 [01:57<10:38,  1.26s/it][A
 21%|████████▌                               | 136/640 [01:57<08:53,  1.06s/it][A
 21%|████████▌                               | 137/640 [01:58<07:40,  1.09it/s][A
 22%|████████▋                               | 138/640 [01:59<06:49,  1.23it/s][A
 22%|████████▋                               | 139/640 [01:59<06:14,  1.34it/s][A
 22%|████████▊                               | 140/640 [02:00<05:46,  1.44it/s][A
 22

the best f1 is now: 0.6823637963343845
{'RED': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1}, 'RG1': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'UPPORT': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1}, 'one': {'precision': 0.6666666666666666, 'recall': 0.6666666666666666, 'f1': 0.6666666666666666, 'number': 3}, 'overall_precision': 0.8, 'overall_recall': 0.6666666666666666, 'overall_f1': 0.7272727272727272, 'overall_accuracy': 0.9714285714285714}



 30%|████████████                            | 193/640 [02:50<49:01,  6.58s/it][A
 30%|████████████▏                           | 194/640 [02:51<35:30,  4.78s/it][A
 30%|████████████▏                           | 195/640 [02:51<26:06,  3.52s/it][A
 31%|████████████▎                           | 196/640 [02:52<19:30,  2.64s/it][A
 31%|████████████▎                           | 197/640 [02:53<14:53,  2.02s/it][A
 31%|████████████▍                           | 198/640 [02:53<11:43,  1.59s/it][A
 31%|████████████▍                           | 199/640 [02:54<09:28,  1.29s/it][A
 31%|████████████▌                           | 200/640 [02:54<07:53,  1.08s/it][A
 31%|████████████▌                           | 201/640 [02:55<06:46,  1.08it/s][A
 32%|████████████▋                           | 202/640 [02:55<05:59,  1.22it/s][A
 32%|████████████▋                           | 203/640 [02:56<05:27,  1.33it/s][A
 32%|████████████▊                           | 204/640 [02:57<05:04,  1.43it/s][A
 32

the best f1 is now: 0.7248684573623666
{'RED': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1}, 'RG1': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'UPPORT': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1}, 'one': {'precision': 0.6666666666666666, 'recall': 0.6666666666666666, 'f1': 0.6666666666666666, 'number': 3}, 'overall_precision': 0.8, 'overall_recall': 0.6666666666666666, 'overall_f1': 0.7272727272727272, 'overall_accuracy': 0.9714285714285714}



 40%|████████████████                        | 257/640 [03:46<39:49,  6.24s/it][A
 40%|████████████████▏                       | 258/640 [03:47<28:54,  4.54s/it][A
 40%|████████████████▏                       | 259/640 [03:47<21:17,  3.35s/it][A
 41%|████████████████▎                       | 260/640 [03:48<15:56,  2.52s/it][A
 41%|████████████████▎                       | 261/640 [03:48<12:13,  1.94s/it][A
 41%|████████████████▍                       | 262/640 [03:49<09:37,  1.53s/it][A
 41%|████████████████▍                       | 263/640 [03:49<07:47,  1.24s/it][A
 41%|████████████████▌                       | 264/640 [03:50<06:31,  1.04s/it][A
 41%|████████████████▌                       | 265/640 [03:51<05:38,  1.11it/s][A
 42%|████████████████▋                       | 266/640 [03:51<05:01,  1.24it/s][A
 42%|████████████████▋                       | 267/640 [03:52<04:35,  1.35it/s][A
 42%|████████████████▊                       | 268/640 [03:52<04:16,  1.45it/s][A
 42

{'RED': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1}, 'RG1': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'UPPORT': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1}, 'one': {'precision': 0.6666666666666666, 'recall': 0.6666666666666666, 'f1': 0.6666666666666666, 'number': 3}, 'overall_precision': 0.8, 'overall_recall': 0.6666666666666666, 'overall_f1': 0.7272727272727272, 'overall_accuracy': 0.9714285714285714}



 50%|████████████████████                    | 321/640 [04:41<31:40,  5.96s/it][A
 50%|████████████████████▏                   | 322/640 [04:41<23:00,  4.34s/it][A
 50%|████████████████████▏                   | 323/640 [04:42<16:58,  3.21s/it][A
 51%|████████████████████▎                   | 324/640 [04:43<12:45,  2.42s/it][A
 51%|████████████████████▎                   | 325/640 [04:43<09:48,  1.87s/it][A
 51%|████████████████████▍                   | 326/640 [04:44<07:45,  1.48s/it][A
 51%|████████████████████▍                   | 327/640 [04:44<06:19,  1.21s/it][A
 51%|████████████████████▌                   | 328/640 [04:45<05:18,  1.02s/it][A
 51%|████████████████████▌                   | 329/640 [04:45<04:37,  1.12it/s][A
 52%|████████████████████▋                   | 330/640 [04:46<04:06,  1.26it/s][A
 52%|████████████████████▋                   | 331/640 [04:47<03:45,  1.37it/s][A
 52%|████████████████████▊                   | 332/640 [04:47<03:30,  1.46it/s][A
 52

{'RED': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1}, 'RG1': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'UPPORT': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1}, 'one': {'precision': 0.6666666666666666, 'recall': 0.6666666666666666, 'f1': 0.6666666666666666, 'number': 3}, 'overall_precision': 0.8, 'overall_recall': 0.6666666666666666, 'overall_f1': 0.7272727272727272, 'overall_accuracy': 0.9714285714285714}



 60%|████████████████████████                | 385/640 [05:36<25:18,  5.96s/it][A
 60%|████████████████████████▏               | 386/640 [05:36<18:23,  4.34s/it][A
 60%|████████████████████████▏               | 387/640 [05:37<13:32,  3.21s/it][A
 61%|████████████████████████▎               | 388/640 [05:37<10:11,  2.42s/it][A
 61%|████████████████████████▎               | 389/640 [05:38<07:49,  1.87s/it][A
 61%|████████████████████████▍               | 390/640 [05:39<06:10,  1.48s/it][A
 61%|████████████████████████▍               | 391/640 [05:39<05:01,  1.21s/it][A
 61%|████████████████████████▌               | 392/640 [05:40<04:12,  1.02s/it][A
 61%|████████████████████████▌               | 393/640 [05:40<03:38,  1.13it/s][A
 62%|████████████████████████▋               | 394/640 [05:41<03:15,  1.26it/s][A
 62%|████████████████████████▋               | 395/640 [05:41<02:58,  1.37it/s][A
 62%|████████████████████████▊               | 396/640 [05:42<02:46,  1.47it/s][A
 62

the best f1 is now: 0.7554963364277457
{'RED': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1}, 'RG1': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'UPPORT': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1}, 'one': {'precision': 0.6666666666666666, 'recall': 0.6666666666666666, 'f1': 0.6666666666666666, 'number': 3}, 'overall_precision': 0.8, 'overall_recall': 0.6666666666666666, 'overall_f1': 0.7272727272727272, 'overall_accuracy': 0.9714285714285714}



 70%|████████████████████████████            | 449/640 [06:31<19:52,  6.24s/it][A
 70%|████████████████████████████▏           | 450/640 [06:32<14:23,  4.55s/it][A
 70%|████████████████████████████▏           | 451/640 [06:33<10:34,  3.36s/it][A
 71%|████████████████████████████▎           | 452/640 [06:33<07:53,  2.52s/it][A
 71%|████████████████████████████▎           | 453/640 [06:34<06:02,  1.94s/it][A
 71%|████████████████████████████▍           | 454/640 [06:34<04:45,  1.53s/it][A
 71%|████████████████████████████▍           | 455/640 [06:35<03:50,  1.25s/it][A
 71%|████████████████████████████▌           | 456/640 [06:36<03:12,  1.05s/it][A
 71%|████████████████████████████▌           | 457/640 [06:36<02:45,  1.11it/s][A
 72%|████████████████████████████▋           | 458/640 [06:37<02:27,  1.24it/s][A
 72%|████████████████████████████▋           | 459/640 [06:37<02:13,  1.35it/s][A
 72%|████████████████████████████▊           | 460/640 [06:38<02:03,  1.45it/s][A
 72

{'RED': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1}, 'RG1': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'UPPORT': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1}, 'one': {'precision': 0.6666666666666666, 'recall': 0.6666666666666666, 'f1': 0.6666666666666666, 'number': 3}, 'overall_precision': 0.8, 'overall_recall': 0.6666666666666666, 'overall_f1': 0.7272727272727272, 'overall_accuracy': 0.9714285714285714}



 80%|████████████████████████████████        | 513/640 [07:26<12:35,  5.95s/it][A
 80%|████████████████████████████████▏       | 514/640 [07:27<09:06,  4.34s/it][A
 80%|████████████████████████████████▏       | 515/640 [07:28<06:40,  3.21s/it][A
 81%|████████████████████████████████▎       | 516/640 [07:28<05:00,  2.42s/it][A
 81%|████████████████████████████████▎       | 517/640 [07:29<03:50,  1.87s/it][A
 81%|████████████████████████████████▍       | 518/640 [07:29<03:00,  1.48s/it][A
 81%|████████████████████████████████▍       | 519/640 [07:30<02:26,  1.21s/it][A
 81%|████████████████████████████████▌       | 520/640 [07:30<02:02,  1.02s/it][A
 81%|████████████████████████████████▌       | 521/640 [07:31<01:46,  1.12it/s][A
 82%|████████████████████████████████▋       | 522/640 [07:32<01:34,  1.25it/s][A
 82%|████████████████████████████████▋       | 523/640 [07:32<01:25,  1.36it/s][A
 82%|████████████████████████████████▊       | 524/640 [07:33<01:19,  1.46it/s][A
 82

{'RED': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1}, 'RG1': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'UPPORT': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1}, 'one': {'precision': 0.6666666666666666, 'recall': 0.6666666666666666, 'f1': 0.6666666666666666, 'number': 3}, 'overall_precision': 0.8, 'overall_recall': 0.6666666666666666, 'overall_f1': 0.7272727272727272, 'overall_accuracy': 0.9714285714285714}



 90%|████████████████████████████████████    | 577/640 [08:21<06:14,  5.94s/it][A
 90%|████████████████████████████████████▏   | 578/640 [08:22<04:29,  4.34s/it][A
 90%|████████████████████████████████████▏   | 579/640 [08:22<03:16,  3.21s/it][A
 91%|████████████████████████████████████▎   | 580/640 [08:23<02:25,  2.42s/it][A
 91%|████████████████████████████████████▎   | 581/640 [08:24<01:50,  1.87s/it][A
 91%|████████████████████████████████████▍   | 582/640 [08:24<01:25,  1.48s/it][A
 91%|████████████████████████████████████▍   | 583/640 [08:25<01:09,  1.21s/it][A
 91%|████████████████████████████████████▌   | 584/640 [08:25<00:57,  1.02s/it][A
 91%|████████████████████████████████████▌   | 585/640 [08:26<00:49,  1.12it/s][A
 92%|████████████████████████████████████▋   | 586/640 [08:26<00:43,  1.25it/s][A
 92%|████████████████████████████████████▋   | 587/640 [08:27<00:39,  1.36it/s][A
 92%|████████████████████████████████████▊   | 588/640 [08:28<00:35,  1.45it/s][A
 92

{'RED': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1}, 'RG1': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'UPPORT': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1}, 'one': {'precision': 0.6666666666666666, 'recall': 0.6666666666666666, 'f1': 0.6666666666666666, 'number': 3}, 'overall_precision': 0.8, 'overall_recall': 0.6666666666666666, 'overall_f1': 0.7272727272727272, 'overall_accuracy': 0.9714285714285714}


In [340]:
def predict(net, dataloader, gpu):
    net.eval()

    f1_now = []
    precision_now = []
    recall_now = []
    accuracy_now = []
    for batch in eval_dataloader:
        batch = {k: v for k, v in batch.items()}#.cuda()
        with torch.no_grad():
            outputs = model_cc(batch['input_ids'], token_type_ids=batch['token_type_ids'], attention_mask=batch['attention_mask'], labels=batch['labels'], POS=batch['POSL'], BIO=batch['BIOL'])
            #TODO: add 2 more parameters as followed:
            #model_cc(batch['input_ids'], token_type_ids=batch['token_type_ids'], attention_mask=batch['attention_mask'],POS = batch['POS'], BIO = batch['BIO'], labels=batch['labels'])
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)

        true_labels = [[label_names[l] for l in label if l != -100] for label in batch["labels"]]
        
        reshaped_predictions = torch.reshape(predictions, (BatchSize,-1))
        true_predictions = [
            [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(reshaped_predictions, batch["labels"])
        ]
        
        all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
        metric.add_batch(predictions=true_predictions, references=true_labels)
        progress_bar_eval.update(1)
        f1_now.append(all_metrics["overall_f1"])
        precision_now.append(all_metrics["overall_precision"])
        recall_now.append(all_metrics["overall_recall"])
        accuracy_now.append(all_metrics["overall_accuracy"])


    return np.mean(f1_now), np.mean(precision_now), np.mean(recall_now), np.mean(accuracy_now),



In [341]:
preds = predict(best_net, test_dataloader, 0)
print("test_f1_score:::"+str(preds[0]))
print("test_precision_score:::"+str(preds[1]))
print("test_recall_score:::"+str(preds[2]))
print("test_accuracy_score:::"+str(preds[3]))



161it [09:18,  1.41s/it]                                                       [A[A

162it [09:19,  1.30s/it][A[A

163it [09:20,  1.23s/it][A[A

164it [09:21,  1.18s/it][A[A

165it [09:22,  1.14s/it][A[A

166it [09:23,  1.11s/it][A[A

167it [09:24,  1.10s/it][A[A

168it [09:25,  1.08s/it][A[A

169it [09:26,  1.07s/it][A[A

170it [09:27,  1.07s/it][A[A

171it [09:28,  1.06s/it][A[A

172it [09:29,  1.06s/it][A[A

173it [09:30,  1.06s/it][A[A

174it [09:31,  1.06s/it][A[A

175it [09:32,  1.06s/it][A[A

176it [09:34,  1.06s/it][A[A

test_f1_score:::0.7078461652481628
test_precision_score:::0.6690595698348256
test_recall_score:::0.776018772893773
test_accuracy_score:::0.9446854659849268
