In [None]:
import json
import pandas as pd


import torch
from torch.utils.data import Dataset
from transformers import TrainingArguments, Trainer
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import DataCollatorForLanguageModeling
from transformers import StoppingCriteria, StoppingCriteriaList

In [None]:
def ds_df(path):
    with open(path, 'r') as file:
        data = json.load(file)

    inputs = []
    outputs = []

    for l in data:
        inputs.append(l['input'].replace('\n', ' '))
        outputs.append(l['output'].replace('\n', ' '))

    df = pd.DataFrame({'i': inputs, 'o': outputs})

    return df

In [None]:
df = ds_df('dataset.jsonl')

In [None]:
print(df.i[4] + '\n')
print(df.o[4])

In [None]:
DEVICE = torch.device("mps")

In [None]:
model_name = "sberbank-ai/rugpt3medium_based_on_gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name).to(DEVICE)

In [None]:
SPECIAL_TOKENS = {'bos_token':'<bos>','eos_token' :'<eos>', 'pad_token':'<pad>', 'sep_token': '<sep>'}
tokenizer.add_special_tokens(SPECIAL_TOKENS)
model.resize_token_embeddings(len(tokenizer))

In [None]:
class DS(Dataset):

    def __init__(self, data, tokenizer, max_length=150):
        self.tokenizer = tokenizer 
        self.input_ids = []
        self.attn_masks = []

        for idx in data.index.to_list():
            inp = data.i[idx]
            out = data.o[idx] 

            encodings_dict = tokenizer('<bos>'+ inp + '<sep>' + out + '<eos>',
                                       truncation=True,
                                       max_length=max_length,
                                       padding="max_length")

            self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
            self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
            
            # if idx == 10000:
            #     break

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'attn_masks': self.attn_masks[idx]
        }

In [None]:
train_dataset = DS(df, tokenizer)

In [None]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [None]:
dir = './'

In [None]:
training_args = TrainingArguments(
    output_dir=f'{dir}Checkouts', 
    overwrite_output_dir = True, 
    num_train_epochs = 8,
    per_device_train_batch_size = 3,
    per_device_eval_batch_size = 3,  
    warmup_steps = 100,
    gradient_accumulation_steps = 1, 
    save_steps = 3000
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    optimizers = (torch.optim.AdamW(model.parameters(),lr=1e-5),None)
)

In [None]:
trainer.train()

In [None]:
trainer.save_model(f'{dir}model_with_summary')
tokenizer.save_vocabulary(f'{dir}tokenizer')

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained(f'{dir}tokenizer')
model = GPT2LMHeadModel.from_pretrained(f'{dir}model_with_summary').to(DEVICE)

In [None]:
SPECIAL_TOKENS = {'bos_token':'<bos>','eos_token' :'<eos>', 'pad_token':'<pad>', 'sep_token': '<sep>'}
tokenizer.add_special_tokens(SPECIAL_TOKENS)

In [None]:
class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords_ids:list):
        self.keywords = keywords_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids[0][-1] in self.keywords:
            print(input_ids)
            return True
        return False

In [None]:
stop_criteria = KeywordsStoppingCriteria(tokenizer.encode(tokenizer.eos_token, return_tensors="pt").to(DEVICE))

In [None]:
inp = 'Продолжи диалог: Собеседник: Привет, чем ты сегодня занимался? Ты: <sep> '

In [None]:
input_ids = tokenizer.encode(inp, return_tensors="pt").to(DEVICE)

In [None]:
with torch.no_grad():
    out = model.generate(input_ids,
                         do_sample=True,
                         num_beams=3,
                         temperature=2.0,
                         top_p=0.9,
                         max_length = 100,
                         stopping_criteria=StoppingCriteriaList([stop_criteria]),
                         eos_token_id=tokenizer.eos_token_id,
                         bos_token_id=tokenizer.bos_token_id,
                         ).to(DEVICE)
print(tokenizer.batch_decode(out, skip_special_tokens=False)[0])

In [None]:
s = tokenizer.batch_decode(out, skip_special_tokens=False)[0]
s = s[len(inp):]

i = s.find("Собеседник:")
if i != -1:
    s = s[:i]

print(s)

In [None]:
import time

def chat(promt):
    while True:
        print('-' * 80)
        dialog = []
        msg = '...'
        while True:
            msg = input('Сообщение:').strip()
            if len(msg) == 0 or msg == 'й':
                break
            msg = msg[0].upper() + msg[1:]
            dialog.append('Собеседник: ' + msg)
            inp = f'{promt} Продолжи диалог:' + ''.join(dialog) + 'Ты: <sep>'

            input_ids = tokenizer.encode(inp, return_tensors="pt").to(DEVICE)

            with torch.no_grad():
                out = model.generate(input_ids,
                                     do_sample=True,
                                     num_beams=3,
                                     temperature=2.0,
                                     top_p=0.9,
                                     max_length = 400,
                                     stopping_criteria=StoppingCriteriaList([stop_criteria]),
                                     eos_token_id=tokenizer.eos_token_id,
                                     bos_token_id=tokenizer.bos_token_id,
                                     ).to(DEVICE)


            s = tokenizer.batch_decode(out, skip_special_tokens=False)[0]
            s = s[len(inp):]
            
            i = s.find("Собеседник:")
            if i != -1:
                s = s[:i]
                
            print(msg)
            print('Бот:> {}'.format(s))
            dialog.append('Ты: ' + s)
            time.sleep(2)
    
        if msg == 'й':
            break

In [None]:
chat('')

In [None]:
model