<a href="https://colab.research.google.com/github/DE50LAT10N/ya-handbook-ml/blob/main/lr0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install --upgrade git+https://github.com/dask/s3fs --no-deps


In [None]:
%pip install transformers tokenizers datasets evaluate accelerate --no-deps

In [None]:
urls = [
    "http://az.lib.ru/t/tolstoj_lew_nikolaewich/text_0039.shtml",
    "http://az.lib.ru/t/tolstoj_lew_nikolaewich/text_0040.shtml",
    "http://az.lib.ru/t/tolstoj_lew_nikolaewich/text_0050.shtml",
    "http://az.lib.ru/t/tolstoj_lew_nikolaewich/text_0060.shtml",
    "http://az.lib.ru/t/tolstoj_lew_nikolaewich/text_0070.shtml",
    "http://az.lib.ru/t/tolstoj_lew_nikolaewich/text_0080.shtml",
    "http://az.lib.ru/t/tolstoj_lew_nikolaewich/text_0090.shtml",
    "http://az.lib.ru/t/tolstoj_lew_nikolaewich/text_1860_dekabristy.shtml",
]

In [None]:
import html
import re
import requests

def download(url):
    return requests.get(url).text

striptags_re = re.compile(r"(<!--.*?-->|<[^>]*>)")
entity_re = re.compile(r"&([^;]+);")

def to_text(s):
    return html.unescape(striptags_re.sub("", s))

def beautify(s):
    lines = [x.strip() for x in s.split("\n") if x.strip() != ""]
    for i in range(min(100, len(lines))):
        if lines[i] == "-->":
            break
    return "\n".join(lines[i + 1 :] if i < 100 else lines)


with open("dataset.txt", "w", encoding="utf-8") as f:
    for u in urls:
        text = beautify(to_text(download(u)))
        f.write(text + "\n\n")

In [None]:
import tokenizers as tok
import transformers as tr

In [None]:
tokenizer = tok.Tokenizer(tok.models.BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = tok.pre_tokenizers.Whitespace()
trainer = tok.trainers.BpeTrainer(special_tokens=["[PAD]"])
tokenizer.train(["dataset.txt"], trainer)
tokenizer.enable_padding()

In [None]:
vocab = tokenizer.get_vocab()
ttokenizer = tr.PreTrainedTokenizerFast(tokenizer_object=tokenizer)
len(vocab)

In [None]:
import datasets

dataset = datasets.load_dataset("text", data_files="dataset.txt")
dataset["train"][13]

In [None]:
def tokenize(x):
    x = ttokenizer(x["text"])
    x["labels"] = x["input_ids"].copy()
    return x


ds = dataset.map(tokenize, batched=True, remove_columns=["text"])
ds["train"][0]

In [None]:
from itertools import chain

block_size = 1024

def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    total_length = (total_length // block_size) * block_size
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

dsb = ds.map(group_texts, batched=True)

In [None]:
tokenizer = tr.AutoTokenizer.from_pretrained("ai-forever/rugpt3small_based_on_gpt2")
gpt = tr.GPT2LMHeadModel.from_pretrained("ai-forever/rugpt3small_based_on_gpt2")
res = gpt.generate(
    **tokenizer("Мне нравится, что вы ", return_tensors="pt"),
    max_new_tokens=50,
    top_k=3,
    do_sample=True
)
tokenizer.decode(res[0])

In [None]:
dataset = datasets.load_dataset("text", data_files="dataset.txt")
ds = dataset.map(lambda x:
                 tokenizer(x["text"]), batched=True, remove_columns=["text"])
dsb = ds.map(group_texts, batched=True)

In [None]:
import os
import json
os.environ["WANDB_DISABLED"] = "true"

In [None]:
targs = tr.TrainingArguments(
    output_dir="gpt2-finetune",
    num_train_epochs=30,
    learning_rate=5e-5,
    warmup_steps=200,
    save_steps=1500,
)
trainer = tr.Trainer(
    gpt,
    args=targs,
    train_dataset=dsb["train"],
    tokenizer=tokenizer,
    data_collator=tr.default_data_collator,  # tr.DataCollatorForLanguageModeling(tokenizer=ttokenizer,mlm=False)
)
trainer.train()

In [None]:
res = gpt.generate(
    **tokenizer("Мне нравится, что вы ", return_tensors="pt").to("cuda"),
    max_new_tokens=50,
    top_k=3,
    do_sample=True
)
tokenizer.decode(res[0])