In [1]:
!pip install segment-geospatial

Collecting segment-geospatial
  Downloading segment_geospatial-0.12.4-py2.py3-none-any.whl.metadata (11 kB)
Collecting fiona (from segment-geospatial)
  Downloading fiona-1.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (56 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.6/56.6 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
Collecting ipympl (from segment-geospatial)
  Downloading ipympl-0.9.7-py3-none-any.whl.metadata (8.7 kB)
Collecting leafmap (from segment-geospatial)
  Downloading leafmap-0.42.12-py2.py3-none-any.whl.metadata (16 kB)
Collecting localtileserver (from segment-geospatial)
  Downloading localtileserver-0.10.6-py3-none-any.whl.metadata (5.2 kB)
Collecting patool (from segment-geospatial)
  Downloading patool-4.0.0-py2.py3-none-any.whl.metadata (4.5 kB)
Collecting rasterio (from segment-geospatial)
  Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting rioxarray

In [2]:
import time
from typing import Tuple
import torch
from samgeo import SamGeo
from tqdm import tqdm
import google.colab.drive as drive

In [3]:
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(device)

cuda


In [5]:
def samgeo_unit(
        sam: SamGeo,
        input_img_path: str,
        output_folder: str,
        batch: bool=True,
        foreground: bool=True,
        erosion_kernel: Tuple[int, int] = (3, 3),
        mask_multiplier: int =255
    ) -> str:

    output_path = output_folder + '/' + input_img_path.split('/')[-1].split('.')[0] + '_delineation.tif'
    sam.generate(
        input_img_path, output_path, batch=batch, foreground=foreground, erosion_kernel=erosion_kernel, mask_multiplier=mask_multiplier
    )

    output_vector_path = output_path[:-4] + '.gpkg'
    output_shp_path = output_path[:-4] + '.shp'

    sam.tiff_to_gpkg(output_path, output_vector_path, simplify_tolerance=None)
    sam.tiff_to_vector(output_path, output_shp_path)

    return output_vector_path

def segment_tiles(input_img_tiles: str, output_folder: str) -> None:
    img_list = input_img_tiles

    start = time.time()

    for idx, img in enumerate(tqdm(img_list)):
        print(f'Working on tile {idx}')

        output = samgeo_unit(img, output_folder)
        if torch.cuda.is_available:
            torch.cuda.empty_cache()

    end = time.time()
    print('Segmentations finished')
    print (f'Total time taken {(end-start)/60} mins')

In [6]:
sam_kwargs = {
    'points_per_side': 128,
    'pred_iou_thresh': 0.80,
    'stability_score_thresh': 0.95,
    'crop_n_layers': 1,
    'crop_n_points_downscale_factor': 2,
    'min_mask_region_area': 100,
}

sam = SamGeo(
    model_type='vit_h',
    checkpoint='sam_vit_h_4b8939.pth',
    sam_kwargs=sam_kwargs,
    device=device
)

Model checkpoint for vit_h not found.


Downloading...
From: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
To: /root/.cache/torch/hub/checkpoints/sam_vit_h_4b8939.pth
100%|██████████| 2.56G/2.56G [00:16<00:00, 152MB/s]
  state_dict = torch.load(f)


In [7]:
input_img = '/content/drive/MyDrive/(4) Execute Stage /Rafid/samgeo_experiments/ground_truths/patch_1.tif'
export_folder = '/content/drive/MyDrive/(4) Execute Stage /Rafid/samgeo_experiments/validation'

In [8]:
_ = samgeo_unit(sam, input_img, export_folder)

100%|██████████| 25/25 [1:05:50<00:00, 158.04s/it]
