# Imports and configs

In [None]:
!pip install segmentation-models-pytorch

In [None]:
import json

import numba
import numpy as np
from numba import types
import numpy.typing as npt
import pandas as pd
import scipy.optimize


class ParticipantVisibleError(Exception):
    pass


@numba.jit(nopython=True)
def _rle_encode_jit(x: npt.NDArray, fg_val: int = 1) -> list[int]:
    """Numba-jitted RLE encoder."""
    dots = np.where(x.T.flatten() == fg_val)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if b > prev + 1:
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return run_lengths


def rle_encode(mask):
    mask = mask.astype(bool)
    flat = mask.T.flatten()
    dots = np.where(flat)[0]
    
    if len(dots) == 0:
        return json.dumps([])
    
    run_lengths = []
    prev = -2
    for b in dots:
        if b > prev + 1:
            run_lengths.extend([b + 1, 0])
        run_lengths[-1] += 1
        prev = b
    
    run_lengths = [int(x) for x in run_lengths]
    return json.dumps(run_lengths)


@numba.njit
def _rle_decode_jit(mask_rle: npt.NDArray, height: int, width: int) -> npt.NDArray:
    """
    s: numpy array of run-length encoding pairs (start, length)
    shape: (height, width) of array to return
    Returns numpy array, 1 - mask, 0 - background
    """
    if len(mask_rle) % 2 != 0:
        # Numba requires raising a standard exception.
        raise ValueError('One or more rows has an odd number of values.')

    starts, lengths = mask_rle[0::2], mask_rle[1::2]
    starts -= 1
    ends = starts + lengths
    for i in range(len(starts) - 1):
        if ends[i] > starts[i + 1]:
            raise ValueError('Pixels must not be overlapping.')
    img = np.zeros(height * width, dtype=np.bool_)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img


def rle_decode(mask_rle: str, shape: tuple[int, int]) -> npt.NDArray:
    """
    mask_rle: run-length as string formatted (start length)
              empty predictions need to be encoded with '-'
    shape: (height, width) of array to return
    Returns numpy array, 1 - mask, 0 - background
    """

    mask_rle = json.loads(mask_rle)
    mask_rle = np.asarray(mask_rle, dtype=np.int32)
    starts = mask_rle[0::2]
    if sorted(starts) != list(starts):
        raise ParticipantVisibleError('Submitted values must be in ascending order.')
    try:
        return _rle_decode_jit(mask_rle, shape[0], shape[1]).reshape(shape, order='F')
    except ValueError as e:
        raise ParticipantVisibleError(str(e)) from e


def calculate_f1_score(pred_mask: npt.NDArray, gt_mask: npt.NDArray):
    pred_flat = pred_mask.flatten()
    gt_flat = gt_mask.flatten()

    tp = np.sum((pred_flat == 1) & (gt_flat == 1))
    fp = np.sum((pred_flat == 1) & (gt_flat == 0))
    fn = np.sum((pred_flat == 0) & (gt_flat == 1))

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0

    if (precision + recall) > 0:
        return 2 * (precision * recall) / (precision + recall)
    else:
        return 0


def calculate_f1_matrix(pred_masks: list[npt.NDArray], gt_masks: list[npt.NDArray]):
    """
    Parameters:
    pred_masks (np.ndarray):
            First dimension is the number of predicted instances.
            Each instance is a binary mask of shape (height, width).
    gt_masks (np.ndarray):
            First dimension is the number of ground truth instances.
            Each instance is a binary mask of shape (height, width).
    """

    num_instances_pred = len(pred_masks)
    num_instances_gt = len(gt_masks)
    f1_matrix = np.zeros((num_instances_pred, num_instances_gt))

    # Calculate F1 scores for each pair of predicted and ground truth masks
    for i in range(num_instances_pred):
        for j in range(num_instances_gt):
            pred_flat = pred_masks[i].flatten()
            gt_flat = gt_masks[j].flatten()
            f1_matrix[i, j] = calculate_f1_score(pred_mask=pred_flat, gt_mask=gt_flat)

    if f1_matrix.shape[0] < len(gt_masks):
        # Add a row of zeros to the matrix if the number of predicted instances is less than ground truth instances
        f1_matrix = np.vstack((f1_matrix, np.zeros((len(gt_masks) - len(f1_matrix), num_instances_gt))))

    return f1_matrix


def oF1_score(pred_masks: list[npt.NDArray], gt_masks: list[npt.NDArray]):
    """
    Calculate the optimal F1 score for a set of predicted masks against
    ground truth masks which considers the optimal F1 score matching.
    This function uses the Hungarian algorithm to find the optimal assignment
    of predicted masks to ground truth masks based on the F1 score matrix.
    If the number of predicted masks is less than the number of ground truth masks,
    it will add a row of zeros to the F1 score matrix to ensure that the dimensions match.

    Parameters:
    pred_masks (list of np.ndarray): List of predicted binary masks.
    gt_masks (np.ndarray): Array of ground truth binary masks.
    Returns:
    float: Optimal F1 score.
    """
    f1_matrix = calculate_f1_matrix(pred_masks, gt_masks)

    # Find the best matching between predicted and ground truth masks
    row_ind, col_ind = scipy.optimize.linear_sum_assignment(-f1_matrix)
    # The linear_sum_assignment discards excess predictions so we need a separate penalty.
    excess_predictions_penalty = len(gt_masks) / max(len(pred_masks), len(gt_masks))
    return np.mean(f1_matrix[row_ind, col_ind]) * excess_predictions_penalty


def evaluate_single_image(label_rles: str, prediction_rles: str, shape_str: str) -> float:
    shape = json.loads(shape_str)
    label_rles = [rle_decode(x, shape=shape) for x in label_rles.split(';')]
    prediction_rles = [rle_decode(x, shape=shape) for x in prediction_rles.split(';')]
    return oF1_score(prediction_rles, label_rles)


def score(solution: pd.DataFrame, submission: pd.DataFrame, row_id_column_name: str) -> float:
    """
    Args:
        solution (pd.DataFrame): The ground truth DataFrame.
        submission (pd.DataFrame): The submission DataFrame.
        row_id_column_name (str): The name of the column containing row IDs.
    Returns:
        float

    Examples
    --------
    >>> solution = pd.DataFrame({'row_id': [0, 1, 2], 'annotation': ['authentic', 'authentic', 'authentic'], 'shape': ['authentic', 'authentic', 'authentic']})
    >>> submission = pd.DataFrame({'row_id': [0, 1, 2], 'annotation': ['authentic', 'authentic', 'authentic']})
    >>> score(solution.copy(), submission.copy(), row_id_column_name='row_id')
    1.0

    >>> solution = pd.DataFrame({'row_id': [0, 1, 2], 'annotation': ['authentic', 'authentic', 'authentic'], 'shape': ['authentic', 'authentic', 'authentic']})
    >>> submission = pd.DataFrame({'row_id': [0, 1, 2], 'annotation': ['[101, 102]', '[101, 102]', '[101, 102]']})
    >>> score(solution.copy(), submission.copy(), row_id_column_name='row_id')
    0.0

    >>> solution = pd.DataFrame({'row_id': [0, 1, 2], 'annotation': ['[101, 102]', '[101, 102]', '[101, 102]'], 'shape': ['[720, 960]', '[720, 960]', '[720, 960]']})
    >>> submission = pd.DataFrame({'row_id': [0, 1, 2], 'annotation': ['[101, 102]', '[101, 102]', '[101, 102]']})
    >>> score(solution.copy(), submission.copy(), row_id_column_name='row_id')
    1.0

    >>> solution = pd.DataFrame({'row_id': [0, 1, 2], 'annotation': ['[101, 103]', '[101, 102]', '[101, 102]'], 'shape': ['[720, 960]', '[720, 960]', '[720, 960]']})
    >>> submission = pd.DataFrame({'row_id': [0, 1, 2], 'annotation': ['[101, 102]', '[101, 102]', '[101, 102]']})
    >>> score(solution.copy(), submission.copy(), row_id_column_name='row_id')
    0.9983739837398374

    >>> solution = pd.DataFrame({'row_id': [0, 1, 2], 'annotation': ['[101, 102];[300, 100]', '[101, 102]', '[101, 102]'], 'shape': ['[720, 960]', '[720, 960]', '[720, 960]']})
    >>> submission = pd.DataFrame({'row_id': [0, 1, 2], 'annotation': ['[101, 102]', '[101, 102]', '[101, 102]']})
    >>> score(solution.copy(), submission.copy(), row_id_column_name='row_id')
    0.8333333333333334
    """
    df = solution
    df = df.rename(columns={'annotation': 'label'})

    df['prediction'] = submission['annotation']
    # Check for correct 'authentic' label
    authentic_indices = (df['label'] == 'authentic') | (df['prediction'] == 'authentic')
    df['image_score'] = ((df['label'] == df['prediction']) & authentic_indices).astype(float)

    df.loc[~authentic_indices, 'image_score'] = df.loc[~authentic_indices].apply(
        lambda row: evaluate_single_image(row['label'], row['prediction'], row['shape']), axis=1
    )
    return float(np.mean(df['image_score']))

In [None]:
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import Dataset, DataLoader
from albumentations.pytorch import ToTensorV2
from tqdm.notebook import tqdm
from PIL import Image
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
import albumentations as A
import torch.nn as nn
import numpy as np
import warnings
import joblib
import random
import torch
import glob
import cv2
import gc
import os

warnings.filterwarnings('ignore')

In [None]:
class CFG:
    dataset_path = "/kaggle/input/recodai-luc-scientific-image-forgery-detection"
    
    model_name = 'Unet'
    encoder_name = 'efficientnet-b5'
    encoder_weights = 'imagenet'
    
    n_folds = 5
    seed = 42
    
    image_size = 512
    batch_size = 4
    learning_rate = 2e-4
    weight_decay = 2e-4
    n_epochs = 18
    es_patience = 5
    
    mask_threshold = 0.5
    mask_sum_threshold = 20

In [None]:
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
os.environ['TF_DETERMINISTIC_OPS'] = '1'
os.environ['PYTHONHASHSEED'] = str(CFG.seed)
torch.manual_seed(CFG.seed)
torch.cuda.manual_seed(CFG.seed)
torch.cuda.manual_seed_all(CFG.seed)  
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(CFG.seed)
random.seed(CFG.seed)

# Data loading and preprocessing

In [None]:
def get_data(dataset_path, is_train=True):
    if is_train:
        image_paths = glob.glob(f"{dataset_path}/train_images/**/*.png")
    else:
        image_paths = glob.glob(f"{dataset_path}/test_images/*.png")
        
    dataset = []
    for image_path in image_paths:
        image_id = image_path.split("/")[-1].split(".")[0]
        if is_train:
            label = "authentic" if "authentic" in image_path else "forged"
            mask_path = f"{dataset_path}/train_masks/{image_id}.npy"
        else:
            label = None
            mask_path = None
            
        dataset.append({
            "case_id": image_id,
            "label": label,
            "image_path": image_path,
            "mask_path": mask_path
        })
        
    return pd.DataFrame(dataset)

In [None]:
def visualize(image_id=None):
    if image_id is None:
        files = glob.glob(f"{CFG.dataset_path}/train_images/authentic/*.png")
        image_ids = [f.split("/")[-1].split(".")[0] for f in files]
        image_id = random.choice(image_ids)
        
    print(f"Visualizing image {image_id}\n")
    authentic_img = np.array(Image.open(f"{CFG.dataset_path}/train_images/authentic/{image_id}.png"))
    forged_img = np.array(Image.open(f"{CFG.dataset_path}/train_images/forged/{image_id}.png"))
    mask = np.load(f"{CFG.dataset_path}/train_masks/{image_id}.npy")
    
    original_mask_shape = mask.shape
    
    if len(authentic_img.shape) == 3 and authentic_img.shape[2] == 4:
        authentic_img = authentic_img[:, :, :3]
    if len(forged_img.shape) == 3 and forged_img.shape[2] == 4:
        forged_img = forged_img[:, :, :3]
    
    if len(mask.shape) == 3 and mask.shape[0] > 1:
        num_masks = mask.shape[0]
        individual_masks = [(mask[i] > 0).astype(float) for i in range(num_masks)]
        
        mask_combined = np.max(mask, axis=0)
        mask_combined = (mask_combined > 0).astype(float)
    else:
        num_masks = 1
        mask_combined = np.squeeze(mask)
        mask_combined = (mask_combined > 0).astype(float)
        individual_masks = [mask_combined]
    
    print("\tAuthentic image shape: ", authentic_img.shape)
    print("\tForged image shape:    ", forged_img.shape)
    print("\tMask shape:            ", original_mask_shape)
    print(f"\tNumber of masks:        {num_masks}")
    print()
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    axes[0, 0].imshow(authentic_img, cmap='gray' if len(authentic_img.shape) == 2 else None)
    axes[0, 0].set_title('Authentic Image', fontsize=12, fontweight='bold')
    axes[0, 0].axis('off')
    
    axes[0, 1].imshow(forged_img, cmap='gray' if len(forged_img.shape) == 2 else None)
    axes[0, 1].set_title('Forged Image', fontsize=12, fontweight='bold')
    axes[0, 1].axis('off')
    
    if num_masks > 1:
        colors = [
            [1, 0, 0],      # Red
            [0, 1, 0],      # Green
            [0, 0, 1],      # Blue
            [1, 1, 0],      # Yellow
            [1, 0, 1],      # Magenta
            [0, 1, 1],      # Cyan
            [1, 0.5, 0],    # Orange
            [0.5, 0, 1],    # Purple
            [0, 0.5, 0.5],  # Teal
            [1, 0.5, 0.5],  # Pink
        ]
        
        # Create RGB mask visualization
        mask_rgb = np.zeros((*mask_combined.shape, 3))
        for i, m in enumerate(individual_masks):
            color = colors[i % len(colors)]
            for c in range(3):
                mask_rgb[:, :, c] += m * color[c]
        
        axes[0, 2].imshow(mask_rgb)
    else:
        axes[0, 2].imshow(mask_combined, cmap='Reds', vmin=0, vmax=1)
    
    title = f'Forgery Mask ({num_masks} region{"s" if num_masks > 1 else ""})'
    axes[0, 2].set_title(title, fontsize=12, fontweight='bold')
    axes[0, 2].axis('off')
    
    axes[1, 0].imshow(forged_img, cmap='gray' if len(forged_img.shape) == 2 else None)
    if num_masks > 1:
        axes[1, 0].imshow(mask_rgb, alpha=0.4)
    else:
        axes[1, 0].imshow(mask_combined, cmap='Reds', alpha=0.4, vmin=0, vmax=1)
    axes[1, 0].set_title('Forged + Mask Overlay', fontsize=12, fontweight='bold')
    axes[1, 0].axis('off')
    
    if len(authentic_img.shape) == 3:
        diff = np.abs(forged_img.astype(float) - authentic_img.astype(float)).mean(axis=2)
    else:
        diff = np.abs(forged_img.astype(float) - authentic_img.astype(float))
    
    axes[1, 1].imshow(diff, cmap='hot')
    axes[1, 1].set_title('Absolute Difference', fontsize=12, fontweight='bold')
    axes[1, 1].axis('off')
    
    forged_highlighted = forged_img.copy().astype(float)
    if len(forged_highlighted.shape) == 3:
        if num_masks > 1:
            forged_highlighted = forged_highlighted * 0.6
            for i, m in enumerate(individual_masks):
                color = colors[i % len(colors)]
                mask_3d = np.stack([m, m, m], axis=2)
                color_overlay = np.array(color) * 255 * 0.6
                forged_highlighted = forged_highlighted + mask_3d * color_overlay
        else:
            mask_3d = np.stack([mask_combined, mask_combined * 0, mask_combined * 0], axis=2)
            forged_highlighted = forged_highlighted * 0.7 + mask_3d * 255 * 0.5
        
        forged_highlighted = np.clip(forged_highlighted, 0, 255).astype(forged_img.dtype)
    
    axes[1, 2].imshow(forged_highlighted, cmap='gray' if len(forged_highlighted.shape) == 2 else None)
    axes[1, 2].set_title('Highlighted Forgery', fontsize=12, fontweight='bold')
    axes[1, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
visualize()

In [None]:
visualize()

In [None]:
visualize()

## Custom training dataset 

In [None]:
class CustomDataset(Dataset):
    def __init__(self, dataset, image_size, transform, mode='train'):
        self.dataset = dataset
        self.image_size = image_size
        self.transform = transform
        self.mode = mode
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        row = self.dataset.iloc[idx]
        
        if self.mode == 'train':
            image = cv2.imread(row['image_path'])
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            if row['label'] == 'forged':
                mask = np.load(row["mask_path"])
                if len(mask.shape) == 3:
                    mask = np.max(mask, axis=0)
                else:
                    mask = np.squeeze(mask)
                mask = (mask > 0).astype(np.float32)
            else:
                mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32)

            image = cv2.resize(image, (self.image_size, self.image_size))
            mask = cv2.resize(mask, (self.image_size, self.image_size))
            
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
            
            return image, mask.unsqueeze(0)
        
        else:
            image = cv2.imread(row['image_path'])
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            original_size = image.shape[:2]
            image = cv2.resize(image, (self.image_size, self.image_size))
            
            augmented = self.transform(image=image)
            image = augmented['image']

            return row['case_id'], image, original_size

## Augmentations

In [None]:
def get_augmentations(is_train):
    train_augmentations = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),
        A.OneOf([
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1),
            A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=20, p=1),
            A.RandomGamma(gamma_limit=(80, 120), p=1),
        ], p=0.5),
        A.OneOf([
            A.GaussNoise(var_limit=(10.0, 50.0), p=1),
            A.GaussianBlur(blur_limit=(3, 5), p=1),
            A.MotionBlur(blur_limit=5, p=1),
        ], p=0.3),
        A.CoarseDropout(
            max_holes=8, max_height=32, max_width=32, 
            min_holes=1, min_height=8, min_width=8, 
            fill_value=0, p=0.3
        ),
        A.Normalize(),
        ToTensorV2()
    ])

    val_augmentations = A.Compose([
        A.Normalize(),
        ToTensorV2()
    ])
    
    if is_train:
        return train_augmentations
    else:
        return val_augmentations

# Training

In [None]:
def get_val_solution(mask_path, case_ids):
    solution = []
    for case_id in case_ids:
        mask = np.load(f"{mask_path}/{case_id}.npy")
        has_forgery = int(mask.sum() > 0)
        if has_forgery:
            solution.append({'case_id': case_id, 'annotation': rle_encode(mask), 'shape': f"[{mask.shape[1]}, {mask.shape[2]}]"})
        else:
            solution.append({'case_id': case_id, 'annotation': "authentic", "shape": 'authentic'})
            
    return pd.DataFrame(solution)

In [None]:
def save_val_samples(visualize_data, path):
    for data in visualize_data:
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        axes[0].imshow(data['original_image'])
        axes[0].set_title(f"Image ID: {data['image_id']} Label: {data['label']}")
        axes[0].axis('off')
        
        axes[1].imshow(data['original_image'])
        axes[1].imshow(data['gt_mask'], cmap='Greens', alpha=0.5, vmin=0, vmax=1)
        axes[1].set_title(f"Ground Truth Mask (Sum: {data['gt_mask'].sum():.0f})")
        axes[1].axis('off')
        
        axes[2].imshow(data['original_image'])
        axes[2].imshow(data['pred_mask'], cmap='Reds', alpha=0.5, vmin=0, vmax=1)
        axes[2].set_title(f"Predicted Mask (Sum: {data['pred_mask'].sum():.0f})")
        axes[2].axis('off')
        
        plt.show()
        plt.savefig(f"{path}/{data['image_id']}.png")

## Custom metrics

In [None]:
class CustomLoss(nn.Module):
    def __init__(self, bce_weight=0.5, dice_weight=0.5):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = smp.losses.DiceLoss(mode="binary")
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
    
    def forward(self, pred, target):
        return self.bce_weight * self.bce(pred, target) + self.dice_weight * self.dice(pred, target)

In [None]:
def dice_score(pred, target, threshold=0.5):
    pred = (torch.sigmoid(pred) > threshold).float()
    intersection = (pred * target).sum()
    return (2. * intersection) / (pred.sum() + target.sum() + 1e-6)

## Training utility class

In [None]:
class Trainer:
    def __init__(self, dataset_path):
        self.dataset_path = dataset_path
        self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    
    def train_one_epoch(self, model, dataloader, optimizer, criterion, current_epoch, n_epochs):
        model.train()
        running_loss = 0.0
        running_dice = 0.0
        
        pbar = tqdm(dataloader, desc=f'Epoch {current_epoch+1}/{n_epochs}')
        for images, masks in pbar:
            images = images.to(self.device)
            masks = masks.to(self.device)
            
            optimizer.zero_grad()
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            running_loss += loss.item()
            running_dice += dice_score(outputs, masks).item()
            
            pbar.set_postfix({'train_loss': f'{loss.item():.4f}', 'train_dice': f'{dice_score(outputs, masks).item():.4f}'})
        
        epoch_loss = running_loss / len(dataloader)
        epoch_dice = running_dice / len(dataloader)
        
        return epoch_loss, epoch_dice

    def validate_one_epoch(self, model, dataloader, criterion, current_epoch, n_epochs):
        model.eval()
        running_loss = 0.0
        running_dice = 0.0
        
        with torch.no_grad():
            pbar = tqdm(dataloader, desc=f'Epoch {current_epoch+1}/{n_epochs}')
            for images, masks in pbar:
                images = images.to(self.device)
                masks = masks.to(self.device)
                
                outputs = model(images)
                loss = criterion(outputs, masks)
                
                running_loss += loss.item()
                running_dice += dice_score(outputs, masks).item()

                pbar.set_postfix({'val_loss': f'{loss.item():.4f}', 'val_dice': f'{dice_score(outputs, masks).item():.4f}'})
        
        epoch_loss = running_loss / len(dataloader)
        epoch_dice = running_dice / len(dataloader)
        
        return epoch_loss, epoch_dice

    def train_one_fold(self, checkpoint_path, model_name, encoder_name, encoder_weights, train, valid, image_size, batch_size, n_epochs, learning_rate, weight_decay, es_patience):           
        train_dataset = CustomDataset(
            train,
            image_size=image_size,
            transform=get_augmentations(is_train=True),
            mode='train'
        )
        valid_dataset = CustomDataset(
            valid, 
            image_size=image_size,
            transform=get_augmentations(is_train=False),
            mode='train'
        )
        
        train_dataloader = DataLoader(
            train_dataset, 
            batch_size=batch_size, 
            shuffle=True, 
            num_workers=2, 
            pin_memory=True
        )
        valid_dataloader = DataLoader(
            valid_dataset, 
            batch_size=batch_size, 
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )
        
        model = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=3,
            classes=1,
        )
        
        model = model.to(self.device)
        criterion = CustomLoss()
        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs, eta_min=1e-6)
        
        best_loss = float("inf")
        iterations_without_improvement = 0
        
        history = {'epoch': [], 'train_loss': [], 'train_dice': [], 'valid_loss': [], 'valid_dice': [], 'learning_rate': []}
        for epoch in range(n_epochs):        
            train_loss, train_dice = self.train_one_epoch(model, train_dataloader, optimizer, criterion, epoch, n_epochs)
            valid_loss, valid_dice = self.validate_one_epoch(model, valid_dataloader, criterion, epoch, n_epochs)  
            
            learning_rate = optimizer.param_groups[0]['lr']
            
            history['epoch'].append(epoch)
            history['train_loss'].append(train_loss)
            history['train_dice'].append(train_dice)
            history['valid_loss'].append(valid_loss)
            history['valid_dice'].append(valid_dice)
            history['learning_rate'].append(learning_rate)
            
            scheduler.step()
            
            if valid_loss < best_loss:
                model.save_pretrained(checkpoint_path)
                
                best_loss = valid_loss
                iterations_without_improvement = 0
                print(f"Epoch {epoch+1}/{n_epochs} - Train Loss: {train_loss:.4f} - Train Dice: {train_dice:.4f} - Valid Loss: {valid_loss:.4f} - Valid Dice: {valid_dice:.4f} (Improved)")
            else:
                print(f"Epoch {epoch+1}/{n_epochs} - Train Loss: {train_loss:.4f} - Train Dice: {train_dice:.4f} - Valid Loss: {valid_loss:.4f} - Valid Dice: {valid_dice:.4f}")
                iterations_without_improvement += 1
                if iterations_without_improvement >= es_patience:
                    break

        print()
        
        del model, criterion, optimizer, scheduler, train_dataloader, valid_dataloader
        gc.collect()
        torch.cuda.empty_cache()

        return history

    def predict(self, model_paths, dataset, image_size=512, mode='test', use_tta=False, mask_threshold=0.5, mask_sum_threshold=20, save_prediction_samples=False):
        def predict_with_tta(model, images):
            model.eval()
            predictions = []
            
            with torch.no_grad():
                pred = torch.sigmoid(model(images))
                predictions.append(pred)
                
                pred = torch.sigmoid(model(torch.flip(images, dims=[3])))
                predictions.append(torch.flip(pred, dims=[3]))
                
                pred = torch.sigmoid(model(torch.flip(images, dims=[2])))
                predictions.append(torch.flip(pred, dims=[2]))
                
                pred = torch.sigmoid(model(torch.flip(images, dims=[2, 3])))
                predictions.append(torch.flip(pred, dims=[2, 3]))
            
            return torch.stack(predictions).mean(0)
        
        models = []
        for model_path in tqdm(model_paths, desc='Loading models'):
            model = smp.from_pretrained(model_path).eval().to(self.device)
            models.append(model)
        
        test_dataset = CustomDataset(dataset, image_size=image_size, transform=get_augmentations(is_train=False), mode=mode)
        test_loader = DataLoader(
            test_dataset, 
            batch_size=16,
            shuffle=False, 
            num_workers=2,
            pin_memory=True
        )
        
        visualize_indices = []
        visualize_indices = random.sample(range(len(dataset)), min(10, len(dataset)))
        visualize_data = []
        
        predictions = []
        sample_idx = 0
        for batch_ids, batch_images, batch_original_sizes in tqdm(test_loader, desc='Running inference'):
            batch_images = batch_images.to(self.device)
            batch_size = batch_images.size(0)
            
            batch_fold_preds = []
            for model in models:
                if use_tta:
                    batch_pred = predict_with_tta(model, batch_images)
                else:
                    with torch.no_grad():
                        batch_pred = torch.sigmoid(model(batch_images))
                
                batch_fold_preds.append(batch_pred.cpu().numpy())
            
            batch_final_preds = np.mean(batch_fold_preds, axis=0)

            for i in range(batch_size):
                image_id = batch_ids[i]
                original_h = batch_original_sizes[0][i].item()
                original_w = batch_original_sizes[1][i].item()
                
                final_pred = batch_final_preds[i, 0]
                final_pred = cv2.resize(final_pred, (original_w, original_h))
                mask = (final_pred > mask_threshold).astype(np.uint8)
                
                if sample_idx in visualize_indices and save_prediction_samples:
                    row = dataset.iloc[sample_idx]
                    original_image = cv2.imread(row['image_path'])
                    original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
                    
                    if row['label'] == 'forged':
                        gt_mask = np.load(row['mask_path'])
                        if len(gt_mask.shape) == 3:
                            gt_mask = np.max(gt_mask, axis=0)
                        else:
                            gt_mask = np.squeeze(gt_mask)
                        gt_mask = (gt_mask > 0).astype(np.float32)
                    else:
                        gt_mask = np.zeros((original_image.shape[0], original_image.shape[1]), dtype=np.float32)
                    
                    visualize_data.append({
                        'image_id': image_id,
                        'original_image': original_image,
                        'gt_mask': gt_mask,
                        'pred_mask': mask.astype(np.float32),
                        'label': row['label']
                    })
                
                if mask.sum() < mask_sum_threshold:
                    annotation = 'authentic'
                else:
                    annotation = rle_encode(mask)
                
                predictions.append({
                    'case_id': image_id,
                    'annotation': annotation
                })
                
                sample_idx += 1
        
        if len(visualize_data) > 0 and save_prediction_samples:            
            save_val_samples(visualize_data, model_paths[0])
        
        return pd.DataFrame(predictions)
    
    def get_fold_score(self, checkpoint_path, dataset, image_size=512, mask_threshold=0.5, mask_sum_threshold=20, save_prediction_samples=False):
        val_submission = self.predict([checkpoint_path], dataset, image_size=image_size, mode='test', mask_threshold=mask_threshold, mask_sum_threshold=mask_sum_threshold, save_prediction_samples=save_prediction_samples)
        val_solution = get_val_solution(f"{self.dataset_path}/train_masks", dataset.case_id.values.tolist())
        
        return score(val_solution, val_submission, "case_id")
    
    def cross_validate(self, dataset, cv, model_name, encoder_name, encoder_weights, image_size, batch_size, n_epochs, learning_rate, weight_decay, es_patience, mask_threshold=0.5, mask_sum_threshold=20, save_prediction_samples=False):
        histories = []
        fold_scores = []

        checkpoint_path = f"{model_name}/{encoder_name}/{encoder_weights}"
        os.makedirs(checkpoint_path, exist_ok=True)

        splits = cv.split(dataset, dataset.label)
        for fold_idx, (train_idx, valid_idx) in enumerate(splits):                
            print(f"Training fold {fold_idx+1}/{cv.get_n_splits()}\n")
            
            _train = dataset.iloc[train_idx]
            _valid = dataset.iloc[valid_idx]
            
            fold_checkpoint_path = f"{checkpoint_path}/fold_{fold_idx}"
            os.makedirs(fold_checkpoint_path, exist_ok=True)
            
            history = self.train_one_fold(fold_checkpoint_path, model_name, encoder_name, encoder_weights, _train, _valid, image_size, batch_size, n_epochs, learning_rate, weight_decay, es_patience)
            histories.append(history)
            
            fold_score = self.get_fold_score(fold_checkpoint_path, _valid, image_size, mask_threshold, mask_sum_threshold, save_prediction_samples)
            fold_scores.append(fold_score)
            
            print(f"\nFold {fold_idx+1} competition score: {fold_score:.4f}\n\n")
                
            gc.collect()
            torch.cuda.empty_cache()
            
        print(f"Final competition score: {np.mean(fold_scores):.4f}")
            
        joblib.dump(histories, f"{checkpoint_path}/histories.pkl")
        joblib.dump(fold_scores, f"{checkpoint_path}/fold_scores.pkl")
        
        return histories, fold_scores

In [None]:
trainer = Trainer(CFG.dataset_path)

histories, fold_scores = trainer.cross_validate(
    dataset=get_data(dataset_path=CFG.dataset_path),
    cv=StratifiedKFold(n_splits=CFG.n_folds, shuffle=True, random_state=CFG.seed),
    model_name=CFG.model_name, 
    encoder_name=CFG.encoder_name, 
    encoder_weights=CFG.encoder_weights, 
    image_size=CFG.image_size,
    batch_size=CFG.batch_size, 
    n_epochs=CFG.n_epochs, 
    learning_rate=CFG.learning_rate, 
    weight_decay=CFG.weight_decay, 
    es_patience=CFG.es_patience,
    save_prediction_samples=True
)

In [None]:
for fold_idx, score in enumerate(fold_scores):
    print(f"Fold {fold_idx+1} competition score: {score:.4f}")
    
print(f"\nMean competition score: {np.mean(fold_scores):.4f}")

# Visualizing the results

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes = axes.flatten()

for fold_idx in range(5):
    history = histories[fold_idx]
    
    axes[0].plot(history['epoch'], history['train_loss'], label=f'Fold {fold_idx+1}', alpha=0.7)
    axes[1].plot(history['epoch'], history['valid_loss'], label=f'Fold {fold_idx+1}', alpha=0.7)
    axes[2].plot(history['epoch'], history['train_dice'], label=f'Fold {fold_idx+1}', alpha=0.7)
    axes[3].plot(history['epoch'], history['valid_dice'], label=f'Fold {fold_idx+1}', alpha=0.7)

axes[0].set_title('Training Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].set_title('Validation Loss')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

axes[2].set_title('Training Dice Score')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Dice Score')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

axes[3].set_title('Validation Dice Score')
axes[3].set_xlabel('Epoch')
axes[3].set_ylabel('Dice Score')
axes[3].legend()
axes[3].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()