In [53]:
import os
from sklearn.model_selection import train_test_split
import pickle
import random
from collections import defaultdict

# Step 1 - Segment into Train/Validation/Testing + Statisfy Data

In [54]:
def segment_and_save_stratified_data(data_dir, save_path):
    """
    Splits the dataset into stratified training, validation, and test sets and saves them to files,
    with filenames that include the save_path prefix directly.

    Parameters:
    - data_dir (str): Directory containing the dataset, organized into subdirectories for each category.
    - save_path (str): Base prefix to be used in the filename for saving the dataset splits.

    Returns:
    - None
    """
    categories = ['Trash', 'Plastic', 'Paper', 'Metal', 'Glass', 'Cardboard']
    image_paths = []
    labels = []

    # Collect all image paths and their corresponding labels
    for label, category in enumerate(categories):
        category_dir = os.path.join(data_dir, category)
        for file in os.listdir(category_dir):
            if file.endswith('.jpg') or file.endswith('.png'):
                image_paths.append(os.path.join(category_dir, file))
                labels.append(label)

    # Split data into train and test with stratification (85% train, 15% test)
    train_paths, test_paths, train_labels, test_labels = train_test_split(
        image_paths, labels, test_size=0.15, stratify=labels, random_state=0)

    # Split train data into train and validation with stratification (remaining 85% train, 15% validation of the train set)
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        train_paths, train_labels, test_size=0.175, stratify=train_labels, random_state=0)  # 0.175 x 0.85 ≈ 0.15

    # Save the data splits
    with open(f"{save_path}_train_data.pkl", 'wb') as f:
        pickle.dump((train_paths, train_labels), f)
    with open(f"{save_path}_val_data.pkl", 'wb') as f:
        pickle.dump((val_paths, val_labels), f)
    with open(f"{save_path}_test_data.pkl", 'wb') as f:
        pickle.dump((test_paths, test_labels), f)

    print(f"Data splits saved with prefix '{save_path}'")



In [55]:
folder_path = '../../data/dataset-resized' # My local path
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1' # for timm
desired_name = "full_dataset_segmented"
segment_and_save_stratified_data(folder_path, desired_name)

Data splits saved with prefix 'full_dataset_segmented'


### Now, double check it worked well and show the distribution

In [56]:
def load_dataset(file_path):
    """
    Loads the dataset from a pickle file.

    Parameters:
    - file_path (str): Path to the pickle file where the dataset is stored.

    Returns:
    - tuple: Two lists containing the image file paths and corresponding labels.
    """
    with open(file_path, 'rb') as file:
        image_paths, labels = pickle.load(file)
    return image_paths, labels


def count_categories(labels):
    """
    Counts the occurrences of each category in a list of labels.

    Parameters:
    - labels (list of int): List of label indices.

    Returns:
    - dict: Dictionary with counts of each category.
    """
    categories = ['Trash', 'Plastic', 'Paper', 'Metal', 'Glass', 'Cardboard']
    counts = {category: 0 for category in categories}
    for label in labels:
        counts[categories[label]] += 1
    return counts

In [57]:
# Filenames for the dataset splits
train_file = desired_name + '_train_data.pkl'
val_file = desired_name + '_val_data.pkl'
test_file = desired_name + '_test_data.pkl'

# Load the datasets
train_paths, train_labels = load_dataset(train_file)
val_paths, val_labels = load_dataset(val_file)
test_paths, test_labels = load_dataset(test_file)

# Count categories in each dataset
train_counts = count_categories(train_labels)
val_counts = count_categories(val_labels)
test_counts = count_categories(test_labels)

# Print the results
print("Training Data Category Counts:", train_counts)
print("Validation Data Category Counts:", val_counts)
print("Test Data Category Counts:", test_counts)

Training Data Category Counts: {'Trash': 96, 'Plastic': 338, 'Paper': 417, 'Metal': 287, 'Glass': 351, 'Cardboard': 282}
Validation Data Category Counts: {'Trash': 20, 'Plastic': 72, 'Paper': 88, 'Metal': 61, 'Glass': 75, 'Cardboard': 60}
Test Data Category Counts: {'Trash': 21, 'Plastic': 72, 'Paper': 89, 'Metal': 62, 'Glass': 75, 'Cardboard': 61}


# Step 2 - Segmment into smaller subsets (In this case, groups of 25, 50, and 100 of each class)

In [58]:
def load_data(filename):
    with open(filename, 'rb') as file:
        image_paths, labels = pickle.load(file)
    return image_paths, labels

def save_data(filename, image_paths, labels):
    with open(filename, 'wb') as file:
        pickle.dump((image_paths, labels), file)

def create_subsets(image_paths, labels, sizes, prefix):
    """
    Creates subsets of the dataset for different sizes specified, ensuring no duplicates, and saves them.

    Parameters:
    - image_paths (list of str): Paths to the images in the dataset.
    - labels (list of int): Corresponding labels for the images.
    - sizes (list of int): Sizes of the subsets to create for each category.
    - prefix (str): Prefix for naming the output files.

    Returns:
    - None
    """
    # Group image paths by label
    categorized_data = defaultdict(list)
    for path, label in zip(image_paths, labels):
        categorized_data[label].append(path)

    categories = ['Trash', 'Plastic', 'Paper', 'Metal', 'Glass', 'Cardboard']
    
    # Ensure deterministic randomness, in case we re-run this cell.
    random.seed(0)

    for size in sizes:
        subset_paths = []
        subset_labels = []
        # Create subset for each category
        for label, cat_name in enumerate(categories):
            if len(categorized_data[label]) >= size:
                random.shuffle(categorized_data[label])
                selected_paths = categorized_data[label][:size]
            else:
                selected_paths = categorized_data[label]  # If not enough data, take all

            subset_paths.extend(selected_paths)
            subset_labels.extend([label] * len(selected_paths))

        # Save the subset to a file
        filename = f"{prefix}_{size}_data_segmented.pkl"
        save_data(filename, subset_paths, subset_labels)
        print(f"Subset of size {size} for each category saved to {filename}")

# Load original training data
train_file = 'full_dataset_segmented_train_data.pkl'
train_paths, train_labels = load_data(train_file)

# Define sizes for the subsets
sizes = [25, 50, 100]

# Create and save subsets
prefix = 'subset'
create_subsets(train_paths, train_labels, sizes, prefix)

# Loop through each size, load the dataset, and print the category counts
for size in sizes:
    subset_file = f'{prefix}_{size}_data_segmented.pkl'
    paths, labels = load_data(subset_file)
    counts = count_categories(labels)
    print(f"Data Category Counts for {size} instances of each category:", counts)



Subset of size 25 for each category saved to subset_25_data_segmented.pkl
Subset of size 50 for each category saved to subset_50_data_segmented.pkl
Subset of size 100 for each category saved to subset_100_data_segmented.pkl
Data Category Counts for 25 instances of each category: {'Trash': 25, 'Plastic': 25, 'Paper': 25, 'Metal': 25, 'Glass': 25, 'Cardboard': 25}
Data Category Counts for 50 instances of each category: {'Trash': 50, 'Plastic': 50, 'Paper': 50, 'Metal': 50, 'Glass': 50, 'Cardboard': 50}
Data Category Counts for 100 instances of each category: {'Trash': 96, 'Plastic': 100, 'Paper': 100, 'Metal': 100, 'Glass': 100, 'Cardboard': 100}
