In [82]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
import numpy as np
import random
import shutil
from tqdm.auto import trange

In [83]:
model_checkpoint = 'cointegrated/rut5-small-chitchat2'
new_model_name = 'rut5-small-chitchat2-fine-tuned'
shutil.rmtree(new_model_name)

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

In [84]:
#fake data for training
pairs = [[str(n), str(n)] for n in range(1000)]

In [85]:
batch_size = 16
report_steps = 200
epochs = 3

model.train()
losses = []
for epoch in range(epochs):
    print('EPOCH', epoch)
    random.shuffle(pairs)
    for i in trange(0, int(len(pairs) / batch_size)):
        batch = pairs[i * batch_size: (i + 1) * batch_size]
        
        x = tokenizer([p[0] for p in batch], return_tensors='pt', padding=True).to(model.device)
        y = tokenizer([p[1] for p in batch], return_tensors='pt', padding=True).to(model.device)
        
        # -100 - специальное значение, позволяющее не учитывать токены
        y.input_ids[y.input_ids == 0] = -100
        
        loss = model(
            input_ids=x.input_ids,
            attention_mask=x.attention_mask,
            labels=y.input_ids,
            decoder_attention_mask=y.attention_mask,
            return_dict=True
        ).loss
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # печатаем скользящее среднее значение функции потерь
        losses.append(loss.item())
        if i % report_steps == 0:
            print('step', i, 'loss', np.mean(losses[-report_steps:]))

EPOCH 0


  0%|          | 0/62 [00:00<?, ?it/s]

step 0 loss 2.9394114017486572
EPOCH 1


  0%|          | 0/62 [00:00<?, ?it/s]

step 0 loss 1.8304987776847113
EPOCH 2


  0%|          | 0/62 [00:00<?, ?it/s]

step 0 loss 1.2506193270683288


In [86]:
#model.eval()

In [87]:
def answer(x, **kwargs):
    inputs = tokenizer(x, return_tensors='pt').to(model.device)
    with torch.no_grad():
        hypotheses = model.generate(
            **inputs, 
            do_sample=True, top_p=0.5, num_return_sequences=3, 
            repetition_penalty=2.5,
            max_length=32,
        )
    return tokenizer.decode(hypotheses[0], skip_special_tokens=True)

In [88]:
model.save_pretrained(new_model_name)
tokenizer.save_pretrained(new_model_name)

('rut5-rut5-small-chitchat2-fine-tuned/tokenizer_config.json',
 'rut5-rut5-small-chitchat2-fine-tuned/special_tokens_map.json',
 'rut5-rut5-small-chitchat2-fine-tuned/spiece.model',
 'rut5-rut5-small-chitchat2-fine-tuned/added_tokens.json',
 'rut5-rut5-small-chitchat2-fine-tuned/tokenizer.json')

In [89]:
answer('какое-то рандомное совершенно предложение')

'Что это такое?'