In [1]:
import os

import itertools
import logging
import pandas as pd
import sklearn.metrics as metrics
import yt.wrapper as yt

from deeppavlov.models.preprocessors.bert_preprocessor import BertPreprocessor
from deeppavlov.models.bert.bert_classifier import BertClassifierModel

from tqdm.autonotebook import tqdm

[nltk_data] Downloading package punkt to /home/lyubanenko/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /home/lyubanenko/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package perluniprops to
[nltk_data]     /home/lyubanenko/nltk_data...
[nltk_data]   Package perluniprops is already up-to-date!
[nltk_data] Downloading package nonbreaking_prefixes to
[nltk_data]     /home/lyubanenko/nltk_data...
[nltk_data]   Package nonbreaking_prefixes is already up-to-date!









In [2]:
tqdm.pandas()

In [3]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

In [4]:
MODEL_DIR = '/home/lyubanenko/data/deeppavlov_models/rubert_cased_L-12_H-768_A-12_v2'

In [5]:
bert_preprocessor_config = {
    'vocab_file': os.path.join(MODEL_DIR, 'vocab.txt'),
    'do_lower_case': False,
    'max_seq_length': 64
}
bert_preprocessor = BertPreprocessor(**bert_preprocessor_config)




In [6]:
input_feats = bert_preprocessor(['Простое ведро (на илл.) стало причиной кровопролитного сражения и гибели более 2000 человек'])

In [7]:
print(input_feats[0].tokens)
print(input_feats[0].input_ids[:10])

['[CLS]', 'Просто', '##е', 'ведро', '(', 'на', 'илл', '.', ')', 'стало', 'причиной', 'кровопролит', '##ного', 'сражения', 'и', 'гибели', 'более', '2000', 'человек', '[SEP]']
[101, 60949, 842, 114147, 120, 1469, 49629, 132, 122, 8488]


In [8]:
yt.config["proxy"]["url"] = "hahn.yt.yandex.net"

TRAIN_TABLE = '//home/ynews/ai/datasets/rbc_politics/train'
TEST_TABLE = '//home/ynews/ai/datasets/rbc_politics/test'

In [9]:
columns=['title', 'target']

force = False

if not 'train_data' in locals() or force:
    train_data = pd.DataFrame(yt.read_table(yt.TablePath(TRAIN_TABLE, columns=columns)))
if not 'test_data' in locals() or force:
    test_data = pd.DataFrame(yt.read_table(yt.TablePath(TEST_TABLE, columns=columns)))

In [10]:
train_data_sample_size = 10240
train_data_sample = test_data.sample(train_data_sample_size, axis=0)

train_data_ = list(train_data_sample['title'].progress_apply(lambda row: bert_preprocessor([row])[0]))
train_true_ = list(train_data_sample['target'])

HBox(children=(IntProgress(value=0, max=10240), HTML(value='')))




In [11]:
test_data_sample_size = 256
test_data_sample = test_data.sample(test_data_sample_size, axis=0)

test_data_ = list(test_data_sample['title'].progress_apply(lambda row: bert_preprocessor([row])[0]))
test_true_ = list(test_data_sample['target'])

HBox(children=(IntProgress(value=0, max=256), HTML(value='')))




In [12]:
print(sum(train_true_), len(train_true_))
print(sum(test_true_), len(test_true_))

3663 10240
102 256


In [13]:
OUTPUT_DIR = './output'

In [14]:
bert_model_config = {
    'bert_config_file': os.path.join(MODEL_DIR, 'bert_config.json'),
    'n_classes': len(train_data.target.unique()),
    'keep_prob': 0.9,
    'save_path': OUTPUT_DIR,
    
    'learning_rate': 2e-6,
    'return_probas': True
}
bert_model = BertClassifierModel(**bert_model_config)

Using TensorFlow backend.






The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
Instructions for updating:
Use keras.layers.dense instead.

















Instructions for updating:
Use standard file APIs to check for files with this prefix.


In [15]:
def apply_model_batch(iterable, size=128):
    it = iter(iterable)
    with tqdm() as pbar:
        while True:
            batch = tuple(itertools.islice(it, size))
            if not batch:
                return
            yield from bert_model(batch)
            pbar.update(len(batch))

In [18]:
y_prob = list(apply_model_batch(test_data_, 96))
y_score = list(map(lambda x: x[1], y_prob))

print(list(zip(y_score, test_true_)))
print('AUC', metrics.roc_auc_score(test_true_, y_score))

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


[(0.53992677, 0), (0.53243244, 1), (0.52184415, 0), (0.54752475, 0), (0.5265158, 1), (0.53639954, 0), (0.5351596, 0), (0.53699946, 1), (0.53076404, 0), (0.5207457, 1), (0.5218589, 1), (0.5150393, 1), (0.5272572, 1), (0.526088, 1), (0.5176361, 1), (0.5215692, 1), (0.5196207, 1), (0.5212312, 0), (0.5179006, 0), (0.53152084, 0), (0.5265077, 0), (0.5265114, 1), (0.53802377, 1), (0.52786833, 0), (0.525029, 1), (0.53945476, 1), (0.51790357, 0), (0.51310676, 0), (0.5200862, 0), (0.5103293, 1), (0.5351488, 1), (0.52843297, 1), (0.5247577, 0), (0.5271485, 1), (0.50772977, 1), (0.53914714, 0), (0.5321631, 1), (0.54132605, 1), (0.5151163, 0), (0.54040956, 0), (0.5167253, 0), (0.5142531, 0), (0.516151, 0), (0.5321675, 0), (0.53853184, 1), (0.5113825, 0), (0.55579036, 0), (0.5257399, 1), (0.5421491, 0), (0.51477593, 1), (0.5346692, 0), (0.5398905, 1), (0.527149, 0), (0.53053755, 1), (0.5241328, 0), (0.522869, 1), (0.53534365, 0), (0.5158377, 1), (0.516496, 1), (0.5269769, 1), (0.53554815, 1), (0.5

In [19]:
def train_model_batch(iterable, size=128):
    it = iter(iterable)
    with tqdm() as pbar:
        while True:
            batch = tuple(itertools.islice(it, size))
            if not batch:
                return
            (train_data_, train_true_) = zip(*batch)
            step = bert_model.train_on_batch(train_data_, train_true_)
            pbar.write(f'{step}')
            pbar.update(len(batch))

In [20]:
train_model_batch(zip(train_data_, train_true_), 32)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

{'loss': 0.6968312, 'learning_rate': 2e-06}
{'loss': 0.7128269, 'learning_rate': 2e-06}
{'loss': 0.68937165, 'learning_rate': 2e-06}
{'loss': 0.77659744, 'learning_rate': 2e-06}
{'loss': 0.7715614, 'learning_rate': 2e-06}
{'loss': 0.6766455, 'learning_rate': 2e-06}
{'loss': 0.7095424, 'learning_rate': 2e-06}
{'loss': 0.7067834, 'learning_rate': 2e-06}
{'loss': 0.6861931, 'learning_rate': 2e-06}
{'loss': 0.6705745, 'learning_rate': 2e-06}
{'loss': 0.67282426, 'learning_rate': 2e-06}
{'loss': 0.5410495, 'learning_rate': 2e-06}
{'loss': 0.68599594, 'learning_rate': 2e-06}
{'loss': 0.744066, 'learning_rate': 2e-06}
{'loss': 0.75826824, 'learning_rate': 2e-06}
{'loss': 0.7165104, 'learning_rate': 2e-06}
{'loss': 0.62501156, 'learning_rate': 2e-06}
{'loss': 0.6578826, 'learning_rate': 2e-06}
{'loss': 0.64691734, 'learning_rate': 2e-06}
{'loss': 0.67216176, 'learning_rate': 2e-06}
{'loss': 0.6479987, 'learning_rate': 2e-06}
{'loss': 0.6255169, 'learning_rate': 2e-06}
{'loss': 0.6965535, 'lear

{'loss': 0.6365743, 'learning_rate': 2e-06}
{'loss': 0.57487106, 'learning_rate': 2e-06}
{'loss': 0.62990284, 'learning_rate': 2e-06}
{'loss': 0.6418253, 'learning_rate': 2e-06}
{'loss': 0.67594606, 'learning_rate': 2e-06}
{'loss': 0.6650083, 'learning_rate': 2e-06}
{'loss': 0.5343714, 'learning_rate': 2e-06}
{'loss': 0.6002871, 'learning_rate': 2e-06}
{'loss': 0.6458644, 'learning_rate': 2e-06}
{'loss': 0.60537636, 'learning_rate': 2e-06}
{'loss': 0.72260296, 'learning_rate': 2e-06}
{'loss': 0.5957589, 'learning_rate': 2e-06}
{'loss': 0.5650791, 'learning_rate': 2e-06}
{'loss': 0.7108805, 'learning_rate': 2e-06}
{'loss': 0.6071268, 'learning_rate': 2e-06}
{'loss': 0.6119791, 'learning_rate': 2e-06}
{'loss': 0.663414, 'learning_rate': 2e-06}
{'loss': 0.65172446, 'learning_rate': 2e-06}
{'loss': 0.64490616, 'learning_rate': 2e-06}
{'loss': 0.6604345, 'learning_rate': 2e-06}
{'loss': 0.58588886, 'learning_rate': 2e-06}
{'loss': 0.6191325, 'learning_rate': 2e-06}
{'loss': 0.6469478, 'lear

In [21]:
y_prob = list(apply_model_batch(test_data_, 96))
y_score = list(map(lambda x: x[1], y_prob))

print(list(zip(y_score, test_true_)))
print('AUC', metrics.roc_auc_score(test_true_, y_score))

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


[(0.24437527, 0), (0.23488504, 1), (0.22981948, 0), (0.13666928, 0), (0.20167416, 1), (0.2971854, 0), (0.3942679, 0), (0.4376179, 1), (0.3150537, 0), (0.59521925, 1), (0.32211435, 1), (0.19389792, 1), (0.43277887, 1), (0.37747285, 1), (0.65493226, 1), (0.19636887, 1), (0.5615346, 1), (0.38698795, 0), (0.23794456, 0), (0.3131529, 0), (0.37061986, 0), (0.2513588, 1), (0.45984462, 1), (0.4882022, 0), (0.23242934, 1), (0.3969434, 1), (0.19502546, 0), (0.10313329, 0), (0.11177452, 0), (0.2329289, 1), (0.19584127, 1), (0.600192, 1), (0.1907665, 0), (0.34531248, 1), (0.15879016, 1), (0.21486291, 0), (0.23007298, 1), (0.15267642, 1), (0.14425042, 0), (0.26472068, 0), (0.11452936, 0), (0.21640041, 0), (0.19540477, 0), (0.28980583, 0), (0.75518924, 1), (0.18571968, 0), (0.26095113, 0), (0.5190089, 1), (0.24244614, 0), (0.24992885, 1), (0.29815975, 0), (0.477792, 1), (0.21606818, 0), (0.44444466, 1), (0.28924066, 0), (0.7505495, 1), (0.46090767, 0), (0.33174416, 1), (0.5076641, 1), (0.7451736, 1