In [1]:
from collections import defaultdict, Counter
import os, operator
from progressbar import progressbar as pb
from nltk import word_tokenize
from typing import List, Tuple, Dict
import json

from torch.utils.data import DataLoader
from transformers import XLMRobertaTokenizer, XLMRobertaForMaskedLM
import torch
import numpy as np
from sklearn.metrics import classification_report, f1_score, accuracy_score
from nltk.tokenize import WhitespaceTokenizer
from nltk import word_tokenize
import re
from copy import deepcopy
import pandas as pd

import gdown #! pip install gdown

In [None]:
! gdown https://drive.google.com/uc?id=1-ZYVbUN691AD6hsMbkJfGsrC3zSayU4z

In [None]:
! tar xzf trained_model_template_free.tar.gz

In [2]:
data_path = "../SemEval2022-Task11_Train-Dev/"
pretrained_model_path = "../../pretrained/xlm-roberta-large/"
trained_model_path = "trained_model_template_free/"
device = 'cuda'

max_len = 96
batch_size = 24
num_epochs = 25
patience = 2

In [3]:
with open("label2word_label.json") as f:
    label2word_label = json.load(f)
    
word_label2label = {}

for lang, val in label2word_label.items():
    word_label2label[lang] = {v:k for k, v in val.items()}

In [4]:
dev_files = []

for folder in os.listdir(data_path):
    files = os.listdir(os.path.join(data_path, folder))
    dev_file = files[0] if "dev" in files[0] else files[1]
    
    dev_files.append(os.path.join(data_path, folder, dev_file))
    
len(dev_files)

11

In [5]:
def parse_conll(files) -> Tuple[Dict, Dict]:
    
    fin_texts, fin_labels = {}, {}
    
    for filename in pb(files):
    
        with open(filename) as f:
            data = f.read().splitlines()

        lang = os.path.basename(filename).split("_")[0].upper()
        texts, labels = [], []

        for row in data:
            if row.startswith("# id "):
                new_texts, new_labels = [], []
                continue

            if row == "":
                texts.append(new_texts)
                labels.append(new_labels)

            else:
                parts = row.split()
                new_texts.append(parts[0])
                new_labels.append(parts[-1])
                
        fin_texts[lang] = texts
        fin_labels[lang] = labels

    return fin_texts, fin_labels

In [6]:
dev_texts, dev_labels = parse_conll(dev_files)

100% (11 of 11) |########################| Elapsed Time: 0:00:00 Time:  0:00:00


In [7]:
for texts, labels in zip(dev_texts.values(), dev_labels.values()):
    assert len(texts) == 800
    assert len(labels) == 800

# Prepare Target

In [8]:
def prepare_target(tokens, labels, lang, label2word_label=label2word_label):
    """
        Replace entities with label words
    """
    new_tokens = []
    for token, label in zip(tokens, labels):
        if label.startswith("B-"):
            prefix, tag = label.split("-")
            new_token = label2word_label[lang][tag]
            new_tokens.append(new_token)
        elif label.startswith("I-"):
            continue
        else:
            new_tokens.append(token)
    
    return new_tokens

In [9]:
dev_targets = defaultdict(list)

for lang in pb(dev_texts):        
    texts, labels = dev_texts[lang], dev_labels[lang]
    
    for text, label in zip(texts, labels):
        target = prepare_target(text, label, lang)
        dev_targets[lang].append(target)

100% (11 of 11) |########################| Elapsed Time: 0:00:00 Time:  0:00:00


In [10]:
for targets, labels in zip(dev_targets.values(), dev_labels.values()):
    assert len(targets) == 800
    assert len(labels) == 800

In [11]:
dev_data = []

for lang in dev_texts:
    for x, y in zip(dev_texts[lang], dev_targets[lang]):
        x = " ".join(x)
        y = " ".join(y)
        dev_data.append((x,y))
        
len(dev_data)

8800

In [15]:
dev_data[3200]

('важным традиционным промыслом является производство пальмового масла .',
 'важным традиционным промыслом является производство dvd .')

# Load Model

In [16]:
tokenizer = XLMRobertaTokenizer.from_pretrained(pretrained_model_path, local_files_only=True)

In [17]:
model = XLMRobertaForMaskedLM.from_pretrained(trained_model_path, local_files_only=True)
model.to(device)
model.eval()
print()




In [18]:
def encode_data(data, max_len=max_len):
    X_data, y_data = [], []

    for item in pb(data):
        x, y = item
        x_enc = tokenizer.encode(x, max_length=max_len, padding="max_length", truncation=True)
        y_enc = tokenizer.encode(y, max_length=max_len, padding="max_length", truncation=True)
        X_data.append(x_enc)
        y_data.append(y_enc)
        
    return np.array(X_data), np.array(y_data)
    
X_dev, y_dev = encode_data(dev_data)
X_dev.shape, y_dev.shape

100% (8800 of 8800) |####################| Elapsed Time: 0:00:05 Time:  0:00:05


((8800, 96), (8800, 96))

In [19]:
dev = np.stack((X_dev, y_dev), axis=1)
dev_batches = DataLoader(dev, batch_size=batch_size, shuffle=False)
dev.shape

(8800, 2, 96)

In [20]:
def get_predictions(trained_model=model, batches=dev_batches, tokenizer=tokenizer):
    pred_labels = []

    for item in pb(batches):
        item = item[:, 0, :]
        out = trained_model(item.to(device))
        logits = out.logits
        tokens_encoded = logits.argmax(axis=-1).tolist()
        for enc in tokens_encoded:
            decoded = tokenizer.decode(enc, skip_special_tokens=True)
            pred_labels.append(decoded)
        
    return pred_labels

In [22]:
dev_x = [item[0] for item in dev_data]
dev_y_true = [item[1] for item in dev_data]
dev_y_pred = get_predictions(model, dev_batches)

100% (367 of 367) |######################| Elapsed Time: 0:01:32 Time:  0:01:32


In [23]:
idd = 800*4 + 70 # 800*2 + 80    -- *495
dev_y_pred[idd]

'настоящие лемуры почти исключительно травоядны : они питаются цветами, dvd, листьями, однако в неволе известны примеры питания насекомыми.'

In [24]:
dev_y_true[idd]

'настоящие лемуры почти исключительно травоядны : они питаются цветами , dvd , листьями , однако в неволе известны примеры питания насекомыми .'

In [25]:
dev_x[idd]

'настоящие лемуры почти исключительно травоядны : они питаются цветами , фрукт , листьями , однако в неволе известны примеры питания насекомыми .'

In [26]:
sentence2lang = {}

for key, items in dev_texts.items():
    for item in items:
        sentence2lang[" ".join(item)] = key

In [27]:
dev_texts2labels = defaultdict(dict)

for lang in dev_texts:
    for text, labels in zip(dev_texts[lang], dev_labels[lang]):
        dev_texts2labels[lang][" ".join(text)] = labels

In [28]:
def build_pred_labels(input_sent, 
                      pred_sent, 
                      word_label2label=word_label2label, 
                      sentence2lang=sentence2lang,
                      nltk_tokenizer=WhitespaceTokenizer()):
    
    lang = sentence2lang[input_sent]
    res, matched_spans = [], []

    pred_tokens = word_tokenize(pred_sent)
    input_tokens = input_sent.split()

    i, j = 0, 0  
    res = []
    
    while i < len(pred_tokens) and j < len(input_tokens):
        
        pred = pred_tokens[i]
        inp = input_tokens[j]
        if pred == inp and pred not in word_label2label[lang]:
            res.append("O")
            i += 1
            j += 1
        elif pred in word_label2label[lang]:
            res.append(word_label2label[lang][pred])
            j += 1
            while j < len(input_tokens) and i+1 < len(pred_tokens) and input_tokens[j] != pred_tokens[i+1]:
                res.append(word_label2label[lang][pred])
                j += 1
                
            i += 1
        elif pred != inp:
            break

    if len(res) < len(input_tokens):
        res.extend(["O"] * (len(input_tokens) - len(res)))
    
    return res

In [29]:
def add_iob(labels):
    iob_labels = []
    
    if labels[0] != "O":
        iob_labels.append("B-" + labels[0])
    else:
        iob_labels.append("O")
    
    for i, label in enumerate(labels[1:], 1):

        if label == "O":
            iob_labels.append("O")
            
        elif labels[i-1] == "O" or (label != labels[i-1] and labels[i-1] != "O"):
            iob_labels.append("B-" + label)
            
        elif not labels[i-1].startswith("O") and label != "O":
            iob_labels.append("I-" + label)
        
        
    return iob_labels

In [30]:
broken = 0
dev_labels_pred = defaultdict(list)

for x, y_pred in zip(dev_x, dev_y_pred):
    lang = sentence2lang[x]
    
    true_labels = dev_texts2labels[lang][x]
    try:
        pred_labels = build_pred_labels(x, y_pred)
    except:
        broken += 1
        pred_labels = ["O"] * len(true_labels)
        
    pred_labels = add_iob(pred_labels)
    assert len(true_labels) == len(pred_labels)
    
    dev_labels_pred[lang].append((true_labels, pred_labels))
broken

0

# Calc metrics by span

In [31]:
from collections import defaultdict
from typing import Set
from overrides import overrides

from allennlp.training.metrics.metric import Metric


class SpanF1(Metric):
    def __init__(self, non_entity_labels=['O']) -> None:
        self._num_gold_mentions = 0
        self._num_recalled_mentions = 0
        self._num_predicted_mentions = 0
        self._TP, self._FP, self._GT = defaultdict(int), defaultdict(int), defaultdict(int)
        self.non_entity_labels = set(non_entity_labels)

    @overrides
    def __call__(self, batched_predicted_spans, batched_gold_spans, sentences=None):
        non_entity_labels = self.non_entity_labels
        for predicted_spans, gold_spans in zip(batched_predicted_spans, batched_gold_spans):
            gold_spans_set = set([x for x, y in gold_spans.items() if y not in non_entity_labels])
            pred_spans_set = set([x for x, y in predicted_spans.items() if y not in non_entity_labels])

            self._num_gold_mentions += len(gold_spans_set)
            self._num_recalled_mentions += len(gold_spans_set & pred_spans_set)
            self._num_predicted_mentions += len(pred_spans_set)

            for ky, val in gold_spans.items():
                if val not in non_entity_labels:
                    self._GT[val] += 1

            for ky, val in predicted_spans.items():
                if val in non_entity_labels:
                    continue
                if ky in gold_spans and val == gold_spans[ky]:
                    self._TP[val] += 1
                else:
                    self._FP[val] += 1

    @overrides
    def get_metric(self, reset: bool = False) -> float:
        all_tags: Set[str] = set()
        all_tags.update(self._TP.keys())
        all_tags.update(self._FP.keys())
        all_tags.update(self._GT.keys())
        all_metrics = {}

        for tag in all_tags:
            precision, recall, f1_measure = self.compute_prf_metrics(true_positives=self._TP[tag],
                                                                     false_negatives=self._GT[tag] - self._TP[tag],
                                                                     false_positives=self._FP[tag])
            all_metrics['P@{}'.format(tag)] = precision
            all_metrics['R@{}'.format(tag)] = recall
            all_metrics['F1@{}'.format(tag)] = f1_measure

        # Compute the precision, recall and f1 for all spans jointly.
        precision, recall, f1_measure = self.compute_prf_metrics(true_positives=sum(self._TP.values()),
                                                                 false_positives=sum(self._FP.values()),
                                                                 false_negatives=sum(self._GT.values())-sum(self._TP.values()))
        all_metrics["micro@P"] = precision
        all_metrics["micro@R"] = recall
        all_metrics["micro@F1"] = f1_measure

        if self._num_gold_mentions == 0:
            entity_recall = 0.0
        else:
            entity_recall = self._num_recalled_mentions / float(self._num_gold_mentions)

        if self._num_predicted_mentions == 0:
            entity_precision = 0.0
        else:
            entity_precision = self._num_recalled_mentions / float(self._num_predicted_mentions)

        all_metrics['MD@R'] = entity_recall
        all_metrics['MD@P'] = entity_precision
        all_metrics['MD@F1'] = 2. * ((entity_precision * entity_recall) / (entity_precision + entity_recall + 1e-13))
        all_metrics['ALLTRUE'] = self._num_gold_mentions
        all_metrics['ALLRECALLED'] = self._num_recalled_mentions
        all_metrics['ALLPRED'] = self._num_predicted_mentions
        if reset:
            self.reset()
        return all_metrics

    @staticmethod
    def compute_prf_metrics(true_positives: int, false_positives: int, false_negatives: int):
        precision = float(true_positives) / float(true_positives + false_positives + 1e-13)
        recall = float(true_positives) / float(true_positives + false_negatives + 1e-13)
        f1_measure = 2. * ((precision * recall) / (precision + recall + 1e-13))
        return precision, recall, f1_measure

    @overrides
    def reset(self):
        self._num_gold_mentions = 0
        self._num_recalled_mentions = 0
        self._num_predicted_mentions = 0
        self._TP.clear()
        self._FP.clear()
        self._GT.clear()

In [32]:
def get_spans(labels):
    fin_spans = []
    for item_ in labels:

        item = deepcopy(item_)
        item.insert(0, "O")
        item.append("O")

        new_spans = {}
        for i, label in enumerate(item[1:-1], 1):

            if item[i] == "O":
                new_spans[(i-1, i-1)] = "O"
            else:
                if item[i-1] == 'O':
                    start_i = i
                if item[i+1] == 'O':
                    new_spans[(start_i-1, i-1)] = item[i].split('-')[1]
                    
        fin_spans.append(new_spans)
                
    return fin_spans    

In [33]:
metrics = {}

for lang in ["BN", "DE", "ES", "TR", "FA", "RU", "ZH", "NL", "KO", "EN", "HI"]:
    true_spans = get_spans([true for true, pred in dev_labels_pred[lang]])
    pred_spans = get_spans([pred for true, pred in dev_labels_pred[lang]])
    
    span_f1 = SpanF1()
    span_f1(pred_spans, true_spans)
    cur_metric = span_f1.get_metric()
    metrics[lang] = cur_metric

In [34]:
pd.options.display.float_format = '{:.3f}'.format
df = pd.DataFrame(index=list(metrics["RU"].keys()))

for lang, metric in metrics.items():
    df[lang] = list(metric.values())
  

In [35]:
df

Unnamed: 0,BN,DE,ES,TR,FA,RU,ZH,NL,KO,EN,HI
P@PER,0.797,0.776,0.885,0.827,0.792,0.86,0.5,0.923,0.805,0.884,0.788
R@PER,0.653,0.593,0.687,0.709,0.61,0.672,0.048,0.742,0.606,0.782,0.6
F1@PER,0.718,0.672,0.773,0.763,0.689,0.754,0.087,0.822,0.692,0.83,0.681
P@CW,0.667,0.672,0.748,0.855,0.566,0.743,0.0,0.752,0.762,0.76,0.632
R@CW,0.367,0.441,0.635,0.602,0.554,0.627,0.0,0.693,0.654,0.648,0.429
F1@CW,0.473,0.532,0.687,0.707,0.56,0.68,0.0,0.721,0.704,0.699,0.511
P@GRP,0.802,0.802,0.779,0.903,0.0,0.802,0.455,0.744,0.805,0.864,0.791
R@GRP,0.619,0.519,0.695,0.564,0.0,0.593,0.208,0.827,0.549,0.741,0.623
F1@GRP,0.699,0.63,0.734,0.694,0.0,0.682,0.286,0.784,0.653,0.798,0.697
P@LOC,0.741,0.82,0.922,0.853,0.685,0.829,0.813,0.906,0.523,0.901,0.798
