In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import torch
from io import BytesIO
from PIL import Image, ImageOps

from train import TrainConfig
train_config = TrainConfig()

IMAGE_PATH = Path("sample.jpg")
CHECKPOINT_PATH = Path("checkpoints/epoch_002.pt")
THRESHOLD = 0.5

print('Resolved device:', train_config.resolved_device())
print('Target size:', train_config.target_size)

In [None]:
from model.hybrid_forgery_detector import HybridForgeryConfig, HybridForgeryDetector

def _predict_image_array(image_arr: np.ndarray, checkpoint_path: Path, threshold: float) -> np.ndarray:
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    cfg = HybridForgeryConfig(**checkpoint['config']['model_config'])
    device = train_config.resolved_device()
    model = HybridForgeryDetector(cfg).to(device)
    model.load_state_dict(checkpoint['model_state'])
    model.eval()
    tensor = torch.from_numpy(image_arr.transpose(2, 0, 1)).unsqueeze(0).to(device)
    zero_noise = torch.zeros(1, 3, tensor.shape[-2], tensor.shape[-1], device=device)
    noise_inputs = {'residual': zero_noise}
    with torch.no_grad():
        mask = model.predict_mask(tensor, threshold=threshold, noise_features=noise_inputs)
    return mask.squeeze().cpu().numpy()

def _load_image_from_path(image_path: str) -> tuple[Image.Image, str]:
    path = Path(image_path)
    if not path.exists():
        raise FileNotFoundError(f"Image not found: {path}")
    return Image.open(path).convert('RGB'), path.name

In [None]:
uploaded_image, uploaded_label = _load_image_from_path(IMAGE_PATH)
original_image = uploaded_image.copy()

try:
    resample_mode = Image.Resampling.BICUBIC
except AttributeError:
    resample_mode = Image.BICUBIC

target_hw = (train_config.target_size, train_config.target_size)
resized_image = ImageOps.fit(uploaded_image, target_hw, method=resample_mode)
normalized = np.array(resized_image, dtype=np.float32) / 255.0

checkpoint_for_single = Path(CHECKPOINT_PATH)

if not checkpoint_for_single.exists():
    print(f'Warning: checkpoint not found: {checkpoint_for_single}. You may need to set CHECKPOINT_PATH.')

threshold_for_single = THRESHOLD
pred_mask = _predict_image_array(normalized, checkpoint_for_single, threshold_for_single)

mask_uint8 = (pred_mask * 255).astype(np.uint8)
mask_img = Image.fromarray(mask_uint8).resize(original_image.size, resample=Image.BICUBIC)
mask_rgba = Image.new("RGBA", original_image.size)
mask_rgba.paste((255, 0, 0, 128), (0, 0, original_image.size[0], original_image.size[1]))
mask_rgba.putalpha(mask_img)
original_rgba = original_image.convert("RGBA")
overlayed_image = Image.alpha_composite(original_rgba, mask_rgba)

fig, axes = plt.subplots(1, 4, figsize=(16, 4))
axes[0].imshow(original_image)
axes[0].set_title(f'Original ({uploaded_label})')
axes[0].axis('off')
axes[1].imshow(resized_image)
axes[1].set_title('Resized')
axes[1].axis('off')
axes[2].imshow(pred_mask, cmap='magma')
axes[2].set_title(f'Prediction (thr={threshold_for_single:.2f})')
axes[2].axis('off')
axes[3].imshow(overlayed_image)
axes[3].set_title('Overlayed')
axes[3].axis('off')
plt.tight_layout()
plt.show()
print('Prediction mask shape:', pred_mask.shape)