In [1]:
import os
import cv2
from facenet_pytorch import MTCNN
import torch
from PIL import Image, ImageFile
import psutil
import time

In [2]:
# Allow processing of large images
Image.MAX_IMAGE_PIXELS = None  # Disable the decompression bomb limit
ImageFile.LOAD_TRUNCATED_IMAGES = True  # Allow loading truncated images

In [3]:
def is_valid_image(file_path):
    """
    Check if a file is a valid image.

    Args:
        file_path (str): Path to the image file.

    Returns:
        bool: True if the file is a valid image, False otherwise.
    """
    try:
        with Image.open(file_path) as img:
            img.verify()
        return True
    except Exception as e:
        print(f"Invalid image {file_path}: {e}")
        return False

In [4]:
def process_in_batches(image_paths, batch_size, mtcnn, output_folder):
    """
    Process images in batches, detecting and cropping faces.

    Args:
        image_paths (list): List of image file paths.
        batch_size (int): Number of images to process per batch.
        mtcnn (MTCNN): MTCNN face detection model.
        output_folder (str): Path to the folder where cropped faces will be saved.

    Returns:
        None
    """
    processed_images = 0
    errors = 0

    for i in range(0, len(image_paths), batch_size):
        batch = image_paths[i:i + batch_size]
        for file_path in batch:
            try:
                # Open and validate image
                with Image.open(file_path) as img:
                    if img.size[0] * img.size[1] > 89478485:  # Check for oversized images
                        print(f"Skipping large image: {file_path}, size: {img.size[0]}x{img.size[1]}")
                        continue

                if not is_valid_image(file_path):
                    print(f"Skipping invalid file: {file_path}")
                    continue

                image = Image.open(file_path).convert("RGB")
                boxes, _ = mtcnn.detect(image)
                if boxes is not None:
                    for idx, box in enumerate(boxes):
                        left, top, right, bottom = map(int, box)
                        face = image.crop((left, top, right, bottom))
                        output_path = os.path.join(output_folder, f"{os.path.basename(file_path)}_face{idx+1}.jpg")
                        face.save(output_path)
                        print(f"Cropped face saved to: {output_path}")
                else:
                    print(f"No faces detected in: {file_path}")
                processed_images += 1

                # Remove the processed image
                os.remove(file_path)
                print(f"Deleted processed file: {file_path}")

            except Exception as e:
                errors += 1
                with open("error_log.txt", "a") as log_file:
                    log_file.write(f"Error processing {file_path}: {e}\n")
                print(f"Error processing file {file_path}: {e}")
                continue

        print(f"Processed {len(batch)} images in batch. Total processed so far: {processed_images}. Errors: {errors}")

        # Monitor memory usage
        memory_info = psutil.virtual_memory()
        print(f"Memory usage: {memory_info.percent}%")
        if memory_info.percent > 90:
            print("High memory usage detected. Pausing for 30 seconds.")
            time.sleep(30)

    print(f"Total cropped face files in {output_folder}: {len(os.listdir(output_folder))}")

In [5]:
def crop_and_save_faces(input_folder, output_folder, batch_size=25):
    """
    Detect and crop faces from images in a folder.

    Args:
        input_folder (str): Path to the folder containing input images.
        output_folder (str): Path to the folder where cropped faces will be saved.
        batch_size (int): Number of images to process per batch.

    Returns:
        None
    """
    mtcnn = MTCNN(keep_all=True, device="cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs(output_folder, exist_ok=True)

    # Gather all image paths
    image_paths = [
        os.path.join(root, file)
        for root, _, files in os.walk(input_folder)
        for file in files if is_valid_image(os.path.join(root, file))
    ]
    print(f"Total images found: {len(image_paths)}")

    process_in_batches(image_paths, batch_size, mtcnn, output_folder)
    print("Face cropping complete.")

In [6]:
# folder paths
if __name__ == "__main__":
    INPUT_FOLDER = "/home/natalyagrokh/img_datasets/temp_scraped_images/flickr_images_10"
    OUTPUT_FOLDER = "/home/natalyagrokh/img_datasets/curated_images/flickr_dataset_curated"
    BATCH_SIZE = 25

    crop_and_save_faces(INPUT_FOLDER, OUTPUT_FOLDER, BATCH_SIZE)

Total images found: 902
Skipping large image: /home/natalyagrokh/img_datasets/temp_scraped_images/flickr_images_10/image_6674.jpg, size: 11656x8742
Skipping large image: /home/natalyagrokh/img_datasets/temp_scraped_images/flickr_images_10/image_11137.jpg, size: 15370x10639
Skipping large image: /home/natalyagrokh/img_datasets/temp_scraped_images/flickr_images_10/image_5633.jpg, size: 8532x11164
Skipping large image: /home/natalyagrokh/img_datasets/temp_scraped_images/flickr_images_10/image_15751.jpg, size: 15370x10639
Skipping large image: /home/natalyagrokh/img_datasets/temp_scraped_images/flickr_images_10/image_14650.jpg, size: 10608x10639
Skipping large image: /home/natalyagrokh/img_datasets/temp_scraped_images/flickr_images_10/image_6723.jpg, size: 11648x7765
Skipping large image: /home/natalyagrokh/img_datasets/temp_scraped_images/flickr_images_10/image_9731.jpg, size: 15370x10639
Skipping large image: /home/natalyagrokh/img_datasets/temp_scraped_images/flickr_images_10/image_6141