In [2]:
import os
import numpy as np
from glob import glob
from skimage import io
from skimage.util import view_as_windows
import matplotlib.pyplot as plt
import cv2
from skimage.metrics import structural_similarity
from PIL import Image
import tensorflow as tf

In [18]:
def initialize_patch_extractor(input_path, output_path, patches_per_image=4, rotations=8, stride=8, mode='no_rot'):
    """
    Initializes parameters for patch extraction.
    """
    rotations_list = [0, 90, 180, 270][:rotations]  # Select rotations based on user input
    relevant_indices = [3, 6, 7, 12]
    
    authentic_images_list = glob(os.path.join(input_path, 'Au', '*'))
    authentic_images_dict = {
        path.split(os.sep)[-1][relevant_indices[0]:relevant_indices[1]] + path.split(os.sep)[-1][relevant_indices[2]:relevant_indices[3]]:
        path for path in authentic_images_list
    }

    return {
        'patches_per_image': patches_per_image,
        'stride': stride,
        'rotations': rotations_list,
        'mode': mode,
        'input_path': input_path,
        'output_path': output_path,
        'au_img_dict': authentic_images_dict,
        'background_index': [13, 21]
    }

In [24]:
def extract_patches_from_image(image, window_shape, stride, num_of_patches, rotations, output_path, im_name, rep_num, mode, patch_type):
    """
    Extracts and saves patches from an image, optionally applying rotations.
    """
    # Extract patches from the image
    windows = view_as_windows(image, window_shape, step=stride)
    patches = [window[0] for window in windows.reshape(-1, *window_shape)]

    # Select random patches
    selected_indices = np.random.choice(len(patches), num_of_patches, replace=False)

    for i, index in enumerate(selected_indices):
        patch = patches[index]
        if mode == 'rot':
            # Save rotated patches
            pil_patch = Image.fromarray(np.uint8(patch))
            for angle in rotations:
                rotated_patch = pil_patch.rotate(angle, resample=Image.BILINEAR)
                rotated_patch.save(f"{output_path}/{patch_type}/{im_name}_{i}_{angle}_{rep_num}.png")
        else:
            # Save patches without rotation
            io.imsave(f"{output_path}/{patch_type}/{im_name}_{i}_{rep_num}.png", patch)

In [16]:
def process_tampered_image(image_path, config, rep_num):
    """
    Processes a tampered image to extract and save patches.
    """
    image = io.imread(image_path)
    im_name = os.path.splitext(os.path.basename(image_path))[0]
    mask_path = os.path.join(config['input_path'], 'masks', f"{im_name}_gt.png")
    mask = io.imread(mask_path)
    image, mask = check_and_reshape(image, mask)

    tampered_patches, num_of_patches = find_tampered_patches(
        image, im_name, mask, (128, 128, 3), config['stride'], 'casia2', config['patches_per_image']
    )
    
    extract_patches_from_image(
        image, (128, 128, 3), config['stride'], num_of_patches, config['rotations'],
        config['output_path'], im_name, rep_num, config['mode'], 'tampered'
    )
    extract_authentic_patches(config, image_path, num_of_patches, rep_num)

In [15]:
def extract_authentic_patches(config, tampered_image_path, num_of_patches, rep_num):
    """
    Extracts and saves patches from an authentic image.
    """
    tampered_image_key = tampered_image_path.split(os.sep)[-1][config['background_index'][0]:config['background_index'][1]]
    
    if tampered_image_key in config['au_img_dict']:
        authentic_image_name = os.path.splitext(os.path.basename(config['au_img_dict'][tampered_image_key]))[0]
        au_image = plt.imread(config['au_img_dict'][tampered_image_key])
        
        extract_patches_from_image(
            au_image, (128, 128, 3), config['stride'], num_of_patches, config['rotations'],
            config['output_path'], authentic_image_name, rep_num, config['mode'], 'authentic'
        )

In [14]:
def extract_patches(config):
    """
    Extracts patches from tampered and authentic images.
    """
    os.makedirs(config['output_path'], exist_ok=True)
    for subdir in ['authentic', 'tampered']:
        os.makedirs(os.path.join(config['output_path'], subdir), exist_ok=True)

    tp_dir = os.path.join(config['input_path'], 'Tp')
    rep_num = 0

    for f in os.listdir(tp_dir):
        try:
            rep_num += 1
            process_tampered_image(os.path.join(tp_dir, f), config, rep_num)
        except (IOError, IndexError) as e:
            rep_num -= 1
            print(str(e))

In [13]:
def find_tampered_patches(image, im_name, mask, window_shape, stride, dataset, patches_per_image):
    """
    Finds tampered patches from an image and mask.
    """
    image_patches = view_as_windows(image, window_shape, step=stride)
    mask_patches = view_as_windows(mask, window_shape, step=stride)

    tampered_patches = []
    
    for i in range(image_patches.shape[0]):
        for j in range(image_patches.shape[1]):
            img_patch = image_patches[i, j, 0]
            mask_patch = mask_patches[i, j, 0]
            num_zeros = np.sum(mask_patch == 0)
            num_ones = np.sum(mask_patch == 255)
            total = num_ones + num_zeros
            
            if dataset == 'casia2' and num_zeros <= 0.99 * total:
                tampered_patches.append(img_patch)
            elif dataset == 'nc16' and 0.80 * total >= num_ones >= 0.20 * total:
                tampered_patches.append(img_patch)

    num_of_patches = min(len(tampered_patches), patches_per_image)
    if len(tampered_patches) < patches_per_image:
        print(f"Number of tampered patches for image {im_name} is only {len(tampered_patches)}")

    return tampered_patches, num_of_patches

In [12]:
def check_and_reshape(image, input_mask):
    """
    Ensures the mask matches the image size and format.
    """
    if input_mask.ndim == 2:
        mask = np.stack([input_mask] * 3, axis=-1)
    else:
        mask = input_mask

    if image.shape != mask.shape:
        if image.shape[0] == mask.shape[1] and image.shape[1] == mask.shape[0]:
            mask = np.reshape(mask, (image.shape[0], image.shape[1], mask.shape[2]))

    return image, mask


In [11]:
def find_mask(tampered_image_path, authentic_images_dict):
    """
    Creates and saves a mask of the image indicating the tampered region.
    """
    relevant_indices = [13, 21]
    save_name = os.path.splitext(os.path.basename(tampered_image_path))[0]
    tampered_image_key = os.path.basename(tampered_image_path)[relevant_indices[0]:relevant_indices[1]]

    if tampered_image_key in authentic_images_dict:
        authentic_image_path = authentic_images_dict[tampered_image_key]
        authentic_image = plt.imread(authentic_image_path)
        tampered_image = plt.imread(tampered_image_path)
        
        if tampered_image.shape == authentic_image.shape:
            gray_authentic_image = cv2.cvtColor(authentic_image, cv2.COLOR_BGR2GRAY)
            gray_tampered_image = cv2.cvtColor(tampered_image, cv2.COLOR_BGR2GRAY)
            
            (_, diff) = structural_similarity(gray_authentic_image, gray_tampered_image, full=True)
            diff = cv2.medianBlur(diff.astype(np.float32), 1)
            
            mask = np.where(diff < 0.98, 255, 0).astype("uint8")
            
            cv2.imwrite(f'masks/{save_name}_gt.png', mask)


In [10]:
def extract_masks():
    """
    Extract and save masks for all tampered images in the dataset.
    """
    save_dir = 'masks'
    os.makedirs(save_dir, exist_ok=True)

    authentic_image_paths = glob(os.path.join('CASIA2', 'Au', '*'))
    tampered_image_paths = glob(os.path.join('CASIA2', 'Tp', '*'))

    relevant_indices = [3, 6, 7, 12]
    authentic_images_dict = {
        path.split(os.sep)[-1][relevant_indices[0]:relevant_indices[1]] + path.split(os.sep)[-1][relevant_indices[2]:relevant_indices[3]]:
        path for path in authentic_image_paths
    }
    
    for tampered_image_path in tampered_image_paths:
        find_mask(tampered_image_path, authentic_images_dict)

In [19]:
# Configuration dictionary
config = {
    'input_path': 'CASIA2',
    'output_path': 'patches_with_rot',
    'patches_per_image': 2,
    'rotations': [0, 90, 180, 270],
    'stride': 128,
    'mode': 'rot',  # 'rot' for rotation, 'no_rot' for no rotation
    'background_index': [13, 21],  # Ensure this key is present
    'au_img_dict': {}  # This should be populated based on your image paths
}

In [29]:
config = initialize_patch_extractor( 
    input_path = 'CASIA2',
    output_path = 'patches_with_rot',
    patches_per_image = 2,
    rotations = 4,
    stride = 128,
    mode = 'rot',  # 'rot' for rotation, 'no_rot' for no rotation
)

In [22]:
extract_masks()

In [30]:
extract_patches(config)

Number of tampered patches for image Tp_S_NRD_S_N_pla20019_pla20019_02390 is only 0
Number of tampered patches for image Tp_S_NNN_S_B_art00070_art00070_01241 is only 1
Number of tampered patches for image Tp_D_NRD_S_N_sec00001_cha00042_00001 is only 1
Number of tampered patches for image Tp_D_CRN_S_N_sec00071_art00028_11281 is only 1
Number of tampered patches for image Tp_D_NRN_S_N_arc00098_ani00098_00318 is only 1
Number of tampered patches for image Tp_D_NRN_S_N_ind00022_ind00088_00438 is only 1
Number of tampered patches for image Tp_D_NRN_S_B_nat00023_cha00024_20076 is only 1
Number of tampered patches for image Tp_S_NRN_S_N_ind20010_ind20010_01751 is only 0
Number of tampered patches for image Tp_S_NNN_S_N_arc20062_arc20062_01702 is only 0
Number of tampered patches for image Tp_S_NNN_S_N_cha00002_cha00002_00837 is only 1
Number of tampered patches for image Tp_S_NRN_S_N_ani10115_ani10115_11656 is only 1
Number of tampered patches for image Tp_S_NNN_S_N_nat00091_nat00091_00993 is