### Install libraries

In [0]:
!pip install transformers -q

[K     |████████████████████████████████| 675kB 3.5MB/s 
[K     |████████████████████████████████| 890kB 9.9MB/s 
[K     |████████████████████████████████| 3.8MB 27.1MB/s 
[K     |████████████████████████████████| 1.1MB 54.6MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


### Get a pre-trained model

In [0]:
!mkdir -p model/pretrained
!wget https://www.dropbox.com/s/2ysavun8x5m6duu/checkpoint.tar?dl=0 -O model/pretrained/checkpoint.tar -q

### Get modules

In [0]:
!git clone https://github.com/IgnatovD/ruBart/ -q

Cloning into 'ruBart'...
remote: Enumerating objects: 19, done.[K
remote: Counting objects: 100% (19/19), done.[K
remote: Compressing objects: 100% (16/16), done.[K
remote: Total 19 (delta 3), reused 10 (delta 1), pack-reused 0[K
Unpacking objects: 100% (19/19), done.


In [0]:
cd ruBart/modules

/content/ruBart


### Get data

In [0]:
!wget https://github.com/RossiyaSegodnya/ria_news_dataset/raw/master/ria.json.gz -q

In [0]:
!gunzip ria.json.gz

In [0]:
import math
import torch
import transformers
from transformers import BartConfig

from fine_tune import FineTune
from preprocessing_sum import GenerateDataloader

import warnings
warnings.filterwarnings('ignore')

In [0]:
#@title Set parameters { vertical-output: true, display-mode: "form" }

batch_size = 8 #@param
max_length = 512 #@param
path_data = 'ria.json' #@param

In [0]:
gd = GenerateDataloader(path_data, batch_size, max_length)

In [0]:
train_data, val_data, test_data = gd.split()

HBox(children=(FloatProgress(value=0.0, max=1003869.0), HTML(value='')))




In [0]:
#@title Set parameters { vertical-output: true, display-mode: "form" }

lr = 1e-4 #@param
weight_decay = 0.0 #@param
num_epoch = 2 #@param
accum_batch_size = 8000 #@param
path_save = 'model/finetune/' #@param
path_pretrained_model = 'model/pretrained/checkpoint.tar' #@param

accum_steps = math.ceil(accum_batch_size / batch_size)
total_steps = math.ceil(len(train_data) / batch_size / accum_steps) * num_epoch # 19621082 - all sentences in data 

args = {
        'batch_size': batch_size,
        'device': torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
        'lr': lr,
        'weight_decay': weight_decay,
        'total_steps': total_steps,
        'warmup_steps': math.ceil(total_steps / 100),
        'vocab_size': 30000,
        'accum_steps': accum_steps,
        'path_save': path_save,
        'path_pretrained_model': path_pretrained_model
        }

config = BartConfig()
config.d_model = 256
config.decoder_attention_heads = 16
config.decoder_ffn_dim = 1024
config.decoder_layers = 4
config.encoder_attention_heads = 16
config.encoder_ffn_dim = 1024
config.encoder_layers = 4
config.num_hidden_layers = 4
config.pad_token_id = 1
config.vocab_size = 30000

In [0]:
dataloader = gd.get_dataloader(train_data)

### Download and fine-tune the model

In [0]:
trainer = FineTune(config, **args)

In [0]:
model, optimizer, scheduler, history, learning_rate, best_loss = trainer.load_pretrained()

In [0]:
model, history, learning_rate, best_loss = trainer.train_sum(dataloader, model, optimizer, scheduler, history, learning_rate, best_loss)

HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))






### Training visualization

In [0]:
trainer.visualization(history, mode='loss')

In [0]:
trainer.visualization(learning_rate, mode='lr')