In [None]:
import os
import json
import cv2
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

# Define configurable threshold
THRESHOLD = 1.15

# Define the RLE encoder
def rle_encode(mask: np.ndarray, fg_val: int = 1) -> str:
    # Flatten mask in Fortran order
    pixels = mask.T.flatten()
    
    # Find foreground indices
    dots = np.where(pixels == fg_val)[0]
    
    # Return authentic if no foreground
    if len(dots) == 0:
        return "authentic"
    
    # Initialize run-length list
    run_lengths = []
    prev = -2
    
    # Compute run-length encoding
    for b in dots:
        if b > prev + 1:
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    
    # Return as JSON string
    return json.dumps([int(x) for x in run_lengths])

# Define example pipeline (replace model logic here)
def pipeline_final(pil_img: Image.Image):
    # Convert PIL to NumPy array
    img = np.array(pil_img)
    
    # Convert to grayscale
    gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    
    # Compute a fake "forgery" confidence metric
    mean_intensity = gray.mean() / 255.0
    
    # Create a dummy mask
    mask = (gray > gray.mean()).astype(np.uint8)
    
    # Prepare debug info
    dbg = {
        "mean_inside": mean_intensity,
        "area": mask.sum(),
        "thr": THRESHOLD
    }
    
    # Apply threshold to decide authenticity
    if mean_intensity > THRESHOLD:
        label = "forged"
    else:
        label = "authentic"
    
    # Return label, mask, and debug info
    return label, mask, dbg

# Define main inference function
def main():
    # Define paths
    TEST_DIR = "/kaggle/input/recodai-luc-scientific-image-forgery-detection/test_images"
    SAMPLE_SUB = "/kaggle/input/recodai-luc-scientific-image-forgery-detection/sample_submission.csv"
    OUT_PATH = "submission.csv"

    # Initialize output rows
    rows = []

    # Run inference on test images
    for f in tqdm(sorted(os.listdir(TEST_DIR)), desc="Inference on Test Set"):
        pil = Image.open(Path(TEST_DIR) / f).convert("RGB")
        
        # Get prediction and mask from pipeline
        label, mask, dbg = pipeline_final(pil)
        
        # Ensure mask format
        if mask is not None:
            mask = np.array(mask, dtype=np.uint8)
        else:
            mask = np.zeros(pil.size[::-1], np.uint8)
        
        # Encode annotation
        if label == "authentic":
            annot = "authentic"
        else:
            annot = rle_encode((mask > 0).astype(np.uint8))
        
        # Append result row
        rows.append({
            "case_id": Path(f).stem,
            "annotation": annot,
            "area": int(dbg.get("area", mask.sum())),
            "mean": float(dbg.get("mean_inside", 0.0)),
            "thr": float(dbg.get("thr", 0.0))
        })
    
    # Create submission dataframe
    sub = pd.DataFrame(rows)
    ss = pd.read_csv(SAMPLE_SUB)
    
    # Ensure consistent case_id type
    ss["case_id"] = ss["case_id"].astype(str)
    sub["case_id"] = sub["case_id"].astype(str)
    
    # Merge with sample submission
    final = ss[["case_id"]].merge(sub, on="case_id", how="left")
    final["annotation"] = final["annotation"].fillna("authentic")
    
    # Save submission file
    final[["case_id", "annotation"]].to_csv(OUT_PATH, index=False)
    print(f"\nSaved submission file: {OUT_PATH}")
    print(final.head(10))

    # Run visualization
    visualize_results(TEST_DIR)

# Define visualization function
def visualize_results(test_dir: str):
    # Select sample images
    sample_files = sorted(os.listdir(test_dir))[:5]

    # Visualize results for a few images
    for f in sample_files:
        pil = Image.open(Path(test_dir) / f).convert("RGB")
        label, mask, dbg = pipeline_final(pil)
        
        # Ensure mask format
        if mask is not None:
            mask = np.array(mask, dtype=np.uint8)
        else:
            mask = np.zeros(pil.size[::-1], np.uint8)
        
        # Print image information
        print(f"{'ðŸ”´' if label == 'forged' else 'ðŸŸ¢'} {f}: {label} | area={mask.sum()} mean={dbg.get('mean_inside', 0):.3f}")
        
        # Display authentic or forged images
        if label == "authentic":
            plt.figure(figsize=(5, 5))
            plt.imshow(pil)
            plt.title(f"{f} â€” Authentic")
            plt.axis("off")
            plt.show()
        else:
            plt.figure(figsize=(10, 5))
            plt.subplot(1, 2, 1)
            plt.imshow(pil)
            plt.title("Original")
            plt.axis("off")
            
            plt.subplot(1, 2, 2)
            plt.imshow(pil)
            plt.imshow(mask, alpha=0.45, cmap="Blues")
            plt.title("Predicted Mask")
            plt.axis("off")
            plt.show()

# Call main function
if __name__ == "__main__":
    main()