In [None]:
import requests
import json
from io import BytesIO
from PIL import Image
import numpy as np
import cv2

In [None]:
def generate_equally_spaced_colors(k):
    colors = []
    step = 360 / k  # Equally spaced hue step

    for i in range(k):
        hue = i * step  # Equally spaced hue values
        rgb = hsv_to_rgb(hue, 1, 1)  # Convert hue to RGB values
        scaled_rgb = tuple(
            int(val * 255) for val in rgb
        )  # Scale RGB values to 0-255 range
        colors.append(scaled_rgb)

    return colors


def hsv_to_rgb(h, s, v):
    c = v * s
    x = c * (1 - abs((h / 60) % 2 - 1))
    m = v - c

    if 0 <= h < 60:
        rgb = (c, x, 0)
    elif 60 <= h < 120:
        rgb = (x, c, 0)
    elif 120 <= h < 180:
        rgb = (0, c, x)
    elif 180 <= h < 240:
        rgb = (0, x, c)
    elif 240 <= h < 300:
        rgb = (x, 0, c)
    else:
        rgb = (c, 0, x)

    return tuple((val + m) for val in rgb)

In [None]:
import numpy as np
from PIL import Image, ImageDraw

def generate_binary_mask(annotations, image_size):
    # Create a blank image
    image = Image.new('L', image_size, 0)
    draw = ImageDraw.Draw(image)
    for ann in annotations:
        if ann["label"] != "Wire":
            continue
        segs = ann["segmentation"]
        for seg in segs:
            extPoints = seg["extPoints"]
            intPoints = seg["intPoints"]
            if len(extPoints) < 2 and len(intPoints) < 2:
                continue
            intPoints = [list(map(tuple,int)) for int in intPoints]
            extPoints = list(map(tuple,extPoints))
            # Draw the exterior contour
            draw.polygon(extPoints, outline=1, fill=1)

            # Draw the interior contours
            for interior_contour in intPoints:
                draw.polygon(interior_contour, outline=0, fill=0)

    # Convert the image to a binary mask (numpy array)
    binary_mask = np.array(image)

    return binary_mask

In [None]:
with open("wires.json") as f:
    data = json.load(f)

In [None]:
# url_set = set()
# for image in data["images"]:
#     url_set.add((image["url"], image["id"]))

# image_set = set()
# imageid_set = set()

# for url, id in url_set:
#     req = requests.get("http://" + url)
#     img = Image.open(BytesIO(req.content)).convert("RGB").resize((448, 448))
#     byte_img = img.tobytes()
#     if byte_img not in image_set:
#         image_set.add(byte_img)
#         imageid_set.add(id)
        
# disclude_id_set = set(
#     [
#         "64696be14d61f800078e5be9",
#         "648b7c226f177e0007432879",
#         "646c4baeeee7ce000765e791",
#         "646c5784eee7ce00076678f9",
#         "646c6335eee7ce0007670362",
#         "646ce574eee7ce00076f6273",
#         "646ce65665bbde0007f09705",
#         "646ce5dce326cf0007b927f2",
#         "646d9cae09affa0007c9a1f0",
#         "646da0bf09affa0007c9cd35",
#         "646dbe8b65bbde0007fe0640",
#         "646df9dd09affa0007cea11e",
#         "646e15bfe326cf0007ca4d5f",
#         "646e185b4d61f80007ca4af1",
#         "646e185ae326cf0007ca7d3f",
#         "646e276109affa0007d1dab3",
#         "646e325809affa0007d2c5ec",
#         "646e361909affa0007d31a96",
#         "646e372b4d61f80007cc9555",
#         "646e37de09affa0007d341ca",
#         "646e37d609affa0007d340e0",
#         "646e3ea109affa0007d3cb65",
#         "646ebed7e326cf0007d4e660",
#         "646ed5ac4d61f80007d69cee",
#         "646f6a304d61f80007def5ca",
#         "647000e2cdccc800070edc63",
#         "647023199999350007945f9d",
#         "647058565ad8d50007eb67e9",
#         "64715fa42687e40007e9fa99",
#         "64716edb2687e40007eaa357",
#         "647179542687e40007eb1590",
#         "647185f2fa532900077861da",
#         "6471ae9a2687e40007eddabc",
#         "6471b3822687e40007ee10b3",
#         "6471c9892687e40007ef465c",
#         "6471e04ffa532900077c4ffc",
#         "6474cead50ebc70007df89d7",
#         "64755af42687e40007187283",
#         "64687349eee7ce00073655c8",
# "646882634d61f8000783b0f1",
# "646c3d0feee7ce00076544da",
# "646c4c97eee7ce000765f48f",
# "646c72a565bbde0007e89b6c",
# "646c7346eee7ce000767cca7",
# "646cc34deee7ce00076c53f3",
# "646cc69ce326cf0007b6b0de",
# "646cdb89eee7ce00076e65e5",
# "646ce0e565bbde0007f02c55",
# "646ce56eeee7ce00076f6223",
# "646d991a65bbde0007fc2c0d",
# "646da40465bbde0007fcaebc",
# "646da6f509affa0007ca141a",
# "646dc01509affa0007cb4e9c",
# "6476d3f97d099e000720e9c9",
# "646dcd5865bbde0007fee2d2",
# "646de2af09affa0007cd4764",
# "646de35165bbde0007002d58",
# "646df3c365bbde0007012a84",
# "646e058f09affa0007cf4fee",
# "646e11db09affa0007d011a3",
# "646e188e09affa0007d09cca",
# "646ec2bb4d61f80007d5cefb",
# "646ecb05e326cf0007d56e94",
# "646ed5704d61f80007d69904",
# "646f5cad65bbde00071876a7",
# "646f713c65bbde00071a3273",
# "647005228f02f700070dcc1c",
# "64700117cdccc800070edfd4",
# "647027d011dca20007a66a7a",
# "64700133cdccc800070ee2cf",
# "6470324c11dca20007a6e55a",
# "6470592111dca20007a8a3e4",
# "6470a16011dca20007ac705b",
# "6470b42ed49b3a0007515686",
# "6470da9cd49b3a00075373c7",
# "64716676fa532900077701a1",
# "6471744e2687e40007eae331",
# "64756bb82687e400071917ef",
# "646c64fceee7ce000767182a",
# "646c654c4d61f80007b27c02",
# "646de34d09affa0007cd5079",
# "646defed4d61f80007c7e165",
# "646e2e8f09affa0007d2760b",
# "646eb85f65bbde00070ee93c",
# "646ed60f65bbde000710303e",
# "64704bbd9999350007962c15",
#     ]
# )

In [None]:
import matplotlib.pyplot as plt

In [None]:
prompt_imgs = []
prompt_masks = []
i = 0
for image in data["images"]:
    if image["id"] in imageid_set and image["id"] not in disclude_id_set:
        # req = requests.get("http://" + image["url"])
        # img = Image.open(BytesIO(req.content)).convert("RGB")
        # prompt_imgs.append(img)
        # # prompt_mask = construct_mask(img, image["tags"])
        # prompt_mask = generate_binary_mask(image["tags"],img.size)
        # prompt_masks.append(prompt_mask)
        # print(image["id"])
        # out = overlay_segmentation(img,prompt_mask)
        # plt.imshow(out)
        # plt.axis("off")
        # plt.show()
        i+=1
print(i)
        

In [None]:
import numpy as np
import cv2

def overlay_segmentation(image, mask):
    # Convert the PIL image to a numpy array
    image_array = np.array(image)
    num_instances = len(np.unique(mask))
    # Create a copy of the image array to draw on
    overlay = image_array.copy()

    # Define colors for each instance ID
    colors = generate_equally_spaced_colors(num_instances)

    # Draw each instance in a different color on the overlay image
    for instance_id in np.unique(mask):
        if instance_id == 0:
            continue

        # Create a binary mask for the current instance ID
        instance_mask = np.where(mask == instance_id, 255, 0).astype(np.uint8)

        # Find contours in the binary mask
        contours, _ = cv2.findContours(instance_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Draw the contours on the overlay image
        cv2.drawContours(overlay, contours, -1, colors[instance_id], thickness=cv2.FILLED)

    # Blend the overlay image with the original image
    blended_image = cv2.addWeighted(overlay, 0.5, image_array, 0.5, 0)

    # Convert the blended image back to PIL format
    blended_image_pil = Image.fromarray(blended_image)

    return blended_image_pil


In [None]:
import matplotlib.pyplot as plt

In [None]:
import seggpt_inference

In [None]:
import torch

In [None]:
from models_seggpt import LearnablePrompt
prompt = LearnablePrompt()

prompt.load_state_dict(torch.load("bestwires_learned_promptv4.pt"))
torch.cuda.empty_cache()

In [None]:
import numpy as np
from PIL import Image

def split_image_into_grids(image, num_horizontal_cells, num_vertical_cells):
    width, height = image.size
    cell_width = width // num_horizontal_cells
    cell_height = height // num_vertical_cells
    grids = []
    grid_positions = []
    for y in range(0, num_vertical_cells):
        for x in range(0, num_horizontal_cells):
            grid_left = x * cell_width
            grid_right = (x + 1) * cell_width
            grid_top = y * cell_height
            grid_bottom = (y + 1) * cell_height
            
            if x == num_horizontal_cells - 1:
                grid_right = width
                
            if y == num_vertical_cells - 1:
                grid_bottom = height

            grid = image.crop((grid_left, grid_top, grid_right, grid_bottom))
            grids.append(grid)
            grid_positions.append((grid_left,grid_top))
    return grids,grid_positions


def stitch_masks(image_size, masks,grid_positions):
    width, height = image_size
    mask = np.zeros((height, width), dtype=np.uint8)

    for (x, y), m in zip(grid_positions,masks):
        mask[y:y + m.shape[0], x:x + m.shape[1]] = m

    return mask




In [None]:
import numpy as np
from skimage.measure import label, regionprops

def separate_masks(binary_mask, area_threshold):
    # Label connected components in the binary mask
    labeled_mask = label(binary_mask)
    
    # Get region properties of each connected component
    regions = regionprops(labeled_mask)
    
    # Initialize an empty list to store individual masks
    separate_masks = []
    
    # Iterate over each region and create a separate binary mask
    for region in regions:
        # Filter regions based on area threshold
        if region.area >= area_threshold:
            instance_mask = (labeled_mask == region.label).astype(np.uint8)
            separate_masks.append(instance_mask)
    
    return separate_masks


In [None]:
def combine_masks(masks):
    # Initialize an empty array to store the combined mask
    combined_mask = np.zeros_like(masks[0])

    # Assign unique instance IDs to each mask
    for i, mask in enumerate(masks, start=1):
        # Find the indices where the mask is True
        indices = np.where(mask == 1)

        # Assign the instance ID to those indices in the combined mask
        combined_mask[indices] = i

    return combined_mask

In [None]:
import numpy as np
import cv2

def overlay_segmentation(image, mask):
    # Convert the PIL image to a numpy array
    image_array = np.array(image)
    num_instances = len(np.unique(mask))
    # Create a copy of the image array to draw on
    overlay = image_array.copy()

    # Define colors for each instance ID
    colors = generate_equally_spaced_colors(num_instances)

    # Draw each instance in a different color on the overlay image
    for instance_id in np.unique(mask):
        if instance_id == 0:
            continue

        # Create a binary mask for the current instance ID
        instance_mask = np.where(mask == instance_id, 255, 0).astype(np.uint8)

        # Find contours in the binary mask
        contours, _ = cv2.findContours(instance_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Draw the contours on the overlay image
        cv2.drawContours(overlay, contours, -1, colors[instance_id], thickness=cv2.FILLED)

    # Blend the overlay image with the original image
    blended_image = cv2.addWeighted(overlay, 0.5, image_array, 0.5, 0)

    # Convert the blended image back to PIL format
    blended_image_pil = Image.fromarray(blended_image)

    return blended_image_pil


In [None]:
# ##Multilivel!!!
# ##Now lets auto-annotate the rest of the images using the single prompt_img/mask
# for image in data['images']:
#     if image['id'] == "64baf8d5832640000758762c": ##Don't predict your prompt_image
#         continue 
#     if image["url"].endswith("comundefined"): ##Matroid Backend Saving Image error
#         continue
#     req = requests.get("http://" + image["url"])
#     test_image = Image.open(BytesIO(req.content)).convert("RGB")
#     image_size = test_image.size  
#     ##Add Whole Image to Grids to be predicted on   
#     test_grids = [test_image]
    
#     num_horizontal_cells = 4
#     num_vertical_cells = 4
#     # Split the PIL image into grids
#     grids4,grids_pos_4 = split_image_into_grids(test_image, num_horizontal_cells, num_vertical_cells)
#     ##Add 4*4 Tiled Image to Grids to be predicted on   
#     test_grids = test_grids + grids4
    
#     num_horizontal_cells = 2
#     num_vertical_cells = 2
#     # Split the PIL image into grids
#     grids2,grids_pos_2 = split_image_into_grids(test_image, num_horizontal_cells, num_vertical_cells)
#     ##Add 2*2 Tiled Image to Grids to be predicted on 
#     test_grids = test_grids + grids2
#     # Predict instance segmentation masks for each grid
#     masks,_ = seggpt_inference.predict_batch(prompt_img,prompt_mask,test_grids,400)
#     # Stitch the masks together to create one segmentation mask
#     result_mask4 = stitch_masks(image_size, masks[1:17], grids_pos_4)
#     result_mask2 = stitch_masks(image_size, masks[17:], grids_pos_2)
#     out_mask = masks[0]
#     area_whole = total_segmented_area(out_mask)
#     area_grid3 = total_segmented_area(result_mask2)
#     area_grid4 = total_segmented_area(result_mask4)
#     mask_levels = [out_mask,result_mask2,result_mask4]
#     mask_areas = [area_whole,area_grid3,area_grid4]
#     # if area_whole > area_grid:
#     #     masks = separate_masks(out_mask,30)
#     # else:
#     #     masks = separate_masks(result_mask,30)
#     ##Pick the mask at the most appropriate level
#     masks = separate_masks(mask_levels[np.argmax(mask_areas)],30)
#     if len(masks) > 0:
#         combined_mask = combine_masks(masks)
#         mask_overlay = overlay_segmentation(test_image,combined_mask)
#     else:
#         mask_overlay = test_image
    
#     # plt.imshow(test_image)
#     # plt.show()
#     # plt.imshow(out_mask)
#     # plt.show()
#     # plt.imshow(result_mask2)
#     # plt.show()
#     # plt.imshow(result_mask4)
#     # plt.show()
#     plt.imshow(mask_overlay)
#     plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
def calculate_bounding_boxes(mask):
    """
    Calculate the bounding boxes of objects from a mask where each object has a different instance ID.

    Parameters:
        mask (numpy.ndarray): Mask where each object has a unique instance ID (integers).

    Returns:
        List: A list of bounding box tuples [(x_min, y_min, x_max, y_max)] for each object in the mask.
    """
    unique_ids = np.unique(mask)
    unique_ids = unique_ids[unique_ids != 0]  # Remove background ID (usually 0)

    bounding_boxes = []

    for obj_id in unique_ids:
        # Find indices of the object with a specific ID
        object_indices = np.where(mask == obj_id)

        # Calculate minimum and maximum coordinates along x and y axes
        x_min = np.min(object_indices[1])
        x_max = np.max(object_indices[1])
        y_min = np.min(object_indices[0])
        y_max = np.max(object_indices[0])

        bounding_boxes.append((x_min, y_min, x_max, y_max))

    return bounding_boxes


In [None]:
# ##Multilivel!!!
# ##Now lets auto-annotate the rest of the images using the single prompt_img/mask
# for image in data['images']:
#     if image['id'] == "64baf8d583264000075875da": ##Don't predict your prompt_image
#         continue 
#     if image["url"].endswith("comundefined"): ##Matroid Backend Saving Image error
#         continue
#     req = requests.get("http://" + image["url"])
#     test_image = Image.open(BytesIO(req.content)).convert("RGB")
#     combined_mask = seggpt_inference.predict_tiled(prompt_img,prompt_mask,test_image)
#     if combined_mask is not None:
#         bboxes = calculate_bounding_boxes(combined_mask)
#         mask_overlay = overlay_segmentation(test_image,combined_mask)
#         plt.imshow(mask_overlay)
#         for (x_min,y_min,x_max,y_max) in bboxes:
#             plt.gca().add_patch(plt.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, 
#                                         linewidth=0.5, edgecolor='r', facecolor='none'))
#         plt.axis('off')
#         plt.show()
#     else:
#         mask_overlay = test_image
#         plt.imshow(mask_overlay)
#         plt.axis('off')
#         plt.show()



In [None]:
def overlay_mask(image, mask):
  """
  Overlays a binary mask on a PIL image.
  
  Args:
    image: PIL Image to overlay mask on
    mask: 2D numpy array with shape (H, W) containing binary mask
  """
  
  # Convert mask to PIL Image
  mask = Image.fromarray(mask.astype(np.uint8)*255)
  mask = mask.resize(image.size)
  
  # Convert to RGBA
  mask = mask.convert('RGBA')
  mask = np.array(mask)
  
  # Overlay mask red channel onto RGB image 
  image = np.array(image) 
  image[:,:,0] = np.where(mask[:,:,3] > 0, mask[:,:,0], image[:,:,0])
  
  return Image.fromarray(image)

In [None]:
torch.cuda.empty_cache()

In [None]:
prompt_img,prompt_mask = prompt()

In [None]:
def calculate_miou(pred, gt):

  classes = np.unique(gt)
  classes = classes[classes != 0]
  if len(classes) == 0:
    return 0 # or nan/other default value
  iou_list = []

  for c in classes:
    pred_c = (pred == c)   
    gt_c = (gt == c)

    intersection = np.logical_and(pred_c, gt_c).sum()
    union = np.logical_or(pred_c, gt_c).sum()
    eps = 1e-6
    iou = intersection / (union + eps)
    iou_list.append(iou)

  miou = np.mean(iou_list)
  return miou 

In [23]:
##Multilivel!!!
##Now lets auto-annotate the rest of the images using the single prompt_img/mask
count = 0
mIoU_no_tiling = []
mIoU_2_2 = []
mIoU_4_4 = []
mIoU_combined = []
for image in data['images']:
    if image["url"].endswith("comundefined"): ##Matroid Backend Saving Image error
        continue 
    count += 1
    req = requests.get("http://" + image["url"])
    test_image = Image.open(BytesIO(req.content)).convert("RGB")
    # combined_mask = seggpt_inference.predict_tiled_finetuned(prompt_img,prompt_mask,test_image)
    masks,level = seggpt_inference.predict_tiled_finetuned(prompt_img,prompt_mask,test_image)    
    mask_overlay = overlay_mask(test_image,masks[level])
    levels = {0:"Original",1:"2x2",2:"4x4"}
    gt_mask = generate_binary_mask(image["tags"],test_image.size)
    mIoU_no_tiling.append(calculate_miou(masks[0],gt_mask))
    mIoU_2_2.append(calculate_miou(masks[1],gt_mask))
    mIoU_4_4.append(calculate_miou(masks[2],gt_mask))
    mIoU_combined.append(calculate_miou(masks[level],gt_mask))
    # plt.imshow(mask_overlay)
    # plt.title("Overlayed Image")
    # plt.axis('off')
    # plt.show()
    # plt.imshow(gt_mask)
    # plt.title("GT mask")
    # plt.axis('off')
    # plt.show()
    # plt.imshow(masks[0])
    # plt.title("No Tiling Mask")
    # plt.axis('off')
    # plt.show()
    # plt.imshow(masks[1])
    # plt.title("2x2 Tiling Mask")
    # plt.axis('off')
    # plt.show()
    # plt.imshow(masks[2])
    # plt.title("4x4 Tiling Mask")
    # plt.axis('off')
    # plt.show()
    # print(f"Chosen Level: {levels[level]}")
    # i += 1
    # if i == 5:
    #     break


KeyboardInterrupt: 

In [25]:
np.mean(mIoU_4_4)

0.7820374339156037

In [26]:
np.mean(mIoU_2_2)

0.7836701612463659

In [27]:
np.mean(mIoU_no_tiling)

0.6730860110201617

In [28]:
np.mean(mIoU_combined)

0.7845496576657801