In [1]:
""" Capybara image dataset generator.
This notebook downloads images of capybaras and not capybaras and produces an h5 dataset of labelled images.
"""
import os
import requests
import time
from io import BytesIO
from PIL import Image
from bs4 import BeautifulSoup
import hashlib
import logging
from tqdm import tqdm
import cv2
import numpy as np
import h5py
from sklearn.model_selection import train_test_split

def calculate_image_hash(image_bytes):
    """
    Calculate a unique hash for an image.
    """
    hasher = hashlib.md5()
    hasher.update(image_bytes)
    return hasher.hexdigest()

def download_images(query, num_images = 1, save_path = './images'):
    """
    Downoad num_images of unique images requested in query. Save to save_path.
    """
    # Create a folder for the query if it doesn't exist
    if save_path == './images':
        query_folder = os.path.join(save_path, query)
    else:
        query_folder = os.path.join(save_path)
        
    os.makedirs(query_folder, exist_ok=True)
    
    # Counter for downloaded images
    count = 0

    # Set to store downloaded image hashes
    downloaded_image_hashes = set()

    # Initialize the page counter
    page = 0

    with tqdm(total = num_images, desc = f"Downloading {query} images", unit = "image") as pbar:
        while count < num_images:
            url = f"https://www.google.com/search?q={query}&tbm=isch&start={page * 20}" 

            try:
                # Send a GET request with a User-Agent header
                headers = {'User-Agent': 'Mozilla/5.0'}
                response = requests.get(url, headers = headers)
                response.raise_for_status()

                soup = BeautifulSoup(response.text, "html.parser")
                img_tags = soup.find_all("img")

                if img_tags:
                    for img_tag in img_tags:
                        if count >= num_images:
                            break

                        img_link = img_tag.get("src")
                        
                        if img_link and img_link.startswith("http"):
                            response = requests.get(img_link)
                            response.raise_for_status()
                            image_bytes = response.content

                            # Calculate the hash of the image
                            image_hash = calculate_image_hash(image_bytes)

                            # Check if the image hash has already been downloaded
                            if image_hash not in downloaded_image_hashes:
                                count += 1
                                downloaded_image_hashes.add(image_hash)

                                img = Image.open(BytesIO(image_bytes))
                                img.save(os.path.join(query_folder, f'{query}_{count}.png'))
                                logging.info(f'{query} - Downloaded image {count}')
                                pbar.update(1)
                            else:
                                logging.info(f'{query} - Skipped duplicate image')
                else:
                    logging.warning(f"Couldn't find any images for: {query}")

            except requests.exceptions.RequestException as e:
                logging.error(f"Error fetching images for {query}: {e}")

            # Increment the page counter for the next iteration
            page += 1

            # Add a 2 second sleep time between requests
            time.sleep(2)

def load_and_preprocess_images(folder):
    """
    Load and preprocess images in a specified folder.
    """
    images = []
    for filename in os.listdir(folder):
        if filename.endswith(".jpeg") or filename.endswith(".jpg") or filename.endswith(".png"):
            img = cv2.imread(os.path.join(folder, filename))
            img = cv2.resize(img, (64, 64))  # Resizing to (64, 64)
            images.append(img)
    return images


if __name__ == "__main__":
    # Configure logging
    logging.basicConfig(filename = 'image_downloader.log', level = logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

    query = 'capybara'
    limit = 250

    download_images('-cartoon -drawing -vector ' + query, num_images = limit, save_path='./images/capybara')
    download_images(f'-{query} animal', num_images = int(.1 * limit), save_path='./images/not_capybara')
    download_images(f'-{query} objects', num_images = int(.1 * limit), save_path='./images/not_capybara')
    download_images(f'-{query} city', num_images = int(.2 * limit), save_path='./images/not_capybara')
    download_images(f'-{query} nature', num_images = int(.2 * limit), save_path='./images/not_capybara')
    download_images(f'-{query} human', num_images = int(.2 * limit), save_path='./images/not_capybara')
    download_images(f'-{query} food', num_images = int(.2 * limit), save_path='./images/not_capybara')

    # Define the paths to your image folders
    capybara_folder = "./images/capybara"
    non_capybara_folder = "./images/not_capybara"

    # Define the H5 dataset file names
    h5_dataset_filename = "capybara_dataset.h5"

    # Load capybara and non-capybara images
    capybara_images = load_and_preprocess_images(capybara_folder)
    non_capybara_images = load_and_preprocess_images(non_capybara_folder)

    classes = ["not capybara", "capybara"]

    # Create an H5 file for the train dataset and store the images and labels
    with h5py.File(h5_dataset_filename, "w") as h5file:
        h5file.create_dataset("x", data = np.array(capybara_images + non_capybara_images))
        h5file.create_dataset("y", data = np.array([1] * len(capybara_images) + [0] * len(non_capybara_images)))
        h5file.create_dataset("list_classes", data = classes, dtype = h5py.special_dtype(vlen=str))

    print(f"H5 dataset saved as {h5_dataset_filename}")

Downloading -cartoon -drawing -vector capybara images: 100%|█| 250/250 [00:50<00
Downloading -capybara animal images: 100%|███| 25/25 [00:06<00:00,  3.97image/s]
Downloading -capybara objects images: 100%|██| 25/25 [00:06<00:00,  3.70image/s]
Downloading -capybara city images: 100%|█████| 50/50 [00:11<00:00,  4.35image/s]
Downloading -capybara nature images: 100%|███| 50/50 [00:11<00:00,  4.19image/s]
Downloading -capybara human images: 100%|████| 50/50 [00:10<00:00,  4.80image/s]
Downloading -capybara food images: 100%|█████| 50/50 [00:10<00:00,  4.66image/s]


H5 dataset saved as capybara_dataset.h5
