Dataset used
I have collected data from TrashNet, which includes images of various trash items labelled trash (these are not recyclable), then there are multiple classes of recyclable trash such as paper, cardboard, glass, metal and plastic

1211 images are there for the training set and 508 images for the test set

Images are pre-labelled


In [None]:
import os
import shutil
import random
import pandas as pd
from sklearn.model_selection import train_test_split

### Creation of training and validation set of original RealWaste dataset

In [None]:
def create_splits(data_dir, output_dir, val_size=0.20):
    # Define class mapping for unifying class names
    class_mapping = {
        "cardboard": "cardboard",  
        "Cardboard": "cardboard",  
        "Glass": "glass",
        "glass": "glass",
        "Metal": "metal",
        "metal": "metal",
        "paper": "paper",
        "Paper": "paper",
        "plastic": "plastic",
        "Plastic": "plastic",
        "trash": "trash",
        "Miscellaneous Trash": "trash",
    }

    # Folders to ignore
    ignored_folders = {"Textile Trash", "Vegetation","Food Organics"}

    train_dir = os.path.join(output_dir, 'train')
    val_dir = os.path.join(output_dir, 'val')

    for d in [train_dir, val_dir]:
        os.makedirs(d, exist_ok=True)

    records = []

    for class_name in os.listdir(data_dir):
        if class_name in ignored_folders:
            print(f"Skipping {class_name} (ignored)")
            continue

        class_dir = os.path.join(data_dir, class_name)
        if not os.path.isdir(class_dir):
            continue
        
        mapped_class = class_mapping.get(class_name, class_name)

        for d in [train_dir, val_dir]:
            class_dir_out = os.path.join(d, mapped_class)
            os.makedirs(class_dir_out, exist_ok=True)

        images = [os.path.join(class_dir, img) for img in os.listdir(class_dir) if img.endswith(('png', 'jpg', 'jpeg'))]
        train_images, val_images = train_test_split(images, test_size=val_size, random_state=42)
        
        print(f"Class {class_name} mapped to {mapped_class}: {len(train_images)} training, {len(val_images)} validation")

        def copy_images(image_list, output_dir, split_type):
            for image in image_list:
                dest = os.path.join(output_dir, mapped_class, os.path.basename(image))
                shutil.copy(image, dest)
                records.append((os.path.basename(image), class_name, mapped_class, split_type))

        copy_images(train_images, train_dir, "train")
        copy_images(val_images, val_dir, "val")

    df = pd.DataFrame(records, columns=['filename', 'original_class', 'mapped_class', 'split_type'])
    df.to_csv(os.path.join(output_dir, 'class_mapping.csv'), index=False)

    print("Dataset splitting and mapping completed!")

original_data_dir = '../data/realwaste-main/RealWaste'
output_data_dir = '../data/dataset_split'

# Create splits
create_splits(original_data_dir, output_data_dir)


Class Cardboard mapped to cardboard: 368 training, 93 validation
Skipping Food Organics (ignored)
Class Glass mapped to glass: 336 training, 84 validation
Class Metal mapped to metal: 632 training, 158 validation
Class Miscellaneous Trash mapped to trash: 396 training, 99 validation
Class Paper mapped to paper: 400 training, 100 validation
Class Plastic mapped to plastic: 736 training, 185 validation
Skipping Textile Trash (ignored)
Skipping Vegetation (ignored)
Dataset splitting and mapping completed!


### Create the test set from the TrashNet and add to data_split folder 

In [None]:

def copy_all_to_test(data_dir, output_dir):
    # Define class mapping for unifying class names
    class_mapping = {
        "cardboard": "cardboard",  
        "Cardboard": "cardboard",  
        "Glass": "glass",
        "glass": "glass",
        "Metal": "metal",
        "metal": "metal",
        "paper": "paper",
        "Paper": "paper",
        "plastic": "plastic",
        "Plastic": "plastic",
        "trash": "trash",
        "Miscellaneous Trash": "trash",
    }

    # Folders to ignore
    ignored_folders = {"Textile Trash", "Vegetation","Food Organics"}

    test_dir = os.path.join(output_dir, 'test')
    os.makedirs(test_dir, exist_ok=True)

    records = []

    for class_name in os.listdir(data_dir):
        if class_name in ignored_folders:
            print(f"Skipping {class_name} (ignored)")
            continue

        class_dir = os.path.join(data_dir, class_name)
        if not os.path.isdir(class_dir):
            continue

        mapped_class = class_mapping.get(class_name, class_name)

        class_test_dir = os.path.join(test_dir, mapped_class)
        os.makedirs(class_test_dir, exist_ok=True)

        images = [os.path.join(class_dir, img) for img in os.listdir(class_dir) if img.endswith(('png', 'jpg', 'jpeg'))]

        print(f"Copying {len(images)} images from {class_name} to {mapped_class} in test set...")

        for image in images:
            dest = os.path.join(class_test_dir, os.path.basename(image))
            shutil.copy(image, dest)
            records.append((os.path.basename(image), class_name, mapped_class, "test"))

    df = pd.DataFrame(records, columns=['filename', 'original_class', 'mapped_class', 'split_type'])
    df.to_csv(os.path.join(output_dir, 'class_mapping.csv'), index=False)

    print("All images successfully copied to the test set!")

original_data_dir = '../data/dataset-resized'
output_data_dir = '../data/dataset_split'

copy_all_to_test(original_data_dir, output_data_dir)

Copying 403 images from cardboard to cardboard in test set...
Copying 501 images from glass to glass in test set...
Copying 410 images from metal to metal in test set...
Copying 594 images from paper to paper in test set...
Copying 482 images from plastic to plastic in test set...
Copying 137 images from trash to trash in test set...
All images successfully copied to the test set!


### Check the number of the images that each folder has

In [None]:
main_dir = "../data/combined_dataset"

image_counts = {}

for subfolder in os.listdir(main_dir):
    subfolder_path = os.path.join(main_dir, subfolder)
    
    if os.path.isdir(subfolder_path):
        image_count = len([file for file in os.listdir(subfolder_path) if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff'))])
        image_counts[subfolder] = image_count

for subfolder, count in image_counts.items():
    print(f"{subfolder}: {count} images")


cardboard: 1472 images
glass: 1472 images
metal: 1472 images
paper: 1472 images
plastic: 1472 images
trash: 1472 images


### Creates the combine dataset by sampling the results of the different augmented techniques

In [None]:
source_dirs = {
    'diffusion': '../data/dataset_diffusion_balanced/train',
    'manipulation': '../data/dataset_balanced/train',
    'erasing': '../data/dataset_erasing_augmented/train'
}
output_dir = '../data/combined_dataset'
categories = ['cardboard', 'glass', 'plastic', 'trash', 'paper', 'metal']

aug_suffixes = {
    'diffusion': 'aug_diffusion_',
    'manipulation': 'aug_manipulation_',
    'erasing': 'aug_erasing_'
}

# Target number of images per category
TOTAL_IMAGES = 1472

def create_output_structure():
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    for category in categories:
        category_path = os.path.join(output_dir, category)
        if not os.path.exists(category_path):
            os.makedirs(category_path)

def get_original_images(source_dir, category):
    category_path = os.path.join(source_dir, category)
    return [f for f in os.listdir(category_path) 
            if os.path.isfile(os.path.join(category_path, f)) 
            and f.lower().endswith(('.jpg', '.jpeg', '.png')) 
            and 'aug' not in f.lower()]

def get_augmented_images_for_instance(source_dir, category, instance_base):
    category_path = os.path.join(source_dir, category)
    return [f for f in os.listdir(category_path) 
            if os.path.isfile(os.path.join(category_path, f)) 
            and f.lower().endswith(('.jpg', '.jpeg', '.png')) 
            and 'aug' in f.lower() 
            and f.startswith(instance_base)]

def get_all_augmented_images(source_dir, category):
    category_path = os.path.join(source_dir, category)
    return [f for f in os.listdir(category_path) 
            if os.path.isfile(os.path.join(category_path, f)) 
            and f.lower().endswith(('.jpg', '.jpeg', '.png')) 
            and 'aug' in f.lower()]

def validate_input_datasets():
    for source, dir_path in source_dirs.items():
        for category in categories:
            total_images = len([f for f in os.listdir(os.path.join(dir_path, category)) 
                              if os.path.isfile(os.path.join(dir_path, category, f)) 
                              and f.lower().endswith(('.jpg', '.jpeg', '.png'))])
            if total_images != TOTAL_IMAGES:
                raise ValueError(f"Expected {TOTAL_IMAGES} total images in {source}/{category}, "
                               f"but found {total_images}.")

def create_combined_dataset():
    validate_input_datasets()
    create_output_structure()
    
    for category in categories:
        orig_by_source = {source: get_original_images(dir_path, category) 
                         for source, dir_path in source_dirs.items()}
        
        orig_images_with_sources = {}
        for source, images in orig_by_source.items():
            for img in images:
                if img not in orig_images_with_sources:
                    orig_images_with_sources[img] = []
                orig_images_with_sources[img].append(source)
        
        orig_images_list = list(orig_images_with_sources.keys())
        total_orig_images = len(orig_images_list)
        
        total_aug_images = TOTAL_IMAGES - total_orig_images
        aug_per_source = total_aug_images // 3
        aug_remainder = total_aug_images % 3
        
        aug_target_counts = {}
        remaining = aug_remainder
        for source in source_dirs:
            aug_target_counts[source] = aug_per_source + (1 if remaining > 0 else 0)
            remaining -= 1 if remaining > 0 else 0
        
        aug_by_source = {source: get_all_augmented_images(dir_path, category) 
                        for source, dir_path in source_dirs.items()}
        used_aug_images = {source: set() for source in source_dirs}
        
        count_by_source = {'diffusion': 0, 'manipulation': 0, 'erasing': 0}
        orig_copied_by_source = {'diffusion': 0, 'manipulation': 0, 'erasing': 0}
        aug_images_copied = {'diffusion': 0, 'manipulation': 0, 'erasing': 0}
        
        # Step 1: Copy all originals, balancing sources
        orig_images_shuffled = orig_images_list.copy()
        random.shuffle(orig_images_shuffled)
        
        orig_per_source = total_orig_images // 3
        orig_remainder = total_orig_images % 3
        orig_counts = {}
        remaining_orig = orig_remainder
        for source in source_dirs:
            orig_counts[source] = orig_per_source + (1 if remaining_orig > 0 else 0)
            remaining_orig -= 1 if remaining_orig > 0 else 0
        
        image_counter = 1
        for orig_image in orig_images_shuffled:
            if image_counter > TOTAL_IMAGES:
                break
            sources = orig_images_with_sources[orig_image]
            selected_source = None
            for source in sources:
                if orig_copied_by_source[source] < orig_counts[source]:
                    selected_source = source
                    break
            if selected_source is None:
                continue
            
            suffix = 'original_'
            src_path = os.path.join(source_dirs[selected_source], category, orig_image)
            ext = os.path.splitext(orig_image)[1]
            new_filename = f"{category}_{suffix}{image_counter}{ext}"
            dst_path = os.path.join(output_dir, category, new_filename)
            shutil.copy2(src_path, dst_path)
            count_by_source[selected_source] += 1
            orig_copied_by_source[selected_source] += 1
            image_counter += 1
        
        orig_copied = sum(orig_copied_by_source.values())
        
        # Step 2: Copy instance-specific augmented images to meet aug_target_counts
        for orig_image in orig_images_shuffled:
            if image_counter > TOTAL_IMAGES:
                break
            
            instance_base = os.path.splitext(orig_image)[0]
            aug_available = {}
            for source, dir_path in source_dirs.items():
                aug_images = get_augmented_images_for_instance(dir_path, category, instance_base)
                aug_available[source] = [img for img in aug_images if img not in used_aug_images[source]]
            
            sources_sorted = sorted(source_dirs.keys(), 
                                  key=lambda s: (aug_target_counts[s] - aug_images_copied[s]), 
                                  reverse=True)
            for source in sources_sorted:
                if image_counter > TOTAL_IMAGES:
                    break
                if aug_images_copied[source] >= aug_target_counts[source]:
                    continue
                if aug_available[source]:
                    aug_image = random.choice(aug_available[source])
                    suffix = aug_suffixes[source]
                    src_path = os.path.join(source_dirs[source], category, aug_image)
                    ext = os.path.splitext(aug_image)[1]
                    new_filename = f"{category}_{suffix}{image_counter}{ext}"
                    dst_path = os.path.join(output_dir, category, new_filename)
                    shutil.copy2(src_path, dst_path)
                    used_aug_images[source].add(aug_image)
                    if aug_image in aug_by_source[source]:
                        aug_by_source[source].remove(aug_image)
                    count_by_source[source] += 1
                    aug_images_copied[source] += 1
                    image_counter += 1
        
        # Step 3: Fill remaining slots to reach aug_target_counts exactly
        for source in source_dirs:
            remaining_for_source = max(0, aug_target_counts[source] - aug_images_copied[source])
            if remaining_for_source > 0 and image_counter <= TOTAL_IMAGES:
                available_images = [img for img in aug_by_source[source] if img not in used_aug_images[source]]
                if len(available_images) < remaining_for_source:
                    raise ValueError(f"Not enough remaining augmented images in {source}/{category}: "
                                   f"found {len(available_images)}, need {remaining_for_source}")
                selected_images = random.sample(available_images, remaining_for_source)
                suffix = aug_suffixes[source]
                for image in selected_images:
                    if image_counter > TOTAL_IMAGES:
                        break
                    src_path = os.path.join(source_dirs[source], category, image)
                    ext = os.path.splitext(image)[1]
                    new_filename = f"{category}_{suffix}{image_counter}{ext}"
                    dst_path = os.path.join(output_dir, category, new_filename)
                    shutil.copy2(src_path, dst_path)
                    count_by_source[source] += 1
                    aug_images_copied[source] += 1
                    image_counter += 1
        
        total_copied = sum(count_by_source.values())
        total_copied = min(total_copied, TOTAL_IMAGES)
        total_aug_copied = sum(aug_images_copied.values())
        
        print(f"{category}: Copied {total_copied} images "
              f"({orig_copied} originals, {total_aug_copied} augmented - "
              f"diffusion: {count_by_source['diffusion']} (orig: {orig_copied_by_source['diffusion']}, aug: {aug_images_copied['diffusion']}, {count_by_source['diffusion']/total_copied:.1%}), "
              f"manipulation: {count_by_source['manipulation']} (orig: {orig_copied_by_source['manipulation']}, aug: {aug_images_copied['manipulation']}, {count_by_source['manipulation']/total_copied:.1%}), "
              f"erasing: {count_by_source['erasing']} (orig: {orig_copied_by_source['erasing']}, aug: {aug_images_copied['erasing']}, {count_by_source['erasing']/total_copied:.1%}))")

if __name__ == "__main__":
    random.seed(42)
    try:
        create_combined_dataset()
        print("Combined dataset creation completed!")
    except ValueError as e:
        print(f"Error: {e}")

cardboard: Copied 1472 images (368 originals, 1104 augmented - diffusion: 491 (orig: 123, aug: 368, 33.4%), manipulation: 491 (orig: 123, aug: 368, 33.4%), erasing: 490 (orig: 122, aug: 368, 33.3%))
glass: Copied 1472 images (336 originals, 1136 augmented - diffusion: 491 (orig: 112, aug: 379, 33.4%), manipulation: 491 (orig: 112, aug: 379, 33.4%), erasing: 490 (orig: 112, aug: 378, 33.3%))
plastic: Copied 1472 images (736 originals, 736 augmented - diffusion: 492 (orig: 246, aug: 246, 33.4%), manipulation: 490 (orig: 245, aug: 245, 33.3%), erasing: 490 (orig: 245, aug: 245, 33.3%))
trash: Copied 1472 images (396 originals, 1076 augmented - diffusion: 491 (orig: 132, aug: 359, 33.4%), manipulation: 491 (orig: 132, aug: 359, 33.4%), erasing: 490 (orig: 132, aug: 358, 33.3%))
paper: Copied 1472 images (400 originals, 1072 augmented - diffusion: 492 (orig: 134, aug: 358, 33.4%), manipulation: 490 (orig: 133, aug: 357, 33.3%), erasing: 490 (orig: 133, aug: 357, 33.3%))
metal: Copied 1472 i