# Retrieval augmented generation with VLMs Demo

Notebook to compare different architectures of VLMs capabilities for RAG.

## 1. Environment Setup
Set working directory (optional) and verify script path.

In [None]:
# Determine notebook and repo root paths
from pathlib import Path
notebook_dir = Path.cwd()
print('Notebook directory:', notebook_dir)
# Ascend one level to get desired repository root (/.../vlm_laboratories/vlm_laboratories)
repo_root = notebook_dir.parent
print('Derived repo root:', repo_root)
# Environment 
script_path = repo_root / 'prompt_engineering_lab' / 'live_vlm_test' / 'vlm_runners.py'
print('Phi3.5 runner script exists:', script_path.exists())


## 2. Import VLMRunnerPhi
Dynamically load the class from runners script.

In [None]:
# Import VLMRunnerPhi (headless safe)
import os, sys, importlib
os.environ['PYGLET_HEADLESS'] = 'true'
os.environ.pop('DISPLAY', None)
for m in list(sys.modules):
    if m.startswith('pyglet'):
        del sys.modules[m]
importlib.invalidate_caches()
import pyglet
pyglet.options['headless'] = True
pyglet.options['shadow_window'] = False
module_dir = repo_root / 'prompt_engineering_lab' / 'live_vlm_test'
if str(module_dir) not in sys.path:
    sys.path.append(str(module_dir))
from vlm_runners import VLMRunnerPhi
print('Imported VLMRunnerPhi from', module_dir)


## 3. Smoke Test Generation
Instantiate the runner and generate text from a synthetic image.

In [None]:
# Smoke test (simple description)
from PIL import Image
import torch, os, gc
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True,max_split_size_mb:64'
torch.cuda.empty_cache(); gc.collect()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
example_img_dir = repo_root / 'rag_database_lab' / 'example_images'
example_img_path = sorted(example_img_dir.glob('*.png'))[0] if list(example_img_dir.glob('*.png')) else None
print('Using example image:', example_img_path)
try:
    runner = VLMRunnerPhi(torch.device('cpu'))  # keep CPU for portability
    if example_img_path:
        test_img = Image.open(example_img_path).convert('RGB')
    else:
        test_img = Image.new('RGB', (224,224), color=(100,120,140))
    out = runner.generate(test_img, 'Describe the scene briefly.')
    print(out[:500])
except Exception as e:
    print('Generation failed:', e)


## 4. Load Traffic Rules
Load and normalize rules from `rules.json` for retrieval tests.

In [None]:
# Load rules
import json, pandas as pd
rules_path = repo_root / 'rag_database_lab' / 'rules.json'
with open(rules_path, 'r', encoding='utf-8') as f:
    raw_rules = json.load(f)
rules = []
for r in raw_rules:
    if isinstance(r, dict) and 'id' in r and 'rule_text' in r:
        rules.append({
            'id': r['id'],
            'text': ' '.join(r['rule_text'].strip().split())
        })
print(f'Loaded {len(rules)} rules')
pd.DataFrame(rules)


## 5. Select an Example Image
Choose one of the example traffic scene images to evaluate rules. You can re-run the next cell to sample another random image.

In [None]:
import random
from IPython.display import display

images_dir = repo_root / 'rag_database_lab' / 'example_images'
all_images = sorted([p for p in images_dir.glob('*.png')])
if not all_images:
    raise FileNotFoundError(f'No PNG images found in {images_dir}')

selected_image = random.choice(all_images)
print(f'Selected image: {selected_image.name}')

# Display the chosen image
try:
    from PIL import Image
    img = Image.open(selected_image)
    display(img)
except Exception as e:
    print('Could not open image with PIL:', e)

selected_image_path = str(selected_image)  # keep a str path for model calls

## 6. Iterative Rule Evaluation
We will ask Phi about each rule separately to obtain an applicability score. Prompt schema aims for a JSON with fields: rule_id, score (0-100), reason.

In [None]:
import time
from statistics import mean

# Ensure we have the runner instance; reuse existing if present
try:
    runner  # noqa: F821
except NameError:
    runner = None

if runner is None:
    # Fallback import if not already done
    from vlm_runners import VLMRunnerPhi # type: ignore
    runner = VLMRunnerPhi(torch.device("cpu"))

iterative_results = []

# Load base_instruction from a .txt file if available
prompt_txt_path = repo_root / 'rag_database_lab' / 'base_instruction.txt'
if prompt_txt_path.exists():
    with open(prompt_txt_path, 'r', encoding='utf-8') as f:
        base_instruction = f.read().strip()

start_time = time.time()
image = Image.open(selected_image_path).convert('RGB')


for r in rules:  # iterate all loaded rules
    rule_id = r["id"]
    rule_text = r["text"]
    prompt = (
        f"{base_instruction}\nRule ID: {rule_id}, rule {rule_text}\n" 
        f"Analyze applicability. Return JSON only."
    )

    raw = runner.generate(image, prompt)
    iterative_results.append(raw)
    print(raw)    

## 7. Batch Evaluation (All Rules at Once)
We provide all rules in a single prompt requesting a JSON array ranking applicability for the displayed image.

In [None]:
# Build batch prompt
rules_block_lines = [f"{r['id']}. {r['text']}" for r in rules]
rules_block = "\n".join(rules_block_lines)

# Load batch_instruction from a .txt file if available
batch_txt_path = repo_root / 'rag_database_lab' / 'batch_instruction.txt'
if batch_txt_path.exists():
    with open(batch_txt_path, 'r', encoding='utf-8') as f:
        batch_instruction = f.read().strip()
else:
    raise FileNotFoundError(f'Batch instruction file not found: {batch_txt_path}')

batch_prompt = (
    f"{batch_instruction}\nCandidate Rules:\n{rules_block}\n"
    f"Analyze the image and produce the output in the correct form."
)

batch_raw = runner.generate(image ,batch_prompt)
print(batch_raw)



## 8. Compare Iterative vs Batch Results
We align scores by rule_id, compute a simple correlation, and inspect agreement among top rules.

In [None]:
# Repair comparison logic for batch prompt that returns ONLY an ordered JSON array of rule_id strings.
# Batch output (batch_raw) contains strictly: ["RULE_ID_1", "RULE_ID_2", ..., "RULE_ID_N"]. No scores.
# Iterative results still include per-rule JSON objects with score + reason.

import re, json
from math import sqrt

# Helpers

def _extract_first_json_obj(text: str):
    m = re.search(r'\{[^{}]*\}', text, re.DOTALL)
    if not m:
        return None
    snippet = m.group(0)
    cleaned = snippet.replace("'", '"')
    cleaned = re.sub(r',\s*}', '}', cleaned)
    try:
        return json.loads(cleaned)
    except Exception:
        return None


def _norm_score(v):
    try:
        x = float(v)
    except Exception:
        return 0.0
    return max(0.0, min(1.0, x))

# Parse iterative (raw strings list aligned with rules order)
parsed_iter = []
for idx, raw in enumerate(iterative_results):
    rule_id = rules[idx]['id']
    rule_text = rules[idx]['text']
    obj = _extract_first_json_obj(raw) or {}
    score = _norm_score(obj.get('score', 0))
    reason = str(obj.get('reason', ''))[:300]
    parsed_iter.append({
        'rule_id': rule_id,
        'rule_text': rule_text,
        'score': score,
        'reason': reason,
        'raw': raw[:350]
    })

print(f"Parsed {len(parsed_iter)} iterative results with scores.")

# Parse batch: ordered array of rule_ids only
batch_order = []
batch_array_match = re.search(r'\[.*?\]', batch_raw, re.DOTALL)
if batch_array_match:
    snippet = batch_array_match.group(0)
    cleaned = snippet.replace("'", '"')
    # Remove trailing commas before closing bracket
    cleaned = re.sub(r',\s*]', ']', cleaned)
    try:
        data = json.loads(cleaned)
        if isinstance(data, list):
            batch_order = [str(x) for x in data]
    except Exception:
        batch_order = []

# Validate batch_order: ensure all rule_ids appear exactly once
expected_ids = [str(r['id']) for r in rules]
if set(batch_order) != set(expected_ids) or len(batch_order) != len(expected_ids):
    print("WARNING: Batch output missing or duplicating rule_ids. Attempting recovery.")
    # Fallback: if some IDs embedded individually, collect them preserving first occurrence order
    found_ids = []
    for rid in expected_ids:
        if rid in batch_raw and rid not in found_ids:
            found_ids.append(rid)
    if set(found_ids) == set(expected_ids):
        batch_order = found_ids

print(f"Parsed batch ordering of {len(batch_order)} rule_ids.")

# Build rank mapping (1-based). Earlier position => higher applicability.
batch_rank = {rid: i + 1 for i, rid in enumerate(batch_order)}
N = len(batch_order) if batch_order else 0

# Correlate iterative scores with batch ranking.
# Transform batch rank to a 'rank_score' where higher means more applicable: rank_score = N + 1 - rank
pairs = []
for entry in parsed_iter:
    rid = str(entry['rule_id'])
    if rid in batch_rank:
        iter_score = entry['score']
        rank_score = N + 1 - batch_rank[rid]
        pairs.append((iter_score, rank_score))

if pairs:
    xs = [p[0] for p in pairs]
    ys = [p[1] for p in pairs]
    mean_x = sum(xs)/len(xs)
    mean_y = sum(ys)/len(ys)
    num = sum((x-mean_x)*(y-mean_y) for x,y in pairs)
    den = (sum((x-mean_x)**2 for x in xs) * sum((y-mean_y)**2 for y in ys)) ** 0.5
    corr = num/den if den else 0.0
else:
    corr = 0.0

print(f"Correlation (iterative score vs batch rank position): {corr:.3f} over {len(pairs)} rules")

# Top-k iterative vs top-k batch (first k in ordering)
K = 5
iter_top_ids = [str(r['rule_id']) for r in sorted(parsed_iter, key=lambda x: x['score'], reverse=True)[:K]]
batch_top_ids = batch_order[:K]
overlap = sorted(set(iter_top_ids) & set(batch_top_ids))
print(f"Top-{K} overlap count: {len(overlap)} | IDs: {overlap}")

# Display alignment for overlapping rules
for rid in overlap:
    it_score = next(r['score'] for r in parsed_iter if str(r['rule_id']) == rid)
    brank = batch_rank.get(rid)
    text = next((r['text'] for r in rules if str(r['id']) == rid), '')
    print(f"Rule {rid}: iter_score={it_score:.2f} batch_rank={brank} | {text[:90]}")

# Build comparison summary structure
comparison_summary = {
    'correlation_iterative_score_vs_batch_rank': corr,
    'iter_top': iter_top_ids,
    'batch_top': batch_top_ids,
    'top_overlap': overlap,
    'iter_count': len(parsed_iter),
    'batch_count': N
}
comparison_summary

## 9. CLIP Embedding-Based Rule Retrieval
Use a dual-encoder (CLIP) to embed the selected image and all rule texts, rank rules by cosine similarity, and compare with Qwen batch ordering & iterative scores.

In [None]:
# Initialize CLIP retriever and compute rule/image embeddings
from clip_rule_retrieval import CLIPRuleImageRetriever, compute_metrics
from PIL import Image
import torch

clip_device = 'cuda' if torch.cuda.is_available() else 'cpu'
retriever = CLIPRuleImageRetriever(device=clip_device)
print(f'Loaded CLIP model on {clip_device}')

# Ensure image object exists
if 'image' not in globals():
    image = Image.open(selected_image_path).convert('RGB')

# Embed rules (list of dicts with id, text)
rule_embs, rule_ids, rule_texts = retriever.embed_rules(rules)
print(f'Embedded {len(rule_ids)} rules -> tensor shape {tuple(rule_embs.shape)}')

# Embed image
img_emb = retriever.embed_image(image)
print('Image embedding shape:', tuple(img_emb.shape))

# Rank rules by similarity
top_k = min(10, len(rule_ids))
ranking = retriever.rank_rules(img_emb, rule_embs, rule_ids, rule_texts, top_k=top_k)

print('\nTop CLIP rule ranking:')
for r in ranking:
    print(f"rule_id={r['rule_id']} sim={r['similarity']:.4f} | {r['rule_text'][:90]}")

# Prepare metrics inputs
clip_rank_ids = [str(r['rule_id']) for r in ranking]  # truncated list
full_clip_rank_ids = [str(r['rule_id']) for r in retriever.rank_rules(img_emb, rule_embs, rule_ids, rule_texts, top_k=None)]
iter_scores_map = {str(e['rule_id']): e['score'] for e in parsed_iter}

# Compute metrics using full ordering vs batch_order (if available)
metrics = compute_metrics(full_clip_rank_ids, batch_order, iter_scores_map, k=5)
print('\nExample metrics summary:')
for k,v in metrics.items():
    if isinstance(v, float):
        print(f'{k}: {v:.4f}')
    else:
        print(f'{k}: {v}')

clip_metrics = metrics  # expose for later cells