In [1]:
from collections import defaultdict, Counter
import os, operator
from progressbar import progressbar as pb
from nltk import word_tokenize
from typing import List, Tuple

## Read data from disk

In [2]:
data_path = "../SemEval2022-Task11_Train-Dev/RU-Russian/"
model_path = "../../pretrained/sbert_large_nlu_ru/"
SAVE_PATH = "models/template_free/sbert_large_nlu_ru/"
device = 'cuda'

In [3]:
with open(os.path.join(data_path, "ru_train.conll")) as f:
    train_file = f.read().splitlines()
    
with open(os.path.join(data_path, "ru_dev.conll")) as f:
    dev_file = f.read().splitlines()

In [4]:
def parse_conll(file) -> Tuple[List, List]:
    texts, labels = [], []
    
    for row in file:
        if row.startswith("#"):
            new_texts, new_labels = [], []
            continue

        if row == "":
            texts.append(new_texts)
            labels.append(new_labels)

        else:
            parts = row.split()
            new_texts.append(parts[0])
            new_labels.append(parts[-1])

    return texts, labels

train_texts, train_labels = parse_conll(train_file)
dev_texts, dev_labels = parse_conll(dev_file)

In [5]:
len(train_texts), len(dev_texts)

(15300, 800)

In [6]:
assert len(train_texts) == len(train_labels)
assert len(dev_texts) == len(dev_labels)

for t, l in zip(train_texts, train_labels):
    assert len(t) == len(l)
    
for t, l in zip(dev_texts, dev_labels):
    assert len(t) == len(l)

In [7]:
ind = 33
print(train_texts[ind], train_labels[ind])

['через', 'год', 'дирекция', 'влезает', 'в', 'долги', 'и', 'берёт', 'под', 'определённый', 'процент', 'акций', 'кредит', 'в', 'казкоммерцбанк', '.'] ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-CORP', 'O']


## Find label words distribution

In [8]:
def parse_entity(text, labels):
    res_text, res_labels = [], []
    new_text = []
    
    for t, l in zip(text, labels):
        if l.startswith("B-"):
            new_text = [t]
            new_label = l.split("-")[-1]
            
        elif l.startswith("I-"):
            new_text.append(t)
            
        elif l == 'O' and len(new_text):
            new_text = " ".join(new_text)
            res_text.append(new_text)
            res_labels.append(new_label)
            new_text, new_label = [], []
            
    return res_text, res_labels

In [9]:
freq_dct = defaultdict(list)

for text, label in zip(train_texts, train_labels):
    parsed_text, parsed_labels = parse_entity(text, label)
    for t, l in zip(parsed_text, parsed_labels):
        freq_dct[l].append(t)
        
for k, v in freq_dct.items():
    freq_dct[k] = sorted(dict(Counter(v)).items(), key=operator.itemgetter(1), reverse=True)

In [10]:
freq_dct.keys()

dict_keys(['GRP', 'PER', 'CW', 'PROD', 'CORP', 'LOC'])

In [11]:
for k, v in freq_dct.items():
    print(k)
    print(v[:2]) # per - человек
    # попробовать вместо слов использовать сущности

GRP
[('бюро переписи населения сша', 25), ('колхоз', 25)]
PER
[('а. п. чехова', 6), ('женщин', 5)]
CW
[('государственный геральдический регистр российской федерации', 16), ('сингл', 15)]
PROD
[('dvd', 39), ('пулемёт', 25)]
CORP
[('rotten tomatoes', 70), ('mtv', 29)]
LOC
[('париж', 43), ('германии', 31)]


In [12]:
label2word_label = { # label words dict
    "GRP": "колхоз",
    "PER": "человек",
    "CW": "сингл",
    "PROD": "dvd",
    "CORP": "mtv",
    "LOC": "париж"
}

word_label2label = {v:k for k, v in label2word_label.items()}

In [13]:
def prepare_target(tokens, labels, label2word_label=label2word_label):
    """
        Replace entities with label words
    """
    new_tokens = []
    for token, label in zip(tokens, labels):
        if label.startswith("B-"):
            prefix, tag = label.split("-")
            new_token = label2word_label[tag]
            new_tokens.append(new_token)
        elif label.startswith("I-"):
            continue
        else:
            new_tokens.append(token)
    
    return new_tokens

In [14]:
train_data, dev_data = [], []

for tokens, label_list in zip(train_texts, train_labels):
    target = prepare_target(tokens, label_list)
    x = " ".join(tokens)
    y = " ".join(target)
    train_data.append((x,y))
    
for tokens, label_list in zip(dev_texts, dev_labels):
    target = prepare_target(tokens, label_list)
    x = " ".join(tokens)
    y = " ".join(target)
    dev_data.append((x,y))

In [15]:
dev_data[0]

('важным традиционным промыслом является производство пальмового масла .',
 'важным традиционным промыслом является производство dvd .')

In [16]:
dev_id_to_entity_map = {}

for i, (text, labels) in enumerate(zip(dev_texts, dev_labels)):
    dev_id_to_entity_map[i] = defaultdict(list)
    for token, label in zip(text, labels):
        if label == "O":
            continue
        prefix, tag = label.split("-")
        if prefix == "B":
            dev_id_to_entity_map[i][tag].append([])
            dev_id_to_entity_map[i][tag][-1].append(label)
        elif prefix == "I":
            dev_id_to_entity_map[i][tag][-1].append(label)

In [17]:
ind = 2
print(dev_texts[ind], dev_labels[ind])

['специальный', 'агент', 'секретной', 'службы', 'сша', 'джеззи', 'фланниган', ',', 'ответственная', 'за', 'нарушение', 'безопасности', ',', 'объединяется', 'с', 'кроссом', ',', 'чтобы', 'найти', 'пропавшую', 'девушку', '.'] ['O', 'O', 'B-GRP', 'I-GRP', 'I-GRP', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


In [18]:
dev_id_to_entity_map[ind]

defaultdict(list, {'GRP': [['B-GRP', 'I-GRP', 'I-GRP']]})

## Prepare for Training

In [19]:
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForMaskedLM
import torch
import numpy as np
from sklearn.metrics import classification_report, f1_score, accuracy_score

In [20]:
max_len = 64
batch_size = 32

In [21]:
tokenizer = BertTokenizer.from_pretrained(model_path, local_files_only=True)

In [22]:
def encode_data(data):
    X_data, y_data = [], []

    for item in data:
        x, y = item
        x_enc = tokenizer.encode(x, max_length=max_len, padding="max_length", truncation=True)
        y_enc = tokenizer.encode(y, max_length=max_len, padding="max_length", truncation=True)
        X_data.append(x_enc)
        y_data.append(y_enc)
        
    return np.array(X_data), np.array(y_data)
    
X_train, y_train = encode_data(train_data)
X_dev, y_dev = encode_data(dev_data)
X_train.shape, y_train.shape, X_dev.shape, y_dev.shape

((15300, 64), (15300, 64), (800, 64), (800, 64))

In [23]:
train = np.stack((X_train, y_train), axis=1)
dev = np.stack((X_dev, y_dev), axis=1)
train.shape, dev.shape

((15300, 2, 64), (800, 2, 64))

In [24]:
train_batches = DataLoader(train, batch_size=batch_size, shuffle=True)
dev_batches = DataLoader(dev, batch_size=batch_size, shuffle=True)

## Train

In [25]:
model = BertForMaskedLM.from_pretrained(model_path, local_files_only=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

model.to(device)
sum(p.numel() for p in model.parameters())

Some weights of BertForMaskedLM were not initialized from the model checkpoint at ../../pretrained/sbert_large_nlu_ru/ and are newly initialized: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


427030858

In [27]:
def _train(model, train_loader, optimizer, epoch_num):
    train_loss, train_acc, train_f1 = [], [], []
    model.train()

    for batch_num, batch in enumerate(train_loader):
        X_batch, y_batch = batch[:, 0, :], batch[:, 1, :]
        X_batch = X_batch.type(torch.LongTensor).to(device)
        y_batch = y_batch.type(torch.LongTensor).to(device)
        optimizer.zero_grad()

        out = model(input_ids=X_batch, labels=y_batch.contiguous(), return_dict=True)
        loss = out.loss
        y_pred = out.logits
        y_pred = torch.argmax(y_pred, dim=2)

        y_pred_flatten = torch.flatten(y_pred).cpu().numpy()
        y_batch_flatten = torch.flatten(y_batch).cpu().numpy()
        f1 = f1_score(y_batch_flatten, y_pred_flatten, average="micro")
        accuracy = accuracy_score(y_batch_flatten, y_pred_flatten)

        train_loss.append(loss.item())
        train_acc.append(accuracy)
        train_f1.append(f1)
        
        if batch_num % 50 == 0:
            print(f"TRAIN: Epoch: {epoch_num}, Batch: {batch_num + 1} / {len(train_loader)}, "
                          f"Loss: {loss.item():.3f}, Accuracy: {accuracy:.3f}, F1: {f1:.3f}")

        loss.backward()
        optimizer.step()

    return np.mean(train_loss), np.mean(train_acc), np.mean(train_f1)

def _val(model, val_loader, epoch_num):
    val_loss, val_acc, val_f1 = [], [], []
    model.eval()

    for batch_num, batch in enumerate(val_loader):
        X_batch, y_batch = batch[:, 0, :], batch[:, 1, :]
        X_batch = X_batch.type(torch.LongTensor).to(device)
        y_batch = y_batch.type(torch.LongTensor).to(device)

        out = model(input_ids=X_batch, labels=y_batch.contiguous())
        loss = out.loss
        y_pred = out.logits
        y_pred = torch.argmax(y_pred, dim=2)

        y_pred_flatten = torch.flatten(y_pred).cpu().numpy()
        y_batch_flatten = torch.flatten(y_batch).cpu().numpy()
        f1 = f1_score(y_batch_flatten, y_pred_flatten, average="micro")
        accuracy = accuracy_score(y_batch_flatten, y_pred_flatten)

        val_loss.append(loss.item())
        val_acc.append(accuracy)
        val_f1.append(f1)
        
        if batch_num % 50 == 0:
            print(f"VAL: Epoch: {epoch_num}, Batch: {batch_num + 1} / {len(val_loader)}, "
                          f"Loss: {loss.item():.3f}, Accuracy: {accuracy:.3f}, F1: {f1:.3f}")

    return np.mean(val_loss), np.mean(val_acc), np.mean(val_f1)

In [28]:
last_epoch = 0
dev_losses = []
patience = 1

for epoch in range(1, 25 + 1):
    train_loss, train_acc, train_f1 = _train(model, train_batches, optimizer, epoch)
    dev_loss, dev_acc, dev_f1 = _val(model, dev_batches, epoch)

    if len(dev_losses) == 0 or dev_loss < dev_losses[-1]:
        model.save_pretrained(SAVE_PATH)

    elif last_epoch == 0:
        last_epoch = epoch + patience

    print(f"After epoch #{epoch}:")
    print(f"Train loss: {train_loss:.3f}, Train Accuracy: {train_acc:.3f}, Train F1: {train_f1:.3f}")
    print(f"Dev loss: {dev_loss:.3f}, Dev Accuracy: {dev_acc:.3f}, Dev F1: {dev_f1:.3f}\n")

    dev_losses.append(dev_loss)
    if epoch == last_epoch:
        break

TRAIN: Epoch: 1, Batch: 1 / 479, Loss: 12.386, Accuracy: 0.000, F1: 0.000
TRAIN: Epoch: 1, Batch: 51 / 479, Loss: 2.446, Accuracy: 0.679, F1: 0.679
TRAIN: Epoch: 1, Batch: 101 / 479, Loss: 1.882, Accuracy: 0.741, F1: 0.741
TRAIN: Epoch: 1, Batch: 151 / 479, Loss: 1.433, Accuracy: 0.776, F1: 0.776
TRAIN: Epoch: 1, Batch: 201 / 479, Loss: 1.365, Accuracy: 0.783, F1: 0.783
TRAIN: Epoch: 1, Batch: 251 / 479, Loss: 1.095, Accuracy: 0.833, F1: 0.833
TRAIN: Epoch: 1, Batch: 301 / 479, Loss: 1.092, Accuracy: 0.816, F1: 0.816
TRAIN: Epoch: 1, Batch: 351 / 479, Loss: 1.180, Accuracy: 0.815, F1: 0.815
TRAIN: Epoch: 1, Batch: 401 / 479, Loss: 0.663, Accuracy: 0.882, F1: 0.882
TRAIN: Epoch: 1, Batch: 451 / 479, Loss: 1.051, Accuracy: 0.834, F1: 0.834
VAL: Epoch: 1, Batch: 1 / 25, Loss: 0.934, Accuracy: 0.844, F1: 0.844
After epoch #1:
Train loss: 1.468, Train Accuracy: 0.785, Train F1: 0.785
Dev loss: 0.959, Dev Accuracy: 0.836, Dev F1: 0.836

TRAIN: Epoch: 2, Batch: 1 / 479, Loss: 0.911, Accuracy:

In [29]:
phrase = dev_data[ind][0]

tokenized = torch.LongTensor([
        tokenizer.encode(phrase, max_length=max_len, padding="max_length", truncation=True)
    ]).to(device)

In [30]:
phrase

'специальный агент секретной службы сша джеззи фланниган , ответственная за нарушение безопасности , объединяется с кроссом , чтобы найти пропавшую девушку .'

In [31]:
out = model(tokenized).logits
out = torch.argmax(out, dim=2)
tokenizer.decode(out[0], skip_special_tokens=True)

'специальныи сингл секретно за за за за, за за за засяся,,,,м,,вшую..'

# decode predictions

In [36]:
model = BertForMaskedLM.from_pretrained(SAVE_PATH, local_files_only=True)
model.to(device)
model.eval()
print()




In [37]:
def get_pred(text, model=model):
    tokenized = torch.LongTensor([
        tokenizer.encode(text, max_length=max_len, padding="max_length", truncation=True)
    ]).to(device)
    out = model(tokenized).logits
    out = torch.argmax(out, dim=2)
    return tokenizer.decode(out[0], skip_special_tokens=True) 

In [38]:
dev_x = [item[0] for item in dev_data]
dev_y_true = [item[1] for item in dev_data]
dev_y_pred = [get_pred(item) for item in pb(dev_x)]

100% (800 of 800) |######################| Elapsed Time: 0:00:45 Time:  0:00:45


In [148]:
ind = 7
dev_y_true[ind], dev_y_pred[ind]

('выпущен 15 февраля 2019 года лейблом mtv .',
 'выпущен 15 февраля 2019 года леtлом mtv.')

In [75]:
def decode_pred(y_true, y_pred):
    pred_labels = []

    y_pred = y_pred.replace("ё", "е").replace("й", "и")
    y_pred = word_tokenize(y_pred)
    
    y_true = y_true.replace("ё", "е").replace("й", "и")
    y_true = word_tokenize(y_true)
    
    true_labels = [word_label2label.get(token, "O") for token in y_true]
    pred_labels = [word_label2label.get(token, "O") for token in y_pred]
    
    true_len = len(true_labels)
    pred_len = len(pred_labels)
    
    if true_len > pred_len:
        pred_labels.extend(["O"] * (true_len - pred_len))
    elif pred_len > true_len:
        true_labels.extend(["O"] * (pred_len - true_len))
            
    assert len(true_labels) == len(pred_labels)
    return true_labels, pred_labels

In [178]:
from copy import deepcopy
# dev_labels

dev_id_to_entity_map_cp = deepcopy(dev_id_to_entity_map)

final_dev_pred = []

for i, (t, p, labels) in enumerate(zip(dev_y_true, dev_y_pred, dev_labels)):
    t, p  = decode_pred(t, p)
    fin = []

    for label in p:
        if label == "O":
            fin.append("O")
        else:
            if label in dev_id_to_entity_map_cp[i]:
                fin.extend(dev_id_to_entity_map_cp[i][label][0])
                if len(dev_id_to_entity_map_cp[i][label]) > 1:
                    dev_id_to_entity_map_cp[i][label] = dev_id_to_entity_map_cp[i][label][1:]
            else:
                fin.append("B-" + label)
    
    true_len = len(labels)
    pred_len = len(fin)
    
    if true_len > pred_len:
        fin.extend(["O"] * (true_len - pred_len))
    elif pred_len > true_len:
        fin = fin[:true_len]
    assert len(fin) == len(labels)
    final_dev_pred.append(fin)

In [184]:
flat_true = [item for sublist in dev_labels for item in sublist]
flat_pred = [item for sublist in final_dev_pred for item in sublist]
assert len(flat_true) == len(flat_pred)

In [188]:
from sklearn.metrics import classification_report

print(classification_report(flat_true, flat_pred, digits=4))

              precision    recall  f1-score   support

      B-CORP     0.8205    0.4051    0.5424       158
        B-CW     0.4129    0.4940    0.4499       168
       B-GRP     0.5221    0.4702    0.4948       151
       B-LOC     0.4193    0.6109    0.4972       221
       B-PER     0.4211    0.5417    0.4738       192
      B-PROD     0.6667    0.4503    0.5375       151
      I-CORP     0.8929    0.4132    0.5650       121
        I-CW     0.5774    0.5389    0.5575       180
       I-GRP     0.6762    0.5000    0.5749       142
       I-LOC     0.3793    0.3438    0.3607        32
       I-PER     0.6337    0.6598    0.6465       194
      I-PROD     0.6667    0.5000    0.5714        56
           O     0.9379    0.9446    0.9412     11002

    accuracy                         0.8852     12768
   macro avg     0.6174    0.5286    0.5548     12768
weighted avg     0.8890    0.8852    0.8844     12768

