In [None]:
from src.sr_utils import SuperResWrapper
from src.sam_utils import (
    load_sam_model,
    run_sam_inference,
    load_mask,
    visualize_prediction_vs_gt,
    dice_coefficient,
    iou_score,
)


In [None]:

# Initialize SR model
sr_model_path = "models/EDSR_x4.pb"
sr_wrapper = SuperResWrapper(sr_model_path, sr_model_name="edsr", scale=4)

# Load SAM
model, processor, device = load_sam_model("nielsr/slimsam-77-uniform")

# Load image and enhance
image_path = "data/iNaturalist_blurry_images/sample_2247.jpg"
mask_path = "data/ground_truth_masks/sample_2247_mask.png"

original_image = Image.open(image_path).convert("RGB")
enhanced_image = sr_wrapper.enhance(original_image)

# Inference on enhanced image
_, masks = run_sam_inference(enhanced_image, model, processor, device)
true_mask = load_mask(mask_path)

# Visualization and Evaluation
visualize_prediction_vs_gt(enhanced_image, masks[0], true_mask, "SAM Output on SR Image")
dice = dice_coefficient(masks[0], true_mask)
iou = iou_score(masks[0], true_mask)

print(f"Dice Coefficient (SR image): {dice:.4f}")
print(f"IoU Score (SR image): {iou:.4f}")