In [None]:
!cp -r /kaggle/input/segment-anything /kaggle/working/segment-anything
%cd segment-anything
!pip install -e .

In [None]:
%cd /kaggle/working/segment-anything
import sys
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "/kaggle/input/sam-vit-b/sam_vit_b_01ec64.pth" # the official vit-b checkpoint from https://github.com/facebookresearch/segment-anything?tab=readme-ov-file
# the segmentation masks are a bit subpar, consider using specialized models such as https://github.com/DevoLearn/CellSAM 
model_type = "vit_b"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)

In [None]:
import numpy as np
def extract_feature(path, spatial_iou_threshold=0.8, min_mask_area=800):
    import cv2
    import numpy as np
    image = cv2.imread(path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    #print(image.shape)
    factor = np.sqrt(image.shape[0] * image.shape[1] / (512 * 512))
    if factor > 1:
        image = cv2.resize(image, (int(image.shape[1] / factor), int(image.shape[0] / factor)))
    #print(image.shape)
    masks = mask_generator.generate(image)
    import numpy as np
    from math import prod
    from scipy import ndimage
    import torch
    
    def compute_iou(mask1, mask2):
        """Compute Intersection over Union between two masks"""
        intersection = np.logical_and(mask1, mask2).sum()
        union = np.logical_or(mask1, mask2).sum()
        return intersection / union if union > 0 else 0
    
    def compute_bbox_iou(bbox1, bbox2):
        """Compute IoU between two bounding boxes [x, y, w, h]"""
        x1, y1, w1, h1 = bbox1
        x2, y2, w2, h2 = bbox2
        
        # Calculate intersection coordinates
        xi1 = max(x1, x2)
        yi1 = max(y1, y2)
        xi2 = min(x1 + w1, x2 + w2)
        yi2 = min(y1 + h1, y2 + h2)
        
        # Calculate intersection area
        inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
        
        # Calculate union area
        box1_area = w1 * h1
        box2_area = w2 * h2
        union_area = box1_area + box2_area - inter_area
        
        return inter_area / union_area if union_area > 0 else 0
    
    def is_spatially_similar(mask1, mask2, iou_threshold=0.8):
        """Check if two masks are spatially too similar"""
        # Method 1: Mask IoU
        mask_iou = compute_iou(mask1['segmentation'], mask2['segmentation'])
        if mask_iou > iou_threshold:
            return True
        return False
    
    # Filter masks by area and spatial similarity
    filtered_masks = []
    regions = []
    
    for i, item in enumerate(masks):
        seg = item['segmentation']
        
        # Filter by area
        if item['area'] < min_mask_area:
            continue
            
        # Check spatial similarity with already selected masks
        is_duplicate = False
        for existing_item in filtered_masks:
            if is_spatially_similar(item, existing_item, spatial_iou_threshold):
                is_duplicate = True
                break
                
        if not is_duplicate:
            
            # Extract region
            rows = np.any(seg, axis=1)
            cols = np.any(seg, axis=0)
            
            if np.any(rows) and np.any(cols):
                y_min, y_max = np.where(rows)[0][[0, -1]]
                x_min, x_max = np.where(cols)[0][[0, -1]]
                
                # Extract the rectangular region
                image_transformed = np.logical_and(image, seg[:,:,np.newaxis])
                bbox_region = image_transformed[y_min:y_max+1, x_min:x_max+1, :]
                
                bbox_tensor = torch.tensor(bbox_region).permute(2, 0, 1).to(torch.float32)
                bbox_tensor = (bbox_tensor - bbox_tensor.min()) / (bbox_tensor.max() - bbox_tensor.min())
                area_ratio = item['area'] / ((y_max - y_min + 1) * (x_max - x_min + 1))
                # Additional size filtering
                if (prod(list(bbox_tensor.shape)) > 300 and 
                    prod(list(bbox_tensor.shape)) < (256 * 256)) and area_ratio < 0.9:
                    filtered_masks.append(item)
                    regions.append(bbox_tensor)
    
    print(f"Filtered from {len(masks)} to {len(filtered_masks)} masks")
    
    # If too few regions, return early
    if len(regions) < 2:
        return 0.0, [], image, filtered_masks, [], regions
    
    # Test with your regions
    scores = []
    record = 0.0
    record_coords = []
    
    for i in range(len(regions)):
        for j in range(len(regions)):
            if i < j:
                a = regions[i]
                b = regions[j]
                dice_score, aligned1_simple, aligned2_simple = simple_centroid_dice(a, b)
                scores.append(dice_score)
                if dice_score > record:
                    record_coords = [i, j]
                    record = dice_score
    
    return max(scores) if scores else 0.0, scores, image, filtered_masks, record_coords, regions

In [None]:
def calculate_centroid(binary_patch):
    """Calculate centroid of binary patch"""
    if binary_patch.sum() == 0:
        return None
    
    y_coords, x_coords = np.where(binary_patch > 0)
    centroid_y = int(np.mean(y_coords))
    centroid_x = int(np.mean(x_coords))
    
    return (centroid_y, centroid_x)
def visualize_alignment(patch1, patch2, aligned1, aligned2, centroid1, centroid2, dice):
    """Visualize the alignment process"""
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Original patches
    axes[0, 0].imshow(patch1, cmap='gray')
    axes[0, 0].plot(centroid1[1], centroid1[0], 'r+', markersize=15, markeredgewidth=2)
    axes[0, 0].set_title(f'Patch 1\nCentroid: {centroid1}')
    axes[0, 0].axis('off')
    
    axes[0, 1].imshow(patch2, cmap='gray')
    axes[0, 1].plot(centroid2[1], centroid2[0], 'r+', markersize=15, markeredgewidth=2)
    axes[0, 1].set_title(f'Patch 2\nCentroid: {centroid2}')
    axes[0, 1].axis('off')
    
    # Aligned patches
    axes[1, 0].imshow(aligned1, cmap='gray')
    axes[1, 0].set_title('Aligned Patch 1')
    axes[1, 0].axis('off')
    
    axes[1, 1].imshow(aligned2, cmap='gray')
    axes[1, 1].set_title('Aligned Patch 2')
    axes[1, 1].axis('off')
    
    # Overlay
    overlay = np.stack([aligned1, aligned2, np.zeros_like(aligned1)], axis=2)
    axes[1, 2].imshow(overlay)
    axes[1, 2].set_title(f'Overlay (Red: Patch1, Green: Patch2)\nDice Score: {dice:.3f}')
    axes[1, 2].axis('off')
    
    # Empty subplot for layout
    axes[0, 2].axis('off')
    
    plt.tight_layout()
    plt.show()


def simple_centroid_dice(region1, region2, visualize=False):
    """Simplified version that's more robust to indexing issues"""
    # Convert to numpy arrays if they are tensors
    if torch.is_tensor(region1):
        patch1 = region1[0].numpy() if region1.dim() == 3 else region1.numpy()
    else:
        patch1 = region1[0] if isinstance(region1, list) else region1
    
    if torch.is_tensor(region2):
        patch2 = region2[0].numpy() if region2.dim() == 3 else region2.numpy()
    else:
        patch2 = region2[0] if isinstance(region2, list) else region2
    
    # Ensure binary patches
    patch1_binary = (patch1 > 0).astype(np.uint8)
    patch2_binary = (patch2 > 0).astype(np.uint8)
    
    # Calculate centroids
    centroid1 = calculate_centroid(patch1_binary)
    centroid2 = calculate_centroid(patch2_binary)
    
    if centroid1 is None or centroid2 is None:
        return 0.0, None, None
    
    # Create large enough canvas
    h1, w1 = patch1_binary.shape
    h2, w2 = patch2_binary.shape
    
    # Use fixed large canvas size
    canvas_size = max(h1, w1, h2, w2) * 3
    canvas1 = np.zeros((canvas_size, canvas_size), dtype=np.uint8)
    canvas2 = np.zeros((canvas_size, canvas_size), dtype=np.uint8)
    
    # Place centroids at center of canvas
    center = canvas_size // 2
    
    # Calculate offsets from centroid to top-left corner
    offset1_y = centroid1[0]
    offset1_x = centroid1[1]
    offset2_y = centroid2[0]
    offset2_x = centroid2[1]
    
    # Calculate placement coordinates
    y1_start = center - offset1_y
    x1_start = center - offset1_x
    y2_start = center - offset2_y
    x2_start = center - offset2_x
    
    # Use try-except for robust placement
    try:
        canvas1[y1_start:y1_start+h1, x1_start:x1_start+w1] = patch1_binary
        canvas2[y2_start:y2_start+h2, x2_start:x2_start+w2] = patch2_binary
        
        # Compute Dice score
        intersection = np.logical_and(canvas1, canvas2).sum()
        union = canvas1.sum() + canvas2.sum()
        dice = 2 * intersection / union if union > 0 else 0.0
        if visualize:
            visualize_alignment(patch1_binary, patch2_binary, canvas1, canvas2, 
                              centroid1, centroid2, dice)
        return dice, canvas1, canvas2
        
    except Exception as e:
        print(f"Alignment error: {e}")
        return 0.0, None, None

In [None]:
import pandas as pd
subm_df = pd.read_csv("/kaggle/input/recodai-luc-scientific-image-forgery-detection/sample_submission.csv")

In [None]:
#for item in subm_df['case_id']:
#    print(item)

In [None]:
import torch
import numpy as np
import os
import numpy.typing as npt
import json
from skimage.transform import resize
from PIL import Image
import cv2
from numba import types
import numba
#roots = os.listdir("/app/notebooks/recodai-luc-scientific-image-forgery-detection/train_images/forged/") 
@numba.jit(nopython=True)
def _rle_encode_jit(x: npt.NDArray, fg_val: int = 1) -> list[int]:
    """Numba-jitted RLE encoder."""
    dots = np.where(x.T.flatten() == fg_val)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if b > prev + 1:
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return run_lengths


def rle_encode(masks: list[npt.NDArray], fg_val: int = 1) -> str:
    """
    Adapted from contrails RLE https://www.kaggle.com/code/inversion/contrails-rle-submission
    Args:
        masks: list of numpy array of shape (height, width), 1 - mask, 0 - background
    Returns: run length encodings as a string, with each RLE JSON-encoded and separated by a semicolon.
    """
    return ';'.join([json.dumps(_rle_encode_jit(x, fg_val)) for x in masks])
from sklearn.naive_bayes import GaussianNB
import joblib
clf = joblib.load("/kaggle/input/sam-vit-b/gaussian_nb_forgery1")
score, counter = 0, 0
preds = []
files = os.listdir("/kaggle/input/recodai-luc-scientific-image-forgery-detection/test_images")
#files = os.listdir("/kaggle/input/recodai-luc-scientific-image-forgery-detection/train_images/authentic")
files = [x.split(".")[0] for x in files]
gt = []
shapes = []
counter = 0
path = ""
for fn in files:
    '''
    shape = ""
    if counter % 2 == 1:
        path = f"/kaggle/input/recodai-luc-scientific-image-forgery-detection/train_images/forged/{fn}.png"
        mask = f"/kaggle/input/recodai-luc-scientific-image-forgery-detection/train_masks/{fn}.npy"
        print(np.load(mask).shape)
        gt.append(rle_encode_from_mask(np.load(mask)))
    else:
        path = f"/kaggle/input/recodai-luc-scientific-image-forgery-detection/train_images/authentic/{fn}.png"
        gt.append("authentic")
    '''
    path = f"/kaggle/input/recodai-luc-scientific-image-forgery-detection/test_images/{fn}.png"
    #path = f"/kaggle/input/recodai-luc-scientific-image-forgery-detection/train_images/forged/20985.png"
    pred = None
    try:
        #raise ValueError()
        orig_shape = cv2.imread(path)[:,:,0].shape
        a, b, image, masks, sim, regions = extract_feature(path)
        x_pred = np.array([[max(b), sorted(b)[-2]]])
        pred = clf.predict(x_pred)
        #if pred == 1:
        if max(b) > 0.98:
            pred = np.logical_or(masks[sim[0]]['segmentation'], masks[sim[1]]['segmentation'])
            print(masks[0]['segmentation'].shape)
            pred = resize(pred.astype(float), orig_shape, order=0, anti_aliasing=False, preserve_range=True)
            pred = np.ones_like(pred)
            #pred = np.array(pred.T)
            #shape = pred.shape
            pred = rle_encode([pred])
        else:
            pred = "authentic"
    except Exception as e:
        print(e, fn)
        pred = "authentic"
    preds.append(pred)
    #shapes.append(shape)
    counter += 1
# tip: are there actually many cases where it gets reduced to 2 masks? there seems to be a couple of success cases
# with just 2 masks

In [None]:
idx = min(len(files), len(preds))

# Convert shape tuples to JSON array strings
#shape_strings = [json.dumps(list(shape)) for shape in shapes[:idx]]

#solution = pd.DataFrame({'case_id': files[:idx], 'annotation': gt[:idx]})
submission = pd.DataFrame({'case_id': files[:idx], 'annotation': preds[:idx]})

In [None]:
submission

In [None]:
submission.to_csv("../submission.csv", index=False)

In [None]:
pd.read_csv("../submission.csv")