In [6]:
import torch
import gradio as gr
import os
from transformers import pipeline, Pipeline
from datetime import datetime
import json

'''
Model picker
    -Token prompt
Check point picker
Stream of chats
Export to story button

-Implement saving checkpoints to json
-Implement loading full checkpoint state
-Implement saving previous
-Implement Export Story
'''
pipe = None
token = ""

# Function to unload the model
def unload_model():
    global pipe
    if 'pipe' in globals():
        del pipe
        torch.cuda.empty_cache()

def load_model(model_name):
    global pipe, token
    unload_model()
    if len(token) > 0:
        pipe = pipeline("text-generation", model=model_name, device='cuda', token=token)
        return
    pipe = pipeline("text-generation", model=model_name, device='cuda')

def changeToken(newToken):
    global token
    token = newToken

# Close Gradio interface if open
unload_model()
load_model("Qwen/Qwen2-0.5B-Instruct")
if 'iface' in globals() and iface is not None:
    iface.close()

def llama_inference(history, new_prompt):
    global pipe
    # Create a combined prompt with the entire chat history
    combined_prompt = ""
    for (prev_prompt, response) in history:
        combined_prompt += f"User: {prev_prompt}\nAssistant: {response}\n"
    combined_prompt += f"User: {new_prompt}\nAssistant: "
    messages = [{"role": "user", "content": combined_prompt}]
    response = ""
    if "Qwen" in str(pipe.model): 
        response = pipe(messages, max_length=1024)
    else:
        response = pipe(messages)
    
    # Log response and messages for debugging
    debug_info = f"{messages} response: {response}"
    
    # Assuming the response structure might need adjustments
    generated_text = response[0]['generated_text']
    if isinstance(generated_text, list):
        content = generated_text[1]['content']  # Adjusting if the response is a list
    else:
        content = generated_text  # Direct access if not a list
    
    return content, debug_info

css = """
    .white-background textarea, .white-background input {
        background-color: white !important;
        color: black !important;
        -webkit-text-fill-color: black !important;
    }
    .file-preview .empty {
        display: none;  /* Adjust the height to your preference */
    }
    .file-preview .full {
        display: block;  /* Adjust the height to your preference */
    }
"""

history = []
# Create the Gradio interface
with gr.Blocks(css=css) as iface:
    global history   
    chatbot = gr.Chatbot(elem_id="chatbot")      
    checkpoint_dropdown = gr.Dropdown(choices=os.listdir("checkpoints")[:-1], label="CheckPoints", interactive=True)
    
    with gr.Row():
        with gr.Column(scale=3):
            model_dropdown = gr.Dropdown(
                choices=["Qwen/Qwen2-0.5B-Instruct", "meta-llama/Meta-Llama-3.1-8B"],
                label="Select Model",
                value="Qwen/Qwen2-0.5B-Instruct",
                interactive=True,
                change=load_model
            )
            hf_token_box = gr.Textbox(placeholder="Enter your Hugging Face token...", label="Hugging Face Token", type="password", elem_classes=["white-background"])
            clear_button = gr.Button("Clear")
            debug_output = gr.Textbox(lines=10, placeholder="Debug information will appear here...", label="Debug Output")
            
        with gr.Column(scale=7):
            export_file = gr.File(label="Export Storyline", elem_classes=["file-preview"])            
            prompt_input = gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt", elem_classes=["white-background"])
            submit_button = gr.Button("Submit")
            save_chkpt_button = gr.Button("Save Checkpoint")
            download_button = gr.Button("Export Storyline")

    def update_chatbot(prompt):
        global history
        response, debug_info = llama_inference(history, prompt)
        history.append((prompt, response))
        return history, debug_info  # Returns a list of tuples (input, response)
    
    def clear_history():
        global history
        history.clear()
        return history, "Cleared chatbot"
    
    def export_story():
        file_name = datetime.now().strftime("%m-%d-%H-%M-%S")
        file_path = f"{file_name}.txt"
        output = ""
        for prompt, response in history:
            output += "---\n" + response + "\n"
        with open(f'story-lines/{file_path}', "w") as file:
            file.write(output)
        return file_path
    
    def load_chkpt(name: str):
        global history
        checkpoint_dropdown.update(choices=os.listdir(f"checkpoints")[::-1])
        file_path = f"checkpoints/{name}"
        with open(file_path, 'r') as json_file:
            history = json.load(json_file)
        # Convert the history to the appropriate format for the chatbot component
        chatbot_history = [(item[0], item[1]) for item in history]
        checkpoint_dropdown.choices = os.listdir('checkpoints')[:-1]
        return chatbot_history, f"Loaded {name}", gr.Dropdown.update(choices=os.listdir('checkpoints')[:-1])
    
    def save_chkpt():
        chkpt_name = datetime.now().strftime("%m-%d-%H-%M-%S")
        file_path = f"checkpoints/{chkpt_name}.json"
        with open(file_path, 'w') as json_file:
            json.dump(history, json_file)
        return f"Checkpoint {chkpt_name} saved."
        
    submit_button.click(update_chatbot, [prompt_input], [chatbot, debug_output])
    clear_button.click(clear_history, [], [chatbot, debug_output])
    download_button.click(export_story, [], export_file)
    save_chkpt_button.click(save_chkpt, [], [debug_output])

    hf_token_box.change(changeToken, [hf_token_box], None)
    checkpoint_dropdown.change(load_chkpt, [checkpoint_dropdown], [chatbot, debug_output, checkpoint_dropdown])

# Launch the Gradio app
iface.launch()


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
  model_dropdown = gr.Dropdown(


Closing server running on port: 7860
Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


