# Hybrid DINOv2-UNet Model for Image Forgery Detection
## Thanks to Hossam Hamouda and Pankaj Gupta for the ideas and the dataset. (links after submission)
## If you have any ideas or suggestions for improving this solution, feel free to share them in the comments.
### Model Architecture

Our hybrid model combines a DINOv2 encoder with a UNet-style decoder for semantic segmentation of forged regions in images.

**Encoder (DINOv2-based):**
Given an input image $I \in \mathbb{R}^{H \times W \times 3}$, we first preprocess it to size $S \times S$ (where $S = 720$):

$$I' = \text{Resize}(I, (S, S))$$

The DINOv2 encoder extracts features:

$$\mathbf{F}_{\text{dino}} = \text{DINOv2}(I') \in \mathbb{R}^{B \times N \times 768}$$

where $B$ is batch size, $N$ is the number of patches, and 768 is the DINOv2 base dimension. We reshape to spatial feature maps:

$$\mathbf{F}_{\text{spatial}} = \text{Reshape}(\mathbf{F}_{\text{dino}}[:, 1:, :]) \in \mathbb{R}^{B \times 768 \times \sqrt{N-1} \times \sqrt{N-1}}$$

A projection layer maps to decoder dimensions:

$$\mathbf{E}_4 = \text{Proj}(\mathbf{F}_{\text{spatial}}) = \text{BN}(\text{ReLU}(\text{Conv}_{1 \times 1}(\mathbf{F}_{\text{spatial}}))) \in \mathbb{R}^{B \times 512 \times H' \times W'}$$

**Decoder (UNet-style):**

The bottleneck processes the encoder output:

$$\mathbf{B} = \text{Bottleneck}(\text{Dropout}(\mathbf{E}_4)) \in \mathbb{R}^{B \times 1024 \times H'/2 \times W'/2}$$

The decoder uses transposed convolutions with skip connections:

$$\mathbf{D}_4 = \text{Dec}_4(\text{Concat}(\text{Up}_4(\mathbf{B}), \mathbf{E}_4))$$
$$\mathbf{D}_3 = \text{Dec}_3(\text{Concat}(\text{Up}_3(\mathbf{D}_4), \mathbf{E}_3))$$
$$\mathbf{D}_2 = \text{Dec}_2(\text{Concat}(\text{Up}_2(\mathbf{D}_3), \mathbf{E}_2))$$
$$\mathbf{D}_1 = \text{Dec}_1(\text{Concat}(\text{Up}_1(\mathbf{D}_2), \mathbf{E}_1))$$

where $\text{Up}_i$ denotes transposed convolution upsampling and $\text{Dec}_i$ are decoder blocks with SE attention.

**Output:**
The final segmentation probability map:

$$P = \sigma(\text{Conv}_{1 \times 1}(\mathbf{D}_1)) \in [0, 1]^{B \times 1 \times S \times S}$$

where $\sigma$ is the sigmoid function.

### Loss Function

We use a combination of Binary Cross-Entropy and Dice loss:

$$\mathcal{L}_{\text{BCE}} = -\frac{1}{N}\sum_{i=1}^{N} [y_i \log(p_i) + (1-y_i) \log(1-p_i)]$$

$$\mathcal{L}_{\text{Dice}} = 1 - \frac{2\sum_{i} p_i y_i + \epsilon}{\sum_{i} p_i + \sum_{i} y_i + \epsilon}$$

$$\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{BCE}} + \mathcal{L}_{\text{Dice}}$$

where $y_i$ are ground truth labels, $p_i$ are predicted probabilities, and $\epsilon = 10^{-6}$ is a smoothing term.

### Post-processing Pipeline

**Test Time Augmentation (TTA):**
We average predictions from original and flipped images:

$$P_{\text{final}} = \frac{1}{3}\left[P_{\text{orig}} + \text{Flip}_H(P_{\text{hflip}}) + \text{Flip}_V(P_{\text{vflip}})\right]$$

**Adaptive Mask Enhancement:**
We compute gradient magnitude using Sobel operators:

$$G_x = \text{Sobel}_x(P_{\text{final}}), \quad G_y = \text{Sobel}_y(P_{\text{final}})$$
$$|\nabla P| = \sqrt{G_x^2 + G_y^2}, \quad |\nabla P|_{\text{norm}} = \frac{|\nabla P|}{\max(|\nabla P|) + \epsilon}$$

The enhanced probability map combines original probabilities with gradient information:

$$P_{\text{enhanced}} = (1-\alpha) \cdot P_{\text{final}} + \alpha \cdot |\nabla P|_{\text{norm}}$$

where $\alpha = 0.35$ is the gradient weight. After Gaussian blur:

$$P_{\text{blur}} = \text{GaussianBlur}(P_{\text{enhanced}}, \sigma=1)$$

**Adaptive Thresholding:**
The threshold is computed adaptively:

$$\tau = \mu(P_{\text{blur}}) + 0.3 \cdot \sigma(P_{\text{blur}})$$

where $\mu$ and $\sigma$ are mean and standard deviation. The binary mask:

$$M = \mathbb{1}[P_{\text{blur}} > \tau]$$

**Morphological Operations:**
We apply closing and opening to clean the mask:

$$M' = \text{Open}(\text{Close}(M, k_1=5), k_2=3)$$

**Filtering:**
Final decision based on area and mean probability:

$$\text{Label} = \begin{cases}
\text{"forged"} & \text{if } \text{Area}(M') \geq 300 \text{ and } \bar{P}_{\text{inside}} \geq 0.25 \\
\text{"authentic"} & \text{otherwise}
\end{cases}$$

where $\bar{P}_{\text{inside}} = \frac{1}{|\Omega|}\sum_{(i,j) \in \Omega} P_{\text{final}}(i,j)$ and $\Omega = \{(i,j) : M'(i,j) = 1\}$.


In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import cv2
from pathlib import Path
import os
import json
import math
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns


# DINOv2
try:
    from transformers import AutoImageProcessor, AutoModel
    HAS_TRANSFORMERS = True
except ImportError as e:
    HAS_TRANSFORMERS = False
    print(e)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

In [None]:
def rle_encode(mask: np.ndarray, fg_val: int = 1) -> str:
    pixels = mask.T.flatten()
    dots = np.where(pixels == fg_val)[0]
    if len(dots) == 0:
        return "authentic"
    run_lengths = []
    prev = -2
    for b in dots:
        if b > prev + 1:
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return json.dumps([int(x) for x in run_lengths])

def load_mask(path):
    if os.path.exists(path):
        mask = np.load(path)
        if mask.ndim == 3:
            mask = mask.max(axis=0)
        return (mask > 0).astype(np.uint8)
    return None

def combine_masks(mask_paths):
    masks = []
    for path in mask_paths:
        mask = load_mask(path)
        if mask is not None:
            masks.append(mask)
    if masks:
        combined = np.zeros_like(masks[0])
        for mask in masks:
            combined = np.maximum(combined, mask)
        return combined
    return None

In [None]:
class ImgDataset(Dataset):
    def __init__(self, img_paths, mask_paths=None, train=True, sz=720):
        self.img_paths = img_paths
        self.mask_paths = mask_paths
        self.train = train
        self.sz = sz
        
        if train:
            self.tfms = transforms.Compose([
                transforms.RandomHorizontalFlip(0.5),
                transforms.RandomVerticalFlip(0.5),
                transforms.RandomRotation(90),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
            ])
        
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        
        try:
            img = Image.open(img_path).convert('RGB')
        except:
            img = Image.new('RGB', (self.sz, self.sz), (128, 128, 128))
        
        img = img.resize((self.sz, self.sz), Image.LANCZOS)
        img = np.array(img).astype(np.float32) / 255.0
        
        if self.mask_paths:
            mask_paths = self.mask_paths[idx]
            mask = combine_masks(mask_paths) if isinstance(mask_paths, list) else load_mask(mask_paths)
            if mask is None:
                mask = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
            else:
                mask = cv2.resize(mask, (self.sz, self.sz), interpolation=cv2.INTER_NEAREST)
            
            img_tensor = torch.from_numpy(img).permute(2, 0, 1)
            mask_tensor = torch.from_numpy(mask).unsqueeze(0).float()
            
            if self.train:
                seed = np.random.randint(2147483647)
                torch.manual_seed(seed)
                img_tensor = self.tfms(img_tensor)
                torch.manual_seed(seed)
                mask_tensor = self.tfms(mask_tensor)
            
            return img_tensor, mask_tensor
        else:
            img_tensor = torch.from_numpy(img).permute(2, 0, 1)
            return img_tensor, str(img_path)

In [None]:
class SEBlock(nn.Module):
    def __init__(self, ch, r=16):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(ch, ch // r, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch // r, ch, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return x * self.se(x)

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, use_se=True):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        self.se = SEBlock(out_ch) if use_se else nn.Identity()
    
    def forward(self, x):
        return self.se(self.conv(x))

# UNet (fallback)
class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1):
        super().__init__()
        self.enc1 = ConvBlock(n_channels, 64)
        self.enc2 = ConvBlock(64, 128)
        self.enc3 = ConvBlock(128, 256)
        self.enc4 = ConvBlock(256, 512)
        
        self.pool = nn.MaxPool2d(2)
        self.drop = nn.Dropout2d(0.1)
        
        self.bottleneck = ConvBlock(512, 1024)
        
        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = ConvBlock(1024, 512)
        
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = ConvBlock(512, 256)
        
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = ConvBlock(256, 128)
        
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = ConvBlock(128, 64)
        
        self.final = nn.Conv2d(64, n_classes, 1)
    
    def crop(self, x, target):
        _, _, h, w = target.size()
        _, _, xh, xw = x.size()
        diff_h = (xh - h) // 2
        diff_w = (xw - w) // 2
        return x[:, :, diff_h:diff_h+h, diff_w:diff_w+w]
    
    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        
        b = self.bottleneck(self.drop(self.pool(e4)))
        
        d4 = self.up4(b)
        d4 = self.crop(d4, e4)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)
        
        d3 = self.up3(d4)
        d3 = self.crop(d3, e3)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.up2(d3)
        d2 = self.crop(d2, e2)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.up1(d2)
        d1 = self.crop(d1, e1)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        
        return torch.sigmoid(self.final(d1))

# Hybrid model (DINO's encoder and UNet decoder)
class DINOv2UNet(nn.Module):
    def __init__(self, dino_path=None, n_classes=1, img_size=720, use_dino=True):
        super().__init__()
        self.use_dino = use_dino and HAS_TRANSFORMERS
        self.img_size = img_size
        
        if self.use_dino and dino_path:
            print(f"Loading DINOv2 from {dino_path}")
            try:
                self.processor = AutoImageProcessor.from_pretrained(
                    dino_path, local_files_only=True, use_fast=False
                )
                self.dino_encoder = AutoModel.from_pretrained(
                    dino_path, local_files_only=True
                ).to('cpu')
                for p in self.dino_encoder.parameters():
                    p.requires_grad = False
                self.dino_encoder.eval()
                dino_dim = 768
                print("DINOv2 loaded on CPU")
            except Exception as e:
                print(e)
                self.use_dino = False
                dino_dim = 0
        else:
            self.use_dino = False
            dino_dim = 0

        self.drop = nn.Dropout2d(0.1)

        if self.use_dino:
            self.dino_proj = nn.Sequential(
                nn.Conv2d(dino_dim, 512, 1),
                nn.BatchNorm2d(512),
                nn.ReLU(inplace=True)
            )
            self.dino_proj_e3 = nn.Sequential(
                nn.Conv2d(512, 256, 1),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True)
            )
            self.dino_proj_e2 = nn.Sequential(
                nn.Conv2d(512, 128, 1),
                nn.BatchNorm2d(128),
                nn.ReLU(inplace=True)
            )
            self.dino_proj_e1 = nn.Sequential(
                nn.Conv2d(512, 64, 1),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True)
            )
            bottleneck_in = 512
        else:
            self.enc1 = ConvBlock(3, 64)
            self.enc2 = ConvBlock(64, 128)
            self.enc3 = ConvBlock(128, 256)
            self.enc4 = ConvBlock(256, 512)
            self.pool = nn.MaxPool2d(2)
            bottleneck_in = 512
        
        self.bottleneck = ConvBlock(bottleneck_in, 1024)
        
        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = ConvBlock(1024, 512)
        
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = ConvBlock(512, 256)
        
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = ConvBlock(256, 128)
        
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = ConvBlock(128, 64)
        
        self.final = nn.Conv2d(64, n_classes, 1)
    
    def forward_features_dino(self, x):
        imgs = (x * 255).clamp(0, 255).byte().permute(0, 2, 3, 1).cpu().numpy()
        inputs = self.processor(images=list(imgs), return_tensors="pt").to('cpu')
        
        if self.training:
            feats = self.dino_encoder(**inputs).last_hidden_state
        else:
            with torch.no_grad():
                feats = self.dino_encoder(**inputs).last_hidden_state
        
        B, N, C = feats.shape
        fmap = feats[:, 1:, :].permute(0, 2, 1)
        s = int(math.sqrt(N - 1))
        fmap = fmap.reshape(B, C, s, s)
        
        return fmap.to(x.device)
    
    def forward(self, x):
        if self.use_dino:
            dino_feats = self.forward_features_dino(x)
            e4 = self.dino_proj(dino_feats)
            target_size = x.shape[2] // 16
            e4 = F.interpolate(e4, size=(target_size, target_size), 
                             mode='bilinear', align_corners=False)
            e3_interp = F.interpolate(e4, size=(x.shape[2]//4, x.shape[3]//4), 
                                     mode='bilinear', align_corners=False)
            e3 = self.dino_proj_e3(e3_interp)
            
            e2_interp = F.interpolate(e4, size=(x.shape[2]//2, x.shape[3]//2), 
                                     mode='bilinear', align_corners=False)
            e2 = self.dino_proj_e2(e2_interp)
            
            e1_interp = F.interpolate(e4, size=(x.shape[2], x.shape[3]), 
                                     mode='bilinear', align_corners=False)
            e1 = self.dino_proj_e1(e1_interp)
            
            b = self.bottleneck(self.drop(e4))
        else:
            e1 = self.enc1(x)
            e2 = self.enc2(self.pool(e1))
            e3 = self.enc3(self.pool(e2))
            e4 = self.enc4(self.pool(e3))
            b = self.bottleneck(self.drop(self.pool(e4)))
        
        d4 = self.up4(b)
        if d4.shape[2:] != e4.shape[2:]:
            e4 = F.interpolate(e4, size=d4.shape[2:], mode='bilinear', align_corners=False)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)
        
        d3 = self.up3(d4)
        if d3.shape[2:] != e3.shape[2:]:
            e3 = F.interpolate(e3, size=d3.shape[2:], mode='bilinear', align_corners=False)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.up2(d3)
        if d2.shape[2:] != e2.shape[2:]:
            e2 = F.interpolate(e2, size=d2.shape[2:], mode='bilinear', align_corners=False)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.up1(d2)
        if d1.shape[2:] != e1.shape[2:]:
            e1 = F.interpolate(e1, size=d1.shape[2:], mode='bilinear', align_corners=False)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        
        return torch.sigmoid(self.final(d1))

In [None]:
def dice_loss(pred, target, smooth=1e-6):
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)
    intersection = (pred_flat * target_flat).sum()
    return 1 - (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)

def bce_dice_loss(pred, target):
    bce = nn.functional.binary_cross_entropy(pred, target)
    dice = dice_loss(pred, target)
    return bce + dice

def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for imgs, masks in tqdm(loader, desc='Training'):
        imgs = imgs.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        preds = model(imgs)
        loss = criterion(preds, masks)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    return total_loss / len(loader)

In [None]:
BASE_PATH = '/kaggle/input/recodai-luc-scientific-image-forgery-detection'

train_img_dir = Path(BASE_PATH) / 'train_images'
train_mask_dir = Path(BASE_PATH) / 'train_masks'
supp_img_dir = Path(BASE_PATH) / 'supplemental_images'
supp_mask_dir = Path(BASE_PATH) / 'supplemental_masks'
test_img_dir = Path(BASE_PATH) / 'test_images'

train_imgs = sorted(list(train_img_dir.glob('*.png')))
supp_imgs = sorted(list(supp_img_dir.glob('*.png'))) if supp_img_dir.exists() else []
test_imgs = sorted(list(test_img_dir.glob('*.png')))

print(f'Train images: {len(train_imgs)}')
print(f'Supplemental images: {len(supp_imgs)}')
print(f'Test images: {len(test_imgs)}')

In [None]:
def get_mask_paths(img_path, train_dir, mask_dir):
    img_id = img_path.stem
    mask_path = mask_dir / f'{img_id}.npy'
    if mask_path.exists():
        return mask_path
    return None

train_mask_paths = []
train_valid_imgs = []

for img_path in train_imgs:
    mask_path = get_mask_paths(img_path, train_img_dir, train_mask_dir)
    if mask_path:
        train_valid_imgs.append(img_path)
        train_mask_paths.append(mask_path)

supp_mask_paths = []
supp_valid_imgs = []

if supp_mask_dir.exists():
    for img_path in supp_imgs:
        mask_path = get_mask_paths(img_path, supp_img_dir, supp_mask_dir)
        if mask_path:
            supp_valid_imgs.append(img_path)
            supp_mask_paths.append(mask_path)

all_train_imgs = train_valid_imgs + supp_valid_imgs
all_train_masks = train_mask_paths + supp_mask_paths

print(f'Valid train images with masks: {len(all_train_imgs)}')

In [None]:
from sklearn.model_selection import train_test_split

train_imgs_split, val_imgs_split, train_masks_split, val_masks_split = train_test_split(
    all_train_imgs, all_train_masks, test_size=0.1, random_state=42
)

train_ds = ImgDataset(train_imgs_split, train_masks_split, train=True, sz=720)
val_ds = ImgDataset(val_imgs_split, val_masks_split, train=False, sz=720)

train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=True, persistent_workers=True)
val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, num_workers=2, pin_memory=True, persistent_workers=True)

print(f'Train batches: {len(train_loader)}')
print(f'Val batches: {len(val_loader)}')

# DINOv2UNet or UNet 


In [None]:
torch.cuda.empty_cache()
import gc
gc.collect()

DINO_PATH = "/kaggle/input/dinov2/pytorch/base/1"
USE_ENSEMBLE = True

print("Creating ensemble models:")

print("Creating UNet model")
model_unet = UNet(n_channels=3, n_classes=1).to(device)
print(f"UNet created - Parameters: {sum(p.numel() for p in model_unet.parameters()):,}")

if USE_ENSEMBLE and HAS_TRANSFORMERS and os.path.exists(DINO_PATH):
    print("Creating DINOv2UNet model")
    model_dino = DINOv2UNet(
        dino_path=DINO_PATH,
        n_classes=1,
        img_size=720,
        use_dino=True
    )
    if model_dino.use_dino:
        model_dino.dino_proj = model_dino.dino_proj.to(device)
        model_dino.dino_proj_e3 = model_dino.dino_proj_e3.to(device)
        model_dino.dino_proj_e2 = model_dino.dino_proj_e2.to(device)
        model_dino.dino_proj_e1 = model_dino.dino_proj_e1.to(device)
    model_dino.bottleneck = model_dino.bottleneck.to(device)
    model_dino.up4 = model_dino.up4.to(device)
    model_dino.dec4 = model_dino.dec4.to(device)
    model_dino.up3 = model_dino.up3.to(device)
    model_dino.dec3 = model_dino.dec3.to(device)
    model_dino.up2 = model_dino.up2.to(device)
    model_dino.dec2 = model_dino.dec2.to(device)
    model_dino.up1 = model_dino.up1.to(device)
    model_dino.dec1 = model_dino.dec1.to(device)
    model_dino.final = model_dino.final.to(device)
    model_dino.drop = model_dino.drop.to(device)
    print(f"DINOv2UNet created - Parameters: {sum(p.numel() for p in model_dino.parameters()):,}")
    print(f"DINOv2 encoder: on CPU")
    print(f"Decoder: on {device}")
else:
    model_dino = None
    print("DINOv2UNet not available, using UNet only")

ENSEMBLE_WEIGHTS = [0.7, 0.3]
print(f"Ensemble weights: UNet={ENSEMBLE_WEIGHTS[0]}, DINOv2UNet={ENSEMBLE_WEIGHTS[1]}")

In [None]:
print("Training UNet model")
optimizer_unet = optim.AdamW(model_unet.parameters(), lr=2e-4, weight_decay=1e-5)
scheduler_unet = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_unet, T_0=5, T_mult=2, eta_min=1e-6)
criterion = bce_dice_loss

best_val_loss_unet = float('inf')
epochs = 20
patience = 5
patience_counter = 0

for epoch in range(epochs):
    train_loss = train_epoch(model_unet, train_loader, optimizer_unet, criterion, device)
    
    model_unet.eval()
    val_loss = 0
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs = imgs.to(device)
            masks = masks.to(device)
            preds = model_unet(imgs)
            loss = criterion(preds, masks)
            val_loss += loss.item()
    val_loss /= len(val_loader)
    
    scheduler_unet.step()
    
    print(f'UNet - Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, LR: {optimizer_unet.param_groups[0]["lr"]:.6f}')
    
    if val_loss < best_val_loss_unet:
        best_val_loss_unet = val_loss
        patience_counter = 0
        torch.save(model_unet.state_dict(), 'best_model_unet.pth')
        print(f'UNet model saved!')
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f'Early stopping at epoch {epoch+1}')
            break

model_unet.load_state_dict(torch.load('best_model_unet.pth'))
print("UNet training completed!")

In [None]:
if model_dino is not None:
    print("Training DINOv2UNet model...")
    optimizer_dino = optim.AdamW(model_dino.parameters(), lr=2e-4, weight_decay=1e-5)
    scheduler_dino = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_dino, T_0=5, T_mult=2, eta_min=1e-6)
    
    best_val_loss_dino = float('inf')
    patience_counter = 0
    
    for epoch in range(epochs):
        train_loss = train_epoch(model_dino, train_loader, optimizer_dino, criterion, device)
        
        model_dino.eval()
        val_loss = 0
        with torch.no_grad():
            for imgs, masks in val_loader:
                imgs = imgs.to(device)
                masks = masks.to(device)
                preds = model_dino(imgs)
                loss = criterion(preds, masks)
                val_loss += loss.item()
        val_loss /= len(val_loader)
        
        scheduler_dino.step()
        
        print(f'DINOv2UNet - Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, LR: {optimizer_dino.param_groups[0]["lr"]:.6f}')
        
        if val_loss < best_val_loss_dino:
            best_val_loss_dino = val_loss
            patience_counter = 0
            torch.save(model_dino.state_dict(), 'best_model_dino.pth')
            print(f'DINOv2UNet model saved!')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f'Early stopping at epoch {epoch+1}')
                break
    
    model_dino.load_state_dict(torch.load('best_model_dino.pth'))
    print("DINOv2UNet training completed!")
else:
    print("Skipping DINOv2UNet training (model not available)")

In [None]:
AREA_THR = 100
MEAN_THR = 0.15
USE_TTA = True

def enhanced_adaptive_mask(prob, alpha_grad=0.35):
    gx = cv2.Sobel(prob, cv2.CV_32F, 1, 0, ksize=3)
    gy = cv2.Sobel(prob, cv2.CV_32F, 0, 1, ksize=3)
    grad_mag = np.sqrt(gx**2 + gy**2)
    grad_norm = grad_mag / (grad_mag.max() + 1e-6)
    
    enhanced = (1 - alpha_grad) * prob + alpha_grad * grad_norm
    enhanced = cv2.GaussianBlur(enhanced, (3, 3), 0)
    
    thr = np.mean(enhanced) + 0.3 * np.std(enhanced)
    mask = (enhanced > thr).astype(np.uint8)
    
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8))
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8))
    
    return mask, thr

def finalize_mask(prob, orig_size, img_size=720):
    mask, thr = enhanced_adaptive_mask(prob)
    mask = cv2.resize(mask, orig_size, interpolation=cv2.INTER_NEAREST)
    return mask, thr

def postprocess_mask(mask, min_area=100):
    mask = mask.astype(np.uint8)
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] < min_area:
            mask[labels == i] = 0
    return mask

def predict_tta(model, img_tensor, device):
    model.eval()
    preds = []
    
    with torch.no_grad():
        img = img_tensor.unsqueeze(0).to(device)
        
        pred = model(img).cpu().numpy()[0, 0]
        preds.append(pred)
        
        img_hflip = torch.flip(img, [3])
        pred_hflip = model(img_hflip).cpu().numpy()[0, 0]
        preds.append(np.fliplr(pred_hflip))
        
        img_vflip = torch.flip(img, [2])
        pred_vflip = model(img_vflip).cpu().numpy()[0, 0]
        preds.append(np.flipud(pred_vflip))
    
    return np.mean(preds, axis=0)

def predict_ensemble(models, weights, img_tensor, device, use_tta=True):
    probs = []
    for model in models:
        if model is None:
            continue
        if use_tta:
            prob = predict_tta(model, img_tensor, device)
        else:
            with torch.no_grad():
                model.eval()
                img = img_tensor.unsqueeze(0).to(device)
                prob = model(img).cpu().numpy()[0, 0]
        probs.append(prob)
    
    if len(probs) == 0:
        return None
    
    if len(probs) == 1:
        return probs[0]
    
    weights_norm = np.array(weights[:len(probs)])
    weights_norm = weights_norm / weights_norm.sum()
    
    ensemble_prob = np.zeros_like(probs[0])
    for prob, w in zip(probs, weights_norm):
        ensemble_prob += w * prob
    
    return ensemble_prob

def pipeline_final(models, weights, pil_img, device, img_size=720):
    img_tensor = torch.from_numpy(
        np.array(pil_img.resize((img_size, img_size)), np.float32) / 255.
    ).permute(2, 0, 1)
    
    prob = predict_ensemble(models, weights, img_tensor, device, use_tta=USE_TTA)
    
    if prob is None:
        return "authentic", None, {"area": 0, "mean_inside": 0.0, "thr": 0.0}
    
    mask, thr = finalize_mask(prob, pil_img.size, img_size)
    
    area = int(mask.sum())
    prob_resized = cv2.resize(prob, pil_img.size, interpolation=cv2.INTER_LINEAR)
    mean_inside = float(prob_resized[mask == 1].mean()) if area > 0 else 0.0
    
    if area < AREA_THR or mean_inside < MEAN_THR:
        return "authentic", None, {"area": area, "mean_inside": mean_inside, "thr": thr}
    
    return "forged", mask, {"area": area, "mean_inside": mean_inside, "thr": thr}

In [None]:
sample_submission_original = pd.read_csv(Path(BASE_PATH) / 'sample_submission.csv')
sample_submission_original['case_id'] = sample_submission_original['case_id'].astype(str)
print(f"Original sample_submission: {len(sample_submission_original)} cases")

test_imgs_available = sorted(list(test_img_dir.glob('**/*.png')))
print(f"Available test images: {len(test_imgs_available)}")

if len(test_imgs_available) > len(sample_submission_original) * 2:
    print(f"\n Found {len(test_imgs_available)} test images,")
    print(f"but sample_submission has only {len(sample_submission_original)} case_id")
    print(f"Creating expanded submission for all found images...")
    
    all_case_ids = [str(Path(p).stem) for p in test_imgs_available]
    sample_submission = pd.DataFrame({
        'case_id': all_case_ids
    })
    
    print(f"Created expanded submission with {len(sample_submission)} case_id")
    print(f"First 10 case_id: {sample_submission['case_id'].head(10).tolist()}")
else:
    sample_submission = sample_submission_original.copy()
    print(f"Using original sample_submission with {len(sample_submission)} case_id")

In [None]:
test_imgs_all = sorted(list(test_img_dir.glob('**/*.png')))
test_imgs_dict = {}
for p in test_imgs_all:
    stem = Path(p).stem
    test_imgs_dict[str(stem)] = p
    test_imgs_dict[stem] = p
    try:
        test_imgs_dict[int(stem)] = p
    except:
        pass

models_list = [model_unet, model_dino]
if model_dino is None:
    models_list = [model_unet]
    ensemble_weights = [1.0]
else:
    ensemble_weights = ENSEMBLE_WEIGHTS

for model in models_list:
    if model is not None:
        model.eval()

predictions_dict = {}

torch.cuda.empty_cache()

print(f"\n Processing {len(sample_submission)} cases with ensemble")
print(f"   Models: {len([m for m in models_list if m is not None])}")
print(f"   Weights: {ensemble_weights}")

with torch.no_grad():
    for idx, row in tqdm(sample_submission.iterrows(), total=len(sample_submission), desc='Predicting'):
        case_id = row['case_id']
        case_id_str = str(case_id)
        
        img_path = None
        if case_id_str in test_imgs_dict:
            img_path = test_imgs_dict[case_id_str]
        elif case_id in test_imgs_dict:
            img_path = test_imgs_dict[case_id]
        else:
            for test_stem, test_path in test_imgs_dict.items():
                if str(test_stem) == case_id_str or str(test_stem) == str(case_id):
                    img_path = test_path
                    break
        
        if img_path is None:
            predictions_dict[case_id_str] = "authentic"
            continue
        
        try:
            orig_img = Image.open(img_path).convert('RGB')
            
            label, mask, dbg = pipeline_final(models_list, ensemble_weights, orig_img, device, img_size=720)
            
            if label == "authentic":
                predictions_dict[case_id_str] = "authentic"
            else:
                if mask is not None:
                    mask = postprocess_mask(mask, min_area=50)
                    if mask.sum() == 0:
                        predictions_dict[case_id_str] = "authentic"
                    else:
                        rle = rle_encode((mask > 0).astype(np.uint8))
                        predictions_dict[case_id_str] = rle
                else:
                    predictions_dict[case_id_str] = "authentic"
        except Exception as e:
            print(f"Error processing {case_id}: {e}")
            predictions_dict[case_id_str] = "authentic"
        
        if (idx + 1) % 10 == 0:
            torch.cuda.empty_cache()

print(f'\nProcessed {len(predictions_dict)} predictions')

submission = sample_submission.copy()
submission['case_id'] = submission['case_id'].astype(str)
submission['annotation'] = submission['case_id'].map(predictions_dict).fillna('authentic')

print(f"\nPredictions summary:")
print(f"   Total cases: {len(submission)}")
print(f"   With predictions: {len(predictions_dict)}")
print(f"   Authentic: {(submission['annotation'] == 'authentic').sum()}")
print(f"   Forged: {(submission['annotation'] != 'authentic').sum()}")

In [None]:
submission_final = sample_submission[['case_id']].merge(
    submission[['case_id', 'annotation']],
    on='case_id',
    how='left'
)
submission_final['annotation'] = submission_final['annotation'].fillna('authentic')

forged_annotations = submission_final[submission_final['annotation'] != 'authentic']['annotation']
if len(forged_annotations) > 0:
    try:
        for ann in forged_annotations.head(3):
            json.loads(ann)
        print(f"RLE encoding is valid (JSON format)")
    except Exception as e:
        print(f"Problem with RLE encoding: {e}")

submission_final.to_csv('submission.csv', index=False)
print(f'\nSubmission saved! Shape: {submission_final.shape}')
print(f"\nFinal statistics:")
print(submission_final.head(10))
print(f"\n   Total case_id: {len(submission_final)}")
print(f"   Authentic: {(submission_final['annotation'] == 'authentic').sum()}")
print(f"   Forged: {(submission_final['annotation'] != 'authentic').sum()}")

missing = set(sample_submission_original['case_id'].astype(str)) - set(submission_final['case_id'].astype(str))
if missing:
    print(f"\nWARNING: Missing case_id: {list(missing)[:10]}...")
else:
    print(f"\nAll case_id from sample_submission processed!")

### Links
**Pankaj Gupta | https://www.kaggle.com/pankajiitr**
**Hossam Hamouda | https://www.kaggle.com/hossam82**