In [None]:
import os
import time
import requests
import base64
import json
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.common.by import By
from selenium.webdriver.support.wait import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from urllib.parse import urlparse, parse_qs

def get_image_extension(image_url):
    mime_type = parse_qs(urlparse(image_url).path).get("image", [""])[0]
    extensions = {
        "image/jpeg": "jpg",
        "image/png": "png",
        "image/gif": "gif",
        # Add more MIME types and extensions as needed
    }
    return extensions.get(mime_type, "jpg")

def download_image(image_url, save_path):
    if image_url.startswith('data:image'):
        # Extract the base64 data from the URL
        _, encoded_data = image_url.split(',', 1)
        image_data = base64.b64decode(encoded_data)
        
        # Get the original image extension
        image_extension = get_image_extension(image_url)
        save_path_with_extension = f"{save_path}.{image_extension}"

        with open(save_path_with_extension, 'wb') as file:
            file.write(image_data)
    else:
        response = requests.get(image_url, stream=True)
        if response.status_code == 200:
            with open(f"{save_path}.jpg", 'wb') as file:
                for chunk in response.iter_content(1024):
                    file.write(chunk)

def create_folder_if_not_exists(folder_name):
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)

def save_image_records(image_records, output_path):
    with open(output_path, 'w') as f:
        json.dump(image_records, f)

def load_image_records(input_path):
    with open(input_path, 'r') as f:
        image_records = json.load(f)
    return image_records

chrome_driver_path = "" # path to chrome driver in your computer (get chrome driver here https://sites.google.com/chromium.org/driver/downloads)
google_url = "https://www.google.com/"
images_folder = "" # the folder of images that you want to search for
image_records_path = "" # the path where you want to save the records (eg: '/Users/yourname/projectname/record.json')
last_queried_info_path = "" # the path where you want to save the last queried info (eg: '/Users/yourname/projectname/last_queried_info.json')

def save_last_image(idx, image_file):
    # Save the last queried image index and name
    last_queried_info = {
        'last_queried_index': idx,
        'last_queried_image': image_file
    }
    with open(last_queried_info_path, 'w') as f:
        json.dump(last_queried_info, f)

def process_images_from_folder(folder_path, start_index=0):
    service = Service(executable_path=chrome_driver_path)
    options = webdriver.ChromeOptions()
    driver = webdriver.Chrome(service=service, options=options)

    # Create the output folder if it doesn't exist
    output_folder = "output"
    create_folder_if_not_exists(output_folder)

    try:
        # Loop through all files in the folder starting from the specified index
        for idx, image_file in enumerate(os.listdir(folder_path)):
            if idx < start_index:
                continue

            # Open Google
            driver.get(google_url)
            
            # Process only image files
            if image_file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.gif', '.webp')):
                image_path = os.path.join(folder_path, image_file)

                # start_search = time.time()
                
                # Click the camera button to search by image
                search_by_img_btn = driver.find_element(By.XPATH, '//div[@aria-label="Search by image"]')
                search_by_img_btn.click()
        
                # Wait till the encoded_image input is loaded
                WebDriverWait(driver, 20).until(EC.presence_of_element_located((By.XPATH, "//input[@name='encoded_image']")))
        
                # Upload image to encoded_image input
                upload_input = driver.find_element(By.XPATH, "//input[@name='encoded_image']")
                upload_input.send_keys(image_path)
        
                # Wait for the result to load (adjust the time.sleep if needed)
                time.sleep(3)
                # WebDriverWait(driver, 20).until(EC.presence_of_element_located((By.XPATH, '//img[@class="wETe9b jFVN1"]')))

                # end_search = time.time()
                
                # Find the returned images by class name
                returned_image_elements = driver.find_elements(By.XPATH, '//img[@class="wETe9b jFVN1"]')
        
                # Create a nested folder based on the image file name
                nested_folder = os.path.join(output_folder, os.path.splitext(image_file)[0])
                create_folder_if_not_exists(nested_folder)
        
                # List to store image URLs for this query
                returned_image_urls = []
                
                # start_download = time.time()
                
                # Iterate the returned image elements to save them
                for index, returned_image_element in enumerate(returned_image_elements):
                    # Get the "data-src" attribute value
                    returned_image_url = returned_image_element.get_attribute('data-src')
                    if returned_image_url is None:
                        # If "data-src" is not available, use the "src" attribute value
                        returned_image_url = returned_image_element.get_attribute('src')
                    
                    # Download the image
                    save_path = os.path.join(nested_folder, f'{image_file}_returned_{index+1}')
                    download_image(returned_image_url, save_path)

                    # Store the image URL in the list
                    returned_image_urls.append(returned_image_url)

                # end_download = time.time()
                
                # Store the image record in the dictionary
                image_records[image_file] = returned_image_urls

                # print('Search takes: ', end_search - start_search)
                # print('Download takes: ', end_download - start_download)

            else:
                print(f"Skipping non-image file: {image_file}")

    except Exception as e:
        print(f"An error occurred: {e}")
        save_last_image(idx, image_file)
    finally:
        driver.quit()
        # Save image records after processing
        save_image_records(image_records, image_records_path)
        save_last_image(idx, image_file)

if __name__ == "__main__":
    # Load image records if they exist
    if os.path.exists(image_records_path):
        print('Loading existing records')
        image_records = load_image_records(image_records_path)
    else:
        print('No existing records, initialize it now')
        image_records = {}

    # Load last queried image index and name if available
    if os.path.exists(last_queried_info_path):
        with open(last_queried_info_path, 'r') as f:
            last_queried_info = json.load(f)
        start_index = last_queried_info.get('last_queried_index', 0)
    else:
        start_index = 0

    start_all = time.time()
    process_images_from_folder(images_folder, start_index)
    end_all = time.time()

    print('All takes: ', end_all - start_all)
