# HM-RAG: Hierarchical Multi-Agent Multimodal RAG ‚Äî Google Colab Setup

**What this notebook does:**
1. Clones the repo
2. Installs all dependencies (from `requirements.txt`)
3. Installs & starts Ollama, pulls `qwen2.5:1.5b` + `nomic-embed-text`
4. Patches source files to fix the **embedding dimension mismatch** (768 vs 1024)
5. Downloads ScienceQA dataset
6. Configures API keys (SerpAPI for web search, optional HF token)
7. Runs inference

**Model used:** `qwen2.5:1.5b` (text) + `Qwen/Qwen2.5-VL-2B-Instruct` (vision, only if images present)
**Web search:** SerpAPI only (requires a free key from https://serpapi.com)

## Step 1: Clone the Repository

In [None]:
import os

# Clone only if not already cloned
if not os.path.exists('/content/HMRAG'):
    !git clone https://github.com/ab-2109/HMRAG.git /content/HMRAG
    print("‚úì Repository cloned")
else:
    print("‚úì Repository already exists")

%cd /content/HMRAG
print(f"Working directory: {os.getcwd()}")

## Step 2: Install Dependencies
Installs all packages from `requirements.txt`. This handles SerpAPI (`google-search-results`), LightRAG, LangChain, transformers, etc.

In [None]:
import os
os.chdir('/content/HMRAG')

print("Installing dependencies from requirements.txt (this may take a few minutes)...")
print("=" * 60)

# Install numpy first to avoid conflicts
!pip install -q numpy==1.26.4

# Install from the repo's requirements.txt
!pip install -q -r requirements.txt

# Ensure critical packages are installed (in case requirements.txt missed any)
!pip install -q google-search-results langchain-ollama huggingface_hub

print("\n" + "=" * 60)
print("‚úì All dependencies installed successfully!")
print("Note: Some dependency warnings are normal and won't affect functionality.")

## Step 3: Patch Source Files (Critical Fixes)
This cell patches the cloned source files to:
1. **Fix embedding dimension mismatch**: `nomic-embed-text` outputs 768 dims, but LightRAG's `ollama_embed` decorator defaults to 1024. We use `ollama_embed.func` (unwrapped) and set `embedding_dim=768`.
2. **Use `qwen2.5:1.5b`** everywhere instead of `qwen2.5:7b` (fits in Colab GPU memory).
3. **Reduce `num_ctx` to 4096** (the 1.5B model can't handle 65536).
4. **Use `Qwen2.5-VL-2B-Instruct`** for vision instead of the 7B variant.
5. **Add HF token support** for downloading gated models.
6. **Fix device handling** in the vision model to avoid dimension mismatches on Colab.

In [None]:
import os
os.chdir('/content/HMRAG')

# =============================================================================
# PATCH 1: retrieval/vector_retrieval.py
# =============================================================================
with open('retrieval/vector_retrieval.py', 'w') as f:
    f.write(
        "from lightrag import LightRAG, QueryParam\n"
        "from lightrag.llm.ollama import ollama_model_complete, ollama_embed\n"
        "from lightrag.utils import EmbeddingFunc\n"
        "\n"
        "from retrieval.base_retrieval import BaseRetrieval\n"
        "\n"
        "\n"
        "class VectorRetrieval(BaseRetrieval):\n"
        "    def __init__(self, config):\n"
        "        self.config = config\n"
        "        self.mode = getattr(config, 'mode', 'naive')\n"
        "        self.top_k = getattr(config, 'top_k', 4)\n"
        "        ollama_host = getattr(config, 'ollama_base_url', 'http://localhost:11434')\n"
        "        model_name = getattr(config, 'llm_model_name', 'qwen2.5:1.5b')\n"
        "        working_dir = getattr(config, 'working_dir', './lightrag_workdir')\n"
        "\n"
        "        self.client = LightRAG(\n"
        "            working_dir=working_dir,\n"
        "            llm_model_func=ollama_model_complete,\n"
        "            llm_model_name=model_name,\n"
        "            llm_model_max_async=4,\n"
        '            llm_model_kwargs={"host": ollama_host, "options": {"num_ctx": 4096}},\n'
        "            embedding_func=EmbeddingFunc(\n"
        "                embedding_dim=768,\n"
        "                max_token_size=8192,\n"
        "                func=lambda texts: ollama_embed.func(\n"
        '                    texts, embed_model="nomic-embed-text", host=ollama_host\n'
        "                ),\n"
        "            ),\n"
        "        )\n"
        "        self.results = []\n"
        "\n"
        "    def find_top_k(self, query):\n"
        "        try:\n"
        "            self.results = self.client.query(\n"
        "                query,\n"
        '                param=QueryParam(mode="naive", top_k=self.top_k)\n'
        "            )\n"
        "        except Exception as e:\n"
        '            print(f"VectorRetrieval error: {e}")\n'
        '            self.results = f"Vector retrieval failed: {str(e)}"\n'
        "        return self.results\n"
    )
print("‚úì Patched retrieval/vector_retrieval.py")

# =============================================================================
# PATCH 2: retrieval/graph_retrieval.py
# =============================================================================
with open('retrieval/graph_retrieval.py', 'w') as f:
    f.write(
        "from lightrag import LightRAG, QueryParam\n"
        "from lightrag.llm.ollama import ollama_model_complete, ollama_embed\n"
        "from lightrag.utils import EmbeddingFunc\n"
        "\n"
        "from retrieval.base_retrieval import BaseRetrieval\n"
        "\n"
        "\n"
        "class GraphRetrieval(BaseRetrieval):\n"
        "    def __init__(self, config):\n"
        "        self.config = config\n"
        "        self.mode = getattr(config, 'mode', 'mix')\n"
        "        self.top_k = getattr(config, 'top_k', 4)\n"
        "        ollama_host = getattr(config, 'ollama_base_url', 'http://localhost:11434')\n"
        "        model_name = getattr(config, 'llm_model_name', 'qwen2.5:1.5b')\n"
        "        working_dir = getattr(config, 'working_dir', './lightrag_workdir')\n"
        "\n"
        "        self.client = LightRAG(\n"
        "            working_dir=working_dir,\n"
        "            llm_model_func=ollama_model_complete,\n"
        "            llm_model_name=model_name,\n"
        "            llm_model_max_async=4,\n"
        '            llm_model_kwargs={"host": ollama_host, "options": {"num_ctx": 4096}},\n'
        "            embedding_func=EmbeddingFunc(\n"
        "                embedding_dim=768,\n"
        "                max_token_size=8192,\n"
        "                func=lambda texts: ollama_embed.func(\n"
        '                    texts, embed_model="nomic-embed-text", host=ollama_host\n'
        "                ),\n"
        "            ),\n"
        "        )\n"
        "        self.results = []\n"
        "\n"
        "    def find_top_k(self, query):\n"
        "        try:\n"
        "            self.results = self.client.query(\n"
        "                query,\n"
        "                param=QueryParam(mode=self.mode, top_k=self.top_k)\n"
        "            )\n"
        "        except Exception as e:\n"
        '            print(f"GraphRetrieval error: {e}")\n'
        '            self.results = f"Graph retrieval failed: {str(e)}"\n'
        "        return self.results\n"
    )
print("‚úì Patched retrieval/graph_retrieval.py")

# =============================================================================
# PATCH 3: retrieval/web_retrieval.py ‚Äî SerpAPI only, qwen2.5:1.5b
# =============================================================================
with open('retrieval/web_retrieval.py', 'w') as f:
    f.write(
        "from langchain_community.utilities import SerpAPIWrapper\n"
        "from langchain_ollama import OllamaLLM\n"
        "\n"
        "from retrieval.base_retrieval import BaseRetrieval\n"
        "\n"
        "\n"
        "class WebRetrieval(BaseRetrieval):\n"
        "    def __init__(self, config):\n"
        "        self.config = config\n"
        '        self.search_engine = "Google"\n'
        "\n"
        "        serpapi_api_key = getattr(config, 'serpapi_api_key', '')\n"
        "        self.top_k = getattr(config, 'top_k', 4)\n"
        "        ollama_base_url = getattr(config, 'ollama_base_url', 'http://localhost:11434')\n"
        "        web_llm_model = getattr(config, 'web_llm_model_name', 'qwen2.5:1.5b')\n"
        "\n"
        "        self.client = SerpAPIWrapper(\n"
        "            serpapi_api_key=serpapi_api_key\n"
        "        )\n"
        "\n"
        "        self.llm = OllamaLLM(\n"
        "            base_url=ollama_base_url,\n"
        "            model=web_llm_model,\n"
        "            temperature=0.35,\n"
        "        )\n"
        "        self.results = []\n"
        "\n"
        "    def format_results(self, results):\n"
        '        """Format the search results into readable text."""\n'
        "        max_results = self.top_k\n"
        "        processed = []\n"
        "\n"
        "        if isinstance(results, dict):\n"
        "            if 'answerBox' in results:\n"
        "                answer = results['answerBox']\n"
        "                processed.append(\n"
        "                    f\"Direct answer: {answer.get('answer', '')}\\n\"\n"
        "                    f\"Source: {answer.get('link', '')}\\n\"\n"
        "                )\n"
        "\n"
        "            if 'organic' in results:\n"
        "                for item in results['organic'][:max_results]:\n"
        "                    processed.append(\n"
        "                        f\"[{item.get('title', 'No title')}]\\n\"\n"
        "                        f\"{item.get('snippet', 'No snippet')}\\n\"\n"
        "                        f\"Link: {item.get('link', '')}\\n\"\n"
        "                    )\n"
        "\n"
        '        return "\\n".join(processed) if processed else "No relevant results found"\n'
        "\n"
        "    def generation(self, results_with_query):\n"
        '        """Use Ollama model to generate an answer from search results."""\n'
        "        try:\n"
        "            answer = self.llm.invoke(results_with_query)\n"
        "        except Exception as e:\n"
        '            print(f"WebRetrieval generation error: {e}")\n'
        '            answer = f"Web generation failed: {str(e)}"\n'
        "        return answer\n"
        "\n"
        "    def find_top_k(self, query):\n"
        "        try:\n"
        "            raw_results = self.client.results(query)\n"
        "            formatted_results = self.format_results(raw_results)\n"
        '            self.results = self.generation(formatted_results + "\\n" + query)\n'
        "        except Exception as e:\n"
        '            print(f"WebRetrieval error: {e}")\n'
        '            self.results = f"Web retrieval failed: {str(e)}"\n'
        "        return self.results\n"
    )
print("‚úì Patched retrieval/web_retrieval.py")

# =============================================================================
# PATCH 4: agents/decompose_agent.py ‚Äî qwen2.5:1.5b
# =============================================================================
with open('agents/decompose_agent.py', 'w') as f:
    f.write(
        "import os\n"
        "import re\n"
        "from typing import List\n"
        "from langchain_core.prompts import PromptTemplate\n"
        "from langchain_ollama import OllamaLLM\n"
        "\n"
        "\n"
        "class DecomposeAgent:\n"
        "    def __init__(self, config):\n"
        "        self.config = config\n"
        "        self.llm = OllamaLLM(\n"
        "            base_url=getattr(config, 'ollama_base_url', 'http://localhost:11434'),\n"
        "            model=getattr(config, 'llm_model_name', 'qwen2.5:1.5b'),\n"
        "            temperature=getattr(config, 'temperature', 0.35),\n"
        "        )\n"
        "\n"
        "    def count_intents(self, query: str) -> int:\n"
        "        prompt = PromptTemplate.from_template(\n"
        '            "Please calculate how many independent intents are contained in the following query. "\n'
        '            "Return only an integer:\\n{query}\\nNumber of intents: "\n'
        "        )\n"
        "        max_attempts = 3\n"
        "        for attempt in range(max_attempts):\n"
        "            formatted_prompt = prompt.format(query=query)\n"
        "            response = self.llm.invoke(formatted_prompt)\n"
        "            try:\n"
        "                numbers = re.findall(r'\\d+', response.strip())\n"
        "                if numbers:\n"
        "                    return int(numbers[0])\n"
        "            except (ValueError, IndexError):\n"
        "                pass\n"
        "            if attempt == max_attempts - 1:\n"
        "                return 1\n"
        "        return 1\n"
        "\n"
        "    def decompose(self, query: str) -> List[str]:\n"
        "        intent_count = self.count_intents(query)\n"
        "        intent_count = min(intent_count, 3)\n"
        "        if intent_count > 1:\n"
        "            return self._split_query(query)\n"
        "        return [query]\n"
        "\n"
        "    def _split_query(self, query: str) -> List[str]:\n"
        "        prompt = PromptTemplate.from_template(\n"
        '            "Split the following query into multiple independent sub-queries, "\n'
        "            \"separated by '||', without additional explanations:\\n{query}\\nList of sub-queries: \"\n"
        "        )\n"
        "        formatted_prompt = prompt.format(query=query)\n"
        "        response = self.llm.invoke(formatted_prompt)\n"
        '        sub_queries = [q.strip() for q in response.split("||") if q.strip()]\n'
        "        if not sub_queries:\n"
        "            return [query]\n"
        "        return sub_queries\n"
    )
print("‚úì Patched agents/decompose_agent.py")

# =============================================================================
# PATCH 5: agents/summary_agent.py
# =============================================================================
with open('agents/summary_agent.py', 'w') as f:
    f.write(
        "from collections import Counter\n"
        "from langchain_ollama import OllamaLLM\n"
        "import re\n"
        "from transformers import AutoProcessor\n"
        "import random\n"
        "import os\n"
        "import torch\n"
        "\n"
        "from prompts.base_prompt import build_prompt\n"
        "\n"
        "\n"
        "class SummaryAgent:\n"
        "    def __init__(self, config):\n"
        "        self.config = config\n"
        "        self.text_llm = OllamaLLM(\n"
        "            base_url=getattr(config, 'ollama_base_url', 'http://localhost:11434'),\n"
        "            model=getattr(config, 'llm_model_name', 'qwen2.5:1.5b')\n"
        "        )\n"
        "        self.hf_token = getattr(config, 'hf_token', '') or os.environ.get('HF_TOKEN', '')\n"
        "        self._vision_model = None\n"
        "        self._processor = None\n"
        "\n"
        "    def _load_vision_model(self):\n"
        "        if self._vision_model is None:\n"
        "            try:\n"
        "                from transformers import Qwen2_5_VLForConditionalGeneration\n"
        "\n"
        '                model_name = "Qwen/Qwen2.5-VL-2B-Instruct"\n'
        "\n"
        "                token_kwargs = {}\n"
        "                if self.hf_token:\n"
        "                    token_kwargs['token'] = self.hf_token\n"
        "\n"
        "                self._processor = AutoProcessor.from_pretrained(\n"
        "                    model_name, use_fast=True, **token_kwargs\n"
        "                )\n"
        "                self._vision_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(\n"
        "                    model_name,\n"
        "                    torch_dtype=torch.float16,\n"
        '                    device_map="auto",\n'
        "                    **token_kwargs\n"
        "                )\n"
        "            except Exception as e:\n"
        '                print(f"Warning: Could not load vision model: {e}")\n'
        "                self._vision_model = None\n"
        "                self._processor = None\n"
        "\n"
        "    def summarize(self, problems, shot_qids, qid, cur_ans):\n"
        "        problem = problems[qid]\n"
        "        question = problem['question']\n"
        '        choices = problem["choices"]\n'
        "        answer = problem['answer']\n"
        "        image = problem.get('image', '')\n"
        "        caption = problem.get('caption', '')\n"
        '        split = problem.get("split", "test")\n'
        "\n"
        "        most_ans = self.get_most_common_answer(cur_ans)\n"
        "\n"
        "        if len(most_ans) == 1:\n"
        "            prediction = self.get_result(most_ans[0])\n"
        "            pred_idx = self.get_pred_idx(prediction, choices, self.config.options)\n"
        "        else:\n"
        '            if image and image == "image.png":\n'
        "                image_path = os.path.join(self.config.image_root, split, qid, image)\n"
        "            else:\n"
        '                image_path = ""\n'
        "\n"
        '            output_text = cur_ans[0] if len(cur_ans) > 0 else ""\n'
        '            output_graph = cur_ans[1] if len(cur_ans) > 1 else ""\n'
        '            output_web = cur_ans[2] if len(cur_ans) > 2 else ""\n'
        "\n"
        "            output = self.refine(output_text, output_graph, output_web,\n"
        "                                 problems, shot_qids, qid, self.config, image_path)\n"
        "            if output is None:\n"
        '                output = "FAILED"\n'
        '            print(f"output: {output}")\n'
        "\n"
        "            ans_fusion = self.get_result(output)\n"
        "            pred_idx = self.get_pred_idx(ans_fusion, choices, self.config.options)\n"
        "        return pred_idx, cur_ans\n"
        "\n"
        "    def get_most_common_answer(self, res):\n"
        "        if not res:\n"
        "            return []\n"
        "        counter = Counter(res)\n"
        "        max_count = max(counter.values())\n"
        "        most_common_values = [item for item, count in counter.items() if count == max_count]\n"
        "        return most_common_values\n"
        "\n"
        "    def refine(self, output_text, output_graph, output_web, problems, shot_qids, qid, args, image_path):\n"
        "        prompt = build_prompt(problems, shot_qids, qid, args)\n"
        '        prompt = f"{prompt} The answer is A, B, C, D, E or FAILED. \\n BECAUSE: "\n'
        "\n"
        "        if not image_path:\n"
        "            output = self.text_llm.invoke(prompt)\n"
        "        else:\n"
        "            output = self.qwen_reasoning(prompt, image_path)\n"
        "            if output:\n"
        '                print(f"**** output: {output}")\n'
        "                output = self.text_llm.invoke(\n"
        '                    f"{output[0]} Summary the above information with format "\n'
        "                    f\"'Answer: The answer is A, B, C, D, E or FAILED.\\n BECAUSE: '\"\n"
        "                )\n"
        "            else:\n"
        "                output = self.text_llm.invoke(prompt)\n"
        "        return output\n"
        "\n"
        "    def get_result(self, output):\n"
        "        pattern = re.compile(r'The answer is ([A-E])')\n"
        "        res = pattern.findall(output)\n"
        "        if len(res) == 1:\n"
        "            answer = res[0]\n"
        "        else:\n"
        '            answer = "FAILED"\n'
        "        return answer\n"
        "\n"
        "    def get_pred_idx(self, prediction, choices, options):\n"
        "        if prediction in options[:len(choices)]:\n"
        "            return options.index(prediction)\n"
        "        else:\n"
        "            return random.choice(range(len(choices)))\n"
        "\n"
        "    def qwen_reasoning(self, prompt, image_path):\n"
        "        self._load_vision_model()\n"
        "        if self._vision_model is None or self._processor is None:\n"
        '            print("Warning: Vision model not available, falling back to text-only.")\n'
        "            return None\n"
        "\n"
        "        try:\n"
        "            from qwen_vl_utils import process_vision_info\n"
        "        except ImportError:\n"
        '            print("Warning: qwen_vl_utils not installed, falling back to text-only.")\n'
        "            return None\n"
        "\n"
        "        messages = [\n"
        "            {\n"
        '                "role": "user",\n'
        '                "content": [\n'
        "                    {\n"
        '                        "type": "image",\n'
        '                        "image": image_path,\n'
        "                    },\n"
        '                    {"type": "text", "text": prompt},\n'
        "                ],\n"
        "            }\n"
        "        ]\n"
        "\n"
        "        text = self._processor.apply_chat_template(\n"
        "            messages, tokenize=False, add_generation_prompt=True\n"
        "        )\n"
        "        image_inputs, video_inputs = process_vision_info(messages)\n"
        "        inputs = self._processor(\n"
        "            text=[text],\n"
        "            images=image_inputs,\n"
        "            videos=video_inputs,\n"
        "            padding=True,\n"
        '            return_tensors="pt",\n'
        "        )\n"
        "\n"
        "        device = next(self._vision_model.parameters()).device\n"
        "        inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}\n"
        "\n"
        "        generated_ids = self._vision_model.generate(**inputs, max_new_tokens=512)\n"
        "        generated_ids_trimmed = [\n"
        "            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)\n"
        "        ]\n"
        "        output_text = self._processor.batch_decode(\n"
        "            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False\n"
        "        )\n"
        "        return output_text\n"
    )
print("‚úì Patched agents/summary_agent.py")

# =============================================================================
# PATCH 6: main.py ‚Äî qwen2.5:1.5b default, add --hf_token argument
# =============================================================================
with open('main.py', 'r') as f:
    main_content = f.read()

changes_made = []

if "default='qwen2.5:7b'" in main_content:
    main_content = main_content.replace("default='qwen2.5:7b'", "default='qwen2.5:1.5b'")
    changes_made.append("qwen2.5:7b -> qwen2.5:1.5b")

if '--hf_token' not in main_content:
    main_content = main_content.replace(
        "parser.add_argument('--top_k', type=int, default=4)\n",
        "parser.add_argument('--top_k', type=int, default=4)\n"
        "    parser.add_argument('--hf_token', type=str, default='',\n"
        "                        help='Hugging Face access token for downloading gated models')\n"
    )
    changes_made.append("Added --hf_token argument")

if 'HF_TOKEN' not in main_content:
    main_content = main_content.replace(
        "    agent = MRetrievalAgent(args)",
        "    # Set HF token if provided\n"
        "    if args.hf_token:\n"
        "        import os as _os\n"
        "        _os.environ['HF_TOKEN'] = args.hf_token\n"
        "        _os.environ['HUGGING_FACE_HUB_TOKEN'] = args.hf_token\n"
        "        try:\n"
        "            from huggingface_hub import login\n"
        "            login(token=args.hf_token)\n"
        "            print('Logged in to Hugging Face Hub')\n"
        "        except Exception as e:\n"
        "            print(f'Warning: Could not login to HF Hub: {e}')\n"
        "\n"
        "    agent = MRetrievalAgent(args)"
    )
    changes_made.append("Added HF login logic")

with open('main.py', 'w') as f:
    f.write(main_content)

if changes_made:
    print("‚úì Patched main.py: " + ", ".join(changes_made))
else:
    print("‚úì main.py already up to date")

# =============================================================================
# PATCH 7: YAML config files
# =============================================================================
for yaml_file in ['configs/decompose_agent.yaml', 'configs/multi_retrieval_agents.yaml']:
    if os.path.exists(yaml_file):
        with open(yaml_file, 'r') as f:
            content = f.read()
        if 'qwen2.5:7b' in content:
            content = content.replace('qwen2.5:7b', 'qwen2.5:1.5b')
            with open(yaml_file, 'w') as f:
                f.write(content)
            print(f"‚úì Patched {yaml_file}: qwen2.5:7b -> qwen2.5:1.5b")
        else:
            print(f"‚úì {yaml_file} already correct")

# Clean up stale working directory
!rm -rf ./lightrag_workdir

print("\n" + "=" * 60)
print("‚úì All patches applied!")
print("  - embedding_dim = 768 (matches nomic-embed-text)")
print("  - ollama_embed.func (bypasses 1024-dim decorator)")
print("  - num_ctx = 4096 (fits qwen2.5:1.5b)")
print("  - Text model: qwen2.5:1.5b")
print("  - Vision model: Qwen/Qwen2.5-VL-2B-Instruct")
print("  - Web search: SerpAPI only")
print("  - HF token support added")

## Step 4: Install and Start Ollama + Pull Models
Ollama runs locally on the Colab VM. We pull `qwen2.5:1.5b` (~1GB) and `nomic-embed-text` (~270MB).

In [None]:
import subprocess
import time
import os

# Install system dependencies
print("Installing system dependencies...")
!sudo apt-get update -qq 2>/dev/null
!sudo apt-get install -y -qq zstd 2>/dev/null

# Install Ollama
print("Installing Ollama...")
!curl -fsSL https://ollama.com/install.sh | sh 2>&1 | tail -3

# Find ollama binary
result = subprocess.run(['which', 'ollama'], capture_output=True, text=True)
ollama_path = result.stdout.strip()
if not ollama_path:
    for path in ['/usr/local/bin/ollama', '/usr/bin/ollama']:
        if os.path.exists(path):
            ollama_path = path
            break

if not ollama_path:
    print("‚ùå Ollama binary not found! Please restart runtime and try again.")
else:
    print(f"‚úì Ollama found at: {ollama_path}")
    
    # Kill any existing ollama processes
    subprocess.run(['pkill', '-f', 'ollama'], stderr=subprocess.DEVNULL)
    time.sleep(2)
    
    # Start Ollama server in background
    print("Starting Ollama server...")
    ollama_process = subprocess.Popen(
        [ollama_path, 'serve'],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE
    )
    time.sleep(8)
    
    # Verify server is running
    result = subprocess.run(['curl', '-s', 'http://localhost:11434/api/tags'],
                          capture_output=True, text=True, timeout=5)
    if result.returncode == 0:
        print("‚úì Ollama server is running!")
    else:
        print("‚ö†Ô∏è Ollama server may not be ready yet. Waiting more...")
        time.sleep(10)
    
    # Pull the text model
    print("\nPulling qwen2.5:1.5b model (~1GB, may take 2-5 min)...")
    !{ollama_path} pull qwen2.5:1.5b
    
    # Pull the embedding model
    print("\nPulling nomic-embed-text model (~270MB)...")
    !{ollama_path} pull nomic-embed-text
    
    print("\n‚úì Ollama setup complete!")
    print("\nAvailable models:")
    !{ollama_path} list

## Step 5: Configure API Keys
- **SerpAPI key** (required for web search): Get a free key at https://serpapi.com
- **HF token** (optional, for gated models): Get from https://huggingface.co/settings/tokens

You can set them via Colab's **Secrets** (left sidebar ‚Üí üîë icon) or paste directly below.

In [None]:
import os

# =====================================================
# OPTION 1: Use Colab Secrets (recommended)
# Add SERPAPI_API_KEY and HF_TOKEN in left sidebar ‚Üí üîë
# =====================================================
try:
    from google.colab import userdata
    SERPAPI_API_KEY = userdata.get('SERPAPI_API_KEY')
    print("‚úì SERPAPI_API_KEY loaded from Colab Secrets")
except Exception:
    # OPTION 2: Paste your key directly here
    SERPAPI_API_KEY = ""  # <-- PASTE YOUR SERPAPI KEY HERE
    if SERPAPI_API_KEY:
        print("‚úì SERPAPI_API_KEY set manually")
    else:
        print("‚ö†Ô∏è SERPAPI_API_KEY not set! Web search will fail.")

try:
    from google.colab import userdata
    HF_TOKEN = userdata.get('HF_TOKEN')
    print("‚úì HF_TOKEN loaded from Colab Secrets")
except Exception:
    # OPTION 2: Paste your HF token directly here
    HF_TOKEN = ""  # <-- PASTE YOUR HF TOKEN HERE (optional)
    if HF_TOKEN:
        print("‚úì HF_TOKEN set manually")
    else:
        print("‚ÑπÔ∏è HF_TOKEN not set (optional ‚Äî only needed for gated models)")

# Store in environment for the subprocess calls
os.environ['SERPAPI_API_KEY'] = SERPAPI_API_KEY or ''
os.environ['HF_TOKEN'] = HF_TOKEN or ''

print("\nAPI keys configured!")

## Step 6: Download ScienceQA Dataset

In [None]:
import os
os.chdir('/content/HMRAG')

# Create dataset directory
os.makedirs('dataset', exist_ok=True)
os.chdir('dataset')

# Clone the ScienceQA repository
if not os.path.exists('ScienceQA'):
    print("Cloning ScienceQA repository...")
    !git clone https://github.com/lupantech/ScienceQA
else:
    print("‚úì ScienceQA directory already exists")

if os.path.exists('ScienceQA'):
    os.chdir('ScienceQA')
    
    # Download the dataset
    if os.path.exists('tools/download.sh'):
        print("Downloading dataset files (this may take several minutes)...")
        !bash tools/download.sh
    else:
        print("download.sh not found, creating data directory...")
        os.makedirs('data', exist_ok=True)
    
    os.chdir('/content/HMRAG')

# Verify dataset structure
print("\n" + "=" * 50)
print("Checking required files:")
required_files = [
    'dataset/ScienceQA/data/scienceqa/problems.json',
    'dataset/ScienceQA/data/scienceqa/pid_splits.json'
]

# Also check alternative locations
alt_files = [
    'dataset/ScienceQA/data/problems.json',
    'dataset/ScienceQA/data/pid_splits.json'
]

data_root = None
for f in required_files:
    if os.path.exists(f):
        print(f"‚úì Found: {f}")
        if 'problems.json' in f:
            data_root = os.path.dirname(f)
    else:
        print(f"  Not at: {f}")

if data_root is None:
    for f in alt_files:
        if os.path.exists(f):
            print(f"‚úì Found: {f}")
            if 'problems.json' in f:
                data_root = os.path.dirname(f)
        else:
            print(f"  Not at: {f}")

if data_root:
    print(f"\n‚úì Data root: {data_root}")
else:
    print("\n‚ö†Ô∏è Could not find problems.json automatically.")
    print("Please check the dataset structure manually:")
    !find dataset/ScienceQA -name "problems.json" 2>/dev/null | head -5
    print("\nYou'll need to set --data_root accordingly in the run command.")

## Step 7: Verify Everything Before Running
Quick check that Ollama server is running and models are available.

In [None]:
import subprocess
import time
import os

os.chdir('/content/HMRAG')

# Check Ollama server
print("=" * 50)
print("CHECKING OLLAMA SERVER")
print("=" * 50)
try:
    result = subprocess.run(['curl', '-s', 'http://localhost:11434/api/tags'],
                          capture_output=True, text=True, timeout=5)
    if result.returncode == 0:
        print("‚úì Ollama server is running!")
    else:
        raise Exception("Not responding")
except Exception:
    print("‚ö†Ô∏è Ollama server not running. Restarting...")
    subprocess.run(['pkill', '-f', 'ollama'], stderr=subprocess.DEVNULL)
    time.sleep(2)
    
    result = subprocess.run(['which', 'ollama'], capture_output=True, text=True)
    ollama_path = result.stdout.strip() or '/usr/local/bin/ollama'
    
    subprocess.Popen([ollama_path, 'serve'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    time.sleep(8)
    
    result = subprocess.run(['curl', '-s', 'http://localhost:11434/api/tags'],
                          capture_output=True, text=True, timeout=5)
    if result.returncode == 0:
        print("‚úì Ollama server restarted!")
    else:
        print("‚ùå Failed to start Ollama. Restart runtime and rerun.")

print("\nAvailable models:")
!ollama list

# Check critical files
print("\n" + "=" * 50)
print("CHECKING SOURCE FILES")
print("=" * 50)
critical_files = [
    'main.py',
    'agents/decompose_agent.py',
    'agents/summary_agent.py',
    'agents/multi_retrieval_agents.py',
    'retrieval/vector_retrieval.py',
    'retrieval/graph_retrieval.py',
    'retrieval/web_retrieval.py',
    'retrieval/base_retrieval.py',
    'prompts/base_prompt.py',
]
for f in critical_files:
    if os.path.exists(f):
        print(f"‚úì {f}")
    else:
        print(f"‚ùå MISSING: {f}")

# Quick import test
print("\n" + "=" * 50)
print("TESTING IMPORTS")
print("=" * 50)
try:
    import sys
    sys.path.insert(0, '/content/HMRAG')
    from langchain_community.utilities import SerpAPIWrapper
    print("‚úì SerpAPIWrapper (google-search-results)")
    from langchain_ollama import OllamaLLM
    print("‚úì OllamaLLM (langchain-ollama)")
    from lightrag import LightRAG
    print("‚úì LightRAG (lightrag-hku)")
    from lightrag.llm.ollama import ollama_model_complete, ollama_embed
    print("‚úì ollama_model_complete, ollama_embed")
    print("\n‚úì All imports successful!")
except ImportError as e:
    print(f"‚ùå Import error: {e}")
    print("Try rerunning Step 2 (Install Dependencies)")

## Step 8: Run Inference ‚Äî Small Test (5 examples)
**Important:** Always delete `./lightrag_workdir` before running to avoid stale dimension-mismatch errors from previous runs.

Adjust `--data_root` below if your dataset location differs. Common paths:
- `./dataset/ScienceQA/data/scienceqa`
- `./dataset/ScienceQA/data`

In [None]:
import os
os.chdir('/content/HMRAG')

# IMPORTANT: Clean lightrag_workdir to avoid dimension mismatch from old runs
!rm -rf ./lightrag_workdir
!mkdir -p outputs

# Auto-detect data_root
data_root = ""
for candidate in [
    "./dataset/ScienceQA/data/scienceqa",
    "./dataset/ScienceQA/data",
]:
    if os.path.exists(os.path.join(candidate, "problems.json")):
        data_root = candidate
        break

if not data_root:
    print("‚ùå Could not find problems.json. Please set data_root manually.")
    print("Searching for it...")
    !find dataset/ -name "problems.json" 2>/dev/null
else:
    print(f"Using data_root: {data_root}")
    
    # Build the command
    serpapi_key = os.environ.get('SERPAPI_API_KEY', '')
    hf_token = os.environ.get('HF_TOKEN', '')
    
    cmd = f"""python3 main.py \
    --data_root {data_root} \
    --image_root ./dataset/ScienceQA/data/scienceqa \
    --output_root ./outputs \
    --working_dir ./lightrag_workdir \
    --serpapi_api_key "{serpapi_key}" \
    --llm_model_name qwen2.5:1.5b \
    --web_llm_model_name qwen2.5:1.5b \
    --test_split test \
    --test_number 5 \
    --shot_number 0 \
    --label test_run \
    --save_every 5"""
    
    if hf_token:
        cmd += f' --hf_token "{hf_token}"'
    
    print(f"\nRunning command:\n{cmd}\n")
    !{cmd}

## Step 9: Run Full Inference
After the small test works, run on the full test set. This will take a while.

In [None]:
import os
os.chdir('/content/HMRAG')

# IMPORTANT: Clean lightrag_workdir to avoid dimension mismatch
!rm -rf ./lightrag_workdir

# Auto-detect data_root
data_root = ""
for candidate in [
    "./dataset/ScienceQA/data/scienceqa",
    "./dataset/ScienceQA/data",
]:
    if os.path.exists(os.path.join(candidate, "problems.json")):
        data_root = candidate
        break

serpapi_key = os.environ.get('SERPAPI_API_KEY', '')
hf_token = os.environ.get('HF_TOKEN', '')

cmd = f"""python3 main.py \
    --data_root {data_root} \
    --image_root ./dataset/ScienceQA/data/scienceqa \
    --output_root ./outputs \
    --working_dir ./lightrag_workdir \
    --serpapi_api_key "{serpapi_key}" \
    --llm_model_name qwen2.5:1.5b \
    --web_llm_model_name qwen2.5:1.5b \
    --test_split test \
    --shot_number 2 \
    --label full_run \
    --save_every 50 \
    --use_caption"""

if hf_token:
    cmd += f' --hf_token "{hf_token}"'

print(f"Running command:\n{cmd}\n")
!{cmd}

## Step 10: View & Download Results

In [None]:
import os
import json
import glob

os.chdir('/content/HMRAG')

# List output files
print("Output files:")
!ls -lh outputs/

# Load and display results
output_files = sorted(glob.glob('outputs/*.json'))
if output_files:
    for fpath in output_files:
        print(f"\n{'=' * 50}")
        print(f"File: {os.path.basename(fpath)}")
        with open(fpath, 'r') as f:
            results = json.load(f)
        print(f"Total results: {len(results)}")
        print("Sample results:")
        for qid, answer in list(results.items())[:5]:
            print(f"  Question {qid}: Answer = {answer}")
else:
    print("No output files found yet. Run inference first.")

In [None]:
# Download results to your local machine
from google.colab import files
import os

os.chdir('/content/HMRAG')

# Zip all outputs
!zip -r outputs.zip outputs/
files.download('outputs.zip')
print("‚úì Download started!")

## Troubleshooting

### "Dimension mismatch" / "embedding_dim" errors
```
rm -rf ./lightrag_workdir
```
Then rerun. This happens when a previous run created the DB with a different embedding dimension.

### "Ollama connection refused"
Re-run the Ollama setup cell (Step 4) or the verification cell (Step 7).

### "SerpAPI error"
Make sure your `SERPAPI_API_KEY` is set correctly in Step 5. Get a free key at https://serpapi.com

### "CUDA out of memory"
The notebook uses `qwen2.5:1.5b` (text) and `Qwen2.5-VL-2B-Instruct` (vision). If you still run out of memory:
- Use `Runtime ‚Üí Change runtime type ‚Üí T4 GPU`
- Restart runtime and rerun all cells