### Main method: running TA Chatbot

Implemented:
* Contriever (IR)
* OPT (generation)
* MS_Marco (ranking)
* Separately, DocQuery retrieves "quotes" the textbook. It gives short, direct answers.

Todo: 
* Implement Flan-T5 (instruction-finetuned) instead of, or _maybe_ in addition to, OPT. 
* UI/UX.
  * Enable user-provided contexts to return direct answers about a piece of text they are studying.
* Fine-tuned models. 

Dream ideas:
* Implement roberta QA Pipeline (from Jerome's DocQuery notebook)
* Implement huggingface SentenceTransformers 
* Implement context search from GPT embeddings 'QA' specialized models.

TO RUN: you must clone all our repos into the same top directory. 

Required directory structure:
```text
-- ta_chatbot
---- main_fun
---- data-generator
---- info-retrieval
```

In [1]:
import torch
import main


document-question-answering is already registered. Overwriting pipeline for task document-question-answering...


In [2]:
# Initialize TA pipeline class & parameters
USER_QUESTION = "What is the Moore machine?"
NUM_ANSWERS_GENERATED = 5
# device to load models(OPT)
device = torch.device("cuda:1") 
# finetuned weight, set None to load pretrained weight
opt_weight_path = "../lgm/data/model_weight/opt_finetune_b128_e20_lr5e06.pt"
# TA pipeline class
ta = main.TA_Pipeline(device = device ,opt_weight_path = opt_weight_path)

In [3]:
# TA pipeline

# contriever: find relevant passages (retriever: document + question = 5 passages)
top_context_list = ta.retrieve(user_question=USER_QUESTION, num_answers_generated=NUM_ANSWERS_GENERATED)
# opt-1.3B: 5 passage + 1 question = 5 answer
generated_answers_list = ta.OPT(user_question = USER_QUESTION, top_context_list = top_context_list, num_answers_generated= NUM_ANSWERS_GENERATED,  print_answers_to_stdout = True)
# rank 5 answers
scores = ta.re_ranking_ms_marco(generated_answers_list)
# print best answer
index_of_best_answer = torch.argmax(scores) # get best answer
print("\n-------------------------------------------------------------\n")
print("[Question]: \n", USER_QUESTION)
print("[Best answer]: \n", generated_answers_list[index_of_best_answer])

User question:  What is the Moore machine?
["Beginning with the state 0000, at the rising clock edge, the left (S_0)flip-flop toggles to 1.  The second (S_1) flip-flop sees this change as afalling clock edge and does nothing, leaving the counter instate 0001.  When the next rising clock edge arrives, the leftflip-flop toggles back to 0, which the second flip-flop sees as arising clock edge, causing it to toggle to 1.  The third (S_2) flip-flopsees the second flip-flop's change as a falling edge and does nothing,and the state settles as 0010.  We leave verification of the remainder of the cycle as an exercise.", 'The signal A in the timing diagram is an output from the FSM, andindicates whether or not the coin should be accepted.  This signal controls the servo that drives the gate, and thus determines whetherthe coin is accepted (A=1) as payment or rejected (A=0) and returnedto the user.', "A Mealy implementation of the FSM appears on the left below, andan example timing diagram illust

In [None]:
answer = ta.doc_query(user_question=USER_QUESTION, num_answers_generated=NUM_ANSWERS_GENERATED)
print("The user question is: ", USER_QUESTION)
print("DocQuery's concise, direct quote, answer: \n", answer)

#### UX via Gradio

In [4]:
import gradio as gr
import numpy as np

## DOCS: https://gradio.app/getting_started/
## See figma for our ideal design.

def greet(name):
    return "Hello " + name + "!"

demo = gr.Interface(
    fn=greet,
    inputs=gr.Textbox(lines=2, placeholder="Name Here..."),
    outputs="text",
)

def sepia(input_img):
    sepia_filter = np.array([
        [0.393, 0.769, 0.189], 
        [0.349, 0.686, 0.168], 
        [0.272, 0.534, 0.131]
    ])
    sepia_img = input_img.dot(sepia_filter.T)
    sepia_img /= sepia_img.max()
    return sepia_img

demo = gr.Interface(sepia, gr.Image(shape=(200, 200)), "image")
demo.launch(share=True)

# demo.launch()
# ssh -L 7860:localhost:7860 kastan@kastan

Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://782d8e592970c581.gradio.app

This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces


(<gradio.routes.App at 0x7f8871b995e0>,
 'http://127.0.0.1:7860/',
 'https://782d8e592970c581.gradio.app')

In [6]:
import random
import gradio as gr

def chat(message, history):
    history = history or []
    message = message.lower()
    response = "hello"

    history.append((message, response))
    return history, history

chatbot = gr.Chatbot().style(color_map=("green", "pink"))
demo = gr.Interface(
    chat,
    ["text", "state"],
    [chatbot, "state"],
    allow_flagging="never",
)
if __name__ == "__main__":
    demo.launch(share = True)

Running on local URL:  http://127.0.0.1:7865
Running on public URL: https://2cb9f79ca59371ba.gradio.app

This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces


In [None]:
import main
class ui_ta():
    def __init__(self):
        self.ta = main.TA_Pipeline(device = torch.device("cuda:1"),opt_weight_path = "../lgm/data/model_weight/opt_finetune_b128_e20_lr5e06.pt")
    def 
    
    
def gradio_chat(message, history):
    history = history or []
    response = "hello"

    history.append((message, response))
    return history, history