In [1]:
from collections import defaultdict, Counter
import os, operator
from progressbar import progressbar as pb
from nltk import word_tokenize
from typing import List, Tuple, Dict
import json

from torch.utils.data import DataLoader
from transformers import XLMRobertaTokenizer, XLMRobertaForMaskedLM
import torch
import numpy as np
from sklearn.metrics import classification_report, f1_score, accuracy_score
from nltk.tokenize import WhitespaceTokenizer
from nltk import word_tokenize
import re
from copy import deepcopy
import pandas as pd

# Read data from disk

In [2]:
data_path = "../SemEval2022-Task11_Train-Dev/"
model_path = "../../pretrained/xlm-roberta-large/"
SAVE_PATH = "models/template_free/multilang_xlm-r/"
device = 'cuda'

max_len = 96
batch_size = 24
num_epochs = 25
patience = 2

In [3]:
train_files, dev_files = [], []

for folder in os.listdir(data_path):
    files = os.listdir(os.path.join(data_path, folder))
    train_file = files[0] if "train" in files[0] else files[1]
    dev_file = files[0] if "dev" in files[0] else files[1]
    
    train_files.append(os.path.join(data_path, folder, train_file))
    dev_files.append(os.path.join(data_path, folder, dev_file))
    
len(train_files), len(dev_files)

(11, 11)

In [4]:
def parse_conll(files) -> Tuple[Dict, Dict]:
    
    fin_texts, fin_labels = {}, {}
    
    for filename in pb(files):
    
        with open(filename) as f:
            data = f.read().splitlines()

        lang = os.path.basename(filename).split("_")[0].upper()
        texts, labels = [], []

        for row in data:
            if row.startswith("# id "):
                new_texts, new_labels = [], []
                continue

            if row == "":
                texts.append(new_texts)
                labels.append(new_labels)

            else:
                parts = row.split()
                new_texts.append(parts[0])
                new_labels.append(parts[-1])
                
        fin_texts[lang] = texts
        fin_labels[lang] = labels

    return fin_texts, fin_labels

In [5]:
train_texts, train_labels = parse_conll(train_files)
dev_texts, dev_labels = parse_conll(dev_files)

100% (11 of 11) |########################| Elapsed Time: 0:00:03 Time:  0:00:03
100% (11 of 11) |########################| Elapsed Time: 0:00:00 Time:  0:00:00


In [6]:
lens = []

for lang, texts in train_texts.items():
    for t in texts:
        lens.append(len(t))

In [7]:
for texts, labels in zip(train_texts.values(), train_labels.values()):
    assert len(texts) == 15300
    assert len(labels) == 15300

In [8]:
for texts, labels in zip(dev_texts.values(), dev_labels.values()):
    assert len(texts) == 800
    assert len(labels) == 800

# Find label words distribution

In [13]:
def parse_entity(text, labels):
    res = []
    new_text = []
    
    for t, l in zip(text, labels):
        if l.startswith("B-"):
            new_text = [t]
            new_label = l.split("-")[-1]
            
        elif l.startswith("I-"):
            new_text.append(t)
            
        elif l == 'O' and len(new_text):
            new_text = " ".join(new_text)
            res.append((new_text, new_label))
            new_text, new_label = [], []
            
    new_text = " ".join(new_text)
    if new_text:
        res.append((new_text, new_label))
            
    return res

In [17]:
freq_dct_by_lang = dict()

for lang in train_texts.keys():
    texts = train_texts[lang]
    labels = train_labels[lang]
    
    freq_dct = defaultdict(list)
    
    for text, label in zip(texts, labels):
        parsed_entities = parse_entity(text, label)
        for item in parsed_entities:
            t, l = item
            freq_dct[l].append(t)
            
    for k, v in freq_dct.items():
        freq_dct[k] = sorted(dict(Counter(v)).items(), key=operator.itemgetter(1), reverse=True)
        
    freq_dct_by_lang[lang] = freq_dct

In [18]:
label2word_label = defaultdict(dict)

for lang, val in freq_dct_by_lang.items():
    for label, label_words in val.items():
        for word, count in label_words:
            if " " not in word:
                label2word_label[lang][label] = word
                break
            
word_label2label = {}

for lang, val in label2word_label.items():
    word_label2label[lang] = {v:k for k, v in val.items()}

In [19]:
label2word_label["EN"]

{'PER': 'jesus',
 'GRP': 'nato',
 'CW': 'single',
 'LOC': 'france',
 'CORP': 'bbc',
 'PROD': 'stucco'}

In [20]:
word_label2label["RU"]

{'колхоз': 'GRP',
 'женщин': 'PER',
 'сингл': 'CW',
 'dvd': 'PROD',
 'mtv': 'CORP',
 'париж': 'LOC'}

# Prepare Target

In [21]:
def prepare_target(tokens, labels, lang, label2word_label=label2word_label):
    """
        Replace entities with label words
    """
    new_tokens = []
    for token, label in zip(tokens, labels):
        if label.startswith("B-"):
            prefix, tag = label.split("-")
            new_token = label2word_label[lang][tag]
            new_tokens.append(new_token)
        elif label.startswith("I-"):
            continue
        else:
            new_tokens.append(token)
    
    return new_tokens

In [23]:
train_targets, dev_targets = defaultdict(list), defaultdict(list)

for lang in pb(train_texts):
    texts, labels = train_texts[lang], train_labels[lang]
    
    for text, label in zip(texts, labels):
        target = prepare_target(text, label, lang)
        train_targets[lang].append(target)
        
    texts, labels = dev_texts[lang], dev_labels[lang]
    
    for text, label in zip(texts, labels):
        target = prepare_target(text, label, lang)
        dev_targets[lang].append(target)

100% (11 of 11) |########################| Elapsed Time: 0:00:01 Time:  0:00:01


In [24]:
for targets, labels in zip(train_targets.values(), train_labels.values()):
    assert len(targets) == 15300
    assert len(labels) == 15300

In [25]:
for targets, labels in zip(dev_targets.values(), dev_labels.values()):
    assert len(targets) == 800
    assert len(labels) == 800

In [27]:
train_data, dev_data = [], []

for lang in train_texts:
    for x, y in zip(train_texts[lang], train_targets[lang]):
        x = " ".join(x)
        y = " ".join(y)
        train_data.append((x,y))
    for x, y in zip(dev_texts[lang], dev_targets[lang]):
        x = " ".join(x)
        y = " ".join(y)
        dev_data.append((x,y))
        
len(train_data), len(dev_data)

(168300, 8800)

In [28]:
train_data[0]

('퀸 메리 런던 대학교 주도하는 18개의 영국대학 연합에 의해 건설되었고 , 2009년 12월에 유럽 남방 천문대 인도되었다 .',
 '퀸 메리 초등학교 주도하는 18개의 영국대학 연합에 의해 건설되었고 , 2009년 12월에 초등학교 인도되었다 .')

In [20]:
dev_id_to_entity_map = defaultdict(dict)

for lang in train_texts:
    for i, (text, labels) in enumerate(zip(dev_texts[lang], dev_labels[lang])):
        dev_id_to_entity_map[lang][i] = defaultdict(list)
        for token, label in zip(text, labels):
            if label == "O":
                continue
            prefix, tag = label.split("-")
            if prefix == "B":
                dev_id_to_entity_map[lang][i][tag].append([])
                dev_id_to_entity_map[lang][i][tag][-1].append(label)
            elif prefix == "I":
                dev_id_to_entity_map[lang][i][tag][-1].append(label)

In [21]:
dev_id_to_entity_map["RU"][0]

defaultdict(list, {'PROD': [['B-PROD', 'I-PROD']]})

# Prepare for Training

In [22]:
tokenizer = XLMRobertaTokenizer.from_pretrained(model_path, local_files_only=True)

In [23]:
def encode_data(data, max_len=max_len):
    X_data, y_data = [], []

    for item in pb(data):
        x, y = item
        x_enc = tokenizer.encode(x, max_length=max_len, padding="max_length", truncation=True)
        y_enc = tokenizer.encode(y, max_length=max_len, padding="max_length", truncation=True)
        X_data.append(x_enc)
        y_data.append(y_enc)
        
    return np.array(X_data), np.array(y_data)
    
X_train, y_train = encode_data(train_data)
X_dev, y_dev = encode_data(dev_data)
X_train.shape, y_train.shape, X_dev.shape, y_dev.shape

100% (168300 of 168300) |################| Elapsed Time: 0:01:10 Time:  0:01:10
100% (8800 of 8800) |####################| Elapsed Time: 0:00:03 Time:  0:00:03


((168300, 96), (168300, 96), (8800, 96), (8800, 96))

In [24]:
train = np.stack((X_train, y_train), axis=1)
dev = np.stack((X_dev, y_dev), axis=1)
train.shape, dev.shape

((168300, 2, 96), (8800, 2, 96))

In [25]:
train_batches = DataLoader(train, batch_size=batch_size, shuffle=True)
dev_batches = DataLoader(dev, batch_size=batch_size, shuffle=False)

# Training...

In [26]:
model = XLMRobertaForMaskedLM.from_pretrained(model_path, local_files_only=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

model.to(device)
sum(p.numel() for p in model.parameters())

560142482

In [27]:
def _train(model, train_loader, optimizer, epoch_num):
    train_loss, train_acc, train_f1 = [], [], []
    model.train()

    for batch_num, batch in enumerate(train_loader):
        X_batch, y_batch = batch[:, 0, :], batch[:, 1, :]
        X_batch = X_batch.type(torch.LongTensor).to(device)
        y_batch = y_batch.type(torch.LongTensor).to(device)
        optimizer.zero_grad()

        out = model(input_ids=X_batch, labels=y_batch.contiguous(), return_dict=True)
        loss = out.loss
        y_pred = out.logits
        y_pred = torch.argmax(y_pred, dim=2)

        y_pred_flatten = torch.flatten(y_pred).cpu().numpy()
        y_batch_flatten = torch.flatten(y_batch).cpu().numpy()
        f1 = f1_score(y_batch_flatten, y_pred_flatten, average="macro")
        accuracy = accuracy_score(y_batch_flatten, y_pred_flatten)

        train_loss.append(loss.item())
        train_acc.append(accuracy)
        train_f1.append(f1)
        
        if batch_num % 100 == 0:
            print(f"TRAIN: Epoch: {epoch_num}, Batch: {batch_num + 1} / {len(train_loader)}, "
                          f"Loss: {loss.item():.3f}, Accuracy: {accuracy:.3f}, F1: {f1:.3f}")

        loss.backward()
        optimizer.step()

    return np.mean(train_loss), np.mean(train_acc), np.mean(train_f1)

def _val(model, val_loader, epoch_num):
    val_loss, val_acc, val_f1 = [], [], []
    model.eval()

    for batch_num, batch in enumerate(val_loader):
        X_batch, y_batch = batch[:, 0, :], batch[:, 1, :]
        X_batch = X_batch.type(torch.LongTensor).to(device)
        y_batch = y_batch.type(torch.LongTensor).to(device)

        out = model(input_ids=X_batch, labels=y_batch.contiguous())
        loss = out.loss
        y_pred = out.logits
        y_pred = torch.argmax(y_pred, dim=2)

        y_pred_flatten = torch.flatten(y_pred).cpu().numpy()
        y_batch_flatten = torch.flatten(y_batch).cpu().numpy()
        f1 = f1_score(y_batch_flatten, y_pred_flatten, average="macro")
        accuracy = accuracy_score(y_batch_flatten, y_pred_flatten)

        val_loss.append(loss.item())
        val_acc.append(accuracy)
        val_f1.append(f1)
        
        if batch_num % 100 == 0:
            print(f"VAL: Epoch: {epoch_num}, Batch: {batch_num + 1} / {len(val_loader)}, "
                          f"Loss: {loss.item():.3f}, Accuracy: {accuracy:.3f}, F1: {f1:.3f}")

    return np.mean(val_loss), np.mean(val_acc), np.mean(val_f1)

In [28]:
last_epoch = 0
dev_losses = []
dev_f1s = []

for epoch in range(1, num_epochs + 1):
    train_loss, train_acc, train_f1 = _train(model, train_batches, optimizer, epoch)
    dev_loss, dev_acc, dev_f1 = _val(model, dev_batches, epoch)

    if len(dev_f1s) == 0 or dev_f1 >= dev_f1s[-1]:
        model.save_pretrained(SAVE_PATH)

    elif last_epoch == 0:
        last_epoch = epoch + patience

    print(f"After epoch #{epoch}:")
    print(f"Train loss: {train_loss:.3f}, Train Accuracy: {train_acc:.3f}, Train F1: {train_f1:.3f}")
    print(f"Dev loss: {dev_loss:.3f}, Dev Accuracy: {dev_acc:.3f}, Dev F1: {dev_f1:.3f}\n")

    dev_losses.append(dev_loss)
    dev_f1s.append(dev_f1)
    if epoch == last_epoch:
        break

TRAIN: Epoch: 1, Batch: 1 / 7013, Loss: 18.653, Accuracy: 0.138, F1: 0.206
TRAIN: Epoch: 1, Batch: 101 / 7013, Loss: 1.393, Accuracy: 0.838, F1: 0.356
TRAIN: Epoch: 1, Batch: 201 / 7013, Loss: 1.396, Accuracy: 0.818, F1: 0.318
TRAIN: Epoch: 1, Batch: 301 / 7013, Loss: 1.105, Accuracy: 0.844, F1: 0.392
TRAIN: Epoch: 1, Batch: 401 / 7013, Loss: 1.490, Accuracy: 0.785, F1: 0.259
TRAIN: Epoch: 1, Batch: 501 / 7013, Loss: 0.992, Accuracy: 0.845, F1: 0.354
TRAIN: Epoch: 1, Batch: 601 / 7013, Loss: 1.085, Accuracy: 0.822, F1: 0.318
TRAIN: Epoch: 1, Batch: 701 / 7013, Loss: 1.036, Accuracy: 0.836, F1: 0.338
TRAIN: Epoch: 1, Batch: 801 / 7013, Loss: 0.921, Accuracy: 0.829, F1: 0.389
TRAIN: Epoch: 1, Batch: 901 / 7013, Loss: 1.079, Accuracy: 0.802, F1: 0.354
TRAIN: Epoch: 1, Batch: 1001 / 7013, Loss: 0.983, Accuracy: 0.810, F1: 0.323
TRAIN: Epoch: 1, Batch: 1101 / 7013, Loss: 0.747, Accuracy: 0.864, F1: 0.434
TRAIN: Epoch: 1, Batch: 1201 / 7013, Loss: 0.558, Accuracy: 0.886, F1: 0.475
TRAIN: Epo

# Decode Predictions

In [29]:
torch.cuda.empty_cache()
import gc
del model
gc.collect()

866

In [30]:
model = XLMRobertaForMaskedLM.from_pretrained(SAVE_PATH, local_files_only=True)
model.to(device)
model.eval()
print()




In [31]:
def get_predictions(trained_model=model, batches=dev_batches, tokenizer=tokenizer):
    pred_labels = []

    for item in pb(batches):
        item = item[:, 0, :]
        out = trained_model(item.to(device))
        logits = out.logits
        tokens_encoded = logits.argmax(axis=-1).tolist()
        for enc in tokens_encoded:
            decoded = tokenizer.decode(enc, skip_special_tokens=True)
            pred_labels.append(decoded)
        
    return pred_labels

In [32]:
dev_x = [item[0] for item in dev_data]
dev_y_true = [item[1] for item in dev_data]
dev_y_pred = get_predictions(model, dev_batches)

100% (367 of 367) |######################| Elapsed Time: 0:01:31 Time:  0:01:31


In [39]:
idd = 800*4 + 70 # 800*2 + 80    -- *495
dev_y_pred[idd]

'настоящие лемуры почти исключительно травоядны : они питаются цветами, dvd, листьями, однако в неволе известны примеры питания насекомыми.'

In [40]:
dev_y_true[idd]

'настоящие лемуры почти исключительно травоядны : они питаются цветами , dvd , листьями , однако в неволе известны примеры питания насекомыми .'

In [41]:
dev_x[idd]

'настоящие лемуры почти исключительно травоядны : они питаются цветами , фрукт , листьями , однако в неволе известны примеры питания насекомыми .'

In [42]:
sentence2lang = {}

for key, items in dev_texts.items():
    for item in items:
        sentence2lang[" ".join(item)] = key

In [43]:
dev_texts2labels = defaultdict(dict)

for lang in dev_texts:
    for text, labels in zip(dev_texts[lang], dev_labels[lang]):
        dev_texts2labels[lang][" ".join(text)] = labels

In [44]:
word_label2label["EN"]

{'jesus': 'PER',
 'nato': 'GRP',
 'single': 'CW',
 'france': 'LOC',
 'bbc': 'CORP',
 'stucco': 'PROD'}

In [45]:
def build_pred_labels(input_sent, 
                      pred_sent, 
                      word_label2label=word_label2label, 
                      sentence2lang=sentence2lang,
                      nltk_tokenizer=WhitespaceTokenizer()
                     ):
    
    lang = sentence2lang[input_sent]
    res, matched_spans = [], []

    pred_tokens = word_tokenize(pred_sent)
    input_tokens = input_sent.split()
#     print(pred_tokens)
#     print(input_tokens)
#     print(len(pred_tokens), len(input_tokens))

    i, j = 0, 0  
    res = []
    
    while i < len(pred_tokens) and j < len(input_tokens):
        
        pred = pred_tokens[i]
        inp = input_tokens[j]
#         print(pred, i)
#         print(inp, j)
#         print(res)
#         print()
        if pred == inp and pred not in word_label2label[lang]:
            res.append("O")
            i += 1
            j += 1
        elif pred in word_label2label[lang]:
            res.append(word_label2label[lang][pred])
            j += 1
            while j < len(input_tokens) and i+1 < len(pred_tokens) and input_tokens[j] != pred_tokens[i+1]:
#                 print(j, i)
                res.append(word_label2label[lang][pred])
                j += 1
                
            i += 1
        elif pred != inp:
            break
#         input()

    if len(res) < len(input_tokens):
        res.extend(["O"] * (len(input_tokens) - len(res)))
    
    return res

In [46]:
def add_iob(labels):
    iob_labels = []
    
    if labels[0] != "O":
        iob_labels.append("B-" + labels[0])
    else:
        iob_labels.append("O")
    
    for i, label in enumerate(labels[1:], 1):

        if label == "O":
            iob_labels.append("O")
            
        elif labels[i-1] == "O" or (label != labels[i-1] and labels[i-1] != "O"):
            iob_labels.append("B-" + label)
            
        elif not labels[i-1].startswith("O") and label != "O":
            iob_labels.append("I-" + label)
        
        
    return iob_labels

In [47]:
x = "파일 : nyk maritime museum01s3200 . jpg | 일본우선 요코하마지점 ( 나카구 )"
y = "파일 : nyk maritime museum01s3200. jpg | 일본우선 요코하마지점 ( 대한민국"

build_pred_labels(x, y)

['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'LOC', 'O']

In [48]:
dev_texts2labels['KO'][x]

['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-CORP', 'O', 'O', 'O', 'O']

In [49]:
predd = add_iob(build_pred_labels(dev_x[idd], dev_y_pred[idd]))
truee = dev_texts2labels['RU'][dev_x[idd]]
predd == truee

True

In [50]:
print(predd)
print(truee)

['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-PROD', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-PROD', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


In [51]:
broken = 0
dev_labels_pred = defaultdict(list)

for x, y_pred in zip(dev_x, dev_y_pred):
    lang = sentence2lang[x]
    
    true_labels = dev_texts2labels[lang][x]
    try:
        pred_labels = build_pred_labels(x, y_pred)
    except:
        broken += 1
        pred_labels = ["O"] * len(true_labels)
        
    pred_labels = add_iob(pred_labels)
    assert len(true_labels) == len(pred_labels)
    
    dev_labels_pred[lang].append((true_labels, pred_labels))
broken

0

# Calc metrics by span

In [52]:
from collections import defaultdict
from typing import Set
from overrides import overrides

from allennlp.training.metrics.metric import Metric


class SpanF1(Metric):
    def __init__(self, non_entity_labels=['O']) -> None:
        self._num_gold_mentions = 0
        self._num_recalled_mentions = 0
        self._num_predicted_mentions = 0
        self._TP, self._FP, self._GT = defaultdict(int), defaultdict(int), defaultdict(int)
        self.non_entity_labels = set(non_entity_labels)

    @overrides
    def __call__(self, batched_predicted_spans, batched_gold_spans, sentences=None):
        non_entity_labels = self.non_entity_labels
        for predicted_spans, gold_spans in zip(batched_predicted_spans, batched_gold_spans):
            gold_spans_set = set([x for x, y in gold_spans.items() if y not in non_entity_labels])
            pred_spans_set = set([x for x, y in predicted_spans.items() if y not in non_entity_labels])

            self._num_gold_mentions += len(gold_spans_set)
            self._num_recalled_mentions += len(gold_spans_set & pred_spans_set)
            self._num_predicted_mentions += len(pred_spans_set)

            for ky, val in gold_spans.items():
                if val not in non_entity_labels:
                    self._GT[val] += 1

            for ky, val in predicted_spans.items():
                if val in non_entity_labels:
                    continue
                if ky in gold_spans and val == gold_spans[ky]:
                    self._TP[val] += 1
                else:
                    self._FP[val] += 1

    @overrides
    def get_metric(self, reset: bool = False) -> float:
        all_tags: Set[str] = set()
        all_tags.update(self._TP.keys())
        all_tags.update(self._FP.keys())
        all_tags.update(self._GT.keys())
        all_metrics = {}

        for tag in all_tags:
            precision, recall, f1_measure = self.compute_prf_metrics(true_positives=self._TP[tag],
                                                                     false_negatives=self._GT[tag] - self._TP[tag],
                                                                     false_positives=self._FP[tag])
            all_metrics['P@{}'.format(tag)] = precision
            all_metrics['R@{}'.format(tag)] = recall
            all_metrics['F1@{}'.format(tag)] = f1_measure

        # Compute the precision, recall and f1 for all spans jointly.
        precision, recall, f1_measure = self.compute_prf_metrics(true_positives=sum(self._TP.values()),
                                                                 false_positives=sum(self._FP.values()),
                                                                 false_negatives=sum(self._GT.values())-sum(self._TP.values()))
        all_metrics["micro@P"] = precision
        all_metrics["micro@R"] = recall
        all_metrics["micro@F1"] = f1_measure

        if self._num_gold_mentions == 0:
            entity_recall = 0.0
        else:
            entity_recall = self._num_recalled_mentions / float(self._num_gold_mentions)

        if self._num_predicted_mentions == 0:
            entity_precision = 0.0
        else:
            entity_precision = self._num_recalled_mentions / float(self._num_predicted_mentions)

        all_metrics['MD@R'] = entity_recall
        all_metrics['MD@P'] = entity_precision
        all_metrics['MD@F1'] = 2. * ((entity_precision * entity_recall) / (entity_precision + entity_recall + 1e-13))
        all_metrics['ALLTRUE'] = self._num_gold_mentions
        all_metrics['ALLRECALLED'] = self._num_recalled_mentions
        all_metrics['ALLPRED'] = self._num_predicted_mentions
        if reset:
            self.reset()
        return all_metrics

    @staticmethod
    def compute_prf_metrics(true_positives: int, false_positives: int, false_negatives: int):
        precision = float(true_positives) / float(true_positives + false_positives + 1e-13)
        recall = float(true_positives) / float(true_positives + false_negatives + 1e-13)
        f1_measure = 2. * ((precision * recall) / (precision + recall + 1e-13))
        return precision, recall, f1_measure

    @overrides
    def reset(self):
        self._num_gold_mentions = 0
        self._num_recalled_mentions = 0
        self._num_predicted_mentions = 0
        self._TP.clear()
        self._FP.clear()
        self._GT.clear()

In [53]:
def get_spans(labels):
    fin_spans = []
    for item_ in labels:

        item = deepcopy(item_)
        item.insert(0, "O")
        item.append("O")

        new_spans = {}
        for i, label in enumerate(item[1:-1], 1):

            if item[i] == "O":
                new_spans[(i-1, i-1)] = "O"
            else:
                if item[i-1] == 'O':
                    start_i = i
                if item[i+1] == 'O':
                    new_spans[(start_i-1, i-1)] = item[i].split('-')[1]
                    
        fin_spans.append(new_spans)
                
    return fin_spans    

In [54]:
metrics = {}

for lang in ["BN", "DE", "ES", "TR", "FA", "RU", "ZH", "NL", "KO", "EN", "HI"]:
    true_spans = get_spans([true for true, pred in dev_labels_pred[lang]])
    pred_spans = get_spans([pred for true, pred in dev_labels_pred[lang]])
    
    span_f1 = SpanF1()
    span_f1(pred_spans, true_spans)
    cur_metric = span_f1.get_metric()
    metrics[lang] = cur_metric

In [55]:
pd.options.display.float_format = '{:.3f}'.format
df = pd.DataFrame(index=list(metrics["RU"].keys()))

for lang, metric in metrics.items():
    df[lang] = list(metric.values())
  

In [56]:
df

Unnamed: 0,BN,DE,ES,TR,FA,RU,ZH,NL,KO,EN,HI
P@CW,0.797,0.776,0.885,0.827,0.787,0.743,0.587,0.752,0.805,0.884,0.632
R@CW,0.653,0.593,0.687,0.709,0.61,0.627,0.344,0.693,0.606,0.782,0.429
F1@CW,0.718,0.672,0.773,0.763,0.687,0.68,0.434,0.721,0.692,0.83,0.511
P@PER,0.802,0.802,0.779,0.903,0.769,0.86,0.647,0.923,0.805,0.864,0.788
R@PER,0.619,0.519,0.695,0.564,0.613,0.672,0.087,0.742,0.549,0.741,0.6
F1@PER,0.699,0.63,0.734,0.694,0.683,0.754,0.154,0.822,0.653,0.798,0.681
P@GRP,0.738,0.859,0.831,0.719,0.758,0.802,0.455,0.744,0.797,0.766,0.791
R@GRP,0.474,0.508,0.641,0.59,0.583,0.593,0.208,0.827,0.624,0.646,0.623
F1@GRP,0.577,0.638,0.723,0.648,0.659,0.682,0.286,0.784,0.7,0.701,0.697
P@PROD,0.741,0.82,0.922,0.853,0.686,0.877,0.704,0.868,0.523,0.901,0.711
