# VLM Qualitative Inference

Interactive probing of the VLM with custom images and questions.

In [None]:

import os
import sys
import torch
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path

# Add src to path
project_root = Path(os.path.abspath('..'))
sys.path.append(str(project_root))
sys.path.append(str(project_root / "src"))

from config import load_config
from models.alignment import MultimodalAlignmentModel
from models.trm_qwen_vlm import QwenVLM
from decoders.qwen import QwenDecoder
from data.transforms import get_image_transforms

%matplotlib inline

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


In [None]:

# --- Configuration ---
config_path = "../configs/trm_vlm_qa_qwen1.5.yaml"
alignment_config_path = "../configs/pixmo_alignment.yaml"
# Update checkpoint path!
checkpoint_path = "../checkpoints/vlm_run/checkpoint-epoch-9" 
alignment_checkpoint = "../notebooks/checkpoints/pixmo_alignment/checkpoint_best.pt"
use_trm = True

# --- Load Models ---
print("Loading Models...")
# 1. Vision
alignment_config = load_config(alignment_config_path)
alignment_config.decoder = None; alignment_config.text_encoder = None
aligned_model = MultimodalAlignmentModel(alignment_config)
if os.path.exists(alignment_checkpoint):
    ckpt = torch.load(alignment_checkpoint, map_location='cpu', weights_only=False)
    aligned_model.load_state_dict(ckpt['model_state_dict'], strict=False)
aligned_model.eval().to(device)

# 2. VLM
config = load_config(config_path)
qwen_decoder = QwenDecoder(config.decoder.model_name, load_in_4bit=True, use_lora=True, device_map="auto")
model = QwenVLM(
    qwen_decoder, alignment_config.vision_encoder.projection_dim,
    use_trm_recursion=use_trm, num_trm_layers=4, num_recursion_steps=4
).to(device)

# Checkpoint
if os.path.isdir(checkpoint_path):
    bin_path = Path(checkpoint_path) / "pytorch_model.bin"
    if bin_path.exists(): model.load_state_dict(torch.load(bin_path, map_location='cpu'), strict=False)
else:
    model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')['model_state_dict'], strict=False)
model.eval()
print("Ready!")


## Interactive Inference

In [None]:

def run_inference(image_path, question):
    # Load and Plot Image
    try:
        if isinstance(image_path, str):
            image = Image.open(image_path).convert('RGB')
        else:
            image = image_path # Allow passing PIL object
            
        plt.figure(figsize=(6,6))
        plt.imshow(image)
        plt.axis('off')
        plt.title(f"Q: {question}")
        plt.show()
        
        # Transform
        transform = get_image_transforms(config.dataset.image_size, is_training=False)
        img_tensor = transform(image).unsqueeze(0).to(device)
        
        # Encode
        with torch.no_grad():
            vision_tokens = aligned_model.vision_encoder(img_tensor, return_sequence=True).sequence
            
        # Generate
        inputs = qwen_decoder.tokenizer([question], return_tensors='pt', padding=True).to(device)
        with torch.no_grad():
            gen_ids = model.generate(
                vision_tokens=vision_tokens,
                question_ids=inputs.input_ids,
                max_new_tokens=128,
                temperature=0.2
            )
        
        answer = qwen_decoder.tokenizer.decode(gen_ids[0], skip_special_tokens=True)
        print(f"ðŸ¤– Answer: {answer}")
        
    except Exception as e:
        print(f"Error: {e}")

# Example Usage
# run_inference("/path/to/image.jpg", "What is in this image?")
