# Rafael Espinosa Mena
### Llama 2 Medical Q&A Project UI - Part 1
### joseespi@usc.edu

### Install and Import Required Libraries

In [None]:
!pip install gradientai --upgrade

In [None]:
from gradientai import Gradient
import tkinter as tk
from tkinter.scrolledtext import ScrolledText
import threading
import time
import textwrap
import os

In [None]:
# set access tokens
os.environ['GRADIENT_ACCESS_TOKEN'] = "zAGvdVCsaZwT6Ltj8HXviqAL0Mw0N8sx"
os.environ['GRADIENT_WORKSPACE_ID'] = "5a3b9526-2d38-486b-a63c-9e0300b06a4f_workspace"
model_id = 'dfd79298-d184-4a95-a62d-ae46ef38701a_model_adapter'

### Download Model

In [None]:
gradient = Gradient()
new_model_adapter = gradient.get_model_adapter(model_adapter_id = model_id)

### Activate Interactive GUI

In [None]:
# note: please be waiting with the GUI when loading it up and generating a response as it is a little slow
# due to Llama2's massive size
class App:
    def __init__(self, root):
        self.root = root
        self.root.title("Finetuned Llama2 LLM for Medical Q&A Part 1")
        
        # Customize background colors here
        input_bg_color = 'white'  # Background color for the input text area
        output_bg_color = 'white'  # Background color for the output text area
        root_bg_color = 'lightgray'  # Background color for the root window

        self.root.configure(bg=root_bg_color)

        # Create text area for input with a placeholder and set width and height
        self.input_text = ScrolledText(root, width=100, height=10, bg=input_bg_color)
        self.input_text.insert(tk.END, "Write your prompt here")
        self.input_text.bind("<FocusIn>", self.on_focus_in)
        self.input_text.bind("<FocusOut>", self.on_focus_out)
        self.input_text.pack(padx=10, pady=10)

        # Create a button that will generate the text
        self.generate_button = tk.Button(root, text='Generate', command=self.on_generate_clicked)
        self.generate_button.pack(padx=10, pady=10)

        # Create an output area for the generated text and set width and height
        self.output = ScrolledText(root, width=100, height=10, bg=output_bg_color, state='disabled')
        self.output.pack(padx=10, pady=10)

        self.generating = False  # Flag to control the generating message update

    def on_focus_in(self, event):
        if self.input_text.get("1.0", tk.END).strip() == "Write your prompt here":
            self.input_text.delete("1.0", tk.END)

    def on_focus_out(self, event):
        if not self.input_text.get("1.0", tk.END).strip():
            self.input_text.insert(tk.END, "Write your prompt here")

    def on_generate_clicked(self):
        # Get user input
        user_input = self.input_text.get("1.0", tk.END).strip()
        if user_input == "Write your prompt here":
            user_input = ""

        self.generating = True
        threading.Thread(target=self.update_generating_message).start()

        # Start a new thread for generating text to keep the UI responsive
        threading.Thread(target=self.generate_text, args=(user_input,)).start()

    def update_generating_message(self):
        dot_count = 0
        while self.generating and self.output.winfo_exists():  # Check if output widget still exists
            dot_count = (dot_count + 1) % 4
            text = "Generating response" + "." * dot_count + "\n"
            self.display_output(text, clear_previous=True)
            time.sleep(0.6)

    def generate_text(self, user_input):
        # Call the model to generate text based on the user's input
        try:
            completion = new_model_adapter.complete(query=user_input, max_generated_token_count=260).generated_output
            wrapped_completion = self.wrap_text(completion)  # Wrap the text properly
        except Exception as e:
            wrapped_completion = f"An error occurred: {str(e)}"
        
        self.generating = False  # Stop updating the generating message
        # Update the output area with the generated text
        self.display_output(wrapped_completion)

    def display_output(self, text, clear_previous=True):
        if self.output.winfo_exists():  # Check if the output widget exists
            self.output.configure(state='normal')
            if clear_previous:
                self.output.delete("1.0", tk.END)
            self.output.insert(tk.END, text)
            self.output.configure(state='disabled')
            
    def wrap_text(self, text, width=90):
        # Wrap text at specified width while treating full stops as part of the word
        wrapped_lines = textwrap.wrap(text, width=width, break_long_words=False, replace_whitespace=False)
        return '\n'.join(wrapped_lines)
    
    def on_close(self):
        """Call this function to clean up and stop threads before closing the application."""
        self.generating = False  # This will signal the update_generating_message thread to stop
        self.root.destroy() 

if __name__ == "__main__":
    root = tk.Tk()
    app = App(root)
    root.protocol("WM_DELETE_WINDOW", app.on_close)  # Ensure that on_close is called when the window is closed
    root.mainloop()