# 🤖 Gradio ChatUI for Google Colab

<a href="https://colab.research.google.com/github/LaansDole/colab-llm/blob/main/colab_ui_enhanced.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook provides an enhanced chat interface for interacting with your local LLM in Google Colab.

## Step 1: Install and Setup Requirements

First, we'll install all necessary packages and set up the environment.

In [None]:
# Keep this tab alive to prevent Colab from disconnecting you { display-mode: "form" }

#@markdown Press play on the music player that will appear below:
%%html
<audio src="https://oobabooga.github.io/silence.m4a" controls>

In [None]:
# Model selection - Change this to your preferred model
MODEL_NAME = "maryasov/qwen2.5-coder-cline:7b-instruct-q8_0"

# Set environment variables
%env OLLAMA_CONTEXT_LENGTH=16384
%env OLLAMA_HOST=0.0.0.0
%env OLLAMA_KEEP_ALIVE=-1

# Install required system packages
!apt-get install -y lshw pciutils

# Check CUDA and GPU
!nvcc --version
!nvidia-smi

In [None]:
# Check available RAM
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print(f"\n🧠 Available RAM: {ram_gb:.1f} GB")
print("✅ High-RAM runtime!" if ram_gb >= 20 else "❌ Not a high-RAM runtime.")

# Install Ollama
!curl -fsSL https://ollama.com/install.sh | sh

# Install Gradio and other dependencies
!pip install -q gradio==4.14.0 markdown rich

## 🛠️ Step 2: Start Ollama and Set Up Cloudflare Tunnel

This cell starts the Ollama server and creates a Cloudflare tunnel to make it accessible.

In [None]:
import subprocess
import time
import requests
import threading
import re

# Function to start ollama in a background thread
def start_ollama():
    subprocess.call(['ollama', 'serve'])

# Start Ollama server
print("📡 Starting Ollama server...")
ollama_thread = threading.Thread(target=start_ollama)
ollama_thread.daemon = True
ollama_thread.start()

# Wait for Ollama HTTP API to be ready
def wait_for_ollama(timeout=60):
    for i in range(timeout):
        try:
            r = requests.get("http://localhost:11434")
            if r.status_code in [200, 404]:
                print(f"✅ Ollama is up (after {i+1}s).")
                return True
        except requests.exceptions.ConnectionError:
            pass
        print(f"⏳ Waiting for Ollama to start... {i+1}s")
        time.sleep(1)
    print("❌ Ollama did not start in time.")
    return False

# Wait for Ollama to start
if not wait_for_ollama():
    raise RuntimeError("Failed to start Ollama server")

# Pull the model
print(f"📥 Downloading model {MODEL_NAME}...")
!ollama pull {MODEL_NAME}

# Setup Cloudflare tunnel
print("🌐 Setting up Cloudflare tunnel...")
!wget -q https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64 -O cloudflared
!chmod +x cloudflared

# Run cloudflared tunnel in background and get the public URL
cloudflared_proc = subprocess.Popen(
    ['./cloudflared', 'tunnel', '--url', 'http://localhost:11434', '--no-autoupdate'],
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
    text=True
)

# Extract the public URL
public_url = None
for _ in range(30):  # Wait up to 30 seconds for the URL
    line = cloudflared_proc.stdout.readline().strip()
    print(line)
    match = re.search(r'(https://.*\.trycloudflare\.com)', line)
    if match:
        public_url = match.group(1)
        break
    time.sleep(1)

if public_url:
    print(f"\n✅ Public URL for Ollama API:\n{public_url}")
else:
    raise RuntimeError("❌ Could not find public Cloudflare URL.")

# Test the connection with a simple query
print("🧪 Testing connection with a quick query...")
data = {
    "model": MODEL_NAME,
    "prompt": "Say hello in one sentence:",
    "stream": False
}

try:
    response = requests.post(f"{public_url}/api/generate", json=data, timeout=30)
    if response.status_code == 200:
        print(f"✅ Connection test successful! Response: {response.json()['response']}")
    else:
        print(f"⚠️ API returned status code {response.status_code}")
except Exception as e:
    print(f"❌ Connection test failed: {str(e)}")


## Step 3: Create Gradio Chat Interface

Now we'll create an enhanced chat interface with additional features.

In [None]:
import gradio as gr
import json
import time
from typing import List, Tuple, Dict, Any
import markdown
from datetime import datetime

# Keep track of conversation history
conversation_history = []

# Cache for system stats
system_stats = {"last_updated": None, "data": {}}

# Function to format messages for the API
def format_prompt(message: str, history: List[List[str]], system_prompt: str = None) -> str:
    """Format the message and history into a prompt for the API."""
    formatted_prompt = ""
    
    # Add system prompt if provided
    if system_prompt:
        formatted_prompt = f"{system_prompt}\n\n"
    
    # Add conversation history
    for user_msg, assistant_msg in history:
        formatted_prompt += f"User: {user_msg}\n\nAssistant: {assistant_msg}\n\n"
    
    # Add the current message
    formatted_prompt += f"User: {message}\n\nAssistant: "
    
    return formatted_prompt

# Function to chat with the LLM
def chat_with_settings(message: str, 
                      history: List[List[str]], 
                      system_prompt: str,
                      temp: float, 
                      top_p_val: float, 
                      max_tok: int,
                      context_window: int):
    """Chat with the LLM using custom settings."""
    
    if not message.strip():
        return "", history
    
    try:
        # Format the prompt with conversation history
        formatted_prompt = format_prompt(message, history, system_prompt)
        
        # Prepare data for API request
        data = {
            "model": MODEL_NAME,
            "prompt": formatted_prompt,
            "stream": False,
            "options": {
                "temperature": temp,
                "top_p": top_p_val,
                "num_predict": max_tok,
                "num_ctx": context_window
            }
        }
        
        # Check if cloudflared process is still running
        if cloudflared_proc.poll() is None:
            # Start timer to measure response time
            start_time = time.time()
            
            # Send request to Ollama API
            response = requests.post(f"{public_url}/api/generate", json=data, timeout=120)
            
            # Calculate response time
            response_time = time.time() - start_time
            
            if response.status_code == 200:
                result = response.json()
                bot_response = result.get('response', 'No response generated.')
                
                # Get token metrics if available
                eval_count = result.get('eval_count', 0)
                prompt_eval_count = result.get('prompt_eval_count', 0)
                
                # Format metrics for display
                metrics = f"\n\n---\n*Generated {eval_count} tokens in {response_time:.2f}s ({eval_count/response_time:.1f} tokens/s)*"
                
                # Update history with the exchange
                history.append([message, bot_response + metrics])
                conversation_history.append({"role": "user", "content": message})
                conversation_history.append({"role": "assistant", "content": bot_response})
                
                return "", history
            else:
                error_msg = f"❌ API Error: {response.status_code} - {response.text}"
                history.append([message, error_msg])
                return "", history
        else:
            error_msg = "❌ Cloudflared tunnel is not running. Please restart the tunnel."
            history.append([message, error_msg])
            return "", history
            
    except requests.exceptions.Timeout:
        error_msg = "⏱️ Request timed out. The model might be taking too long to respond."
        history.append([message, error_msg])
        return "", history
    except requests.exceptions.ConnectionError:
        error_msg = "🔌 Connection error. Please check if the tunnel is still active."
        history.append([message, error_msg])
        return "", history
    except Exception as e:
        error_msg = f"❌ Error: {str(e)}"
        history.append([message, error_msg])
        return "", history

# Function to clear chat history
def clear_chat():
    """Clear the chat history."""
    global conversation_history
    conversation_history = []
    return [], []

# Function to retry the last message
def retry_last_message(history, system_prompt, temp, top_p_val, max_tok, context_window):
    """Retry the last message with potentially new settings."""
    if not history:
        return history
    
    last_user_msg = history[-1][0]
    # Remove the last exchange
    history = history[:-1]
    # Resend the last user message
    _, updated_history = chat_with_settings(last_user_msg, history, system_prompt, temp, top_p_val, max_tok, context_window)
    return updated_history

# Function to get model info
def get_model_info():
    """Get information about the current model."""
    try:
        if cloudflared_proc.poll() is None:
            response = requests.get(f"{public_url}/api/tags", timeout=10)
            if response.status_code == 200:
                models = response.json().get('models', [])
                current_model = next((m for m in models if m['name'] == MODEL_NAME), None)
                if current_model:
                    return f"📋 **Current Model:** {MODEL_NAME}\n📏 **Size:** {current_model.get('size', 'Unknown')} bytes\n🏷️ **Modified:** {current_model.get('modified_at', 'Unknown')}"
        return f"📋 **Current Model:** {MODEL_NAME}\n⚠️ **Status:** Model info unavailable"
    except Exception as e:
        return f"📋 **Current Model:** {MODEL_NAME}\n❌ **Status:** Cannot retrieve model info ({str(e)})"

# Function to save conversation
def save_conversation(history):
    """Save the current conversation to a JSON file."""
    if not history:
        return "❌ No conversation to save."
    
    try:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"conversation_{timestamp}.json"
        
        with open(filename, "w") as f:
            json.dump(conversation_history, f, indent=2)
        
        return f"✅ Conversation saved to {filename}"
    except Exception as e:
        return f"❌ Error saving conversation: {str(e)}"

# Custom CSS for better styling
custom_css = """
.gradio-container {
    max-width: 1200px !important;
    margin: auto;
}

.chat-message {
    padding: 10px;
    margin: 5px 0;
    border-radius: 10px;
}

.user-message {
    background-color: #e3f2fd;
    margin-left: 20%;
}

.bot-message {
    background-color: #f5f5f5;
    margin-right: 20%;
}

.message-header {
    font-weight: bold;
    margin-bottom: 5px;
}
"""

print("🎨 Setting up Gradio interface...")

# Create the Gradio interface
with gr.Blocks(css=custom_css, title="🤖 Enhanced LLM Chat", theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # 🤖 Enhanced LLM Chat Interface
        
        Chat with your locally running LLM via Ollama. The model is accessible through a Cloudflare tunnel.
        """
    )
    
    with gr.Row():
        with gr.Column(scale=3):
            # Main chat interface
            chatbot = gr.Chatbot(
                [],
                elem_id="chatbot",
                bubble_full_width=False,
                height=500,
                show_label=False
            )
            
            with gr.Row():
                msg = gr.Textbox(
                    placeholder="Type your message here...",
                    container=False,
                    scale=4,
                    show_label=False
                )
                send_btn = gr.Button("Send 📤", scale=1, variant="primary")
            
            with gr.Row():
                clear_btn = gr.Button("Clear Chat 🗑️", scale=1)
                retry_btn = gr.Button("Retry Last 🔄", scale=1)
                save_btn = gr.Button("Save Chat 💾", scale=1)
            
            # Status message area
            status_msg = gr.Markdown("")
        
        with gr.Column(scale=1):
            # Sidebar with model info and controls
            gr.Markdown("### 🔧 Model Information")
            model_info = gr.Markdown(get_model_info())
            refresh_info_btn = gr.Button("Refresh Model Info 🔄", size="sm")
            
            gr.Markdown("### 🧠 System Prompt")
            system_prompt = gr.Textbox(
                placeholder="Optional: Set a system prompt to guide the model's behavior",
                label="System Prompt",
                value="You are a helpful AI assistant.",
                lines=3
            )
            
            gr.Markdown("### ⚙️ Settings")
            with gr.Accordion("Generation Parameters", open=True):
                temperature = gr.Slider(
                    minimum=0.1,
                    maximum=2.0,
                    value=0.7,
                    step=0.1,
                    label="Temperature",
                    info="Controls randomness (lower = more focused)"
                )
                
                top_p = gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=0.9,
                    step=0.05,
                    label="Top P",
                    info="Controls diversity (lower = more focused)"
                )
                
                max_tokens = gr.Slider(
                    minimum=100,
                    maximum=4096,
                    value=2048,
                    step=100,
                    label="Max Tokens",
                    info="Maximum response length"
                )
                
                context_window = gr.Slider(
                    minimum=1024,
                    maximum=16384,
                    value=8192,
                    step=1024,
                    label="Context Window",
                    info="Maximum context length"
                )
    
    # Event handlers
    msg.submit(
        chat_with_settings,
        inputs=[msg, chatbot, system_prompt, temperature, top_p, max_tokens, context_window],
        outputs=[msg, chatbot]
    )
    
    send_btn.click(
        chat_with_settings,
        inputs=[msg, chatbot, system_prompt, temperature, top_p, max_tokens, context_window],
        outputs=[msg, chatbot]
    )
    
    clear_btn.click(
        clear_chat,
        outputs=[chatbot, msg]
    )
    
    retry_btn.click(
        retry_last_message,
        inputs=[chatbot, system_prompt, temperature, top_p, max_tokens, context_window],
        outputs=[chatbot]
    )
    
    refresh_info_btn.click(
        get_model_info,
        outputs=[model_info]
    )
    
    save_btn.click(
        save_conversation,
        inputs=[chatbot],
        outputs=[status_msg]
    )

print("✅ Gradio interface created successfully!")
print("🚀 Launching the web interface...")

# Launch the interface
demo.launch(
    share=True,  # Create a public link
    server_name="0.0.0.0",  # Allow external connections
    server_port=7860,  # Use a specific port
    show_error=True,
    debug=False
)

### How to Use:

1. **Click the public Gradio link** that appears above
2. **Type your questions** in the chat input field
3. **Adjust settings** in the right sidebar:
   - **System Prompt**: Sets the personality/behavior of the AI
   - **Temperature**: Controls creativity (0.1 = focused, 2.0 = creative)
   - **Top P**: Controls diversity (0.1 = narrow, 1.0 = diverse)
   - **Max Tokens**: Controls maximum response length
   - **Context Window**: Sets how much conversation history to include
4. **Use the buttons**:
   - **Send**: Submit your message
   - **Clear Chat**: Reset conversation
   - **Retry Last**: Regenerate the last response with current settings
   - **Save Chat**: Export conversation as JSON
   - **Refresh Model Info**: Update model status