In [None]:
!CUDA_VISIBLE_DEVICES="1" python3 -m torch.distributed.launch --nproc_per_node 2 ./train.py --model="gpt2-large" --fp16='O3'

In [6]:
import os
import pandas as pd
df = pd.read_csv('messages.csv')

In [8]:
import random

BST_DATA_DIR = "ParlAI/data/blended_skill_talk/"
MSC_DATA_DIR = "ParlAI/data/msc/msc/msc_dialogue/"
MAX_DIALOG_TOKENS_LEN = 256
def separate_profile_dialog(persona_dialog):
    profile = [u for u in persona_dialog if u.startswith("Bot profile:")]
    dialog = [u for u in persona_dialog if u.startswith("Person:") or u.startswith("Bot:")]

    return profile, dialog


def get_processed_dialog(dialog, bot_label, person_label, profile):
    processed_dialog = []
    for utterance in dialog:
        utterance = utterance.replace(bot_label, "Bot")
        utterance = utterance.replace(person_label, "Person")
        processed_dialog.append(utterance)

    persona = [f"Bot profile: {p}" for p in profile]

    return {"persona": persona, "dialog": processed_dialog}


def prepare_person_bot_dialog(dialog):
    prepared_dialog = []
    user1, user2 = ("Person", "Bot") if random.randint(0, 1) else ("Bot", "Person")
    for i, utterance in enumerate(dialog):
        user = user1 if i % 2 == 0 else user2
        prepared_dialog.append(f"{user}: {utterance}")

    return prepared_dialog


In [9]:
from datasets import load_dataset
from loguru import logger

def _normalize_dialog(raw_dialog):
    dialog = [d.strip() for d in raw_dialog]

    for ch in "’?!.,":
        dialog = [d.replace(f" {ch}", ch) for d in dialog]

    dialog = [d.replace(".", ". ") for d in dialog]
    dialog = [d.replace("  ", " ") for d in dialog]
    dialog = [d.strip() for d in dialog]

    dialog = [d.replace("’", "'") for d in dialog]
    dialog = [d.replace("' s", "'s") for d in dialog]
    dialog = [d.replace("' ll", "'ll") for d in dialog]
    dialog = [d.replace("' re", "'re") for d in dialog]
    dialog = [d.replace("' t", "'t") for d in dialog]
    dialog = [d.replace("' m", "'m") for d in dialog]

    return dialog


def prepare_daily_dialog():
    daily_dialog = load_dataset("daily_dialog")

    samples = []
    for data_type in daily_dialog:
        for x in daily_dialog[data_type]:
            dialog = _normalize_dialog(x["dialog"])
            prepared_dialog = prepare_person_bot_dialog(dialog)
            samples.append(prepared_dialog)

    logger.info(f"Prepared {len(samples)} dialogs of daily_dialog dataset")

    return samples


In [None]:
import re
from collections import defaultdict

def _filter_dialogs(dialogs):
    processed_dialogs = []
    for dialog in dialogs:
        has_bad_utterance = False
        for utterance in dialog:
            # for some reason it could be broken records with meta info in utterances. Skip those dialogs
            if "_conv" in utterance:
                has_bad_utterance = True

        if not has_bad_utterance:
            processed_dialogs.append(dialog)

    return processed_dialogs


def prepare_empathetic_dialogues():
    dataset = load_dataset("empathetic_dialogues")
    dialog_id_to_dialog = defaultdict(list)
    for data_type in dataset:
        for sample in dataset[data_type]:
            utterance = sample["utterance"].replace("_comma_", ",").strip()
            utterance = re.sub(" +", " ", utterance)
            dialog_id_to_dialog[sample["conv_id"]].append(utterance)

    dialogs = list(dialog_id_to_dialog.values())
    filtered_dialogs = _filter_dialogs(dialogs)

    prepared_dialogs = [prepare_person_bot_dialog(d) for d in filtered_dialogs]
    logger.info(f"Prepared {len(prepared_dialogs)} dialogs of empathetic_dialogues dataset")

    return prepared_dialogs


In [11]:


def prepare_bst(data_dir=BST_DATA_DIR):
    samples = []
    for dataset_type in ["train", "test", "valid"]:
        print(os.listdir(data_dir))
        #print(data_dir)
        data_path = os.path.join(data_dir, f"{dataset_type}.json")
        
        with open(data_path, encoding="utf-8") as f:
            data = json.load(f)
            for row in data:
                person0_profile = row["personas"][0]
                person1_profile = row["personas"][1]

                conv_starter = [
                    f"person 0: {row['free_turker_utterance']}",
                    f"person 1: {row['guided_turker_utterance']}",
                ]
                dialog = [f"person {u[0]}: {u[1]}" for u in row["dialog"]]
                full_dialog = conv_starter + dialog

                processed_dialog1 = get_processed_dialog(
                    full_dialog, bot_label="person 0", person_label="person 1", profile=person0_profile
                )
                samples.append(processed_dialog1)

                processed_dialog2 = get_processed_dialog(
                    full_dialog, bot_label="person 1", person_label="person 0", profile=person1_profile
                )
                samples.append(processed_dialog2)

    logger.info(f"Prepared {len(samples)} dialogs of BST dataset")

    return samples


In [12]:
import json
import jsonlines

def _generate_personalized_dialogs(raw_dialog):
    """
    :param raw_dialog: {'personas': [[str], [str]], dialog: [{'text': str, 'id': str}]}
    :return: list containing two personalized dialogs: first dialog with persona1 facts, second dialog with persona2 facts
    """
    persona_facts_1 = raw_dialog["personas"][0]
    persona_facts_1 = [text.strip() for text in persona_facts_1]

    persona_facts_2 = raw_dialog["personas"][1]
    persona_facts_2 = [text.strip() for text in persona_facts_2]

    dialog = [f"{u['id']}: {u['text'].strip()}" for u in raw_dialog["dialog"]]

    personalized_dialog1 = get_processed_dialog(
        dialog, bot_label="Speaker 1", person_label="Speaker 2", profile=persona_facts_1
    )

    personalized_dialog2 = get_processed_dialog(
        dialog, bot_label="Speaker 2", person_label="Speaker 1", profile=persona_facts_2
    )

    return [personalized_dialog1, personalized_dialog2]


def prepare_msc(sessions_data_dir=MSC_DATA_DIR):
    dialogs = []

    for session_dir in ["session_2", "session_3", "session_4", "session_5"]:
        for dataset_type in ["train", "test", "valid"]:
            data_path = os.path.join(sessions_data_dir, session_dir, f"{dataset_type}.txt")
            
            if not os.path.isfile(data_path):
                continue

            with jsonlines.open(data_path) as reader:
                for raw_dialog in reader:
                    dialogs += _generate_personalized_dialogs(raw_dialog)

    init_dialogs_path = os.path.join(sessions_data_dir, "init_persona_all.json")
    
    print(sessions_data_dir, os.listdir(sessions_data_dir))
    
    with open(init_dialogs_path, encoding="utf-8") as f:
        
        data = json.load(f)
        for key, raw_dialog in data.items():
            if "train" not in key:
                continue

            raw_dialog["dialog"] = [
                {"text": line["text"], "id": f"Speaker {1 + i % 2}"} for i, line in enumerate(raw_dialog["dialog"])
            ]
            dialogs += _generate_personalized_dialogs(raw_dialog)

    logger.info(f"Prepared {len(dialogs)} dialogs of MSC dataset")

    return dialogs


In [13]:
def prepare_persona_chat():
    dataset = load_dataset("bavard/personachat_truecased")
    samples = []
    for data_type in dataset:
        for sample in dataset[data_type]:
            persona = [f"Bot profile: {p}" for p in sample["personality"]]
            history = sample["history"] + [sample["candidates"][-1]]

            prepared_history = []
            user1, user2 = ("Person", "Bot")
            for i, utterance in enumerate(history):
                user = user1 if i % 2 == 0 else user2
                prepared_history.append(f"{user}: {utterance}")

            dialog = [u for u in prepared_history if "SILENCE" not in u]

            if len(dialog) >= 2:
                samples.append({"persona": persona, "dialog": dialog})

    logger.info(f"Prepared {len(samples)} dialogs of persona_chat dataset")

    return samples


In [14]:
import random

def separate_profile_dialog(persona_dialog):
    profile = [u for u in persona_dialog if u.startswith("Bot profile:")]
    dialog = [u for u in persona_dialog if u.startswith("Person:") or u.startswith("Bot:")]

    return profile, dialog


def get_processed_dialog(dialog, bot_label, person_label, profile):
    processed_dialog = []
    for utterance in dialog:
        utterance = utterance.replace(bot_label, "Bot")
        utterance = utterance.replace(person_label, "Person")
        processed_dialog.append(utterance)

    persona = [f"Bot profile: {p}" for p in profile]

    return {"persona": persona, "dialog": processed_dialog}


def prepare_person_bot_dialog(dialog):
    prepared_dialog = []
    user1, user2 = ("Person", "Bot") if random.randint(0, 1) else ("Bot", "Person")
    for i, utterance in enumerate(dialog):
        user = user1 if i % 2 == 0 else user2
        prepared_dialog.append(f"{user}: {utterance}")

    return prepared_dialog


In [15]:
from nltk import ngrams
from tqdm import tqdm

def _encode_dialogs(dialogs, tokenizer):
    sep = tokenizer.eos_token
    dialogs_strs = [sep.join(d) + sep for d in dialogs]
    tokenized_samples = tokenizer.batch_encode_plus(dialogs_strs)

    return tokenized_samples["input_ids"]


def read_personalized_dialogs_splits(tokenizer, max_tokens_len=MAX_DIALOG_TOKENS_LEN):
    personalized_dialogs = prepare_bst() + prepare_msc() + prepare_persona_chat()
    logger.info(f"Personalized dialogs: {len(personalized_dialogs)}")

    dialogs_splits = set()
    for obj in tqdm(personalized_dialogs):
        persona = obj["persona"]
        dialog = obj["dialog"]

        # generate all possible subdialogs and filter them by tokens length
        dialog_splits = []
        for n in range(2, len(dialog) + 1):
            dialog_splits += ngrams(dialog, n)

        dialog_splits = [tuple(persona + list(d)) for d in dialog_splits]
        encoded_dialog_splits = _encode_dialogs(dialog_splits, tokenizer)

        for split, encoded_split in zip(dialog_splits, encoded_dialog_splits):
            if len(encoded_split) > max_tokens_len:
                continue
            dialogs_splits.add(split)

    logger.info(f"Personalized dialogs splits: {len(dialogs_splits)}")

    return list(dialogs_splits)


def filter_personalized_dialogs(dialogs):
    """Leave the longest dialogs with unique dialog prefix."""
    prefix_to_dialog = {}

    for persona_dialog in dialogs:
        profile, dialog = separate_profile_dialog(persona_dialog)
        prefix = tuple(profile + dialog[:2])

        if len(persona_dialog) > len(prefix_to_dialog.get(prefix, [])):
            prefix_to_dialog[prefix] = persona_dialog

    filtered_dialogs = list(prefix_to_dialog.values())
    logger.info(f"Filtered personalized dialogs: {len(filtered_dialogs)}")

    return filtered_dialogs
def _encode_dialog(dialog, tokenizer):
        sep = tokenizer.eos_token
        dialog_str = sep.join(dialog) + sep
        encoded_sample = tokenizer.encode(dialog_str)

        return encoded_sample

In [28]:
import pickle

import torch
from torch.utils.data import Dataset


class ConversationDataset(Dataset):
    def __init__(self, args, tokenizer, max_tokens_len=MAX_DIALOG_TOKENS_LEN):
        cached_features_file = os.path.join(args.cache_dir, f"cached_features_{max_tokens_len}.pickle")
        if os.path.exists(cached_features_file) and not args.overwrite_cache:
            logger.info(f"Loading features from cached file {cached_features_file}")
            with open(cached_features_file, "rb") as fh:
                self.examples = pickle.load(fh)
        else:
            logger.info(f"Creating features from dataset file at {args.cache_dir}")
            os.makedirs(args.cache_dir, exist_ok=True)

            dialogs = self._read_dialogs(tokenizer, max_tokens_len)

            self.examples = []
            for dialog in tqdm(dialogs):
                if len(dialog) < 2:
                    continue

                encoded_dialog = _encode_dialog(dialog, tokenizer)
                if len(encoded_dialog) <= max_tokens_len:
                    self.examples.append(encoded_dialog)

            logger.info(f"{len(self.examples)} dialogs left after tokenization")

            logger.info(f"Saving features into cached file {cached_features_file}")
            with open(cached_features_file, "wb") as fh:
                pickle.dump(self.examples, fh, protocol=pickle.HIGHEST_PROTOCOL)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, item):
        return torch.tensor(self.examples[item], dtype=torch.long)

    @staticmethod
    def _read_dialogs(tokenizer, max_tokens_len, common_dialogs_weight=5):
        personalized_dialogs_splits = read_personalized_dialogs_splits(tokenizer, max_tokens_len)
        personalized_dialogs = filter_personalized_dialogs(personalized_dialogs_splits)

        common_dialogs = prepare_daily_dialog() + prepare_empathetic_dialogues()
        logger.info(f"Common dialogs: {len(common_dialogs)}")

       

        dialogs = personalized_dialogs + common_dialogs * common_dialogs_weight
        logger.info(f"Total dialogs: {len(dialogs)}")

        random.shuffle(dialogs)

        return dialogs

    @staticmethod
    def _encode_dialog(dialog, tokenizer):
        sep = tokenizer.eos_token
        dialog_str = sep.join(dialog) + sep
        encoded_sample = tokenizer.encode(dialog_str)

        return encoded_sample


In [29]:
TEST_DATA_PATH = 'test_data.jsonl'
TEST_CANDS_NUM = 5

DEVICE = "cuda:0"
MODEL_DIR = "./"

TOP_P = 0.9
TOP_K = 40
TEMPERATURE = 1
MAX_RESPONSE_LEN = 24

In [30]:

from apex import amp
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import GPT2Tokenizer, TrainingArguments, Trainer, GPTNeoForCausalLM
import wandb

wandb.init(project="gpt_neo_training_2021", entity="netbug")

def decode_responses(responses_ids, tokenizer):
    responses_ids = [r for r in responses_ids if r[-1] == tokenizer.eos_token_id]
    responses = [tokenizer.decode(r, skip_special_tokens=True) for r in responses_ids]
    responses = [r.strip() for r in responses if r]

    return responses

def _read_test_contexts(test_data_path='test_data.jsonl'):
    with jsonlines.open(test_data_path) as reader:
        return [dialog for dialog in reader]


def _save_checkpoint(args, model, tokenizer,it):
    output_dir = os.path.join(args.output_dir, f"checkpoint_{it}")
    os.makedirs(output_dir, exist_ok=True)
    logger.info(f"Saving model checkpoint to {output_dir}")

    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    torch.save(args, os.path.join(output_dir, "training_args.bin"))


def _generate_gpt_candidates(
    context,
    cands_num,
    model,
    tokenizer,
    device,
    top_p=TOP_P,
    top_k=TOP_K,
    temperature=TEMPERATURE,
    max_response_len=MAX_RESPONSE_LEN,
):
    sep = tokenizer.eos_token
    context_str = sep.join(context) + sep
    encoded_context = tokenizer.encode(context_str, return_tensors="pt").to(device)
    context_len = encoded_context.shape[-1]

    responses_ids = model.generate(
        encoded_context,
        do_sample=True,
        max_length=context_len + max_response_len,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
       
        num_return_sequences=cands_num,
        pad_token_id=tokenizer.eos_token_id,
        return_dict=True
    )
    responses_ids = [r[context_len:] for r in responses_ids]

    return decode_responses(responses_ids, tokenizer)


def evaluate(args, model, tokenizer, val_dataset, test_contexts=None, cands_num=2):
    def collate(examples):
        return pad_sequence(examples, batch_first=True, padding_value=tokenizer.eos_token_id)

    if test_contexts:
        logger.info(f"Generate candidates for {len(test_contexts)} test contexts")
        for context in test_contexts:
            candidates = _generate_gpt_candidates(context, cands_num, model, tokenizer, args.device)
            for cand in candidates:
                logger.info(f"{context} -> {cand}")

    logger.info(f"*** Running evaluation on {len(val_dataset)} examples, eval batch size = {args.eval_batch_size} ***")

    dataloader = DataLoader(val_dataset, batch_size=args.eval_batch_size, collate_fn=collate, drop_last=True)
    eval_loss = 0.0
    model.eval()

    for inputs in dataloader:
        inputs = inputs.to(args.device)
        with torch.no_grad():
            outputs = model(inputs, labels=inputs,return_dict=True)
        lm_loss = outputs[0]
        eval_loss += lm_loss.mean().item()

    eval_loss = eval_loss / len(dataloader)
    perplexity = torch.exp(torch.tensor(eval_loss))

    logger.info(f"Model perplexity on val dataset = {perplexity}")

    return perplexity


def train(args, train_dataset, val_dataset, model, tokenizer):
    def collate(examples):
        return pad_sequence(examples, batch_first=True, padding_value=tokenizer.eos_token_id)

    train_dataloader = DataLoader(
        train_dataset, shuffle=True, batch_size=args.train_batch_size, collate_fn=collate, drop_last=True
    )
    
    training_steps = args.num_train_epochs * len(train_dataloader) // args.gradient_accumulation_steps

    optimizer = AdamW(params=model.parameters(), 
                      lr=args.learning_rate, eps=args.adam_epsilon)
    
    
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=training_steps
    )
    
    model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
    #if args.n_gpu > 1:
    #    model = torch.nn.DataParallel(model).to('cuda')#,device_ids=[0, 1]
   
    #model = torch.nn.DataParallel(model, device_ids=[0, 1])
    
    
    test_contexts = _read_test_contexts()
    #print(test_contexts[0])
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Total train batch size  = {args.train_batch_size * args.gradient_accumulation_steps}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {training_steps}")

    best_ppl = np.inf
    backprop_steps = 0
    model.zero_grad()

    for epoch in range(args.num_train_epochs):
        logger.info(f"Epoch {epoch + 1}")
        
        #evaluate(args, model, tokenizer, val_dataset, test_contexts)
       
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for step, inputs in enumerate(epoch_iterator, start=1):
            inputs = inputs.to(args.device)
            model.train()
            outputs = model(inputs, labels=inputs)
            #print(outputs)
            loss = outputs['loss']
            
           
            
            #print(loss)
            if args.n_gpu > 1:
                loss = loss.mean()
                
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            
            
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
                
            #if step % args.eval_steps == 0:
             #       _save_checkpoint(args, model, tokenizer)
              #      logger.info("Saved")
                    
            if step % args.gradient_accumulation_steps == 0:
                logger.info("loss = {}".format(loss))
                wandb.log({"loss":  loss.item()})
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                
                optimizer.step()
                
                scheduler.step()
                model.zero_grad()

                backprop_steps += 1
                if step % args.eval_steps == 0:
                    _save_checkpoint(args, model, tokenizer,it=step)
                    logger.info(f"Saved at step {step}")
                    
                    #ppl = evaluate(args, model, tokenizer, val_dataset, test_contexts)
                    #strange bug produced
                    #wandb.log({"ppl":  ppl})
                    #if ppl < best_ppl:
                     #   logger.info("Best model found")
                      #  best_ppl = ppl
                        
                       # _save_checkpoint(args, model, tokenizer,it=f'best_on_{step}')

    return model


In [31]:
import numpy as np
import torch
from loguru import logger
from sklearn.model_selection import train_test_split
from transformers import GPT2TokenizerFast
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, GPT2LMHeadModel, GPT2Tokenizer

class Args:
    def __init__(self):
        self.overwrite_cache = False
        self.tokenizer_name = "EleutherAI/gpt-neo-1.3B"
        self.model_name_or_path = 'EleutherAI/gpt-neo-1.3B'
        self.output_dir = "ssd/models_gpt/model_dir_xl_512"
        self.cache_dir = "cache_dir"
        self.eval_batch_size = 1
        self.train_batch_size = 1
        self.gradient_accumulation_steps = 1
        self.learning_rate = 5e-5
        self.warmup_steps = 0
        self.adam_epsilon = 1e-8
        self.max_grad_norm = 1.0
        self.num_train_epochs = 1
        self.eval_steps = 10000
        self.seed = 42
        self.fp16 = True
        self.fp16_opt_level = "O1"
        self.n_gpu = 0
        self.local_rank = 2
        self.device = torch.device(DEVICE)


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def freeze(
    model,
    freeze_emb=True,
    freeze_ln=False,
    freeze_attn=False,
    freeze_ff=True,
    freeze_other=True,
):
    
    for name, p in model.named_parameters():
    # freezing all parameters except the layernorm and positional embeddings

        name = name.lower()
        if 'ln' in name or 'norm' in name:
            p.requires_grad = not freeze_ln
        elif 'embeddings' in name:
            p.requires_grad = not freeze_emb
        elif 'mlp' in name:
            p.requires_grad = not freeze_ff
        elif 'attn' in name:
            p.requires_grad = not freeze_attn
        else:
            p.requires_grad = not freeze_other
           
    return model

def main():
    args = Args()
    
    args.device = torch.device(DEVICE)

    logger.info(f"Training/evaluation parameters {vars(args)}")
    set_seed(args.seed)

    if args.model_name_or_path.startswith('EleutherAI'):
        tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name, bos_token='<|endoftext|>',
                                              eos_token='<|endoftext|>', pad_token='<|pad|>')
        
        model = GPTNeoForCausalLM.from_pretrained(args.model_name_or_path,return_dict=True).to(args.device)
        model = freeze(model)
        
        
    else:
        tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name)
   
        model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path).to(args.device)
    #model.to(args.device)

    dataset = ConversationDataset(args, tokenizer)
    #print(dataset.)
    train_dataset, val_dataset = train_test_split(dataset, test_size=500, random_state=42, shuffle=True)
    logger.info(f"Train dataset = {len(train_dataset)} samples; Val dataset = {len(val_dataset)} samples")

    model = train(args, train_dataset, val_dataset, model, tokenizer)

    evaluate(args, model, tokenizer, val_dataset)
    



In [None]:
if __name__ == "__main__":
    main()

In [None]:
model = GPTNeoForCausalLM.from_pretrained(args.model_name_or_path).to(args.device)

In [20]:
torch.cuda.empty_cache()

In [6]:
args = Args()

In [21]:
import torch.distributed.autograd as dist_autograd
from torch.nn.parallel import DistributedDataParallel as DDP
from torch import optim
from torch.distributed.optim import DistributedOptimizer
from torch.distributed.rpc import RRef

In [26]:
tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name)
model = GPTNeoForCausalLM.from_pretrained(args.model_name_or_path)
model = DDP(model, device_ids=2).to('cuda')

TypeError: object of type 'int' has no len()

In [2]:
import random
import json
import numpy as np
from nltk import ngrams
import torch
from loguru import logger
from sklearn.model_selection import train_test_split
from transformers import GPT2TokenizerFast
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, GPT2LMHeadModel, GPT2Tokenizer
class Args:
    def __init__(self):
        self.overwrite_cache = False
        self.tokenizer_name = ""
        self.model_name_or_path = "microsoft/DialoGPT-large"
        self.output_dir = "model_dir1"
        self.cache_dir = "cache_dir"
        self.eval_batch_size = 2
        self.train_batch_size = 2
        self.gradient_accumulation_steps = 32
        self.learning_rate = 5e-5
        self.warmup_steps = 0
        self.adam_epsilon = 1e-8
        self.max_grad_norm = 1.0
        self.num_train_epochs = 1
        self.eval_steps = 10000
        self.seed = 42
        self.fp16 = True
        self.fp16_opt_level = "O3"
args = Args()

MODEL_DIR = '/mnt/swamp/models/GPTNeo_ckpt20'

tokenizer = GPT2Tokenizer.from_pretrained(MODEL_DIR)
model = GPT2LMHeadModel.from_pretrained(args.tokenizer_name)
model.to('cuda')

TOP_P = 0.9
TOP_K = 40
TEMPERATURE = 1
MAX_RESPONSE_LEN = 24

def decode_responses(responses_ids, tokenizer):
    responses_ids = [r for r in responses_ids if r[-1] == tokenizer.eos_token_id]
    responses = [tokenizer.decode(r, skip_special_tokens=True) for r in responses_ids]
    
    responses = [r.strip() for r in responses if r]

    return responses


def generate_gpt_candidates(
    context,
    cands_num,
    model,
    tokenizer,
    device='cuda',
    top_p=TOP_P,
    top_k=TOP_K,
    temperature=TEMPERATURE,
    max_response_len=MAX_RESPONSE_LEN,
):
    sep = tokenizer.eos_token
    context_str = sep.join(context) + sep
    
    encoded_context = tokenizer.encode(context_str, return_tensors="pt").to(device)
    context_len = encoded_context.shape[-1]

    responses_ids = model.generate(
        encoded_context,
        do_sample=True,
        max_length=context_len + max_response_len,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_return_sequences=cands_num,
        pad_token_id=tokenizer.eos_token_id,
    )
    resp = []
    for response in responses_ids:
        
        r = tokenizer.decode(response).replace(context_str,'')
        
        if r.split(tokenizer.eos_token)[0]:
            resp+=[r.split(tokenizer.eos_token)[0]]
        else:
            resp+=[r]
        
    #esponses_ids = [r[context_len:] for r in responses_ids]

    return resp#decode_responses(responses_ids, tokenizer)

generate_gpt_candidates(
    context=['Bot profile: I dont play video games.  I like to walk on the beach. I like pop music.',
             'Hi',
             'Hi',
             'Hi :DhiHi there Hey, how are you? I am great and you are awesome',
             'What are you Doing?'],model=model,tokenizer=tokenizer,cands_num=4)

['Oh I am doing great thanks for asking',
 'Well I am just relaxing and enjoying life :D',
 'I am eating and going to my class.Im good!',
 'I wanna play GTA V and I wanna play it, you know I play GTA V on my phone...']

In [3]:
generate_gpt_candidates(context=['Bot profile: I dont play video games.  I like to walk on the beach. I like pop music.','Hi','Hi','Hi thereHey, how are you? I am great and you are awesome','What are you Doing?','Haha, its my first day. Just got a new phone.','Oh wow what is model','LG G3.','Oh thats good, i have an iphone', 'Oh you are a lucky one!','Thats right','What is your name?','jhb? Thats strange name'],model=model,tokenizer=tokenizer,cands_num=10)

['I want to give you a pokemon.',
 'Lol i thought soo',
 'Yea its a great name',
 'Haha its like my last name, but not really',
 'oh right it is',
 'jhb?',
 "Lol it's the same for you.",
 "Jhb? I have an a a a, but that's",
 'haha thats a weird name',
 'He has a good one.j']

In [191]:
!nvidia-smi

Wed Sep 29 08:51:08 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0 Off |                  N/A |
|  0%   46C    P8    16W / 350W |  21789MiB / 24265MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:06:00.0 Off |                  N/A |
|  0%   37C    P8    22W / 350W |  13361MiB / 24268MiB |      0%      Defaul

In [113]:
dataset = ConversationDataset(args, tokenizer)

2021-09-28 13:43:23.226 | INFO     | __main__:__init__:15 - Loading features from cached file cache_dir/cached_features_256.pickle
  0%|▏                                   | 700/177926 [00:30<1:09:02, 42.78it/s]

In [122]:
train_dataset, val_dataset = train_test_split(dataset, test_size=500, random_state=42, shuffle=True)

In [None]:
import json

with open('output.jsonl', 'w') as outfile:
    for entry in JSON_file:
        json.dump(entry, outfile)
       

In [117]:
tokenizer.decode(dataset.examples[0])

"Bot profile: I don't play video games.  I like to walk on the beach. I like pop music.<|endoftext|>Bot profile: I like to watch TV.  I like Game of Thrones. My favorite season of Game of Thrones is season five. I like watching new shows. I have not watched Bridgerton.<|endoftext|>Bot profile: I eat meat.<|endoftext|>Bot profile: I live in an area near a beach.<|endoftext|>Bot: Oh I understand.. I can do that sometimes too. I try to watch new shows that are coming out so I can't go online and read the spoilers.. or accidentally come across them.<|endoftext|>Person: Yep, the exact same thing happens to me. I always end up online or on a forum.<|endoftext|>Bot: It's way too tempting. I do like going online and reading people's opinions afterwards.. but even Game of Thrones season 8 was spoiled online before it was released to the public.<|endoftext|>"

In [142]:
_read_test_contexts()

["Bot profile: I've been acting since I was ten years old.<|endoftext|>Bot profile: Now I'm a fixture on a weekly telenovela.<|endoftext|>Bot profile: I'm an actor.<|endoftext|>Bot profile: I made my first movie when I was 18.<|endoftext|>Bot profile: I starred in television commercials when I was young.<|endoftext|>Bot: Just completed another movie. I been acting since age ten.<|endoftext|>Person: I'm close to the beach so I surf a lot.<|endoftext|>Bot: I'm close to the kitchen so I eat a lot.<|endoftext|>Person: Nice, do you make a lot of money acting?<|endoftext|>Bot: No but you'd be a good voiceover artist. Ever tried it?<|endoftext|>Person: No, I work as a life guard.<|endoftext|>Bot: Sorry abut that. I did commercials when I was younger.<|endoftext|>Person: I'm watching the sun go down, it is so beautiful.<|endoftext|>Bot: Oh okay. How many fingers am I holding up?<|endoftext|>Person: That is a hard one two, for the peace sign.<|endoftext|>Bot: I appear on a weekly internet show?

In [None]:
!python3 ParlAI/parlai/scripts/safe_interactive.py -t blended_skill_talk -mf zoo:blender/blender_90M/model

In [None]:
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Utility functions and classes for detecting offensive language.
"""

from parlai.agents.transformer.transformer import TransformerClassifierAgent
from parlai.core.agents import create_agent, create_agent_from_shared
from parlai.tasks.dialogue_safety.agents import OK_CLASS, NOT_OK_CLASS
from parlai.utils.typing import TShared
from parlai.utils.io import PathManager
import parlai.utils.logging as logging
import os


class OffensiveLanguageClassifier:
    """
    Load model trained to detect offensive language in the context of single- turn
    dialogue utterances.

    This model was trained to be robust to adversarial examples created by humans. See
    <http://parl.ai/projects/dialogue_safety/> for more information.
    """

    def __init__(
        self,
        shared: TShared = None,
        custom_model_file='zoo:dialogue_safety/single_turn/model',
    ):
        if not shared:
            self.model = self._create_safety_model(custom_model_file)
        else:
            self.model = create_agent_from_shared(shared['model'])
        self.classes = {OK_CLASS: False, NOT_OK_CLASS: True}

    def share(self):
        shared = {'model': self.model.share()}
        return shared

    def _create_safety_model(self, custom_model_file):
        from parlai.core.params import ParlaiParser

        parser = ParlaiParser(False, False)
        TransformerClassifierAgent.add_cmdline_args(parser, partial_opt=None)
        parser.set_params(
            model='transformer/classifier',
            model_file=custom_model_file,
            print_scores=True,
            data_parallel=False,
        )
        safety_opt = parser.parse_args([])
        return create_agent(safety_opt, requireModelExists=True)

    def contains_offensive_language(self, text):
        """
        Returns the probability that a message is safe according to the classifier.
        """
        act = {'text': text, 'episode_done': True}
        self.model.observe(act)
        response = self.model.act()['text']
        pred_class, prob = [x.split(': ')[-1] for x in response.split('\n')]
        pred_not_ok = self.classes[pred_class]  # check whether classified as NOT OK
        prob = float(prob)  # cast string to float

        return pred_not_ok, prob

    def __contains__(self, key):
        """
        A simple way of checking whether the model classifies an utterance as offensive.

        Returns True if the input phrase is offensive.
        """
        pred_not_ok, prob = self.contains_offensive_language(key)
        return pred_not_ok


class OffensiveStringMatcher:
    """
    Detects offensive language using a list of offensive language and phrases from
    https://github.com/LDNOOBW.
    """

    def __init__(self, datapath: str = None):
        """
        Get data from external sources and build data representation.

        If datapath ends in '.txt' it is assumed a custom model file is already given.
        """
        import parlai.core.build_data as build_data
        from parlai.core.dict import DictionaryAgent

        self.tokenize = DictionaryAgent.split_tokenize

        def _path():
            # Build the data if it doesn't exist.
            build()
            return os.path.join(
                self.datapath, 'OffensiveLanguage', 'OffensiveLanguage.txt'
            )

        def build():
            version = 'v1.0'
            dpath = os.path.join(self.datapath, 'OffensiveLanguage')
            if not build_data.built(dpath, version):
                logging.info(f'building data: {dpath}')
                if build_data.built(dpath):
                    # An older version exists, so remove these outdated files.
                    build_data.remove_dir(dpath)
                build_data.make_dir(dpath)

                # Download the data.
                fname = 'OffensiveLanguage.txt'
                url = 'http://parl.ai/downloads/offensive_language/' + fname
                build_data.download(url, dpath, fname)

                # Mark the data as built.
                build_data.mark_done(dpath, version)

        if datapath is not None and datapath.endswith('.txt'):
            # Load custom file.
            self.datafile = datapath
        else:
            # Build data from zoo, and place in given datapath.
            if datapath is None:
                # Build data from zoo.
                from parlai.core.params import ParlaiParser

                parser = ParlaiParser(False, False)
                self.datapath = parser.parse_args([])['datapath']
            else:
                self.datapath = datapath
            self.datafile = _path()

        # store a token trie: e.g.
        # {'2': {'girls': {'1': {'cup': {'__END__': True}}}}
        self.END = '__END__'
        self.max_len = 1
        self.offensive_trie = {}
        self.word_prefixes = [
            'de',
            'de-',
            'dis',
            'dis-',
            'ex',
            'ex-',
            'mis',
            'mis-',
            'pre',
            'pre-',
            'non',
            'non-',
            'semi',
            'semi-',
            'sub',
            'sub-',
            'un',
            'un-',
        ]
        self.word_suffixes = [
            'a',
            'able',
            'as',
            'dom',
            'ed',
            'er',
            'ers',
            'ery',
            'es',
            'est',
            'ful',
            'fy',
            'ies',
            'ify',
            'in',
            'ing',
            'ish',
            'less',
            'ly',
            's',
            'y',
        ]
        self.allow_list = [
            'butter',
            'buttery',
            'spicy',
            'spiced',
            'spices',
            'spicier',
            'spicing',
            'twinkies',
        ]

        with PathManager.open(self.datafile, 'r') as f:
            for p in f.read().splitlines():
                mod_ps = [p]
                mod_ps += [pref + p for pref in self.word_prefixes]
                mod_ps += [p + suff for suff in self.word_suffixes]
                for mod_p in mod_ps:
                    if mod_p not in self.allow_list:
                        self.add_phrase(mod_p)

    def add_phrase(self, phrase):
        """
        Add a single phrase to the filter.
        """
        toks = self.tokenize(phrase)
        curr = self.offensive_trie
        for t in toks:
            if t not in curr:
                curr[t] = {}
            curr = curr[t]
        curr[self.END] = True
        self.max_len = max(self.max_len, len(toks))

    def add_words(self, phrase_list):
        """
        Add list of custom phrases to the filter.
        """
        for phrase in phrase_list:
            self.add_phrase(phrase)

    def _check_sequence(self, toks, idx, node):
        """
        Check if words from the sequence are in the trie.

        This checks phrases made from toks[i], toks[i:i+2] ... toks[i:i + self.max_len]
        """
        right = min(idx + self.max_len, len(toks))
        for i in range(idx, right):
            if toks[i] in node:
                node = node[toks[i]]
                if self.END in node:
                    return ' '.join(toks[j] for j in range(idx, i + 1))
            else:
                break
        return False

    def contains_offensive_language(self, text):
        """
        Determine if text contains any offensive words in the filter.
        """
        if type(text) is str:
            toks = self.tokenize(text.lower())
        elif type(text) is list or type(text) is tuple:
            toks = text

        for i in range(len(toks)):
            res = self._check_sequence(toks, i, self.offensive_trie)
            if res:
                return res

        return None

    def find_all_offensive_language(self, text):
        """
        Find all offensive words from text in the filter.
        """
        if type(text) is str:
            toks = self.tokenize(text.lower())
        elif type(text) is list or type(text) is tuple:
            toks = text

        all_offenses = []
        for i in range(len(toks)):
            res = self._check_sequence(toks, i, self.offensive_trie)
            if res:
                all_offenses.append(res)
        return all_offenses

    def __contains__(self, key):
        """
        Determine if text contains any offensive words in the filter.
        """
        return self.contains_offensive_language(key)


In [None]:
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Agent that gets the local keyboard input in the act() function.

Applies safety classifier(s) to process user and partner messages.
"""

from typing import Optional
from parlai.core.params import ParlaiParser
from parlai.core.opt import Opt
from parlai.core.message import Message
from parlai.utils.misc import display_messages
from parlai.utils.strings import colorize
from parlai.agents.local_human.local_human import LocalHumanAgent
from parlai.utils.safety import OffensiveStringMatcher, OffensiveLanguageClassifier


OFFENSIVE_USER_REPLY = '[ Sorry, could not process that message. Please try again. ]'
OFFENSIVE_BOT_REPLY = (
    '[ Unsafe model reply detected. Clearing agent history. Please try again. ]'
)


class SafeLocalHumanAgent(LocalHumanAgent):
    @classmethod
    def add_cmdline_args(
        cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
    ) -> ParlaiParser:
        """
        Add command-line arguments specifically for this agent.
        """
        agent = parser.add_argument_group('Safe Local Human Arguments')
        agent.add_argument(
            '--safety',
            type=str,
            default='all',
            choices={'none', 'string_matcher', 'classifier', 'all'},
            help='Apply safety filtering to messages',
        )
        super().add_cmdline_args(parser, partial_opt=partial_opt)
        return parser

    def __init__(self, opt, shared=None):
        super().__init__(opt)
        self.id = 'safeLocalHuman'
        self._init_safety(opt)

    def _init_safety(self, opt):
        """
        Initialize safety modules.
        """
        if opt['safety'] == 'string_matcher' or opt['safety'] == 'all':
            self.offensive_string_matcher = OffensiveStringMatcher()
        if opt['safety'] == 'classifier' or opt['safety'] == 'all':
            self.offensive_classifier = OffensiveLanguageClassifier()

        self.self_offensive = False

    def check_offensive(self, text):
        """
        Check if text is offensive using string matcher and classifier.
        """
        if text == '':
            return False
        if (
            hasattr(self, 'offensive_string_matcher')
            and text in self.offensive_string_matcher
        ):
            return True
        if hasattr(self, 'offensive_classifier') and text in self.offensive_classifier:
            return True

        return False

    def observe(self, msg):
        """
        Observe bot reply if and only if it passes.
        """
        if self.self_offensive:
            # User was offensive, they must try again
            return

        # Now check if bot was offensive
        bot_offensive = self.check_offensive(msg.get('text', ''))
        if not bot_offensive:
            # View bot message
            print(
                display_messages(
                    [msg],
                    add_fields=self.opt.get('display_add_fields', ''),
                    prettify=self.opt.get('display_prettify', False),
                    verbose=self.opt.get('verbose', False),
                )
            )
            msg.force_set('bot_offensive', False)
        else:
            msg.force_set('bot_offensive', True)
            print(OFFENSIVE_BOT_REPLY)

    def get_reply(self):
        reply_text = input(colorize('Enter Your Message:', 'field') + ' ')
        reply_text = reply_text.replace('\\n', '\n')

        return reply_text

    def act(self):
        # get human reply
        reply = Message(
            {
                'id': self.getID(),
                'label_candidates': self.fixedCands_txt,
                'episode_done': False,
            }
        )
        reply_text = self.get_reply()

        # check if human reply is offensive
        self.self_offensive = self.check_offensive(reply_text)
        while self.self_offensive:
            print(OFFENSIVE_USER_REPLY)
            reply_text = self.get_reply()
            # check if human reply is offensive
            self.self_offensive = self.check_offensive(reply_text)

        # check for episode done
        if '[DONE]' in reply_text or self.opt.get('single_turn', False):
            raise StopIteration

        # set reply text
        reply['text'] = reply_text

        # check if finished
        if '[EXIT]' in reply_text:
            self.finished = True
            raise StopIteration

        return reply


In [None]:

"""
Script for safety protected interaction between a local human keyboard input and a
trained model.
"""

from parlai.core.params import ParlaiParser
from parlai.core.script import ParlaiScript, register_script
from parlai.core.agents import create_agent
from parlai.core.worlds import create_task
from parlai.agents.safe_local_human.safe_local_human import SafeLocalHumanAgent
import parlai.utils.logging as logging
import random


def setup_args(parser=None):
    if parser is None:
        parser = ParlaiParser(True, True, 'Like interactive, but adds a safety filter')
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.add_argument(
        '--display-prettify',
        type='bool',
        default=False,
        help='Set to use a prettytable when displaying '
        'examples with text candidates',
    )
    parser.add_argument(
        '--display-add-fields',
        type=str,
        default='',
        help='Display these fields when verbose is off (e.g., "--display-add-fields label_candidates,beam_texts")',
    )
    parser.add_argument(
        '-it',
        '--interactive-task',
        type='bool',
        default=True,
        help='Create interactive version of task',
    )
    parser.set_defaults(interactive_mode=True, task='interactive')
    SafeLocalHumanAgent.add_cmdline_args(parser, partial_opt=None)
    return parser


def safe_interactive(opt):
    if isinstance(opt, ParlaiParser):
        logging.error('interactive should be passed opt not Parser')
        opt = opt.parse_args()

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    agent.opt.log()
    human_agent = SafeLocalHumanAgent(opt)
    world = create_task(opt, [human_agent, agent])

    # Interact until episode done
    while True:
        world.parley()
        bot_act = world.get_acts()[-1]
        if 'bot_offensive' in bot_act and bot_act['bot_offensive']:
            agent.reset()

        if opt.get('display_examples'):
            print('---')
            print(world.display())
        if world.epoch_done():
            logging.info('epoch done')
            break


@register_script('safe_interactive')
class SafeInteractive(ParlaiScript):
    @classmethod
    def setup_args(cls):
        return setup_args()

    def run(self):
        return safe_interactive(self.opt)


if __name__ == '__main__':
    random.seed(42)
    SafeInteractive.main()


In [None]:
!nvidia-smi

In [None]:
!pip install onnx==1.8.1
!pip install onnxruntime-gpu==1.8.1
!pip install onnxconverter-common==1.8.1

In [118]:
from transformers import GPT2TokenizerFast


MODEL_DIR = 'model_dir/checkpoint/'
DEVICE = 'cuda'
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_DIR)
model = GPT2LMHeadModel.from_pretrained(MODEL_DIR).to(DEVICE)





In [119]:
TOP_P = 0.9
TOP_K = 40
TEMPERATURE = 1
MAX_RESPONSE_LEN = 24

In [133]:
def gen_nuclues(text):
  text = f'{text}'
  inpt = tokenizer.encode(text, return_tensors="pt").to(DEVICE)
  print(inpt)
  out = model.generate(inpt,max_length=250,num_return_sequences=10)
  return tokenizer.decode(out[0])

In [134]:
gen_nuclues('Bot profile: Ive been acting since I was ten years old.<|endoftext|>Bot profile: Now Im a fixture on a weekly telenovela.<|endoftext|>Bot profile: Im an actor.<|endoftext|>Bot profile: I made my first movie when I was 18.<|endoftext|>Bot profile: I starred in television commercials when I was young.<|endoftext|>Bot: Just completed another movie. I been acting since age ten.<|endoftext|>')

tensor([[20630,  7034,    25,   314,   303,   587,  7205,  1201,   314,   373,
          3478,   812,  1468,    13, 50256, 20630,  7034,    25,  2735,  1846,
           257, 29220,   319,   257, 10273, 13632, 23397,   626,    64,    13,
         50256, 20630,  7034,    25,  1846,   281,  8674,    13, 50256, 20630,
          7034,    25,   314,   925,   616,   717,  3807,   618,   314,   373,
          1248,    13, 50256, 20630,  7034,    25,   314, 31636,   287,  5581,
         30503,   618,   314,   373,  1862,    13, 50256, 20630,    25,  2329,
          5668,  1194,  3807,    13,   314,   587,  7205,  1201,  2479,  3478,
            13, 50256]], device='cuda:0')


RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

In [None]:

context_strs = [sep.join(context) + sep for _ in range(cands_num)]

In [None]:
import random
from argparse import ArgumentParser
from itertools import chain
from pprint import pformat
import warnings
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, GPT2LMHeadModel, GPT2Tokenizer

def run():
    model='openai-gpt'
    dataset_path=''
    model_checkpoint='./'
    max_history=3
    device = 'cuda'
    max_length=20
    min_length=1
    seed=0
    temperature=1
    top_k=5000
    top_p=0.9
    seed = 0
    no_sample='store_true'
    dataset_cache='./dataset_cache'
    max_history=4
    args={'device':device,'max_length':max_length,'min_length':min_length,'seed':seed,'temperature':temperature,
          'top_k':top_k,'top_p':top_p,'seed':seed,'no_sample':'store_true','max_history':max_history}


    if model_checkpoint == "":
        if model == 'gpt2':
            raise ValueError("Interacting with GPT2 requires passing a finetuned model_checkpoint")
        else:
            model_checkpoint = download_pretrained_model()
	
    if seed != 0:
        random.seed(seed)
        torch.random.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
    
    tokenizer_class, model_class = (GPT2Tokenizer, GPT2LMHeadModel) if model == 'gpt2' else (OpenAIGPTTokenizer, OpenAIGPTLMHeadModel)
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    
    model = GPT2LMHeadModel.from_pretrained(model_checkpoint)
    model.to(device)
    add_special_tokens_(model, tokenizer)

    print("Type in personality: ")    
    personality = input()   # Ava is a 25 year old girl ready to talk and very interested in User.

    personality=[tokenizer.encode(personality)]
    print(personality)
    print("Selected personality: %s", tokenizer.decode(chain(*personality)))

    history = []
    while True:
        raw_text = input(">>> ")
        while not raw_text:
            print('Prompt should not be empty!')
            raw_text = input(">>> ")
        history.append(tokenizer.encode(raw_text))
        with torch.no_grad():
            out_ids = sample_sequence(personality, history, tokenizer, model,args)
        history.append(out_ids)
        history = history[-(2*args['max_history']+1):]
        #print(out_ids)
        out_text = tokenizer.decode(out_ids, skip_special_tokens=True)
        print(out_text)
run()
