In [None]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

import torch
import matplotlib.pyplot as plt
from PIL import Image

from utils import load_config, set_seed
from models import MAIRA2Model, MAIRA2Config

In [None]:
# Set seed for reproducibility
set_seed(42)

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 1. Load MAIRA-2 Model

In [None]:
# Load configuration
config = load_config(project_root / "configs" / "maira2_config.yaml")

# Display model config
print("Model configuration:")
for key, value in config["model"].items():
    print(f"  {key}: {value}")

In [None]:
# Load model (requires GPU with sufficient memory)
# Uncomment to load:

# model = MAIRA2Model.from_pretrained(
#     checkpoint=config["model"]["checkpoint"],
#     device=device,
#     load_in_8bit=True,  # Use 8-bit quantisation to reduce memory
# )
# print("Model loaded successfully!")

## 2. Run Inference

In [None]:
def run_inference(model, image_path, prompt_type="findings"):
    """Run inference on a single image."""
    # Load image
    image = Image.open(image_path).convert("RGB")
    
    # Generate report
    output = model.generate(
        images=image,
        prompt_type=prompt_type,
    )
    
    return output

# Example usage (uncomment when model is loaded):
# image_path = "path/to/your/chest_xray.jpg"
# result = run_inference(model, image_path)
# print("Generated Report:")
# print(result["generated_text"])

## 3. Batch Inference

In [None]:
def batch_inference(model, image_paths, prompt_type="findings"):
    """Run inference on multiple images."""
    results = []
    
    for path in image_paths:
        try:
            result = run_inference(model, path, prompt_type)
            results.append({
                "path": str(path),
                "report": result["generated_text"],
                "success": True,
            })
        except Exception as e:
            results.append({
                "path": str(path),
                "error": str(e),
                "success": False,
            })
    
    return results

## 4. Visual Question Answering

In [None]:
def ask_question(model, image_path, question):
    """Ask a question about the image."""
    image = Image.open(image_path).convert("RGB")
    
    output = model.generate(
        images=image,
        question=question,
        prompt_type="vqa",
    )
    
    return output["generated_text"]

# Example:
# answer = ask_question(model, image_path, "Is there cardiomegaly present?")
# print(f"Answer: {answer}")