In [2]:
from utils.utils import load_model
from prompts.generic_prompt import load_prefix, generate_response_interactive, select_prompt_interactive
from prompts.generic_prompt_parser import load_prefix as load_prefix_parse
from prompts.persona_chat import convert_sample_to_shot_persona
from prompts.persona_chat_memory import convert_sample_to_shot_msc, convert_sample_to_shot_msc_interact
from prompts.persona_parser import convert_sample_to_shot_msc as convert_sample_to_shot_msc_parse
from prompts.emphatetic_dialogue import convert_sample_to_shot_ed
from prompts.daily_dialogue import convert_sample_to_shot_DD_prefix, convert_sample_to_shot_DD_inference
from prompts.skill_selector import convert_sample_to_shot_selector
import random
import torch
import pprint
pp = pprint.PrettyPrinter(indent=4)
args = type('', (), {})()
args.multigpu = False
device = 4

## To use GPT-Jumbo (178B) set this to true and input your api-key
## Visit https://studio.ai21.com/account for more info
## AI21 provides 10K tokens per day, so you can try only for few turns
api = False
api_key = ''

In [3]:
## This is the config dictionary used to select the template converter
mapper = {
          "persona": {"shot_converter":convert_sample_to_shot_persona, 
                    "shot_converter_inference": convert_sample_to_shot_persona,
                     "file_data":"data/persona/","with_knowledge":None,
                     "shots":{1024:[0,1,2],2048:[0,1,2,3,4,5]},"max_shot":{1024:2,2048:3},
                     "shot_separator":"\n\n",
                     "meta_type":"all","gen_len":50,"max_number_turns":5},
          "msc": {"shot_converter":convert_sample_to_shot_msc, 
                    "shot_converter_inference": convert_sample_to_shot_msc_interact,
                     "file_data":"data/msc/session-2-","with_knowledge":None,
                     "shots":{1024:[0,1],2048:[0,1,3]},"max_shot":{1024:1,2048:3},
                     "shot_separator":"\n\n",
                     "meta_type":"all","gen_len":50,"max_number_turns":3},
          "ed": {"shot_converter":convert_sample_to_shot_ed, 
                 "shot_converter_inference": convert_sample_to_shot_ed,
                 "file_data":"data/ed/","with_knowledge":None,
                  "shots":{1024:[0,1,7],2048:[0,1,17]},"max_shot":{1024:7,2048:17},
                  "shot_separator":"\n\n",
                  "meta_type":"none","gen_len":50,"max_number_turns":5},
          "DD": {"shot_converter":convert_sample_to_shot_DD_prefix, 
                 "shot_converter_inference": convert_sample_to_shot_DD_inference,
                 "file_data":"data/dailydialog/","with_knowledge":False,
                  "shots":{1024:[0,1,2],2048:[0,1,6]},"max_shot":{1024:2,2048:6},
                  "shot_separator":"\n\n",
                  "meta_type":"all_turns","gen_len":50,"max_number_turns":5},
          "msc-parse": {"shot_converter":convert_sample_to_shot_msc_parse, "max_shot":{1024:1,2048:2},
                 "file_data":"data/msc/parse-session-1-","level":"dialogue", "retriever":"none",
                  "shots":{1024:[0,1],2048:[0, 1, 2]},"shot_separator":"\n\n",
                  "meta_type":"incremental","gen_len":50,"max_number_turns":3},
                  
         }
## This is the config dictionary used to select the template converter
mapper_safety = {
          "safety_topic": {"file_data":"data/safety_layers/safety_topic.json","with_knowledge":None,
                     "shots":{1024:[0,1,2],2048:[0,1,2,3,4,5]},"max_shot":{1024:2,2048:3},
                     "shot_separator":"\n\n",
                     "meta_type":"all","gen_len":50,"max_number_turns":2},
          "safety_nonadv": {"file_data":"data/safety_layers/safety_nonadv.json","with_knowledge":None,
                     "shots":{1024:[0,1,2],2048:[0,1,2,3,4,5]},"max_shot":{1024:2,2048:3},
                     "shot_separator":"\n\n",
                     "meta_type":"all","gen_len":50,"max_number_turns":2},
          "safety_adv": {"file_data":"data/safety_layers/safety_adv.json","with_knowledge":None,
                     "shots":{1024:[0,1,2],2048:[0,1,2,3,4,5]},"max_shot":{1024:2,2048:3},
                     "shot_separator":"\n\n",
                     "meta_type":"all","gen_len":50,"max_number_turns":2},
         }

In [4]:
## Load LM and tokenizer
## You can try different LMs: 
##   gpt2 
##   gpt2-medium 
##   gpt2-large
##   gpt2-xl
##   EleutherAI/gpt-neo-1.3B
##   EleutherAI/gpt-neo-2.7B
##   EleutherAI/gpt-j-6B
## So far the largest I could load is gpt2-large
model_checkpoint = "EleutherAI/gpt-neo-1.3B"
model, tokenizer, max_seq = load_model(args,model_checkpoint,device)

LOADING EleutherAI/gpt-neo-1.3B
DONE LOADING


In [5]:
available_datasets = mapper.keys()
prompt_dict = {}
prompt_parse = {}
prompt_skill_selector = {}
for d in available_datasets:
    if "parse" in d:
        prompt_parse[d] = load_prefix_parse(tokenizer=tokenizer, shots_value=mapper[d]["shots"][max_seq], 
                                shot_converter=mapper[d]["shot_converter"], 
                                file_shot=mapper[d]["file_data"]+"valid.json", 
                                name_dataset=d, level=mapper[d]["level"], 
                                shot_separator=mapper[d]["shot_separator"],sample_times=1)[0]
    else:
        prompt_skill_selector[d] = load_prefix(tokenizer=tokenizer, shots_value=[6], 
                    shot_converter=convert_sample_to_shot_selector, 
                    file_shot= mapper[d]["file_data"]+"train.json" if "smd" in d else mapper[d]["file_data"]+"valid.json", 
                    name_dataset=d, with_knowledge=None, 
                    shot_separator=mapper[d]["shot_separator"],sample_times=1)[0]
        prompt_dict[d] = load_prefix(tokenizer=tokenizer, shots_value=mapper[d]["shots"][max_seq], 
                    shot_converter=mapper[d]["shot_converter"], 
                    file_shot=mapper[d]["file_data"]+"valid.json", 
                    name_dataset=d, with_knowledge=mapper[d]["with_knowledge"], 
                    shot_separator=mapper[d]["shot_separator"],sample_times=1)[0]
    
## add safety prompts
for d in mapper_safety.keys():
    prompt_skill_selector[d] = load_prefix(tokenizer=tokenizer, shots_value=[6], 
            shot_converter=convert_sample_to_shot_selector, 
            file_shot= mapper_safety[d]["file_data"], 
            name_dataset=d, with_knowledge=None, 
            shot_separator=mapper_safety[d]["shot_separator"],sample_times=1)[0]


Loaded persona dict_keys([6]) shots for shuffle 0!
Loaded persona dict_keys([0, 1, 2, 3, 4, 5]) shots for shuffle 0!
Loaded msc dict_keys([6]) shots for shuffle 0!
Loaded msc dict_keys([0, 1, 3]) shots for shuffle 0!
Loaded ed dict_keys([6]) shots for shuffle 0!
Loaded ed dict_keys([0, 1, 17]) shots for shuffle 0!
Loaded DD dict_keys([6]) shots for shuffle 0!
Loaded DD dict_keys([0, 1, 6]) shots for shuffle 0!
Loaded msc-parse dict_keys([0, 1, 2]) shots for shuffle 0!
Loaded safety_topic dict_keys([6]) shots for shuffle 0!
Loaded safety_nonadv dict_keys([6]) shots for shuffle 0!
Loaded safety_adv dict_keys([6]) shots for shuffle 0!


In [6]:
def run_parsers(args, model, tokenizer, device, max_seq, dialogue, skill, prefix_dict):
    dialogue["user_memory"].append([])

    if skill not in ["msc"]: return dialogue

    # if d == "dialKG":
    #     dialogue["KG"].append([])

    ### parse 
    d_p = f"{skill}-parse"
    # print(f"Parse with {d_p}")

    prefix = prefix_dict[d_p].get(mapper[d_p]["max_shot"][max_seq])
    query = generate_response_interactive(model, tokenizer, shot_converter=mapper[d_p]["shot_converter"], 
                                                dialogue=dialogue, prefix=prefix, 
                                                device=device,  with_knowledge=None, 
                                                meta_type=None, gen_len=50, 
                                                beam=1, max_seq=max_seq, eos_token_id=198, 
                                                do_sample=False, multigpu=False, api=api, api_key=api_key)

    # print(f"Query: {query}")
    # if d == "wow":
    #     dialogue["KB_wiki"].append([retrieve_K])
    # elif d == "dialKG":
    #     dialogue["KG"][-1] = [retrieve_K]
    # elif d == "wit":
    #     dialogue["KB_internet"].append([retrieve_K])
    #     dialogue["query"].append([query])
    if skill == "msc":
        if "none" != query:
            dialogue["user"].append(query)
            dialogue["user_memory"][-1] = [query]
    return dialogue

In [None]:
max_number_turns = 3
dialogue = {"dialogue":[],"meta":[],"user":[],"assistant":[],"user_memory":[]}
## This meta information is the persona of the FSB
dialogue["meta"] = dialogue["assistant"] = [
                "i am the smartest chat-bot around .",
                "my name is FSB . ",
                "i love chatting with people .",
                "my creator is Andrea"
                ]
t = 10
while t>0: 
    t -= 1
    user_utt = input(">>> ")
    dialogue["dialogue"].append([user_utt,""])
    ## run the skill selector
    skill = select_prompt_interactive(model, tokenizer, 
                                      shot_converter=convert_sample_to_shot_selector, 
                                      dialogue=dialogue, prompt_dict=prompt_skill_selector, 
                                      device=device, max_seq=max_seq, max_shot=6)
    
    if "safety" in skill: 
        response = "Shall we talk about something else?"
        print(f"FSB (Safety) >>> {response}")

    else:
        ## parse user dialogue history ==> msc-parse
        dialogue = run_parsers(args, model, tokenizer, device=device, max_seq=max_seq,
                                dialogue=dialogue, skill=skill,  
                                prefix_dict=prompt_parse)
        ## generate response based on skills
        prefix = prompt_dict[skill].get(mapper[skill]["max_shot"][max_seq])
        response = generate_response_interactive(model, tokenizer, shot_converter=mapper[skill]["shot_converter_inference"], 
                                                    dialogue=dialogue, prefix=prefix, 
                                                    device=device, with_knowledge=mapper[skill]["with_knowledge"], 
                                                    meta_type=mapper[skill]["meta_type"], gen_len=50, 
                                                    beam=1, max_seq=max_seq, eos_token_id=198, 
                                                    do_sample=True, multigpu=False, api=api, api_key=api_key)
                    

        print(f"FSB ({skill}) >>> {response}")
    dialogue["dialogue"][-1][1] = response
    dialogue["dialogue"] = dialogue["dialogue"][-max_number_turns:]
    dialogue["user_memory"] = dialogue["user_memory"][-max_number_turns:]
print("This is the conversation history with its meta-data!")
print(pp.pprint(dialogue))