Create a `HOME` constant.

In [1]:
import os
HOME = os.getcwd()
print("HOME:", HOME)

HOME: /home/ec2-user/geoseg/segment-anything/notebooks


## Load Model

In [2]:
import torch

DEVICE = torch.device('cuda')
MODEL_TYPE = "vit_b"

In [3]:
import torch
print(torch.cuda.is_available())

True


In [4]:
import sys
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'

Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-i14qdd2t
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-i14qdd2t
  Resolved https://github.com/facebookresearch/segment-anything.git to commit dca509fe793f601edb92606367a655c15ac00fdf
  Preparing metadata (setup.py) ... [?25ldone
[?25h

In [5]:
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

chkpt_path = '../weights/sam_vit_b_01ec64.pth'
sam = sam_model_registry[MODEL_TYPE](checkpoint=chkpt_path).to(device=DEVICE)

predictor = SamPredictor(sam)

In [8]:
from PIL import Image
import cv2
import numpy as np

def process_image(image_path):
    # Read the image using OpenCV
    image = cv2.imread(image_path)
    
    # Verify if the image is loaded
    if image is None:
        print(f'Could not open or find the image: {image_path}')
        return None
    
    # Convert the OpenCV image to a PIL Image for easier manipulation
    image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    
    image_final = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
    
    return image_final

In [26]:
import glob
import os
import json
import torch
import gc
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
import numpy as np
import cv2

# Path to the directories
imgs_path = './prompt_imgs'
labels_path = './prompt_labels'
output_path = './prompt_out'

os.makedirs(output_path, exist_ok=True)

# Get all image and corresponding JSON label files
imgs_files = glob.glob(os.path.join(imgs_path, '*.png'))

# Iterate over all images
for i, img_file in enumerate(imgs_files):
    # Load image
    img = cv2.imread(img_file)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Set image for the predictor
    predictor.set_image(img)

    # Load corresponding JSON file
    json_file = os.path.join(labels_path, os.path.basename(img_file).replace('.png', '.json'))
    with open(json_file, 'r') as file:
        data = json.load(file)
        bboxes = data['bboxes']
        
    original_json = f"./dataset/{os.path.basename(img_file).replace('.png', '').replace('.json', '')}.json"
    img_metadata = None
    with open(original_json, 'r') as original_file:
        original_data = json.load(original_file)
        img_metadata = original_data['metadata']
    # Prepare bounding boxes in batches of 20
    input_boxes_batches = []
    curr_batch = []
    count = 0
    for bbox in bboxes:
        x, y, w, h = bbox
        count += 1
        if count != 20:
            curr_batch.append([x, y, x + w, y + h])
        else:
            input_boxes_batches.append(curr_batch)
            curr_batch = [[x, y, x + w, y + h]]
            count = 1
    if curr_batch:  # Add remaining boxes if any
        input_boxes_batches.append(curr_batch)

    # Process each batch of boxes and generate masks
    masks_batches = []
    for input_boxes in input_boxes_batches:
        tens_boxes = torch.tensor(input_boxes, device=predictor.device)
        transformed_boxes = predictor.transform.apply_boxes_torch(tens_boxes, img.shape[:2])
        masks, _, _ = predictor.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes,
            multimask_output=False,
        )
        masks_batches.append(masks)

        # Cleanup to manage GPU memory
        del tens_boxes
        torch.cuda.empty_cache()
        gc.collect()

    # Collect results
    result = []
    for masks in masks_batches:
        for mask in masks:
            result.append(mask.cpu().numpy())

    # Convert masks to polygons and prepare JSON
    segments = []
    for res in result:
        # Ensure the mask is a 2D array
        if res.ndim > 2:
            res = res.squeeze()
        if res.ndim == 2 and res.size > 0:
            mask = (res.astype(np.uint8) * 255)
            contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
            for contour in contours:
                poly = cv2.approxPolyDP(contour, 1, True).reshape(-1, 2).tolist()
                if len(poly) < 3:  # Skip invalid polygons
                    continue
                x, y, w, h = cv2.boundingRect(contour)
                area = cv2.contourArea(contour)
                segments.append({"bbox": [x, y, w, h], "polygon": [poly], "area": area})

    # Save the metadata as a JSON file
    json_path = os.path.join(output_path, f"{os.path.basename(img_file).replace('.png', '').replace('.json', '')}.json")
    with open(json_path, 'w') as f:
        json.dump({"segments": segments, "metadata": img_metadata}, f)

    print(f"Processed and saved results for {os.path.basename(img_file)}")

Processed and saved results for ZL0_0732_0731938166_896EBY_N0363294ZCAM07114_1100LMJ01.png
Processed and saved results for ZL0_0763_0734689015_318EBY_N0380000ZCAM07114_1100LMJ01.png
Processed and saved results for ZL0_0764_0734772775_443EBY_N0380944ZCAM07114_1100LMJ03.png
Processed and saved results for ZL0_0896_0746497876_285EBY_N0440820ZCAM07114_0340LMJ01.png
Processed and saved results for ZL0_0899_0746759154_159EBY_N0440898ZCAM07114_0340LMJ02.png
Processed and saved results for ZL0_0950_0751290309_364EBY_N0461870ZCAM07114_0340LMJ01.png
Processed and saved results for ZL0_0950_0751290403_363EBY_N0461870ZCAM07114_1100LMJ01.png
Processed and saved results for ZL0_1000_0755729009_894EBY_N0474404ZCAM07114_0340LMJ01.png
Processed and saved results for ZL0_1000_0755729038_894EBY_N0474404ZCAM07114_1100LMJ01.png
Processed and saved results for ZL0_1006_0756258025_456EBY_N0481156ZCAM07114_0340LMJ01.png
