# HM-RAG: Hierarchical Multi-Agent Multimodal RAG ‚Äî Google Colab Setup

**What this notebook does:**
1. Clones the repo
2. Installs all dependencies (from `requirements.txt`)
3. Installs & starts Ollama, pulls `qwen2.5:1.5b` + `nomic-embed-text`
4. Patches source files to fix the **embedding dimension mismatch** (768 vs 1024)
5. Downloads ScienceQA dataset
6. Configures API keys (Serper API for web search, optional HF token)
7. Runs inference

**Model used:** `qwen2.5:1.5b` (text) + `Qwen/Qwen2.5-VL-2B-Instruct` (vision, only if images present)
**Web search:** Serper API only (requires a free key from https://serper.dev)
**LLM backend:** All LLM calls (including LightRAG) go through locally hosted Ollama ‚Äî no OpenAI API key needed.

## Step 1: Clone the Repository

In [None]:
import os

# Clone only if not already cloned
if not os.path.exists('/content/HMRAG'):
    !git clone https://github.com/ab-2109/HMRAG.git /content/HMRAG
    print("‚úì Repository cloned")
else:
    print("‚úì Repository already exists")

%cd /content/HMRAG
print(f"Working directory: {os.getcwd()}")

## Step 2: Install Dependencies
Installs all packages from `requirements.txt`. This handles Serper API (`requests`), LightRAG, LangChain, transformers, etc.

In [None]:
import os
os.chdir('/content/HMRAG')

print("Installing dependencies from requirements.txt (this may take a few minutes)...")
print("=" * 60)

# Install numpy first to avoid conflicts
!pip install -q numpy==1.26.4

# Install from the repo's requirements.txt
!pip install -q -r requirements.txt

# Ensure critical packages are installed (in case requirements.txt missed any)
!pip install -q requests langchain-ollama huggingface_hub

print("\n" + "=" * 60)
print("‚úì All dependencies installed successfully!")
print("Note: Some dependency warnings are normal and won't affect functionality.")

## Step 3: Patch Source Files (Critical Fixes)
This cell patches the cloned source files to:
1. **Fix embedding dimension mismatch**: `nomic-embed-text` outputs 768 dims, but LightRAG's `ollama_embed` decorator defaults to 1024. We use `ollama_embed.func` (unwrapped) and set `embedding_dim=768`.
2. **Use `qwen2.5:1.5b`** everywhere instead of `qwen2.5:7b` (fits in Colab GPU memory).
3. **LightRAG uses Ollama** (`ollama_model_complete`) ‚Äî no GPT-4o-mini / OpenAI API key needed.
4. **Reduce `num_ctx` to 4096** (the 1.5B model can't handle 65536).
5. **Web search uses Serper API** (`requests.post` to `https://google.serper.dev/search`).
6. **Use `Qwen/Qwen2-VL-2B-Instruct`** for vision (preprocessing image captioning & decision agent).
7. **Add HF token support** for downloading gated models.
8. **Fix device handling** in the vision model to avoid dimension mismatches on Colab.
9. **Create preprocessing module** ‚Äî Phase 1 (Section 3.1): image‚Üítext via VLM, assemble documents, index into LightRAG.

In [None]:
import os
os.chdir('/content/HMRAG')

# =============================================================================
# PATCH 1: retrieval/vector_retrieval.py
# =============================================================================
with open('retrieval/vector_retrieval.py', 'w') as f:
    f.write('''"""Vector-based Retrieval Agent (HM-RAG Layer 2, Section 3.3.1)."""

import asyncio
import logging
from typing import Any

from lightrag import LightRAG, QueryParam
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
from lightrag.utils import EmbeddingFunc

from retrieval.base_retrieval import BaseRetrieval

logger = logging.getLogger(__name__)

_SEED_DOCUMENT = (
    "Science is the systematic study of the natural world through observation "
    "and experimentation. Key branches include physics, chemistry, biology, "
    "earth science, and astronomy."
)


class VectorRetrieval(BaseRetrieval):
    MODE = "naive"

    def __init__(self, config):
        super().__init__(config)
        self.mode = self.MODE
        self._initialised = False
        ollama_host = getattr(config, 'ollama_base_url', 'http://localhost:11434')
        model_name = getattr(config, 'llm_model_name', 'qwen2.5:1.5b')
        working_dir = getattr(config, 'working_dir', './lightrag_workdir')

        self.client = LightRAG(
            working_dir=working_dir,
            llm_model_func=ollama_model_complete,
            llm_model_name=model_name,
            llm_model_max_async=4,
            llm_model_kwargs={"host": ollama_host, "options": {"num_ctx": 4096}},
            embedding_func=EmbeddingFunc(
                embedding_dim=768,
                max_token_size=8192,
                func=lambda texts: ollama_embed.func(
                    texts, embed_model="nomic-embed-text", host=ollama_host
                ),
            ),
        )
        logger.info('VectorRetrieval initialised | mode=%s | dir=%s', self.mode, working_dir)

    def _ensure_initialised(self):
        if self._initialised:
            return
        try:
            import nest_asyncio
            nest_asyncio.apply()
            loop = asyncio.get_event_loop()
            loop.run_until_complete(self.client.ainsert(_SEED_DOCUMENT))
            logger.info('VectorRetrieval: seed document inserted (fallback)')
        except Exception as e:
            logger.debug('VectorRetrieval: seed insert skipped: %s', e)
        self._initialised = True

    def find_top_k(self, query):
        self._ensure_initialised()
        try:
            result = self.client.query(
                query,
                param=QueryParam(mode=self.MODE, top_k=self.top_k)
            )
            return str(result) if result else ''
        except Exception as e:
            logger.error('VectorRetrieval error: %s', e)
            return f'Vector retrieval failed: {e}'
''')
print("‚úì Patched retrieval/vector_retrieval.py")

# =============================================================================
# PATCH 2: retrieval/graph_retrieval.py
# =============================================================================
with open('retrieval/graph_retrieval.py', 'w') as f:
    f.write('''"""Graph-based Retrieval Agent (HM-RAG Layer 2, Section 3.3.2)."""

import asyncio
import logging
from typing import Any

from lightrag import LightRAG, QueryParam
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
from lightrag.utils import EmbeddingFunc

from retrieval.base_retrieval import BaseRetrieval

logger = logging.getLogger(__name__)

_SEED_DOCUMENT = (
    "Science is the systematic study of the natural world through observation "
    "and experimentation. Key branches include physics, chemistry, biology, "
    "earth science, and astronomy."
)


class GraphRetrieval(BaseRetrieval):
    def __init__(self, config):
        super().__init__(config)
        self.mode = getattr(config, 'graph_search_mode', getattr(config, 'mode', 'mix'))
        self._initialised = False
        ollama_host = getattr(config, 'ollama_base_url', 'http://localhost:11434')
        model_name = getattr(config, 'llm_model_name', 'qwen2.5:1.5b')
        working_dir = getattr(config, 'working_dir', './lightrag_workdir')

        self.client = LightRAG(
            working_dir=working_dir,
            llm_model_func=ollama_model_complete,
            llm_model_name=model_name,
            llm_model_max_async=4,
            llm_model_kwargs={"host": ollama_host, "options": {"num_ctx": 4096}},
            embedding_func=EmbeddingFunc(
                embedding_dim=768,
                max_token_size=8192,
                func=lambda texts: ollama_embed.func(
                    texts, embed_model="nomic-embed-text", host=ollama_host
                ),
            ),
        )
        logger.info('GraphRetrieval initialised | mode=%s | dir=%s', self.mode, working_dir)

    def _ensure_initialised(self):
        if self._initialised:
            return
        try:
            import nest_asyncio
            nest_asyncio.apply()
            loop = asyncio.get_event_loop()
            loop.run_until_complete(self.client.ainsert(_SEED_DOCUMENT))
            logger.info('GraphRetrieval: seed document inserted (fallback)')
        except Exception as e:
            logger.debug('GraphRetrieval: seed insert skipped: %s', e)
        self._initialised = True

    def find_top_k(self, query):
        self._ensure_initialised()
        try:
            result = self.client.query(
                query,
                param=QueryParam(mode=self.mode, top_k=self.top_k)
            )
            return str(result) if result else ''
        except Exception as e:
            logger.error('GraphRetrieval error: %s', e)
            return f'Graph retrieval failed: {e}'
''')
print("‚úì Patched retrieval/graph_retrieval.py")

# =============================================================================
# PATCH 3: retrieval/web_retrieval.py ‚Äî Serper API
# =============================================================================
with open('retrieval/web_retrieval.py', 'w') as f:
    f.write('''"""Web-based Retrieval Agent (HM-RAG Layer 2, Section 3.3.3)."""

import logging
from typing import Any, Dict, List, Union

import requests
from langchain_ollama import OllamaLLM

from retrieval.base_retrieval import BaseRetrieval

logger = logging.getLogger(__name__)

_SERPER_URL = "https://google.serper.dev/search"

_SYNTHESIS_PROMPT = (
    "You are a helpful science question answering assistant.\\n"
    "Below are search results retrieved from the web for the given question.\\n"
    "Use ONLY the information in these search results to answer the question.\\n"
    "If the results do not contain enough information, say so.\\n"
    "Be concise and factual.\\n\\n"
    "Search results:\\n{results}\\n\\n"
    "Question: {query}\\n\\n"
    "Answer:"
)


class WebRetrieval(BaseRetrieval):
    def __init__(self, config):
        super().__init__(config)
        self.serper_api_key = getattr(config, 'serper_api_key', '')
        ollama_base_url = getattr(config, 'ollama_base_url', 'http://localhost:11434')
        web_llm_model = getattr(config, 'web_llm_model_name', 'qwen2.5:1.5b')

        self.llm = OllamaLLM(
            base_url=ollama_base_url,
            model=web_llm_model,
            temperature=0.35,
        )
        logger.info('WebRetrieval initialised | top_k=%d | model=%s', self.top_k, web_llm_model)

    def _serper_search(self, query):
        if not self.serper_api_key:
            raise RuntimeError('Serper API key is not set')
        headers = {'X-API-KEY': self.serper_api_key, 'Content-Type': 'application/json'}
        payload = {'q': query, 'num': self.top_k}
        resp = requests.post(_SERPER_URL, json=payload, headers=headers, timeout=30)
        resp.raise_for_status()
        return resp.json()

    def format_results(self, results):
        if isinstance(results, str):
            return results if results.strip() else 'No relevant results found.'
        if not isinstance(results, dict):
            return str(results) if results else 'No relevant results found.'
        snippets = []
        answer_box = results.get('answerBox')
        if answer_box and isinstance(answer_box, dict):
            answer_text = answer_box.get('answer') or answer_box.get('snippet') or ''
            if answer_text:
                snippets.append(f'Direct answer: {answer_text}')
        for item in results.get('organic', [])[:self.top_k]:
            title = item.get('title', 'No title')
            snippet = item.get('snippet', 'No snippet')
            snippets.append(f'[{title}]\\n{snippet}')
        knowledge = results.get('knowledgeGraph')
        if knowledge and isinstance(knowledge, dict):
            desc = knowledge.get('description', '')
            if desc:
                snippets.append(f'Knowledge Graph: {desc}')
        return '\\n\\n'.join(snippets) if snippets else 'No relevant results found.'

    def _generate(self, search_results, query):
        prompt = _SYNTHESIS_PROMPT.format(results=search_results, query=query)
        try:
            answer = self.llm.invoke(prompt)
            return answer.strip() if answer else ''
        except Exception as e:
            logger.error('WebRetrieval generation failed: %s', e)
            return f'Web generation failed: {e}'

    def find_top_k(self, query):
        try:
            raw_results = self._serper_search(query)
            formatted = self.format_results(raw_results)
            answer = self._generate(formatted, query)
            return answer
        except Exception as e:
            logger.error('WebRetrieval failed: %s', e)
            return f'Web retrieval failed: {e}'
''')
print("‚úì Patched retrieval/web_retrieval.py (Serper API)")

# =============================================================================
# PATCH 4: agents/decompose_agent.py
# =============================================================================
with open('agents/decompose_agent.py', 'w') as f:
    f.write('''"""Decomposition Agent (HM-RAG Layer 1)."""

import re
from typing import List
from langchain_core.prompts import PromptTemplate
from langchain_ollama import OllamaLLM


class DecomposeAgent:
    def __init__(self, config):
        self.config = config
        self.llm = OllamaLLM(
            base_url=getattr(config, 'ollama_base_url', 'http://localhost:11434'),
            model=getattr(config, 'llm_model_name', 'qwen2.5:1.5b'),
            temperature=getattr(config, 'temperature', 0.35),
        )

    def count_intents(self, query: str) -> int:
        prompt = PromptTemplate.from_template(
            "Please calculate how many independent intents are contained in the following query. "
            "Return only an integer:\\n{query}\\nNumber of intents: "
        )
        for attempt in range(3):
            response = self.llm.invoke(prompt.format(query=query))
            numbers = re.findall(r'\\d+', response.strip())
            if numbers:
                return int(numbers[0])
        return 1

    def decompose(self, query: str) -> List[str]:
        intent_count = min(self.count_intents(query), 3)
        if intent_count > 1:
            return self._split_query(query)
        return [query]

    def _split_query(self, query: str) -> List[str]:
        prompt = PromptTemplate.from_template(
            "Split the following query into multiple independent sub-queries, "
            "separated by '||', without additional explanations:\\n{query}\\nList of sub-queries: "
        )
        response = self.llm.invoke(prompt.format(query=query))
        sub_queries = [q.strip() for q in response.split("||") if q.strip()]
        return sub_queries if sub_queries else [query]
''')
print("‚úì Patched agents/decompose_agent.py")

# =============================================================================
# PATCH 5: agents/summary_agent.py ‚Äî vision model with load-once guard
# =============================================================================
with open('agents/summary_agent.py', 'w') as f:
    f.write('''"""Decision Agent (HM-RAG Layer 3) ‚Äî voting + expert refinement."""

from collections import Counter
from langchain_ollama import OllamaLLM
import re
from transformers import AutoProcessor
import random
import os
import torch

from prompts.base_prompt import build_prompt


class SummaryAgent:
    def __init__(self, config):
        self.config = config
        self.text_llm = OllamaLLM(
            base_url=getattr(config, 'ollama_base_url', 'http://localhost:11434'),
            model=getattr(config, 'llm_model_name', 'qwen2.5:1.5b')
        )
        self.hf_token = getattr(config, 'hf_token', '') or os.environ.get('HF_TOKEN', '')
        self._vision_model = None
        self._processor = None
        self._vision_load_attempted = False

    def _load_vision_model(self):
        if self._vision_load_attempted:
            return
        self._vision_load_attempted = True
        try:
            from transformers import Qwen2VLForConditionalGeneration

            model_name = "Qwen/Qwen2-VL-2B-Instruct"
            token_kwargs = {}
            if self.hf_token:
                token_kwargs['token'] = self.hf_token

            self._processor = AutoProcessor.from_pretrained(
                model_name, use_fast=True, **token_kwargs
            )
            self._vision_model = Qwen2VLForConditionalGeneration.from_pretrained(
                model_name,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                device_map="auto" if torch.cuda.is_available() else None,
                **token_kwargs
            )
            print(f"‚úì Vision model {model_name} loaded")
        except Exception as e:
            print(f"Warning: Could not load vision model: {e}")
            self._vision_model = None
            self._processor = None

    def summarize(self, problems, shot_qids, qid, cur_ans):
        problem = problems[qid]
        question = problem['question']
        choices = problem["choices"]
        answer = problem['answer']
        image = problem.get('image', '')
        split = problem.get("split", "test")

        most_ans = self.get_most_common_answer(cur_ans)

        if len(most_ans) == 1:
            prediction = self.get_result(most_ans[0])
            pred_idx = self.get_pred_idx(prediction, choices, self.config.options)
        else:
            if image and image == "image.png":
                image_path = os.path.join(self.config.image_root, split, qid, image)
            else:
                image_path = ""

            output_text = cur_ans[0] if len(cur_ans) > 0 else ""
            output_graph = cur_ans[1] if len(cur_ans) > 1 else ""
            output_web = cur_ans[2] if len(cur_ans) > 2 else ""

            output = self.refine(output_text, output_graph, output_web,
                                 problems, shot_qids, qid, self.config, image_path)
            if output is None:
                output = "FAILED"
            print(f"output: {output}")
            ans_fusion = self.get_result(output)
            pred_idx = self.get_pred_idx(ans_fusion, choices, self.config.options)
        return pred_idx, cur_ans

    def get_most_common_answer(self, res):
        if not res:
            return []
        counter = Counter(res)
        max_count = max(counter.values())
        return [item for item, count in counter.items() if count == max_count]

    def refine(self, output_text, output_graph, output_web, problems, shot_qids, qid, args, image_path):
        prompt = build_prompt(problems, shot_qids, qid, args)
        prompt = f"{prompt} The answer is A, B, C, D, E or FAILED. \\n BECAUSE: "

        if not image_path:
            output = self.text_llm.invoke(prompt)
        else:
            output = self.qwen_reasoning(prompt, image_path)
            if output:
                output = self.text_llm.invoke(
                    f"{output[0]} Summary the above information with format "
                    f"'Answer: The answer is A, B, C, D, E or FAILED.\\n BECAUSE: '"
                )
            else:
                output = self.text_llm.invoke(prompt)
        return output

    def get_result(self, output):
        pattern = re.compile(r'The answer is ([A-E])')
        res = pattern.findall(output)
        return res[0] if len(res) == 1 else "FAILED"

    def get_pred_idx(self, prediction, choices, options):
        if prediction in options[:len(choices)]:
            return options.index(prediction)
        return random.choice(range(len(choices)))

    def qwen_reasoning(self, prompt, image_path):
        self._load_vision_model()
        if self._vision_model is None or self._processor is None:
            print("Warning: Vision model not available, falling back to text-only.")
            return None

        try:
            from qwen_vl_utils import process_vision_info
        except ImportError:
            print("Warning: qwen_vl_utils not installed, falling back to text-only.")
            return None

        if os.path.isfile(image_path) and not image_path.startswith(('http://', 'https://', 'file://')):
            image_uri = 'file://' + os.path.abspath(image_path)
        else:
            image_uri = image_path

        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image_uri},
                    {"type": "text", "text": prompt},
                ],
            }
        ]

        text = self._processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = self._processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )

        device = next(self._vision_model.parameters()).device
        inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}

        generated_ids = self._vision_model.generate(**inputs, max_new_tokens=512)
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)
        ]
        output_text = self._processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        return output_text
''')
print("‚úì Patched agents/summary_agent.py (load-once guard, file:// prefix, Qwen2-VL-2B)")

# =============================================================================
# PATCH 6: main.py ‚Äî add preprocessing step
# =============================================================================
with open('main.py', 'r') as f:
    main_content = f.read()

changes_made = []

# Fix model name
if "default='qwen2.5:7b'" in main_content:
    main_content = main_content.replace("default='qwen2.5:7b'", "default='qwen2.5:1.5b'")
    changes_made.append("qwen2.5:7b -> qwen2.5:1.5b")

# Fix API key naming
if '--serpapi_api_key' in main_content:
    main_content = main_content.replace('--serpapi_api_key', '--serper_api_key')
    main_content = main_content.replace('serpapi_api_key', 'serper_api_key')
    changes_made.append("serpapi -> serper")

# Add --hf_token if missing
if '--hf_token' not in main_content:
    main_content = main_content.replace(
        "return parser.parse_args()",
        "    parser.add_argument('--hf_token', type=str, default='',\n"
        "                        help='HF access token for gated models')\n"
        "    return parser.parse_args()"
    )
    changes_made.append("Added --hf_token")

# Add preprocessing import and call
if 'KnowledgeBaseBuilder' not in main_content:
    # Add import
    main_content = main_content.replace(
        'from agents.multi_retrieval_agents import MRetrievalAgent',
        'from agents.multi_retrieval_agents import MRetrievalAgent\n'
        'from preprocessing.build_knowledge_base import KnowledgeBaseBuilder'
    )
    # Add preprocessing call before agent init
    main_content = main_content.replace(
        '    agent = MRetrievalAgent(args)',
        '    # Phase 1: Build Knowledge Base (Section 3.1)\n'
        '    splits_path = os.path.join(args.data_root, "pid_splits.json")\n'
        '    with open(splits_path, "r") as f:\n'
        '        _pid_splits = json.load(f)\n'
        '    train_qids = _pid_splits.get("train", [])\n'
        '    builder = KnowledgeBaseBuilder(args)\n'
        '    builder.build(problems, train_qids)\n'
        '\n'
        '    agent = MRetrievalAgent(args)'
    )
    changes_made.append("Added preprocessing step")

# Add HF login if missing
if '_setup_hf_token' not in main_content and 'HF_TOKEN' not in main_content:
    main_content = main_content.replace(
        '    agent = MRetrievalAgent(args)',
        '    # Set HF token if provided\n'
        '    if hasattr(args, "hf_token") and args.hf_token:\n'
        '        os.environ["HF_TOKEN"] = args.hf_token\n'
        '        os.environ["HUGGING_FACE_HUB_TOKEN"] = args.hf_token\n'
        '        try:\n'
        '            from huggingface_hub import login\n'
        '            login(token=args.hf_token)\n'
        '            print("‚úì Logged in to Hugging Face Hub")\n'
        '        except Exception as e:\n'
        '            print(f"Warning: Could not login to HF Hub: {e}")\n'
        '\n'
        '    agent = MRetrievalAgent(args)'
    )
    changes_made.append("Added HF login")

with open('main.py', 'w') as f:
    f.write(main_content)

if changes_made:
    print("‚úì Patched main.py: " + ", ".join(changes_made))
else:
    print("‚úì main.py already up to date")

# =============================================================================
# PATCH 7: YAML configs
# =============================================================================
for yaml_file in ['configs/decompose_agent.yaml', 'configs/multi_retrieval_agents.yaml']:
    if os.path.exists(yaml_file):
        with open(yaml_file, 'r') as f:
            content = f.read()
        updated = False
        if 'qwen2.5:7b' in content:
            content = content.replace('qwen2.5:7b', 'qwen2.5:1.5b')
            updated = True
        if 'serpapi_api_key' in content:
            content = content.replace('serpapi_api_key', 'serper_api_key')
            updated = True
        if updated:
            with open(yaml_file, 'w') as f:
                f.write(content)
            print(f"‚úì Patched {yaml_file}")
        else:
            print(f"‚úì {yaml_file} already correct")

# =============================================================================
# PATCH 8: Create preprocessing module (Phase 1, Section 3.1)
# =============================================================================
os.makedirs('preprocessing', exist_ok=True)

with open('preprocessing/__init__.py', 'w') as f:
    f.write('''"""Preprocessing module for HM-RAG (Phase 1, Section 3.1)."""
from preprocessing.build_knowledge_base import KnowledgeBaseBuilder
__all__ = ["KnowledgeBaseBuilder"]
''')

with open('preprocessing/build_knowledge_base.py', 'w') as f:
    f.write('''"""
Knowledge-Base Builder - HM-RAG Phase 1 (Section 3.1).

Eq 1: D_img  = VLM(image)              - image-to-text via Qwen VLM
Eq 2: D_comb = concat(D_text, D_img)   - merge textual + visual info
Eq 3: KB     = LightRAG.insert(D_comb) - index into vector + graph DB
"""

import asyncio
import gc
import json
import logging
import os
from typing import Any, Dict, List, Optional

import torch

logger = logging.getLogger(__name__)
_MARKER = ".kb_built"
_INSERT_BATCH = 50
_VLM_MAX_TOKENS = 256


class ImageCaptioner:
    """Generate captions using Qwen2-VL-2B-Instruct (~1.5B non-embedding params)."""

    MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"

    def __init__(self, hf_token="", device=None):
        self.hf_token = hf_token
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self._model = None
        self._processor = None

    def _load(self):
        if self._model is not None:
            return
        from transformers import Qwen2VLForConditionalGeneration, AutoProcessor

        token_kw = {"token": self.hf_token} if self.hf_token else {}
        logger.info("Loading VLM %s ...", self.MODEL_ID)
        self._processor = AutoProcessor.from_pretrained(self.MODEL_ID, use_fast=True, **token_kw)
        dtype = torch.float16 if self.device == "cuda" else torch.float32
        self._model = Qwen2VLForConditionalGeneration.from_pretrained(
            self.MODEL_ID, torch_dtype=dtype,
            device_map="auto" if self.device == "cuda" else None,
            **token_kw,
        )
        if self.device != "cuda":
            self._model.to(self.device)
        logger.info("VLM loaded on %s", self.device)

    def unload(self):
        del self._model, self._processor
        self._model = None
        self._processor = None
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        logger.info("VLM unloaded")

    def caption(self, image_path):
        if not image_path or not os.path.isfile(image_path):
            return ""
        self._load()
        try:
            from qwen_vl_utils import process_vision_info
            messages = [{
                "role": "user",
                "content": [
                    {"type": "image", "image": f"file://{image_path}"},
                    {"type": "text", "text": (
                        "Describe this image in detail for a science "
                        "question-answering system. Include all visible "
                        "text, labels, diagrams, charts, and relevant "
                        "scientific information."
                    )},
                ],
            }]
            text = self._processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            image_inputs, video_inputs = process_vision_info(messages)
            inputs = self._processor(text=[text], images=image_inputs, videos=video_inputs,
                                     padding=True, return_tensors="pt")
            device = next(self._model.parameters()).device
            inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
            with torch.inference_mode():
                gen_ids = self._model.generate(**inputs, max_new_tokens=_VLM_MAX_TOKENS)
            trimmed = [out[len(inp):] for inp, out in zip(inputs["input_ids"], gen_ids)]
            return self._processor.batch_decode(trimmed, skip_special_tokens=True)[0].strip()
        except Exception as e:
            logger.warning("Caption failed for %s: %s", image_path, e)
            return ""


def _build_document(problem, qid, image_caption):
    parts = []
    subject = problem.get("subject", "")
    topic = problem.get("topic", "")
    if subject or topic:
        parts.append(f"Subject: {subject}  Topic: {topic}")

    q = problem.get("question", "")
    choices = problem.get("choices", [])
    if q:
        opts = " | ".join(f"({chr(65+i)}) {c}" for i, c in enumerate(choices))
        parts.append(f"Question: {q}\\nOptions: {opts}")

    hint = (problem.get("hint") or "").strip()
    if hint:
        parts.append(f"Hint: {hint}")

    lecture = (problem.get("lecture") or "").strip()
    if lecture:
        parts.append(f"Lecture: {lecture}")

    solution = (problem.get("solution") or "").strip()
    if solution:
        parts.append(f"Solution: {solution}")

    ds_caption = (problem.get("caption") or "").strip()
    if ds_caption:
        parts.append(f"Image caption: {ds_caption}")

    if image_caption:
        parts.append(f"Image description (VLM): {image_caption}")

    answer_idx = problem.get("answer")
    if answer_idx is not None and answer_idx < len(choices):
        parts.append(f"Correct answer: ({chr(65+answer_idx)}) {choices[answer_idx]}")

    return "\\n".join(parts)


class KnowledgeBaseBuilder:
    def __init__(self, config):
        self.config = config
        self.working_dir = getattr(config, "working_dir", "./lightrag_workdir")
        self.image_root = getattr(config, "image_root", "")
        self.hf_token = getattr(config, "hf_token", "") or os.environ.get("HF_TOKEN", "")

    def build(self, problems, qids):
        marker_path = os.path.join(self.working_dir, _MARKER)
        if os.path.exists(marker_path):
            logger.info("Knowledge base already built ‚Äî skipping. Delete %s to rebuild.", marker_path)
            return

        os.makedirs(self.working_dir, exist_ok=True)
        logger.info("=" * 60)
        logger.info("Phase 1: Building Knowledge Base (%d problems)", len(qids))
        logger.info("=" * 60)

        captions = self._caption_images(problems, qids)
        documents = self._assemble_documents(problems, qids, captions)
        self._index_documents(documents)

        with open(marker_path, "w") as f:
            f.write(f"Built from {len(documents)} documents\\n")
        logger.info("Phase 1 complete ‚Äî marker written to %s", marker_path)

    def _caption_images(self, problems, qids):
        cache_path = os.path.join(self.working_dir, "vlm_captions.json")
        if os.path.exists(cache_path):
            logger.info("Loading cached VLM captions from %s", cache_path)
            with open(cache_path, "r") as f:
                return json.load(f)

        image_qids = []
        for qid in qids:
            img = problems[qid].get("image", "")
            if img:
                split = problems[qid].get("split", "train")
                img_path = os.path.join(self.image_root, split, qid, img)
                if os.path.isfile(img_path):
                    image_qids.append((qid, img_path))

        captions = {qid: "" for qid in qids}
        if not image_qids:
            logger.info("No images found ‚Äî skipping VLM captioning")
            return captions

        logger.info("Captioning %d images with VLM ...", len(image_qids))
        captioner = ImageCaptioner(hf_token=self.hf_token)
        try:
            for i, (qid, img_path) in enumerate(image_qids):
                captions[qid] = captioner.caption(img_path)
                if (i + 1) % 50 == 0 or (i + 1) == len(image_qids):
                    logger.info("  captioned %d / %d images", i + 1, len(image_qids))
        finally:
            captioner.unload()

        with open(cache_path, "w") as f:
            json.dump(captions, f)
        logger.info("VLM captions cached to %s", cache_path)
        return captions

    def _assemble_documents(self, problems, qids, captions):
        documents = []
        for qid in qids:
            doc = _build_document(problems[qid], qid, captions.get(qid, ""))
            if doc.strip():
                documents.append(doc)
        logger.info("Assembled %d documents", len(documents))
        return documents

    def _index_documents(self, documents):
        import nest_asyncio
        nest_asyncio.apply()

        from lightrag import LightRAG
        from lightrag.llm.ollama import ollama_model_complete, ollama_embed
        from lightrag.utils import EmbeddingFunc

        ollama_host = getattr(self.config, "ollama_base_url", "http://localhost:11434")
        model_name = getattr(self.config, "llm_model_name", "qwen2.5:1.5b")

        rag = LightRAG(
            working_dir=self.working_dir,
            llm_model_func=ollama_model_complete,
            llm_model_name=model_name,
            llm_model_max_async=4,
            llm_model_kwargs={"host": ollama_host, "options": {"num_ctx": 4096}},
            embedding_func=EmbeddingFunc(
                embedding_dim=768,
                max_token_size=8192,
                func=lambda texts: ollama_embed.func(
                    texts, embed_model="nomic-embed-text", host=ollama_host,
                ),
            ),
        )

        total = len(documents)
        logger.info("Indexing %d documents into LightRAG ...", total)
        loop = asyncio.get_event_loop()
        for start in range(0, total, _INSERT_BATCH):
            batch = documents[start : start + _INSERT_BATCH]
            combined = "\\n\\n---\\n\\n".join(batch)
            try:
                loop.run_until_complete(rag.ainsert(combined))
            except Exception as e:
                logger.error("Insert failed for batch %d-%d: %s", start, start + len(batch), e)
            done = min(start + _INSERT_BATCH, total)
            logger.info("  indexed %d / %d documents", done, total)
        logger.info("LightRAG indexing complete")
''')
print("‚úì Created preprocessing/build_knowledge_base.py (Phase 1: VLM captioning + LightRAG indexing)")

# =============================================================================
# PATCH 9: Install nest_asyncio + Pillow
# =============================================================================
!pip install -q nest_asyncio Pillow

# Clean up stale working directory (preprocessing will rebuild it)
!rm -rf ./lightrag_workdir

print("\n" + "=" * 60)
print("‚úì All patches applied!")
print("  - embedding_dim = 768 (matches nomic-embed-text)")
print("  - ollama_embed.func (bypasses 1024-dim decorator)")
print("  - num_ctx = 4096 (fits qwen2.5:1.5b)")
print("  - Text model: qwen2.5:1.5b (via Ollama)")
print("  - LightRAG LLM: ollama_model_complete (NOT GPT-4o-mini)")
print("  - Vision model: Qwen/Qwen2-VL-2B-Instruct")
print("  - Web search: Serper API (https://serper.dev)")
print("  - HF token support added")
print("  - Phase 1 preprocessing: VLM captioning + LightRAG indexing")
print("  - Knowledge base is built from training data before evaluation")

## Step 4: Install and Start Ollama + Pull Models
Ollama runs locally on the Colab VM. We pull `qwen2.5:1.5b` (~1GB) and `nomic-embed-text` (~270MB).

In [None]:
import subprocess
import time
import os

# Install system dependencies
print("Installing system dependencies...")
!sudo apt-get update -qq 2>/dev/null
!sudo apt-get install -y -qq zstd 2>/dev/null

# Install Ollama
print("Installing Ollama...")
!curl -fsSL https://ollama.com/install.sh | sh 2>&1 | tail -3

# Find ollama binary
result = subprocess.run(['which', 'ollama'], capture_output=True, text=True)
ollama_path = result.stdout.strip()
if not ollama_path:
    for path in ['/usr/local/bin/ollama', '/usr/bin/ollama']:
        if os.path.exists(path):
            ollama_path = path
            break

if not ollama_path:
    print("‚ùå Ollama binary not found! Please restart runtime and try again.")
else:
    print(f"‚úì Ollama found at: {ollama_path}")
    
    # Kill any existing ollama processes
    subprocess.run(['pkill', '-f', 'ollama'], stderr=subprocess.DEVNULL)
    time.sleep(2)
    
    # Start Ollama server in background
    print("Starting Ollama server...")
    ollama_process = subprocess.Popen(
        [ollama_path, 'serve'],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE
    )
    time.sleep(8)
    
    # Verify server is running
    result = subprocess.run(['curl', '-s', 'http://localhost:11434/api/tags'],
                          capture_output=True, text=True, timeout=5)
    if result.returncode == 0:
        print("‚úì Ollama server is running!")
    else:
        print("‚ö†Ô∏è Ollama server may not be ready yet. Waiting more...")
        time.sleep(10)
    
    # Pull the text model
    print("\nPulling qwen2.5:1.5b model (~1GB, may take 2-5 min)...")
    !{ollama_path} pull qwen2.5:1.5b
    
    # Pull the embedding model
    print("\nPulling nomic-embed-text model (~270MB)...")
    !{ollama_path} pull nomic-embed-text
    
    print("\n‚úì Ollama setup complete!")
    print("\nAvailable models:")
    !{ollama_path} list

## Step 5: Configure API Keys
- **Serper API key** (required for web search): Get a free key at https://serper.dev (2,500 free searches)
- **HF token** (optional, for gated models): Get from https://huggingface.co/settings/tokens

You can set them via Colab's **Secrets** (left sidebar ‚Üí üîë icon) or paste directly below.

In [None]:
import os

# =====================================================
# OPTION 1: Use Colab Secrets (recommended)
# Add SERPER_API_KEY and HF_TOKEN in left sidebar ‚Üí üîë
# =====================================================
try:
    from google.colab import userdata
    SERPER_API_KEY = userdata.get('SERPER_API_KEY')
    print("‚úì SERPER_API_KEY loaded from Colab Secrets")
except Exception:
    # OPTION 2: Paste your key directly here
    SERPER_API_KEY = ""  # <-- PASTE YOUR SERPER KEY HERE (from https://serper.dev)
    if SERPER_API_KEY:
        print("‚úì SERPER_API_KEY set manually")
    else:
        print("‚ö†Ô∏è SERPER_API_KEY not set! Web search will fail.")

try:
    from google.colab import userdata
    HF_TOKEN = userdata.get('HF_TOKEN')
    print("‚úì HF_TOKEN loaded from Colab Secrets")
except Exception:
    # OPTION 2: Paste your HF token directly here
    HF_TOKEN = ""  # <-- PASTE YOUR HF TOKEN HERE (optional)
    if HF_TOKEN:
        print("‚úì HF_TOKEN set manually")
    else:
        print("‚ÑπÔ∏è HF_TOKEN not set (optional ‚Äî only needed for gated models)")

# Store in environment for the subprocess calls
os.environ['SERPER_API_KEY'] = SERPER_API_KEY or ''
os.environ['HF_TOKEN'] = HF_TOKEN or ''

print("\nAPI keys configured!")

## Step 6: Download ScienceQA Dataset

In [None]:
import os
os.chdir('/content/HMRAG')

# Create dataset directory
os.makedirs('dataset', exist_ok=True)
os.chdir('dataset')

# Clone the ScienceQA repository
if not os.path.exists('ScienceQA'):
    print("Cloning ScienceQA repository...")
    !git clone https://github.com/lupantech/ScienceQA
else:
    print("‚úì ScienceQA directory already exists")

if os.path.exists('ScienceQA'):
    os.chdir('ScienceQA')
    
    # Download the dataset
    if os.path.exists('tools/download.sh'):
        print("Downloading dataset files (this may take several minutes)...")
        !bash tools/download.sh
    else:
        print("download.sh not found, creating data directory...")
        os.makedirs('data', exist_ok=True)
    
    os.chdir('/content/HMRAG')

# Verify dataset structure
print("\n" + "=" * 50)
print("Checking required files:")
required_files = [
    'dataset/ScienceQA/data/scienceqa/problems.json',
    'dataset/ScienceQA/data/scienceqa/pid_splits.json'
]

# Also check alternative locations
alt_files = [
    'dataset/ScienceQA/data/problems.json',
    'dataset/ScienceQA/data/pid_splits.json'
]

data_root = None
for f in required_files:
    if os.path.exists(f):
        print(f"‚úì Found: {f}")
        if 'problems.json' in f:
            data_root = os.path.dirname(f)
    else:
        print(f"  Not at: {f}")

if data_root is None:
    for f in alt_files:
        if os.path.exists(f):
            print(f"‚úì Found: {f}")
            if 'problems.json' in f:
                data_root = os.path.dirname(f)
        else:
            print(f"  Not at: {f}")

if data_root:
    print(f"\n‚úì Data root: {data_root}")
else:
    print("\n‚ö†Ô∏è Could not find problems.json automatically.")
    print("Please check the dataset structure manually:")
    !find dataset/ScienceQA -name "problems.json" 2>/dev/null | head -5
    print("\nYou'll need to set --data_root accordingly in the run command.")

## Step 7: Verify Everything Before Running
Quick check that Ollama server is running and models are available.

In [None]:
import subprocess
import time
import os

os.chdir('/content/HMRAG')

# Check Ollama server
print("=" * 50)
print("CHECKING OLLAMA SERVER")
print("=" * 50)
try:
    result = subprocess.run(['curl', '-s', 'http://localhost:11434/api/tags'],
                          capture_output=True, text=True, timeout=5)
    if result.returncode == 0:
        print("‚úì Ollama server is running!")
    else:
        raise Exception("Not responding")
except Exception:
    print("‚ö†Ô∏è Ollama server not running. Restarting...")
    subprocess.run(['pkill', '-f', 'ollama'], stderr=subprocess.DEVNULL)
    time.sleep(2)
    
    result = subprocess.run(['which', 'ollama'], capture_output=True, text=True)
    ollama_path = result.stdout.strip() or '/usr/local/bin/ollama'
    
    subprocess.Popen([ollama_path, 'serve'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    time.sleep(8)
    
    result = subprocess.run(['curl', '-s', 'http://localhost:11434/api/tags'],
                          capture_output=True, text=True, timeout=5)
    if result.returncode == 0:
        print("‚úì Ollama server restarted!")
    else:
        print("‚ùå Failed to start Ollama. Restart runtime and rerun.")

print("\nAvailable models:")
!ollama list

# Check critical files
print("\n" + "=" * 50)
print("CHECKING SOURCE FILES")
print("=" * 50)
critical_files = [
    'main.py',
    'agents/decompose_agent.py',
    'agents/summary_agent.py',
    'agents/multi_retrieval_agents.py',
    'retrieval/vector_retrieval.py',
    'retrieval/graph_retrieval.py',
    'retrieval/web_retrieval.py',
    'retrieval/base_retrieval.py',
    'prompts/base_prompt.py',
]
for f in critical_files:
    if os.path.exists(f):
        print(f"‚úì {f}")
    else:
        print(f"‚ùå MISSING: {f}")

# Quick import test
print("\n" + "=" * 50)
print("TESTING IMPORTS")
print("=" * 50)
try:
    import sys
    sys.path.insert(0, '/content/HMRAG')
    import requests
    print("‚úì requests (for Serper API)")
    from langchain_ollama import OllamaLLM
    print("‚úì OllamaLLM (langchain-ollama)")
    from lightrag import LightRAG
    print("‚úì LightRAG (lightrag-hku)")
    from lightrag.llm.ollama import ollama_model_complete, ollama_embed
    print("‚úì ollama_model_complete, ollama_embed (Ollama backend for LightRAG)")
    print("\n‚úì All imports successful!")
    print("  LightRAG LLM: Ollama qwen2.5:1.5b (NOT GPT-4o-mini)")
    print("  Web search: Serper API (requests.post)")
except ImportError as e:
    print(f"‚ùå Import error: {e}")
    print("Try rerunning Step 2 (Install Dependencies)")

## Step 8: Run Inference ‚Äî Small Test (5 examples)
**Phase 1 (Preprocessing)** runs automatically on first launch and builds the knowledge base from training data. This is cached in `./lightrag_workdir/` ‚Äî subsequent runs skip this step.

**To force a rebuild**, delete `./lightrag_workdir/.kb_built` (or the entire directory).

Adjust `--data_root` below if your dataset location differs.

In [None]:
import os
os.chdir('/content/HMRAG')

# The preprocessing step (Phase 1) will build the knowledge base on
# first run and cache it in ./lightrag_workdir/.  To force a rebuild:
#   !rm -rf ./lightrag_workdir
!mkdir -p outputs

# Auto-detect data_root
data_root = ""
for candidate in [
    "./dataset/ScienceQA/data/scienceqa",
    "./dataset/ScienceQA/data",
]:
    if os.path.exists(os.path.join(candidate, "problems.json")):
        data_root = candidate
        break

if not data_root:
    print("‚ùå Could not find problems.json. Please set data_root manually.")
    print("Searching for it...")
    !find dataset/ -name "problems.json" 2>/dev/null
else:
    print(f"Using data_root: {data_root}")
    
    # Build the command
    serper_key = os.environ.get('SERPER_API_KEY', '')
    hf_token = os.environ.get('HF_TOKEN', '')
    
    cmd = f"""python3 main.py \
    --data_root {data_root} \
    --image_root ./dataset/ScienceQA/data/scienceqa \
    --output_root ./outputs \
    --working_dir ./lightrag_workdir \
    --serper_api_key "{serper_key}" \
    --llm_model_name qwen2.5:1.5b \
    --web_llm_model_name qwen2.5:1.5b \
    --test_split test \
    --test_number 5 \
    --shot_number 0 \
    --label test_run \
    --save_every 5"""
    
    if hf_token:
        cmd += f' --hf_token "{hf_token}"'
    
    print(f"\nRunning command:\n{cmd}\n")
    !{cmd}

## Step 9: Run Full Inference
After the small test works, run on the full test set. This will take a while.

In [None]:
import os
os.chdir('/content/HMRAG')

# Knowledge base in ./lightrag_workdir is reused from previous runs.
# To force a full rebuild: !rm -rf ./lightrag_workdir

# Auto-detect data_root
data_root = ""
for candidate in [
    "./dataset/ScienceQA/data/scienceqa",
    "./dataset/ScienceQA/data",
]:
    if os.path.exists(os.path.join(candidate, "problems.json")):
        data_root = candidate
        break

serper_key = os.environ.get('SERPER_API_KEY', '')
hf_token = os.environ.get('HF_TOKEN', '')

cmd = f"""python3 main.py \
    --data_root {data_root} \
    --image_root ./dataset/ScienceQA/data/scienceqa \
    --output_root ./outputs \
    --working_dir ./lightrag_workdir \
    --serper_api_key "{serper_key}" \
    --llm_model_name qwen2.5:1.5b \
    --web_llm_model_name qwen2.5:1.5b \
    --test_split test \
    --shot_number 2 \
    --label full_run \
    --save_every 50 \
    --use_caption"""

if hf_token:
    cmd += f' --hf_token "{hf_token}"'

print(f"Running command:\n{cmd}\n")
!{cmd}

## Step 10: View & Download Results

In [None]:
import os
import json
import glob

os.chdir('/content/HMRAG')

# List output files
print("Output files:")
!ls -lh outputs/

# Load and display results
output_files = sorted(glob.glob('outputs/*.json'))
if output_files:
    for fpath in output_files:
        print(f"\n{'=' * 50}")
        print(f"File: {os.path.basename(fpath)}")
        with open(fpath, 'r') as f:
            results = json.load(f)
        print(f"Total results: {len(results)}")
        print("Sample results:")
        for qid, answer in list(results.items())[:5]:
            print(f"  Question {qid}: Answer = {answer}")
else:
    print("No output files found yet. Run inference first.")

In [None]:
# Download results to your local machine
from google.colab import files
import os

os.chdir('/content/HMRAG')

# Zip all outputs
!zip -r outputs.zip outputs/
files.download('outputs.zip')
print("‚úì Download started!")

## Troubleshooting

### "Dimension mismatch" / "embedding_dim" errors
```
rm -rf ./lightrag_workdir
```
Then rerun. This happens when a previous run created the DB with a different embedding dimension.

### Preprocessing takes too long
The first run builds the knowledge base from training data (Phase 1). This is cached ‚Äî subsequent runs skip it. To force a rebuild:
```
rm ./lightrag_workdir/.kb_built
```

### VLM captioning is slow
Image captioning uses `Qwen2-VL-2B-Instruct`. On Colab T4 GPU, expect ~1-2 seconds per image. Captions are cached in `./lightrag_workdir/vlm_captions.json`.

### "Ollama connection refused"
Re-run the Ollama setup cell (Step 4) or the verification cell (Step 7).

### "Serper API error"
Make sure your `SERPER_API_KEY` is set correctly in Step 5. Get a free key at https://serper.dev (2,500 free searches).

### "CUDA out of memory"
The notebook uses `qwen2.5:1.5b` (text) and `Qwen/Qwen2-VL-2B-Instruct` (vision). If you still run out of memory:
- Use `Runtime ‚Üí Change runtime type ‚Üí T4 GPU`
- Restart runtime and rerun all cells
- The VLM is automatically unloaded after preprocessing to free GPU memory

### LightRAG LLM
LightRAG is configured to use `ollama_model_complete` with `qwen2.5:1.5b` ‚Äî it does **NOT** use GPT-4o-mini. No OpenAI API key is needed.