In [1]:
import os
from pathlib import Path

import cv2
import matplotlib.image as mpimg
import numpy as np
import matplotlib.pyplot as plt
import torch

from skimage.morphology import thin, skeletonize
from skimage import measure

from segment_anything import sam_model_registry, SamPredictor


In [2]:
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint='/mnt/offline_prod/Alon/transfer/tarp_sam_inference_26_10_25/sam_vit_h_4b8939.pth')
sam.to(device=device)
predictor = SamPredictor(sam)

In [3]:
def calc_IOU(bounding_box, img, mask1, mask2, display_images=False, mask1_no_dilate=None):

    x,y,w,h = bounding_box
    u,d,l,r = y, y + h, x, x + w
    
    patch1 = mask1[u:d,l:r]
    patch2 = mask2[u:d,l:r]
    
    fp1 = (patch1>0).astype(int)
    fp2 = (patch2>0).astype(int)
    
    union_mask = np.zeros((h,w), np.uint8)
    intersect_mask = np.zeros((h,w), np.uint8)
    
    intersect_mask = fp1*fp2
    union_mask = fp1 + fp2
    union_mask = map_mask(union_mask, {1:1,2:1})

    intersect_mask = intersect_mask.astype(np.int64)
    union_mask = union_mask.astype(np.int64)
    
    if sum(union_mask.flatten()) > 0:    
        # IOU = sum(intersect_mask.flatten())/sum(union_mask.flatten())
        IOU = sum(intersect_mask.flatten())/sum(union_mask.flatten())
    else:
        IOU = 0 
        
    # if (display_images and IOU > 0): # (display_images and IOU == 0)

    if display_images == True:
        
        _buffer = 30
        
        n_u = lower(u,_buffer)
        n_d = upper(d,_buffer)
        n_r = upper(r,_buffer)
        n_l = lower(l,_buffer)
                
        print('\n\n')  
        print('up diff - ', str(n_u-u))
        print('down diff - ', str(n_d-d))
        print('right diff - ', str(n_r-r))
        print('left diff - ', str(n_l-l))
        
        print('IOU - ', IOU)
        
        display_arrays([img[n_u:n_d,n_l:n_r,:],
                        mask1[n_u:n_d,n_l:n_r],                              
                        mask2[n_u:n_d,n_l:n_r]], (20,20)) 
                    
    return IOU



def new_calc_polygons(image, threshold=0.5, fig_size=(5,5), plt_fig=False):

    new_image = image.copy()
    ret, thresh = cv2.threshold(new_image, threshold, 255, cv2.THRESH_BINARY)
    contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
    cont_im = cv2.drawContours(new_image, contours, -1, (0,255,0), 3)
    polygons = {}

    if cont_im is not None:
        
        if plt_fig:
            plt.figure(figsize = fig_size)
            plt.imshow(cont_im)
            plt.imshow(thresh)
            plt.show()

        for poly_index, contour in enumerate(contours, 1):
            object_area = cv2.contourArea(contour)
            bounding_box = cv2.boundingRect(contour) # x,y,w,h = cv2.boundingRect(contour)

           # Create a mask to get the filled polygon pixels
            _mask = np.zeros(new_image.shape[:2], dtype=np.uint8)
            cv2.fillPoly(_mask, [contour], 1)
            # Extract the pixels comprising the polygon
            polygon_pixels = np.argwhere(_mask == 1)

            coords = []
            for point in contour:
                coords.append((int(point[0][0]), int(point[0][1])))
            polygons[poly_index] = {"coords":coords,
                                    "bb":bounding_box,
                                    "area": object_area,
                                    "pixels": [(x, y) for x, y in polygon_pixels]}
    return polygons


def map_mask(mask, mapping_dict):
#     new__mask = mask.copy()
    new__mask = np.zeros(mask.shape, np.uint8)
    for original, new in mapping_dict.items():
        iter_indexes = np.where(mask == original)
        new__mask[iter_indexes[0], iter_indexes[1]] = new
    return new__mask
          

def modify_mask(mask, mapping_mask, mapping_dict):
    new_mask = mask.copy()    
    for original, new in mapping_dict.items():
        iter_indexes = np.where(mapping_mask == original)
        new_mask[iter_indexes[0], iter_indexes[1]] = new
    return new_mask


def display_arrays(display_list, figure_size = (20,20), titles = None,
                   plt_fig = True, savefig_pth = None):    
    plt.figure(figsize=figure_size)  
    try:  
        for i in range(len(display_list)):
            plt.subplot(1, len(display_list), i+1)
            plt.imshow((display_list[i]))
            plt.axis('off')
            if titles != None:
                plt.title(titles[i])
    except:
        hello_there = 1

In [4]:
np_data = np.load("new_sam_arrays.npz")
footprint = np_data["footprint"]
tarp = np_data["tarp"]
ground_tarp = np_data["ground_tarp"]
img = np_data["img"]
result = np_data["result"]

pred_logits = torch.load("sam_logits.pt")


In [5]:
def local_US_tarp_Apr_25_post(prediction, blue_roof_mask = None):
    
    """
    model classes
    1 - tarp & insulation on roof
    2 - tarp on the ground
    3 - blue roof
    4 - ignore (out of AOI)
    """

    if blue_roof_mask is not None:
        prediction = modify_mask(prediction, blue_roof_mask, {2:2, 3:3})

    combined_mask = map_mask(prediction, {1:1,2:1,3:1})
    post_pred = prediction.copy()

    polygons = new_calc_polygons(combined_mask,
                                 0.5,
                                 fig_size=None,
                                 plt_fig=False)
    
    for poly, poly_dict in polygons.items():        

        poly_indexes = (np.array([coords[0] for coords in poly_dict['pixels']]),
                        np.array([coords[1] for coords in poly_dict['pixels']]))      
        
        poly_pred_class = prediction[poly_indexes[0], poly_indexes[1]]

        roof_tarp_pixel_count = len(np.where(poly_pred_class == 1)[0])
        ground_tarp_pixels_count = len(np.where(poly_pred_class == 2)[0])
        blue_roof_pixels_count = len(np.where(poly_pred_class == 3)[0])
        background_pixels_count = len(np.where(poly_pred_class == 0)[0])

        # if roof_tarp_pixel_count < 0.5 * (roof_tarp_pixel_count + ground_tarp_pixels_count + blue_roof_pixels_count):
        if blue_roof_pixels_count > 0.1 * (roof_tarp_pixel_count + ground_tarp_pixels_count + blue_roof_pixels_count):
            post_pred[poly_indexes[0], poly_indexes[1]] = 5
            
        elif ground_tarp_pixels_count > 0.1 * (roof_tarp_pixel_count + ground_tarp_pixels_count + blue_roof_pixels_count):
            post_pred[poly_indexes[0], poly_indexes[1]] = 5

        elif roof_tarp_pixel_count < 0.3 *(roof_tarp_pixel_count + background_pixels_count):
            post_pred[poly_indexes[0], poly_indexes[1]] = 0

        post_pred = modify_mask(post_pred, blue_roof_mask, {2:2, 3:3})

    return post_pred

In [6]:
def SAM_predict(image, input_points, input_labels):

    predictor.set_image(image)
    
    masks, scores, logits = predictor.predict(
    point_coords=input_points,
    point_labels=input_labels,
    multimask_output=True,
    )

    return masks, scores, logits



def plot_connected_components(binary_mask):    

    # Find connected components
    labeled_mask = measure.label(binary_mask, connectivity=2)  # Use 4-connectivity (2D)
    print(np.unique(labeled_mask.flatten()))

    # Display the original binary mask and labeled mask
    fig, axes = plt.subplots(1, 2, figsize=(20, 20))
    axes[0].imshow(binary_mask, cmap='gray')
    axes[0].set_title('Binary Mask')
    axes[0].axis('off')

    axes[1].imshow(labeled_mask, cmap='nipy_spectral')  # Use a color map for visualization
    axes[1].set_title('Labeled Components')
    axes[1].axis('off')

    plt.show()



def display_SAM_prediction(image, masks, scores, input_points, input_labels):
   
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10,10))
        plt.imshow(image)
        show_mask(mask, plt.gca())
        show_points(input_points, input_labels, plt.gca())
        plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()  


inference_img_size = 2048
upper = lambda x, buffer: min(x+buffer, inference_img_size)
lower = lambda x, buffer: max(0, x-buffer)


# def calc_instance_intersection(labeled_mask, unlabeled_mask, mapping = None):
    
#     if mapping != None:
#         unlabeled_mask = map_mask(unlabeled_mask, mapping)

#     remaining_labels_mask = unlabeled_mask*labeled_mask
#     return np.unique(remaining_labels_mask.flatten())


def check_if_high_conf_poly(labeled_mask, componenet_index, prediction_logits):

    pxl_indx = np.where(labeled_mask == componenet_index)
    pxl_confidence = prediction_logits[1,pxl_indx[0],pxl_indx[1]].cpu().detach().numpy()
    pxl_test = (pxl_confidence > 0.38).astype(np.int32)
    
    if sum(pxl_test) > 0.5*len(pxl_test):
        poly_high_conf = True        
    else:
        poly_high_conf = False

    return poly_high_conf, pxl_indx, pxl_confidence
    

def find_poly_corr_with_sam(pxl_indx, img, tarp_mask, format_img = False):
        
    pred_u = min(pxl_indx[0])
    pred_d = max(pxl_indx[0])
    pred_l = min(pxl_indx[1])
    pred_r = max(pxl_indx[1])

    buffer = 100
    u,d,= lower(pred_u,buffer), upper(pred_d,buffer)       
    l,r = lower(pred_l,buffer), upper(pred_r,buffer)

    input_points = np.array([[pxl_indx[1][i]-l, pxl_indx[0][i]-u] for i in range(len(pxl_indx[0]))])
    input_labels = np.array([1]*len(input_points))

    img_patch = img[u:d,l:r] 

    if format_img == True:     
        norm_img_patch = np.clip(img_patch * 255, 0, 255)          
        img_patch = norm_img_patch.astype(np.uint8) # Convert to uint8

    masks, scores, logits = SAM_predict(img_patch, input_points, input_labels)
    max_index = np.argmax(scores)
    
    sam_pxl_indx = np.where(masks[max_index] == True)

    sam_u = min(sam_pxl_indx[0])+u
    sam_d = max(sam_pxl_indx[0])+u
    sam_l = min(sam_pxl_indx[1])+l
    sam_r = max(sam_pxl_indx[1])+l

    iou_u = min(pred_u, sam_u) 
    iou_d = max(pred_d, sam_d) 
    iou_l = min(pred_l, sam_l) 
    iou_r = max(pred_r, sam_r) 

    bb = iou_l-l, iou_u-u, iou_r-iou_l, iou_d-iou_u
    
    # display_arrays([masks[max_index].astype(np.int32),
    #                         tarp_mask[u:d,l:r]])

    IOU = calc_IOU(bb, img_patch, masks[max_index].astype(np.int32),
                        tarp_mask[u:d,l:r], False) 
    
    return (
        IOU, img_patch, masks, scores, input_points, input_labels, 
        bb, max_index, pred_d, pred_l, pred_r, pred_u , d, l, r, u
    )


In [7]:
def first_post_process(img, original_tarp_mask, ground_tarp_mask, footprint_mask, pred_logits, display_results = False):


    prediction_mask = local_US_tarp_Apr_25_post(original_tarp_mask, ground_tarp_mask)


    prediction_logits = torch.softmax(pred_logits, dim=0)
    
    # Convert to numpy and permute dimensions if necessary
    # Convert from [C, H, W] to [H, W, C]   
    image_to_display = prediction_logits[0:3,:,:].permute(1, 2, 0)   
    image_to_display = image_to_display.cpu().detach().numpy()

    if image_to_display.dtype != 'float32':
        image_to_display = image_to_display.astype('float32')

    tarp_mask = map_mask(prediction_mask, {1:1})

    tarp_labeled_mask = measure.label(tarp_mask, connectivity=2)
   
    thin_tarp_mask = np.uint8(skeletonize(tarp_mask))

    assert len(np.unique(thin_tarp_mask.flatten())) == 2

    thin_tarp_indexes = np.where(thin_tarp_mask == 1)

    assert 0 not in np.unique(tarp_labeled_mask[thin_tarp_indexes[0],thin_tarp_indexes[1]].flatten())

    labeled_mask = thin_tarp_mask * tarp_labeled_mask
    tarp_component_indexes = np.unique(labeled_mask.flatten()).tolist()
    tarp_component_indexes.pop(0) 

    img_patching_arr = np.zeros(prediction_mask.shape, np.uint8)

    buffer = 100    
    poly_label_mapping = {}

    
    for componenet_index in tarp_component_indexes:        

        if componenet_index in poly_label_mapping:
            continue

        poly_high_conf, pxl_indx, pxl_confidence = check_if_high_conf_poly(labeled_mask,
                                                                           componenet_index,
                                                                           prediction_logits)
        if poly_high_conf == True:
            poly_label_mapping[componenet_index] = 1
        else:
            poly_label_mapping[componenet_index] = 6

        pred_u = min(pxl_indx[0])
        pred_d = max(pxl_indx[0])
        pred_l = min(pxl_indx[1])
        pred_r = max(pxl_indx[1])

        u,d,= lower(pred_u,buffer), upper(pred_d,buffer)       
        l,r = lower(pred_l,buffer), upper(pred_r,buffer)

        img_patching_arr[u:d,l:r] = 1


    patching_labeled_mask = measure.label(img_patching_arr, connectivity=2)
    patching_component_indexes = np.unique(patching_labeled_mask.flatten()).tolist()
    patching_component_indexes.pop(0) 
    img_patch_dict = {}

    for patch_component_indx in patching_component_indexes:

        patching_pxl_indx = np.where(patching_labeled_mask == patch_component_indx)
        patch_u = min(patching_pxl_indx[0])
        patch_d = max(patching_pxl_indx[0])
        patch_l = min(patching_pxl_indx[1])
        patch_r = max(patching_pxl_indx[1])

        img_patch = img[patch_u:patch_d,patch_l:patch_r]      
        norm_img_patch = np.clip(img_patch * 255, 0, 255)          
        img_patch = norm_img_patch.astype(np.uint8)

        img_patch_dict[(patch_u, patch_d, patch_l, patch_r)] = img_patch

    
    footprint_mask = map_mask(footprint_mask, {1:10,3:10,4:10,5:10})  # previous - {1:1,2:1,3:1,4:1,5:1}
    
    return_mask = map_mask(tarp_labeled_mask, poly_label_mapping)
    remaining_classes_mask = map_mask(prediction_mask, {2:2,3:3,4:4,5:5})
    return_mask = return_mask + remaining_classes_mask + footprint_mask

    
    return return_mask, img_patch_dict

In [8]:
def second_post_process(mask, img_patch_dict):

    img = np.zeros((mask.shape[0],mask.shape[1],3), np.uint8)

    for patch_key, patch_img in img_patch_dict.items():
        patch_u = patch_key[0]
        patch_d = patch_key[1]
        patch_l = patch_key[2]
        patch_r = patch_key[3]
        img[patch_u:patch_d,patch_l:patch_r] = patch_img
        

    footprint_mask = np.zeros(mask.shape, np.uint8)

    fp_indexes = np.where(mask >= 10)

    footprint_mask[fp_indexes[0],fp_indexes[1]]=1

    prediction_mask = mask - 10*footprint_mask

    tarp_mask = map_mask(prediction_mask, {6:1})
    
    tarp_labeled_mask = measure.label(tarp_mask, connectivity=2)
   
    thin_tarp_mask = np.uint8(skeletonize(tarp_mask))

    assert len(np.unique(thin_tarp_mask.flatten())) == 2

    thin_tarp_indexes = np.where(thin_tarp_mask == 1)

    assert 0 not in np.unique(tarp_labeled_mask[thin_tarp_indexes[0],thin_tarp_indexes[1]].flatten())

    labeled_mask = thin_tarp_mask * tarp_labeled_mask
    num_components = len(np.unique(labeled_mask.flatten()))
     
    labeled_fp_mask = measure.label(footprint_mask, connectivity=2)
    fp_num_componenets = len(np.unique(labeled_fp_mask.flatten()))
    
    if num_components == 1 or fp_num_componenets == 1:
        return prediction_mask
    
    buffer = 100
    poly_label_mapping = {}

    for fp_index in range(fp_num_componenets):
        
        fp_pxl_indx = np.where(labeled_fp_mask == fp_index)
        labeled_mask_in_fp = labeled_mask[fp_pxl_indx[0],fp_pxl_indx[1]].flatten()
        poly_indexes_inside_fp = np.unique(labeled_mask_in_fp)
        poly_indexes_inside_fp = poly_indexes_inside_fp.tolist()
        poly_indexes_inside_fp.pop(0) 

        dict_poly_inside_fp = {x:np.count_nonzero(labeled_mask_in_fp == x) for x in poly_indexes_inside_fp}
        # acsending_dict_poly_inside_fp = dict(sorted(dict_poly_inside_fp.items(), key=lambda item: item[1]))
        decsending_dict_desc = dict(sorted(dict_poly_inside_fp.items(), key=lambda item: item[1], reverse=True))
      
        # if fp_index == 0:  
        #     individual_poly_processing = True
        # else:
        #     individual_poly_processing = False

        # individual_poly_processing = True

        for componenet_index in list(decsending_dict_desc.keys()):        

            if componenet_index in poly_label_mapping:
                continue

            pxl_indx = np.where(labeled_mask == componenet_index)
                    

            (   IOU, img_patch, masks, scores, input_points, 
                input_labels, bb, max_index, pred_d, 
                pred_l, pred_r, pred_u, d, l, r, u
            ) = find_poly_corr_with_sam(pxl_indx, img, tarp_mask, False)

            if IOU < 0.3:        
                # print(f'{u} {d} {l} {r}')
                poly_label_mapping[componenet_index] = 0
            elif IOU < 0.7:
                # print(f'light filter -  {u} {d} {l} {r}')
                poly_label_mapping[componenet_index] = 0
            if IOU >= 0.7:
                # print(f'strong filter -  {u} {d} {l} {r}')
                poly_label_mapping[componenet_index] = 1


    final_mask = map_mask(tarp_labeled_mask, poly_label_mapping)
    remaining_classes_mask = map_mask(prediction_mask, {1:1,2:2,3:3,4:4,5:5})
    final_mask = final_mask + remaining_classes_mask

    
    return final_mask

In [9]:

return_mask, img_patch_dict = first_post_process(img, tarp, ground_tarp, footprint, pred_logits, display_results = False)

final_tarp = second_post_process(return_mask, img_patch_dict)


In [None]:
# new_result = map_mask(result, {1:1,2:2,3:3,4:4,5:5})

# print(np.unique(final_tarp.flatten()))
# print(sum(final_tarp.flatten()))

# print(np.unique(new_result.flatten()))
# print(sum(new_result.flatten()))

# np.array_equal(new_result, final_tarp)

[0 1 2 5]
14422
[0 1 2 5]
14422


True