# generate the masks by SAM and write into jsons

In [None]:
import sys
import cv2
import os
import json
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "../sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.86,
    stability_score_thresh=0.1,
    crop_n_layers=1,
    crop_overlap_ratio=0,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=1000,  # Requires open-cv to run post-processing
)

# src_root = 'D://research/OnsetDetection/onset_imgs'
src_root = 'D://research/OnsetStatistics/data/isolated_jpgs3'
file_list = os.listdir(src_root)
file_list.sort()
# dst_root = 'D://research/OnsetDetection/pure_imgs2'
dst_root ='D://research/OnsetStatistics/data/isolated_jsons3'

for file in file_list[540:541]:
    print(file)
    img_list = os.listdir(os.path.join(src_root, file))
    img_list.sort()
    
    if not os.path.exists(os.path.join(dst_root, file)):
        os.makedirs(os.path.join(dst_root, file))
    
    for img in img_list:
        
        imgg = Image.open(os.path.join(src_root, file, img))
        img_array = np.array(imgg.convert('L').getdata()).reshape(240,240).astype(np.float32)/255

        image = cv2.imread(os.path.join(src_root, file, img))
        ori_image = np.array(image)
        # image = cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21)
        ret, image = cv2.threshold(image, 3,255, cv2.THRESH_BINARY)
        # image = cv2.GaussianBlur(image, (5,5), 1.5)

        # print(ori_image.shape)
        masks = mask_generator.generate(image)
        
        # 写入json
        segs = []
        areas = []
        bbox = []
        ious = []
        coords = []
        scores = []
        crops = []
        for mask in masks:
            segs.append(mask['segmentation'].tolist())
            areas.append(mask['area'])
            bbox.append(mask['bbox'])
            ious.append(mask['predicted_iou'])
            coords.append(mask['point_coords'])
            scores.append(mask['stability_score'])
            crops.append(mask['crop_box'])
        new_dict = {'segmentation':segs, 'area':areas, 'bbox':bbox, 'predicted_iou':ious, 'point_coords':coords, 'stability_score':scores, 'crop_box':crops}
        
        
        # json_str = json.dumps(masks)
        # new_dict = json.loads(json_str)
        json_path = os.path.join(dst_root, file, img[:-6]+'.json')
        with open(json_path, 'w') as f:
            json.dump(new_dict, f)
            
        print(img)

# get innotated images from jsons

In [5]:
def show_anns(anns, dst_path):
    if len(anns) == 0:
        img = np.ones((240, 240, 3))
    else:
        sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
        # ax = plt.gca()
        # ax.set_autoscale_on(False)

        img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
        img[:,:,3] = 0
        for ann in sorted_anns:
            m = ann['segmentation']
            color_mask = np.concatenate([np.random.random(3), [0.35]])
            img[m] = color_mask
    
    plt.axis('off')
    plt.gca().xaxis.set_major_locator(plt.NullLocator()) 
    plt.gca().yaxis.set_major_locator(plt.NullLocator()) 
    plt.subplots_adjust(top=1,bottom=0,left=0,right=1,hspace=0,wspace=0) 
    plt.margins(0,0)
    plt.imshow(img)
    plt.savefig(dst_path, bbox_inches="tight", pad_inches=0.0)
    plt.close()

In [None]:
import os
import json
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
src_root = 'E://OnsetStatistics/data/isolated_jsons3'
file_list = os.listdir(src_root)
file_list.sort()

img_root = 'D://research/OnsetStatistics/data/isolated_jpgs3'
dst_root = 'D://research/OnsetStatistics/data/isolated_pures3'
for file in file_list:
    
    f_list = os.listdir(os.path.join(src_root, file))
    f_list.sort()
    print(file)
    
    for f in f_list:
        
        imgg = Image.open(os.path.join(img_root, file, f[:-5]+'_a.jpg'))
        img_array = np.array(imgg.convert('1').getdata()).reshape(240,240).astype(np.float32)/255
        
        jjson = open(os.path.join(src_root, file, f), 'r')
        content = jjson.read()
        a = json.loads(content)
        # print(type(a))
        
        new_masks = []
        for i in range(len(a['area'])):
            if a['area'][i]>100:
                mask = {}
                mask['segmentation'] = np.reshape(a['segmentation'][i], (240,240))
                mask['area'] = a['area'][i]
                new_masks.append(mask)
        
        final_masks = []
        for mask in new_masks:
            sum_intensity = 0
            num = 0
            # print(type(mask))
            for i in range(240):
                for j in range(240):
                    if mask['segmentation'][i,j]:
                        sum_intensity+=img_array[i,j]
                        num+=1
            # print(sum_intensity/num)
            if sum_intensity/num>0.01:
                final_masks.append(mask)
        
        dst_path = os.path.join(dst_root, file)
        if not os.path.exists(dst_path):
            os.makedirs(dst_path)
        # plt.imshow(img_array, cmap='Greys_r')
        # plt.axis('off')
        show_anns(final_masks, os.path.join(dst_path, f[:-5]+'.jpg'))
        
        # plt.savefig(dst_path, bbox_inches="tight", pad_inches=0.0)
        # plt.close()
        # print(f)