## Imports

In [2]:
import leafmap
import geopandas as gpd
import rasterio
from rasterio.warp import calculate_default_transform, reproject, Resampling
from rasterio.features import rasterize
from rasterio.mask import mask
from rasterio.plot import show
import numpy as np
import matplotlib.pyplot as plt
import json
from shapely import wkt
import torch
from samgeo import SamGeo, SamGeo2, raster_to_vector, overlay_images

## Generating Feature Tables

In [None]:
sam2 = SamGeo2(
    model_id="sam2-hiera-large",
    device="cuda",
    apply_postprocessing=False,
    points_per_side=32,
    points_per_batch=64,
    pred_iou_thresh=0.8,
    stability_score_thresh=0.92,
    stability_score_offset=0.7,
    crop_n_layers=4,
    crop_nms_thresh=0.8,
    crop_overlap_ratio=0.1,
    box_nms_thresh=0.9,
    crop_n_points_downscale_factor=1,
    min_mask_region_area=0.1,
    use_m2m=True,
)

In [None]:
array, gdf = sam2.region_groups(
    "./Temporary_Data/masks4.tif", min_size=25, out_vector="masks4_regions.geojson", out_image="masks4_regions.tif"
)
gdf

## Class Labelling

In [None]:
gdf = gpd.read_file("./Temporary_Data/mask5.shp")
print(gdf.head())

image_path = "norm_img.tif"
image_src = rasterio.open(image_path)

In [None]:
for idx, row in gdf.iterrows():
    geom = [row.geometry]
    out_image, out_transform = mask(image_src, geom, crop=True, nodata=0, filled=True)
    out_image = out_image.transpose(1, 2, 0)

    if out_image.min() == 0 and out_image.max() == 0:
        print(f"Mask {idx} is empty, skipping...")
        continue

    width, height, _ = out_image.shape
    if width < 5 or height < 5:
        print(f"Skipping mask {idx}: too small ({width}x{height} pixels)")
        continue

    # Full-image overview with current mask highlighted
    fig, axes = plt.subplots(1, 2, figsize=(14, 7))

    # Overall image
    show(image_src, ax=axes[0])
    gdf.iloc[[idx]].boundary.plot(ax=axes[0], edgecolor="red", linewidth=2)
    axes[0].set_title(f"Mask {idx} Location in Full Image")

    # Cropped mask image
    #norm_img = out_image.transpose(1, 2, 0)
    axes[1].imshow(out_image, vmin=0, vmax=255)
    axes[1].set_title(f"Mask {idx} Cropped View")
    axes[1].axis("off")

    plt.tight_layout()
    plt.show()


    user_input = input("Press Enter to continue, or type q to quit: ")
    if user_input.strip().lower() == 'q':
        break
