### Single Image Inferencing

In [None]:
from io import BytesIO
from PIL import Image, ImageOps
try:
    from google.colab import files as colab_files
except ImportError:
    colab_files = None
    
def _predict_image_array(image_arr: np.ndarray, checkpoint_path: Path, threshold: float) -> np.ndarray:
    from model.hybrid_forgery_detector import HybridForgeryDetector
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    cfg = HybridForgeryConfig(**checkpoint["config"]["model_config"])
    device = train_config.resolved_device()
    model = HybridForgeryDetector(cfg).to(device)
    model.load_state_dict(checkpoint["model_state"])
    model.eval()
    tensor = torch.from_numpy(image_arr.transpose(2, 0, 1)).unsqueeze(0).to(device)
    zero_noise = torch.zeros(1, 3, tensor.shape[-2], tensor.shape[-1], device=device)
    noise_inputs = {"residual": zero_noise}
    with torch.no_grad():
        mask = model.predict_mask(tensor, threshold=threshold, noise_features=noise_inputs)
    return mask.squeeze().cpu().numpy()
    
def _load_uploaded_image() -> tuple[Image.Image, str]:
    if USE_COLAB and colab_files is not None:
        uploaded = colab_files.upload()
        if not uploaded:
            raise ValueError("No file uploaded. Please select an image.")
        name, data = next(iter(uploaded.items()))
        return Image.open(BytesIO(data)).convert("RGB"), name
    local_image_path = Path("./REPLACE_WITH_IMAGE.jpg")
    if "REPLACE_WITH_IMAGE" in str(local_image_path):
        raise ValueError("Set local_image_path to an actual image path when running outside Colab.")
    if not local_image_path.exists():
        raise FileNotFoundError(f"Local image not found: {local_image_path}")
    return Image.open(local_image_path).convert("RGB"), str(local_image_path)
    
uploaded_image, uploaded_label = _load_uploaded_image()
original_image = uploaded_image.copy()
try:
    resample_mode = Image.Resampling.BICUBIC
except AttributeError:
    resample_mode = Image.BICUBIC
target_hw = (train_config.target_size, train_config.target_size)
resized_image = ImageOps.fit(uploaded_image, target_hw, method=resample_mode)
normalized = np.array(resized_image, dtype=np.float32) / 255.0
checkpoint_for_single = checkpoint_path if "checkpoint_path" in locals() else Path(train_config.checkpoint_dir) / "best.pt"
threshold_for_single = (
    best_threshold if "best_threshold" in locals()
    else (trained_config.primary_eval_threshold if "trained_config" in locals() else train_config.primary_eval_threshold)
 )
pred_mask = _predict_image_array(normalized, checkpoint_for_single, threshold_for_single)
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(original_image)
axes[0].set_title(f"Original ({uploaded_label})")
axes[0].axis("off")
axes[1].imshow(resized_image)
axes[1].set_title("Resized")
axes[1].axis("off")
axes[2].imshow(pred_mask, cmap="magma")
axes[2].set_title(f"Prediction (thr={threshold_for_single:.2f})")
axes[2].axis("off")
plt.tight_layout()
pred_mask.shape