In [4]:
#!/usr/bin/env python
# coding: utf-8

# In[1]:
import gradio as gr
import requests
import threading
import time
from transformers import BartForConditionalGeneration, BartTokenizer

OLLAMA_API_BASE = "http://localhost:11434"

# Load summarization model
model_name = "facebook/bart-large-cnn"
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)

# In[2]:
def list_models():
    try:
        response = requests.get(f"{OLLAMA_API_BASE}/api/tags")
        response.raise_for_status()
        return [model["name"] for model in response.json().get("models", [])]
    except Exception:
        return []

def list_available_models():
    return [
        "llama2", "mistral", "gemma", "codellama", "orca-mini", 
        "phi", "dolphin-mixtral", "llava", "qwen", "tinyllama"
    ]

def is_model_installed(model_name):
    return model_name in list_models()

def pull_model(model_name, progress_callback):
    try:
        response = requests.post(
            f"{OLLAMA_API_BASE}/api/pull",
            json={"name": model_name},
            stream=True
        )
        for line in response.iter_lines():
            if line:
                progress_callback(line.decode())
    except Exception as e:
        progress_callback(f"Error pulling model: {e}")

def install_model_ui(model_name):
    outputs = []
    def pull_callback(update):
        outputs.append(update)
    thread = threading.Thread(target=pull_model, args=(model_name, pull_callback))
    thread.start()
    while thread.is_alive():
        time.sleep(1)
        yield "\n".join(outputs)
    yield "\n".join(outputs)

def summarize_context(history):
    full_text = ""
    for msg in history:
        role = "User" if msg["role"] == "user" else "Assistant"
        full_text += f"{role}: {msg['content']}\n"

    inputs = tokenizer(full_text, return_tensors="pt", max_length=1024, truncation=True)
    summary_ids = model.generate(
        inputs["input_ids"],
        max_length=150,
        min_length=40,
        length_penalty=2.0,
        num_beams=4,
        early_stopping=True
    )
    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

# In[3]:
def handle_user_input(user_message, selected_model, chat_state, context_summary):
    if not is_model_installed(selected_model):
        return [], chat_state, gr.update(visible=True), "", context_summary

    # Append user input and get response
    chat_state.append({"role": "user", "content": user_message})
    try:
        response = requests.post(
            f"{OLLAMA_API_BASE}/api/chat",
            json={
                "model": selected_model,
                "messages": chat_state,
                "stream": False
            }
        )
        reply = response.json()["message"]["content"]
    except Exception as e:
        reply = f"⚠️ Error: {str(e)}"

    chat_state.append({"role": "assistant", "content": reply})

    formatted_chat = [{"role": m["role"], "content": m["content"]} for m in chat_state]

    # Update raw history
    raw = ""
    for i in range(0, len(chat_state), 2):
        user = chat_state[i]["content"]
        bot = chat_state[i + 1]["content"] if i + 1 < len(chat_state) else ""
        raw += f"👤: {user}\n🤖: {bot}\n\n"

    return formatted_chat, chat_state, gr.update(visible=False), raw, context_summary

def save_to_context(chat_state):
    try:
        summary = summarize_context(chat_state)
    except Exception as e:
        summary = f"⚠️ Error summarizing context: {e}"
    return summary

def on_model_select(model_name):
    return gr.update(visible=not is_model_installed(model_name))

def refresh_model_lists():
    installed = list_models()
    all_known = list_available_models()
    not_installed = sorted(set(all_known) - set(installed))
    return (
        gr.update(choices=installed),
        gr.update(choices=not_installed)
    )

# In[4]:
with gr.Blocks() as demo:
    gr.Markdown("## 💬 Chat with Ollama LLMs + Context Bucket")

    with gr.Row():
        with gr.Column(scale=3):
            installed_dropdown = gr.Dropdown(label="Installed Models", choices=list_models(), interactive=True)
            download_dropdown = gr.Dropdown(label="Available to Download", choices=sorted(set(list_available_models()) - set(list_models())), interactive=True)

            chatbot = gr.Chatbot(label="Conversation", type="messages")
            msg = gr.Textbox(label="Your Message", placeholder="Type a message and press Enter")
            state = gr.State([])              # Chat history
            context_state = gr.State("")      # Summarized context

            save_button = gr.Button("📥 Save to Context Bucket")
            install_button = gr.Button("Install Selected Model", visible=False)
            install_output = gr.Textbox(label="Installation Progress", lines=6, visible=True)

        with gr.Column(scale=1):
            gr.Markdown("### 📜 Raw Chat History")
            chat_history_box = gr.Textbox(label="Raw History", lines=15, interactive=False, show_copy_button=True)

            gr.Markdown("### 🪣 Context Bucket (Summarized)")
            context_bucket_box = gr.Textbox(label="", lines=10, interactive=False, show_copy_button=True)

    # Event bindings
    msg.submit(
        handle_user_input,
        inputs=[msg, installed_dropdown, state, context_state],
        outputs=[chatbot, state, install_button, chat_history_box, context_bucket_box]
    )

    download_dropdown.change(on_model_select, inputs=download_dropdown, outputs=install_button)
    install_button.click(install_model_ui, inputs=download_dropdown, outputs=install_output)
    save_button.click(save_to_context, inputs=state, outputs=context_bucket_box)

demo.launch()


* Running on local URL:  http://127.0.0.1:7863
* To create a public link, set `share=True` in `launch()`.


