In [32]:
%cd /home/hotson/kaggle_work/recodai-luc-scientific-image-forgery-detection/code

/home/hotson/kaggle_work/recodai-luc-scientific-image-forgery-detection/code


# Inference

In [None]:
import torch
import numpy as np
import cv2
import json
from pathlib import Path
from PIL import Image
from tqdm import tqdm

from ultralytics import YOLO
from transformers import AutoProcessor
from torchvision import transforms

from dinov2_uperhead import DinoV2_UPerNet
from dinov2_unet import DinoV2UNet
from utils import rle_encode, evaluate_single_image
from matplotlib import pyplot as plt


class Config:
    yolo_img_size = 1024
    yolo_conf = 0.25
    yolo_iou = 0.6

margin = 20
img_size = 532

area_thres = 400
mean_in_thres = 0.55
min_cc_area = 50

# area_thres = 600
# mean_in_thres = 0.65
# min_cc_area = 200

img_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std =[0.229, 0.224, 0.225],
    ),
])
device = torch.device(
    "cuda" if torch.cuda.is_available() else "cpu"
)

# Loading detector
detect_model = YOLO("../runs/detect/train/weights/best.pt")

# Loading segmentation model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state = torch.load("../exps/20260115_182834/vit_base_patch14_dinov2.lvd142m_fold0.pt", map_location=device)
# state = torch.load("../exps/20260110_193809/vit_base_patch14_dinov2.lvd142m_fold0.pt", map_location=device)
model_state = state["model_state"]
seg_model = DinoV2_UPerNet(dinov2_id="facebook/dinov2-base", num_classes=1).to(device)
# seg_model = DinoV2UNet(dinov2_id="facebook/dinov2-base", out_classes=1).to(device)
seg_model.load_state_dict(model_state)
seg_model.eval()

supplemental_images = Path("../supplemental_images")

def apply_clahe(image: Image.Image) -> Image.Image:
    img_np = np.array(image)
    lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
    l, a, b = cv2.split(lab)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    l2 = clahe.apply(l)
    lab2 = cv2.merge((l2, a, b))
    enhanced = cv2.cvtColor(lab2, cv2.COLOR_LAB2RGB)
    return Image.fromarray(enhanced)


def enhanced_adaptive_mask(prob: np.ndarray, alpha_grad=0.45):
    prob = np.asarray(prob, dtype=np.float32)
    if prob.ndim == 3:
        prob = prob[..., 0]

    gx = cv2.Sobel(prob, cv2.CV_32F, 1, 0, ksize=3)
    gy = cv2.Sobel(prob, cv2.CV_32F, 0, 1, ksize=3)
    grad_mag = np.sqrt(gx**2 + gy**2)
    grad_norm = grad_mag / (grad_mag.max() + 1e-6)

    enhanced = (1 - alpha_grad) * prob + alpha_grad * grad_norm
    enhanced = cv2.GaussianBlur(enhanced, (3, 3), 0)

    thr = float(np.mean(enhanced) + 0.1 * np.std(enhanced))
    mask = (enhanced > thr).astype(np.uint8)

    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8))
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8))
    return mask, thr


def finalize_mask(prob: np.ndarray, orig_size_wh: tuple[int, int]):
    mask, thr = enhanced_adaptive_mask(prob)
    w, h = orig_size_wh
    mask = cv2.resize(mask, (int(w), int(h)), interpolation=cv2.INTER_NEAREST)
    return mask, thr


def remove_small_components(mask01: np.ndarray, min_area: int) -> np.ndarray:
    if min_area is None or min_area <= 0:
        return (mask01 > 0).astype(np.uint8)

    m = (mask01 > 0).astype(np.uint8)
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(m, connectivity=8)

    out = np.zeros_like(m, dtype=np.uint8)
    for lab in range(1, num_labels):
        if stats[lab, cv2.CC_STAT_AREA] >= min_area:
            out[labels == lab] = 1
    return out


def preprocess(image: Image.Image, base_model_preprocessor, img_size: int = 532):
    image = apply_clahe(image)
    image = transforms.functional.resize(image, (img_size, img_size))
    processed = base_model_preprocessor(
        images=image,
        return_tensors="pt",
        do_resize=False,
        do_center_crop=False,
        do_normalize=True,
        do_rescale=True,
    )
    return processed["pixel_values"].squeeze(0)  # [C,H,W]


@torch.no_grad()
def get_seg_mask_raw(image: Image.Image, H: int, W: int):
    """
    Returns raw predicted binary mask (after finalize + CC),
    plus prob maps for viz/aggregation.

    Returns:
      pred_bin   : (H,W) uint8 {0,1}
      prob_full  : (H,W) float32 in [0,1]
      prob_small : (S,S) float32 in [0,1]  (model-res, i.e., img_size x img_size)
    """
    image = apply_clahe(image)

    x = img_transform(image)          # (3,S,S)
    x = x.unsqueeze(0).to(device)     # (1,3,S,S)

    out = seg_model(x)
    logits = out if torch.is_tensor(out) else out[0]

    prob_small = torch.sigmoid(logits)[0, 0].detach().cpu().numpy().astype(np.float32)  # (S,S)

    # binarize at orig-res using your existing finalize_mask()
    mask_full, _ = finalize_mask(prob_small, (W, H))  # (H,W) uint8-ish
    pred_bin = (mask_full > 0).astype(np.uint8)

    # IMPORTANT: for crops, DO NOT use the same big CC threshold here.
    # We'll do CC + gating once globally after merging.
    prob_full = cv2.resize(prob_small, (W, H), interpolation=cv2.INTER_LINEAR).astype(np.float32)

    return pred_bin, prob_full, prob_small


def load_gt_mask_from_image_path(image_path: str, w: int, h: int) -> np.ndarray:
    mask_path = (
        image_path
        .replace("train_images", "train_masks")
        .replace("supplemental_images", "supplemental_masks")
        .replace("/forged", "")
        .replace(".png", ".npy")
    )
    if not Path(mask_path).exists():
        return np.zeros((h, w), dtype=np.uint8)

    arr = np.load(mask_path, allow_pickle=True)
    if arr.ndim == 3:
        arr = arr.max(axis=0)
    gt_mask = (arr > 0).astype(np.uint8)
    if gt_mask.shape[:2] != (h, w):
        gt_mask = cv2.resize(gt_mask, (w, h), interpolation=cv2.INTER_NEAREST)
    return gt_mask

def apply_overlay_rgb(img_rgb: np.ndarray, mask01: np.ndarray, color=(255, 0, 0), alpha=0.45):
    """
    img_rgb: (H,W,3) uint8
    mask01: (H,W) {0,1} or {0,255}
    """
    out = img_rgb.copy()
    m = (mask01 > 0)
    if m.any():
        col = np.array(color, dtype=np.float32)
        out[m] = (alpha * col + (1 - alpha) * out[m].astype(np.float32)).astype(np.uint8)
    return out

def draw_yolo_boxes(img_rgb: np.ndarray, boxes_xyxy: np.ndarray, color=(255, 0, 0), thickness=2):
    """
    Draws xyxy boxes on RGB image. Uses cv2 (expects BGR), so we convert back/forth.
    boxes_xyxy: (N,4) float or int in xyxy.
    """
    if boxes_xyxy is None or len(boxes_xyxy) == 0:
        return img_rgb

    img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
    for b in boxes_xyxy:
        x1, y1, x2, y2 = [int(round(v)) for v in b.tolist()]
        cv2.rectangle(img_bgr, (x1, y1), (x2, y2), color[::-1], thickness)  # reverse to BGR
    return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

def overlay_heatmap_on_rgb(img_rgb: np.ndarray, heat01: np.ndarray, alpha=0.45):
    """
    img_rgb: (H,W,3) uint8
    heat01:  (H,W) float in [0,1]
    """
    h = np.clip(heat01, 0.0, 1.0)
    h255 = (h * 255).astype(np.uint8)
    heat_bgr = cv2.applyColorMap(h255, cv2.COLORMAP_JET)
    heat_rgb = cv2.cvtColor(heat_bgr, cv2.COLOR_BGR2RGB)
    out = (alpha * heat_rgb.astype(np.float32) + (1 - alpha) * img_rgb.astype(np.float32)).astype(np.uint8)
    return out

def normalize01(x: np.ndarray, eps=1e-6):
    x = x.astype(np.float32)
    mn, mx = float(x.min()), float(x.max())
    if mx - mn < eps:
        return np.zeros_like(x, dtype=np.float32)
    return (x - mn) / (mx - mn)


break_loop = 3

scores = []
for img_path in tqdm(sorted(supplemental_images.glob("*.png")), desc="Inference"):
    case_id = img_path.stem

    image = Image.open(img_path).convert("RGB")
    W, H = image.size

    predicted_mask = np.zeros((H, W), dtype=np.uint8)
    prob_accum = np.zeros((H, W), dtype=np.float32)
    prob_count = np.zeros((H, W), dtype=np.float32)

    preds = detect_model.predict(
        source=str(img_path),
        conf=Config.yolo_conf,
        iou=Config.yolo_iou,
        imgsz=Config.yolo_img_size,
        verbose=False
    )

    boxes = preds[0].boxes
    crops = []

    is_authentic = True

    if boxes is None or len(boxes) <= 1:
        # whole-image segmentation
        pred_bin, prob_full, prob_small = get_seg_mask_raw(image, H=H, W=W)
        predicted_mask = np.maximum(predicted_mask, pred_bin)

        prob_accum += prob_full
        prob_count += 1.0

        # keep for viz
        boxes_xyxy = None
    else:
        boxes_xyxy = boxes.xyxy.cpu().numpy()  # for drawing

        for bbox in boxes_xyxy:
            x1, y1, x2, y2 = bbox.tolist()

            # ints + clamp
            x1 = int(max(0, x1)); y1 = int(max(0, y1))
            x2 = int(min(W, x2)); y2 = int(min(H, y2))

            # apply margin safely
            x1m = int(min(W, x1 + margin))
            y1m = int(min(H, y1 + margin))
            x2m = int(max(0, x2 - margin))
            y2m = int(max(0, y2 - margin))

            if x2m <= x1m or y2m <= y1m:
                continue

            crop = image.crop((x1m, y1m, x2m, y2m))
            cw, ch = crop.size
            crops.append(crop)

            pred_bin, crop_prob_full, crop_prob_small = get_seg_mask_raw(crop, H=ch, W=cw)

            predicted_mask[y1m:y2m, x1m:x2m] = np.maximum(
                predicted_mask[y1m:y2m, x1m:x2m],
                pred_bin
            )

            prob_accum[y1m:y2m, x1m:x2m] += crop_prob_full
            prob_count[y1m:y2m, x1m:x2m] += 1.0
            
    # average probability where we have coverage
    prob_vis = prob_accum / np.maximum(prob_count, 1e-6)

    # CC filter ONCE (global)
    predicted_mask = remove_small_components(predicted_mask, min_cc_area)

    area = int(predicted_mask.sum())
    if area > 0:
        mean_inside = float(prob_vis[predicted_mask == 1].mean())
    else:
        mean_inside = 0.0

    is_forged = (area >= area_thres) and (mean_inside >= mean_in_thres)

    if not is_forged:
        predicted_mask[:] = 0

    # Evaluate vs GT (supplemental has masks)
    gt_mask = load_gt_mask_from_image_path(str(img_path), w=W, h=H)

    label_rles = rle_encode([gt_mask])
    prediction_rles = rle_encode([predicted_mask])

    score = evaluate_single_image(
        label_rles=label_rles,
        prediction_rles=prediction_rles,
        shape_str=json.dumps([int(H), int(W)])
    )
    scores.append(score)

    # ----- visualize -----
    img_np = np.array(image)  # RGB uint8

    # if boxes_xyxy wasn't set (whole-image path), set None safely
    if "boxes_xyxy" not in locals():
        boxes_xyxy = None

    img_with_boxes = draw_yolo_boxes(img_np, boxes_xyxy, color=(255, 0, 0), thickness=3)

    pred_overlay = apply_overlay_rgb(img_np, predicted_mask, color=(255, 0, 0), alpha=0.45)  # red
    gt_overlay   = apply_overlay_rgb(img_np, gt_mask,        color=(0, 255, 0), alpha=0.45)  # green

    both_overlay = img_np.copy()
    both_overlay = apply_overlay_rgb(both_overlay, gt_mask,        color=(0, 255, 0), alpha=0.35)
    both_overlay = apply_overlay_rgb(both_overlay, predicted_mask, color=(255, 0, 0), alpha=0.35)
    both_overlay = draw_yolo_boxes(both_overlay, boxes_xyxy, color=(255, 0, 0), thickness=3)

    # heatmap + overlay
    heat = np.clip(prob_vis, 0.0, 1.0)
    heat_overlay = overlay_heatmap_on_rgb(img_np, heat, alpha=0.45)
    heat_overlay = draw_yolo_boxes(heat_overlay, boxes_xyxy, color=(255, 0, 0), thickness=3)

    # (optional) show "contrast enhanced" heat for better viewing (doesn't change model)
    heat_norm = normalize01(heat)
    heat_norm_overlay = overlay_heatmap_on_rgb(img_np, heat_norm, alpha=0.45)
    heat_norm_overlay = draw_yolo_boxes(heat_norm_overlay, boxes_xyxy, color=(255, 0, 0), thickness=3)

    plt.figure(figsize=(22, 14))

    plt.subplot(3, 3, 1)
    plt.title("Input + YOLO boxes")
    plt.imshow(img_with_boxes); plt.axis("off")

    plt.subplot(3, 3, 2)
    plt.title(f"Pred mask (score={score:.4f})")
    plt.imshow(predicted_mask, cmap="gray"); plt.axis("off")

    plt.subplot(3, 3, 3)
    plt.title("GT mask")
    plt.imshow(gt_mask, cmap="gray"); plt.axis("off")

    plt.subplot(3, 3, 4)
    plt.title("Pred overlay (red)")
    plt.imshow(pred_overlay); plt.axis("off")

    plt.subplot(3, 3, 5)
    plt.title("GT overlay (green)")
    plt.imshow(gt_overlay); plt.axis("off")

    plt.subplot(3, 3, 6)
    plt.title("GT+Pred overlay + boxes")
    plt.imshow(both_overlay); plt.axis("off")

    plt.subplot(3, 3, 7)
    plt.title("Prob heatmap (0..1)")
    plt.imshow(heat, cmap="jet", vmin=0, vmax=1); plt.axis("off")

    plt.subplot(3, 3, 8)
    plt.title("Heatmap overlay + boxes")
    plt.imshow(heat_overlay); plt.axis("off")

    plt.subplot(3, 3, 9)
    plt.title("Heatmap normalized overlay (viz)")
    plt.imshow(heat_norm_overlay); plt.axis("off")

    plt.tight_layout()
    plt.show()

    # cleanup locals that can leak between iterations
    if "boxes_xyxy" in locals():
        del boxes_xyxy

    # break_loop -= 1
    # if break_loop == 0:
    #     break

    # # for crop in crops:
    # #     plt.title(f"mean_pixel: {np.mean(np.asarray(crop))}")
    # #     plt.imshow(crop)
    # #     plt.show();

In [36]:
sum(scores)/len(scores)

np.float64(0.0)