In [5]:
import os
import glob
import math
import torch
from torch.utils.data import random_split

In [14]:
def save_to_txt(file_path, data_list, ds, ext):
    with open(file_path, 'w') as f:
        for item in sorted(data_list):
            scene_id = os.path.basename(os.path.dirname(os.path.dirname(item)))
            image_id = os.path.basename(item).split('.')[0]

            
            image_path = f"{ds}/{scene_id}/rgb/{image_id}.{ext}"
            
            mask_path = f"{ds}_masks/{scene_id}/mask_visib/{image_id}.png"
            if ds == 'train_primesense':
                mask_path = f"{ds}/{scene_id}/mask_visib/{image_id}_000000.png"
            
            f.write(f"{image_path} {mask_path}\n")

def generate_split_files(root, split):
    datasets = ['train_pbr', 'train_primesense']
    for ds in datasets:
        if ds == "train_pbr":
            full_size_train = 50000
        elif ds == "train_primesense":
            full_size_train = 37584
        else:
            raise ValueError(f'Invalid dataset: {ds}')

        image_ext = "jpg" if ds == 'train_pbr' else "png"
        ids = list(sorted(glob.glob(os.path.join(root, ds, "*", "rgb", "*." + image_ext))))

        indexes = range(full_size_train)
        l_index, u_indexes = random_split(
            dataset=indexes,
            lengths=[split, 1 - split],
            generator=torch.Generator().manual_seed(42)
        )

        labeled_ids = [ids[i] for i in l_index]
        unlabeled_ids = [ids[i] for i in u_indexes]

        # Create directories for saving the txt files if they don't exist
        txt_dir = os.path.join(root, 'splits', '1_' + str(int(1/split)))
        os.makedirs(txt_dir, exist_ok=True)
        
        labeled_txt_path = os.path.join(txt_dir, f"{ds}_labeled.txt")
        unlabeled_txt_path = os.path.join(txt_dir, f"{ds}_unlabeled.txt")

        save_to_txt(labeled_txt_path, labeled_ids, ds, image_ext)
        save_to_txt(unlabeled_txt_path, unlabeled_ids, ds, image_ext)

In [23]:
root = '/' # insert root path
split = 1/32

generate_split_files(root, split)

In [25]:
root = '/' # insert root path
txt_dir = os.path.join(root, 'splits')
os.makedirs(txt_dir, exist_ok=True)
# Handle validation dataset
val_dataset = 'test_primesense'
val_image_ext = "png"
val_ids = list(sorted(glob.glob(os.path.join(root, val_dataset, "*", "rgb", "*." + val_image_ext))))

validation_txt_path = os.path.join(txt_dir, f"{val_dataset}.txt")
save_to_txt(validation_txt_path, val_ids, val_dataset, val_image_ext)