In [1]:
import os
import sys

import pandas as pd
import numpy as np
import random
import itertools

from matplotlib import pyplot as plt

from tqdm import tqdm

from pandarallel import pandarallel

import pymorphy2
import nltk
import pickle
import gc

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from keras.preprocessing.sequence import pad_sequences

import gensim
from gensim.models.callbacks import CallbackAny2Vec
from gensim.models.phrases import Phrases, Phraser

import time


SEED = 1
def init_random_seed(value=0):
    random.seed(value)
    np.random.seed(value)
    torch.manual_seed(value)
    torch.cuda.manual_seed(value)
    torch.backends.cudnn.deterministic = True
init_random_seed(SEED)
    
pd.set_option('display.max_colwidth', 255)
tqdm.pandas()
pandarallel.initialize(progress_bar=True, nb_workers=8, use_memory_fs=False)

INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.


In [2]:
from transformers import T5Model, T5Tokenizer, T5ForConditionalGeneration

In [3]:
abbr = pd.read_csv("../data/abbr.csv")
lenta_train = pd.read_csv("../data/lenta_train.csv")
lenta_test = pd.read_csv("../data/lenta_test.csv")

def get_abbr(abbr_id):
    return abbr[abbr["abbr_id"] == abbr_id]["abbr_norm"].iloc[0]

def get_desc(abbr_id):
    return abbr[abbr["abbr_id"] == abbr_id]["desc_norm"].iloc[0]

In [4]:
def get_fill_task(text, labels, window_size=10):
    text = text.copy()
    abbr_id = random.choices(list(set(label) - set("_")))[0]
    ind = labels.index(abbr_id)
    abbr_id = int(abbr_id.replace("W-", ""))
    desc = get_desc(abbr_id)
    abbr_norm = text[ind]
    text[ind] = "<extra_id_1>"
    l = max(0, ind - window_size)
    r = ind + window_size
    pair = f"fill {abbr_norm} | {' '.join(text[l:r])}" 
    return pair, desc

In [5]:
texts = lenta_train["text_new"].to_list()
labels = lenta_train["labels_new"].to_list()

pairs = []
for i in tqdm(range(len(texts))):
    text = texts[i].split()
    label = labels[i].split()
    
    try:
        pairs.append(get_fill_task(text, label))
    except:
        continue

100%|██████████| 548700/548700 [01:05<00:00, 8336.99it/s]


In [6]:
pairs[1]

('fill кг | обладминистрация в сб 20 фев подарочный набор состоять из семь <extra_id_1> куриный мясо пять кило гречневый крупа и шесть литр',
 'килограмм')

In [7]:
from transformers import T5ForConditionalGeneration, T5Tokenizer
MODEL_NAME = "sberbank-ai/ruT5-base"

tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)

print('Количество параметров', sum(np.product(t.shape) for t in model.parameters()))

Количество параметров 222903552


In [8]:
def generate(model, text, **kwargs):
    inputs = tokenizer(text, return_tensors='pt')
    with torch.no_grad():
        hypotheses = model.generate(**inputs, **kwargs)
    return tokenizer.decode(hypotheses[0], skip_special_tokens=True)

In [9]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

In [10]:
from tqdm.auto import trange
import random
import numpy as np

batch_size = 16  # сколько примеров показываем модели за один шаг
report_steps = 20  # раз в сколько шагов печатаем результат
epochs = 10  # сколько раз мы покажем данные модели
model.cuda()

model.train()
losses = []
for epoch in range(epochs):
    print('EPOCH', epoch)
    random.shuffle(pairs)
    for i in trange(0, int(len(pairs) / batch_size)):
        batch = pairs[i * batch_size: (i + 1) * batch_size]
        # кодируем вопрос и ответ 
        x = tokenizer([p[0] for p in batch], return_tensors='pt', padding=True).to(model.device)
        y = tokenizer([p[1] for p in batch], return_tensors='pt', padding=True).to(model.device)
        # -100 - специальное значение, позволяющее не учитывать токены
        y.input_ids[y.input_ids == 0] = -100
        # вычисляем функцию потерь
        loss = model(
            input_ids=x.input_ids,
            attention_mask=x.attention_mask,
            labels=y.input_ids,
            decoder_attention_mask=y.attention_mask,
            return_dict=True
        ).loss
        # делаем шаг градиентного спуска
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        # печатаем скользящее среднее значение функции потерь
        losses.append(loss.item())
        if i % report_steps == 0:
            print('step', i, 'loss', np.mean(losses[-report_steps:]))

EPOCH 0


  0%|          | 0/34291 [00:00<?, ?it/s]

step 0 loss 11.86892032623291
step 20 loss 9.378391814231872
step 40 loss 6.49945182800293
step 60 loss 5.952630472183228
step 80 loss 5.5946629524230955
step 100 loss 5.101913475990296
step 120 loss 4.919951748847962
step 140 loss 4.662925553321839
step 160 loss 4.4793765306472775
step 180 loss 4.487955963611602
step 200 loss 4.281306004524231
step 220 loss 4.0929787993431095
step 240 loss 4.193723571300507
step 260 loss 4.217689514160156
step 280 loss 3.888678765296936
step 300 loss 3.7140239000320436
step 320 loss 3.9411176800727845
step 340 loss 3.798016381263733
step 360 loss 3.612498414516449
step 380 loss 3.5130589485168455
step 400 loss 3.6169383049011232
step 420 loss 3.33280154466629
step 440 loss 3.6183501482009888
step 460 loss 3.442434549331665
step 480 loss 3.36221821308136
step 500 loss 3.151871645450592
step 520 loss 3.210953783988953
step 540 loss 3.0059702515602114
step 560 loss 3.116533374786377
step 580 loss 3.090851294994354
step 600 loss 3.1066410541534424
step 62

step 4920 loss 0.8062367469072342
step 4940 loss 0.9005944952368736
step 4960 loss 0.87210054397583
step 4980 loss 0.8682508289813995
step 5000 loss 1.0000765338540076
step 5020 loss 0.7942536115646363
step 5040 loss 0.768225172907114
step 5060 loss 0.6879331141710281
step 5080 loss 0.6627293922007084
step 5100 loss 0.870412340760231
step 5120 loss 0.691646384447813
step 5140 loss 0.7494843900203705
step 5160 loss 0.9107828244566918
step 5180 loss 0.7551818184554577
step 5200 loss 0.8259865745902062
step 5220 loss 0.8284564286470413
step 5240 loss 0.8949624091386795
step 5260 loss 0.7425629980862141
step 5280 loss 0.7518613427877426
step 5300 loss 0.7725245520472527
step 5320 loss 0.7499701797962188
step 5340 loss 0.8532861918210983
step 5360 loss 0.7140013247728347
step 5380 loss 0.7850799322128296
step 5400 loss 0.6618482083082199
step 5420 loss 0.9246819466352463
step 5440 loss 0.76730320379138
step 5460 loss 0.6968156792223453
step 5480 loss 0.7334330350160598
step 5500 loss 0.7264

step 9760 loss 0.5437453594058752
step 9780 loss 0.4725322399288416
step 9800 loss 0.39577089473605154
step 9820 loss 0.5434926800429821
step 9840 loss 0.6477682664990425
step 9860 loss 0.45388770997524264
step 9880 loss 0.6213390033692121
step 9900 loss 0.5369862887077034
step 9920 loss 0.5711780320852995
step 9940 loss 0.5149630881845951
step 9960 loss 0.46584524437785146
step 9980 loss 0.5594268836081028
step 10000 loss 0.5472898080945015
step 10020 loss 0.3854832325130701
step 10040 loss 0.5152585089206696
step 10060 loss 0.38435007641091945
step 10080 loss 0.38758729174733164
step 10100 loss 0.37701262906193733
step 10120 loss 0.5778123959898949
step 10140 loss 0.49014346916228535
step 10160 loss 0.5743768196552992
step 10180 loss 0.448729232698679
step 10200 loss 0.6531313240528107
step 10220 loss 0.5236358620226383
step 10240 loss 0.4262387868016958
step 10260 loss 0.48917985484004023
step 10280 loss 0.44405491426587107
step 10300 loss 0.41068476922810077
step 10320 loss 0.41273

step 14420 loss 0.42521117962896826
step 14440 loss 0.4861670546233654
step 14460 loss 0.3129865301772952
step 14480 loss 0.4660312200896442
step 14500 loss 0.39245738722383977
step 14520 loss 0.3735300075262785
step 14540 loss 0.32824043473228814
step 14560 loss 0.32301296377554534
step 14580 loss 0.27353907357901336
step 14600 loss 0.38654768019914626
step 14620 loss 0.3145447187125683
step 14640 loss 0.3931882468983531
step 14660 loss 0.34590654689818623
step 14680 loss 0.40333166169002654
step 14700 loss 0.2734866398386657
step 14720 loss 0.39375204518437384
step 14740 loss 0.40275752209126947
step 14760 loss 0.3633188549429178
step 14780 loss 0.3613278169184923
step 14800 loss 0.38932064762338997
step 14820 loss 0.2367310424335301
step 14840 loss 0.336702230013907
step 14860 loss 0.3455478329211473
step 14880 loss 0.27943711122497916
step 14900 loss 0.4675520084798336
step 14920 loss 0.31843536123633387
step 14940 loss 0.3601026427000761
step 14960 loss 0.3232391245663166
step 149

step 19060 loss 0.2880386605858803
step 19080 loss 0.27087020529434086
step 19100 loss 0.25550264017656443
step 19120 loss 0.30119307078421115
step 19140 loss 0.31998858265578745
step 19160 loss 0.3213635457213968
step 19180 loss 0.34025832479819657
step 19200 loss 0.20846925107762218
step 19220 loss 0.18007696145214142
step 19240 loss 0.35490981237962843
step 19260 loss 0.2969955254346132
step 19280 loss 0.23687245529145
step 19300 loss 0.3015058010816574
step 19320 loss 0.3274524969048798
step 19340 loss 0.17902935678139328
step 19360 loss 0.2997012242209166
step 19380 loss 0.2946732981130481
step 19400 loss 0.3295608405023813
step 19420 loss 0.3338103686459363
step 19440 loss 0.36814675251953305
step 19460 loss 0.2806670473888516
step 19480 loss 0.26299281883984804
step 19500 loss 0.38605516669340434
step 19520 loss 0.2604771500453353
step 19540 loss 0.2365845251828432
step 19560 loss 0.32690901644527914
step 19580 loss 0.3539141910150647
step 19600 loss 0.2826032777316868
step 1962

step 23680 loss 0.22355380058288574
step 23700 loss 0.17847079113125802
step 23720 loss 0.24532783795148133
step 23740 loss 0.3166666870936751
step 23760 loss 0.20756356390193104
step 23780 loss 0.27835907647386193
step 23800 loss 0.19833559021353722
step 23820 loss 0.281295269401744
step 23840 loss 0.22231017190497368
step 23860 loss 0.20305025877896696
step 23880 loss 0.26177110970020295
step 23900 loss 0.25656790342181923
step 23920 loss 0.26556523055769504
step 23940 loss 0.23538550173398107
step 23960 loss 0.2028667011298239
step 23980 loss 0.23361551351845264
step 24000 loss 0.26086583230644467
step 24020 loss 0.19791225623339415
step 24040 loss 0.31000593106728047
step 24060 loss 0.24100807383656503
step 24080 loss 0.3003956195898354
step 24100 loss 0.2277471374720335
step 24120 loss 0.22398299314081668
step 24140 loss 0.20315952943637966
step 24160 loss 0.1695642032660544
step 24180 loss 0.25393454935401677
step 24200 loss 0.2029120283201337
step 24220 loss 0.2543921545147896
s

step 28300 loss 0.28833422511816026
step 28320 loss 0.30440572686493395
step 28340 loss 0.18050468103028833
step 28360 loss 0.2087681104429066
step 28380 loss 0.324215782340616
step 28400 loss 0.16543314922600985
step 28420 loss 0.23209027564153076
step 28440 loss 0.2226025358773768
step 28460 loss 0.22479028441011906
step 28480 loss 0.23920235394034534
step 28500 loss 0.2290478221140802
step 28520 loss 0.2159542660228908
step 28540 loss 0.1910565703175962
step 28560 loss 0.14084063302725552
step 28580 loss 0.21380701363086702
step 28600 loss 0.23083078414201735
step 28620 loss 0.17404953073710203
step 28640 loss 0.19052944835275412
step 28660 loss 0.24981674228329212
step 28680 loss 0.1881789357867092
step 28700 loss 0.23619517269544305
step 28720 loss 0.2091351472772658
step 28740 loss 0.2016155523713678
step 28760 loss 0.19398548319004477
step 28780 loss 0.23050591219216585
step 28800 loss 0.3134930341970176
step 28820 loss 0.2856887564063072
step 28840 loss 0.16912816604599357
step

step 32900 loss 0.22904159016907216
step 32920 loss 0.2712306783068925
step 32940 loss 0.17714457735419273
step 32960 loss 0.19144387478008867
step 32980 loss 0.14511584541760386
step 33000 loss 0.2119609140790999
step 33020 loss 0.20782706884201615
step 33040 loss 0.17741785626858472
step 33060 loss 0.15858169852290302
step 33080 loss 0.1161962991231121
step 33100 loss 0.14787431857548655
step 33120 loss 0.2580953758209944
step 33140 loss 0.13205309328623116
step 33160 loss 0.2455327871721238
step 33180 loss 0.19613259020261467
step 33200 loss 0.20973005653358995
step 33220 loss 0.1819356913678348
step 33240 loss 0.2119174921186641
step 33260 loss 0.20861450335942208
step 33280 loss 0.13584714690223337
step 33300 loss 0.17257836415665223
step 33320 loss 0.28262497978284956
step 33340 loss 0.29599538519978524
step 33360 loss 0.20067960342857988
step 33380 loss 0.18803939083591104
step 33400 loss 0.16285536924842745
step 33420 loss 0.13627954265102743
step 33440 loss 0.17981208092533052

  0%|          | 0/34291 [00:00<?, ?it/s]

step 0 loss 0.21089195092208685
step 20 loss 0.2376395340077579
step 40 loss 0.22385706841014325
step 60 loss 0.1521665283245966
step 80 loss 0.214618634339422
step 100 loss 0.10192731637507677
step 120 loss 0.2147181489272043
step 140 loss 0.15944940337794833


KeyboardInterrupt: 

In [59]:
model.save_pretrained("../data/t5_trainded.model")

In [68]:
text = "Нынешний ___ президент Джо Байден заявил, что вопрос пошлин на импорт китайских товаров в США находится на рассмотрении, и он намерен обсудить его с министром финансов страны Джанет Йеллен после своего возвращения из азиатского турне, сообщило агентство Bloomberg. Эти слова Байдена были интерпретированы инвесторами как сигнал возможной отмены некоторых пошлин."
generate(model, f" fill сша | {text}")

'западный штат'

In [46]:
abbr.sort_values("desc_count", ascending=False).head(50)

Unnamed: 0,abbr,desc,desc_norm,desc_len,abbr_norm,abbr_len,abbr_count,desc_count,abbr_id
1647,каэр,КР,кр,1,каэр,1,0,133344,1647
2568,гг.,годы,год,1,год,1,114333,114333,2568
1503,кот.,который,который,1,кот,1,81160,76083,1503
1701,нт,нит,нит,1,нт,1,215231,58791,1701
4085,дн,дина,дин,1,дн,1,104033,53587,4085
2327,Рос.,Россия,россия,1,расти,1,4549,44827,2327
5102,тж.,также;,также,1,тж,1,43,33827,5102
3152,вв.,века,век,1,век,1,29348,29348,3152
461,вр.,время,время,1,вр,1,76411,28925,461
4447,чел.,человек,человек,1,чел,1,30758,27581,4447
