In [1]:
import os
import json
import logging
import requests
from urllib.parse import urlparse
import mimetypes

In [2]:
from PIL import Image
from collections import defaultdict
import shutil
import numpy as np
import sqlite3
import pickle
from scipy.spatial.distance import cosine
from transformers import ViTFeatureExtractor, ViTModel
import torch

In [3]:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")



In [4]:
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("image_fetch.log"),  
        logging.StreamHandler()               
    ]
)

In [5]:
GI_API_KEY = os.getenv('API_KEY', 'AIzaSyA43pbijmUNCtMNgSopT7VOimtgERBRXKU')
GI_SEARCH_ENGINE_ID = os.getenv('SEARCH_ENGINE_ID', 'b6f2650fd0921483a')

In [6]:
PEXELS_API_KEY = os.getenv('API_KEY', 'yxshO7kOwkkbsGf2TmXkfq2MqWaMYjdaVOja0elnSPBXPgBL645wyYhs')

In [None]:
def print_image_counts(image_output_folder):
    """
    Print the number of images in each category folder.
    
    Args:
        image_output_folder (str): Base folder containing category subfolders
    """
    print("\nCurrent Image Counts:")
    
    # Ensure the folder exists
    if not os.path.exists(image_output_folder):
        print(f"Output folder {image_output_folder} does not exist.")
        return
    
    # Track total images
    total_images = 0
    categories_with_images = 0
    
    # Iterate through each category
    for category in categories:
        category_folder = os.path.join(image_output_folder, category)
        
        # Count image files (excluding directories)
        if os.path.exists(category_folder):
            image_count = len([f for f in os.listdir(category_folder) 
                                if os.path.isfile(os.path.join(category_folder, f))])
            
            print(f"{category}: {image_count} images")
            total_images += image_count
            
            if image_count > 0:
                categories_with_images += 1
    
    print(f"\nTotal Images: {total_images}")
    print(f"Categories with Images: {categories_with_images}/{len(categories)}")

In [7]:
def init_database(db_path="image_embeddings.db"):
    """
    Initialize a database at the given path, creating a table if it doesn't exist.

    Args:
        db_path (str): Path to the database file. Defaults to "image_embeddings.db".

    Returns:
        sqlite3.Connection: The established connection to the database.
    """

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute("""
        CREATE TABLE IF NOT EXISTS embeddings (
            id INTEGER PRIMARY KEY,
            url TEXT UNIQUE,
            embedding BLOB
        )
    """)
    conn.commit()
    return conn

def copy_existing_data(input_folder, output_folder, db_file):
    """Copy existing images and database to a writable output folder.

    Args:
        
    
    """
    
    if os.path.exists(input_folder):
        shutil.copytree(input_folder, output_folder, dirs_exist_ok=True)
    db_input = os.path.join(input_folder, db_file)
    db_output = os.path.join(output_folder, db_file)
    if os.path.exists(db_input):
        shutil.copy(db_input, db_output)
    return db_output
    
    """
    os.makedirs(output_folder, exist_ok=True)
    db_input = os.path.join(input_folder, db_file)
    db_output = os.path.join(output_folder, db_file)
    if os.path.exists(db_input):
        shutil.copy(db_input, db_output)
        print("Copied db!")
    return db_output
    """
    

def save_embedding_to_db(conn, url, embedding):
    """Save an image embedding to the database.

    Args:
        conn (sqlite3.Connection): Connection to the database.
        url (str): URL of the image.
        embedding (numpy.ndarray): The image embedding to save.
    """
    cursor = conn.cursor()
    embedding_blob = pickle.dumps(embedding)  
    try:
        cursor.execute("INSERT INTO embeddings (url, embedding) VALUES (?, ?)", (url, embedding_blob))
        conn.commit()
    except sqlite3.IntegrityError:
        logging.info(f"URL already exists in the database: {url}")

def is_similar_to_existing(conn, new_embedding, threshold=0.5):
    """
    Check if the given embedding is similar to any existing embedding in the database.
    
    Args:
        conn (sqlite3.Connection): Connection to the database.
        new_embedding (numpy.ndarray): The embedding to check.
        threshold (float, optional): The maximum cosine similarity between the new embedding and an existing one.
            Defaults to 0.5.
    
    Returns:
        bool: True if the new embedding is similar to an existing one, False otherwise.
    """
    cursor = conn.cursor()
    cursor.execute("SELECT embedding FROM embeddings")
    for row in cursor.fetchall():
        existing_embedding = pickle.loads(row[0])  
        similarity = 1 - cosine(new_embedding, existing_embedding)
        if similarity > (1 - threshold):  
            return True
    return False

def get_image_embedding(image_path):
    """
    Generate an embedding for an image using a Vision Transformer (ViT).

    Args:
        image_path (str): Path to the image file for which the embedding is to be generated.

    Returns:
        numpy.ndarray: A 1D array representing the image embedding generated by the ViT model.
    """

    # Load and preprocess the image
    image = Image.open(image_path).convert("RGB")
    inputs = feature_extractor(images=image, return_tensors="pt")

    # Get the output embeddings
    with torch.no_grad():
        outputs = vit_model(**inputs)
        embedding = outputs.last_hidden_state[:, 0, :]  # [CLS] token embedding

    # Convert to numpy array
    return embedding.squeeze().numpy()

def load_page_state(state_file="page_state.json"):
    """
    Load the page state from a file to track which pages have been crawled.

    Args:
        state_file (str): Path to the state file. Defaults to "page_state.json".

    Returns:
        dict: A dictionary containing the last fetched page number for each category.
    """
    if os.path.exists(state_file):
        with open(state_file, "r") as file:
            return json.load(file)
    return {}

def save_page_state(page_state, state_file="page_state.json"):
    """
    Save the page state to a file to track which pages have been crawled.

    Args:
        page_state (dict): A dictionary containing the last fetched page number for each category.
        state_file (str): Path to the state file. Defaults to "page_state.json".
    """
    with open(state_file, "w") as file:
        json.dump(page_state, file)

def fetch_images_from_google_image(category, num_results, state_file="page_state.json", api_key=GI_API_KEY, search_engine_id=GI_SEARCH_ENGINE_ID):
    """
    Fetch image URLs for a given category using the Google Custom Search API, starting from the last saved page state.

    Args:
        category (str): The search category or query term.
        num_results (int): The total number of image URLs to fetch.
        state_file (str, optional): Path to the state file.
        api_key (str, optional): The API key for Google Custom Search.
        search_engine_id (str, optional): The search engine ID for Google Custom Search.

    Returns:
        list[str]: A list of image URLs fetched for the given category.
    """
    page_state = load_page_state(state_file)
    start = page_state.get(category, 1) 

    url = "https://www.googleapis.com/customsearch/v1"
    image_urls = []

    while len(image_urls) < num_results:
        params = {
            'key': api_key,
            'cx': search_engine_id,
            'q': category,
            'searchType': 'image',
            'num': min(10, num_results - len(image_urls)),
            'start': start,
        }
        response = requests.get(url, params=params)
        response.raise_for_status()
        data = response.json()
        items = data.get('items', [])
        if not items:
            break
        image_urls.extend([item['link'] for item in items])
        start += len(items)

    # Update and save page state
    page_state[category] = start
    save_page_state(page_state, state_file)

    return image_urls[:num_results]

def fetch_images_from_pexels(category, num_results, state_file="pexels_page_state.json", api_key=PEXELS_API_KEY):
    """
    Fetch image URLs for a given category using the Pexels API, starting from the last saved page state.

    Args:
        category (str): The search category or query term.
        num_results (int): The total number of image URLs to fetch.
        state_file (str, optional): Path to the state file. Defaults to "pexels_page_state.json".
        api_key (str, optional): The API key for Pexels API.

    Returns:
        list[str]: A list of image URLs fetched for the given category.
    """
    headers = {"Authorization": api_key}
    base_url = "https://api.pexels.com/v1/search"

    # Load page state
    if os.path.exists(state_file):
        with open(state_file, "r") as file:
            page_state = json.load(file)
    else:
        page_state = {}

    page = page_state.get(category, 1)
    image_urls = []

    while len(image_urls) < num_results:
        params = {
            "query": category,
            "per_page": min(80, num_results - len(image_urls)), 
            "page": page,
        }
        response = requests.get(base_url, headers=headers, params=params)
        response.raise_for_status()
        data = response.json()

        photos = data.get("photos", [])
        if not photos: 
            break

        # Collect image URLs
        image_urls.extend([photo["src"]["original"] for photo in photos])
        page += 1  # Move to the next page

    # Save the updated page state
    page_state[category] = page
    with open(state_file, "w") as file:
        json.dump(page_state, file)

    return image_urls[:num_results]

def fetch_images_from_openverse(category, num_results, state_file="openverse_page_state.json"):
    """
    Fetch image URLs for a given category using the Openverse API.

    Args:
        category (str): Search category or query term.
        num_results (int): Total number of image URLs to fetch.
        state_file (str, optional): Path to the state file. Defaults to "openverse_page_state.json".

    Returns:
        list[str]: List of image URLs fetched for the given category.
    """
    base_url = "https://api.openverse.engineering/v1/images"
    page_state = {}

    # Load previous page state if it exists
    if os.path.exists(state_file):
        with open(state_file, "r") as file:
            page_state = json.load(file)

    page = page_state.get(category, 1)
    image_urls = []

    while len(image_urls) < num_results:
        params = {
            "q": category,
            "page": page,
            "page_size": min(100, num_results - len(image_urls)),
            "license": "cc0,by",
        }

        try:
            response = requests.get(base_url, params=params)
            response.raise_for_status()
            data = response.json()

            results = data.get("results", [])
            if not results:
                break

            image_urls.extend([item["url"] for item in results if "url" in item])
            page += 1

        except Exception as e:
            logging.error(f"Error fetching images from Openverse: {e}")
            break

    # Save updated page state
    page_state[category] = page
    with open(state_file, "w") as file:
        json.dump(page_state, file)

    return image_urls[:num_results]
    
def process_image_batch(urls, output_folder, db_path, category, target_count, fetch_func, state_file):
    """
    Process a batch of image URLs with deduplication checks and fetch more if needed.
    
    Args:
        urls (list): List of image URLs to process
        output_folder (str): Base output folder path
        db_path (str): Path to embeddings database
        category (str): Category name for subfolder organization
        target_count (int): Desired number of unique images
        fetch_func (callable): Function to fetch more image URLs
        state_file (str): Path to the state file for tracking pages
    """
    # Ensure category folder exists
    category_folder = os.path.join(output_folder, category)
    os.makedirs(category_folder, exist_ok=True)
    
    # Initialize database connection
    conn = init_database(db_path)
    
    # Create temporary folder for downloads
    temp_dir = os.path.join(output_folder, 'temp')
    os.makedirs(temp_dir, exist_ok=True)
    
    processed_count = 0
    processed_urls = set()
    
    while processed_count < target_count and urls:
        url = urls.pop(0)
        if url in processed_urls:
            continue
            
        processed_urls.add(url)
        
        try:
            # Download to temp file
            temp_path = os.path.join(temp_dir, 'temp_image.jpg')
            response = requests.get(url, stream=True, timeout=10)
            response.raise_for_status()
            
            with open(temp_path, 'wb') as f:
                f.write(response.content)
            
            # Generate embedding
            embedding = get_image_embedding(temp_path)
            
            # Check for duplicates
            if not is_similar_to_existing(conn, embedding):
                # Save the image if it's unique
                parsed_url = urlparse(url)
                file_name = os.path.basename(parsed_url.path)
                
                if not os.path.splitext(file_name)[1]:
                    content_type = response.headers.get('Content-Type', '')
                    ext = mimetypes.guess_extension(content_type.split(';')[0]) if content_type else '.jpg'
                    file_name += ext
                
                # Save image and embedding
                final_path = os.path.join(category_folder, file_name)
                shutil.move(temp_path, final_path)
                save_embedding_to_db(conn, url, embedding)
                processed_count += 1
                print(f"Saved new image: {file_name} ({processed_count}/{target_count})")
            else:
                print(f"Skipped duplicate from URL: {url}")
                os.remove(temp_path)
                
                # If we're running low on URLs, fetch more
                if len(urls) < (target_count - processed_count):
                    additional_urls = fetch_func(
                        category=category,
                        num_results=target_count - processed_count,
                        state_file=state_file
                    )
                    urls.extend([url for url in additional_urls if url not in processed_urls])
                    print(f"Fetched {len(additional_urls)} additional URLs")
                
        except Exception as e:
            print(f"Error processing {url}: {e}")
            if os.path.exists(temp_path):
                os.remove(temp_path)
                
    # Cleanup
    if os.path.exists(temp_dir):
        shutil.rmtree(temp_dir)
    conn.close()
    
    return processed_count

In [8]:
categories = [
    # Kitchen
    "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",

    # Indoor
    "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush",
    
    # Vehicle
    "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",

    # Animal
    "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
]

In [9]:
def daily_run(image_source, KAGGLE_INPUT, KAGGLE_OUTPUT, num_results_per_category):
    # Paths
    db_path = os.path.join(KAGGLE_OUTPUT, "image_embeddings.db")
    image_output_folder = os.path.join(KAGGLE_OUTPUT, "images")
    
    # Copy existing database if it exists
    existing_db = os.path.join(KAGGLE_INPUT, "image_embeddings.db")
    if os.path.exists(existing_db) and not os.path.exists(db_path):
        shutil.copy(existing_db, db_path)
        print("Copied existing database")
    
    conn = init_database(db_path)
    conn.close()
    
    # State files for different sources
    state_files = {
        'Pexels': "page_state_pexels.json",
        'Google Image': "page_state_gi.json",
        'Openverse': "page_state_openverse.json"
    }
    
    # Select fetch function based on source
    fetch_functions = {
        'Pexels': fetch_images_from_pexels,
        'Google Image': fetch_images_from_google_image,
        'Openverse': fetch_images_from_openverse
    }
    
    if image_source not in fetch_functions:
        raise ValueError(f"Invalid image source. Choose from: {list(fetch_functions.keys())}")
    
    fetch_func = fetch_functions[image_source]
    state_file = state_files[image_source]
    
    # Process each category
    for category in categories:
        # Count existing images in the category folder
        category_folder = os.path.join(image_output_folder, category)
        existing_images = len([f for f in os.listdir(category_folder) 
                                if os.path.isfile(os.path.join(category_folder, f))])
        
        # Calculate remaining images to fetch
        remaining_images = max(0, num_results_per_category - existing_images)
        
        if remaining_images <= 0:
            print(f"{category} already has {existing_images} images. Skipping.")
            continue
        
        print(f"\nProcessing category: {category}")
        print(f"Existing images: {existing_images}, Remaining to fetch: {remaining_images}")
        
        urls = fetch_func(
            category=category,
            num_results=remaining_images,
            state_file=state_file
        )
        print(f"Fetched {len(urls)} URLs for {category}")
        
        # Process images with deduplication and additional fetching if needed
        processed_count = process_image_batch(
            urls=urls,
            output_folder=image_output_folder,
            db_path=db_path,
            category=category,
            target_count=remaining_images,
            fetch_func=fetch_func,
            state_file=state_file
        )
        
        print(f"Completed processing {category}: {processed_count}/{remaining_images} unique images saved")

In [10]:
num_results_per_category = 1000

In [11]:
KAGGLE_INPUT = "D:/Project/VQA/vqa_dataset"
KAGGLE_OUTPUT = "D:/Project/VQA/vqa_dataset"
image_source='Pexels'

'Pexels'/'Google Image'/'Openverse'

In [None]:
print_image_counts(os.path.join(KAGGLE_OUTPUT, "images"))

In [12]:
daily_run(
    image_source=image_source, 
    KAGGLE_INPUT=KAGGLE_INPUT, 
    KAGGLE_OUTPUT=KAGGLE_OUTPUT,
    num_results_per_category=num_results_per_category
)


Processing category: bottle
Fetched 300 URLs for bottle
Skipped duplicate from URL: https://images.pexels.com/photos/1188649/pexels-photo-1188649.jpeg
Fetched 300 additional URLs
Skipped duplicate from URL: https://images.pexels.com/photos/1000084/pexels-photo-1000084.jpeg
Skipped duplicate from URL: https://images.pexels.com/photos/1435598/pexels-photo-1435598.jpeg
Skipped duplicate from URL: https://images.pexels.com/photos/1342529/pexels-photo-1342529.jpeg
Skipped duplicate from URL: https://images.pexels.com/photos/932577/pexels-photo-932577.jpeg
Skipped duplicate from URL: https://images.pexels.com/photos/1407846/pexels-photo-1407846.jpeg
Skipped duplicate from URL: https://images.pexels.com/photos/2421953/pexels-photo-2421953.jpeg
Skipped duplicate from URL: https://images.pexels.com/photos/159403/bottle-liquid-clean-equipment-159403.jpeg
Skipped duplicate from URL: https://images.pexels.com/photos/1479706/pexels-photo-1479706.jpeg
Skipped duplicate from URL: https://images.pexe



Skipped duplicate from URL: https://images.pexels.com/photos/15061732/pexels-photo-15061732.png
Skipped duplicate from URL: https://images.pexels.com/photos/15328478/pexels-photo-15328478.jpeg
Skipped duplicate from URL: https://images.pexels.com/photos/16065311/pexels-photo-16065311.jpeg
Skipped duplicate from URL: https://images.pexels.com/photos/15538110/pexels-photo-15538110.jpeg
Skipped duplicate from URL: https://images.pexels.com/photos/20431435/pexels-photo-20431435.jpeg
Skipped duplicate from URL: https://images.pexels.com/photos/25216807/pexels-photo-25216807.jpeg
Skipped duplicate from URL: https://images.pexels.com/photos/25155253/pexels-photo-25155253.jpeg
Skipped duplicate from URL: https://images.pexels.com/photos/29148452/pexels-photo-29148452.jpeg
Skipped duplicate from URL: https://images.pexels.com/photos/27489058/pexels-photo-27489058.jpeg
Skipped duplicate from URL: https://images.pexels.com/photos/29675497/pexels-photo-29675497.jpeg
Skipped duplicate from URL: htt

KeyboardInterrupt: 

In [None]:
print_image_counts(os.path.join(KAGGLE_OUTPUT, "images"))