In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory


# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

## To know about the trained model, see this note book [Image Forgery Detection 1](https://www.kaggle.com/code/abdulahadshaik/image-forgery-detection-1)

In [None]:
# ====================================================
# Recod.ai/LUC - Enhanced Inference Notebook
# ----------------------------------------------------
# Multi-GPU + AMP + Tile-Based Inference + Cleanup
# ====================================================

import os, gc, json, cv2, warnings
import numpy as np
import pandas as pd
from glob import glob
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

warnings.filterwarnings("ignore")
os.environ["OMP_NUM_THREADS"] = "1"

# ====================================================
# 1. Paths & Config
# ====================================================
BASE_PATH = "/kaggle/input/recodai-luc-scientific-image-forgery-detection"
TEST_IMG_PATH = os.path.join(BASE_PATH, "test_images")

MODEL_PATH = "/kaggle/input/imageforgerytrainedmodels/model_final1.pth"
SUBMISSION_PATH = "submission.csv"

IMG_SIZE = 1024       # use high-resolution inference
BATCH_SIZE = 2
THRESHOLD = 0.35      # can tune 0.25–0.45
MIN_BLOB_AREA = 60    # remove tiny false positives (<60px)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Using {torch.cuda.device_count()} GPU(s): {DEVICE}")

# ====================================================
# 2. Model Definition (same as training)
# ====================================================
class ConvBlock(nn.Module):
    def __init__(self, c1, c2):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(c1, c2, 3, padding=1), nn.BatchNorm2d(c2), nn.ReLU(True),
            nn.Conv2d(c2, c2, 3, padding=1), nn.BatchNorm2d(c2), nn.ReLU(True)
        )
    def forward(self, x): return self.seq(x)

class UNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=1, base=32):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base)
        self.enc2 = ConvBlock(base, base*2)
        self.enc3 = ConvBlock(base*2, base*4)
        self.enc4 = ConvBlock(base*4, base*8)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = ConvBlock(base*8, base*16)
        self.up4 = nn.ConvTranspose2d(base*16, base*8, 2, stride=2)
        self.dec4 = ConvBlock(base*16, base*8)
        self.up3 = nn.ConvTranspose2d(base*8, base*4, 2, stride=2)
        self.dec3 = ConvBlock(base*8, base*4)
        self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, stride=2)
        self.dec2 = ConvBlock(base*4, base*2)
        self.up1 = nn.ConvTranspose2d(base*2, base, 2, stride=2)
        self.dec1 = ConvBlock(base*2, base)
        self.out_conv = nn.Conv2d(base, out_ch, 1)
    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        b  = self.bottleneck(self.pool(e4))
        d4 = self.dec4(torch.cat([self.up4(b), e4], 1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], 1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], 1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], 1))
        return self.out_conv(d1)

# ====================================================
# 3. Utilities
# ====================================================
def read_image(path):
    try:
        img = plt.imread(path)
    except:
        img = np.zeros((512,512,3), np.uint8)
    if img is None or img.size == 0:
        img = np.zeros((512,512,3), np.uint8)
    if img.ndim == 2:
        img = np.stack([img,img,img], -1)
    elif img.shape[2] == 4:
        img = img[:,:,:3]
    if img.dtype != np.uint8:
        img = (img*255).astype(np.uint8)
    return img

def clean_mask(mask, min_area=50):
    mask = (mask > 0.5).astype(np.uint8)
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
    cleaned = np.zeros_like(mask)
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] >= min_area:
            cleaned[labels == i] = 1
    return cleaned

def rle_encode(mask):
    pixels = mask.T.flatten()
    dots = np.where(pixels == 1)[0]
    run_lengths, prev = [], -2
    for b in dots:
        if b > prev + 1: run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return json.dumps(run_lengths)

def mask_to_rle(mask, threshold=0.5):
    mask = (mask > threshold).astype(np.uint8)
    if mask.sum() == 0:
        return "authentic"
    return rle_encode(mask)

# ====================================================
# 4. Load Model
# ====================================================
model = UNet(base=32).to(DEVICE)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
state = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(state)
model.eval()
print("✅ Model loaded successfully")

# ====================================================
# 5. Tile-based Inference Function
# ====================================================
def predict_tiled(image, model, tile_size=512, overlap=64):
    """Split large image into overlapping tiles, predict each, then stitch back."""
    h, w = image.shape[:2]
    stride = tile_size - overlap
    full_pred = np.zeros((h, w), dtype=np.float32)
    weight_map = np.zeros((h, w), dtype=np.float32)

    tiles = []
    coords = []
    for y in range(0, h, stride):
        for x in range(0, w, stride):
            y2 = min(y + tile_size, h)
            x2 = min(x + tile_size, w)
            tile = image[y:y2, x:x2]
            pad_y, pad_x = tile_size - tile.shape[0], tile_size - tile.shape[1]
            tile = cv2.copyMakeBorder(tile, 0, pad_y, 0, pad_x, cv2.BORDER_REFLECT)
            tiles.append(tile)
            coords.append((y, y2, x, x2))

    preds = []
    with torch.no_grad(), torch.amp.autocast("cuda"):
        for i in range(0, len(tiles), BATCH_SIZE):
            batch = tiles[i:i+BATCH_SIZE]
            batch_t = [torch.tensor(t.transpose(2,0,1), dtype=torch.float32)/255. for t in batch]
            batch_t = torch.stack(batch_t).to(DEVICE)
            p = torch.sigmoid(model(batch_t)).cpu().numpy()
            preds.extend(p)

    # stitch predictions
    for (y, y2, x, x2), p in zip(coords, preds):
        p = p.squeeze()
        p = p[:(y2-y), :(x2-x)]
        full_pred[y:y2, x:x2] += p
        weight_map[y:y2, x:x2] += 1.0
    weight_map[weight_map == 0] = 1
    full_pred /= weight_map
    return full_pred

# ====================================================
# 6. Run Inference on Test Set
# ====================================================
test_images = sorted(glob(os.path.join(TEST_IMG_PATH, "*.png")))
print(f"Found {len(test_images)} test images.")

predictions = []

for path in tqdm(test_images, desc="Predicting"):
    img = read_image(path)
    pred = predict_tiled(img, model, tile_size=512, overlap=64)
    pred = clean_mask(pred, min_area=MIN_BLOB_AREA)
    rle = mask_to_rle(pred, threshold=THRESHOLD)
    case_id = os.path.splitext(os.path.basename(path))[0]
    predictions.append({"case_id": case_id, "annotation": rle})
    gc.collect()

df_sub = pd.DataFrame(predictions)
df_sub.to_csv(SUBMISSION_PATH, index=False)
print(f"\n✅ Submission file saved: {SUBMISSION_PATH}")
display(df_sub.head())

# ====================================================
# 7. Optional Visualization
# ====================================================
for i in range(min(3, len(test_images))):
    path = test_images[i]
    img = read_image(path)
    pred = predict_tiled(img, model)
    pred = clean_mask(pred, min_area=MIN_BLOB_AREA)
    plt.figure(figsize=(10,4))
    plt.subplot(1,2,1); plt.imshow(img); plt.title(os.path.basename(path))
    plt.subplot(1,2,2); plt.imshow(pred>THRESHOLD); plt.title("Predicted Forgery Mask")
    plt.show()
