In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
from fastai.text.all import *
from sklearn.model_selection import train_test_split

In [None]:
pretrained_weights = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(pretrained_weights)
model = GPT2LMHeadModel.from_pretrained(pretrained_weights)

In [None]:
path = Path()

In [None]:
data_path = (path/'StephenKing_txt')

In [None]:
data_path = (path/'HarryPotter_txt')

In [None]:
df = pd.DataFrame()
for txt_f in get_text_files(data_path):
    with open(txt_f, encoding='utf-8') as f:
        while True:
            txt = f.read(2048)
            if(txt == ''): break
            df = df.append([txt], ignore_index=True)

In [None]:
class TransformersTokenizer(Transform):
    """Класс токенайзера для работы с GPT2"""
    def __init__(self, tokenizer): self.tokenizer = tokenizer
    def encodes(self, x): 
        toks = self.tokenizer.tokenize(x)
        return tensor(self.tokenizer.convert_tokens_to_ids(toks))
    def decodes(self, x): return TitledStr(self.tokenizer.decode(x.cpu().numpy()))

In [None]:
splits = train_test_split(list(range(df.count()[0])), test_size=0.1)
tls = TfmdLists(df[0].values, TransformersTokenizer(tokenizer), splits=splits, dl_type=LMDataLoader)

In [None]:
bs,sl = 8,256
dls = tls.dataloaders(bs=bs, seq_len=sl)

In [None]:
dls.show_batch(max_n=2)

In [None]:
class DropOutput(Callback):
    """Класс поддержки для обучения GPT2"""
    def after_pred(self): self.learn.pred = self.pred[0]

In [None]:
learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), cbs=[DropOutput], metrics=Perplexity()).to_fp16()

In [None]:
learn.validate()

In [None]:
learn.lr_find()

In [None]:
%%time
learn.fit_one_cycle(1, 1e-4)

In [None]:
prompt = "It was a bright day."
prompt_ids = tokenizer.encode(prompt)
inp = tensor(prompt_ids)[None].cuda()

In [None]:
preds = learn.model.generate(inp, max_length=500, repetition_penalty=6.0,
                             temperature=1.5, no_repeat_ngram_size=2,
                             do_sample=True, top_k=5, top_p=0.95)
tokenizer.decode(preds[0].cpu().numpy())

In [None]:
learn.export(path/'WriterHP_transf_model.pkl')

In [None]:
learn.export(path/'WriterStKng_transf_model.pkl')

In [None]:
learn.save('writer_HarPot_transf_1epoch')

In [None]:
learn.save('writer_StKng_transf_1epoch')