In [103]:
import os
import shutil
import random
from PIL import Image

# Define the root directory and the split percentages
root_dir = 'Data_test/Rocks'
train_percentage = 0.8
validation_test_split = 0.5

# Get a list of all rock types (subdirectories in the root directory)
rock_types = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]

for rock_type in rock_types:
    # Create train, validation, and test directories for each rock type
    train_dir = os.path.join(root_dir, 'train', rock_type)
    validation_dir = os.path.join(root_dir, 'validation', rock_type)
    test_dir = os.path.join(root_dir, 'test', rock_type)

    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(validation_dir, exist_ok=True)
    os.makedirs(test_dir, exist_ok=True)

    # Get a list of all image files in the current rock type directory
    images = [f for f in os.listdir(os.path.join(root_dir, rock_type)) if os.path.isfile(os.path.join(root_dir, rock_type, f))]

    # Shuffle the images to ensure randomness
    random.shuffle(images)

    # Calculate the split index
    split_index = int(len(images) * train_percentage)

    # Split the images into training and testing sets
    train_images = images[:split_index]
    validation_test_images = images[split_index:]

    validation_test_index = int(len(validation_test_images) * validation_test_split)

    # Split the images into validation and test sets
    test_images = validation_test_images[:validation_test_index]
    validation_images = validation_test_images[validation_test_index:]

    # Function to convert an image to PNG
    def convert_to_png(src_path, dst_dir, filename):
        try:
            with Image.open(src_path) as img:
                # Convert to RGBA for proper PNG conversion
                img = img.convert('RGB')
                # Define path for PNG file
                png_filename = os.path.splitext(filename)[0] + '.png'
                dst_path = os.path.join(dst_dir, png_filename)
                img.save(dst_path, 'PNG')
        except Exception as e:
            print(f"Failed to convert {src_path}. Error: {e}")

    # Move the images to their respective directories
    for image in train_images:
        src_path = os.path.join(root_dir, rock_type, image)
        convert_to_png(src_path, train_dir, image)

    for image in test_images:
        src_path = os.path.join(root_dir, rock_type, image)
        convert_to_png(src_path, test_dir, image)

    for image in validation_images:
        src_path = os.path.join(root_dir, rock_type, image)
        convert_to_png(src_path, validation_dir, image)

print("Image conversion, splitting, and moving completed.")

Failed to convert Data_test/Rocks/Basalt/.DS_Store. Error: cannot identify image file 'Data_test/Rocks/Basalt/.DS_Store'
Failed to convert Data_test/Rocks/Granite/.DS_Store. Error: cannot identify image file 'Data_test/Rocks/Granite/.DS_Store'
Failed to convert Data_test/Rocks/Sandstone/.DS_Store. Error: cannot identify image file 'Data_test/Rocks/Sandstone/.DS_Store'
Failed to convert Data_test/Rocks/Chalk/.DS_Store. Error: cannot identify image file 'Data_test/Rocks/Chalk/.DS_Store'
Failed to convert Data_test/Rocks/Slate/.DS_Store. Error: cannot identify image file 'Data_test/Rocks/Slate/.DS_Store'
Image conversion, splitting, and moving completed.


In [106]:
from datasets import load_dataset

dataset = load_dataset("imagefolder", data_dir="Data_test/Rocks")

Resolving data files:   0%|          | 0/249 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/34 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/32 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/249 [00:00<?, ?files/s]

Downloading data:   0%|          | 0/34 [00:00<?, ?files/s]

Downloading data:   0%|          | 0/32 [00:00<?, ?files/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [107]:
dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 249
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 34
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 32
    })
})