In [None]:
import numpy as np
import cv2
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import pandas as pd
from sklearn.model_selection import train_test_split 
import time
from torch.optim.lr_scheduler import ReduceLROnPlateau 
import warnings

# Suppress the specific UserWarning from the LR scheduler
warnings.filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler")

from warnings import filterwarnings
filterwarnings('ignore') 

# --- CONFIGURATION --
TARGET_SIZE = 256
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 8
# FINAL OPTIMIZATION: 2 epochs is sufficient since the model saves after the first epoch's best val loss.
EPOCHS = 2 
LEARNING_RATE = 1e-4

# --- PATHS ---
TRAIN_ROOT = "/kaggle/input/recodai-luc-scientific-image-forgery-detection/train_images" 
MASK_ROOT = "/kaggle/input/recodai-luc-scientific-image-forgery-detection/train_masks" 

MODEL_SAVE_PATH = "/tmp/model_new_scratch.pth" 

# --- UTILITY FUNCTIONS ---
def compute_ela(img_path, quality=95, scale=10):
    img = cv2.imread(img_path)
    if img is None or img.size == 0:
        try:
            img_data = np.load(img_path)
            if img_data.ndim == 3:
                img = cv2.cvtColor(img_data, cv2.COLOR_RGB2BGR)
            elif img_data.ndim == 2:
                img = cv2.cvtColor(img_data, cv2.COLOR_GRAY2BGR)
        except Exception:
            return np.zeros((TARGET_SIZE, TARGET_SIZE), dtype=np.float32)

    if img is None or img.size == 0:
        return np.zeros((TARGET_SIZE, TARGET_SIZE), dtype=np.float32)

    img_resized = cv2.resize(img, (TARGET_SIZE, TARGET_SIZE)) 
    temp_path = f"/tmp/temp_ela_{os.path.basename(img_path)}_{time.time()}.jpg"
    try:
        cv2.imwrite(temp_path, img_resized, [cv2.IMWRITE_JPEG_QUALITY, quality])
        compressed_img = cv2.imread(temp_path)
        if compressed_img is None: return np.zeros((TARGET_SIZE, TARGET_SIZE), dtype=np.float32) 
        error = np.abs(img_resized.astype(np.float32) - compressed_img.astype(np.float32))
        ela_feature_2d = np.mean(error, axis=2) * scale
    finally:
        if os.path.exists(temp_path): os.remove(temp_path)
    return cv2.resize(ela_feature_2d, (TARGET_SIZE, TARGET_SIZE), 
                      interpolation=cv2.INTER_LINEAR).astype(np.float32)

class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    def forward(self, pred, target):
        pred = pred.contiguous().view(-1)
        target = target.contiguous().view(-1)
        intersection = (pred * target).sum()
        dice = (2. * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)
        return 1 - dice

# Model architecture includes Dropout to match inference cell
class UNet(nn.Module):
    def __init__(self, in_channels=4, num_classes=1):
        super().__init__()
        def block(in_c, out_c): 
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, 1, 1), nn.ReLU(), 
                nn.Dropout(p=0.2), # Dropout layer 
                nn.Conv2d(out_c, out_c, 3, 1, 1), nn.ReLU()
            )
        
        self.enc1 = block(in_channels, 64)
        self.enc2 = block(64, 128)
        self.bottleneck = block(128, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.dec2 = block(128 + 128, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.dec1 = block(64 + 64, 64)
        self.final_conv = nn.Conv2d(64, num_classes, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        p1 = F.max_pool2d(e1, 2)
        e2 = self.enc2(p1)
        p2 = F.max_pool2d(e2, 2)
        b = self.bottleneck(p2)
        d2 = self.upconv2(b)
        d2 = torch.cat((d2, e2), dim=1)
        d2 = self.dec2(d2)
        d1 = self.upconv1(d2)
        d1 = torch.cat((d1, e1), dim=1)
        d1 = self.dec1(d1)
        return torch.sigmoid(self.final_conv(d1))

class ForgeryDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row['img_path']
        mask_path = row['mask_path']

        # --- Load Image ---
        rgb_image = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        
        if rgb_image is None or rgb_image.size == 0:
            try:
                img_data = np.load(img_path)
                if img_data.ndim == 3:
                    rgb_image = img_data 
                elif img_data.ndim == 2:
                    rgb_image = cv2.cvtColor(img_data, cv2.COLOR_GRAY2RGB)
            except Exception as e:
                raise RuntimeError(f"Failed to load image from {img_path}: {e}")
        
        # --- Load Mask ---
        try:
            mask = np.load(mask_path)
            if mask.ndim > 2:
                mask = mask[:, :, 0]
        except Exception as e:
            raise RuntimeError(f"Failed to load mask from {mask_path}: {e}")

        ela_feature_2d = compute_ela(img_path) 
        
        # Resize all features
        rgb_image_resized = cv2.resize(rgb_image, (TARGET_SIZE, TARGET_SIZE))
        ela_feature_resized = cv2.resize(ela_feature_2d, (TARGET_SIZE, TARGET_SIZE)) 
        
        # Use INTER_NEAREST for binary mask resizing
        mask_resized = cv2.resize(mask.astype(np.uint8), (TARGET_SIZE, TARGET_SIZE), interpolation=cv2.INTER_NEAREST)

        # Stack RGB (3) and ELA (1) for a 4-channel input
        ela_feature_3d = np.expand_dims(ela_feature_resized, axis=-1)
        stacked_input = np.concatenate([rgb_image_resized, ela_feature_3d], axis=-1)
        
        # Convert to PyTorch tensors and normalize
        image = torch.tensor(stacked_input.transpose(2, 0, 1) / 255.0, dtype=torch.float32)
        mask = torch.tensor(mask_resized / 255.0, dtype=torch.float32).unsqueeze(0) 

        return image, mask

def train_model(model, train_loader, val_loader, epochs=EPOCHS, save_path=MODEL_SAVE_PATH):
    criterion = DiceLoss() 
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    # Scheduler to reduce LR if val loss plateaus
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1) 
    
    best_val_loss = float('inf')
    
    model.to(DEVICE)
    print(f"Starting training on {DEVICE} for {epochs} epochs...")
    
    for epoch in range(epochs):
        model.train()
        train_loss_sum = 0
        
        for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_loss_sum += loss.item()
        
        avg_train_loss = train_loss_sum / len(train_loader)
        
        # Validation Phase
        model.eval()
        val_loss_sum = 0
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss_sum += loss.item()
        
        avg_val_loss = val_loss_sum / len(val_loader)
        
        # Scheduler Step
        scheduler.step(avg_val_loss)

        print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), save_path)
            print(f"Model saved successfully to {save_path}. (New best Val Loss: {best_val_loss:.4f})")

# --- MAIN EXECUTION BLOCK ---

if __name__ == '__main__':
    
    print("Preparing training data paths...")
    
    data_list = []
    
    # Recursively walk through the TRAIN_ROOT to find all image files
    for root, _, files in os.walk(TRAIN_ROOT):
        for f in files:
            valid_extensions = ('.png', '.jpg', '.jpeg', '.tif', '.tiff', '.npy') 
            
            if f.lower().endswith(valid_extensions):
                # Only process files in the 'forged' subdirectory, as only they have masks
                if 'forged' in root.lower():
                    case_id = os.path.splitext(f)[0]
                    img_path = os.path.join(root, f)
                    
                    # Use .npy for the mask extension
                    mask_path = os.path.join(MASK_ROOT, f"{case_id}.npy")
                    
                    data_list.append({
                        'case_id': case_id,
                        'img_path': img_path,
                        'mask_path': mask_path
                    })

    if not data_list:
        full_df = pd.DataFrame(columns=['case_id', 'img_path', 'mask_path'])
    else:
        full_df = pd.DataFrame(data_list)
    
    if not full_df.empty:
        # Final check: Keep only images that have a corresponding mask file
        full_df['mask_exists'] = full_df['mask_path'].apply(os.path.exists)
        full_df = full_df[full_df['mask_exists']].drop(columns=['mask_exists']).reset_index(drop=True)
    
    if full_df.empty:
        print("ðŸ›‘ FATAL ERROR: No valid image/mask pairs found in the input paths. Cannot train. (Check file extensions/paths again)")
    else:
        print(f"âœ… Found {len(full_df)} valid forged samples for training.")
        
        # Split data
        train_df, val_df = train_test_split(full_df, test_size=0.1, random_state=42)
        
        # Create DataLoaders
        train_dataset = ForgeryDataset(train_df)
        val_dataset = ForgeryDataset(val_df)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

        # Instantiate model (4 input channels: 3 RGB + 1 ELA)
        model = UNet(in_channels=4)

        # START TRAINING
        train_model(model, train_loader, val_loader)
        
        print("\nâœ… TRAINING COMPLETE. The trained model is saved and ready for inference.")


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Double convolution block used in UNet
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

# UNet Model
class UNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=1):
        super(UNet, self).__init__()
        self.dconv_down1 = DoubleConv(in_ch, 64)
        self.dconv_down2 = DoubleConv(64, 128)
        self.dconv_down3 = DoubleConv(128, 256)
        self.dconv_down4 = DoubleConv(256, 512)
        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dconv_up3 = DoubleConv(256 + 512, 256)
        self.dconv_up2 = DoubleConv(128 + 256, 128)
        self.dconv_up1 = DoubleConv(128 + 64, 64)
        self.conv_last = nn.Conv2d(64, out_ch, 1)

    def forward(self, x):
        conv1 = self.dconv_down1(x)
        conv2 = self.dconv_down2(self.maxpool(conv1))
        conv3 = self.dconv_down3(self.maxpool(conv2))
        conv4 = self.dconv_down4(self.maxpool(conv3))

        x = self.upsample(conv4)
        x = torch.cat([x, conv3], dim=1)
        x = self.dconv_up3(x)

        x = self.upsample(x)
        x = torch.cat([x, conv2], dim=1)
        x = self.dconv_up2(x)

        x = self.upsample(x)
        x = torch.cat([x, conv1], dim=1)
        x = self.dconv_up1(x)

        out = self.conv_last(x)
        return out


In [None]:
import torch

# Assuming UNet is already defined (from previous code)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Create UNet instance
model = UNet(in_ch=3, out_ch=1).to(DEVICE)

# Create a dummy input image batch: batch_size=1, 3 channels, 256x256 image
dummy_input = torch.randn(1, 3, 256, 256).to(DEVICE)

# Forward pass
output = model(dummy_input)

# Print output shape
print("Output shape:", output.shape)


In [None]:
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
import torch

class ForgeryDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]['img_path']
        mask_path = self.df.iloc[idx]['mask_path']

        # Load image and mask
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.load(mask_path)  # assuming mask is saved as .npy

        # Normalize image to [0,1]
        image = image / 255.0
        mask = mask.astype(np.float32)

        # Convert to CHW format for PyTorch
        image = torch.tensor(image.transpose(2,0,1), dtype=torch.float)
        mask = torch.tensor(mask[np.newaxis, :, :], dtype=torch.float)  # add channel dim

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask


In [None]:
from torch.utils.data import DataLoader

BATCH_SIZE = 4  # you can adjust depending on GPU
train_dataset = ForgeryDataset(train_df)
val_dataset = ForgeryDataset(val_df)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [None]:
import torch.optim as optim

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

model = UNet(in_ch=3, out_ch=1).to(DEVICE)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)


In [None]:
def compute_multi_ela(img_path, scales=[5, 10, 15]):
    img = cv2.imread(img_path)
    if img is None or img.size == 0:
        try:
            img_data = np.load(img_path)
            if img_data.ndim == 3:
                img = cv2.cvtColor(img_data, cv2.COLOR_RGB2BGR)
            elif img_data.ndim == 2:
                img = cv2.cvtColor(img_data, cv2.COLOR_GRAY2BGR)
        except:
            return np.zeros((TARGET_SIZE, TARGET_SIZE, len(scales)+3), dtype=np.float32)
    img_resized = cv2.resize(img, (TARGET_SIZE, TARGET_SIZE))
    ela_channels = []
    for scale in scales:
        temp_path = f"/tmp/temp_{scale}_{time.time()}.jpg"
        cv2.imwrite(temp_path, img_resized, [cv2.IMWRITE_JPEG_QUALITY, 95])
        compressed_img = cv2.imread(temp_path)
        if compressed_img is None: compressed_img = np.zeros_like(img_resized)
        error = np.abs(img_resized.astype(np.float32) - compressed_img.astype(np.float32))
        ela_feature = np.mean(error, axis=2) * scale
        ela_channels.append(ela_feature)
        os.remove(temp_path)
    # Combine RGB + ELA channels
    combined = np.concatenate([img_resized.astype(np.float32)/255.] + [c[...,None]/255. for c in ela_channels], axis=2)
    return cv2.resize(combined, (TARGET_SIZE, TARGET_SIZE), interpolation=cv2.INTER_LINEAR)


In [None]:
def post_process_mask(mask, threshold=0.45, min_area=64):
    mask_bin = (mask > threshold).astype(np.uint8)
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_bin, connectivity=8)
    final_mask = np.zeros_like(mask_bin)
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] >= min_area:
            final_mask[labels == i] = 1
    return final_mask


In [None]:
import cv2
import matplotlib.pyplot as plt
import os

# --- Configuration ---
# Set the most probable path for the test image in the Kaggle environment
IMAGE_PATH = "/kaggle/input/recodai-luc-scientific-image-forgery-detection/test_images/45.png" 

# --- Image Display Script ---

try:
    # 1. Check if the file exists
    if not os.path.exists(IMAGE_PATH):
        # Fallback check, as the file extension might be .npy or .jpg
        print(f"ðŸ›‘ ERROR: Image file '{IMAGE_PATH}' not found. Trying common alternatives...")
        
        # Checking for common alternatives found in the dataset
        base_dir = "/kaggle/input/recodai-luc-scientific-image-forgery-detection/test_images"
        if os.path.exists(os.path.join(base_dir, '45.npy')):
            IMAGE_PATH = os.path.join(base_dir, '45.npy')
        elif os.path.exists(os.path.join(base_dir, '45.jpg')):
             IMAGE_PATH = os.path.join(base_dir, '45.jpg')
        else:
             # If no path is found, raise the error
            raise FileNotFoundError(f"Image 45 not found at {base_dir} with .png, .npy, or .jpg extension.")
    
    # 2. Load the image robustly (handling .npy if necessary)
    img_bgr = cv2.imread(IMAGE_PATH)
    
    # If standard loading fails (e.g., it's a raw .npy file)
    if img_bgr is None or img_bgr.size == 0:
        try:
            img_data = np.load(IMAGE_PATH)
            if img_data.ndim == 3:
                 # Assume RGB/BGR format is loaded
                img_bgr = cv2.cvtColor(img_data, cv2.COLOR_RGB2BGR)
            elif img_data.ndim == 2:
                 # Grayscale to BGR
                img_bgr = cv2.cvtColor(img_data, cv2.COLOR_GRAY2BGR)
        except Exception:
            raise RuntimeError(f"Failed to load image data from '{IMAGE_PATH}' using both cv2.imread and np.load.")

    # 3. Convert from BGR to RGB for correct display in matplotlib
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

    # 4. Display the image
    plt.figure(figsize=(10, 10))
    plt.imshow(img_rgb)
    plt.title(f"Image for Case ID: 45 ({os.path.basename(IMAGE_PATH)})")
    plt.axis('off') # Hide axis ticks and labels
    plt.show()

except Exception as e:
    print(f"An error occurred during image display: {e}")

In [None]:
class UNetEnhanced(nn.Module):
    def __init__(self, in_channels=6, out_channels=1):
        super().__init__()
        self.enc1 = nn.Sequential(nn.Conv2d(in_channels, 64, 3, padding=1), nn.ReLU(), nn.BatchNorm2d(64))
        self.enc2 = nn.Sequential(nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.BatchNorm2d(128))
        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec1 = nn.Sequential(nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(), nn.BatchNorm2d(64))
        self.final = nn.Conv2d(64, out_channels, 1)
    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        d1 = self.up(e2)
        d1 = self.dec1(d1 + e1)  # skip connection
        return torch.sigmoid(self.final(d1))


In [None]:
import numpy as np
import cv2
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from tqdm import tqdm
import time
import csv

# --- CONFIGURATION & PATHS ---
TARGET_SIZE = 256
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 8
# ENHANCEMENT: Adjusted threshold to favor True Positives (from 0.5)
FIXED_THRESHOLD = 0.45 
# ENHANCEMENT: Minimum area filter for noise reduction (Tune this value on public LB)
MIN_FORGERY_AREA = 64

TEST_IMAGE_ROOT = "/kaggle/input/recodai-luc-scientific-image-forgery-detection/test_images" 
SAMPLE_SUBMISSION_FILE = "/kaggle/input/recodai-luc-scientific-image-forgery-detection/sample_submission.csv"

# Model path should point to the successfully trained model
model_path = "/tmp/model_new_scratch.pth" 
OUTPUT_FILENAME = "submission.csv" 

# --- UTILITY FUNCTIONS (Must be defined here for the block to run) ---
def compute_ela(img_path, quality=95, scale=10):
    img = cv2.imread(img_path)
    if img is None or img.size == 0:
        try:
            img_data = np.load(img_path)
            if img_data.ndim == 3: img = cv2.cvtColor(img_data, cv2.COLOR_RGB2BGR)
            elif img_data.ndim == 2: img = cv2.cvtColor(img_data, cv2.COLOR_GRAY2BGR)
        except Exception:
            return np.zeros((TARGET_SIZE, TARGET_SIZE), dtype=np.float32)

    if img is None or img.size == 0:
        return np.zeros((TARGET_SIZE, TARGET_SIZE), dtype=np.float32)
    
    img_resized = cv2.resize(img, (TARGET_SIZE, TARGET_SIZE)) 
    temp_path = f"/tmp/temp_{os.path.basename(img_path)}_{time.time()}.jpg"
    try:
        cv2.imwrite(temp_path, img_resized, [cv2.IMWRITE_JPEG_QUALITY, quality])
        compressed_img = cv2.imread(temp_path)
        if compressed_img is None: return np.zeros((TARGET_SIZE, TARGET_SIZE), dtype=np.float32) 
        error = np.abs(img_resized.astype(np.float32) - compressed_img.astype(np.float32))
        ela_feature_2d = np.mean(error, axis=2) * scale
    finally:
        if os.path.exists(temp_path): os.remove(temp_path)
    return cv2.resize(ela_feature_2d, (TARGET_SIZE, TARGET_SIZE), 
                      interpolation=cv2.INTER_LINEAR).astype(np.float32)

def create_test_df_robust(test_image_root, sample_submission_path):
    master_df = pd.read_csv(sample_submission_path)
    master_df['case_id'] = master_df['case_id'].astype(str)
    present_files = {}
    if os.path.exists(test_image_root):
        # Use os.walk for robust search, accounting for .npy
        for root, _, files in os.walk(test_image_root):
            for f in files:
                case_id = os.path.splitext(f)[0]
                if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff', '.npy')):
                    present_files[case_id] = os.path.join(root, f)
                    
    master_df['img_path'] = master_df['case_id'].map(present_files)
    master_df['img_path'] = master_df['img_path'].fillna('MISSING_FILE')
    return master_df[['case_id', 'img_path']]

def rle_encode(mask):
    if mask.sum() == 0: return "authentic"
    pixels = mask.T.flatten()
    pixels = np.concatenate([[0], pixels, [0]]) 
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ', '.join(str(x) for x in runs)

# --- MODEL ARCHITECTURE (FIXED: Includes Dropout to match trained model) ---
class UNet(nn.Module):
    def __init__(self, in_channels=4, num_classes=1):
        super().__init__()
        # FIXED: Includes nn.Dropout(p=0.2)
        def block(in_c, out_c): 
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, 1, 1), nn.ReLU(), 
                nn.Dropout(p=0.2), 
                nn.Conv2d(out_c, out_c, 3, 1, 1), nn.ReLU()
            )

        self.enc1 = block(in_channels, 64)
        self.enc2 = block(64, 128)
        self.bottleneck = block(128, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.dec2 = block(128 + 128, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.dec1 = block(64 + 64, 64)
        self.final_conv = nn.Conv2d(64, num_classes, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        p1 = F.max_pool2d(e1, 2)
        e2 = self.enc2(p1)
        p2 = F.max_pool2d(e2, 2)
        b = self.bottleneck(p2)
        d2 = self.upconv2(b)
        d2 = torch.cat((d2, e2), dim=1)
        d2 = self.dec2(d2)
        d1 = self.upconv1(d2)
        d1 = torch.cat((d1, e1), dim=1)
        d1 = self.dec1(d1)
        return torch.sigmoid(self.final_conv(d1))

# --- INFERENCE FUNCTION WITH POST-PROCESSING ---
def run_inference_and_segment(unet_model, test_df):
    results = [] 
    unet_model.eval()
    images_to_process = []
    
    for index, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Processing"):
        case = row['case_id']
        img_path = row['img_path']
        
        if img_path == 'MISSING_FILE' or img_path == 'NOT_FOUND':
            results.append({'case_id': case, 'annotation': 'authentic'})
            continue
            
        try:
            # Robust image loading: Try cv2, then np.load for .npy
            rgb_image = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
            if rgb_image is None or rgb_image.size == 0:
                img_data = np.load(img_path)
                if img_data.ndim == 3: rgb_image = img_data
                elif img_data.ndim == 2: rgb_image = cv2.cvtColor(img_data, cv2.COLOR_GRAY2RGB)

            if rgb_image is None or rgb_image.size == 0: raise ValueError(f"Invalid image data for {case}")
                
            original_shape = rgb_image.shape[:2]
            ela_feature_2d = compute_ela(img_path)
            
            rgb_image_resized = cv2.resize(rgb_image, (TARGET_SIZE, TARGET_SIZE))
            ela_feature_resized = cv2.resize(ela_feature_2d, (TARGET_SIZE, TARGET_SIZE)) 
            ela_feature_3d = np.expand_dims(ela_feature_resized, axis=-1)
            stacked_input = np.concatenate([rgb_image_resized, ela_feature_3d], axis=-1)
            images_to_process.append((case, original_shape, stacked_input))
            
            # Process batch
            if len(images_to_process) == BATCH_SIZE or index == len(test_df) - 1:
                if images_to_process:
                    batch_inputs = torch.stack([
                        torch.tensor(img_data.transpose(2, 0, 1) / 255.0, dtype=torch.float32) 
                        for _, _, img_data in images_to_process
                    ]).to(DEVICE)
                    
                    with torch.no_grad():
                        outputs = unet_model(batch_inputs).detach().cpu().numpy()
                        
                    for i, output in enumerate(outputs):
                        case_id_out, original_shape_out, _ = images_to_process[i]
                        output_prob = output.squeeze()
                        
                        # --- LOG PROBABILITY HERE ---
                        # RESTORED: Logging the max probability for debugging
                        max_prob = np.max(output_prob)
                        print(f"|--- Case {case_id_out} Max Forgery Probability: {max_prob:.4f} ---|")
                        # ----------------------------

                        # Apply Threshold (using 0.45)
                        final_mask_resized = (output_prob > FIXED_THRESHOLD).astype(np.uint8)
                        
                        # ENHANCEMENT: Minimum Area Filtering (MAFilter)
                        clean_mask_resized = np.zeros_like(final_mask_resized)
                        
                        # Find connected components 
                        num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
                            final_mask_resized, 4, cv2.CV_32S
                        )
                        
                        # Iterate through each component (label 0 is the background)
                        for label in range(1, num_labels):
                            area = stats[label, cv2.CC_STAT_AREA]
                            if area >= MIN_FORGERY_AREA:
                                # Keep segments that meet the minimum size requirement
                                clean_mask_resized[labels == label] = 1 
                        
                        # Resize the CLEANED mask back to the original size
                        final_mask = cv2.resize(
                            clean_mask_resized, 
                            (original_shape_out[1], original_shape_out[0]), 
                            interpolation=cv2.INTER_NEAREST
                        )
                        
                        rle_annotation = rle_encode(final_mask)
                        results.append({'case_id': case_id_out, 'annotation': rle_annotation})
                        
                    images_to_process = [] # Reset batch
        except Exception as e:
            print(f"Error processing case {case}: {e}. Defaulting to authentic.")
            results.append({'case_id': case, 'annotation': 'authentic'})
    return pd.DataFrame(results)

# --- MAIN EXECUTION BLOCK ---
if __name__ == "__main__":
    
    print(f"--- Starting inference on {DEVICE} at {pd.Timestamp.now()} ---")
    
    # 1. Load Model
    model = None
    try:
        model = UNet(in_channels=4).to(DEVICE)
        model.load_state_dict(torch.load(model_path, map_location=DEVICE))
        model.eval() # Set model to evaluation mode
        print(f"Model loaded successfully from {model_path}")
    except Exception as e:
        print(f"Error loading model from {model_path}. Submitting 'authentic' for all cases. Error: {e}")
        model = None
        
    # 2. Prepare Data
    test_df = create_test_df_robust(TEST_IMAGE_ROOT, SAMPLE_SUBMISSION_FILE)
    test_df['case_id'] = test_df['case_id'].astype(str) 
    
    # 3. Run Inference
    if model:
        results_df = run_inference_and_segment(model, test_df)
    else:
        results_df = test_df[['case_id']].assign(annotation='authentic')
        
    # 4. Finalize Submission DF
    submission_df = test_df[['case_id']].copy().merge(results_df, on='case_id', how='left')
    submission_df['annotation'] = submission_df['annotation'].fillna('authentic')
    submission_df = submission_df[['case_id', 'annotation']].sort_values('case_id').reset_index(drop=True)
    
    # 5. Write CSV with Correct RLE Formatting
    with open(OUTPUT_FILENAME, "w", newline='') as f:
        writer = csv.writer(f, quoting=csv.QUOTE_MINIMAL)
        writer.writerow(['case_id', 'annotation'])
        
        for _, row in submission_df.iterrows():
            case_id = str(row['case_id']) 
            annotation = row['annotation']
            
            if annotation.lower() == 'authentic':
                 writer.writerow([case_id, annotation])
            else:
                 # Create the full bracketed RLE string
                 full_rle_string = f"[{annotation}]"
                 writer.writerow([case_id, full_rle_string])
    
    print(f"\nâœ… Created {OUTPUT_FILENAME} with {len(submission_df)} rows at {pd.Timestamp.now()}")


In [None]:
!cat submission.csv

In [None]:
def validate_and_print_rle(submission_df):
    """
    Validates RLE output structure and prints debugging info.
    Checks for: 1. Authentic/RLE count. 2. Even number of RLE elements.
    """
    print("\n--- RLE Output Validation Check ---")
    
    # Analyze the annotations
    authentic_count = submission_df['annotation'].apply(lambda x: x == 'authentic').sum()
    rle_rows = submission_df[submission_df['annotation'] != 'authentic']
    
    print(f"Total Submissions: {len(submission_df)}")
    print(f"Authentic (No Forgery) Count: {authentic_count}")
    print(f"RLE Annotated (Forged) Count: {len(rle_rows)}")
    
    # CRITICAL CHECK: RLE strings must always have an even number of elements (start, length, start, length...)
    rle_check = rle_rows['annotation'].apply(lambda x: len(x.split(' ')) % 2 == 0)
    
    if rle_check.all():
        print(f"âœ… RLE Structure: All {len(rle_rows)} RLE strings contain an even number of elements.")
    else:
        # Prints a warning if any RLE string has an odd number of elements (a common error)
        bad_rle_count = len(rle_rows) - rle_check.sum()
        print(f"ðŸ›‘ RLE ERROR: Found {bad_rle_count} RLE strings with an odd number of elements (Invalid pairing).")

In [None]:
submission_df = pd.read_csv("submission.csv")
validate_and_print_rle(submission_df)