In [None]:
import cv2
import numpy as np
from PIL import Image, ImageChops, ImageEnhance
from typing import Optional, Dict, Tuple
import os


# ============================================================
# DETECT REAL IMAGE FORMAT
# ============================================================

def detect_format(path: str) -> Optional[str]:
    """
    Detect image format using PIL.
    Args:
        path: Path to an image file.
    Returns:
        Image format in lowercase ("jpeg", "png", etc) or None if detection fails.
    """
    if not os.path.exists(path):
        print(f"[ERROR] File not found: {path}")
        return None

    try:
        with Image.open(path) as img:
            fmt = img.format
            return fmt.lower() if fmt else None
    except Exception as e:
        print(f"[ERROR] Cannot detect format: {e}")
        return None


# ============================================================
# JPEG ONLY — ERROR LEVEL ANALYSIS (ELA)
# ============================================================

def ela(path: str, quality: int = 90) -> str:
    """
    Perform Error Level Analysis (only for JPEG).
    Returns path to generated ELA image.
    """
    try:
        original = Image.open(path).convert("RGB")
    except Exception as e:
        raise ValueError(f"Unable to open image for ELA: {e}")

    temp_path = "temp_ela.jpg"
    ela_result_path = "ela_result.png"

    original.save(temp_path, "JPEG", quality=quality)
    compressed = Image.open(temp_path)

    # Difference map
    ela_img = ImageChops.difference(original, compressed)

    extrema = ela_img.getextrema()
    max_diff = max([e[1] for e in extrema])
    scale = 255.0 / max_diff if max_diff > 0 else 1
    ela_img = ImageEnhance.Brightness(ela_img).enhance(scale)

    ela_img.save(ela_result_path)
    return ela_result_path


# ============================================================
# JPEG ONLY — BLOCK VARIANCE ANALYSIS
# ============================================================

def block_variance(path: str) -> str:
    """
    Compute block variance map for JPEG images.
    Returns saved file path.
    """
    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise FileNotFoundError(f"Could not read image: {path}")

    h, w = img.shape
    block = 8

    var_map = np.zeros((h // block, w // block))

    for i in range(0, h - block, block):
        for j in range(0, w - block, block):
            patch = img[i:i+block, j:j+block]
            var_map[i // block, j // block] = float(np.var(patch))

    var_map = cv2.resize(var_map, (w, h), interpolation=cv2.INTER_NEAREST)
    var_map = cv2.normalize(var_map, None, 0, 255, cv2.NORM_MINMAX)

    output = "block_variance.png"
    cv2.imwrite(output, var_map)
    return output


# ============================================================
# UNIVERSAL — NOISE MAP
# ============================================================

def noise_map(path: str) -> str:
    """
    Create noise residual map.
    """
    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise FileNotFoundError(f"Could not read image: {path}")

    blur = cv2.GaussianBlur(img, (5, 5), 0)
    noise = cv2.absdiff(img, blur)
    noise = cv2.normalize(noise, None, 0, 255, cv2.NORM_MINMAX)

    output = "noise_map.png"
    cv2.imwrite(output, noise)
    return output


# ============================================================
# UNIVERSAL — SHARPNESS MAP
# ============================================================

def sharpness_map(path: str) -> str:
    """
    Compute sharpness / edge map via Laplacian.
    """
    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise FileNotFoundError(f"Could not read image: {path}")

    lap = cv2.Laplacian(img, cv2.CV_64F)
    lap = np.absolute(lap)
    lap = cv2.normalize(lap, None, 0, 255, cv2.NORM_MINMAX)

    output = "sharpness_map.png"
    cv2.imwrite(output, lap)
    return output


# ============================================================
# MAIN FORENSIC PIPELINE
# ============================================================

def forensic_pipeline(path: str) -> Dict[str, str]:
    """
    Main forensic analysis pipeline.
    Returns dict of generated filepaths.
    """
    fmt = detect_format(path)
    if fmt is None:
        raise ValueError("Cannot detect format. File may be corrupted.")

    print(f"[INFO] Detected format: {fmt}")

    results: Dict[str, str] = {}

    # JPEG-specific methods
    if fmt == "jpeg":
        print("[INFO] Running JPEG forensic methods...")
        results["ela"] = ela(path)
        results["block_variance"] = block_variance(path)
    else:
        print("[INFO] JPEG-specific methods skipped.")

    print("[INFO] Running universal forensic methods...")

    results["noise_map"] = noise_map(path)
    results["sharpness_map"] = sharpness_map(path)

    heatmap = tampering_heatmap(path, results)
    results["heatmap"] = heatmap

    return results


# ============================================================
# COMPUTE FORENSIC SCORES
# ============================================================

def compute_scores(results: Dict[str, str]) -> Dict[str, float]:
    """
    Compute numerical forensic scores.
    """
    scores: Dict[str, float] = {}

    # ELA score (optional)
    if "ela" in results:
        ela_img = cv2.imread(results["ela"], cv2.IMREAD_GRAYSCALE)
        scores["ela_mean"] = float(np.mean(ela_img))
    else:
        scores["ela_mean"] = 0.0

    noise = cv2.imread(results["noise_map"], cv2.IMREAD_GRAYSCALE)
    sharp = cv2.imread(results["sharpness_map"], cv2.IMREAD_GRAYSCALE)

    scores["noise_std"] = float(np.std(noise))
    scores["sharpness_mean"] = float(np.mean(sharp))

    return scores


# ============================================================
# DECISION ENGINE
# ============================================================

def decide_fakeness(scores: Dict[str, float], fmt: Optional[str]) -> Tuple[str, int]:
    """
    Decide whether image is likely fake.
    Returns: ("FAKE"/"REAL", fraud_score)
    """
    ELA_THRESHOLD = 40
    NOISE_THRESHOLD = 25
    SHARPNESS_THRESHOLD = 60

    ela_flag = scores["ela_mean"] > ELA_THRESHOLD if fmt == "jpeg" else False
    noise_flag = scores["noise_std"] > NOISE_THRESHOLD
    sharp_flag = scores["sharpness_mean"] > SHARPNESS_THRESHOLD

    fraud_score = int(ela_flag + noise_flag + sharp_flag)
    result = "FAKE" if fraud_score >= 2 else "REAL"
    return result, fraud_score


# ============================================================
# TAMPERING HEATMAP
# ============================================================

def tampering_heatmap(path: str, results: Dict[str, str]) -> str:
    """
    Combine maps into a tampering heatmap.
    """
    orig = cv2.imread(path)
    if orig is None:
        raise FileNotFoundError(f"Unable to read original image: {path}")

    noise = cv2.imread(results["noise_map"], cv2.IMREAD_GRAYSCALE)
    sharp = cv2.imread(results["sharpness_map"], cv2.IMREAD_GRAYSCALE)

    noise_norm = noise.astype(np.float32) / 255.0
    sharp_norm = sharp.astype(np.float32) / 255.0

    anomaly = (noise_norm * 0.5) + (sharp_norm * 0.5)

    if "ela" in results:
        ela = cv2.imread(results["ela"], cv2.IMREAD_GRAYSCALE)
        anomaly += (ela.astype(np.float32) / 255.0) * 0.7

    if "block_variance" in results:
        bv = cv2.imread(results["block_variance"], cv2.IMREAD_GRAYSCALE)
        anomaly += (bv.astype(np.float32) / 255.0) * 0.4

    anomaly /= anomaly.max() if anomaly.max() > 0 else 1
    anomaly = cv2.GaussianBlur(anomaly, (21, 21), 0)

    heatmap = (anomaly * 255).astype(np.uint8)
    heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    heatmap_color = cv2.resize(heatmap_color, (orig.shape[1], orig.shape[0]))

    overlay = cv2.addWeighted(orig, 0.6, heatmap_color, 0.4, 0)

    output = "heatmap_overlay.png"
    cv2.imwrite(output, overlay)
    return output



# ============================================================
# RUNNER
# ============================================================

image_path = "../../input/fake-2.jpg"
results = forensic_pipeline(image_path)
fmt = detect_format(image_path)

scores = compute_scores(results)
print("\n[INFO] Forensic scores:", scores)

decision, fraud_score = decide_fakeness(scores, fmt)
print(f"\n[RESULT] Decision: {decision} (fraud score = {fraud_score})")

[INFO] Detected format: jpeg
[INFO] Running JPEG forensic methods...
[INFO] Running universal forensic methods...

[INFO] Forensic analysis complete. Results saved:
 - ela: ela_result.png
 - block_variance: block_variance.png
 - noise_map: noise_map.png
 - sharpness_map: sharpness_map.png

[INFO] Forensic scores: {'ela_mean': 2.0112096943048576, 'noise_std': 25.687977778493764, 'sharpness_mean': 7.175243404522613}

[RESULT] Decision: REAL (fraud score = 1)
