In [None]:
import os
from dds_cloudapi_sdk import DDSCloudAPI, V2Task
import supervision as sv
import numpy as np
from PIL import Image
from tqdm import tqdm
import cv2
from pathlib import Path
from pycocotools import mask as maskUtils
from deepdataspace.client import Client
from deepdataspace.model import create_task_with_local_image_auto_resize
from deepdataspace import Config
import torch
from segment_anything import sam_model_registry, SamPredictor


# Implementing SAM

Creating helper functions 

In [None]:
class Masking:
    def __init__(self, img_path):
        self.img_path = img_path
        self.img_name = os.path.basename(self.img_path)
        with rio.open(self.img_path) as src:
            self.img_16bit = src.read((1, 2, 3))  # RGB bands
        self.img = img_as_ubyte(self.img_16bit)

    def img_profile(self):
        with rio.open(self.img_path) as src:
            profile = src.profile
            profile.update({
                'count': 1,
                'dtype': 'uint8',
                'nodata': 0
            })
        return profile

    def write_image(self):
        output_dir = 'RGB_8bit'
        os.makedirs(output_dir, exist_ok=True)
        self.rgb_path = os.path.join(output_dir, f'converted_{self.img_name}')
        profile = self.img_profile()
        profile.update({'count': 3})
        with rio.open(self.rgb_path, 'w', **profile) as dst:
            dst.write(np.clip(self.img, 0, 255))
        return self.rgb_path

def open_mask(mask_path):
    with rio.open(mask_path) as src:
        return src.read(1)

Processing images using SAM

In [None]:

def process_images(image_paths, text_prompts):
    sam = LangSAM()

    tiff_output_dir = r"output_folder/TIFF_combined_new"
    png_output_dir = r"output_folder/Masks_PNG_combined_new"
    plot_output_dir = r"output_folder/Plots_combined_new"

    os.makedirs(tiff_output_dir, exist_ok=True)
    os.makedirs(png_output_dir, exist_ok=True)
    os.makedirs(plot_output_dir, exist_ok=True)

    total_time = 0.0

    for idx_img, img_path in enumerate(image_paths, start=1):
        tile = Masking(img_path)
        print(f"\nProcessing: {tile.img_name}")
        start_time = time.time()   # <-- start timing

        tile.write_image()
        profile = tile.img_profile()
        combined_mask = None

        for idx, prompt in enumerate(text_prompts):
            print(f"  â†’ Predicting class {idx + 1}: '{prompt}'")
            temp_mask_path = "temp_mask.tif"
            sam.predict(tile.rgb_path, prompt,
                        box_threshold=0.24, text_threshold=0.24,
                        output=temp_mask_path)
            mask = open_mask(temp_mask_path)
            binary_mask = (mask > 0).astype(np.uint8)

            if combined_mask is None:
                combined_mask = np.zeros_like(binary_mask, dtype=np.uint8)

            combined_mask[binary_mask == 1] = idx + 1  # label values: 1..N

        # Save combined mask as TIFF
        combined_tiff_path = os.path.join(
            tiff_output_dir, f"{os.path.splitext(tile.img_name)[0]}_combined_mask.tif"
        )
        with rio.open(combined_tiff_path, 'w', **profile) as dst:
            dst.write(combined_mask, 1)
        print(f"Saved combined TIFF mask to: {combined_tiff_path}")

        # Save colored PNG
        combined_png_path = os.path.join(
            png_output_dir, f"{os.path.splitext(tile.img_name)[0]}_combined_mask.png"
        )
        plt.imsave(combined_png_path, combined_mask, cmap='tab10', vmin=0, vmax=len(text_prompts))
        print(f"Saved colored PNG to: {combined_png_path}")

        # Save side-by-side plot
        fig, ax = plt.subplots(1, 2, figsize=(14, 7))
        ax[0].imshow(np.moveaxis(tile.img, 0, -1))
        ax[0].set_title(f"Original: {tile.img_name}")
        ax[0].axis('off')

        ax[1].imshow(combined_mask, cmap='tab10', vmin=0, vmax=len(text_prompts))
        ax[1].set_title("Combined Multi-class Mask")
        ax[1].axis('off')

        plot_path = os.path.join(
            plot_output_dir, f"{os.path.splitext(tile.img_name)[0]}_plot.png"
        )
        plt.tight_layout()
        plt.savefig(plot_path)
        plt.close()
        print(f"Saved side-by-side plot to: {plot_path}")

        # ---- Timing ----
        elapsed = time.time() - start_time
        total_time += elapsed
        avg_time = total_time / idx_img
        print(f"Time for {tile.img_name}: {elapsed:.2f} s | "
              f"Running average: {avg_time:.2f} s")

if __name__ == "__main__":
    image_folder = r"input_folder"
    image_paths = [os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.lower().endswith('.png')]
    text_prompts = ["plant", "forest", "house", "road", "brown soil"]
    process_images(image_paths, text_prompts)

# Implementing DINO-X

## Configuration

In [None]:
API_TOKEN = "e170ac646a7326529a6bd5a937eed69d"  # Replace with your actual token
TEXT_PROMPT = "plant . forest . house . road . brown soil"

INPUT_FOLDER = "/input_folder"   # Folder containing PNG images
OUTPUT_FOLDER = "/output_folder"

# Create output directory
dinox_mask_dir = Path(OUTPUT_FOLDER, "dinox_masks")
dinox_mask_dir.mkdir(parents=True, exist_ok=True)

# Class mapping
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x.strip()]
class_name_to_id = {name: idx + 1 for idx, name in enumerate(classes)}


## Initializing DINOX

In [None]:
config = Config(API_TOKEN)
client = Client(config)

## Looping the files

In [None]:

image_paths = [
    p for p in Path(INPUT_FOLDER).rglob("*.png")
    if not any(suffix in p.name for suffix in ["_dinosam_mask", "_dinox_mask", "_annotated"])
]

if not image_paths:
    print("No PNG images found.")
    exit()

for image_path in image_paths:
    print(f"Processing: {image_path.name}")
    try:
        img = cv2.imread(str(image_path))
        if img is None:
            print(f"Could not read image: {image_path}")
            continue

        height, width = img.shape[:2]
        stem = image_path.stem

        # DINO-X detection (API Call)
        task = create_task_with_local_image_auto_resize(
            api_path="/v2/task/dinox/detection",
            api_body_without_image={
                "model": "DINO-X-1.0",
                "prompt": {"type": "text", "text": TEXT_PROMPT},
                "targets": ["bbox", "mask"], 
                "bbox_threshold": 0.1,
                "iou_threshold": 0.5,
                "mask_format": "coco_rle"
            },
            image_path=str(image_path)
        )
        
        client.run_task(task)
        predictions = task.result.get("objects", [])

        # Process Predictions & Decode Masks
        # Initialize blank mask
        dinox_mask = np.zeros((height, width), dtype=np.uint8)
        
        if not predictions:
            print(f"No detections for {stem}.")
        else:
            for obj in predictions:
                cls = obj["category"].strip().lower()
                class_id = class_name_to_id.get(cls, 0)
                
                if class_id == 0:
                    continue

                # Mask Decoding (COCO RLE)
                if "mask" in obj and obj["mask"] is not None:
                    try:
                        # 1. Handle RLE format structure
                        if isinstance(obj["mask"]["counts"], list):
                            rle = obj["mask"]
                        else:
                            rle = {"size": obj["mask"]["size"], "counts": obj["mask"]["counts"]}

                        # 2. Decode
                        m = maskUtils.decode(rle)
                        
                        # 3. Fix dimensions (Sometimes decodes to HxWx1)
                        if m.ndim == 3:
                            m = np.squeeze(m)
                        
                        # 4. Resize (If API auto-resized the image)
                        if m.shape != (height, width):
                            m = cv2.resize(m, (width, height), interpolation=cv2.INTER_NEAREST)
                        
                        # 5. Apply to final mask
                        dinox_mask[m > 0] = class_id

                    except Exception as e:
                        print(f"Mask decode failed for {cls} in {image_path.name}: {e}")

        # Save Output
        output_path = dinox_mask_dir / f"{stem}_dinox_mask.png"
        cv2.imwrite(str(output_path), dinox_mask)
        print(f"Saved: {output_path.name}")

    except Exception as e:
        print(f"Failed on {image_path.name}: {e}")

print("\nAll images processed.")

# DiSEG-X 

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

INPUT_FOLDER = "/input_folder"
OUTPUT_FOLDER = "/output_folder"
SAM_CHECKPOINT = "/sam_model/sam_vit_h_4b8939.pth"

# Create output directories
disegx_mask_dir = Path(OUTPUT_FOLDER, "disegx_masks")
annotated_dir = Path(OUTPUT_FOLDER, "annotated")

for d in [disegx_mask_dir, annotated_dir]:
    d.mkdir(parents=True, exist_ok=True)

# Class mapping
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x.strip()]
class_name_to_id = {name: idx + 1 for idx, name in enumerate(classes)}

# Load SAM & Initialize Client
print("Loading SAM model...")
sam = sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT).to(DEVICE)
sam_predictor = SamPredictor(sam)

config = Config(API_TOKEN)
client = Client(config)

## Process each PNG image

In [None]:
image_paths = [
    p for p in Path(INPUT_FOLDER).rglob("*.png")
    if not any(suffix in p.name for suffix in ["_disegx_mask", "_annotated"])
]

if not image_paths:
    print("No PNG images found.")
    exit()

In [None]:

for image_path in image_paths:
    print(f"Processing: {image_path.name}")
    try:
        img = cv2.imread(str(image_path))
        if img is None:
            print(f"Could not read image: {image_path}")
            continue

        height, width = img.shape[:2]
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        stem = image_path.stem

        # Step 1: DINO-X Detection (Get Bboxes only)
        task = create_task_with_local_image_auto_resize(
            api_path="/v2/task/dinox/detection",
            api_body_without_image={
                "model": "DINO-X-1.0",
                "prompt": {"type": "text", "text": TEXT_PROMPT},
                "targets": ["bbox"], # We only need boxes for DiSEG-X
                "bbox_threshold": 0.1,
                "iou_threshold": 0.5,
            },
            image_path=str(image_path)
        )
        client.run_task(task)
        predictions = task.result.get("objects", [])

        if not predictions:
            print(f"No detections found for {stem}.")
            continue

        # Step 2: DiSEG-X Workflow (SAM Refinement)
        sam_predictor.set_image(img_rgb)

        disegx_mask = np.zeros((height, width), dtype=np.uint8)
        boxes, labels, class_ids = [], [], []
        det_masks_sam = []

        for obj in predictions:
            cls = obj["category"].strip().lower()
            class_id = class_name_to_id.get(cls, 0)
            if class_id == 0:
                continue

            # Extracting Box
            x, y, w, h = obj["bbox"]
            box = np.array([x, y, x + w, y + h])

            try:
                # SAM Inference (Prompted by Box)
                masks, scores, _ = sam_predictor.predict(box=box, multimask_output=True)
                
                # Selecting best mask (highest score)
                if masks is not None and len(masks) > 0:
                    best_mask_idx = np.argmax(scores)
                    binary_mask = masks[best_mask_idx] > 0.5
                    
                    # Add to DiSEG-X consolidated mask
                    disegx_mask[binary_mask] = class_id
                    
                    # Store for Annotation
                    boxes.append(box)
                    class_ids.append(class_id)
                    labels.append(f"{cls} {obj['score']:.2f}")
                    det_masks_sam.append(binary_mask)
                else:
                    print(f"No valid SAM masks for {cls}")

            except Exception as e:
                print(f"SAM failed for {cls}: {e}")

        # Step 3: Save DiSEG-X Output
        # Saving Mask
        cv2.imwrite(str(disegx_mask_dir / f"{stem}_disegx_mask.png"), disegx_mask)

        # Annotate
        if len(boxes) > 0:
            detections = sv.Detections(
                xyxy=np.array(boxes),
                mask=np.array(det_masks_sam).astype(bool),
                class_id=np.array(class_ids),
            )
            annotated = img.copy()
            annotated = sv.BoxAnnotator().annotate(scene=annotated, detections=detections)
            annotated = sv.LabelAnnotator().annotate(scene=annotated, detections=detections, labels=labels)
            annotated = sv.MaskAnnotator().annotate(scene=annotated, detections=detections)

            cv2.imwrite(str(annotated_dir / f"{stem}_annotated.jpg"), annotated)
        
        print(f"Saved DiSEG-X mask: {stem}")

    except Exception as e:
        print(f"Failed on {image_path.name}: {e}")

print("\nAll images processed.")