Get FREE API keys:
- Pexels: https://www.pexels.com/api/ 

In [1]:
import os
import requests
from PIL import Image, ImageStat
from io import BytesIO
import time
from tqdm import tqdm

In [None]:
# ==================== CONFIGURATION ====================
TARGET_SIZE = 512
TOTAL_IMAGES = 10000 # will end early if not enough valid images found and dataset is exhausted
TRAIN_SPLIT = 0.8
SATURATION_THRESHOLD = 20 # Threshold (0-255). 0 is pure grey. 20 filters out B&W and very washed-out sepia.`

PEXELS_API_KEY = "" # Insert your Pexels API key here  
# =======================================================

# Download dataset

In [None]:
def resize_and_crop_to_square(image, size=512):
    """Resize and center crop to exact square"""
    width, height = image.size
    
    if width < height:
        new_width = size
        new_height = int(height * (size / width))
    else:
        new_height = size
        new_width = int(width * (size / height))
    
    image = image.resize((new_width, new_height), Image.LANCZOS)
    
    left = (new_width - size) // 2
    top = (new_height - size) // 2
    image = image.crop((left, top, left + size, top + size))
    
    return image


def is_greyscale(image, threshold=20):
    """
    Converts image to HSV and checks if the average saturation 
    is below the threshold.
    """
    # Convert to Hue, Saturation, Value space
    hsv_img = image.convert('HSV')
    
    # Calculate the average value of the Saturation channel (index 1)
    saturation = ImageStat.Stat(hsv_img).mean[1]
    
    # If average saturation is low, it's likely greyscale/B&W
    return saturation < threshold

def download_and_process(img_url, timeout=15):
    """Download, check color, and ensure RGB 512x512"""
    try:
        response = requests.get(img_url, timeout=timeout)
        response.raise_for_status()
        
        image = Image.open(BytesIO(response.content))
        
        if is_greyscale(image, SATURATION_THRESHOLD):
            return None, False

        
        # Convert to RGB
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Resize to 512x512
        image = resize_and_crop_to_square(image, TARGET_SIZE)
        
        return image, True
    except Exception as e:
        # print(f"Debug error: {e}") 
        return None, False

def download_from_pexels(api_key, total_images, train_split):
    print("\nDownloading from Pexels...")
    
    os.makedirs('dataset/train', exist_ok=True)
    os.makedirs('dataset/val', exist_ok=True)
    
    headers = {'Authorization': api_key}
    train_count = int(total_images * train_split)
    
    train_idx = 0
    val_idx = 0
    downloaded = 0
    failed = 0
    page = 1
    per_page = 80
    
    with tqdm(total=total_images, desc="Downloading") as pbar:
        while downloaded < total_images:
            try:
                url = f'https://api.pexels.com/v1/curated?per_page={per_page}&page={page}'
                response = requests.get(url, headers=headers, timeout=15)
                response.raise_for_status()
                data = response.json()
                
                if 'photos' not in data or len(data['photos']) == 0:
                    break
                
                for photo in data['photos']:
                    if downloaded >= total_images:
                        break
                    
                    # Get large image
                    img_url = photo['src']['large2x']
                    image, success = download_and_process(img_url)
                    
                    if success:
                        # Save
                        if train_idx < train_count:
                            filename = f'dataset/train/image_{train_idx + 1:06d}.jpg'
                            train_idx += 1
                        else:
                            filename = f'dataset/val/val_{val_idx + 1:06d}.jpg'
                            val_idx += 1
                        
                        image.save(filename, 'JPEG', quality=95)
                        downloaded += 1
                        pbar.update(1)
                    else:
                        failed += 1
                    
                    time.sleep(0.05)
                
                page += 1
                time.sleep(1)
                
            except Exception as e:
                print(f"\n❌ Error: {e}")
                time.sleep(5)
                continue
    
    print_summary(downloaded, train_idx, val_idx, failed)


def print_summary(downloaded, train_idx, val_idx, failed):
    print(f"\n{'='*60}")
    print(f"Download Complete!")
    print(f"{'='*60}")
    print(f"Statistics:")
    print(f"  Downloaded: {downloaded} images (512x512 RGB)")
    print(f"  Train: {train_idx} images")
    print(f"  Val: {val_idx} images")
    print(f"  Failed: {failed}")
    print(f"{'='*60}")

def main():
    if PEXELS_API_KEY:
        print("Pexels API key found")
        download_from_pexels(PEXELS_API_KEY, TOTAL_IMAGES, TRAIN_SPLIT)
    else:
        print("No API key found!")

if __name__ == "__main__":
    main()

Pexels API key found

Downloading from Pexels...


Downloading:  69%|██████▉   | 6919/10000 [2:46:41<1:14:13,  1.45s/it]


Download Complete!
Statistics:
  Downloaded: 6919 images (512x512 RGB)
  Train: 6919 images
  Val: 0 images
  Failed: 1081





# Split into test/train
Only need to run this if you exhausted the dataset and need to split the data from train to test

In [None]:
import shutil
import random
from pathlib import Path


# Define paths
train_dir = Path('dataset/train')
val_dir = Path('dataset/val')

# Create validation directory if it doesn't exist
val_dir.mkdir(parents=True, exist_ok=True)

# Get all .jpg files from train directory
jpg_files = list(train_dir.glob('*.jpg'))

print(f"Total images in train folder: {len(jpg_files)}")

# Calculate number of files to move (15%)
num_val = int(len(jpg_files) * 0.15)
print(f"Moving {num_val} images to validation folder ({num_val/len(jpg_files)*100:.1f}%)")

# Randomly select files to move
files_to_move = random.sample(jpg_files, num_val)

# Move the files
for file_path in files_to_move:
    destination = val_dir / file_path.name
    shutil.move(str(file_path), str(destination))
    
print(f"\nCompleted!")
print(f"Train folder now has: {len(list(train_dir.glob('*.jpg')))} images")
print(f"Validation folder now has: {len(list(val_dir.glob('*.jpg')))} images")

Total images in train folder: 6919
Moving 1037 images to validation folder (15.0%)

Completed!
Train folder now has: 5882 images
Validation folder now has: 1037 images


# Rename files in each folder

If you've shuffle the data and want clean names

In [5]:
def rename_images_in_folder(folder_path, prefix):
    """
    Rename all .jpg images in a folder with a given prefix and zero-padded numbering.
    
    Args:
        folder_path: Path to the folder containing images
        prefix: Prefix for the new filenames (e.g., 'train_img' or 'val_img')
    """
    folder = Path(folder_path)
    
    # Get all .jpg files
    jpg_files = sorted(folder.glob('*.jpg'))
    
    if not jpg_files:
        print(f"No .jpg files found in {folder_path}")
        return
    
    # Determine the number of digits needed for zero-padding
    num_digits = len(str(len(jpg_files)))
    num_digits = max(num_digits, 4)  # Use at least 4 digits
    
    print(f"\nRenaming {len(jpg_files)} images in '{folder_path}'...")
    
    # Create temporary names first to avoid conflicts
    temp_mapping = []
    for i, old_path in enumerate(jpg_files, start=1):
        temp_name = folder / f"temp_{i}_{old_path.name}"
        old_path.rename(temp_name)
        temp_mapping.append((temp_name, i))
    
    # Now rename to final names
    for temp_path, i in temp_mapping:
        new_name = f"{prefix}_{str(i).zfill(num_digits)}.jpg"
        new_path = folder / new_name
        temp_path.rename(new_path)
        print(f"  {temp_path.name} -> {new_name}")
    
    print(f"✓ Completed renaming in '{folder_path}'")

# Define paths
train_dir = 'dataset/train'
val_dir = 'dataset/val'

# Rename images in train folder
rename_images_in_folder(train_dir, 'train_img')

# Rename images in validation folder
rename_images_in_folder(val_dir, 'val_img')

print("\n" + "="*50)
print("All images have been renamed successfully!")
print("="*50)


Renaming 5882 images in 'dataset/train'...
  temp_1_image_000001.jpg -> train_img_0001.jpg
  temp_2_image_000002.jpg -> train_img_0002.jpg
  temp_3_image_000003.jpg -> train_img_0003.jpg
  temp_4_image_000006.jpg -> train_img_0004.jpg
  temp_5_image_000007.jpg -> train_img_0005.jpg
  temp_6_image_000008.jpg -> train_img_0006.jpg
  temp_7_image_000009.jpg -> train_img_0007.jpg
  temp_8_image_000010.jpg -> train_img_0008.jpg
  temp_9_image_000011.jpg -> train_img_0009.jpg
  temp_10_image_000012.jpg -> train_img_0010.jpg
  temp_11_image_000013.jpg -> train_img_0011.jpg
  temp_12_image_000015.jpg -> train_img_0012.jpg
  temp_13_image_000016.jpg -> train_img_0013.jpg
  temp_14_image_000017.jpg -> train_img_0014.jpg
  temp_15_image_000019.jpg -> train_img_0015.jpg
  temp_16_image_000020.jpg -> train_img_0016.jpg
  temp_17_image_000021.jpg -> train_img_0017.jpg
  temp_18_image_000022.jpg -> train_img_0018.jpg
  temp_19_image_000023.jpg -> train_img_0019.jpg
  temp_20_image_000024.jpg -> trai