In [48]:
import os
import shutil
import random

In [49]:
# Input paths
detection_path = '../data/segmentation-dataset/segmentation-dataset/'
images_path = os.path.join(detection_path, 'images')
labels_path = os.path.join(detection_path, 'labels')

# Output paths
destination_dir = "../data/segmentation-dataset-6by2/"
splits = ['train', 'val', 'test']
for split in splits:
    os.makedirs(os.path.join(destination_dir, split, 'images'), exist_ok=True)
    os.makedirs(os.path.join(destination_dir, split, 'labels'), exist_ok=True)


In [50]:
# Collect 6by2 samples
images_6by2 = [f for f in os.listdir(images_path) if '6by2' in f]
labels_6by2 = [f for f in os.listdir(labels_path) if '6by2' in f]

In [51]:
# Ensure Lengths match
len(images_6by2), len(labels_6by2)

(500, 500)

In [52]:
# Match filenames (strip extensions)
img_bases = set([f[:-4] for f in images_6by2])
lbl_bases = set([f[:-4] for f in labels_6by2])
common_bases = list(img_bases & lbl_bases)

In [53]:
# Filter to only existing file pairs
valid_bases = []
for base in common_bases:
    img_path = os.path.join(images_path, base + '.jpg')
    lbl_path = os.path.join(labels_path, base + '.txt')
    if os.path.exists(img_path) and os.path.exists(lbl_path):
        valid_bases.append(base)


In [54]:
len(valid_bases)

500

In [55]:
# Shuffle and split
random.shuffle(valid_bases)
n = len(common_bases)
train_split = int(0.7 * n)
val_split   = int(0.85 * n)

In [56]:
train_bases = common_bases[:train_split]
val_bases   = common_bases[train_split:val_split]
test_bases  = common_bases[val_split:]

In [57]:
# Copy function
def copy_data(basenames, split):
    for base in basenames:
        shutil.copy(os.path.join(images_path, base + '.jpg'), os.path.join(destination_dir, split, 'images', base + '.jpg'))
        shutil.copy(os.path.join(labels_path, base + '.txt'), os.path.join(destination_dir, split, 'labels', base + '.txt'))


In [58]:
# Copy files
copy_data(train_bases, 'train')
copy_data(val_bases, 'val')
copy_data(test_bases, 'test')

In [60]:
len(train_bases) + len(val_bases) + len(test_bases)

500