In [2]:
import pandas as pd
import numpy as np

from code_.process_conll import process_file, advanced_process_file
from code_.evaluation import class_report_base, class_report_advanced, shrink_predictions
from code_.bert import Tokenizer, convert_to_dataset, get_labels_list_from_dataset
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers import DataCollatorForTokenClassification

from datasets import Dataset, load_metric

  from .autonotebook import tqdm as notebook_tqdm
  metric = load_metric("seqeval")


In [3]:
labels_list = ['ARG0', 'ARG1', 'ARG1-DSP', 'ARG2', 'ARG3', 'ARG4', 'ARG5', 'ARGA', 'ARGM-ADJ', 'ARGM-ADV', 'ARGM-CAU', 'ARGM-COM', 'ARGM-CXN', 'ARGM-DIR', 'ARGM-DIS', 'ARGM-EXT', 'ARGM-GOL', 'ARGM-LOC', 'ARGM-LVB', 'ARGM-MNR', 'ARGM-MOD', 'ARGM-NEG', 'ARGM-PRD', 'ARGM-PRP', 'ARGM-PRR', 'ARGM-REC', 'ARGM-TMP', 'C-ARG0', 'C-ARG1', 'C-ARG1-DSP', 'C-ARG2', 'C-ARG3', 'C-ARG4', 'C-ARGM-ADV', 'C-ARGM-COM', 'C-ARGM-CXN', 'C-ARGM-DIR', 'C-ARGM-EXT', 'C-ARGM-GOL', 'C-ARGM-LOC', 'C-ARGM-MNR', 'C-ARGM-PRP', 'C-ARGM-PRR', 'C-ARGM-TMP', 'C-V', 'R-ARG0', 'R-ARG1', 'R-ARG2', 'R-ARG3', 'R-ARG4', 'R-ARGM-ADJ', 'R-ARGM-ADV', 'R-ARGM-CAU', 'R-ARGM-COM', 'R-ARGM-DIR', 'R-ARGM-GOL', 'R-ARGM-LOC', 'R-ARGM-MNR', 'R-ARGM-TMP', 'V', '_']

In [4]:
metric = load_metric("seqeval")

def compute_metrics(p):

    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [labels_list[p] for (p, l) in zip(prediction, label) if l != -100 and p < len(labels_list)]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [labels_list[l] for (p, l) in zip(prediction, label) if l != -100 and p < len(labels_list)]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)

    return {
            "precision": results["overall_precision"],
            "recall": results["overall_recall"],
            "f1": results["overall_f1"],
            "accuracy": results["overall_accuracy"],
        }


In [5]:
tokenizer = Tokenizer("distilbert-base-uncased", labels_list)

In [6]:
model = AutoModelForTokenClassification.from_pretrained('model_checkpoints/baseline', num_labels=len(labels_list))
# model = AutoModelForTokenClassification.from_pretrained('model_checkpoints/advanced', num_labels=len(labels_list))

In [7]:
model_name = "distilbert-base-uncased"
args = TrainingArguments(
    "model_checkpoints/baseline",
    evaluation_strategy = 'epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    weight_decay=0.01,
    push_to_hub=False,
)
trainer = Trainer(
    model, args,
    # train_dataset=tokenized_datasets["train"],
    # eval_dataset=tokenized_datasets["validation"],
    data_collator=DataCollatorForTokenClassification(tokenizer.tokenizer, padding=True),
    tokenizer=tokenizer.tokenizer,
    compute_metrics=compute_metrics
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False)


In [12]:
def run_model(tok:Tokenizer, trainer:Trainer, examples:list, use_context:bool, save_file:str):
    if not use_context:
        tokenized_data = tok.tokenize_and_align_labels_pred(examples)
    else: 
        tokenized_data = tok.tokenize_and_align_labels_context(examples)
    # todo context

    dataset = Dataset.from_dict(tokenized_data)


    predictions_raw, labels_pred, _ = trainer.predict(dataset)

    predictions = np.argmax(predictions_raw, axis=2)

    list_predictions = [
        [labels_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels_pred)
    ]
    true_labels = [
        [labels_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels_pred)
    ]
    



    val_word_ids = [tok.tokenizer(sentence, truncation=False, is_split_into_words=True).word_ids() for sentence in examples['sentence']]

    df = pd.DataFrame(columns=['sentence', 'prediction', 'gold', 'word_ids'])
    for tokens, prediction, gold, word_ids in zip(tokenized_data['input_ids'], list_predictions, true_labels, val_word_ids):
        sentence = tok.tokenizer.decode(tokens)
        df.loc[len(df.index)] = [sentence, prediction, gold, word_ids]

    gold_restored = []
    pred_restored = []
    for i, row in df.iterrows():
        sentence = row[0]
        orig_sentence = sentence.split('[SEP]')[0].split(' ')[1:]
        prediction = row[1]
        gold = row[2]
        word_ids = row[3][1:-1]
        gold_restored.append(shrink_predictions(word_ids, gold))
        pred_restored.append(shrink_predictions(word_ids, prediction))

    df['gold_restored'] = gold_restored
    df['pred_restored'] = pred_restored

    print(df.columns)
    df.to_csv(save_file)
    class_report_base(save_file)


In [13]:
examples = {
    'sentence': [['some','fucking','sentence'], ['another','fucking','sentence']],
    'predicate': ['fucking', 'fucking'],
    'labels': ['ARG0, V, _', 'ARG0, V, _']
}


run_model(tokenizer,
          trainer,
          examples,
          use_context=False, 
          save_file='data/challenge_results/test.csv')

  0%|          | 0/1 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 1/1 [00:00<00:00, 62.50it/s]

Index(['sentence', 'prediction', 'gold', 'word_ids', 'gold_restored',
       'pred_restored'],
      dtype='object')
              precision    recall  f1-score   support

         'V'       1.00      1.00      1.00         2
         '_'       0.00      0.00      0.00         2
       'C-V'       0.00      0.00      0.00         0
      'ARGA'       0.00      0.00      0.00         0
      'ARG3'       0.00      0.00      0.00         0
      'ARG2'       0.00      0.00      0.00         0
      'ARG5'       0.00      0.00      0.00         0
      'ARG0'       0.00      0.00      0.00         2
      'ARG4'       0.00      0.00      0.00         0
      'ARG1'       0.00      0.00      0.00         0
    'C-ARG4'       0.00      0.00      0.00         0
    'C-ARG2'       0.00      0.00      0.00         0
    'R-ARG2'       0.00      0.00      0.00         0
    'C-ARG0'       0.00      0.00      0.00         0
    'R-ARG0'       0.00      0.00      0.00         0
    'C-ARG1'      


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
