In [None]:
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image


# ============================================================
# MODULE 1: Mask Encoding and Visualization
# ============================================================

def rle_encode(mask, fg_val=1):
    """Convert binary mask to Run-Length Encoding (RLE)"""
    dots = np.where(mask.T.flatten() == fg_val)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if b > prev + 1:
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return run_lengths


def visualize_mask(mask, title="Mask"):
    """Display binary mask with grid and pixel values"""
    plt.figure(figsize=(6, 6))
    plt.imshow(mask, cmap='gray', vmin=0, vmax=1)
    plt.title(title)
    plt.axis('off')

    # Grid
    for i in range(mask.shape[0] + 1):
        plt.axhline(i - 0.5, color='red', alpha=0.3, linewidth=0.5)
        plt.axvline(i - 0.5, color='red', alpha=0.3, linewidth=0.5)

    # Pixel values
    for i in range(mask.shape[0]):
        for j in range(mask.shape[1]):
            plt.text(
                j, i, str(mask[i, j]),
                ha='center', va='center',
                color='blue' if mask[i, j] == 0 else 'white',
                fontweight='bold'
            )
    plt.show()


# ============================================================
# MODULE 2: Example Mask Generators
# ============================================================

def create_plus_mask(size=9):
    """Create a plus-shaped binary mask"""
    mask = np.zeros((size, size), dtype=np.uint8)
    mid = size // 2
    mask[mid, 2:size-2] = 1
    mask[2:size-2, mid] = 1
    return mask


def create_minus_mask(size=9):
    """Create a minus-shaped binary mask"""
    mask = np.zeros((size, size), dtype=np.uint8)
    mid = size // 2
    mask[mid, 2:size-2] = 1
    return mask


# ============================================================
# MODULE 3: Mask Distribution Analysis
# ============================================================

def mask_distribution(train_masks_dir, heatmap_size=(100, 100)):
    """Generate heatmap of forgery mask positions"""
    if not os.path.exists(train_masks_dir):
        return (0.5, 0.5), None

    heatmap = np.zeros(heatmap_size, dtype=np.float32)
    all_positions = []

    for mask_file in os.listdir(train_masks_dir):
        if not mask_file.endswith('.npy'):
            continue

        mask_path = os.path.join(train_masks_dir, mask_file)
        try:
            mask = np.load(mask_path)

            if mask.ndim == 3:
                mask = mask.squeeze(axis=0) if mask.shape[0] == 1 else mask
                mask = mask.squeeze(axis=2) if mask.shape[-1] == 1 else mask

            if mask.ndim != 2:
                continue

            y_coords, x_coords = np.where(mask > 0)
            if len(y_coords) == 0:
                continue

            height, width = mask.shape
            for y, x in zip(y_coords, x_coords):
                norm_y, norm_x = y / height, x / width
                heatmap_y = min(int(norm_y * heatmap_size[0]), heatmap_size[0] - 1)
                heatmap_x = min(int(norm_x * heatmap_size[1]), heatmap_size[1] - 1)
                heatmap[heatmap_y, heatmap_x] += 1
                all_positions.append((norm_x, norm_y))

        except Exception:
            continue

    if all_positions:
        max_pos = np.unravel_index(np.argmax(heatmap), heatmap.shape)
        max_norm = (max_pos[1] / heatmap_size[1], max_pos[0] / heatmap_size[0])
        return max_norm, heatmap

    return (0.5, 0.5), heatmap


def plot_heatmap(heatmap):
    """Visualize heatmap"""
    plt.figure(figsize=(10, 8))
    plt.imshow(heatmap, cmap='hot', interpolation='nearest')
    plt.colorbar()
    plt.title('Forgery Location Heatmap')
    plt.xlabel('Normalized X')
    plt.ylabel('Normalized Y')
    plt.show()


# ============================================================
# MODULE 4: Submission Generation
# ============================================================

def generate_submission(test_images_dir, sample_csv, hottest_pos, output_path='submission.csv'):
    """Generate submission CSV file with RLE or 'authentic' labels"""
    sample_submission = pd.read_csv(sample_csv)
    submission_data = []

    for case_id in sample_submission['case_id']:
        img_path = os.path.join(test_images_dir, f"{case_id}.png")

        with Image.open(img_path) as img:
            width, height = img.size

        if np.random.random() < 0.01:
            mask = np.zeros((height, width), dtype=np.uint8)
            center_x = int(hottest_pos[0] * width)
            center_y = int(hottest_pos[1] * height)
            center_x = np.clip(center_x, 4, width - 5)
            center_y = np.clip(center_y, 4, height - 5)

            h = w = min(8, min(width, height) // 20)
            mask[center_y - h//2:center_y + h//2, center_x - w//2:center_x + w//2] = 1

            rle = rle_encode(mask)
            annotation = json.dumps([int(x) for x in rle])
        else:
            annotation = 'authentic'

        submission_data.append({'case_id': case_id, 'annotation': annotation})

    submission = pd.DataFrame(submission_data)
    submission.to_csv(output_path, index=False)
    print(f"âœ… Submission saved to {output_path}")


# ============================================================
# MAIN EXECUTION
# ============================================================

if __name__ == "__main__":
    np.random.seed(52)

    # Example masks
    example = np.array([[1, 0], [1, 1]])
    print(f"Our example:\n{example}")
    print(f"RLE encoding: {rle_encode(example)}")
    visualize_mask(example, "Example Mask")

    plus_mask = create_plus_mask()
    print(f"RLE Plus: {rle_encode(plus_mask)}")
    visualize_mask(plus_mask, "Plus Mask")

    minus_mask = create_minus_mask()
    print(f"RLE Minus: {rle_encode(minus_mask)}")
    visualize_mask(minus_mask, "Minus Mask")

    # Mask distribution and heatmap
    train_masks_dir = '/kaggle/input/recodai-luc-scientific-image-forgery-detection/train_masks'
    hottest_pos, heatmap = mask_distribution(train_masks_dir)
    if heatmap is not None:
        plot_heatmap(heatmap)

    # Generate submission
    test_images_dir = '/kaggle/input/recodai-luc-scientific-image-forgery-detection/test_images'
    sample_csv = '/kaggle/input/recodai-luc-scientific-image-forgery-detection/sample_submission.csv'
    generate_submission(test_images_dir, sample_csv, hottest_pos)
