# VLM Comparison: Base vs Fine-tuned

Compare your fine-tuned VLM model against the base model to evaluate improvements.

## Usage
1. Set `EXP_DIR` to your experiment directory (e.g., `runs/exp_20250102_xxx`)
2. Run all cells - paths are auto-loaded from experiment config

## Notes
- Config is auto-loaded from `{EXP_DIR}/config.json` (saved during training)
- For end-to-end testing (YOLO + VLM), use `predict.py --vlm`

## 1. Configuration

In [None]:
import json
from pathlib import Path

# ============================================================
# ONLY SET THIS - everything else is auto-loaded from config
# ============================================================
EXP_DIR = "../runs/exp_xxx"  # Your experiment directory
# ============================================================

# Load experiment config
config_path = Path(EXP_DIR) / "config.json"
if not config_path.exists():
    raise FileNotFoundError(
        f"Config not found: {config_path}\n"
        "Make sure you're using an experiment trained with the latest train.py"
    )

with open(config_path) as f:
    config = json.load(f)

# Auto-load settings from config
MODEL_NAME = config["vlm"]["model"]
PRECISION = config["vlm"]["precision"]
VLM_DATA_DIR = config["vlm"]["data_dir"]
ADAPTER_PATH = config["vlm"]["adapter"] or f"{EXP_DIR}/vlm/best"

print(f"Loaded config from: {config_path}")
print(f"  Model: {MODEL_NAME}")
print(f"  Precision: {PRECISION}")
print(f"  VLM Data: {VLM_DATA_DIR}")
print(f"  Adapter: {ADAPTER_PATH}")

In [None]:
# Test images - VLM validation set (auto-loaded from config)
TEST_FOLDER = f"{VLM_DATA_DIR}/images/val"

# Your evaluation prompt - customize if needed
PROMPT = """
Look at the bounding box in the image.
What do you see inside the marked area?
""".strip()

# Optional: System prompt (should match training config)
SYSTEM_PROMPT = """
You are an object detection assistant.
When shown an image with a bounding box, identify what is inside the marked area.
""".strip()

# Set to None if you didn't use system prompt during training
# SYSTEM_PROMPT = None

In [None]:
# Find all test images
import glob

image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.webp"]
test_images = []

for ext in image_extensions:
    test_images.extend(glob.glob(f"{TEST_FOLDER}/{ext}"))
    test_images.extend(glob.glob(f"{TEST_FOLDER}/**/{ext}", recursive=True))

test_images = sorted(set(test_images))

print(f"Found {len(test_images)} test images:")
for img in test_images[:10]:  # Show first 10
    print(f"  - {img}")
if len(test_images) > 10:
    print(f"  ... and {len(test_images) - 10} more")

## 2. Load Base Model

In [None]:
import sys
sys.path.insert(0, '..')

from yologen.models.vlm.qwen import QwenVLM

print("Loading base model...")
base_vlm = QwenVLM(
    model_name=MODEL_NAME,
    load_in_4bit=(PRECISION == "4bit"),
    load_in_8bit=(PRECISION == "8bit"),
    use_lora=False,
)
base_vlm.load_model()
print("Base model ready!")

## 3. Run Base Model Inference

In [None]:
base_results = []

for i, image_path in enumerate(test_images):
    print(f"\r[Base] Processing {i+1}/{len(test_images)}: {image_path.split('/')[-1]}", end="")
    
    response = base_vlm.generate(
        image=image_path,
        question=PROMPT,
        system_prompt=SYSTEM_PROMPT,
    )
    
    base_results.append({
        "image": image_path,
        "response": response.strip()
    })

print(f"\nBase model: {len(base_results)} images processed")

## 4. Clear GPU Memory

In [None]:
import torch
import gc

del base_vlm
gc.collect()
torch.cuda.empty_cache()
print("GPU memory cleared.")

## 5. Load Fine-tuned Model

In [None]:
print("Loading fine-tuned model...")
finetuned_vlm = QwenVLM(
    model_name=MODEL_NAME,
    load_in_4bit=(PRECISION == "4bit"),
    load_in_8bit=(PRECISION == "8bit"),
    use_lora=False,
)
finetuned_vlm.load_model()
finetuned_vlm.load_adapter(ADAPTER_PATH)
print(f"Fine-tuned model ready! Adapter: {ADAPTER_PATH}")

## 6. Run Fine-tuned Model Inference

In [None]:
finetuned_results = []

for i, image_path in enumerate(test_images):
    print(f"\r[Fine-tuned] Processing {i+1}/{len(test_images)}: {image_path.split('/')[-1]}", end="")
    
    response = finetuned_vlm.generate(
        image=image_path,
        question=PROMPT,
        system_prompt=SYSTEM_PROMPT,
    )
    
    finetuned_results.append({
        "image": image_path,
        "response": response.strip()
    })

print(f"\nFine-tuned model: {len(finetuned_results)} images processed")

## 7. Comparison Table

In [None]:
from IPython.display import display, HTML

html = "<table style='width:100%; border-collapse: collapse; font-size:12px;'>"
html += "<tr style='background:#333; color:white;'>"
html += "<th style='padding:8px; border:1px solid #ddd;'>#</th>"
html += "<th style='padding:8px; border:1px solid #ddd;'>Image</th>"
html += "<th style='padding:8px; border:1px solid #ddd;'>Base Model</th>"
html += "<th style='padding:8px; border:1px solid #ddd;'>Fine-tuned</th>"
html += "</tr>"

for i, (base, ft) in enumerate(zip(base_results, finetuned_results)):
    base_resp = base["response"][:50]
    ft_resp = ft["response"][:50]
    
    html += f"<tr>"
    html += f"<td style='padding:6px; border:1px solid #ddd;'>{i+1}</td>"
    html += f"<td style='padding:6px; border:1px solid #ddd;'>{base['image'].split('/')[-1]}</td>"
    html += f"<td style='padding:6px; border:1px solid #ddd;'>{base_resp}</td>"
    html += f"<td style='padding:6px; border:1px solid #ddd;'>{ft_resp}</td>"
    html += f"</tr>"

html += "</table>"
display(HTML(html))

## 8. Visual Comparison

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import math

# Grid layout
n_images = min(len(test_images), 12)  # Show max 12 images
n_cols = 3
n_rows = math.ceil(n_images / n_cols)

fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 5*n_rows))
axes = axes.flatten() if n_images > 1 else [axes]

for i in range(n_images):
    base = base_results[i]
    ft = finetuned_results[i]
    
    try:
        img = Image.open(base["image"])
        axes[i].imshow(img)
        
        # Labels on image
        axes[i].text(0.02, 0.98, f"Base: {base['response'][:20]}", 
                     transform=axes[i].transAxes, fontsize=9, fontweight='bold',
                     color='white', backgroundcolor='#333333',
                     verticalalignment='top')
        axes[i].text(0.02, 0.88, f"FT: {ft['response'][:20]}", 
                     transform=axes[i].transAxes, fontsize=9, fontweight='bold',
                     color='white', backgroundcolor='#0066cc',
                     verticalalignment='top')
        
        axes[i].set_title(base["image"].split("/")[-1], fontsize=8)
        axes[i].axis('off')
    except Exception as e:
        axes[i].text(0.5, 0.5, f"Error: {e}", ha='center', va='center')
        axes[i].axis('off')

# Hide empty subplots
for j in range(n_images, len(axes)):
    axes[j].axis('off')

plt.tight_layout()
plt.savefig("comparison_results.png", dpi=150, bbox_inches='tight')
plt.show()
print("\nSaved to: comparison_results.png")

In [None]:
del finetuned_vlm
gc.collect()
torch.cuda.empty_cache()
print("Done!")