In [37]:
import modal

In [41]:
def convert_filenames(file_path: str):
    import os
    import re

    for filename in os.listdir(file_path):
        # Replace all except last dot with underscore
        new_filename = filename.replace(".", "_", filename.count(".") - 1)
        if not re.search(r"_\d+\.\w+$", new_filename):
            # Add an int to the end of base name
            new_filename = new_filename.replace(".", "_1.")
        os.rename(
            os.path.join(file_path, filename), os.path.join(file_path, new_filename)
        )


convert_filenames("../data/images/Satellite-Curb-Segmentation-7/train")

In [29]:
app = modal.App(name="crossing-distance-sam2-inference")

infer_image = (
    modal.Image.debian_slim(python_version="3.10")
    .apt_install("git", "wget", "python3-opencv", "ffmpeg")
    .pip_install(
        "torch",
        "torchvision",
        "torchaudio",
        "opencv-python==4.10.0.84",
        "pycocotools~=2.0.8",
        "matplotlib~=3.9.2",
        "supervision",
    )
    .run_commands(f"git clone https://git@github.com/facebookresearch/sam2.git")
    .run_commands("pip install -e sam2/.")
    .run_commands("pip install -e 'sam2/.[dev]'")
    .run_commands("cd 'sam2/checkpoints'; ./download_ckpts.sh")
)

weights_volume = modal.Volume.from_name(
    "sam2-weights", create_if_missing=True, environment_name="sam_test"
)
inputs_volume = modal.Volume.from_name("crosswalk-data-sf", environment_name="sfo")
outputs_volume = modal.Volume.from_name(
    "sam2-results", create_if_missing=True, environment_name="sam_test"
)

In [35]:
@app.function(
    volumes={
        "/weights": weights_volume,
        "/inputs": inputs_volume,
        "/outputs": outputs_volume,
    },
    image=infer_image,
    gpu="A10G",
    timeout=3600,
)
def run_inference(model_path: str):
    import os
    import random

    import cv2
    import numpy as np
    import supervision as sv
    import torch
    from PIL import Image
    from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
    from sam2.build_sam import build_sam2

    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    checkpoint = f"../../weights/{model_path}/checkpoints/checkpoint.pt"
    model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
    sam2 = build_sam2(model_cfg, checkpoint, device="cuda")
    mask_generator = SAM2AutomaticMaskGenerator(sam2)

    input_set = os.listdir("/inputs")

    image = random.choice([img for img in input_set])
    image_name = os.path.splitext(os.path.basename(image))[0]
    image_path = os.path.join("/inputs", image)
    opened_image = Image.open(image_path).convert("RGB")
    opened_image = opened_image.resize((1024, 1024), Image.Resampling.LANCZOS)
    opened_image = np.array(opened_image)
    result = mask_generator.generate(opened_image)

    detections = sv.Detections.from_sam(sam_result=result)

    mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
    annotated_image = opened_image.copy()
    annotated_image = mask_annotator.annotate(annotated_image, detections=detections)

    output_dir = f"/outputs/{model_path}"
    os.makedirs(output_dir, exist_ok=True)
    annotated_image_path = os.path.join(output_dir, f"{image_name}_masked.jpg")
    cv2.imwrite(annotated_image_path, annotated_image)


with app.run():
    run_inference.remote("train_1")