# Fine-Tuning

Воспользуемся моделью ruDialoGpt3-medium-finetuned-telegram от Kirili4ik, предобученной на диалогах из Telegram. Это позволит смешать стили непринужденной беседы и тяжелых диалогов Достоевского

In [6]:
import sys
import re
import json

from sklearn.model_selection import train_test_split
from tqdm import tqdm

import torch
from transformers import TextDataset, DataCollatorForLanguageModeling
from torch.utils.data import DataLoader

from accelerate import Accelerator
from transformers import AdamW, AutoModelForSequenceClassification, get_scheduler

import tqdm

In [7]:
from transformers import AutoModelForCausalLM, AutoTokenizer

checkpoint = "Kirili4ik/ruDialoGpt3-medium-finetuned-telegram"
tokenizer =  AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint)

Воспользуемся функцией для загрузки датасета:

In [8]:
def load_dataset(train_path, test_path, tokenizer):
    """Creates train and test PyTorch datasets and collate_fn using HuggingFace.

    Parameters
    ----------
    train_path: str
        String containing path to train data

    test_path: str
        String containing path to test data

    tokenizer: HuggingFace tokenizer
        Tokenizer that used to compute the length of the text after encoding.
        For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html
    """
    train_dataset = TextDataset(
          tokenizer  = tokenizer,
          file_path  = train_path,
          block_size = 256)

    test_dataset = TextDataset(
          tokenizer  = tokenizer,
          file_path  = test_path,
          block_size = 256)

    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=False
    )
    return train_dataset, test_dataset, data_collator

In [9]:
# Create PyTorch Datasets
train_dataset, test_dataset, data_collator = load_dataset('train_test/Dostoevsky_train.txt', 
                                                          'train_test/Dostoevsky_test.txt', tokenizer)

# Create PyTorch Dataloaders
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=1, collate_fn=data_collator)
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator)



Проверим, все ли в порядке, сделав один forward.

In [10]:
try:
    for batch in train_loader:
        break
    {k: v.shape for k, v in batch.items()}

    outputs = model(**batch)
except:
    print("Unexpected error:", sys.exc_info()[0])
    raise

## Fine-tuning

Будем учить модель 2 эпохи оптимизатором AdamW:

In [11]:
num_epochs = 2
optimizer = AdamW(model.parameters(), lr=3e-5)
save_checkpoint_path = 'model/ruDialoGpt3_Dostoevsky.pt'


num_training_steps = num_epochs * len(train_dataset)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=100,
    num_training_steps=num_training_steps
)

accelerator = Accelerator()
train_dl, test_dl, model, optimizer = accelerator.prepare(
    train_loader, test_loader, model, optimizer
)
# wandb.watch(model, log="all")



In [12]:
progress_bar = tqdm.tqdm(range(num_training_steps))

for epoch in range(num_epochs):

    ### TRAIN EPOCH
    model.train()
    for batch in train_dl:
        optimizer.zero_grad()
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)

        # wandb.log({'train_loss':loss.item()})
        optimizer.step()
        lr_scheduler.step()
        progress_bar.update(1)

    ### SAVE
    torch.save({
            'model_state_dict': model.state_dict(),
    }, save_checkpoint_path)

    ### VALIDATE ONCE
    cum_loss = 0
    model.eval()
    with torch.inference_mode():
        for batch in test_dl:
            outputs = model(**batch)
            cum_loss += float(outputs.loss.item())

    print(cum_loss/len(test_loader))
    # wandb.log({'val_mean_loss':cum_loss/len(test_loader)})

 50%|█████████████████████████████████████                                     | 6151/12300 [21:56<19:22:30, 11.34s/it]

3.0416363929937407


100%|████████████████████████████████████████████████████████████████████████████| 12300/12300 [43:40<00:00,  4.85it/s]

3.138288442559821
