In [None]:
import numpy as np
import os
import glob
import shutil

In [None]:
# SPECIFY DIRECTORY HERE
img_dir = 'cnn_data2'

In [None]:
ratio_train = 0.7
ratio_val = 0.15

train_dir = 'cnn_data_train'
val_dir = 'cnn_data_val'
test_dir = 'cnn_data_test'

# Create new directories (overwrite old ones with the same name)
for directory in [train_dir, val_dir, test_dir]:
    if os.path.isdir('./' + directory):
        shutil.rmtree(directory)
    os.mkdir(directory)

In [None]:
# Load all data
img_path = os.path.join(img_dir, '*g')
all_imgs = np.array(glob.glob(img_path))

# Split on images with robot images and non-robot images
robot_imgs = np.array(list(filter(lambda x: x[-5] == '1', all_imgs)))
non_robot_imgs = np.setdiff1d(all_imgs, robot_imgs)

# Shuffle all data
np.random.seed(0)
np.random.shuffle(robot_imgs)
np.random.shuffle(non_robot_imgs)

# Length of the robot and non-robot sets
n_robot = len(robot_imgs)
n_non_robot = len(non_robot_imgs)

# Split data
robot_train_idx = round(ratio_train * n_robot)
robot_val_idx = round((ratio_train + ratio_val) * n_robot)
non_robot_train_idx = round(ratio_train * n_non_robot)
non_robot_val_idx = round((ratio_train + ratio_val) * n_non_robot)

# Training set
train_robot = robot_imgs[:robot_train_idx]
train_non_robot = non_robot_imgs[:non_robot_train_idx]
train_set = np.concatenate([train_robot, train_non_robot])

# Validation set
val_robot = robot_imgs[robot_train_idx:robot_val_idx]
val_non_robot = non_robot_imgs[non_robot_train_idx:non_robot_val_idx]
val_set = np.concatenate([val_robot, val_non_robot])

# Test set
test_robot = robot_imgs[robot_val_idx:]
test_non_robot = non_robot_imgs[non_robot_val_idx:]
test_set = np.concatenate([test_robot, test_non_robot])

In [None]:
for train_img in train_set:
    img_name = os.path.basename(train_img)
    shutil.copyfile(train_img, os.path.join(train_dir, img_name))
    
for val_img in val_set:
    img_name = os.path.basename(val_img)
    shutil.copyfile(val_img, os.path.join(val_dir, img_name))
    
for test_img in test_set:
    img_name = os.path.basename(test_img)
    shutil.copyfile(test_img, os.path.join(test_dir, img_name))