In [None]:
import os
import tqdm
import json
from PIL import Image
import matplotlib.pyplot as plt
from pycocotools.mask import encode, decode
import cv2
import numpy as np
import random

dataset_root = '/home/dchenbs/workspace/datasets/DirectSAPlus/DirectSAM-1800px-0424'
subsets = os.listdir(dataset_root)

image_path_to_sample = {}
for subset in subsets:
    if subset == 'Merged':
        continue
    subset_path = os.path.join(dataset_root, subset)
    files = os.listdir(subset_path)
    print(subset, len(files))
    for file in tqdm.tqdm(files):
        file = os.path.join(subset_path, file)
        sample = json.load(open(file))

        if sample['image_path'] not in image_path_to_sample:
            image_path_to_sample[sample['image_path']] = [file]
        else:
            image_path_to_sample[sample['image_path']].append(file)

In [None]:
print(f'Total number of images: {len(image_path_to_sample)}')
image_path_to_sample = {image_path: files for image_path, files in image_path_to_sample.items() if len(files) > 1}
print(f'Total more than one target images: {len(image_path_to_sample)}')

In [None]:
os.makedirs(os.path.join(dataset_root, 'Merged'), exist_ok=True)
image_paths = list(image_path_to_sample.keys())
for image_path in tqdm.tqdm(image_paths):
    files = image_path_to_sample[image_path]

    output_file = os.path.join(dataset_root, 'Merged', os.path.basename(image_path)).split('.')[0] + '.json'
    
    target = None
    for i, file in enumerate(files):
        sample = json.load(open(file))
        label = decode(sample['target'])

        if target is None:
            target = label
        else:
            target += label

    target = target>=1
    rle_target = encode(np.array(target, order='F', dtype=np.uint8))
    rle_target['counts'] = rle_target['counts'].decode('utf-8')
    sample['target'] = rle_target

    merged = (decode(sample['prediction']) + target) > 0
    rle_merged = encode(np.array(merged, order='F', dtype=np.uint8))
    rle_merged['counts'] = rle_merged['counts'].decode('utf-8')
    sample['merged'] = rle_merged

    json.dump(sample, open(output_file, 'w'), indent=4)

    for file in files:
        os.rename(file, file.replace(os.path.basename(file), '_' + os.path.basename(file)))
        # print(f'Renaming {file} to {file.replace(os.path.basename(file), "_" + os.path.basename(file))}')


In [None]:
# dir = '/home/dchenbs/workspace/datasets/DirectSAPlus/DirectSAM-1800px-0424/LVIS/'

# for file in os.listdir(dir):
#     if file.startswith('_'):
#         print(file)
#         os.rename(os.path.join(dir, file), os.path.join(dir, file[1:]))

In [None]:
color_maps = ['Reds', 'Greens', 'Blues', 'Greys']

for i in range(10):
    image_path = random.choice(list(image_path_to_sample.keys()))
    files = image_path_to_sample[image_path]

    print(image_path, files)
    image = Image.open(image_path).resize((1800, 1800))


    plt.figure(figsize=((len(files)+2) * 5, 5))
    plt.subplot(1, (len(files)+2), 1)
    plt.imshow(image)
    plt.axis('off')


    target = None
    for i, file in enumerate(files):
        file = file.replace(os.path.basename(file), '_' + os.path.basename(file))
        plt.subplot(1, (len(files)+2), i + 2)
        sample = json.load(open(file))
        label = decode(sample['target'])

        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (7, 7))
        label = cv2.dilate(label.astype('uint8'), kernel, iterations=1)
        plt.imshow(label * 255, cmap=color_maps[i])

        plt.imshow(image, alpha=0.3)
        plt.axis('off')
        plt.title(sample['info']['dataset'])

        if target is None:
            target = label
        else:
            target += label
    
    plt.subplot(1, (len(files)+2), len(files) + 2)
    plt.imshow((target>=1) * 255, cmap=color_maps[-1])
    plt.imshow(image, alpha=0.3)
    plt.axis('off')
    plt.title('Merged')
    plt.show()

    # break