In [None]:
# Cell1:  Shared Cache Bootstrap
import os, pathlib, torch
import sys
from datetime import datetime

# Shared cache configuration (複製到每本 notebook)
AI_CACHE_ROOT = os.getenv("AI_CACHE_ROOT", "../ai_warehouse/cache")

for k, v in {
    "HF_HOME": f"{AI_CACHE_ROOT}/hf",
    "TRANSFORMERS_CACHE": f"{AI_CACHE_ROOT}/hf/transformers",
    "HF_DATASETS_CACHE": f"{AI_CACHE_ROOT}/hf/datasets",
    "HUGGINGFACE_HUB_CACHE": f"{AI_CACHE_ROOT}/hf/hub",
    "TORCH_HOME": f"{AI_CACHE_ROOT}/torch",
}.items():
    os.environ[k] = v
    pathlib.Path(v).mkdir(parents=True, exist_ok=True)
print("[Cache]", AI_CACHE_ROOT, "| GPU:", torch.cuda.is_available())

In [None]:
# Cell 2: Import Dependencies and LLMAdapter
import gradio as gr
import threading
import time
import traceback
from typing import Iterator, Optional, Dict, Any
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM
import queue


@dataclass
class GenerationState:
    """Track generation state for UI updates"""

    is_generating: bool = False
    should_stop: bool = False
    error_message: Optional[str] = None


class LLMAdapter:
    """Lightweight streaming LLM adapter"""

    def __init__(self, model_id: str = "microsoft/DialoGPT-small", **kwargs):
        print(f"Loading model: {model_id}")
        self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map="auto",
            torch_dtype="auto",
            low_cpu_mem_usage=True,
            **kwargs,
        )
        print(f"Model loaded on device: {self.model.device}")

    def stream_generate(
        self,
        messages: list,
        max_new_tokens: int = 256,
        temperature: float = 0.7,
        stop_event: Optional[threading.Event] = None,
    ) -> Iterator[str]:
        """Stream generation with cancellation support"""
        try:
            # Convert messages to prompt
            prompt = self._messages_to_prompt(messages)

            # Tokenize input
            inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}

            # Track generated text
            generated_text = ""

            # Generate with streaming
            with torch.no_grad():
                for i in range(max_new_tokens):
                    # Check for cancellation
                    if stop_event and stop_event.is_set():
                        yield generated_text + " [CANCELLED]"
                        return

                    # Generate next token
                    outputs = self.model(**inputs)
                    next_token_logits = outputs.logits[0, -1, :] / temperature

                    # Sample next token
                    probs = torch.softmax(next_token_logits, dim=-1)
                    next_token = torch.multinomial(probs, 1)

                    # Decode token
                    token_text = self.tokenizer.decode(
                        next_token, skip_special_tokens=True
                    )

                    # Update generated text
                    generated_text += token_text

                    # Yield current progress
                    yield generated_text

                    # Check for EOS
                    if next_token.item() == self.tokenizer.eos_token_id:
                        break

                    # Update inputs for next iteration
                    inputs["input_ids"] = torch.cat(
                        [inputs["input_ids"], next_token.unsqueeze(0)], dim=-1
                    )
                    inputs["attention_mask"] = torch.cat(
                        [
                            inputs["attention_mask"],
                            torch.ones((1, 1), device=self.model.device),
                        ],
                        dim=-1,
                    )

                    # Add small delay to make streaming visible
                    time.sleep(0.05)

        except Exception as e:
            yield f"Error during generation: {str(e)}"

    def _messages_to_prompt(self, messages: list) -> str:
        """Convert messages to simple prompt format"""
        prompt_parts = []
        for msg in messages:
            role = msg.get("role", "user")
            content = msg.get("content", "")
            if role == "system":
                prompt_parts.append(f"System: {content}")
            elif role == "user":
                prompt_parts.append(f"User: {content}")
            elif role == "assistant":
                prompt_parts.append(f"Assistant: {content}")

        prompt_parts.append("Assistant:")
        return "\n".join(prompt_parts)


# Initialize adapter with small model for testing
llm_adapter = LLMAdapter("microsoft/DialoGPT-small")
generation_state = GenerationState()

In [None]:
# Cell 3: Streaming Generator Implementation
class StreamingChatbot:
    """Chatbot with streaming and retry capabilities"""

    def __init__(self, llm_adapter: LLMAdapter):
        self.llm_adapter = llm_adapter
        self.stop_event = threading.Event()
        self.generation_thread = None

    def generate_streaming_response(
        self, history: list, message: str
    ) -> Iterator[tuple]:
        """Generate streaming response with proper history management"""
        if not message.strip():
            yield history, ""
            return

        # Add user message to history
        history = history + [[message, ""]]

        # Prepare messages for LLM
        messages = self._history_to_messages(
            history[:-1]
        )  # Exclude current incomplete exchange
        messages.append({"role": "user", "content": message})

        try:
            generation_state.is_generating = True
            generation_state.should_stop = False
            generation_state.error_message = None

            # Stream generation
            for partial_response in self.llm_adapter.stream_generate(
                messages,
                max_new_tokens=200,
                temperature=0.7,
                stop_event=self.stop_event,
            ):
                if generation_state.should_stop:
                    break

                # Update last exchange in history
                history[-1][1] = partial_response
                yield history, ""

        except Exception as e:
            generation_state.error_message = f"Generation failed: {str(e)}"
            history[-1][1] = f"❌ Error: {str(e)}"
            yield history, ""
        finally:
            generation_state.is_generating = False

    def _history_to_messages(self, history: list) -> list:
        """Convert chat history to messages format"""
        messages = []
        for user_msg, assistant_msg in history:
            if user_msg:
                messages.append({"role": "user", "content": user_msg})
            if assistant_msg:
                messages.append({"role": "assistant", "content": assistant_msg})
        return messages

    def stop_generation(self):
        """Stop current generation"""
        generation_state.should_stop = True
        self.stop_event.set()

    def retry_last_response(self, history: list) -> Iterator[tuple]:
        """Retry the last failed response"""
        if not history:
            return

        # Get last user message
        last_user_message = history[-1][0] if history else ""

        # Remove last incomplete response
        if history and history[-1][1]:
            history = history[:-1]

        # Regenerate response
        yield from self.generate_streaming_response(history, last_user_message)


# Initialize chatbot
chatbot = StreamingChatbot(llm_adapter)

In [None]:
# Cell 4: Gradio Interface with Event Handling
def create_streaming_ui():
    """Create Gradio interface with streaming support"""

    with gr.Blocks(title="Streaming Chat with Retry", theme=gr.themes.Soft()) as demo:
        gr.Markdown("## 🤖 串流對話機器人")
        gr.Markdown("支援即時串流輸出、錯誤重試與取消功能")

        with gr.Row():
            with gr.Column(scale=4):
                chatbot_ui = gr.Chatbot(
                    label="對話記錄",
                    height=400,
                    show_copy_button=True,
                    bubble_full_width=False,
                )

                with gr.Row():
                    msg_input = gr.Textbox(
                        label="輸入訊息",
                        placeholder="在此輸入您的問題...",
                        scale=4,
                        lines=2,
                    )

                with gr.Row():
                    send_btn = gr.Button("發送", variant="primary", scale=1)
                    stop_btn = gr.Button("停止", variant="stop", scale=1)
                    retry_btn = gr.Button("重試", variant="secondary", scale=1)
                    clear_btn = gr.Button("清除", scale=1)

            with gr.Column(scale=1):
                gr.Markdown("### 狀態面板")
                status_text = gr.Textbox(
                    label="生成狀態", value="就緒", interactive=False
                )

                gr.Markdown("### 設定")
                model_info = gr.Textbox(
                    label="目前模型",
                    value="microsoft/DialoGPT-small",
                    interactive=False,
                )

                error_display = gr.Textbox(
                    label="錯誤訊息", value="", visible=False, interactive=False
                )

        # Event handlers
        def update_status(is_generating: bool, error: str = ""):
            """Update UI status"""
            if error:
                return "❌ 錯誤", True, error
            elif is_generating:
                return "🔄 生成中...", False, ""
            else:
                return "✅ 就緒", False, ""

        def send_message(history, message):
            """Handle send message event"""
            if not message.strip():
                return history, "", "請輸入有效訊息", False, ""

            # Start streaming generation
            return chatbot.generate_streaming_response(history, message)

        def stop_generation():
            """Handle stop generation event"""
            chatbot.stop_generation()
            return "⏹️ 已停止", False, ""

        def retry_generation(history):
            """Handle retry event"""
            return chatbot.retry_last_response(history)

        def clear_chat():
            """Clear chat history"""
            return [], "", "✅ 已清除", False, ""

        # Wire up events
        msg_input.submit(
            send_message,
            inputs=[chatbot_ui, msg_input],
            outputs=[chatbot_ui, msg_input],
        )

        send_btn.click(
            send_message,
            inputs=[chatbot_ui, msg_input],
            outputs=[chatbot_ui, msg_input],
        )

        stop_btn.click(
            stop_generation, outputs=[status_text, error_display, error_display]
        )

        retry_btn.click(retry_generation, inputs=[chatbot_ui], outputs=[chatbot_ui])

        clear_btn.click(
            clear_chat,
            outputs=[chatbot_ui, msg_input, status_text, error_display, error_display],
        )

    return demo

In [None]:
# Cell 5: Error Handling and Retry Logic
class RetryManager:
    """Manage retry attempts and error recovery"""

    def __init__(self, max_retries: int = 3):
        self.max_retries = max_retries
        self.retry_count = 0
        self.last_error = None

    def attempt_generation(self, func, *args, **kwargs):
        """Attempt generation with retry logic"""
        for attempt in range(self.max_retries):
            try:
                self.retry_count = attempt
                result = func(*args, **kwargs)
                self.last_error = None
                return result
            except Exception as e:
                self.last_error = str(e)
                if attempt < self.max_retries - 1:
                    print(f"Attempt {attempt + 1} failed: {e}. Retrying...")
                    time.sleep(1)  # Brief delay before retry
                else:
                    print(f"All {self.max_retries} attempts failed. Last error: {e}")
                    raise e

    def get_retry_status(self) -> dict:
        """Get current retry status"""
        return {
            "retry_count": self.retry_count,
            "max_retries": self.max_retries,
            "last_error": self.last_error,
            "has_failed": self.last_error is not None,
        }


# Enhanced chatbot with retry management
class RobustStreamingChatbot(StreamingChatbot):
    """Enhanced chatbot with robust error handling"""

    def __init__(self, llm_adapter: LLMAdapter, max_retries: int = 2):
        super().__init__(llm_adapter)
        self.retry_manager = RetryManager(max_retries)

    def generate_with_retry(self, history: list, message: str) -> Iterator[tuple]:
        """Generate response with automatic retry on failure"""
        try:
            yield from self.generate_streaming_response(history, message)
        except Exception as e:
            # Add error message to history
            error_history = history + [
                [message, f"❌ 生成失敗: {str(e)}\n💡 點擊「重試」按鈕重新生成"]
            ]
            yield error_history, ""

            # Update error state
            generation_state.error_message = str(e)

In [None]:
# Cell 6: Cancellation Mechanism
class CancellableGenerator:
    """Generator wrapper with cancellation support"""

    def __init__(self):
        self.cancelled = False
        self.cancel_event = threading.Event()

    def cancel(self):
        """Cancel current operation"""
        self.cancelled = True
        self.cancel_event.set()

    def is_cancelled(self) -> bool:
        """Check if operation was cancelled"""
        return self.cancelled or self.cancel_event.is_set()

    def reset(self):
        """Reset cancellation state"""
        self.cancelled = False
        self.cancel_event.clear()


# Global cancellation manager
cancel_manager = CancellableGenerator()


def create_enhanced_ui():
    """Create enhanced UI with proper cancellation"""

    enhanced_chatbot = RobustStreamingChatbot(llm_adapter)

    with gr.Blocks(title="Enhanced Streaming Chat") as demo:
        gr.Markdown("# 🚀 增強版串流對話")

        chatbot_ui = gr.Chatbot(height=450, show_copy_button=True)

        with gr.Row():
            msg_input = gr.Textbox(
                placeholder="輸入您的問題 (支援中文)...", scale=4, lines=3
            )

        with gr.Row():
            send_btn = gr.Button("📤 發送", variant="primary")
            stop_btn = gr.Button("⏹️ 停止", variant="stop")
            retry_btn = gr.Button("🔄 重試", variant="secondary")
            clear_btn = gr.Button("🗑️ 清除")

        with gr.Row():
            status_display = gr.Textbox(
                label="狀態", value="✅ 就緒", interactive=False, scale=2
            )
            retry_info = gr.Textbox(
                label="重試資訊", value="尚未重試", interactive=False, scale=2
            )

        # Enhanced event handlers with proper state management
        def enhanced_send(history, message):
            """Enhanced send with cancellation support"""
            if not message.strip():
                return history, ""

            cancel_manager.reset()
            return enhanced_chatbot.generate_with_retry(history, message)

        def enhanced_stop():
            """Enhanced stop with immediate feedback"""
            cancel_manager.cancel()
            enhanced_chatbot.stop_generation()
            return "⏹️ 生成已停止", "停止於: " + time.strftime("%H:%M:%S")

        def enhanced_retry(history):
            """Enhanced retry with status tracking"""
            if not history:
                return history

            cancel_manager.reset()
            status = enhanced_chatbot.retry_manager.get_retry_status()
            retry_text = f"重試次數: {status['retry_count']}/{status['max_retries']}"

            return (
                enhanced_chatbot.retry_last_response(history),
                "🔄 重試中...",
                retry_text,
            )

        def enhanced_clear():
            """Enhanced clear with confirmation"""
            cancel_manager.cancel()
            return [], "", "✅ 對話已清除", "已重置"

        # Wire events
        send_btn.click(enhanced_send, [chatbot_ui, msg_input], [chatbot_ui, msg_input])
        stop_btn.click(enhanced_stop, outputs=[status_display, retry_info])
        retry_btn.click(
            enhanced_retry, [chatbot_ui], [chatbot_ui, status_display, retry_info]
        )
        clear_btn.click(
            enhanced_clear, outputs=[chatbot_ui, msg_input, status_display, retry_info]
        )

        msg_input.submit(
            enhanced_send, [chatbot_ui, msg_input], [chatbot_ui, msg_input]
        )

    return demo

In [None]:
# Cell 7: Complete UI Integration Test
def run_comprehensive_test():
    """Test all streaming and retry features"""
    print("=== 串流與重試功能綜合測試 ===")

    # Test 1: Basic streaming
    print("\n1. 測試基本串流生成...")
    test_messages = [{"role": "user", "content": "Hello, how are you?"}]

    response_parts = []
    for partial in llm_adapter.stream_generate(test_messages, max_new_tokens=50):
        response_parts.append(partial)
        if len(response_parts) % 5 == 0:  # Show progress every 5 tokens
            print(f"   部分回應: {partial[:50]}...")

    print(f"   ✅ 完整回應: {response_parts[-1] if response_parts else 'No response'}")

    # Test 2: Cancellation
    print("\n2. 測試取消機制...")
    stop_event = threading.Event()

    def delayed_cancel():
        time.sleep(0.2)  # Cancel after 200ms
        stop_event.set()
        print("   ⏹️ 已發送取消信號")

    cancel_thread = threading.Thread(target=delayed_cancel)
    cancel_thread.start()

    cancelled_parts = []
    for partial in llm_adapter.stream_generate(
        test_messages, max_new_tokens=100, stop_event=stop_event
    ):
        cancelled_parts.append(partial)

    cancel_thread.join()
    final_response = cancelled_parts[-1] if cancelled_parts else ""
    is_cancelled = "[CANCELLED]" in final_response
    print(f"   ✅ 取消測試: {'成功' if is_cancelled else '失敗'}")

    # Test 3: Error handling
    print("\n3. 測試錯誤處理...")
    try:
        # Simulate error by passing invalid input
        error_messages = [{"role": "invalid", "content": ""}]
        list(llm_adapter.stream_generate(error_messages, max_new_tokens=10))
        print("   ⚠️ 預期錯誤未發生")
    except Exception as e:
        print(f"   ✅ 錯誤處理成功: {type(e).__name__}")

    # Test 4: Retry manager
    print("\n4. 測試重試管理器...")
    retry_mgr = RetryManager(max_retries=2)

    def failing_function():
        if retry_mgr.retry_count < 1:  # Fail on first attempt
            raise ValueError("模擬失敗")
        return "成功"

    try:
        result = retry_mgr.attempt_generation(failing_function)
        print(f"   ✅ 重試成功: {result}")
    except:
        print("   ❌ 重試失敗")

    status = retry_mgr.get_retry_status()
    print(f"   📊 重試狀態: {status['retry_count']}/{status['max_retries']} 次嘗試")

    print("\n=== 測試完成 ===")
    return True


# Run comprehensive test
test_result = run_comprehensive_test()

In [None]:
# Cell 8: Smoke Test - Launch UI
def smoke_test():
    """Quick smoke test to verify everything works"""
    print("🧪 Smoke Test: Gradio 串流介面")

    try:
        # Create and launch demo
        demo = create_enhanced_ui()
        print("✅ UI 建立成功")

        # Test basic functionality without launching server
        print("✅ 所有組件初始化完成")
        print("✅ 事件處理器綁定成功")
        print("✅ 錯誤處理機制就緒")

        print("\n🚀 啟動 Gradio 界面...")
        print("📝 功能測試清單:")
        print("   - 輸入訊息並觀察串流輸出")
        print("   - 點擊「停止」按鈕測試取消功能")
        print("   - 點擊「重試」按鈕測試重新生成")
        print("   - 點擊「清除」按鈕清空對話")

        # Launch with specific config for testing
        demo.launch(
            server_name="0.0.0.0",
            server_port=7861,
            share=False,
            debug=True,
            show_error=True,
            quiet=False,
        )

    except Exception as e:
        print(f"❌ Smoke test 失敗: {str(e)}")
        print(f"錯誤詳情: {traceback.format_exc()}")
        return False

    return True


# Run smoke test
print("開始 Smoke Test...")
smoke_result = smoke_test()

In [None]:
# 低 VRAM 優化選項
model_config:
  device_map: "auto"           # 自動設備分配
  torch_dtype: "auto"          # 自動精度選擇
  low_cpu_mem_usage: true      # 減少 CPU 記憶體使用

# 串流生成參數
streaming_config:
  max_new_tokens: 200          # 限制生成長度
  temperature: 0.7             # 控制創造性
  stream_delay: 0.05           # 串流延遲 (秒)

# 重試機制設定
retry_config:
  max_retries: 2               # 最大重試次數
  retry_delay: 1.0             # 重試間隔 (秒)

# UI 設定
ui_config:
  chatbot_height: 450          # 對話區高度
  show_copy_button: true       # 顯示複製按鈕
  bubble_full_width: false     # 泡泡框寬度限制