In [1]:
import os
import numpy as np
from PIL import Image
import imagehash

DIR_REAL_TRAIN = '/Users/alexchilton/Downloads/working/train'
DIR_REAL_TEST = '/Users/alexchilton/Downloads/working/test'
DIR_REAL_VALIDATION = '/Users/alexchilton/Downloads/working/validation'

def load_images_with_hashes(directory):
    images_with_hashes = {}
    for class_name in os.listdir(directory):
        class_path = os.path.join(directory, class_name)
        if os.path.isdir(class_path):
            images_with_hashes[class_name] = {}
            for file_name in os.listdir(class_path):
                if file_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    file_path = os.path.join(class_path, file_name)
                    try:
                        img = Image.open(file_path)
                        img = img.convert('RGB')  # Ensure all images are in RGB format
                        img_hash = imagehash.average_hash(img)
                        images_with_hashes[class_name][file_name] = img_hash
                    except Exception as e:
                        print(f"Error loading image {file_path}: {e}")
    return images_with_hashes

def find_duplicates(train_hashes, test_hashes):
    duplicates = []
    for class_name, test_images in test_hashes.items():
        train_images = train_hashes.get(class_name, {})
        for test_file, test_hash in test_images.items():
            for train_file, train_hash in train_images.items():
                if test_hash == train_hash:
                    duplicates.append((class_name, train_file))
    return duplicates

def delete_duplicates(duplicates, train_directory):
    for class_name, train_file in duplicates:
        file_path = os.path.join(train_directory, class_name, train_file)
        try:
            os.remove(file_path)
            print(f"Deleted duplicate file: {file_path}")
        except Exception as e:
            print(f"Error deleting file {file_path}: {e}")

# Load images and compute hashes
train_hashes = load_images_with_hashes(DIR_REAL_TRAIN)
test_hashes = load_images_with_hashes(DIR_REAL_TEST)

# Find duplicates
duplicates = find_duplicates(train_hashes, test_hashes)

# Delete duplicates
delete_duplicates(duplicates, DIR_REAL_TRAIN)



KeyboardInterrupt: 