In [1]:
import os
import sys
import random
import itertools
import pickle
import json

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
from pandarallel import pandarallel

import pymorphy2
import nltk
from nltk.tokenize import word_tokenize, wordpunct_tokenize, sent_tokenize

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from keras.preprocessing.sequence import pad_sequences

from transformers import T5Model, T5Tokenizer, T5ForConditionalGeneration

from tqdm.auto import trange
import gc

sys.path.append("..")
from src import *

def init_random_seed(value=0):
    random.seed(value)
    np.random.seed(value)
    torch.manual_seed(value)
    torch.cuda.manual_seed(value)
    torch.backends.cudnn.deterministic = True
init_random_seed()
    
pd.set_option('display.max_colwidth', 255)
tqdm.pandas()
pandarallel.initialize(progress_bar=True, nb_workers=8, use_memory_fs=False)

%load_ext autoreload
%autoreload 2

INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.


In [3]:
with open("../data/fill_tasks_t5.pickle", "rb") as f:
    fill_tasks = pickle.load(f)
new_tasks = [] 
for task in fill_tasks:
    if len(task[0]) < 350:
        new_tasks.append(task)
print(len(new_tasks) / len(fill_tasks))
fill_tasks = new_tasks

MODEL_NAME = "sberbank-ai/ruT5-base"

tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

batch_size = 16  # сколько примеров показываем модели за один шаг
report_steps = 200  # раз в сколько шагов печатаем результат
epochs = 1  # сколько раз мы покажем данные модели

model.train()
losses = []
try:
    for epoch in range(epochs):
        print('EPOCH', epoch)
        random.shuffle(fill_tasks)
        for i in trange(0, int(len(fill_tasks) / batch_size)):
            batch = fill_tasks[i * batch_size: (i + 1) * batch_size]
            # кодируем вопрос и ответ 
            x = tokenizer([p[0] for p in batch], return_tensors='pt', padding=True).to(model.device)
            y = tokenizer([p[1] for p in batch], return_tensors='pt', padding=True).to(model.device)
            # -100 - специальное значение, позволяющее не учитывать токены
            y.input_ids[y.input_ids == 0] = -100
            # вычисляем функцию потерь
            loss = model(
                input_ids=x.input_ids,
                attention_mask=x.attention_mask,
                labels=y.input_ids,
                decoder_attention_mask=y.attention_mask,
                return_dict=True
            ).loss
            # делаем шаг градиентного спуска
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            # печатаем скользящее среднее значение функции потерь
            losses.append(loss.item())
            if i % report_steps == 0:
                print('step', i, 'loss', np.mean(losses[-report_steps:]))

            del x, y
            gc.collect()
            torch.cuda.empty_cache()
except Exception as ex: 
    print("ERROR")
    model.save_pretrained(f"../models/t5/t5_trained_{i}.model")
    raise
    
model.save_pretrained(f"../models/t5/t5_trained_full.model")

0.9776699305845254
EPOCH 0


  0%|          | 0/23582 [00:00<?, ?it/s]

step 0 loss 14.187641143798828
step 200 loss 3.074126591384411
step 400 loss 1.2904317392408848
step 600 loss 0.9118332402408122
step 800 loss 0.7503897586092353
step 1000 loss 0.644437567256391
step 1200 loss 0.6028430039249361
step 1400 loss 0.5449692789837718
step 1600 loss 0.46744242902845146
step 1800 loss 0.45731070790439843
step 2000 loss 0.41072798665612936
step 2200 loss 0.38936339505016804
step 2400 loss 0.3815291952341795
step 2600 loss 0.35272137619554994
step 2800 loss 0.36231855985708533
step 3000 loss 0.33569877207279203
step 3200 loss 0.33495756428223106
step 3400 loss 0.335749902902171
step 3600 loss 0.3357129411958158
step 3800 loss 0.3064714600984007
step 4000 loss 0.28765694751404225
step 4200 loss 0.2958027189038694
step 4400 loss 0.2664230223465711
step 4600 loss 0.26277364682406185
step 4800 loss 0.26945081979036334
step 5000 loss 0.2512307298090309
step 5200 loss 0.2564898593723774
step 5400 loss 0.23393741123378275
step 5600 loss 0.24112421906553208
step 5800 l

## Scoring 

In [2]:
lenta_test = pd.read_csv("../data/lenta_test_t5.csv")

test_texts = list(map(lambda x: x.split(" "), lenta_test["text_upd"].to_list()))
test_labels = list(map(lambda x: x.split(" "), lenta_test["labels_upd"].to_list()))

def get_sentences(text, labels):
    sentences = sent_tokenize(" ".join(text), language="russian")
    sent_labels = []
    
    i = 0
    for sent in sentences:
        sent = sent.split(" ")
        sent_labels.append(" ".join(labels[i:i+len(sent)]))
        i += len(sent)
    return sentences, sent_labels

sentences = []
sentences_labels = []
for text, labels in tqdm(zip(test_texts, test_labels)):
    sent, sent_labels = get_sentences(text, labels)
    sentences.extend(sent)
    sentences_labels.extend(sent_labels)

100853it [00:10, 9360.47it/s]


In [3]:
new_sentences = []
new_sentences_labels = []
for i, sent in enumerate(sentences):
    if len(sent.split(" ")) <= 10: # 50
        new_sentences.append(sent)
        new_sentences_labels.append(sentences_labels[i])
        

print(len(new_sentences) / len(sentences))        
sentences = new_sentences
sentences_labels = new_sentences_labels

0.16689332146699995


In [39]:
abbr_detection = AbbrDetection()

morph = pymorphy2.MorphAnalyzer(lang="ru", 
                                units=[pymorphy2.units.DictionaryAnalyzer()])
def normalize(word):
    parse_list = morph.parse(str(word))
    if parse_list != []:
        return parse_list[0].normal_form
    else:
        return word.lower()
    
def create_fill_task(sent, i):
    new_sent = sent.copy()
    new_sent[i] = "<extra_id_1>"
    return f"fill {sent[i]} | {' '.join(new_sent)}"

tasks = []
preds = []

for sent_i in tqdm(range(len(sentences))):
    sent = sentences[sent_i].split(" ")
    sent_labels = sentences_labels[sent_i].split(" ")
    
    for word_i in range(len(sent)):
        word = sent[word_i]
        label = sent_labels[word_i]
        
        norm_word = normalize(word)
        if abbr_detection.word_is_abbr(norm_word):
            task = create_fill_task(sent, word_i)
            tasks.append((sent_i, word_i, task))
        
    preds.append(["_"] * len(sent))

100%|██████████| 170461/170461 [02:04<00:00, 1368.82it/s]


In [4]:
# with open("../data/preds_test_t5.pickle", "wb") as f:
#     pickle.dump(preds, f)
    
# with open("../data/tasks_test_t5.pickle", "wb") as f:
#     pickle.dump(tasks, f)  

with open("../data/preds_test_t5.pickle", "rb") as f:
    preds = pickle.load(f)
    
with open("../data/tasks_test_t5.pickle", "rb") as f:
    tasks = pickle.load(f)  

MODEL_NAME = "sberbank-ai/ruT5-base"
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
model = T5ForConditionalGeneration.from_pretrained(f"../models/t5/t5_trained_full.model")
model.cuda()
get_model_params_num(model)

222903552

In [5]:
batch_size = 64 # 32

for i in trange(0, int(len(tasks) / batch_size)):
    batch = tasks[i * batch_size: (i + 1) * batch_size]
    
    batch_sent_i = [p[0] for p in batch]
    batch_word_i = [p[1] for p in batch]
    inputs = tokenizer([p[2] for p in batch], return_tensors='pt', padding=True).to(model.device)
    
    with torch.no_grad():
        hypotheses = model.generate(**inputs, num_beams=5)
    
    for j in range(len(batch)):
        sent_i = batch_sent_i[j]
        word_i = batch_word_i[j]
        preds[sent_i][word_i] = tokenizer.decode(hypotheses[j], skip_special_tokens=True).replace(" ", "=")
        
    del batch
    gc.collect()
    torch.cuda.empty_cache()

  0%|          | 0/580 [00:00<?, ?it/s]

In [6]:
preds = [" ".join(pred) for pred in preds]

In [7]:
# with open("../data/preds_t5.pickle", "wb") as f:
#     pickle.dump(preds, f) 

with open("../data/preds_t5.pickle", "rb") as f:
    preds = pickle.load(f) 

In [8]:
full_preds = []
full_labels = []
for i in range(len(preds)):
    pred = preds[i]
    true = sentences_labels[i]
    if len(pred.split(" ")) == len(true.split(" ")):
        full_preds.extend([preds[i]])
        full_labels.extend([sentences_labels[i]])
full_preds = " ".join(full_preds).split(" ")
full_labels = " ".join(full_labels).split(" ")

In [9]:
len(full_preds) == len(full_labels)

True

In [10]:
morph = pymorphy2.MorphAnalyzer(lang="ru", 
                                units=[pymorphy2.units.DictionaryAnalyzer()])

def norm_tokenize(line):
    tokenized_norm = []
    for word in word_tokenize(line):
        parse_list = morph.parse(str(word))
        if parse_list != []:
            norm_form = parse_list[0].normal_form
        else:
            norm_form = word
        tokenized_norm.append(norm_form)
    return tokenized_norm

In [15]:
full_preds_norm = []
for label in tqdm(full_preds):
    if label != "_":
        label = label.replace("=", " ").lower()
        label_norm = norm_tokenize(label)
        full_preds_norm.append(" ".join(label_norm))
    else:
        full_preds_norm.append("_")

100%|██████████| 1360937/1360937 [00:06<00:00, 205580.27it/s]


In [16]:
full_labels_norm = []
for label in tqdm(full_labels):
    if label != "_":
        label = label.replace("=", " ").lower()
        label_norm = norm_tokenize(label)
        full_labels_norm.append(" ".join(label_norm))
    else:
        full_labels_norm.append("_")

100%|██████████| 1360937/1360937 [00:00<00:00, 1547142.42it/s]


In [26]:
f1_score(full_labels_norm, full_preds_norm, average="macro")

0.034699985738502095

In [27]:
def get_filtred_accuracy_score(true, pred):
    stacked = zip(true, pred)
    filt_stacked = list(filter(lambda x: x != ("_", "_"), stacked))
    true, pred = zip(*filt_stacked)
    return accuracy_score(true, pred)

In [28]:
get_filtred_accuracy_score(full_labels_norm, full_preds_norm)

0.1711422953645595

In [29]:
stacked = zip(full_labels_norm, full_preds_norm)
filt_stacked = list(filter(lambda x: x != ("_", "_"), stacked))
true, pred = zip(*filt_stacked)

In [30]:
filt_stacked

[('_', 'международный паралимпийский суд'),
 ('_', 'российский федерация'),
 ('_', 'рига чемпион'),
 ('передача данные', 'передача данные'),
 ('_', 'двухпортовой'),
 ('_', 'геймальному падение'),
 ('_', 'феерация лига чемпион'),
 ('_', 'бережно-космической станция'),
 ('_', 'управление по борьба с экономический преступление'),
 ('_', 'наса'),
 ('информация', 'информация'),
 ('_', 'республика штат право дума'),
 ('_', 'всероссийский автономный комиссия'),
 ('_', 'информационный агентство'),
 ('_', 'федеральный агентство'),
 ('_', 'голодный дума'),
 ('северный корея', 'следственный комитет'),
 ('международный', 'международный'),
 ('_', 'международный антидопинговый агентство'),
 ('_', 'марецкий уголовный розыск'),
 ('_', 'мейсон'),
 ('информация', 'информация'),
 ('_', 'инвестиционный фонд'),
 ('_', 'миллион'),
 ('_', 'оперативный'),
 ('_', 'законодательный собрание'),
 ('_', 'манерский автономный округ'),
 ('_', 'тунецкий автономный округ'),
 ('_', 'тимарский автономный округ'),
 ('фина