In [None]:
!pip install unsloth

In [None]:
#deep research should be installed in the same environment as the notebook

# cd ./tools/deep_research
# pip install -e .

In [None]:
from unsloth import FastModel
import os
import torch
from transformers import AutoTokenizer, TextIteratorStreamer, pipeline
from threading import Thread
import json
import time
from deep_research.query import process_query
from deep_research.search import search
from deep_research.crawler import WebCrawler
import gradio as gr
from unsloth.chat_templates import get_chat_template
import re

class StreamingResearchPipeline:
    def __init__(self, model_name="unsloth/gemma-3-4b-it", cache_dir="./cache", system_prompt_file=None):
        """
        Initialize a streaming pipeline with Unsloth's Gemma3 model that can integrate
        live research via deep_research during generation.
        
        Args:
            model_name (str): The Unsloth model ID
            cache_dir (str): Directory for caching web pages
            system_prompt_file (str): Path to a text file containing the system prompt
        """
        print(f"Loading model: {model_name}")

        # Track timing metrics for each operation
        self.metrics = {
            "model_load_time": 0,
            "research_time": 0,
            "total_tokens_generated": 0,
            "total_searches": 0,
            "successful_searches": 0,
            "failed_searches": 0
        }
        
        start_time = time.time()
        
        # Load the model and tokenizer using FastModel
        self.model, self.tokenizer = FastModel.from_pretrained(
            model_name = model_name,
            max_seq_length = 2048,  # Choose any for long context!
            load_in_4bit = True,    # 4 bit quantization to reduce memory
            load_in_8bit = False,   # A bit more accurate, uses 2x memory
            full_finetuning = False # Full finetuning support
        )
        
        # Apply the chat template
        self.tokenizer = get_chat_template(
            self.tokenizer,
            chat_template = "gemma-3"
        )
            
        self.metrics["model_load_time"] = time.time() - start_time
        print(f"Model loaded successfully in {self.metrics['model_load_time']:.2f} seconds")
        
        # Initialize the deep research crawler with detailed logging
        print("Initializing research crawler...")
        self.crawler = WebCrawler(
            cache_dir=cache_dir,
            respect_robots=True,
            requests_per_second=1,
            max_depth=2,
            max_pages_per_site=5,
            summarization_model="facebook/bart-large-cnn",
            relevance_model="cross-encoder/ms-marco-MiniLM-L-6-v2",
            qa_model="deepset/roberta-base-squad2"
        )
        
        # Store model generation parameters for transparency
        self.generation_params = {
            "max_new_tokens": 1024,
            "do_sample": True,
            "temperature": 0.7,  # Slightly reduced for more focused responses
            "top_p": 0.92,
            "top_k": 50
        }
        
        # Load system prompt
        if system_prompt_file and os.path.exists(system_prompt_file):
            with open(system_prompt_file, 'r') as f:
                self.system_prompt = f.read()
            print(f"Loaded system prompt from {system_prompt_file}")
        else:
            # Improved system prompt
            self.system_prompt = """
            You are an AI assistant with real-time research capabilities. When you need to search for information, 
            use the <deep_research>YOUR SEARCH QUERY</deep_research> tags. You should thoroughly research any topic 
            before providing your final answer.
            
            Follow these guidelines:
            1. Use <deep_research> tags ONLY for actual web searches, not when discussing the tags themselves.
            2. Perform as many searches as needed to gather comprehensive information.
            3. Use specific, focused search queries rather than broad ones.
            4. After completing all necessary research, provide your final answer between <ANSWER> and </ANSWER> tags.
            5. Include citations to your sources in the final answer.
            6. If answering factual questions, ensure your answer is based on the research results.
            
            Example usage:
            User: "What are the latest advancements in quantum computing?"
            
            You: "I'll research this topic for you.
            <deep_research>latest advancements in quantum computing 2025</deep_research>
            
            Let me check for more specific information about quantum error correction.
            <deep_research>quantum error correction breakthroughs 2025</deep_research>
            
            Now I'll look for information about quantum advantage demonstrations.
            <deep_research>quantum computational advantage demonstrations 2025</deep_research>
            
            <ANSWER>
            
            The latest advancements in quantum computing ....
            </ANSWER>"
            """
            print("Using default improved system prompt")
    
    def perform_research(self, query):
        """
        Execute the deep_research pipeline on a query with detailed logging
        
        Args:
            query (str): The search query
            
        Returns:
            dict: Research results with sources and debug information
        """
        # Skip research if it appears to be discussing the tags rather than an actual query
        if "deep_research" in query.lower() or "example" in query.lower() or len(query.strip()) < 3:
            print(f"Skipping research for meta-query: {query}")
            self.metrics["skipped_searches"] = self.metrics.get("skipped_searches", 0) + 1
            return f"\n\n[RESEARCH SKIPPED] Meta-query or system prompt example detected: '{query}'\n\n", {
                "original_query": query,
                "skipped": True,
                "reason": "Meta-query or system prompt example"
            }
        
        research_start = time.time()
        print(f"Performing research: {query}")
        
        try:
            # Process the query
            query_result = process_query(query)
            query_process_time = time.time() - research_start
            print(f"Query processing complete in {query_process_time:.2f}s")
            print(f"Parsed query: {query_result['parsed_query']}")
            
            # Search for relevant information
            search_start = time.time()
            search_results = search(
                query_result['parsed_query'],
                engines=['duckduckgo'],
                max_results=12  # Increased for better coverage
            )
            search_time = time.time() - search_start
            print(f"Search complete in {search_time:.2f}s, found {len(search_results)} results")
            
            # Log search results for visibility
            for i, result in enumerate(search_results[:3]):
                try:
                    print(f"Result {i+1}: {result.title} - {result.url}")
                except AttributeError:
                    print(f"Result {i+1}: Unable to extract title/URL from result format")
            
            # Crawl and synthesize information
            crawl_start = time.time()
            answer = self.crawler.generate_definitive_answer(
                query,
                search_results=search_results
            )
            crawl_time = time.time() - crawl_start
            print(f"Crawling and synthesis complete in {crawl_time:.2f}s")
            
            # Track total research time
            total_research_time = time.time() - research_start
            self.metrics["research_time"] += total_research_time
            self.metrics["total_searches"] += 1
            self.metrics["successful_searches"] += 1
            
            # Format the result with original text and debug info
            result = {
                "original_query": query,
                "parsed_query": query_result['parsed_query'],
                "answer": answer['answer'],
                "sources": answer['sources'],
                "metrics": {
                    "query_process_time": query_process_time,
                    "search_time": search_time,
                    "crawl_time": crawl_time,
                    "total_research_time": total_research_time,
                    "num_sources": len(answer['sources'])
                },
                "success": True
            }
            
            # Format the text result
            text_result = f"\n\n[RESEARCH RESULTS for: '{query}']\n\n"
            text_result += f"{answer['answer']}\n\n"
            text_result += "[SOURCES]:\n"
            
            # Extract source information
            for i, source in enumerate(answer['sources']):
                try:
                    title = source.title
                    url = source.url
                    text_result += f"- [{i+1}] {title}: {url}\n"
                except AttributeError:
                    try:
                        title = source.get('title', source['title'] if 'title' in source else 'No title')
                        url = source.get('url', source['url'] if 'url' in source else 'No URL')
                        text_result += f"- [{i+1}] {title}: {url}\n"
                    except (AttributeError, KeyError, TypeError):
                        text_result += f"- [{i+1}] Source information unavailable\n"
            
            text_result += "\n"
            text_result += f"[DEBUG]: Research completed in {total_research_time:.2f}s, "
            text_result += f"searched {len(search_results)} pages, used {len(answer['sources'])} sources\n"
            
            return text_result, result
            
        except Exception as e:
            # Handle exceptions during research
            error_time = time.time() - research_start
            error_message = f"Error during research: {str(e)}"
            print(error_message)
            
            self.metrics["total_searches"] += 1
            self.metrics["failed_searches"] += 1
            
            # Return error information
            text_result = f"\n\n[RESEARCH ERROR for: '{query}']\n\n"
            text_result += f"An error occurred during research: {str(e)}\n"
            text_result += f"The model will continue generation with available information.\n\n"
            
            result = {
                "original_query": query,
                "error": str(e),
                "error_time": error_time,
                "success": False
            }
            
            return text_result, result
    
    def create_combined_prompt(self, user_prompt, chat_history=None):
        """
        Create a chat-based prompt using the Gemma-3 format
        
        Args:
            user_prompt (str): The user's input
            chat_history (list): List of (user, assistant) message pairs
            
        Returns:
            str: The formatted prompt for the model
        """
        # Create messages list in the format expected by the chat template
        messages = []
        
        # Add system message if provided
        if self.system_prompt:
            messages.append({
                "role": "system",
                "content": [{"type": "text", "text": self.system_prompt}]
            })
        
        # Add chat history if provided
        if chat_history:
            for user_msg, assistant_msg in chat_history:
                messages.append({
                    "role": "user",
                    "content": [{"type": "text", "text": user_msg}]
                })
                if assistant_msg:  # Skip None responses (in case of still generating)
                    messages.append({
                        "role": "assistant",
                        "content": [{"type": "text", "text": assistant_msg}]
                    })
        
        # Add current user message
        messages.append({
            "role": "user",
            "content": [{"type": "text", "text": user_prompt}]
        })
        
        # Apply chat template
        text = self.tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True  # Must add for generation
        )
        
        # Log prompt information
        input_ids = self.tokenizer([text], return_tensors="pt").input_ids
        prompt_tokens = input_ids.shape[1]
        print(f"Created prompt with {prompt_tokens} tokens")
        
        return text
    
    def generate_streaming_with_research(self, prompt, chat_history=None):
        """
        Generate a response with live research integration and detailed process logging
        
        Args:
            prompt (str): The user input
            chat_history (list): Optional chat history
            
        Returns:
            generator: A generator yielding response tokens with research integrated
        """
        # Create the formatted chat prompt
        text = self.create_combined_prompt(prompt, chat_history)
        
        # Set up the text streamer
        streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
        
        # Set up generation parameters using the exact pattern from the example
        generation_kwargs = {
            **self.tokenizer([text], return_tensors="pt").to("cuda"),
            "max_new_tokens": self.generation_params["max_new_tokens"],
            "do_sample": self.generation_params["do_sample"],
            "temperature": self.generation_params["temperature"],
            "top_p": self.generation_params["top_p"],
            "top_k": self.generation_params["top_k"],
            "streamer": streamer
        }
        
        # Print generation information
        print(f"Starting generation with params: {json.dumps(self.generation_params, indent=2)}")
        gen_start_time = time.time()
        
        # Start generation in a separate thread
        generation_thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
        generation_thread.start()
        
        # Variables for tracking tags
        tag_buffer = ""
        in_tag = False
        research_query = ""
        token_count = 0
        
        # Variables for tracking answer tags
        in_answer = False
        final_answer = ""
        
        # Debug data for UI display
        debug_data = {
            "token_count": 0,
            "token_rate": 0,
            "research_queries": [],
            "research_times": [],
            "has_final_answer": False
        }
        
        # Stream tokens and look for research tags
        generated_text = ""
        start_time = time.time()
        for new_text in streamer:
            current_time = time.time()
            token_count += 1
            debug_data["token_count"] = token_count
            elapsed = current_time - start_time
            if elapsed > 0:
                debug_data["token_rate"] = token_count / elapsed
            
            # Format debug info to show in UI
            debug_str = f"[REALTIME DEBUG] Tokens: {token_count} | Rate: {debug_data['token_rate']:.1f} t/s"
            if debug_data["research_queries"]:
                debug_str += f" | Searches: {len(debug_data['research_queries'])}"
            if debug_data["has_final_answer"]:
                debug_str += f" | Final Answer: Found"
                
            # Include the debug info in what we yield
            yield debug_str + "\n"
            
            generated_text += new_text
            tag_buffer += new_text
            
            # Format token information for UI
            token_info = f"[TOKEN] '{new_text}'"
            yield token_info + "\n"
            
            # Check for opening research tag
            if not in_tag and "<deep_research>" in tag_buffer:
                in_tag = True
                tag_start_index = tag_buffer.find("<deep_research>")
                # Yield everything before the tag
                yield f"[TAG DETECTED] Opening research tag\n"
                # Reset the buffer to just track what's inside the tag
                tag_buffer = tag_buffer[tag_start_index + len("<deep_research>"):]
            
            # Check for closing research tag when we're inside a tag
            elif in_tag and "</deep_research>" in tag_buffer:
                in_tag = False
                research_query = tag_buffer[:tag_buffer.find("</deep_research>")]
                
                # Yield the query and that we're researching
                yield f"[RESEARCH STARTED] Query: '{research_query}'\n"
                
                # Perform the research
                query_start_time = time.time()
                text_results, research_results = self.perform_research(research_query)
                query_duration = time.time() - query_start_time
                
                # Update debug data
                debug_data["research_queries"].append(research_query)
                debug_data["research_times"].append(query_duration)
                
                # Yield research status
                if research_results.get("success", False):
                    yield f"[RESEARCH COMPLETED] in {query_duration:.2f}s\n"
                else:
                    yield f"[RESEARCH FAILED] in {query_duration:.2f}s: {research_results.get('error', 'Unknown error')}\n"
                
                # Yield research data for UI
                try:
                    research_meta = {
                        "query": research_query,
                        "duration": query_duration,
                        "success": research_results.get("success", False)
                    }
                    
                    if research_results.get("success", False):
                        research_meta["num_sources"] = len(research_results["sources"])
                        research_meta["sources"] = []
                        for s in research_results["sources"][:3]:
                            try:
                                research_meta["sources"].append({"title": s.title, "url": s.url})
                            except AttributeError:
                                try:
                                    research_meta["sources"].append({
                                        "title": s.get("title", "No title"),
                                        "url": s.get("url", "No URL")
                                    })
                                except (AttributeError, KeyError, TypeError):
                                    research_meta["sources"].append({"title": "Source info unavailable", "url": ""})
                    
                    yield f"[RESEARCH DETAILS] {json.dumps(research_meta, indent=2)}\n"
                except Exception as e:
                    yield f"[RESEARCH DETAILS ERROR] Failed to generate research details: {str(e)}\n"
                
                # Yield the actual research results
                yield text_results
                
                # Reset the buffer
                tag_buffer = ""
            
            # Check for opening answer tag
            elif not in_answer and "<ANSWER>" in generated_text[-20:]:
                in_answer = True
                debug_data["has_final_answer"] = True
                yield f"[FINAL ANSWER STARTED]\n"
                
            # Check for closing answer tag
            elif in_answer and "</ANSWER>" in generated_text[-20:]:
                in_answer = False
                yield f"[FINAL ANSWER COMPLETED]\n"
                
                # Extract the final answer
                answer_pattern = r"<ANSWER>(.*?)</ANSWER>"
                answer_matches = re.findall(answer_pattern, generated_text, re.DOTALL)
                if answer_matches:
                    final_answer = answer_matches[-1].strip()
                    yield f"[EXTRACTED ANSWER] Length: {len(final_answer)} characters\n"
            
            # Normal streaming when not at a tag boundary
            elif not in_tag:
                tag_buffer = tag_buffer[-30:]  # Keep a sliding window for tag detection
        
        # Generation complete - yield final stats
        total_gen_time = time.time() - gen_start_time
        self.metrics["total_tokens_generated"] += token_count
        
        # Extract the final answer if it exists but wasn't caught during streaming
        if not final_answer:
            answer_pattern = r"<ANSWER>(.*?)</ANSWER>"
            answer_matches = re.findall(answer_pattern, generated_text, re.DOTALL)
            if answer_matches:
                final_answer = answer_matches[-1].strip()
        
        final_stats = {
            "total_tokens": token_count,
            "generation_time": total_gen_time, 
            "tokens_per_second": token_count / total_gen_time if total_gen_time > 0 else 0,
            "research_queries": len(debug_data["research_queries"]),
            "total_research_time": sum(debug_data["research_times"]),
            "average_research_time": sum(debug_data["research_times"]) / len(debug_data["research_times"]) if debug_data["research_times"] else 0,
            "has_final_answer": bool(final_answer)
        }
        
        yield f"\n\n[GENERATION COMPLETE]\n{json.dumps(final_stats, indent=2)}\n"
        
        # If there's a final answer, yield it separately for easy access
        if final_answer:
            yield f"\n\n[FINAL FORMATTED ANSWER]\n{final_answer}\n"
    
    def get_metrics(self):
        """Return the current pipeline metrics as a formatted string"""
        return json.dumps(self.metrics, indent=2)
    
    def chat(self, user_prompt, chat_history=None):
        """
        Convenience method for chat-style usage
        
        Args:
            user_prompt (str): The user's message
            chat_history (list): Optional list of (user, assistant) message pairs
            
        Returns:
            str: The complete response with research integrated
        """
        response = ""
        for token in self.generate_streaming_with_research(user_prompt, chat_history):
            response += token
        return response


# Improved Gradio interface with two panels:
# 1. Chat interface for normal interaction
# 2. Debug panel showing real-time information
def create_gradio_interface():
    # Initialize the pipeline
    pipeline = StreamingResearchPipeline(
        model_name="unsloth/gemma-3-4b-it",
        cache_dir="./cache",
        system_prompt_file=None
    )

    # Initialize chat history
    history = []
    
    # Store debug information
    debug_history = []
    
    # Store final answers for each query
    final_answers = {}

    def user(message, history):
        # Return the user's message to display it
        debug_history.clear()  # Reset debug for new query
        return message, history + [[message, None]]

    def bot(history):
        # Extract the user's message
        user_message = history[-1][0]
        
        # Convert previous history to the format expected by the pipeline
        prev_history = history[:-1] if len(history) > 1 else None
        
        # For storing the actual response vs debug info
        actual_response = ""
        current_debug = ""
        final_answer = ""
        
        # Track if we're currently inside the answer block
        in_answer_block = False
        
        # Generate the response with research integration
        for token in pipeline.generate_streaming_with_research(user_message, prev_history):
            # Check for final formatted answer
            if token.startswith("\n\n[FINAL FORMATTED ANSWER]"):
                final_answer = token.split("\n", 3)[3].strip()
                # Don't add this to the debug or response
                continue
                
            # Check if this is a debug/status line
            if token.startswith("["):
                # Store the debug info to display in debug panel
                debug_history.append(token)
                current_debug = "\n".join(debug_history[-20:])  # Keep last 20 debug lines
                
                # If it's a research result, add it to the actual response
                if token.startswith("[RESEARCH RESULTS") or token.startswith("[SOURCES"):
                    actual_response += token
                elif token.startswith("[FINAL ANSWER STARTED]"):
                    in_answer_block = True
                elif token.startswith("[FINAL ANSWER COMPLETED]"):
                    in_answer_block = False
            else:
                # Normal text token - add to the response
                actual_response += token
            
            # Update both panels
            history[-1][1] = actual_response
            yield history, current_debug
        
        # Store the final answer for this query
        if final_answer:
            final_answers[user_message] = final_answer
            
            # If we found a final answer, replace the response with it
            # This ensures the user sees a clean answer without all the debug info
            clean_response = f"I researched your question thoroughly.\n\n{final_answer}"
            history[-1][1] = clean_response
        
        # Final stats - add pipeline metrics
        debug_history.append("\n[PIPELINE METRICS]")
        debug_history.append(pipeline.get_metrics())
        current_debug = "\n".join(debug_history[-20:])
        
        # Final yield to ensure everything is displayed
        yield history, current_debug

    with gr.Blocks(css="footer {visibility: hidden}") as demo:
        gr.Markdown("# 🔍 Gemma3 with Deep Research Integration")
        gr.Markdown("""This enhanced demo shows the real-time inner workings of Unsloth's optimized Gemma3 model with research capabilities.
        
        The model streams every token as it's generated and shows you exactly when it searches for information using `<deep_research>` tags.
        It then provides the final researched answer between `<ANSWER>` tags.
        
        Try asking about current events, recent news, or specific factual information!
        
        **The right panel shows you real-time debug information including:**
        - Each token as it's generated
        - When research is triggered and completed
        - Source information and metrics
        - Performance statistics
        """)
        
        with gr.Row():
            with gr.Column(scale=6):
                chatbot = gr.Chatbot([], elem_id="chatbot", height=600)
                with gr.Row():
                    msg = gr.Textbox(placeholder="Ask me anything...", show_label=False)
                    clear = gr.Button("Clear Chat")
            
            with gr.Column(scale=4):
                debug_panel = gr.Textbox(
                    value="Debug information will appear here during generation...",
                    label="Real-time Debug Information",
                    lines=30,
                    max_lines=30
                )
        
        msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
            bot, [chatbot], [chatbot, debug_panel]
        )
        clear.click(lambda: [], None, [chatbot, debug_panel], queue=False)
        
        gr.Markdown("""
        ### System Information
        - **Model**: Unsloth's optimized Gemma-3-4b with custom research capabilities
        - **Hardware**: Running on CUDA for acceleration
        - **Research**: Uses DuckDuckGo search and fetches/summarizes web content
        - **Format**: Final answers are provided within <ANSWER> tags
        """)

    return demo

In [None]:
# Run the interface
if __name__ == "__main__":
    demo = create_gradio_interface()
    demo.launch(share=True, debug=True)