In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import os
import shutil
from tqdm import tqdm
from joblib import Parallel, delayed
from skimage.measure import label

In [2]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    


In [3]:
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

In [4]:
def get_files(PATH):
    file_lan = []
    for filepath,dirnames,filenames in os.walk(PATH):
        for filename in filenames:
            file_lan.append(os.path.join(filepath,filename))
    return file_lan

In [5]:
def segment_connected_components(mask):
    labeled_array, num_features = label(mask, connectivity=2, return_num=True)

    components = {}
    for label_idx in range(1, num_features + 1):
        component_mask = (labeled_array == label_idx)
        if component_mask.sum() < component_mask.shape[0]*component_mask.shape[1] * 0.004: continue
        components[label_idx] = component_mask.astype(int)

    return components

In [6]:
import sys
sys.path.append("..")

from segment_anything import sam_model_registry, SamPredictor
sam_checkpoint = "sam_vit_l_0b3195.pth"

model_type = "vit_l"

device = 'cuda' if torch.cuda.is_available() else 'cpu'

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
pathlist = [r'D:/Code/deep_learning/PrivateWork/20240514DIS_SAM/IS-Net/DIS5K/DIS5K-test/gt/',
            ]

In [10]:
def istestdatas(impath):
    # if os.path.exists(impath.replace('/gt','/SGT')): 
    #     return
    gt = cv2.imread(impath)
    gt = cv2.cvtColor(gt, cv2.COLOR_BGR2GRAY)
    tpath = impath.replace('/gt','/im').replace('.png','.jpg')
    image = cv2.imread(tpath)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    mask = gt.copy()
    mask[mask<128] = 0
    mask[mask>=128] = 1
    connected_components = segment_connected_components(mask)
    if connected_components.__len__()>1:
        connected_components['0'] = mask
    for idx, component in connected_components.items():
        mask_input = np.array(component, dtype="uint8")
        if mask_input.max() == 1:
            mask_input = mask_input * 255
        rows, cols = np.where(mask_input > 125)
        left,top = np.min(cols),np.min(rows)
        right,bottom = np.max(cols),np.max(rows)
        lenth = bottom-top
        width = right-left
        input_box = np.array([left,top,right,bottom])

        input_label = np.array([1])

        predictor.set_image(image)
        masks,_,_ = predictor.predict(
        point_labels=input_label,
        box=input_box,
        multimask_output=True,)
        masks = masks[0]

        print(tpath.replace('/im','/enhance_im').replace('.jpg','_comp_'+str(idx)+'.jpg'))
        shutil.copy(tpath,tpath.replace('/im','/enhance_im').replace('.jpg','_comp_'+str(idx)+'.jpg'))
        cv2.imwrite(impath.replace('/gt','/enhance_gt').replace('.png','_comp_'+str(idx)+'.png'), np.array(mask_input, dtype="uint8"))
        cv2.imwrite(impath.replace('/gt','/enhance_sam').replace('.png','_comp_'+str(idx)+'.png'), np.array(masks, dtype="uint8")*255)
    
nums = 0
for k in pathlist:
    nums+=1
    print(nums,'/',pathlist.__len__())
    impaths = get_files(k)
    enhance_im_path = k.replace('/gt','/enhance_im')
    enhance_gt_path = k.replace('/gt','/enhance_gt')
    enhance_sam_path = k.replace('/gt','/enhance_sam')
    if not os.path.exists(enhance_im_path):
        os.makedirs(enhance_im_path, exist_ok=True)
    if not os.path.exists(enhance_gt_path):
        os.makedirs(enhance_gt_path, exist_ok=True)
    if not os.path.exists(enhance_sam_path):
        os.makedirs(enhance_sam_path, exist_ok=True)
    print(impaths.__len__())
    for im in tqdm(range(impaths.__len__()),total=impaths.__len__()):
        istestdatas(impaths[im]) 
        # break

1 / 1
3


 33%|███▎      | 1/3 [00:00<00:01,  1.63it/s]

D:/Code/deep_learning/PrivateWork/20240514DIS_SAM/IS-Net/DIS5K/DIS5K-test/enhance_im/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.jpg


 67%|██████▋   | 2/3 [00:00<00:00,  2.59it/s]

D:/Code/deep_learning/PrivateWork/20240514DIS_SAM/IS-Net/DIS5K/DIS5K-test/enhance_im/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.jpg
D:/Code/deep_learning/PrivateWork/20240514DIS_SAM/IS-Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.jpg
D:/Code/deep_learning/PrivateWork/20240514DIS_SAM/IS-Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.jpg


100%|██████████| 3/3 [00:01<00:00,  1.74it/s]

D:/Code/deep_learning/PrivateWork/20240514DIS_SAM/IS-Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.jpg



