In [None]:
import glob
import os
from typing import Any, Dict, List
from pycocotools.coco import COCO
from pycocotools.mask import decode
import json
import matplotlib.pyplot as plt
import cv2  # type: ignore
import numpy as np
import shutil

def sort_by_num(path,dataset='SA_1B'):
    if dataset == "./LVIS_output/":
        num = int(path.split("/")[-1].split(".")[0])
    else:
        num = int(path.split("_")[-1].split(".")[0])
    return num

color_list = [
    [1.0, 0.0, 0.0, 0.35],  # 红色 (Red)
    [0.0, 1.0, 0.0, 0.35],  # 绿色 (Green)
    [0.0, 0.0, 1.0, 0.35],  # 蓝色 (Blue)
    [1.0, 1.0, 0.0, 0.35],  # 黄色 (Yellow)
    [1.0, 0.0, 1.0, 0.35],  # 洋红 (Magenta)
    [0.0, 1.0, 1.0, 0.35],  # 青色 (Cyan)
    [0.5, 0.0, 0.0, 0.35],  # 深红 (Maroon)
    [0.0, 0.5, 0.0, 0.35],  # 深绿 (Dark Green)
    [0.0, 0.0, 0.5, 0.35],  # 深蓝 (Navy)
    [0.5, 0.5, 0.0, 0.35],  # 橄榄 (Olive)
    [0.0, 0.5, 0.5, 0.35],  # 深青 (Teal)
    [0.5, 0.0, 0.5, 0.35],  # 紫色 (Purple)
    [0.3, 0.3, 0.3, 0.35],  # 灰色 (Gray)
    [1.0, 0.5, 0.0, 0.35],  # 橙色 (Orange)
    [0.5, 1.0, 0.0, 0.35],  # 鲜绿 (Lime)
    [0.0, 0.5, 1.0, 0.35],  # 天蓝 (Sky Blue)
    [1.0, 0.0, 0.5, 0.35],  # 玫瑰红 (Rose)
    [0.5, 0.0, 1.0, 0.35],  # 靛蓝 (Indigo)
    [1.0, 0.5, 0.5, 0.35],  # 褐红 (Crimson)
    [0.5, 1.0, 0.5, 0.35],  # 橙黄 (Gold)
    [0.5, 0.5, 1.0, 0.35],  # 蓝紫 (Blue Violet)
    [0.8, 0.2, 0.2, 0.35],  # 深粉 (Deep Pink)
    [0.2, 0.8, 0.2, 0.35],  # 鲜绿 (Bright Green)
    [0.2, 0.2, 0.8, 0.35],  # 矢车菊蓝 (Cornflower Blue)
    [0.8, 0.8, 0.2, 0.35],  # 明黄 (Lemon Yellow)
    [0.2, 0.8, 0.8, 0.35],  # 湖蓝 (Lake Blue)
]

def get_color(index):
    global color_list
    return color_list[index]

def show_anns(anns):
    if len(anns) == 0:
        return

    ax = plt.gca()
    ax.set_autoscale_on(False) 
    img = np.ones((anns[0]['segmentation'].shape[0], anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for idx, ann in enumerate(anns):
        m = ann['segmentation']
        color_mask = amg.get_color(idx%25)
        img[m] = color_mask
    ax.imshow(img)

def write_output( filename: str, path: str, appendix: str) -> None:
    save_folder = os.path.join(path, filename.split("/")[-1].replace('.json', ''))

    if os.path.exists(save_folder):
        shutil.rmtree(save_folder)
    os.makedirs(save_folder, exist_ok=True)
    return save_folder
    
    # write_masks_to_folder(asks, save_base)
            
def show_mask(mask, ax, idx, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = get_color(idx%25)

    h, w = mask.shape[-2:]
    mask_image = (mask).reshape(h, w, 1) * np.array(color).reshape(1, 1, -1)
    ax.imshow(mask_image)



Set the pathes and number of images to process

In [None]:
# Path for saving the output
path = "../dataset/SA_1B_GT/"

num_files = 100

# dataset directory
directory = '../dataset/SA_1B'

# get all of the json files in the directory
json_files = glob.glob(f"{directory}/*.json")

sorted_targets = sorted(json_files, key=sort_by_num)

Generating groundtruth masks based on the json file

In [None]:
for json_file in sorted_targets[:num_files]:
    print(f'generating masks for {json_file}')
    image = json_file.replace('json', 'jpg')
    image = cv2.imread(image)   
    if image is None:
            print(f"Could not load '{t}' as an image, skipping...")
            continue
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    with open(json_file, 'r') as file:
        data = json.load(file)
    print(f"There are {len(data['annotations'])} masks in this image.")
    save_folder = write_output(json_file, path, 'gt')
    for i in range(len(data['annotations'])):
        annotation = data['annotations'][i]  

        rle = {
            'counts': annotation['segmentation']['counts'],
            'size': annotation['segmentation']['size']
        }

        mask = decode(rle)

        mask_name = f"{i}.png"
        cv2.imwrite(os.path.join(save_folder, mask_name), mask * 255)


Generate coloured masks on the images

In [None]:
for json_file in sorted_targets[:num_files]:
    print(f'generating entire masks for {json_file}')
    image = json_file.replace('json', 'jpg')
    image = cv2.imread(image)   
    if image is None:
            print(f"Could not load '{t}' as an image, skipping...")
            continue
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    with open(json_file, 'r') as file:
        data = json.load(file)

    plt.figure(figsize=(20,20))
    plt.imshow(image)
    save_folder = os.path.join(path, json_file.split("/")[-1].replace('.json', ''))
    for i in range(len(data['annotations'])):
        annotation = data['annotations'][i]  

        rle = {
            'counts': annotation['segmentation']['counts'],
            'size': annotation['segmentation']['size']
        }

        mask = decode(rle)

        show_mask(mask,plt.gca(),i)

    plt.axis('off')
        # if not os.path.isfile(f'{save_base}/{base}.jpg'):
        #     shutil.copy(t, f'{save_base}/{base}.jpg')               
    # plt.show()
    output_name =   f"{save_folder}/{json_file.split('/')[-1].replace('.json', '_gt')}.png"
    
    print(f'writing to {output_name}')
    plt.savefig(f"{output_name}", bbox_inches='tight', pad_inches=0)
    plt.close()