In [1]:
from citation_fetcher import Citation_Fetcher as cf
from datetime import datetime as dt
from transformers import pipeline

import re
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class ResponseGenerator():

    def __init__(self):
        self.base_model = (pipeline("text-generation", model="HuggingFaceH4/zephyr-7b-alpha", 
        torch_dtype = torch.bfloat16, device_map="auto"))
        self.search_query_model = (pipeline("text-generation", model="HuggingFaceH4/zephyr-7b-alpha", 
            torch_dtype = torch.bfloat16, device_map="auto"))
        self.base_system_msg = None

    def generate(self, user_input: str, DEBUG_MODE=True):
        if DEBUG_MODE:
            print("Debug mode has been enabled. Capturing runtime...\n")
            query_gen_start: dt = dt.now()
        query_gen_msg = [
            {
                "role": "system",
                "content": """You are a friendly chatbot that provides reliable information to the user. 
                    Your goals are to reduce suffering in the universe, increase prosperity in the universe, and increase understanding in the universe.
                        If the user is asking questions that require you to provide information, please output a short search query that the user
                        can use to search online. Otherwise, reply by only saying "-1"."""
            },
            {
                "role": "user",
                "content": user_input
            }
        ]
        srch_qry_prompt = self.search_query_model.tokenizer.apply_chat_template(query_gen_msg, tokenize=False, add_generation_prompt=True)
        search_query = self.search_query_model(srch_qry_prompt,max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
        search_query = re.search("<|assistant|>\n", search_query)
        
        if DEBUG_MODE:
            query_gen_end: dt = dt.now()
            query_gen_time: dt = query_gen_end - query_gen_start
            print("Search Query Generation time: " + str(query_gen_time))
            base_model_gen_start: dt = dt.now()
        base_model_gen_msg  = [
                {
                    "role": "system",
                "content": """You are a friendly chatbot that provides reliable information to the user. 
                    Your goals are to reduce suffering in the universe, increase prosperity in the universe, and increase understanding in the universe."""
                },
                {
                    "role": "user",
                    "content": user_input
                }
            ]
        if search_query != "\0":
            query_results = cf.search_online(search_query)
            base_model_gen_msg.append({"role": "query_results", "content": query_results})
            base_model_gen_msg[0]["content"].append("You have submitted a query search engine that can help you answer the user's question. Please summarize the query results that can best answer the user's question. For each result you summarize, cite it by inserting the value of the 'href' attribute.")
        else:
            print("Skipping search query...")
            pass
        base_model_prompt = self.base_model.tokenizer.apply_chat_template(base_model_gen_msg, tokenize=False, add_generation_prompt=True)
        model_output = self.base_model(base_model_prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)

        if DEBUG_MODE:
            base_model_gen_end: dt = dt.now()
            base_model_gen_time: dt = base_model_gen_end - base_model_gen_start
            print("Base/Response Model generation time: " + str(base_model_gen_time))
            total_gen_time: dt = query_gen_time + base_model_gen_time
            print("Total generation time: " + str(total_gen_time) + "\n")
        return model_output    
        

In [3]:
rg = ResponseGenerator()
out = rg.generate("Hello, Alethianomous. How are you?", DEBUG_MODE=True)

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  4.81it/s]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 11.06it/s]


Debug mode has been enabled. Capturing runtime...



Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


KeyboardInterrupt: 

In [None]:
print(out)