In [1]:
import os

def count_images_in_folders(directory):
    """
    Recursively counts image files in a given directory and its subdirectories.
    """
    # A tuple of common image file extensions (case-insensitive)
    image_extensions = ('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp')
    
    image_count = 0
    
    # Check if the provided path is a valid directory
    if not os.path.isdir(directory):
        print(f"Error: The path '{directory}' is not a valid directory.")
        return 0

    # os.walk() efficiently traverses all directories and subdirectories
    for root, _, files in os.walk(directory):
        for file in files:
            # Check if the file ends with one of the image extensions
            if file.lower().endswith(image_extensions):
                image_count += 1
                
    return image_count

# --- MAIN EXECUTION ---
if __name__ == "__main__":
    # ❗️ IMPORTANT: Change this to the path of your main folder!
    folder_path = '/Users/natalyagrokh/AI/ml_expressions/img_datasets/ferckjalfag_dataset_full/hard_case'
    
    total_images = count_images_in_folders(folder_path)
    
    print(f"✅ Total number of images found: {total_images}")

✅ Total number of images found: 20334


In [1]:
import os
import vertexai
from vertexai.generative_models import GenerativeModel, Part
from PIL import Image
from pathlib import Path
import random
from collections import defaultdict
from tqdm import tqdm
from google.api_core import client_options
from google.api_core.client_options import ClientOptions

In [2]:
# --- Configuration ---
# Set your Google Cloud project details here.
GCP_PROJECT_ID = "mlexpimgsorting-v2"  # Your new, working Project ID
GCP_LOCATION = "us-central1"

# This sets the specific environment variable the Google SDK requires
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = "/Users/natalyagrokh/AI/img_curation/mlexpimgsorting-v2-c5a570b110c3.json"

DATASET_ROOT_DIR = Path("/Users/natalyagrokh/AI/ml_expressions/img_datasets/ferckjalfaga_dataset_adult")
OUTPUT_DIR = Path("/Users/natalyagrokh/AI/ml_expressions/img_datasets/inpained_neutral_speech")
TOTAL_SAMPLE_SIZE = 100

In [3]:
# Uses the Gemini 1.5 Pro model on Vertex AI to perform image inpainting.
def perform_inpainting(image_path: Path, mask_path: Path, prompt: str, model: GenerativeModel):
    """Processes a single image and returns the inpainted image bytes or an error string."""
    try:
        # Prepare the image and mask for the API call.
        with open(image_path, "rb") as f:
            base_image_bytes = f.read()
        with open(mask_path, "rb") as f:
            mask_image_bytes = f.read()

        base_image_part = Part.from_data(data=base_image_bytes, mime_type="image/jpeg")
        mask_image_part = Part.from_data(data=mask_image_bytes, mime_type="image/png")

        contents = [prompt, base_image_part, mask_image_part]

        # Make the API call
        response = model.generate_content(contents)

        # Extract the generated image data
        return response.candidates[0].content.parts[0].inline_data.data

    except Exception as e:
        # Return the error message if the API call fails
        return f"Error: {e}"

In [4]:
def get_image_samples(root_dir: Path, total_samples: int):
    """Scans subdirectories for images and returns a balanced, random sample."""
    images_by_category = defaultdict(list)
    for filepath in root_dir.rglob("*.jpeg"):
        category = filepath.parent.name
        images_by_category[category].append(filepath)

    num_categories = len(images_by_category)
    if num_categories == 0:
        return []

    samples_per_category = total_samples // num_categories
    
    sampled_files = []
    for category, files in images_by_category.items():
        num_to_sample = min(samples_per_category, len(files))
        if num_to_sample > 0:
            sampled_files.extend(random.sample(files, num_to_sample))
            
    return sampled_files

In [5]:
# --- MAIN EXECUTION BLOCK ---
if __name__ == "__main__":
    # 1. Initialize the Vertex AI SDK.
    vertexai.init(project=GCP_PROJECT_ID, location=GCP_LOCATION)
    
    # 2. Initialize the Model
    # Create a custom client with a longer timeout.
    api_endpoint = f"{GCP_LOCATION}-aiplatform.googleapis.com"
    client_options_config = client_options.ClientOptions(api_endpoint=api_endpoint)
    
    # Define a transport with a 600-second (10-minute) timeout.
    custom_transport = transport.grpc.GenerativeServiceGrpcTransport(
        client_options=client_options_config,
        timeout=600.0 
    )
    
    # Initialize the model using the custom transport client.
    model = GenerativeModel(
        "gemini-2.5-pro",
        transport=custom_transport
    )
    
    OUTPUT_DIR.mkdir(exist_ok=True)
    
    # 3. Get a sample of images to process
    print(f"Scanning for images in '{DATASET_ROOT_DIR}'...")
    image_samples = get_image_samples(DATASET_ROOT_DIR, TOTAL_SAMPLE_SIZE)
    
    if not image_samples:
        print("No images found. Exiting.")
        exit()
        
    print(f"Found and sampled {len(image_samples)} images for processing.")

    # 4. Define the mask and prompt
    mask_path = Path("mask.png")
    prompt = "Inpaint the masked area to remove the glasses, keeping the expression neutral."

    if not mask_path.exists():
        print(f"Creating a dummy mask file at '{mask_path}'.")
        # ... (dummy mask creation code) ...
        dummy_mask.save(mask_path)
    
    # 5. Loop through and process images
    print("\n--- Starting Inpainting Process ---")
    for image_path in tqdm(image_samples, desc="Inpainting Images"):
        result_data = perform_inpainting(image_path, mask_path, prompt, model)
        
        if isinstance(result_data, bytes):
            relative_path = image_path.relative_to(DATASET_ROOT_DIR)
            output_path = OUTPUT_DIR / relative_path
            output_path.parent.mkdir(parents=True, exist_ok=True)
            
            with open(output_path, "wb") as f:
                f.write(result_data)
            # This print statement is removed from the loop to avoid clutter
        else:
            print(f"FAILURE: Could not process {image_path.name}. Reason: {result_data}")
            
    print("\n--- Process Complete ---")

NameError: name 'transport' is not defined