In [1]:
import torch
import torch.nn as nn
from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ExponentialLR
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from jupyterthemes import jtplot
import time
jtplot.style(theme='grade3')
from IPython.display import HTML, display
def set_css_in_cell_output():
    display(HTML('''
        <style>
            .jupyter-widgets {color: #d5d5d5 !important;}
            .widget-label {color: #d5d5d5 !important;}
        </style>
    '''))
get_ipython().events.register('pre_run_cell', set_css_in_cell_output)
import tensorflow as tf
gpu = tf.config.list_physical_devices(device_type='GPU')[0]
tf.config.experimental.set_memory_growth(gpu, True)



In [2]:
# Для классификации фейковых токенов(слов) будет использовать Soft F1 Loss, 
# так как количество фейковых слов много меньше правильных слов.
# Soft precision
def precision_score(pred, target, eps=1e-10):
    tp = (pred*target).sum()
    tp_fp = pred.sum()+eps
    return tp/tp_fp
# Soft precision loss
def precision_loss(pred, target, eps=1e-10):
    return 1-precision_score(pred, target, eps)
# Soft recall
def recall_score(pred, target, eps=1e-10):
    tp = (pred*target).sum()
    tp_fp = target.sum()+eps
    return tp/tp_fp
# Soft recall loss
def recall_loss(pred, target, eps=1e-10):
    return 1-recall_score(pred, target, eps)
# Soft F-beta metric
def f1_score(pred, target, beta=1, eps=1e-10):
    prec = precision_score(pred, target, eps)
    rec = recall_score(pred, target, eps)
    return (1+beta**2)*prec*rec/(prec*beta**2+rec)
# Soft F-beta loss
def f1_loss(pred, target, beta=1, eps=1e-10):
    return 1-f1_score(pred, target, beta, eps)

In [3]:
import torch.nn.utils as utils
from torch.utils.data import DataLoader, Dataset 
from transformers import BertForSequenceClassification, BertTokenizer
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Будем использовать русскоязычный BERT от Сбербанка
tokenizer = BertTokenizer.from_pretrained("sberbank-ai/ruBert-base")
bert_model = BertForSequenceClassification.from_pretrained("sberbank-ai/ruBert-base")
bert_model = bert_model.to(device)

Some weights of the model checkpoint at sberbank-ai/ruBert-base were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not in

In [4]:
import pandas as pd
# обучающий датасет, взято из статьи: https://habr.com/ru/post/564916/
# сам датасет: https://storage.yandexcloud.net/nlp/paranmt_ru_leipzig.zip
df = pd.read_csv("/data/paranmt_ru_leipzig.tsv", sep="\t", index_col="idx")
size_10 = 1000
df = df.iloc[:10*size_10]
df_train = df.iloc[:size_10*8]
df_test = df.iloc[size_10*8:]
del df

In [5]:
import os
import random
import pandas as pd

class Tokens:
    def __init__(self, filepath=os.path.join('data', 'token_clusters.csv')):
        self.filepath = filepath
        self.df = self.load_tokens()
        self.mapping_by_cluster = self.create_mapping_by_cluster()
        self.mapping_by_token_id = self.create_mapping_by_token_id()

    def load_tokens(self):
        return pd.read_csv(self.filepath)

    def create_mapping_by_cluster(self):
        mapping_by_cluster = dict()
        for cluster, group in self.df.groupby('cluster'):
            mapping_by_cluster[cluster] = list(group.token_id)

        return mapping_by_cluster

    def create_mapping_by_token_id(self):
        mapping_by_token_id = dict()
        for token_id, cluster in zip(self.df.token_id, self.df.cluster):
            mapping_by_token_id[token_id] = cluster

        return mapping_by_token_id

    def get_cluster(self, token_id):
        return self.mapping_by_token_id[token_id]

    def get_token_id_from_cluster(self, cluster):
        return random.sample(self.mapping_by_cluster[cluster], 1)[0]

    def get_random_token(self, token_id):
        cluster = self.get_cluster(token_id)
        random_token = self.get_token_id_from_cluster(cluster)
        return random_token

In [6]:
from random import randint, uniform
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer=tokenizer):
        self.vocab = list(tokenizer.vocab.values())
        self.len_vocab = len(self.vocab)
        self.original = df.original.values
        self.rewrite = df.ru.values
        self.tokenizer = tokenizer
        self.t = Tokens()
        self.emb_size = 512
        self.sep_id = self.tokenizer.convert_tokens_to_ids(["SEP"])
    def __len__(self):
        return self.original.shape[0]
    def __getitem__(self, idx):
        original = self.original[idx]
        mask = torch.zeros(self.emb_size)
        rewrite = self.rewrite[idx]
        # с вероятностью 20% оригинал и перефразирование совпадают
        no_rewrite = uniform(0, 1)<0.2
        if no_rewrite:
            rewrite = original
            
        data = self.tokenizer(original, rewrite, 
                              return_tensors="pt", max_length=self.emb_size, 
                              padding='max_length', truncation='longest_first')
        # находим 2 текст между SEP токенами(id=102)
        first_sep = 0
        second_sep = 0
        for i in range(512):
            if data["input_ids"][0][i]==102 and not first_sep:
                first_sep = i
            elif data["input_ids"][0][i]==102:
                second_sep=i
        # в 50% текстов делаем фейки и помечаем target как фейк(1) или не фейк(0)
        target = randint(0, 1)
        if target:
            # заменяет часать (на практике около 25%) второго текста на фейки, если target - фейк
            for i in range((second_sep-first_sep)):
                preplace_id = randint(first_sep+1, second_sep-1)
                token_id = int(data["input_ids"][0][preplace_id])
                try:
                    new_token_id = self.t.get_random_token(token_id)
                    data["input_ids"][0][preplace_id]=new_token_id
                    mask[preplace_id]=1
                except Exception as e:
                    pass
                    #print("ex",e)
        return (data["input_ids"].view((-1)), data["attention_mask"].view((-1)), 
                data["token_type_ids"].view((-1)), mask, target)

In [7]:
train_dataset = MyDataset(df_train)
test_dataset = MyDataset(df_test)

In [8]:
def train_epoch(model, optimizer, dataloader, device, batch_size):
    epoch_loss, length, acc, tp, tp_fp, tp_fn, acc_cls = 0, 0, 0, 0, 0, 0, 0
    for i, (input_ids, attention_mask, token_type_ids, mask, target) in enumerate(tqdm(dataloader)):
        if i%bs_accumulate==0:
            optimizer.zero_grad()
        (tokens_class, text_class) = model(input_ids, attention_mask, token_type_ids, device)
        criterion1 = f1_loss
        criterion2 = nn.BCELoss()
        loss1 = criterion1(tokens_class.to(device), mask.to(device))
        loss2 = criterion2(text_class.to(device), target.float().to(device))
        loss = loss1 + loss2
        loss.backward()
        epoch_loss += loss.item()
        pred = ((tokens_class>=0.5)*1)
        acc+=(pred.to(device)==mask.to(device)).sum().item()/mask.view((-1)).shape[0]
        tp+=(pred.to(device)*mask.to(device)).sum().item()
        tp_fp += (pred>=0.5).sum().item()
        tp_fn += mask.sum().item()
        acc_cls += ((((text_class>=0.5)*1)==target.to(device))*1.).mean().item()
        if i%bs_accumulate==0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
    precision = tp/(tp_fp+1e-10)
    recall = tp/(tp_fn+1e-10)
    return epoch_loss/len(dataloader), acc/len(dataloader), precision, recall, acc_cls/len(dataloader)
def test_epoch(model, dataloader, device, batch_size):
    epoch_loss, length, acc, tp, tp_fp, tp_fn, acc_cls = 0, 0, 0, 0, 0, 0, 0
    for i, (input_ids, attention_mask, token_type_ids, mask, target) in enumerate(tqdm(dataloader)):
        (tokens_class, text_class) = model(input_ids, attention_mask, token_type_ids, device)
        criterion1 = f1_loss
        criterion2 = nn.BCELoss()
        loss1 = criterion1(tokens_class.to(device), mask.to(device))
        loss2 = criterion2(text_class.to(device), target.float().to(device))
        loss = loss1 + loss2
        epoch_loss += loss.item()
        pred = ((tokens_class>=0.5)*1)
        acc+=(pred.to(device)==mask.to(device)).sum().item()/mask.view((-1)).shape[0]
        tp+=(pred.to(device)*mask.to(device)).sum().item()
        tp_fp += (pred>=0.5).sum().item()
        tp_fn += mask.sum().item()
        acc_cls += ((((text_class>=0.5)*1)==target.to(device))*1.).mean().item()
    precision = tp/(tp_fp+1e-10)
    recall = tp/(tp_fn+1e-10)
    return epoch_loss/len(dataloader), acc/len(dataloader), precision, recall, acc_cls/len(dataloader)

In [9]:
class Article_Estimator(nn.Module):
    def __init__(self, bert_model = bert_model):
        super().__init__()
        self.bert_model = bert_model
        self.pooler_dim = 768
        self.is_fake = nn.Linear(self.pooler_dim, 1)
        self.is_token_fake = nn.Linear(self.pooler_dim, 1)
    def forward(self, input_ids, attention_mask, token_type_ids, device):
        out = self.bert_model.bert(input_ids = input_ids.int().to(device), 
                                   attention_mask = attention_mask.to(device), 
                                   token_type_ids = token_type_ids.int().to(device))
        tokens_cls = nn.Sigmoid()(self.is_token_fake(out.last_hidden_state)).squeeze(dim=-1)
        text_cls = nn.Sigmoid()(self.is_fake(out.pooler_output)).squeeze(dim=-1)
        return (tokens_cls, text_cls)
model = Article_Estimator().to(device)
lr=1e-5
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
i=0
layers=2
for param in model.parameters():
#     if i<5+16*8:
    if i<6+16*(12-layers):
        param.requires_grad=False
    i+=1
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 14,179,588 trainable parameters


In [10]:
bs_size = 250
bs_train, bs_test = bs_size, bs_size
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=bs_train, 
                                               shuffle=True, num_workers=4)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=bs_test, 
                                              shuffle=False, num_workers=4)

In [11]:
run_name = str(layers)+"L-base-BCE-correct-10-epochs"

In [12]:
epochs = 10
bs_accumulate = 1

In [13]:
import wandb
wandb.init(project='ruBERT', name=run_name)
config = wandb.config
config.learning_rate = lr
config.layers = layers
config.loss = "BCELoss"
config.epochs = epochs
config.tokens = 512
config.model = "bert-base-uncased"
config.batch_size = str(bs_accumulate)+"*"+str(bs_train)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdiht404[0m (use `wandb login --relogin` to force relogin)


In [14]:
max_test_acc, max_test_prec, max_test_rec, max_test_f1, max_acc_cls = 0, 0, 0, 0, 0
max_acc_iter, max_prec_iter, max_rec_iter, max_f1_iter, max_acc_cls_iter = 0, 0, 0, 0, 0

In [25]:
epochs = 20

In [26]:
start = 0
for epoch in tqdm(range(start, start+epochs)):
    model.train()
    loss, acc, prec, rec, acc_cls = train_epoch(model, optimizer, train_dataloader, device, bs_train)
    f1_train = 2*rec*prec/(rec+prec+1e-10)
    model.eval()
    with torch.no_grad():
        loss_test, acc_test, prec_test, rec_test, acc_cls_test = test_epoch(model, test_dataloader, device, bs_test)
    f1_test = 2*rec_test*prec_test/(rec_test+prec_test+1e-10)
    wandb.log({"Train": {"Loss": loss, "Accuracy": acc, "Precision": prec, 
                     "Recall": rec, "F1": f1_train, "Acc_cls": acc_cls}, 
           "Test":{"Loss": loss_test, "Accuracy": acc_test, "Precision": prec_test, 
                   "Recall": rec_test, "F1": f1_test, "Acc_cls": acc_cls_test}})
    if acc_test > max_test_acc:
        max_test_acc, max_acc_iter = acc_test, epoch
    if prec_test > max_test_prec:
        max_test_prec, max_prec_iter = prec_test, epoch
    if rec_test > max_test_rec:
        max_test_rec, max_rec_iter = rec_test, epoch
    if f1_test > max_test_f1:
        max_test_f1, max_f1_iter = f1_test, epoch
    if acc_cls_test > max_acc_cls:
        max_acc_cls, max_acc_cls_iter = acc_cls_test, epoch
    print(" Epoch:", epoch, "Train Loss:", round(loss, 5), "Test Loss:", round(loss_test, 5), '\n',
          "Train acc:    ", round(acc,5),     "Test Acc:    ", round(acc_test,5),    "Max acc:     ", round(max_test_acc,5),  "Max acc_iter:    ", max_acc_iter, '\n',
          "Train prec:   ", round(prec,5),    "Test Prec:   ", round(prec_test,5),   "Max prec:    ", round(max_test_prec,5), "Max prec iter:   ", max_prec_iter, '\n',
          "Train recall: ", round(rec,5),     "Test Recall: ", round(rec_test,5),    "Max recall:  ", round(max_test_rec,5),  "Max recall iter: ", max_rec_iter, '\n',
          "Train F1:     ", round(f1_train,5),"Test F1:     ", round(f1_test,5),     "Max F1_score:", round(max_test_f1,5),   "Max F1 iter:     ", max_f1_iter, '\n',
          "Train acc cls:", round(acc_cls,5), "Test acc cls:", round(acc_cls_test,5),"Max acc_cls:",  round(max_acc_cls,5),   "Max acc_cls iter:", max_acc_cls_iter, '\n')
    torch.save(model.state_dict(), f'/data/Models/ruBERT_8k_2L_{epoch+1}_epochs.pt')

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

 Epoch: 10 Train Loss: 1.011 Test Loss: 0.97614 
 Train acc:     0.98836 Test Acc:     0.98929 Max acc:      0.98949 Max acc_iter:     7 
 Train prec:    0.12238 Test Prec:    0.14885 Max prec:     0.14885 Max prec iter:    10 
 Train recall:  0.21266 Test Recall:  0.22139 Max recall:   0.22139 Max recall iter:  10 
 Train F1:      0.15536 Test F1:      0.17802 Max F1_score: 0.17802 Max F1 iter:      10 
 Train acc cls: 0.96713 Test acc cls: 0.9705 Max acc_cls: 0.9755 Max acc_cls iter: 9 



  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

 Epoch: 11 Train Loss: 0.95076 Test Loss: 0.90514 
 Train acc:     0.99031 Test Acc:     0.99146 Max acc:      0.99146 Max acc_iter:     11 
 Train prec:    0.17872 Test Prec:    0.21625 Max prec:     0.21625 Max prec iter:    11 
 Train recall:  0.22834 Test Recall:  0.24362 Max recall:   0.24362 Max recall iter:  11 
 Train F1:      0.2005 Test F1:      0.22912 Max F1_score: 0.22912 Max F1 iter:      11 
 Train acc cls: 0.97325 Test acc cls: 0.976 Max acc_cls: 0.976 Max acc_cls iter: 11 



  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

 Epoch: 12 Train Loss: 0.87805 Test Loss: 0.82475 
 Train acc:     0.99265 Test Acc:     0.99391 Max acc:      0.99391 Max acc_iter:     12 
 Train prec:    0.27748 Test Prec:    0.39446 Max prec:     0.39446 Max prec iter:    12 
 Train recall:  0.27934 Test Recall:  0.28749 Max recall:   0.28749 Max recall iter:  12 
 Train F1:      0.27841 Test F1:      0.33258 Max F1_score: 0.33258 Max F1 iter:      12 
 Train acc cls: 0.97188 Test acc cls: 0.9675 Max acc_cls: 0.976 Max acc_cls iter: 11 



  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

 Epoch: 13 Train Loss: 0.76114 Test Loss: 0.70996 
 Train acc:     0.99467 Test Acc:     0.99542 Max acc:      0.99542 Max acc_iter:     13 
 Train prec:    0.48211 Test Prec:    0.57842 Max prec:     0.57842 Max prec iter:    13 
 Train recall:  0.33034 Test Recall:  0.3264 Max recall:   0.3264 Max recall iter:  13 
 Train F1:      0.39205 Test F1:      0.41731 Max F1_score: 0.41731 Max F1 iter:      13 
 Train acc cls: 0.97113 Test acc cls: 0.9755 Max acc_cls: 0.976 Max acc_cls iter: 11 



  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

 Epoch: 14 Train Loss: 0.66383 Test Loss: 0.62603 
 Train acc:     0.99535 Test Acc:     0.99556 Max acc:      0.99556 Max acc_iter:     14 
 Train prec:    0.57532 Test Prec:    0.64651 Max prec:     0.64651 Max prec iter:    14 
 Train recall:  0.38941 Test Recall:  0.38931 Max recall:   0.38931 Max recall iter:  14 
 Train F1:      0.46445 Test F1:      0.48598 Max F1_score: 0.48598 Max F1 iter:      14 
 Train acc cls: 0.9735 Test acc cls: 0.973 Max acc_cls: 0.976 Max acc_cls iter: 11 



  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

 Epoch: 15 Train Loss: 0.60664 Test Loss: 0.57856 
 Train acc:     0.99569 Test Acc:     0.99579 Max acc:      0.99579 Max acc_iter:     15 
 Train prec:    0.61963 Test Prec:    0.64036 Max prec:     0.64651 Max prec iter:    14 
 Train recall:  0.43111 Test Recall:  0.43704 Max recall:   0.43704 Max recall iter:  15 
 Train F1:      0.50846 Test F1:      0.51952 Max F1_score: 0.51952 Max F1 iter:      15 
 Train acc cls: 0.97325 Test acc cls: 0.975 Max acc_cls: 0.976 Max acc_cls iter: 11 



  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

 Epoch: 16 Train Loss: 0.57287 Test Loss: 0.57523 
 Train acc:     0.99588 Test Acc:     0.99594 Max acc:      0.99594 Max acc_iter:     16 
 Train prec:    0.63368 Test Prec:    0.68982 Max prec:     0.68982 Max prec iter:    16 
 Train recall:  0.4594 Test Recall:  0.43289 Max recall:   0.43704 Max recall iter:  15 
 Train F1:      0.53265 Test F1:      0.53195 Max F1_score: 0.53195 Max F1 iter:      16 
 Train acc cls: 0.97713 Test acc cls: 0.973 Max acc_cls: 0.976 Max acc_cls iter: 11 



  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

 Epoch: 17 Train Loss: 0.5573 Test Loss: 0.54749 
 Train acc:     0.99596 Test Acc:     0.99593 Max acc:      0.99594 Max acc_iter:     16 
 Train prec:    0.66014 Test Prec:    0.6511 Max prec:     0.68982 Max prec iter:    16 
 Train recall:  0.47132 Test Recall:  0.48053 Max recall:   0.48053 Max recall iter:  17 
 Train F1:      0.54998 Test F1:      0.55296 Max F1_score: 0.55296 Max F1 iter:      17 
 Train acc cls: 0.97138 Test acc cls: 0.9735 Max acc_cls: 0.976 Max acc_cls iter: 11 



  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

 Epoch: 18 Train Loss: 0.54426 Test Loss: 0.53867 
 Train acc:     0.99601 Test Acc:     0.9961 Max acc:      0.9961 Max acc_iter:     18 
 Train prec:    0.6547 Test Prec:    0.65157 Max prec:     0.68982 Max prec iter:    16 
 Train recall:  0.48813 Test Recall:  0.49003 Max recall:   0.49003 Max recall iter:  18 
 Train F1:      0.55928 Test F1:      0.55937 Max F1_score: 0.55937 Max F1 iter:      18 
 Train acc cls: 0.97163 Test acc cls: 0.9735 Max acc_cls: 0.976 Max acc_cls iter: 11 



  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

 Epoch: 19 Train Loss: 0.53044 Test Loss: 0.52382 
 Train acc:     0.99616 Test Acc:     0.99634 Max acc:      0.99634 Max acc_iter:     19 
 Train prec:    0.65994 Test Prec:    0.74779 Max prec:     0.74779 Max prec iter:    19 
 Train recall:  0.5055 Test Recall:  0.46783 Max recall:   0.49003 Max recall iter:  18 
 Train F1:      0.57249 Test F1:      0.57558 Max F1_score: 0.57558 Max F1 iter:      19 
 Train acc cls: 0.97313 Test acc cls: 0.9725 Max acc_cls: 0.976 Max acc_cls iter: 11 



In [29]:
model.load_state_dict(torch.load('/data/Models/ruBERT_8k_2L_20_epochs.pt'))

<All keys matched successfully>