In [None]:
from dotenv import load_dotenv
from datasets import load_dataset
from typing import List, Any
import torch
import json
import os
from PIL import Image
from tqdm import tqdm
from unsloth import FastVisionModel
from unsloth.chat_templates import get_chat_template
from transformers import AutoTokenizer, AutoModelForCausalLM
load_dotenv()

if 'HF_TOKEN' not in os.environ:
    raise ValueError("HF_TOKEN environment variable is not set")

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!


True

In [2]:
dataset = load_dataset("crag-mm-2025/crag-mm-single-turn-public")

#### **Using Llama 8B to generate sub-queries**

In [3]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", dtype="auto", device_map="auto")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", dtype="auto", device_map="auto")
tokenizer.pad_token_id = tokenizer.eos_token_id
print("Model device:", model.device)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model device: cuda:0


In [4]:
# Task 2 implementation to generate sub question and query it to web search 
class Task2SubQuestionGenerator:
    def __init__(self, model, tokenizer) -> None:
        # model for subquestion generation
        self.model = model
        self.tokenizer = tokenizer
        from cragmm_search.search import UnifiedSearchPipeline
        # initiate both image and web search API
        ## validation
        self.search_pipeline = UnifiedSearchPipeline(
            image_model_name="openai/clip-vit-large-patch14-336",
            image_hf_dataset_id="crag-mm-2025/image-search-index-validation",
            text_model_name="BAAI/bge-large-en-v1.5",
            web_hf_dataset_id="crag-mm-2025/web-search-index-validation",
        )
    def remove_none_recursively(self, data) -> dict:
        """
        Recursively removes keys from a dictionary (and nested dictionaries/lists)
        where the value is None, an empty dictionary, or an empty list.
        """
        if isinstance(data, dict):
            # First, recursively clean the values
            cleaned_data = {
                k: self.remove_none_recursively(v) 
                for k, v in data.items()
            }
            # Then, remove keys where the cleaned value is None or empty
            return {
                k: v for k, v in cleaned_data.items() 
                if v is not None and v not in [{}, [], '<>', "<>"]
            }
        elif isinstance(data, list):
            cleaned_list = [self.remove_none_recursively(item) for item in data]
            # Remove None and empty items from the list
            return [
                item for item in cleaned_list 
                if item is not None and item != {} and item != []
            ]
        else:
            return data
    
    def generate_sub_question(self, query, image) -> List[str]:
        SYSTEM_PROMPT="""You are an expert query decomposition engine for a Retrieval-Augmented Generation (RAG) system. Your task is to analyze the user's complex input and break it down into distinct, simple sub-queries that are necessary to gather facts to answer the original question.

        Logic for Decomposition:
        1. Analyze the user's request to identify the underlying logical steps or variables needed.
        2. If the user asks a comparison question (e.g., "Is X better than Y?"), generate separate queries for the attributes of X and Y.
        3. If the user asks a conditional question (e.g., "Can a Toyota drive from A to B with 5 gallons?"), generate separate queries for the distance between A and B, and the fuel efficiency/tank capacity of the Toyota.
        4. Ensure every sub-query focuses on retrieving specific, factual information.

        Constraints:
        - Output only the list of sub-queries, one per line.
        - Do not include any conversational phrases, explanations, or numbering.
        - Focus on keywords and specific entities.
        - Do not include quotation marks.

        User Input: [Insert User Input Here]
        """
        # The function query the entities by image first. combine Phong's version
        # The output should be same as List[dict[str, Any]] with score and entities
        image_response = self.search_pipeline(image, k=1)
        extracted_entities = self.remove_none_recursively(image_response)
        text_query = f"query: {query}. Entity:'entity_name': '{extracted_entities}'. Generate sub-question that will help answer the query."
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": text_query},
        ]
        prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        inputs = self.tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            output = self.model.generate(
                **inputs,
                max_new_tokens=512,
                do_sample=False,
                temperature=1.0,
                pad_token_id=128001,
            )
        model_response = self.tokenizer.decode(output[0], skip_special_tokens=True)
        model_response = model_response.split("assistant")[-1].strip()
        filtered_questions = list(dict.fromkeys(model_response.split('\n')))
        return filtered_questions[0]
    
    def sub_question_web_search(self, image, query: str, k=1) -> str:
        sub_question = self.generate_sub_question(query, image)
        response = self.search_pipeline(sub_question, k=k)
        return self.clean_up_search_results(response)        

    def clean_up_search_results(self, results: List[List[dict[str, Any]]]) -> str:
        retrieval_documents = []
        for index, item in enumerate(results):
            retrieval_documents.append(f"Document {index+1}: {item[0]['page_snippet']}")
        return "\n".join(retrieval_documents)

In [None]:
question_generator = Task2SubQuestionGenerator(model, tokenizer)

Using device: cuda
Loading web search data from Hugging Face...


Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

[backoff|INFO]Backing off send_request(...) for 0.3s (requests.exceptions.ConnectionError: HTTPSConnectionPool(host='us.i.posthog.com', port=443): Max retries exceeded with url: /batch/ (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x76606cb483a0>: Failed to establish a new connection: [Errno 111] Connection refused')))


Successfully loaded collection with 904899 entries


[backoff|INFO]Backing off send_request(...) for 1.2s (requests.exceptions.ConnectionError: HTTPSConnectionPool(host='us.i.posthog.com', port=443): Max retries exceeded with url: /batch/ (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x76606cb48a30>: Failed to establish a new connection: [Errno 111] Connection refused')))
[backoff|INFO]Backing off send_request(...) for 3.5s (requests.exceptions.ConnectionError: HTTPSConnectionPool(host='us.i.posthog.com', port=443): Max retries exceeded with url: /batch/ (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x76606cb490c0>: Failed to establish a new connection: [Errno 111] Connection refused')))
[backoff|ERROR]Giving up send_request(...) after 4 tries (requests.exceptions.ConnectionError: HTTPSConnectionPool(host='us.i.posthog.com', port=443): Max retries exceeded with url: /batch/ (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x76608ff73cd0>: Failed to 

In [None]:
path = "Task2_web_search_results.jsonl"

folder = "final_outputs/crop"

if not os.path.exists(folder):
    raise FileNotFoundError("You have to run Task1 first to get the crop image.")

png_files = sorted(
    [f for f in os.listdir(folder) if f.lower().endswith(".png")],
    key=lambda x: int(os.path.splitext(x)[0])
)

processed_queries = set()

if os.path.exists(path):
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            try:
                data = json.loads(line)
                processed_queries.add(data['query'])
            except json.JSONDecodeError:
                continue

print(f"Found {len(processed_queries)} items already processed.")


with open(path, "a", encoding="utf-8") as f:
    
    for index in tqdm(range(len(png_files)), desc="Processing items"):
        image_path = os.path.join(folder, png_files[index])
        image = Image.open(image_path).convert("RGB")
        item = dataset['validation'][index]
            
        current_query = item['turns']['query'][0]
        
        # SKIP if we have done this one already
        if current_query in processed_queries:
            continue
        
        # Run the sub question generation and web search
        response = question_generator.sub_question_web_search(
            image, 
            current_query, 
            k=1
        )
        
        entry = {
            "query": current_query,
            "response": response
        }
        f.write(json.dumps(entry) + "\n")
        processed_queries.add(current_query)

In [None]:
data = []
path = "Task2_web_search_results_k_1.jsonl"

with open(path, "r", encoding="utf-8") as f:
    for line in f:
        data.append(json.loads(line))

In [None]:
def build_rag_prompt(query, retrieval_documents, image, tokenizer, model):
    """
    Build RAG prompt with image and return tokenized inputs ready for generation.
    
    Args:
        query: User question
        retrieval_documents: Retrieved document snippets
        image: PIL Image or image data
        tokenizer: Model tokenizer
        model: Model instance (for device placement)
    
    Returns:
        Tokenized inputs ready for model.generate()
    """
    messages = [
        {
            "role": "system",
            "content": (
                "You are an expert assistant. Answer the user question ONLY based on "
                "the provided retrieved documents. If the documents do not contain "
                "enough information, say 'I do not have enough information to answer.' "
                "Do NOT hallucinate."
            )
        },
        {
            "role": "user",
            "content": (
                f"<|image|>Retrieved Documents:\n{retrieval_documents}\n\n"
                f"User Question:\n{query}"
            )
        }
    ]

    # Convert to final prompt using Llama chat template
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    # Tokenize with image and move to device
    inputs = tokenizer(image, prompt, return_tensors="pt").to(model.device)
    
    return inputs

In [None]:
llama_model, llama_tokenizer = FastVisionModel.from_pretrained(
    model_name = 'unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit',
    load_in_4bit = True,
)
FastVisionModel.for_inference(llama_model)
llama_tokenizer = get_chat_template(
    llama_tokenizer,
    chat_template = "llama-3.2",
)

In [None]:
output_path = "Task2_final_answers.jsonl"

# Check what's already processed
processed_indices = set()
if os.path.exists(output_path):
    with open(output_path, "r", encoding="utf-8") as f:
        for line in f:
            try:
                df = json.loads(line)
                processed_indices.add(df['index'])
            except json.JSONDecodeError:
                continue

print(f"Found {len(processed_indices)} items already processed.")

# Process and save answers
with open(output_path, "a", encoding="utf-8") as f:
    for idx in tqdm(range(len(processed_indices)), desc="Generating ianswer"):
        item = dataset['validation'][idx]
        # Skip already processed items
        if idx in processed_indices:
            continue
            
        if item['image'] is None:
            continue
            
        query = item['turns']['query'][0]
        image = item['image']
        prompt = data[idx]['response']

        inputs = build_rag_prompt(query, prompt, image, llama_tokenizer, llama_model)
        
        with torch.no_grad():
            output = llama_model.generate(
                **inputs,
                max_new_tokens=100,
                do_sample=False,
                temperature=0.0,
                pad_token_id=llama_tokenizer.pad_token_id,
            )
        
        full_response = llama_tokenizer.decode(output[0], skip_special_tokens=True)
        
        # Extract only the assistant's answer (after the last "assistant" marker)
        final_answer = full_response.split("assistant")[-1].strip()
        
        entry = {
            "index": idx,
            "query": query,
            "answer": final_answer
        }
        f.write(json.dumps(entry) + "\n")
        f.flush()  
    