In [245]:
import os
import json
import openai
import random
import datetime
import pandas as pd

from pathlib import Path
from pprint import pprint
from collections import defaultdict

from dataclasses import dataclass, field
from typing import Optional

from langchain import PromptTemplate
from langchain.agents import create_pandas_dataframe_agent
from langchain.llms.openai import OpenAI
from langchain.callbacks import get_openai_callback
from langchain.agents import AgentType

# from src.e2e.e2e_utils import E2E_InstrucTOD

# with open("data/restaurant_db.json", "r") as f:
#     restaurant_db = json.load(f)
    
# keep_attr = ["address", "area", "food", "introduction", "name", "phone", "postcode", "pricerange", "signature"]
# restaurant_db = [{k:v for k, v in sample.items() if k in keep_attr} for sample in restaurant_db]

# with open("data/restaurant_db_new.json", "w") as f:
#     json.dump(restaurant_db, f)

In [246]:
@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(
        default="gpt-3.5-turbo-0301",
        metadata={"help": "The path of the HuggingFace model."}
    )
    temperature: Optional[int] = field(
        default=0,
        metadata={"help": "Temperature for the agent's generation"}
    )
    print_cost: Optional[bool] = field(
        default=False,
        metadata={"help": "Print the cost of using the agent for KB interaction at every turn"}
    )


@dataclass
class DataArguments:
    dataset_name: Optional[str] = field(
        default=None,
        metadata={"help": "Train dataset path"}
    )
    root_data_path: Optional[str] = field(
        default="./data", metadata={"help": "The path to the data directory."},
    )
    dialog_history_limit_bs: Optional[int] = field(
        default=3,
        metadata={"help": "Length of dialogue history to take for the proxy belief state"}
    )
    dialog_history_limit_rg: Optional[int] = field(
        default=7,
        metadata={"help": "Length of dialogue history to take for the response generation"}
    )
    log_path: Optional[str] = field(
        default="./demo/logs/",
        metadata={"help": "path for the log directory"}
    )
    load_path: Optional[str] = field(
        default="/data/",
        metadata={"help": "load path for the kb"}
    )
    config_path: Optional[str] = field(
        default="/config.json",
        metadata={"help": "load path"}
    )
    agent_max_iterations: Optional[int] = field(
        default=5,
        metadata={"help": "Max number of iterations for agents in e2e (higher=better but more expensive)"}
    )
    verbose: Optional[bool] = field(
        default=False,
        metadata={"help": "To output the intermediary steps"}
    )
    return_intermediate_steps: Optional[bool] = field(
        default=False,
        metadata={"help": "To log the intermediary steps (KB retrieval process)"}
    )
    
@dataclass
class TrainingArguments:
    pass

In [247]:
#config.json
CONFIG = {"task_objective":"Book a restaurant",
          "proxy_bs":{"instruction":"Suppose you have access to a database where all the column attributes are given in INFORMATION, what do you need to query to the database in order to reply to the user in the following conversation?",
                      "example":"USER: I need fruits.\nSYSTEM: Do you have any preferences?\nUSER: Yes, apples if possible. How expensive and how many are there?\n\nNeed: Information about pricerange and and count for apple\n\nUSER: Cool, that asnwers my question.\nSYSTEM: I am happy to help. Anything else needed?\nUSER: I'm done, thanks!\n\nNeed: Information about nothing\n\nUSER: I'm looking for a popular zoo around here\nSYSTEM: There are multiple zoos I can recommend. Any preference on the location?\nUSER: I want it to be in the west part of town\n\nNeed: Information about zoo in the west",
                      "template":"""{instruction}\n\nYou can follow these examples:{example}\n\nINFORMATION: {information}\n{dialogue_context}\n\nNeed:""",
                      "input_variables":["instruction", "information", "example", "dialogue_context"]},
          "response_generation":{"instruction":"In a task oriented dialogue setting, generate a natural and helpful SYSTEM response to the USER query in the conversation provided in CONTEXT. You should follow the information provided in ACT to generate this answer. Do not mention that you are referring to a dataframe and don't overload the user with too many choices. You have the ability to perform confirm bookings for users:",
                                 "example":"USER: I need a place to fish\nSYSTEM: Any preference in the type of fish?\nUSER: Preferably salmons, but sardines are also fine\n\nACT: blue lake, 37th avenue\n\nSYSTEM: How about blue lake, 37th avenue in that case?",
                                 "template":"""{instruction}\n\nYou can follow these examples:{example}\n\nCONTEXT:\n{dialogue_context}\n\nACT:{dialogue_act}\n\nSYSTEM:""",
                                 "input_variables":["instruction", "example", "dialogue_context", "dialogue_act"]},
          "chitchat":{"instruction":"Given the following dialogue context between a USER and a SYSTEM, generate the response of SYSTEM as naturally as possible. Only answer questions that you have factual knowledge about. If you do not have the knowledge to answer the question, simply state so to the user:",
                      "template":"""{instruction}\n\n{dialogue_context}""",
                      "input_variables":["instruction", "dialogue_context"]},
          "classification":{"instruction":"Given the TASK and the provided dialogue context, classify whether the user UTTERANCE is related to the conversation topic or not. As long as the utterance is somewhat relate to the domain of the conversation, then it's fine. Just answer with 'yes' or 'no':",
                            "example":"USER: Give me information about a 3 star hotel\nSYSTEM: Sure! I found 5 that could be interesting.\nTASK: Book a hotel\nUSER UTTERANCE: I want to buy toys\nRELATED: No\n\nUSER: Where is the turtle attraction?\nSYSTEM: Right next to the ice cream shop\nTASK: Recommend an attraction\nUSER UTTERANCE: Ah I see. How do I go there?\nRELATED: Yes\n\nUSER: Where are the balloons?\nSYSTEM: In a park nearby\nTASK: Book a hotel\nUSER UTTERANCE: Fine, how many stars is this hotel?\nRELATED: Yes\n\nUSER: Where is flight A301 going?\nSYSTEM: This flight is going to Spain\nUSER: And how about B563?\nSYSTEM: This one is going to Greece.\nTASK: Buy an airplane ticket\nUSER UTTERANCE: I am trying to go to Barcelona, so A301 right?\nRELATED: Yes",
                            "template":"""{instruction}\n\nYou can follow these examples:{example}\n\n{dialogue_context}\nTASK: {task}\nUSER UTTERANCE: {utterance}\nRELATED:""",
                            "input_variables":["instruction", "example", "task", "dialogue_context", "utterance"]},
          "mode":{"instruction":"",
                  "example":"",
                  "template":"",
                  "input_variables":""},
         }

In [248]:
class AgentConfig():
    def __init__(self,
                 config,
                 model_args,
                 data_args):
        self.config = config
        self.model_args = model_args
        self.data_args = data_args
        
        self.task_objective = config["task_objective"]
        self.kb_df = self._load_knowledge_base()
        self._load_prompt_config()
        self.attributes = list(self.kb_df.columns)
        
    
    def _load_prompt_config(self):
        
        self.prompt_bs_template = PromptTemplate(template=self.config["proxy_bs"]["template"],
                                                 input_variables=self.config["proxy_bs"]["input_variables"])
        self.bs_instruction = self.config["proxy_bs"]["instruction"]
        self.bs_example = self.config["proxy_bs"]["example"]

        self.prompt_rg_template = PromptTemplate(template=self.config["response_generation"]["template"],
                                                 input_variables=self.config["response_generation"]["input_variables"])
        self.rg_instruction = self.config["response_generation"]["instruction"]
        self.rg_example = self.config["response_generation"]["example"]
        
        self.prompt_chitchat_template = PromptTemplate(template=self.config["chitchat"]["template"],
                                                       input_variables=self.config["chitchat"]["input_variables"])
        self.chitchat_instruction = self.config["chitchat"]["instruction"]
        
        self.prompt_classification_template = PromptTemplate(template=self.config["classification"]["template"],
                                                             input_variables=self.config["classification"]["input_variables"])
        self.classification_instruction = self.config["classification"]["instruction"]
        self.classification_example = self.config["classification"]["example"]
        
        self.welcome_sentence = f"Hi! I am a TOD with chitchat capability. My current task is {self.task_objective}. What can I help you with?"
        
        
    def _load_knowledge_base(self):
        
        kb_ext = os.path.splitext(self.data_args.load_path)[-1]
        kb_path = self.data_args.load_path
        if kb_ext == ".json":
            kb_df = pd.read_json(kb_path)
        elif kb_ext == ".csv":
            kb_df = pd.read_csv(kb_path)
        elif kb_ext == ".xlsx":
            kb_df = pd.read_excel(kb_path) 
        else:
            raise ValueError(f"Knowledge base should be either json, csv or xslx. Current kb_path: {kb_path}")
        
        return kb_df
    

    def _completion(self, prompt):
        
        if "gpt-3.5-turbo" in self.model_args.model_name_or_path or "gpt-4" in self.model_args.model_name_or_path:
            try:
                completion = openai.ChatCompletion.create(
                        model=self.model_args.model_name_or_path.replace("openai/", ""),
                        messages=[
                            {"role": "user", "content": prompt}
                        ],
                        temperature=0
                    )
            except: #Try twice, API sometimes fails due to server issues
                completion = openai.ChatCompletion.create(
                    model=self.model_args.model_name_or_path.replace("openai/", ""),
                    messages=[
                        {"role": "user", "content": prompt}
                    ],
                    temperature=0
                )
            response = completion.choices[0].message.content.strip()

        else:
            raise ValueError("model_name_or_path should be gpt-3.5-turbo or gpt-4 for this setting")
        
        return response  
    
    
    def flush_logs(self):
        
        self.conversation_logs = defaultdict(list)
        
        
    def _create_session_logs(self):
        
        self.conversation_logs[self.session_id] = {0:{"utterance":self.welcome_sentence,
                                                      "speaker":"SYSTEM",
                                                      "mode":"",
                                                      "belief_state":"",
                                                      "database_query":""}}
         
    def _update_logs(self, utterance, speaker, mode="", bs="", kb=""):
        
        self.conversation_logs[self.session_id][self.turn] = {"utterance":utterance,
                                                              "speaker":speaker,
                                                              "mode":mode,
                                                              "belief_state":bs,
                                                              "database_query":kb}
        self.turn += 1
    
    def _save_logs(self):
        log_file = datetime.datetime.now().strftime("log__%d-%m-%Y__%H-%M-%S.json")
        with open(os.path.join(self.data_args.log_path, log_file), "w") as f:
            json.dump(self.conversation_logs, f, indent=4)
    
    
    
    
class InstructTODS(AgentConfig):

    def __init__(self, 
                 config, 
                 model_args, 
                 data_args):
        super().__init__(config, 
                         model_args, 
                         data_args)
        
        self.conversation_logs = defaultdict(list)
        self.agent = self._load_agent()
        self.turn = 1
        
    
    def _print_config(self):
        print("\n\n")
        print("=================="*5)
        print(f"Current task: {self.task_objective}")
        print(f"Max dialogue history for belief state context: {self.data_args.dialog_history_limit_bs}")
        print(f"Max dialogue history for response generation context: {self.data_args.dialog_history_limit_rg}")
        print(f"Config file: {self.data_args.config_path}")
        print(f"Logging directory: {self.data_args.log_path}")
        print(f"Knowledge base path: {self.data_args.load_path}\n")
        print("To exit the interaction, type 'q' or 'quit' in the prompt")
        print("=================="*5)
        print("\n\n")
        


    def _load_agent(self):
        #Only support using OpenAI models currently (GPT3.5, GPT4)
        
        model = OpenAI(model_name=self.model_args.model_name_or_path, 
                       temperature=self.model_args.temperature)

        agent = create_pandas_dataframe_agent(llm=model, 
                                              df=self.kb_df, 
                                              max_iterations=self.data_args.agent_max_iterations, 
                                              verbose=self.data_args.verbose)
        
        return agent

    
    def reset_knowledge_base(self, df):
        
        self.agent = create_pandas_dataframe_agent(llm=self.model, 
                                                   df=df, 
                                                   max_iterations=self.data_args.agent_max_iterations, 
                                                   verbose=self.data_args.verbose)
    
    
    def reset_task_objective(self, objective:str):
        
        self.task_objective = objective

        
    
    def _parse_dialogue_history(self, mode):
        
        prompt_dh = ""       
        speakers = [v for turn, content in self.conversation_logs[self.session_id].items() for k, v in content.items() if k == "speaker"]
        dh = [v for turn, content in self.conversation_logs[self.session_id].items() for k, v in content.items() if k == "utterance"]
        
        if mode == "bs":
            dh_limit = self.data_args.dialog_history_limit_bs
        elif mode == "rg":
            dh_limit = self.data_args.dialog_history_limit_rg
        else:
            raise ValueError("Can only parse dialogue history in two modes: bs or rg.")
        
        #Assume here that we only use this method when the last turn was from the user
        if len(dh) < dh_limit:
            L = len(dh)
        else:
            L = dh_limit

        for idx in reversed(range(1, L+1)):
            prompt_dh += f"{speakers[-idx]}: {dh[-idx]}\n"

        return prompt_dh


    def _tod_or_chitchat(self, utterance):
        
        dialogue_context = self._parse_dialogue_history(mode="rg")
        prompt = self.prompt_classification_template.format(instruction=self.classification_instruction,
                                                            example=self.classification_example,
                                                            task=self.task_objective,
                                                            dialogue_context=dialogue_context,
                                                            utterance=utterance)
        output = self._completion(prompt) 
        return output

    
    def _tod_turn(self, utterance):
        
        # print(f"USER: {utterance}")
        dialogue_context_bs = self._parse_dialogue_history(mode="bs") + f"USER: {utterance}"
        prompt = self.prompt_bs_template.format(instruction=self.bs_instruction,
                                                example=self.bs_example,
                                                information=", ".join(self.attributes),
                                                dialogue_context=dialogue_context_bs)
        output = self._completion(prompt)
        print(f"\n----------------\nCurrent belief state: {output}")
        self._update_logs(utterance, "USER", bs=output)

        with get_openai_callback() as cb:
            try:
                query_df = self.agent.run(f"If there are many fitting this criteria, pick a few to propose: {output}") #Use fake intermediary belief state
            except ValueError as e:
                response = str(e)
                if not response.startswith("Could not parse LLM output: `"):
                    raise e
                query_df = response.removeprefix("Could not parse LLM output: `").removesuffix("`")
            if self.model_args.print_cost:
                print(f"Total Tokens: {cb.total_tokens}")
                print(f"Prompt Tokens: {cb.prompt_tokens}")
                print(f"Completion Tokens: {cb.completion_tokens}")
                print(f"Total Cost (USD): ${cb.total_cost}")
                
        if query_df == "Agent stopped due to iteration limit or time limit.":
            query_df = "There is nothing that fits the criteria. Ask for more information."
        print(f"Query database results: {query_df}\n----------------\n")
        dialogue_context_rg = self._parse_dialogue_history(mode="rg") + f"USER: {utterance}\nSYSTEM:"
        prompt = self.prompt_rg_template.format(instruction=self.rg_instruction,
                                                example=self.rg_example,
                                                dialogue_context=dialogue_context_rg,
                                                dialogue_act=query_df)
        response = self._completion(prompt)
        print(f"SYSTEM: {response}")
        self._update_logs(response, "SYSTEM", mode=self.mode, kb=query_df)

        return response

    
    def _chitchat_turn(self, utterance):
        
        self._update_logs(utterance, "USER")
        # print(f"USER: {utterance}")
        dialogue_context = self._parse_dialogue_history(mode="rg") + f"USER: {utterance}\nSYSTEM:"
        prompt = self.prompt_chitchat_template.format(instruction=self.chitchat_instruction,
                                                      dialogue_context=dialogue_context)
        # print(dialogue_context)
        response = self._completion(prompt)
        print(f"SYSTEM: {response}")
        self._update_logs(response, "SYSTEM", self.mode)
        
        return response

    
    def interact(self):
        
        self.session_id = random.randint(100000000,999999999)
        self._create_session_logs()
        utterance = ""
        
        self._print_config()
        print(f"{self.welcome_sentence}\n=========================================\n")

        while True:
            utterance = input("USER: ")
            if utterance == "quit" or utterance == "q":
                self._save_logs()
                return "Finished" #TBD
            
            
            is_tod = self._tod_or_chitchat(utterance)
            
            if "yes" in is_tod.lower():
                self.mode = "Task-Oriented Dialogue"
                self._tod_turn(utterance)
            elif "no" in is_tod.lower():
                self.mode = "Chitchat"
                self._chitchat_turn(utterance)
            print("=========================================\n")

# TO-DO

- Add mode switch between TOD and chitchat
- Add chitchat with _chitchat_turn()  
- Add logging

In [239]:
model_args = ModelArguments()
data_args = DataArguments()

data_args.load_path = "/home/willy/instructod/demo/data/restaurant_db.json"
data_args.log_path = "/home/willy/instructod/demo/logs/"
data_args.verbose = False
data_args.dialog_history_limit_bs = 2
data_args.dialog_history_limit_rg = 5

model_args.model_name_or_path = "gpt-3.5-turbo-0301"

In [240]:
instructtods = InstructTODS(config=CONFIG,
                            model_args=model_args,
                            data_args=data_args)



In [241]:
#Testing single turn
# instructtods.session_id = "123456789"
# instructtods._create_session_log()
# instructtods._tod_turn("Test")

In [242]:
instructtods.interact()

Hi! I am a TOD with chitchat capability. My current task is Book a restaurant. What can I help you with?


Type: test


USER: test
SYSTEM: I'm sorry, I don't understand what you mean by "test". Can you please provide more information or ask a specific question?


Type: q


'Finished'

In [135]:
# db = json.load(open("data/restaurant_db.json", "r"))

In [202]:
instructtods.conversation_logs

defaultdict(list,
            {262811963: {0: {'utterance': 'Hi! I am a TOD with chitchat capability. My current task is Book a restaurant. What can I help you with?',
               'speaker': 'SYSTEM',
               'mode': '',
               'belief_state': '',
               'database_query': ''},
              1: {'utterance': "What's the biggest cat in the world?",
               'speaker': 'USER',
               'mode': '',
               'belief_state': '',
               'database_query': ''},
              2: {'utterance': "The biggest cat in the world is the Siberian tiger. However, as a restaurant booking assistant, I'm not an expert on animal facts. Is there anything else I can assist you with regarding booking a restaurant?",
               'speaker': 'SYSTEM',
               'mode': 'Chitchat',
               'belief_state': '',
               'database_query': ''},
              3: {'utterance': "It's fine. Can you recommend some british restaurant?",
               's

In [22]:
print(instructtods._parse_dialogue_history())

USER: information about "the oak bistro"
SYSTEM: I'm sorry, but I couldn't find any information about a restaurant that serves British food. Could you please provide more details or try searching for a different restaurant?
USER: this restaurant serves british food
SYSTEM: I'm sorry, but I couldn't find any information about "the oak bistro". Could you please provide more details or try searching for a different restaurant?



In [92]:
instructtods.kb_df

Unnamed: 0,address,area,food,id,introduction,location,name,phone,postcode,pricerange,type,signature
0,Regent Street City Centre,centre,italian,19210,Pizza hut is a large chain with restaurants na...,"[52.20103, 0.126023]",pizza hut city centre,1.223324e+09,cb21ab,cheap,restaurant,
1,Finders Corner Newmarket Road,east,international,30650,,"[52.21768, 0.224907]",the missing sock,1.223813e+09,cb259aq,cheap,restaurant,african babooti
2,106 Regent Street City Centre,centre,indian,19214,curry garden serves traditional indian and ban...,"[52.200187, 0.126407]",curry garden,1.223302e+09,cb21dp,expensive,restaurant,
3,82 Cherry Hinton Road Cherry Hinton,south,chinese,19192,,"[52.188528, 0.140627]",the good luck chinese food takeaway,1.223244e+09,cb17ag,expensive,restaurant,
4,G4 Cambridge Leisure Park Clifton Way Cherry H...,south,italian,19196,pizza hut is a large chain with restaurants na...,"[52.190176, 0.13699]",pizza hut cherry hinton,1.223324e+09,cb17dy,moderate,restaurant,
...,...,...,...,...,...,...,...,...,...,...,...,...
105,Midsummer Common,centre,british,508,,"[52.21251, 0.12774000000000002]",midsummer house restaurant,1.223369e+09,cb41ha,expensive,restaurant,seared scallops with truffle apple and celeriac
106,Bridge Street City Centre,centre,french,19230,cote is a modern french bistro offering some o...,"[52.209028, 0.118296]",cote,1.223311e+09,cb21uf,expensive,restaurant,
107,32 Bridge Street City Centre,centre,italian,19234,caffe uno is a chain of cafe style restaurants...,"[52.209632, 0.117213]",caffe uno,1.223449e+09,cb21uj,expensive,restaurant,
108,17 Hills Road City Centre,centre,chinese,19222,sesame restaurant and bar offers a wide variet...,"[52.197154, 0.129511]",sesame restaurant and bar,1.223359e+09,cb21nw,expensive,restaurant,
