In [None]:
!pip install pandas matplotlib scikit-learn transformers optimum auto-gptq matplotlib seaborn deep_translator pony

In [None]:
%run template.ipynb

import torch
import pickle
import pandas as pd

from transformers import BertTokenizer

from vkr.ml import model, dataset, train
from vkr.llm import saiga_llama3
from vkr.data import datasets
from vkr.utils.vkr_root import VKR_ROOT

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained('DeepPavlov/rubert-base-cased')
bert = torch.nn.DataParallel(model.BertForBinaryClassification('DeepPavlov/rubert-base-cased')).to(
    device)

In [None]:
SMALL = False
MEDIUM = 0

# dts = [
#     ('WELFake', datasets.Welfake(VKR_ROOT / 'data/datasets/WELFake_Dataset.csv', True)),
#     ('FakeNews', datasets.FakeNewsPredictions(VKR_ROOT / 'data/datasets/FakeNewsPrediction.csv', True)),
# ]
dts = [
    ('Russian_WELFake',
     datasets.RussianWelfake(VKR_ROOT / 'data/datasets/Russian_WELFake_Dataset.csv',
                             [saiga_llama3.prompt_1, saiga_llama3.prompt_2,
                              saiga_llama3.prompt_3])),
    ('Kaggle', datasets.RussianKaggle(VKR_ROOT / 'data/datasets/russian_kaggle',
                                      [saiga_llama3.prompt_1, saiga_llama3.prompt_2,
                                       saiga_llama3.prompt_3])),
]
train_datas = []
val_datas = []
for dt_name, dt in dts:
    train_, test_ = dt.get_train_test()
    if SMALL:
        train_ = train_.sample(100)
        test_ = test_.sample(100)
    if MEDIUM > 0:
        train_ = train_.sample(len(train_) // MEDIUM)
        test_ = train_.sample(len(test_) // MEDIUM)
    train_datas.append(train_)
    val_datas.append((dt_name, test_))

In [None]:
BATCH_SIZE = 32

train_dataset = dataset.NewsDataset(pd.concat(train_datas), tokenizer)
train_loader = dataset.create_weighted_dataloader(train_dataset, batch_size=BATCH_SIZE)

val_loaders = [
    (phase_name, dataset.create_standard_dataloader(dataset.NewsDataset(val_data, tokenizer),
                                                    batch_size=BATCH_SIZE))
    for phase_name, val_data in val_datas
]

In [None]:
optimizer = torch.optim.Adam(bert.parameters(), lr=3e-5)
criterion = torch.nn.BCEWithLogitsLoss()

num_epochs = 10

In [None]:
train_results = train.train_binary(
    bert,
    train_loader,
    val_loaders,
    optimizer,
    criterion,
    num_epochs,
    device,
)
torch.save(bert.state_dict(), 'ru_bert.torch')
with open('ru_train_results.pkl', 'wb') as fout:
    pickle.dump(train_results, fout)