In [None]:
"""
üî• NUCLEAR OPTION: Multi-Detector Ensemble üî•
GO BIG OR GO HOME

Strategy:
1. Run SIFT at TWO scales (original + 0.75x)
2. Add ORB detector (different feature type)
3. Union ALL detections (if ANY detector finds it ‚Üí include it)
4. Aggressive mask expansion
5. Lower all thresholds

Expected: 0.310-0.325 OR spectacular failure to 0.25
"""

import os
import json
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import cv2
from scipy.ndimage import label as scipy_label, binary_fill_holes, binary_dilation
from skimage.measure import regionprops

class Config:
    BASE_PATH = '/kaggle/input/recodai-luc-scientific-image-forgery-detection'
    TEST_IMAGES = os.path.join(BASE_PATH, 'test_images')
    SAMPLE_SUB = os.path.join(BASE_PATH, 'sample_submission.csv')
    
    # AGGRESSIVE SIFT
    SIFT_FEATURES = 8000  # Much more
    SIFT_CONTRAST = 0.015  # Lower threshold
    MATCH_RATIO = 0.82  # More permissive
    MIN_MATCHES = 3  # Lower
    RANSAC_THRESH = 6.0  # More permissive
    MIN_DISPLACEMENT = 18  # Catch closer copies
    
    # ORB DETECTOR
    ORB_FEATURES = 3000
    ORB_MATCH_THRESHOLD = 50  # Hamming distance
    
    # MULTI-SCALE
    SCALES = [1.0, 0.75]  # Run at original and 75% size
    
    # AGGRESSIVE THRESHOLDS
    CONFIDENCE_THRESHOLD = 0.20  # Very low
    MIN_MASK_PIXELS = 50  # Low
    MIN_COVERAGE = 0.0002  # Low
    MAX_COVERAGE = 0.50  # High
    
    # MASK GENERATION
    CIRCLE_RADIUS = 20  # Large circles
    USE_CLAHE = True

config = Config()

print("="*80)
print("üî• NUCLEAR OPTION: Multi-Detector Ensemble")
print("="*80)
print("Strategy: SIFT(2 scales) + ORB + Aggressive fusion")
print("Risk Level: EXTREME")
print("Expected: Jump to 0.31+ OR drop to 0.25")
print("="*80 + "\n")

def rle_encode(mask):
    if mask.sum() == 0:
        return []
    dots = np.where(mask.T.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if b > prev + 1:
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return run_lengths

def preprocess_image(img_array):
    if len(img_array.shape) == 3:
        gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
    else:
        gray = img_array.copy()
    
    gray = cv2.normalize(gray, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    
    if config.USE_CLAHE:
        clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
        gray = clahe.apply(gray)
    
    return gray

class AggressiveSIFT:
    """SIFT with aggressive parameters"""
    def __init__(self):
        self.sift = cv2.SIFT_create(
            nfeatures=config.SIFT_FEATURES,
            contrastThreshold=config.SIFT_CONTRAST,
            edgeThreshold=12
        )
    
    def detect(self, gray):
        h, w = gray.shape
        kp, desc = self.sift.detectAndCompute(gray, None)
        
        if desc is None or len(desc) < config.MIN_MATCHES * 2:
            return np.zeros((h, w), dtype=np.uint8), 0.0
        
        bf = cv2.BFMatcher(cv2.NORM_L2)
        matches = bf.knnMatch(desc, desc, k=2)
        
        good = []
        for m_n in matches:
            if len(m_n) == 2:
                m, n = m_n
                if m.queryIdx != m.trainIdx and m.distance < config.MATCH_RATIO * n.distance:
                    good.append(m)
        
        if len(good) < config.MIN_MATCHES:
            return np.zeros((h, w), dtype=np.uint8), 0.0
        
        src_pts = np.float32([kp[m.queryIdx].pt for m in good])
        dst_pts = np.float32([kp[m.trainIdx].pt for m in good])
        
        disp = np.linalg.norm(dst_pts - src_pts, axis=1)
        valid = disp > config.MIN_DISPLACEMENT
        
        if valid.sum() < config.MIN_MATCHES:
            return np.zeros((h, w), dtype=np.uint8), 0.0
        
        try:
            M, inliers = cv2.findHomography(src_pts[valid], dst_pts[valid], 
                                          cv2.RANSAC, config.RANSAC_THRESH)
            
            if M is None or inliers is None or inliers.sum() < config.MIN_MATCHES:
                return np.zeros((h, w), dtype=np.uint8), 0.0
            
            mask = np.zeros((h, w), dtype=np.uint8)
            
            for pt in src_pts[valid][inliers.flatten() > 0]:
                x, y = int(pt[0]), int(pt[1])
                if 0 <= x < w and 0 <= y < h:
                    cv2.circle(mask, (x, y), config.CIRCLE_RADIUS, 1, -1)
            
            for pt in dst_pts[valid][inliers.flatten() > 0]:
                x, y = int(pt[0]), int(pt[1])
                if 0 <= x < w and 0 <= y < h:
                    cv2.circle(mask, (x, y), config.CIRCLE_RADIUS, 1, -1)
            
            conf = min(1.0, inliers.sum() / 8.0)
            return mask, conf
        except:
            return np.zeros((h, w), dtype=np.uint8), 0.0

class ORBDetector:
    """ORB detector for different feature type"""
    def __init__(self):
        self.orb = cv2.ORB_create(nfeatures=config.ORB_FEATURES)
    
    def detect(self, gray):
        h, w = gray.shape
        kp, desc = self.orb.detectAndCompute(gray, None)
        
        if desc is None or len(desc) < config.MIN_MATCHES * 2:
            return np.zeros((h, w), dtype=np.uint8), 0.0
        
        bf = cv2.BFMatcher(cv2.NORM_HAMMING)
        matches = bf.knnMatch(desc, desc, k=2)
        
        good = []
        for m_n in matches:
            if len(m_n) == 2:
                m, n = m_n
                if m.queryIdx != m.trainIdx and m.distance < config.ORB_MATCH_THRESHOLD:
                    good.append(m)
        
        if len(good) < config.MIN_MATCHES:
            return np.zeros((h, w), dtype=np.uint8), 0.0
        
        src_pts = np.float32([kp[m.queryIdx].pt for m in good])
        dst_pts = np.float32([kp[m.trainIdx].pt for m in good])
        
        disp = np.linalg.norm(dst_pts - src_pts, axis=1)
        valid = disp > config.MIN_DISPLACEMENT
        
        if valid.sum() < config.MIN_MATCHES:
            return np.zeros((h, w), dtype=np.uint8), 0.0
        
        mask = np.zeros((h, w), dtype=np.uint8)
        
        for pt in src_pts[valid]:
            x, y = int(pt[0]), int(pt[1])
            if 0 <= x < w and 0 <= y < h:
                cv2.circle(mask, (x, y), config.CIRCLE_RADIUS, 1, -1)
        
        for pt in dst_pts[valid]:
            x, y = int(pt[0]), int(pt[1])
            if 0 <= x < w and 0 <= y < h:
                cv2.circle(mask, (x, y), config.CIRCLE_RADIUS, 1, -1)
        
        conf = min(1.0, valid.sum() / 10.0)
        return mask, conf

class NuclearEnsemble:
    """Ensemble of all detectors with aggressive fusion"""
    def __init__(self):
        self.sift = AggressiveSIFT()
        self.orb = ORBDetector()
        print("‚úì Initialized: Aggressive SIFT + ORB")
    
    def detect(self, image):
        gray = preprocess_image(image)
        h, w = image.shape[:2]
        
        all_masks = []
        all_confs = []
        
        # Run SIFT at multiple scales
        for scale in config.SCALES:
            if scale != 1.0:
                scaled_h, scaled_w = int(h * scale), int(w * scale)
                scaled_gray = cv2.resize(gray, (scaled_w, scaled_h))
            else:
                scaled_gray = gray
            
            mask, conf = self.sift.detect(scaled_gray)
            
            # Scale back
            if scale != 1.0 and mask.sum() > 0:
                mask = cv2.resize(mask.astype(np.float32), (w, h))
                mask = (mask > 0.5).astype(np.uint8)
            
            if mask.sum() > 0:
                all_masks.append(mask)
                all_confs.append(conf)
        
        # Run ORB
        orb_mask, orb_conf = self.orb.detect(gray)
        if orb_mask.sum() > 0:
            all_masks.append(orb_mask)
            all_confs.append(orb_conf)
        
        # UNION of all masks
        if len(all_masks) == 0:
            return np.zeros((h, w), dtype=np.uint8), 0.0
        
        combined = np.zeros((h, w), dtype=np.uint8)
        for mask in all_masks:
            combined = np.maximum(combined, mask)
        
        # Aggressive post-processing
        combined = self.aggressive_refine(combined)
        
        final_conf = max(all_confs) if all_confs else 0.0
        
        return combined, final_conf
    
    def aggressive_refine(self, mask):
        """Aggressive mask expansion"""
        if mask.sum() == 0:
            return mask
        
        h, w = mask.shape
        
        # Aggressive dilation
        kernel_large = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
        mask = cv2.dilate(mask, kernel_large, iterations=2)
        
        # Close gaps
        kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9, 9))
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel_close, iterations=3)
        
        # Fill holes
        mask = binary_fill_holes(mask).astype(np.uint8)
        
        # Remove only very small regions
        labeled, _ = scipy_label(mask)
        for region in regionprops(labeled):
            if region.area < 30:  # Very small threshold
                mask[labeled == region.label] = 0
        
        return mask

def generate_submission(detector):
    print("\nüöÄ Generating NUCLEAR submission...")
    
    sample_sub = pd.read_csv(config.SAMPLE_SUB)
    submissions = []
    stats = {'authentic': 0, 'forged': 0}
    
    for case_id in tqdm(sample_sub['case_id'], desc="Processing"):
        try:
            img = np.array(Image.open(os.path.join(config.TEST_IMAGES, f"{case_id}.png")))
            mask, conf = detector.detect(img)
            
            is_forged = False
            if conf > config.CONFIDENCE_THRESHOLD and mask.sum() >= config.MIN_MASK_PIXELS:
                coverage = mask.sum() / (img.shape[0] * img.shape[1])
                if config.MIN_COVERAGE < coverage < config.MAX_COVERAGE:
                    is_forged = True
            
            if is_forged:
                rle = rle_encode(mask)
                if len(rle) > 0:
                    annotation = json.dumps([int(x) for x in rle])
                    stats['forged'] += 1
                else:
                    annotation = 'authentic'
                    stats['authentic'] += 1
            else:
                annotation = 'authentic'
                stats['authentic'] += 1
            
            submissions.append({'case_id': case_id, 'annotation': annotation})
        except:
            submissions.append({'case_id': case_id, 'annotation': 'authentic'})
            stats['authentic'] += 1
    
    df = pd.DataFrame(submissions)
    df.to_csv('submission.csv', index=False)
    
    print(f"\n‚úì Complete:")
    print(f"  Forged: {stats['forged']} ({stats['forged']/len(df)*100:.1f}%)")
    print(f"  Authentic: {stats['authentic']} ({stats['authentic']/len(df)*100:.1f}%)")
    print("\n‚úì Saved: submission.csv")

def main():
    print("\n" + "="*80)
    print("üî• INITIALIZING NUCLEAR OPTION")
    print("="*80)
    
    detector = NuclearEnsemble()
    generate_submission(detector)
    
    print("\n" + "="*80)
    print("üí£ NUCLEAR SUBMISSION READY")
    print("="*80)
    print("\nThis will either:")
    print("  ‚úÖ Jump to 0.31+ and beat the leaders")
    print("  ‚ùå Drop to 0.25 spectacular failure")
    print("\nüé≤ GO BIG OR GO HOME! Submit and find out!")
    print("="*80 + "\n")

if __name__ == "__main__":
    main()