In [4]:
%pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip available: 22.3.1 -> 23.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [1]:
import gradio as gr

global llm
llm = "GPT"

global domain
domain = "restaurant"

global google_api_key
google_api_key = None

with open("demo_css/set_param.css", "r", encoding="utf-8") as f:
    set_param_css = f.read()

with gr.Blocks(css=set_param_css) as demo:
    gr.Markdown("# Please choose a domain of recommendation you would like to try")
    with gr.Row():
        select_domain = gr.Dropdown(
            ["Restaurant", "Clothing"], value="Restaurant",
            show_label=False
        )
    with gr.Row():
        with gr.Column(scale=9):
            google_api_key_box = gr.Textbox(
                    elem_id="google_api_key_box",
                    placeholder="Enter Google API Key. If you don't have one, leave it blank.",
                    label="To get Google API Key, go to https://console.developers.google.com/ and navigate to 'Credentials' tab after logging in to your google cloud account and then click on 'CREATE CREDENTIALS'.",
                    interactive=True)
        with gr.Column(min_width=70, scale=1):
            google_api_key_submit_button = gr.Button(value="Submit", elem_id="google_api_key_submit_button")
    gr.Markdown("# Please choose LLM")
    gr.Markdown("### It's a lot better to use GPT but if you don't want to get openai api key, you can use Alpaca Lora.")
    with gr.Row():
        select_llm = gr.Dropdown(
            ["GPT", "Alpaca Lora"], value="GPT",
            show_label=False
        )
    with gr.Row():
        with gr.Column(scale=9):
            openai_api_key_or_gradio_url_box = gr.Textbox(
                elem_id="openai_api_key_or_gradio_url_box",
                placeholder="Enter OpenAI API Key",
                label="To get OpenAI API go to https://platform.openai.com/playground and after logging into your account, click on 'View API keys'.")
        with gr.Column(min_width=70, scale=1):
            llm_key_submit_button = gr.Button(value="Submit", elem_id="llm_key_submit_button")

    def set_domain(selected_domain: str) -> None:
        """
        Set domain and update the visibility of testbox accordingly.

        :param selected_domain: user selected domain
        """
        global domain
        domain = selected_domain.lower()
        
        if domain == "clothing":
          return gr.update(visible=False)
        else:
          return gr.update(visible=True)

    def set_google_api_key(google_key_input: str) -> None:
        """
        Set google_api_key.

        :param google_key_input: user input for google_api_key
        """
        global google_api_key
        if google_key_input == "":
            google_api_key = None
        else:
            google_api_key = google_key_input

    def set_llm(selected_llm: str) -> None:
        """
        Set llm and update message in textbox accordingly.

        :param selected_llm: user selected llm
        """
        global llm
        llm = selected_llm
        if llm == "GPT":
            return gr.update(
                value="",
                placeholder="Enter OpenAI API Key",
                label="To get OpenAI API Key, go to https://platform.openai.com/playground and after logging into your account, click on 'View API keys'.")
        else:
            return gr.update(
                value="",
                placeholder="Enter Gradio URL",
                label="To get Gradio URL, go to https://colab.research.google.com/drive/1FfKTLmVV0rQSQWkvoGpiyb1RuK7E1l6k?usp=sharing and run cells. The url is after 'Running on public URL:', outputted by the cell below 'Gradio.live API hosting'.")

    def set_openai_api_key_or_gradio_url(openai_api_key_or_gradio_url_input: str) -> None:
        """
        Set openai_api_key_or_gradio_url.

        :param openai_api_key_or_gradio_url_input: user input for openai_api_key_or_gradio_url
        """
        global openai_api_key_or_gradio_url
        openai_api_key_or_gradio_url = openai_api_key_or_gradio_url_input

    select_domain.input(
        fn=set_domain, inputs=select_domain, outputs=google_api_key_box)
    
    google_api_key_box.submit(
        fn=set_google_api_key, inputs=google_api_key_box
    )

    google_api_key_submit_button.click(
        fn=set_google_api_key, inputs=google_api_key_box
    )

    select_llm.input(
        fn=set_llm, inputs=select_llm, outputs=openai_api_key_or_gradio_url_box)

    openai_api_key_or_gradio_url_box.submit(
        fn=set_openai_api_key_or_gradio_url, inputs=openai_api_key_or_gradio_url_box)

    llm_key_submit_button.click(
        fn=set_openai_api_key_or_gradio_url, inputs=openai_api_key_or_gradio_url_box)

if __name__ == "__main__":
    demo.launch()

  from .autonotebook import tqdm as notebook_tqdm


Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


In [5]:
import gradio as gr
import yaml
import time
import os
from conv_rec_system import ConvRecSystem
from domain_specific.classes.restaurants.geocoding.nominatim_wrapper import NominatimWrapper
from domain_specific.classes.restaurants.geocoding.google_v3_wrapper import GoogleV3Wrapper
from domain_specific.classes.restaurants.location_constraint_merger import LocationConstraintMerger
from domain_specific.classes.restaurants.location_status import LocationStatus
from domain_specific.classes.restaurants.location_filter import LocationFilter
from information_retrievers.filter.word_in_filter import WordInFilter

with open("system_config.yaml") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
config['LLM'] = llm
config['PATH_TO_DOMAIN_CONFIGS'] = f'domain_specific/configs/{domain}_configs'

with open("system_config.yaml", "w") as f:
    dump = yaml.safe_dump(config, sort_keys=False)
    f.write(dump)

if domain == "restaurant":
    if google_api_key is None:
        geocoder = NominatimWrapper()
        
        if geocoder.geocode("toronto") is None:
            geocoder = None
    else:
        os.environ['GOOGLE_API_KEY'] = google_api_key
        geocoder = GoogleV3Wrapper()
        
    if geocoder is None:
        user_filter_objects = [WordInFilter(["location"], "address")]

        conv_rec_system = ConvRecSystem(
            config, openai_api_key_or_gradio_url,
            user_defined_filter=user_filter_objects, user_interface_str="demo")
    else:
        user_constraint_merger_objects = [LocationConstraintMerger(geocoder)]
        user_constraint_status_objects = [LocationStatus(geocoder)]
        user_filter_objects = [LocationFilter("location", ["latitude", "longitude"], 2, geocoder)]

        conv_rec_system = ConvRecSystem(
            config, openai_api_key_or_gradio_url, user_defined_constraint_mergers=user_constraint_merger_objects,
            user_constraint_status_objects=user_constraint_status_objects,
            user_defined_filter=user_filter_objects, user_interface_str="demo")
else:
    conv_rec_system = ConvRecSystem(
        config, openai_api_key_or_gradio_url, user_interface_str="demo")

with open("demo_css/chatbot.css", "r", encoding="utf-8") as f:
    chatbot_css = f.read()

with gr.Blocks(css=chatbot_css) as demo:
    gr.Markdown("# LLM Convrec")
    history = gr.State([[None, conv_rec_system.init_msg]])
    with gr.Row():
        chatbot = gr.Chatbot(
            value=[[None, conv_rec_system.init_msg]], show_label=False, elem_id="llm_conv_rec")
    with gr.Row(equal_height=True):
        with gr.Column(scale=8):
            user_input = gr.Textbox(show_label=False, placeholder="Enter text", container=False)
        with gr.Column(min_width=70, scale=1):
            send_button = gr.Button(value="Send")
        with gr.Column(min_width=70, scale=3):
            new_conv_button = gr.Button(value="New Conversation")

    def display_user_input(user_message, chatbot, history) -> tuple[gr.Textbox, gr.Chatbot, gr.State]:
        """
        Display user input.
        
        :param user_message: user input
        :param chatbot: chatbot that display chat history
        :param history: chat history
        :return: a tuple of textbox,chatbot, and state that are updated
        """
        chatbot[-1][1] = history[-1][1]
        return "", chatbot + [[user_message, None]], history + [[user_message, None]]

    def display_recommender_response(chatbot, history) -> tuple[gr.Chatbot, gr.State]:
        """
        Display recommender's response.
        
        :param chatbot: chatbot that display chat history
        :param history: chat history
        :return: a tuple of chatbot and state that are updated
        """
        bot_message = conv_rec_system.get_response(chatbot[-1][0])
        history[-1][1] = bot_message
        chatbot[-1][1] = ""
        for character in bot_message:
            chatbot[-1][1] += character
            time.sleep(0.015)
            yield chatbot, history

    def reset_state() -> tuple[gr.Textbox, gr.Chatbot, gr.State]:
        """
        Reset state.
        
        :return: a tuple of textbox, chatbot, and state that are reset
        """
        conv_rec_system.dialogue_manager.state_manager.reset_state()
        return "", [[None, conv_rec_system.init_msg]], [[None, conv_rec_system.init_msg]]
    
    user_input.submit(
        fn=display_user_input, inputs=[user_input, chatbot, history], 
        outputs=[user_input, chatbot, history], queue=True).then(
            fn=display_recommender_response, inputs=[chatbot, history], outputs=[chatbot, history])
        
    send_button.click(
        fn=display_user_input, inputs=[user_input, chatbot, history], 
        outputs=[user_input, chatbot, history], queue=True).then(
            fn=display_recommender_response, inputs=[chatbot, history], outputs=[chatbot, history])
        
    new_conv_button.click(
        fn=reset_state, outputs=[user_input, chatbot, history], queue=True)


if __name__ == "__main__":
    demo.queue()
    demo.launch()
    

Loaded as API: https://51316bc1ca4d44f687.gradio.live/ ✔


Exception: The provided Gradio URL is invalid. Please input a correct url and retry.