In [10]:
import numpy as np
from mtcnn.mtcnn import MTCNN
from PIL import Image
import os
import zipfile
import torch
from torchvision import transforms
import torchvision.transforms as transforms
import torchvision.models.segmentation as segmentation
from torchvision.models.segmentation import deeplabv3_resnet101

In [17]:
def process_and_remove_background(image_paths, output_size_inches=1, dpi=300, surroundings_factor=0.7, output_path_template="face_{}_{}.png", output_folder="output_images"):
    # Define the U-Net model with ResNet50 backbone
    model = torch.hub.load("pytorch/vision", "deeplabv3_resnet50", pretrained=True)
    model.eval()

    # Create output directory if it does not exist
    os.makedirs(output_folder, exist_ok=True)

    output_size_pixels = output_size_inches * dpi

    for image_index, image_path in enumerate(image_paths):
        # Load the image
        image = Image.open(image_path)

        # Convert the image to NumPy array
        img_array = np.array(image)

        # Create an MTCNN detector
        detector = MTCNN()

        # Detect faces in the image
        face_locations = detector.detect_faces(img_array)

        # Check if faces are found
        if len(face_locations) == 0:
            print(f"No faces found in image {image_index + 1}.")
            continue

        # Crop and save the faces with specified surroundings_factor
        for i, face_location in enumerate(face_locations):
            x, y, width, height = face_location['box']

            # Calculate the new dimensions for cropping based on surroundings_factor
            new_width = int(width * (1.0 + surroundings_factor))
            new_height = int(height * (1.0 + surroundings_factor))

            # Find the maximum dimension (width or height) of the detected face
            max_dimension = max(new_width, new_height)

            # Calculate the top-left and bottom-right coordinates for the square crop
            new_x1 = max(0, x + (width - max_dimension) // 2)
            new_y1 = max(0, y + (height - max_dimension) // 2)
            new_x2 = min(img_array.shape[1], x + (width + max_dimension) // 2)
            new_y2 = min(img_array.shape[0], y + (height + max_dimension) // 2)

            # Crop the image with the new dimensions
            face = img_array[new_y1:new_y2, new_x1:new_x2]

            # Resize the face to the desired output size
            face_img = Image.fromarray(face)
            aspect_ratio = face_img.width / face_img.height
            if aspect_ratio > 1:  # Image is wider than it is tall
                new_height = output_size_pixels
                new_width = int(new_height * aspect_ratio)
            else:  # Image is taller than it is wide, or is square
                new_width = output_size_pixels
                new_height = int(new_width / aspect_ratio)
            face_img = face_img.resize((new_width, new_height))

            output_path = os.path.join(output_folder, output_path_template.format(image_index + 1, i + 1))
            face_img.save(output_path)
            print(f"Face {i+1} in image {image_index + 1} saved at {output_path}.")

            # Remove background from the saved face
            # Load the input image
            face_image = Image.open(output_path)

            # Preprocess the image for the model
            preprocess = transforms.Compose([
                transforms.Resize((512, 512)),  # Adjust the size to fit the model input size
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])
            input_tensor = preprocess(face_image)
            input_batch = input_tensor.unsqueeze(0)

            # Run the image through the model
            with torch.no_grad():
                output = model(input_batch)["out"][0]

            # Convert the output to a binary mask
            mask = (output.argmax(0) == 15).float()

            # Resize the mask to match the size of the original image
            resize_transform = transforms.Resize(face_image.size, interpolation=Image.NEAREST)
            mask_resized = resize_transform(mask.unsqueeze(0)).squeeze(0)

            # Create a new RGBA image with the original image and the transparent background
            rgba_image = Image.new("RGBA", face_image.size)
            rgba_image.paste(face_image, (0, 0))

            # Convert the mask to a NumPy array and apply as the alpha channel
            mask_np = (mask_resized * 255).byte().numpy()
            mask_image = Image.fromarray(mask_np, mode="L")
            rgba_image.putalpha(mask_image)

            # Save the result
            rgba_image.save(output_path)

    # Zip the output folder if there are multiple images
    if len(image_paths) > 1:
        with zipfile.ZipFile("output_faces.zip", "w") as zipf:
            for root, _, files in os.walk(output_folder):
                for file in files:
                    zipf.write(os.path.join(root, file), os.path.basename(file))
        print("Output images have been zipped to 'output_faces.zip'.")

        # Delete unzipped files after zipping
        for root, _, files in os.walk(output_folder):
            for file in files:
                os.remove(os.path.join(root, file))
        print("Unzipped output images have been deleted.")

In [18]:
# Example usage with a list of image paths
image_paths = ["Photos/photo_6204012424814769816_y.jpg", "Photos/photo_6204012424814769817_y.jpg", "Photos/photo_6204012424814769818_y.jpg",
               "Photos/photo_6204012424814769819_y.jpg", "Photos/photo_6204012424814769820_y.jpg", "Photos/photo_6204012424814769821_y.jpg"]
process_and_remove_background(image_paths, surroundings_factor=0.7, output_path_template="face_{}_{}.png")

Using cache found in /home/msds2023/jrjimenez/.cache/torch/hub/pytorch_vision_main


Face 1 in image 1 saved at output_images/face_1_1.png.
Face 1 in image 2 saved at output_images/face_2_1.png.
Face 1 in image 3 saved at output_images/face_3_1.png.
Face 1 in image 4 saved at output_images/face_4_1.png.
Face 1 in image 5 saved at output_images/face_5_1.png.
Face 1 in image 6 saved at output_images/face_6_1.png.
Output images have been zipped to 'output_faces.zip'.
Unzipped output images have been deleted.
