In [None]:
import os
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F

# =============================================================================
# CONFIG
# =============================================================================
class Config:
    BASE_PATH = '/kaggle/input/recodai-luc-scientific-image-forgery-detection'
    TRAIN_IMAGES_FORGED = os.path.join(BASE_PATH, 'train_images/forged')
    TRAIN_IMAGES_AUTH = os.path.join(BASE_PATH, 'train_images/authentic')
    TRAIN_MASKS = os.path.join(BASE_PATH, 'train_masks')
    TEST_IMAGES = os.path.join(BASE_PATH, 'test_images')
    SAMPLE_SUB = os.path.join(BASE_PATH, 'sample_submission.csv')

    # SIFT parameters
    SIFT_FEATURES = 5500
    SIFT_CONTRAST = 0.019
    MATCH_RATIO = 0.79
    MIN_MATCHES = 4
    RANSAC_THRESH = 5.5
    MIN_DISPLACEMENT = 23
    MAX_IMAGE_SIZE = 1600
    USE_CLAHE = True

    # Variant tweaks
    SIFT_CONFIDENCE_THRESHOLD = 0.31
    SIFT_MIN_MASK_PIXELS = 85
    SIFT_MIN_COVERAGE = 0.00035
    SIFT_MAX_COVERAGE = 0.42

    # CNN parameters
    CNN_ENABLED = True
    CNN_IMAGE_SIZE = 256
    CNN_THRESHOLD = 0.5

    # Ensemble
    ENSEMBLE_MODE = 'weighted'
    SIFT_WEIGHT = 0.75
    CNN_WEIGHT = 0.25

config = Config()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"ðŸŽ¯ HYBRID ENSEMBLE - SIFT + Fast CNN")
print(f" Device: {device}")

# =============================================================================
# Dummy CNN Model (replace with your trained model)
# =============================================================================
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 8, 3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(8,1)
    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = torch.sigmoid(self.fc(x))
        return x

cnn_model = SimpleCNN().to(device)
cnn_model.eval()

# =============================================================================
# SIFT scoring
# =============================================================================
def sift_score(img_path, ref_path):
    img = cv2.imread(img_path, 0)
    ref = cv2.imread(ref_path, 0)
    if img is None or ref is None:
        return 0.0
    if config.USE_CLAHE:
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        img = clahe.apply(img)
        ref = clahe.apply(ref)
    sift = cv2.SIFT_create(nfeatures=config.SIFT_FEATURES, contrastThreshold=config.SIFT_CONTRAST)
    kp1, des1 = sift.detectAndCompute(img, None)
    kp2, des2 = sift.detectAndCompute(ref, None)
    if des1 is None or des2 is None:
        return 0.0
    bf = cv2.BFMatcher()
    matches = bf.knnMatch(des1, des2, k=2)
    good = []
    for m,n in matches:
        if m.distance < config.MATCH_RATIO * n.distance:
            good.append(m)
    score = len(good) / max(len(matches),1)
    return float(score)

# =============================================================================
# CNN scoring
# =============================================================================
def cnn_score(img_path):
    img = Image.open(img_path).convert('RGB').resize((config.CNN_IMAGE_SIZE, config.CNN_IMAGE_SIZE))
    img = np.array(img)/255.0
    img = torch.tensor(img).permute(2,0,1).unsqueeze(0).float().to(device)
    with torch.no_grad():
        out = cnn_model(img).item()
    return float(out)

# =============================================================================
# Generate Submission
# =============================================================================
def generate_submission():
    sub = pd.read_csv(config.SAMPLE_SUB)
    preds = []
    # Use first authentic image as reference
    ref_img_name = os.listdir(config.TRAIN_IMAGES_AUTH)[0]
    ref_path = os.path.join(config.TRAIN_IMAGES_AUTH, ref_img_name)

    for img_name in tqdm(sub['case_id']):
        img_name_str = str(img_name) + ".png"  # Convert to string and add extension
        test_path = os.path.join(config.TEST_IMAGES, img_name_str)
        s_score = sift_score(test_path, ref_path)
        c_score = cnn_score(test_path) if config.CNN_ENABLED else 0.0

        # Ensemble
        if config.ENSEMBLE_MODE == 'weighted':
            final_score = config.SIFT_WEIGHT*s_score + config.CNN_WEIGHT*c_score
        else: # voting
            final_score = 1.0 if (s_score>0.5 or c_score>config.CNN_THRESHOLD) else 0.0
        preds.append(final_score)

    sub['annotation'] = preds
    output_file = "submission.csv"
    sub.to_csv(output_file, index=False)
    print(f"âœ… Submission saved to {output_file}")

# =============================================================================
# MAIN
# =============================================================================
if __name__ == "__main__":
    generate_submission()
