In [None]:
from citation_fetcher import Citation_Fetcher as cf
from datetime import datetime as dt
from transformers import pipeline
from nltk import tokenize as sentence_delim

import nltk
import re
import torch
import yake

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class ResponseGenerator():

    def __init__(self):
        self.base_model = None
        self.base_system_msg = None
        self.kw_extractor = yake.KeywordExtractor()


    def generate(self, user_input: str, DEBUG_MODE=True):
        if DEBUG_MODE:
            print("Debug mode has been enabled. Capturing runtime...\n")
            query_gen_start: dt = dt.now()
        nltk.download('punkt')
        processed_input = sentence_delim.sent_tokenize(user_input)
        last_sentence_index = len(processed_input) - 1
        processed_input = processed_input[last_sentence_index]
        keywords = self.kw_extractor.extract_keywords(processed_input)

        if DEBUG_MODE:
            query_gen_end: dt = dt.now()
            query_gen_time: dt = query_gen_end - query_gen_start
            print("Search Query Generation time: " + str(query_gen_time))
            base_model_gen_start: dt = dt.now()

        base_model_gen_msg  = [
                {
                    "role": "system",
                "content": """You are a friendly chatbot that provides reliable information to the user.
                    Your goals are to reduce suffering in the universe, increase prosperity in the universe, and increase understanding in the universe."""
                },
                {
                    "role": "user",
                    "content": user_input
                }
            ]
        if len(keywords) > 0:
            search_query = keywords[0][0]
            print("Searching for: \"" + search_query + "\"...")
            query_results = cf.search_online(search_query)
            base_model_gen_msg.append({"role": "query_results", "content": query_results})
            base_model_gen_msg[0]["content"].append("You have submitted a query search engine that can help you answer the user's question. Please summarize the query results that can best answer the user's question. Cite each result by copying the \"href\" value.")
        else:
            print("Skipping search query...")
            pass

        self.kw_extractor = None
        del self.kw_extractor
        gc.collect()
        torch.cuda.empty_cache()
        self.base_model = (pipeline("text-generation", model="HuggingFaceH4/zephyr-7b-alpha",
            torch_dtype = torch.bfloat16, device_map="auto"))
        base_model_prompt = self.base_model.tokenizer.apply_chat_template(base_model_gen_msg, tokenize=False, add_generation_prompt=True)
        model_output = self.base_model(base_model_prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
        model_output = model_output.split("<|Assistant|>\n")
        model_output = model_output[1]
        self.base_model = None
        del self.base_model
        gc.collect()
        torch.cuda.empty_cache()

        if DEBUG_MODE:
            base_model_gen_end: dt = dt.now()
            base_model_gen_time: dt = base_model_gen_end - base_model_gen_start
            print("Base/Response Model generation time: " + str(base_model_gen_time))
            total_gen_time: dt = query_gen_time + base_model_gen_time
            print("Total generation time: " + str(total_gen_time) + "\n")
        return query_results, model_output


In [None]:
rg = ResponseGenerator()
query_results, out = rg.generate("Hello, Alethianomous. What is the weather at Atlanta, Georgia?", DEBUG_MODE=True)

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:03<00:00,  2.33it/s]


Debug mode has been enabled. Capturing runtime...



[nltk_data] Downloading package punkt to
[nltk_data]     /home/sp15-chatbot/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


KeyboardInterrupt: 

In [None]:
print(out)

In [None]:
print(rg.generate("When is OwlCon at Kennesaw State University?"))

In [None]:
print(rg.generate("When is the next solar eclipse?"))

In [None]:
print(rg.generate("Do you think you are sentient?"))

In [None]:
prnt(rg.generate("How are cars made?"))