# VLM Comparison: Base vs Fine-tuned

Compare your fine-tuned VLM model against the base model to evaluate improvements.

## Usage
1. Set `EXP_DIR` to your experiment directory (e.g., `runs/exp_20250102_xxx`)
2. Run all cells - paths are auto-loaded from experiment config

## Notes
- Config is auto-loaded from `{EXP_DIR}/config.json` (saved during training)
- For end-to-end testing (YOLO + VLM), use `predict.py --vlm`

## 1. Configuration

In [None]:
import json
from pathlib import Path

# ============================================================
# ONLY SET THIS - everything else is auto-loaded from config
# ============================================================
EXP_DIR = "../runs/exp_xxx"  # Your experiment directory
# ============================================================

# Load experiment config
config_path = Path(EXP_DIR) / "config.json"
if not config_path.exists():
    raise FileNotFoundError(
        f"Config not found: {config_path}\n"
        "Make sure you're using an experiment trained with the latest train.py"
    )

with open(config_path) as f:
    config = json.load(f)

# Auto-load settings from config
MODEL_NAME = config["vlm"]["model"]
PRECISION = config["vlm"]["precision"]
VLM_DATA_DIR = config["vlm"]["data_dir"]
ADAPTER_PATH = config["vlm"]["adapter"] or f"{EXP_DIR}/vlm/best"

print(f"Loaded config from: {config_path}")
print(f"  Model: {MODEL_NAME}")
print(f"  Precision: {PRECISION}")
print(f"  VLM Data: {VLM_DATA_DIR}")
print(f"  Adapter: {ADAPTER_PATH}")

In [None]:
# Test images - VLM validation set (auto-loaded from config)
TEST_FOLDER = f"{VLM_DATA_DIR}/images/val"

# Your evaluation prompt - customize if needed
PROMPT = """
Look at the bounding box in the image.
What do you see inside the marked area?
""".strip()

# Optional: System prompt (should match training config)
SYSTEM_PROMPT = """
You are an object detection assistant.
When shown an image with a bounding box, identify what is inside the marked area.
""".strip()

# Set to None if you didn't use system prompt during training
# SYSTEM_PROMPT = None

In [None]:
# Find test images (limited for quick comparison)
import glob
import random

MAX_IMAGES = 20  # Limit for comparison - adjust as needed

image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.webp"]
all_images = []

for ext in image_extensions:
    all_images.extend(glob.glob(f"{TEST_FOLDER}/{ext}"))
    all_images.extend(glob.glob(f"{TEST_FOLDER}/**/{ext}", recursive=True))

all_images = sorted(set(all_images))

# Filter out global images (_all.jpg) - they have multiple boxes
# Only keep per-box images (_box*.jpg) for "marked area" prompts
per_box_images = [img for img in all_images if "_all." not in img]
global_images = [img for img in all_images if "_all." in img]

print(f"Found {len(all_images)} total images:")
print(f"  - {len(per_box_images)} per-box images (single bbox)")
print(f"  - {len(global_images)} global images (multiple bboxes) - excluded")

# Sample from per-box images only
if len(per_box_images) > MAX_IMAGES:
    test_images = random.sample(per_box_images, MAX_IMAGES)
    test_images = sorted(test_images)
    print(f"\nSampled {MAX_IMAGES} per-box images for comparison:")
else:
    test_images = per_box_images
    print(f"\nUsing {len(test_images)} per-box images:")

for img in test_images[:10]:
    print(f"  - {img.split('/')[-1]}")
if len(test_images) > 10:
    print(f"  ... and {len(test_images) - 10} more")

## 2. Load Base Model

In [None]:
import sys
sys.path.insert(0, '..')

from yologen.models.vlm.qwen import QwenVLM

print("Loading base model...")
base_vlm = QwenVLM(
    model_name=MODEL_NAME,
    load_in_4bit=(PRECISION == "4bit"),
    load_in_8bit=(PRECISION == "8bit"),
    use_lora=False,
)
base_vlm.load_model()
print("Base model ready!")

## 3. Run Base Model Inference

In [None]:
base_results = []

for i, image_path in enumerate(test_images):
    print(f"\r[Base] Processing {i+1}/{len(test_images)}: {image_path.split('/')[-1]}", end="")
    
    response = base_vlm.generate(
        image=image_path,
        question=PROMPT,
        system_prompt=SYSTEM_PROMPT,
    )
    
    base_results.append({
        "image": image_path,
        "response": response.strip()
    })

print(f"\nBase model: {len(base_results)} images processed")

## 4. Clear GPU Memory

In [None]:
import torch
import gc

del base_vlm
gc.collect()
torch.cuda.empty_cache()
print("GPU memory cleared.")

## 5. Load Fine-tuned Model

In [None]:
print("Loading fine-tuned model...")
finetuned_vlm = QwenVLM(
    model_name=MODEL_NAME,
    load_in_4bit=(PRECISION == "4bit"),
    load_in_8bit=(PRECISION == "8bit"),
    use_lora=False,
)
finetuned_vlm.load_model()
finetuned_vlm.load_adapter(ADAPTER_PATH)
print(f"Fine-tuned model ready! Adapter: {ADAPTER_PATH}")

## 6. Run Fine-tuned Model Inference

In [None]:
finetuned_results = []

for i, image_path in enumerate(test_images):
    print(f"\r[Fine-tuned] Processing {i+1}/{len(test_images)}: {image_path.split('/')[-1]}", end="")
    
    response = finetuned_vlm.generate(
        image=image_path,
        question=PROMPT,
        system_prompt=SYSTEM_PROMPT,
    )
    
    finetuned_results.append({
        "image": image_path,
        "response": response.strip()
    })

print(f"\nFine-tuned model: {len(finetuned_results)} images processed")

## 7. Visual Comparison

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import textwrap

n_images = min(len(test_images), 12)

fig, axes = plt.subplots(n_images, 2, figsize=(24, 6*n_images),
                         gridspec_kw={'width_ratios': [1.5, 1]})

if n_images == 1:
    axes = axes.reshape(1, -1)

for i in range(n_images):
    base = base_results[i]
    ft = finetuned_results[i]
    
    # Left: Image (larger)
    try:
        img = Image.open(base["image"])
        axes[i, 0].imshow(img)
        axes[i, 0].set_title(base["image"].split("/")[-1], fontsize=11, fontweight='bold')
        axes[i, 0].axis('off')
    except Exception as e:
        axes[i, 0].text(0.5, 0.5, f"Error: {e}", ha='center', va='center')
        axes[i, 0].axis('off')
    
    # Right: Responses
    axes[i, 1].axis('off')
    
    # Wrap long text
    base_text = textwrap.fill(base["response"], width=45)
    ft_text = textwrap.fill(ft["response"], width=45)
    
    # Base response (gray background)
    axes[i, 1].text(0.02, 0.78, "BASE MODEL:", fontsize=12, fontweight='bold',
                    color='#333', transform=axes[i, 1].transAxes, va='top')
    axes[i, 1].text(0.02, 0.68, base_text, fontsize=11,
                    color='#333', transform=axes[i, 1].transAxes, va='top',
                    bbox=dict(boxstyle='round', facecolor='#f0f0f0', edgecolor='#ccc', pad=0.5))
    
    # Fine-tuned response (blue background)
    axes[i, 1].text(0.02, 0.38, "FINE-TUNED:", fontsize=12, fontweight='bold',
                    color='#0066cc', transform=axes[i, 1].transAxes, va='top')
    axes[i, 1].text(0.02, 0.28, ft_text, fontsize=11,
                    color='#0066cc', transform=axes[i, 1].transAxes, va='top',
                    bbox=dict(boxstyle='round', facecolor='#e6f2ff', edgecolor='#99ccff', pad=0.5))

plt.tight_layout()
save_path = f"{EXP_DIR}/comparison_results.png"
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.show()
print(f"\nCompared {n_images} images. Saved to: {save_path}")

In [None]:
del finetuned_vlm
gc.collect()
torch.cuda.empty_cache()
print("Done!")