In [1]:
import os
import json
from transformers import AutoTokenizer, pipeline
import torch
from pdb import set_trace as breakpoint

In [2]:
model_name = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="PATH_TO_CATCH_HERE")
llm = pipeline(
    "text-generation",
    model=model_name,
    torch_dtype=torch.float16,
    device_map="auto",
)

DEFAULT_SCENARIOS = {
    "restaurant": {"desc": "The user is in a restaurant ordering food.", "role": "a waiter"},
    "job_interview": {"desc": "The user is in a job interview for a software engineering position.", "role": "an interviewer"},
    "travel": {"desc": "The user is at an airport checking in for a flight.", "role": "a check-in agent"}
}


PROFILE_DIR = "user_profiles"
os.makedirs(PROFILE_DIR, exist_ok=True)



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0


In [53]:
def load_user_profile(username):
    profile_path = os.path.join(PROFILE_DIR, f"{username}.json")
    if os.path.exists(profile_path):
        with open(profile_path, "r") as file:
            user_profile = json.load(file)
    else:
        user_profile = {"username": username, "scenario": None, "chat_history": []}

    if "ai_role" not in user_profile:
        user_profile["ai_role"] = None

    return user_profile


def save_user_profile(user_profile):
    profile_path = os.path.join(PROFILE_DIR, f"{user_profile['username']}.json")
    with open(profile_path, "w") as file:
        json.dump(user_profile, file, indent=4)


def format_prompt(user_profile, user_input):
    scenario_description = user_profile["scenario"]
    chat_history = user_profile["chat_history"]

    prompt = f"[Scenario]: {scenario_description}\n\n"
    for entry in chat_history[-5:]:
        prompt += f"### Human: {entry['user']}\n### Assistant: {entry['ai']}\n"
    
    prompt += f"### Human: {user_input}\n### Assistant:"
    return prompt


def infer_ai_role(scenario_description): # returns suggestions
    prompt = f"Based on the following scenario, what role should the AI play?\n\nScenario: {scenario_description}\n\nAI Role:"
    print(prompt)
    response = llm(
        prompt,
        do_sample=False, 
        max_new_tokens=10,
    )[0]['generated_text'].split("AI Role:")[-1].strip()
    #print("AI ROLE RESPONSE")
    #print(response)
    return response if response else "assistant" 

def talk_agent(username):
    user_profile = load_user_profile(username)

    # Select or set a scenario
    if not user_profile["scenario"]:
        print("Choose a scenario:")
        for key in DEFAULT_SCENARIOS:
            print(f"- {key}: {DEFAULT_SCENARIOS[key]['desc']}")
        scenario_key = input("Enter scenario name (or type 'custom' to create your own): ").strip().lower()

        if scenario_key in DEFAULT_SCENARIOS:
            user_profile["scenario"] = DEFAULT_SCENARIOS[scenario_key]["desc"]
        else:
            user_profile["scenario"] = input("Describe your custom scenario: ")
        #user_profile["ai_role"] = infer_ai_role(user_profile["scenario"])
        user_profile["ai_role"] = DEFAULT_SCENARIOS[scenario_key]["role"] #Set up role
        save_user_profile(user_profile)
        # print(f"Scenario set: {user_profile['scenario']}")
        # print(f"AI will play: {user_profile['ai_role']}")

    else:  # Existing user, infer role if missing
        print("EXISTING USER")
        if not user_profile.get("ai_role"):
            print("CREATING AI ROLE")
            user_profile["ai_role"] = infer_ai_role(user_profile["scenario"])
            save_user_profile(user_profile)

    formatted_prompt = f"[Scenario]: {user_profile['scenario']}\n\n"
    formatted_prompt += f"### Assistant ({user_profile['ai_role']}):\n"
    formatted_prompt += "You are playing the role of the assistant role above in the scenario.\n"
    formatted_prompt += "Start the conversation.\nPost your response here: "
    
    print("\nChat started! Type 'return' to pause.\n")
    response = llm(
        formatted_prompt,
        do_sample=True,
        top_k=50,
        top_p=0.7,
        num_return_sequences=1,
        repetition_penalty=1.1,
        max_new_tokens=100,
    )[0]['generated_text'].split("Post your response here: ")[-1].strip() #Generate Response
    print("RESPONSE START")
    print(f"{user_profile['ai_role']}: {response}")
    # print(f"AI ({user_profile['ai_role']}): {response}")

    user_profile["chat_history"].append({"user": "AI INITIATED", "ai": response})
    save_user_profile(user_profile) #Save Response



    while True:
        user_input = input("You: ")
        
        if user_input.lower() == "return":
            print("Chat paused. You can resume later.")
            save_user_profile(user_profile)
            break
        if user_input.lower() == "suggest line":
            formatted_prompt = f"[Scenario]: {user_profile['scenario']}\n\n"
            formatted_prompt += f"### Assistant ({user_profile['ai_role']}): (Provide three sentences to start the conversation in the form of a numbered list.)"
        else:
            formatted_prompt = f"[Scenario]: {user_profile['scenario']}\n\n"
            formatted_prompt += f"### Assistant ({user_profile['ai_role']})"
            formatted_prompt += "You are playing the role of the assistant role in the scenario above.\n "
            formatted_prompt += f"The user has responded with the following: {user_input}\n"
            formatted_prompt += "Continue the conversation by saying only your response.\nPost your response here:"
        response = llm(
            formatted_prompt,
            do_sample=True,
            top_k=50,
            top_p=0.7,
            num_return_sequences=1,
            repetition_penalty=1.1,
            max_new_tokens=100,
        )[0]['generated_text'].split("Post your response here:")[-1].strip()
        # breakpoint()
        print(f"{user_profile['ai_role']}: {response}")

        user_profile["chat_history"].append({"user": user_input, "ai": response})
        save_user_profile(user_profile)
#suggest line command: have machine suggest a line

In [54]:
#To do
#1. Figure out better chat functions
#2. Improve beginning of the prompt
#3. Prompt to input for suggestion
#4. Return: jumps straight into conversation: maybe do a quick history/review
username = input("Enter your username: ").strip()
r1 = talk_agent(username)

Enter your username:  hzt


Choose a scenario:
- restaurant: You are in a restaurant ordering food.
- job_interview: You are in a job interview for a software engineering position.
- travel: You are at an airport checking in for a flight.


Enter scenario name (or type 'custom' to create your own):  travel



Chat started! Type 'return' to pause.

RESPONSE START
a check-in agent: "Good afternoon, sir/ma'am! Welcome to [Airline Name]. How can I assist you today?"


You:  Hi. I would like to check in to the airline


a check-in agent: "Of course, sir/ma'am! May I have your boarding pass please? Or would you like to check in online and print it out at the airport?"


You:  I would like to check in online


a check-in agent: "Great, that's easy! Can you please provide me with your boarding pass and ID so I can scan them?"


You:  Sure. Here is my boarding pass and ID


a check-in agent: "Great, thank you! *scans boarding pass* Everything looks good. Would you like to upgrade to a higher class of service? We have options available for Economy Plus or Business Class."


You:  How much is the higher class?


a check-in agent: "The higher class fare for this flight is $200."


You:  return


Chat paused. You can resume later.
