In [None]:
KAGGLEHUB_PATH="/kaggle/input/recodai-luc-scientific-image-forgery-detection"

In [None]:
import numpy as np
import pandas as pd
import os
import csv
import warnings
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import cv2
from tqdm.auto import tqdm
import sys
import logging

# --- FINAL SUBMISSION CONFIGURATION ---
IMAGE_SIZE = 256
MODEL_INPUT_CHANNELS = 4 # Match your successful 4-channel input (RGB + ELA)
OUTPUT_FILENAME = "submission.csv"

# CRITICAL: Path to the newly optimized model weights
FINAL_MODEL_PATH = "/kaggle/input/rluc-sfic-st/submission_B060_Final.pth" 

# Inference Parameters
FIXED_THRESHOLD = 0.50      # Use 0.50 for the final decision threshold
MIN_FORGERY_AREA = 64
Tversky_BETA = 0.60         # Beta used for the final loading/validation check
alpha = 0.40

# Kaggle Paths
TEST_IMAGE_ROOT = os.path.join(KAGGLEHUB_PATH, "test_images")
SAMPLE_SUBMISSION_FILE = os.path.join(KAGGLEHUB_PATH, "sample_submission.csv")

# --- CORE FUNCTIONS (Required for loading the model and inference) ---

class TverskyLoss(nn.Module):
    # This loss definition is only used to satisfy PyTorch's requirement for loading weights trained with this beta.
    def __init__(self, alpha=alpha, beta=Tversky_BETA, smooth=1e-7):
        super(TverskyLoss, self).__init__(); self.alpha = alpha; self.beta = beta; self.smooth = smooth
    def forward(self, inputs, targets):
        inputs = inputs.view(-1); targets = targets.view(-1)
        TP = (inputs * targets).sum(); FP = ((1 - targets) * inputs).sum()
        FN = (targets * (1 - inputs)).sum()
        tversky_index = (TP + self.smooth) / (TP + self.alpha * FP + self.beta * FN + self.smooth)
        return 1 - tversky_index

# Helper function to build the correct Convolution Block structure
def ConvBlock(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True), 
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True)
    )

class UNet(nn.Module):
    # Reconstructed U-Net architecture matching the layer names and channels from the 0.303 checkpoint.
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        
        # Encoder
        self.enc1 = ConvBlock(in_channels, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.enc2 = ConvBlock(64, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Bottleneck
        self.bottleneck = ConvBlock(128, 256)
        
        # Decoder
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = ConvBlock(128 + 128, 128) 
        
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = ConvBlock(64 + 64, 64) 
        
        # Final Output
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1) 
        
    def forward(self, x):
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        b = self.bottleneck(p2)
        d2 = self.upconv2(b)
        d2 = F.interpolate(d2, size=e2.shape[2:], mode='nearest') 
        d2 = self.dec2(torch.cat((d2, e2), dim=1))
        d1 = self.upconv1(d2)
        d1 = F.interpolate(d1, size=e1.shape[2:], mode='nearest')
        d1 = self.dec1(torch.cat((d1, e1), dim=1))
        return torch.sigmoid(self.final_conv(d1))


def get_ela_feature_data(img_path):
    """Generates the single-channel ELA feature input (Must match training preprocessing)."""
    try:
        img = cv2.imread(img_path)
        img_resized = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
        ela_feature = np.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=np.float32) 
        return ela_feature
    except Exception:
        return np.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=np.float32)

def rle_encode(mask):
    """Encodes a binary mask into a space-separated RLE string."""
    if mask.sum() == 0: return "authentic"
    pixels = mask.T.flatten(); pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    # NOTE: Returns space-separated RLE string: "N N N N..."
    return ' '.join(str(x) for x in runs)

def create_test_df_robust(test_image_root, sample_submission_path):
    master_df = pd.read_csv(sample_submission_path); master_df['case_id'] = master_df['case_id'].astype(str)
    present_files = {}
    if os.path.exists(test_image_root):
        for root, _, files in os.walk(test_image_root):
            for f in files:
                case_id = os.path.splitext(f)[0]
                if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff', '.npy')) and case_id.isdigit():
                    present_files[case_id] = os.path.join(root, f)
    master_df['img_path'] = master_df['case_id'].map(present_files).fillna('MISSING_FILE')
    return master_df[master_df['img_path'] != 'MISSING_FILE'][['case_id', 'img_path']].reset_index(drop=True)

def run_submission_inference(unet_model, test_df, fixed_threshold, min_forgery_area):
    results = []
    unet_model.to('cpu').eval() # Run on CPU for stability in inference
    
    with torch.no_grad():
        for index, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Generating Submission"):
            case_id = str(row['case_id']); img_path = row['img_path']
            gc.collect()

            img_bgr = cv2.imread(img_path)
            if img_bgr is None: continue
            
            img_rgb_orig = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB);
            
            # Prepare 4-Channel Input (RGB + ELA)
            rgb_resized = cv2.resize(img_rgb_orig, (IMAGE_SIZE, IMAGE_SIZE)) / 255.0
            ela_feature = get_ela_feature_data(img_path)
            
            input_4ch = np.dstack([rgb_resized, np.expand_dims(ela_feature, axis=-1)])
            
            # Convert to PyTorch format (C, H, W) and add batch dim (1, C, H, W)
            input_tensor = torch.from_numpy(input_4ch).permute(2, 0, 1).float().unsqueeze(0).to('cpu')

            # Prediction
            output_prob = unet_model(input_tensor).squeeze().numpy()
            
            # Post-Processing
            final_mask_resized = (output_prob > fixed_threshold).astype(np.uint8)
            clean_mask_resized = np.zeros_like(final_mask_resized)

            num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(final_mask_resized, 4, cv2.CV_32S)

            for label in range(1, num_labels):
                area = stats[label, cv2.CC_STAT_AREA]
                if area >= min_forgery_area:
                    clean_mask_resized[labels == label] = 1
            
            original_shape = img_rgb_orig.shape[:2]
            final_mask = cv2.resize(clean_mask_resized, (original_shape[1], original_shape[0]), interpolation=cv2.INTER_NEAREST)
            rle_annotation = rle_encode(final_mask); 
            results.append({'case_id': case_id, 'annotation': rle_annotation})

    return pd.DataFrame(results)

# --- 3. FINAL EXECUTION BLOCK FOR INFERENCE ---
if __name__ == "__main__":

    print("\n--- Starting FINAL SUBMISSION INFERENCE ---")

    # 1. Load Model
    model = UNet(in_channels=MODEL_INPUT_CHANNELS, out_channels=1)

    try:
        if not os.path.exists(FINAL_MODEL_PATH):
             raise FileNotFoundError(f"Final model weights not found at: {FINAL_MODEL_PATH}.")
        
        # Suppress UserWarning on load_weights
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", UserWarning)
            # CRITICAL FIX: Load the state dict and map to CPU
            model.load_state_dict(torch.load(FINAL_MODEL_PATH, map_location=torch.device('cpu')), strict=False)
        
        print(f"âœ… Loaded Final CTF Model weights: {FINAL_MODEL_PATH}")
    except Exception as e:
        print(f"ðŸ›‘ FATAL Error loading final weights: {e}. Aborting submission.")
        sys.exit(1)

    # 2. Generate Submission File for Test Data
    print("\n--- Generating Kaggle Submission File ---")
    test_df = create_test_df_robust(TEST_IMAGE_ROOT, SAMPLE_SUBMISSION_FILE)

    if test_df.empty:
        submission_df = pd.DataFrame(columns=['case_id', 'annotation'])
    else:
        print(f"Processing {len(test_df)} test case(s)...")
        results_df = run_submission_inference(model, test_df, FIXED_THRESHOLD, MIN_FORGERY_AREA)
        submission_df = pd.read_csv(SAMPLE_SUBMISSION_FILE)[['case_id']].astype(str)
        submission_df = submission_df.merge(results_df, on='case_id', how='left')
        submission_df['annotation'] = submission_df['annotation'].fillna('authentic')
        submission_df = submission_df[['case_id', 'annotation']].sort_values('case_id').reset_index(drop=True)

    # 3. Write Final CSV (Guaranteed Correct RLE Formatting)
    with open(OUTPUT_FILENAME, "w", newline='') as f:
        writer = csv.writer(f, quoting=csv.QUOTE_MINIMAL)
        writer.writerow(['case_id', 'annotation'])

        for _, row in submission_df.iterrows():
            annotation = row['annotation']

            if annotation.lower() == 'authentic':
                writer.writerow([row['case_id'], annotation])
            else:
                # CRITICAL FIX: Generate the exact comma-separated string required inside the brackets.
                
                # 1. Split the space-separated numbers (e.g., "442080 34 442384 40")
                rle_list = annotation.split(' ')
                
                # 2. Join the list using ", " (e.g., "442080, 34, 442384, 40")
                comma_separated_rle = ", ".join(rle_list)
                
                # 3. Wrap in brackets.
                full_rle_string = f"[{comma_separated_rle}]"
                
                writer.writerow([row['case_id'], full_rle_string])

    print(f"\nâœ… FINAL SUBMISSION CREATED: {OUTPUT_FILENAME} with {len(submission_df)} total rows. Please submit this file.")

In [None]:
def validate_and_print_rle(submission_df):
    """
    Validates RLE output structure and prints debugging info.
    Checks for: 1. Authentic/RLE count. 2. Even number of RLE elements.
    """
    print("\n--- RLE Output Validation Check ---")

    # Analyze the annotations
    authentic_count = submission_df['annotation'].apply(lambda x: x == 'authentic').sum()
    rle_rows = submission_df[submission_df['annotation'] != 'authentic']

    print(f"Total Submissions: {len(submission_df)}")
    print(f"Authentic (No Forgery) Count: {authentic_count}")
    print(f"RLE Annotated (Forged) Count: {len(rle_rows)}")

    # CRITICAL CHECK: RLE strings must always have an even number of elements (start, length, start, length...)
    rle_check = rle_rows['annotation'].apply(lambda x: len(x.split(' ')) % 2 == 0)

    if rle_check.all():
        print(f"âœ… RLE Structure: All {len(rle_rows)} RLE strings contain an even number of elements.")
    else:
        # Prints a warning if any RLE string has an odd number of elements (a common error)
        bad_rle_count = len(rle_rows) - rle_check.sum()
        print(f"ðŸ›‘ RLE ERROR: Found {bad_rle_count} RLE strings with an odd number of elements (Invalid pairing).")

In [None]:
submission_df = pd.read_csv("submission.csv")
validate_and_print_rle(submission_df)

In [None]:
!cat submission.csv