In [None]:
import sys
import re
import json

import tqdm
import torch
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from transformers import TextDataset, DataCollatorForLanguageModeling
from torch.utils.data import DataLoader
from accelerate import Accelerator
from transformers import  get_scheduler, AdamW, AutoModelForSequenceClassification


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

# import pretrained model
checkpoint = "Kirili4ik/ruDialoGpt3-medium-finetuned-telegram"
tokenizer =  AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint)

In [None]:
# Maps text by 1/0 for it to be the person or the machine in the dialogue
def get_user_param(text: dict, machine_name_in_chat: str) -> str:
    if text['from'] == machine_name_in_chat:
        return '1'  # machine
    else:
        return '0'  # human


# Creates train and test PyTorch datasets and collate_fn using HuggingFace
def load_dataset(train_path, test_path, tokenizer):
    train_dataset = TextDataset(
          tokenizer  = tokenizer,
          file_path  = train_path,
          block_size = 256)
     
    test_dataset = TextDataset(
          tokenizer  = tokenizer,
          file_path  = test_path,
          block_size = 256)   
    
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=False
    )
    return train_dataset, test_dataset, data_collator

# Maps text to 1 of 4 buckets based on length after encoding
def get_length_param(text: str, tokenizer) -> str:
    tokens_count = len(tokenizer.encode(text))
    if tokens_count <= 15:
        len_param = '1'
    elif tokens_count <= 50:
        len_param = '2'
    elif tokens_count <= 256:
        len_param = '3'
    else:
        len_param = '-'
    return len_param

# Create a text file for training in special format for ruDialoGPT-3
def build_text_file(data_json: dict, dest_path: str, 
                    tokenizer, machine_name_in_chat='Дима'):
    f = open(dest_path, 'w', encoding='utf-8')
    new_data = ''
    for i in range(len(data_json) - 1):
        message, next_message = data_json[i], data_json[i+1]
        if message['text'] == '' or type(message['text']) != str:
            continue
        if next_message['text'] == '' or type(next_message['text']) != str:
            continue
        if i%2==0:
            user = '1'
        else:
            user = '0'
        length = get_length_param(data_json[i+1]['text'], tokenizer)
        message_text = re.sub(r"\n", ". ", message['text'])
        new_data += f"|{user}|{length}|{message_text}{tokenizer.eos_token}" + "\n"

    f.write(new_data)


In [None]:
PATH = 'data/result.json' 

with open(PATH, encoding='utf-8') as f: messages = json.load(f)['messages']

In [None]:
train, test = messages[int(len(messages)*0.1):], messages[:int(len(messages)*0.1)]

build_text_file(train, 'data/train_dataset.txt', tokenizer)
build_text_file(test,  'data/test_dataset.txt', tokenizer)

Here the first number is the spearker number - '1' for GPT and '0' for the person.

The second number is the lengths of the expected answer: '1' for short, '2' for medium, '3' for long texts and '-' for all others.


In [None]:
# Create PyTorch Datasets
train_dataset, test_dataset, data_collator = load_dataset('data/train_dataset.txt', 
                                                          'data/test_dataset.txt', tokenizer)

# Create PyTorch Dataloaders
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=1, collate_fn=data_collator)
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator)

In [None]:
# this cell checks 1 forward pass
try:
    for batch in train_loader:
        break
    {k: v.shape for k, v in batch.items()}

    outputs = model(**batch)
except:
    print("Unexpected error:", sys.exc_info()[0])
    raise

## Fine-tuning

In [None]:
num_epochs = 3 #@param {type:"integer"}
optimizer = AdamW(model.parameters(), lr=3e-5) #@param
save_checkpoint_path = 'models/model_GPT.pt' #@param {type:"string"}


# Calculate the total number of training steps
num_training_steps = num_epochs * len(train_dataset)

# Create a learning rate scheduler using a linear schedule
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=100,
    num_training_steps=num_training_steps
)

# Initialize the accelerator for distributed training
accelerator = Accelerator()

# Prepare the data loaders, model, and optimizer for training using the accelerator
train_dl, test_dl, model, optimizer = accelerator.prepare(
    train_loader, test_loader, model, optimizer
)

In [None]:
progress_bar = tqdm.tqdm(range(num_training_steps))

for epoch in range(num_epochs):

    #train epoch
    model.train()
    for batch in train_dl:
        optimizer.zero_grad()
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)

    
        optimizer.step()
        lr_scheduler.step()
        progress_bar.update(1)

    #save
    torch.save({
            'model_state_dict': model.state_dict(),
    }, save_checkpoint_path)

#     ### VALIDATE ONCE
#     cum_loss = 0
#     model.eval()
#     with torch.inference_mode():
#         for batch in test_dl:
#             outputs = model(**batch)
#             cum_loss += float(outputs.loss.item())

#     print(cum_loss/len(test_loader))
#     # wandb.log({'val_mean_loss':cum_loss/len(test_loader)})

## Inference

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Download checkpoint:
checkpoint = "Kirili4ik/ruDialoGpt3-medium-finetuned-telegram"
tokenizer =  AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint)

# Insert finetuned model
checkpoint = torch.load('models/model_GPT.pt', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])

model = model.to('cpu')
model.eval()
print()

In [None]:
chat_history_ids = torch.zeros((1, 0), dtype=torch.int)

while True:
    next_who = input("Who's phrase?\t")  # The user should input 'H' for human or 'G' for the chatbot. 
    if next_who == "H":
        input_user = input("===> Human: ")
        new_user_input_ids = tokenizer.encode(f"|0|{get_length_param(input_user, tokenizer)}|" \
                                              + input_user + tokenizer.eos_token, return_tensors="pt")
        chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)

    if next_who == "G":
        # The user should input 1/2/3/- to specify the length of the generated response
        next_len = input("Phrase len? 1/2/3/-\t")  
        # Encode the phrase length and append it to chat_history_ids
        new_user_input_ids = tokenizer.encode(f"|1|{next_len}|", return_tensors="pt")
        chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
        # Get the current length of the conversation history
        input_len = chat_history_ids.shape[-1]
        #  Generate a response from the model using the conversation history
        chat_history_ids = model.generate(
            chat_history_ids,
            num_return_sequences=1,                     
            max_length=512,
            no_repeat_ngram_size=3,
            do_sample=True,
            top_k=50,
            top_p=0.9,
            temperature = 0.6,                         
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )

        print(f"===> GPT-3:  {tokenizer.decode(chat_history_ids[:, input_len:][0], skip_special_tokens=True)}")