In [None]:
import os
import tqdm
import json
from PIL import Image
import matplotlib.pyplot as plt
from pycocotools.mask import encode, decode
import cv2
import numpy as np
import random

dataset_root = '/home/dchenbs/workspace/datasets/DSA/DirectSAM-1800px-0424'
subsets = os.listdir(dataset_root)

image_path_to_sample = {}
for subset in subsets:
    if subset == 'Merged':
        continue
    subset_path = os.path.join(dataset_root, subset)
    files = os.listdir(subset_path)
    print(subset, len(files))
    for file in tqdm.tqdm(files):
        if file.startswith('_'):
            continue
        file = os.path.join(subset_path, file)
        sample = json.load(open(file))

        if sample['image_path'] not in image_path_to_sample:
            image_path_to_sample[sample['image_path']] = [file]
        else:
            image_path_to_sample[sample['image_path']].append(file)

In [None]:
print(f'Total number of images: {len(image_path_to_sample)}')
image_path_to_sample = {image_path: files for image_path, files in image_path_to_sample.items() if len(files) > 1}
print(f'Total more than one target images: {len(image_path_to_sample)}')

In [None]:
folders = []

for k, v in tqdm.tqdm(image_path_to_sample.items()):
    folder = os.path.dirname(k)
    if folder not in folders:
        folders.append(folder)

print(folders)

In [None]:
os.makedirs(os.path.join(dataset_root, 'COCO'), exist_ok=True)
os.makedirs(os.path.join(dataset_root, 'ImageNet'), exist_ok=True)

image_paths = list(image_path_to_sample.keys())
for image_path in tqdm.tqdm(image_paths):
    files = image_path_to_sample[image_path]

    if 'coco2017' in image_path:
        output_folder = 'COCO'
    elif 'imagenet' in image_path:
        output_folder = 'ImageNet'
    else:
        raise ValueError(f'Unknown dataset: {image_path}')

    output_file = os.path.join(dataset_root, output_folder, os.path.basename(image_path)).split('.')[0] + '.json'
    
    human_labels = []
    for i, file in enumerate(files):
        sample = json.load(open(file))
        human_labels.append(sample['human_label'][0])
    sample['human_label'] = human_labels
    
    json.dump(sample, open(output_file, 'w'), indent=4)

    for file in files:
        os.rename(file, file.replace(os.path.basename(file), '_' + os.path.basename(file)))
        # print(f'Renaming {file} to {file.replace(os.path.basename(file), "_" + os.path.basename(file))}')
    # break