<a href="https://colab.research.google.com/github/LucasOsco/AI-RemoteSensing/blob/main/image_notebook/sam_point_v01.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Part of this code utilizes the `segment-geospatial` package, which is a Python package designed for segmenting geospatial data using the Segment Anything Model (SAM). This package was created by Professor Dr. Qiusheng Wu.

The `segment-geospatial` package has been adapted from the `segment-anything-eo` repository, originally authored by Aliaksandr Hancharenka. The main goal of the `segment-geospatial` package is to simplify the process of applying SAM to geospatial data analysis, making it more accessible and requiring minimal coding effort.

To use this package, it's available for download and installation via PyPI and conda-forge.

For more information, details, and examples on how to use this package, you can visit Professor Wu's GitHub page at [https://github.com/opengeos/segment-geospatial](https://github.com/opengeos/segment-geospatial).

In [None]:
# Install the necessary libraries
!pip install pycrs segment-geospatial leafmap localtileserver

In [None]:
# Import required libraries
import os
import zipfile
import leafmap
import geopandas as gpd
import rasterio
import glob
from rasterio.merge import merge
from rasterio.features import shapes
from samgeo import SamGeo, tms_to_geotiff
from segment_anything import sam_model_registry

In [None]:
# Define a function to convert the GeoDataFrame to a list of [x, y] pairs
def convert_to_coord_pairs(gdf):
    """Convert the GeoDataFrame to a list of [x, y] pairs.

    Args:
        gdf (GeoDataFrame): Geospatial data

    Returns:
        list: A list of [x, y] coordinates
    """
    return [[point.x, point.y] for point in gdf.geometry]

In [None]:
# Define a function to add a file to the zip file
def add_to_zip(zipf, output_file):
    """Add file to zip.

    Args:
        zipf (zipfile.ZipFile): zip file
        output_file (str): file path
    """
    zipf.write(output_file)

In [None]:
# Read the image
image = 'Image_Tree.tif' # Switch to your image instead

# Read the shapefile
shapefile = 'ROI_Point.shp' # Switch to your shapefile instead
gdf = gpd.read_file(shapefile)

# Convert the GeoDataFrame to a list of [x, y] pairs
point_coords = convert_to_coord_pairs(gdf)

# Define SAM's model and path
out_dir = os.path.join(os.path.expanduser("~"), "Downloads")
checkpoint = os.path.join(out_dir, "sam_vit_h_4b8939.pth")

In [None]:
# Initialize the SamPredictor
sam = SamGeo(
    model_type="vit_h",
    checkpoint=checkpoint,
    automatic=False,
    sam_kwargs=None,
)

# Specify the image to segment
sam.set_image(image)

In [None]:
# Prepare zip file to be saved
zip_name = 'masks.zip'
zipf = zipfile.ZipFile(zip_name, 'w', zipfile.ZIP_DEFLATED)

# Loop through each point in point_coords
for i, point in enumerate(point_coords):
    # Predict and save to a unique file per point
    output_file = f'mask_{i}.tif'
    sam.predict([point], point_labels=1, point_crs="EPSG:4326", output=output_file)
    
    # Add file to zip
    add_to_zip(zipf, output_file)

# Close the zip file after all files are added
zipf.close()

In [None]:
def load_rasters(file_pattern):
    """Load rasters based on a file pattern.

    Args:
        file_pattern (str): File pattern to match (e.g., "mask_*.tif")

    Returns:
        list: A list of rasterio dataset objects
    """
    file_paths = glob.glob(file_pattern)
    return [rasterio.open(fp) for fp in file_paths]

def merge_rasters(raster_datasets):
    """Merge a list of raster datasets into one.

    Args:
        raster_datasets (list): A list of rasterio dataset objects

    Returns:
        tuple: A tuple containing the merged dataset and the associated transform
    """
    return merge(raster_datasets)

def save_merged_raster(merged, transform, meta, output_file):
    """Save the merged raster to disk.

    Args:
        merged (numpy array): The merged raster data
        transform (Affine): The transformation associated with the merged raster
        meta (dict): Metadata to be associated with the output raster
        output_file (str): The path to the output file
    """
    meta.update({
        "driver": "GTiff",
        "height": merged.shape[1],
        "width": merged.shape[2],
        "transform": transform,
        "crs": "EPSG:4326"
    })

    with rasterio.open(output_file, "w", **meta) as dest:
        dest.write(merged)

# Get a list of your mask files
dem_fps = load_rasters("mask_*.tif")

# Merge them into a single file
mosaic, out_trans = merge_rasters(dem_fps)

# Copy the metadata
out_meta = dem_fps[0].meta.copy()

# Define output file
output_file = 'mosaic.tif'

# Save merged raster
save_merged_raster(mosaic, out_trans, out_meta, output_file)

# Close the files
for src in dem_fps:
    src.close()

In [None]:
def mask_to_shapes(mask, transform):
    """Convert a mask to a list of shapes.

    Args:
        mask (numpy array): The mask
        transform (Affine): The transformation associated with the mask

    Returns:
        list: A list of dictionary objects representing shapes
    """
    return (
        {'properties': {'raster_val': v}, 'geometry': s}
        for i, (s, v) in enumerate(shapes(mask, transform=transform))
        if v == 255  # Add condition to only keep 'trees'
    )

def shapes_to_geodataframe(shapes, crs):
    """Convert a list of shapes to a GeoDataFrame.

    Args:
        shapes (list): A list of dictionary objects representing shapes
        crs (CRS or dict): Coordinate Reference System

    Returns:
        GeoDataFrame: A GeoDataFrame representing the shapes
    """
    gdf = gpd.GeoDataFrame.from_features(shapes)
    gdf.crs = crs
    return gdf

def save_geodataframe(gdf, filename):
    """Save a GeoDataFrame to a Shapefile.

    Args:
        gdf (GeoDataFrame): The GeoDataFrame to be saved
        filename (str): The path to the output Shapefile
    """
    gdf.to_file(filename)

# Convert the mask to integer type
mask = mosaic.astype('int16')

# Here, assuming out_trans is your transform and "EPSG:4326" is your crs
transform = out_trans
crs = "EPSG:4326"

# Convert the mask to a list of shapes
geoms = list(mask_to_shapes(mask, transform))

# Convert the list of shapes to a GeoDataFrame
gdf = shapes_to_geodataframe(geoms, crs)

# Save the GeoDataFrame to a Shapefile
save_geodataframe(gdf, "mask.shp")

In [None]:
# Display the results
mosaic = 'mosaic.tif' # Switch to your directory instead
style={'color': '#a37aa9',}

m = leafmap.Map(center=shapefile)
m.add_raster(mosaic, layer_name="Mask Mosaic",cmap="viridis_r", opacity=0.5)
m.add_vector(shapefile, layer_name='Vector', style=style)
m