<a href="https://colab.research.google.com/github/Christobaltobbin/Segment_Anything_Model-SAM_Kaza_Region/blob/main/Sam_Point_Prompt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##**Sam Code using Point Prompt**


In [None]:
!pip install segment-geospatial

Collecting segment-geospatial
  Downloading segment_geospatial-0.12.3-py2.py3-none-any.whl.metadata (11 kB)
Collecting fiona (from segment-geospatial)
  Downloading fiona-1.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (56 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.6/56.6 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
Collecting ipympl (from segment-geospatial)
  Downloading ipympl-0.9.4-py3-none-any.whl.metadata (8.7 kB)
Collecting leafmap (from segment-geospatial)
  Downloading leafmap-0.42.4-py2.py3-none-any.whl.metadata (16 kB)
Collecting localtileserver (from segment-geospatial)
  Downloading localtileserver-0.10.5-py3-none-any.whl.metadata (5.2 kB)
Collecting patool (from segment-geospatial)
  Downloading patool-3.1.0-py2.py3-none-any.whl.metadata (4.3 kB)
Collecting rasterio (from segment-geospatial)
  Downloading rasterio-1.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting rioxarray 

Mount or Connect to your Drive

In [None]:
from google.colab import drive

# Load the Drive helper and mount
drive.mount('/content/drive')

Mounted at /content/drive


Import necessary Libraries

In [None]:
import os
import leafmap
import rasterio
import numpy as np
import geopandas as gpd
import pandas as pd
import matplotlib.pyplot as plt

from samgeo import SamGeo, show_image, download_file, overlay_images, tms_to_geotiff
from google.colab import output
output.enable_custom_widget_manager()

Load Tiff file

In [None]:
geotiff_path = "/content/drive/MyDrive/Kaza2/ROI_1_WET.tif"

In [None]:
# Open the GeoTIFF file
with rasterio.open(geotiff_path) as dataset:
    # Read the data into an array (all bands)
    data = dataset.read()
    # Print some basic information about the file
    print(f"Width: {dataset.width}")
    print(f"Height: {dataset.height}")
    print(f"Number of bands: {dataset.count}")
    print(f"Coordinate Reference System (CRS): {dataset.crs}")
    print(f"Transform: {dataset.transform}")

Width: 1976
Height: 1873
Number of bands: 4
Coordinate Reference System (CRS): EPSG:32734
Transform: | 10.00, 0.00, 1196980.00|
| 0.00,-10.00, 7997290.00|
| 0.00, 0.00, 1.00|


Let's re-organise the bands in case of sentinel 2 - remember RGB comes as BGR NIR for 4 band sentinel

In [None]:
geotiff_path_reorder = "/content/drive/MyDrive/Kaza2/ROI_1_WET_reordered.tif"
with rasterio.open(geotiff_path) as src:
    # Copy the metadata
    meta = src.meta.copy()

    num_bands = src.count # Get the actual number of bands in the dataset

    new_order = [3, 2, 1]  # Define the new band order

    # Check if the dataset has enough bands for the desired reordering
    if num_bands >= max(new_order):
      # Define the new band order
      new_order = [3, 2, 1]
    else:
      print(f"Warning: Dataset only has {num_bands} bands. Reordering skipped.")
      new_order = list(range(1, num_bands + 1))  # Keep original order


    # Update the metadata if the number of bands changes (optional)
    meta.update(count=len(new_order))

    # Read the bands in the new order
    bands = [src.read(band) for band in new_order]

    # Write the reordered bands to a new file
    with rasterio.open(geotiff_path_reorder, "w", **meta) as dst:
        for i, band_data in enumerate(bands, start=1):
            dst.write(band_data, i)

Let's check if the re-ordering worked

In [None]:
# Paths to your GeoTIFF files
geotiff_path = "/content/drive/MyDrive/Kaza2/ROI_1_WET.tif"
geotiff_path_reorder = "/content/drive/MyDrive/Kaza2/ROI_1_WET_reordered.tif"

# Read the original and reordered data
with rasterio.open(geotiff_path) as src_original:
    data_original = src_original.read()  # Reads all bands as a 3D numpy array
with rasterio.open(geotiff_path_reorder) as src_reordered:
    data_reordered = src_reordered.read()

# Define the expected band order
new_order = [3, 2, 1]

# Check if the bands are reordered correctly
reordered_correctly = np.array_equal(data_original[new_order[0]-1], data_reordered[0]) and \
                       np.array_equal(data_original[new_order[1]-1], data_reordered[1]) and \
                       np.array_equal(data_original[new_order[2]-1], data_reordered[2])


if reordered_correctly:
    print("Bands were reordered successfully!")
else:
    print("Bands were not reordered as expected.")

Bands were reordered successfully!


## Let us normalize the input image to 0/1


In [None]:
# First we load the image and normalize
# Open the GeoTIFF file and normalize each band to [0,1]
with rasterio.open(geotiff_path_reorder) as dataset:
    # Initialize an empty list to hold normalized bands
    normalized_bands = []

    # Iterate over each band
    for i in range(1, dataset.count + 1):  # Bands are 1-indexed in rasterio
        band = dataset.read(i).astype(float)  # Read band and convert to float

        # Normalize the band using min-max scaling to range [0, 1]
        band_min, band_max = band.min(), band.max()
        normalized_band = (band - band_min) / (band_max - band_min)

        # Append to list of normalized bands
        normalized_bands.append(normalized_band)

# Stack normalized bands along a new axis to create a 3D array if needed
normalized_data = np.stack(normalized_bands, axis=0)

# Re-open and save the Normalized Image
output_geotiff_path = "/content/drive/MyDrive/Kaza2/ROI_1_WET_Normalized.tif"
with rasterio.open(geotiff_path) as src:
    # Extract metadata from the original file
    meta = src.meta.copy()

    # Set the data type to 'float32' since we normalized to [0, 1]
    meta.update(dtype=rasterio.float32, count=len(normalized_data))

    # Write the new, normalized bands to the output file
    with rasterio.open(output_geotiff_path, "w", **meta) as dst:
        for i, band in enumerate(normalized_data, start=1):
            dst.write(band.astype(rasterio.float32), i)

print(f"Normalized GeoTIFF created at {output_geotiff_path}")

## saving a normalized output that is only RGB

output_geotiff_path_rgb = "/content/drive/MyDrive/Kaza2/ROI_1_WET_Normalized_RGB.tif"

with rasterio.open(geotiff_path) as src:
    # Extract metadata from the original file
    meta = src.meta.copy()

    # Set the data type to 'float32' since we normalized to [0, 1]
    meta.update(dtype=rasterio.float32, count=len(normalized_data))

    # Write the new, normalized bands to the output file - only first 3 bands
    with rasterio.open(output_geotiff_path_rgb, "w", **meta) as dst:
        for i, band in enumerate(normalized_data[0:3,:,:], start=1):
            dst.write(band.astype(rasterio.float32), i)


print(f"Normalized GeoTIFF created at {output_geotiff_path_rgb}")


Normalized GeoTIFF created at /content/drive/MyDrive/Kaza2/ROI_1_WET_Normalized.tif
Normalized GeoTIFF created at /content/drive/MyDrive/Kaza2/ROI_1_WET_Normalized_RGB.tif


## **Create a rgb_stack that can be used for the model**

- here we should use the normalized data

- **IMPORTANT**: rasterio stores the data information as: (n_bands, height, width) but the modle requires the input to be in (height,width, n_bands)

- this is why we need to re-arrange the bands

In [None]:
# rgb_image = np.dstack((normalized_data[0,:,:],
#                           normalized_data[1,:,:],
#                           normalized_data[2,:,:]))

In [None]:
# rgb_image.shape

(1873, 1976, 3)

## **Read the GeoTIFF and Prepare the RGB NumPy Array**
  Read the tiff file and also preserve the georeferencing data

In [None]:
# Open the GeoTIFF and extract RGB bands
with rasterio.open(output_geotiff_path_rgb) as dataset:
    crs = dataset.crs  # Coordinate Reference System
    transform = dataset.transform  # Geotransform
    red_band = dataset.read(1)  # Band 1
    green_band = dataset.read(2)  # Band 2
    blue_band = dataset.read(3)  # Band 3

    # Stack bands into an RGB NumPy array
    rgb_image = np.dstack((red_band, green_band, blue_band))

# Normalize if needed (scale to 0-255 for SAM)
# rgb_image = (rgb_image / rgb_image.max() * 255).astype(np.uint8)

# Print the shape of the RGB image for confirmation
print(f"RGB Image Shape: {rgb_image.shape}")

RGB Image Shape: (1873, 1976, 3)


## Import Points

In [None]:
point_shp = '/content/drive/MyDrive/Kaza2/SHP_ROI_1.shp'
# Read the shapefile
gdf_points = gpd.read_file(point_shp)

# Reproject the shapefile to match the GeoTIFF CRS
gdf_points = gdf_points.to_crs(crs)

# coord_list = gdf_points.get_coordinates()
# coord_list

In [None]:
# Extract pixel indices for each point
from rasterio.transform import rowcol

point_pixels = []
for geom in gdf_points.geometry:
    row, col = rowcol(transform, geom.x, geom.y)  # Convert geographic coords to pixel indices
    point_pixels.append((row, col))

In [None]:
# Extract coordinates from the points
# coord_list = [(geom.x, geom.y) for geom in gdf_points.geometry]

In [None]:
# image_normalized = rasterio.open("/content/drive/MyDrive/Kaza2/ROI_1_WET_Normalized_RGB.tif")
# image_normalized.count

3

## **Load the SAM model: Point Prompt**

- here we load the simplest model but there are more advanced alternatives that should be explored

In [None]:
sam = SamGeo(
    model_type="vit_h",
    automatic=False,
    sam_kwargs=None,
)

In [None]:
sam.set_image(rgb_image)

## **Use Pixel Coordinates for Segmentation**

In [None]:
for pixel in point_pixels:
    row, col = pixel  # Row and column indices
    segmentation_mask = sam.predict(point_coords=[(col, row)])  # Note: SAM expects (x, y) as (col, row)

    # Do something with the mask
    print(f"Segmentation mask for pixel {pixel}: {segmentation_mask}")

Segmentation mask for pixel (1121, 1459): None
Segmentation mask for pixel (1132, 1479): None
Segmentation mask for pixel (586, 1141): None
Segmentation mask for pixel (539, 1140): None
Segmentation mask for pixel (1080, 1402): None
Segmentation mask for pixel (967, 1348): None
Segmentation mask for pixel (638, 1060): None
Segmentation mask for pixel (498, 1152): None
Segmentation mask for pixel (439, 1101): None
Segmentation mask for pixel (522, 1113): None
Segmentation mask for pixel (999, 1755): None
Segmentation mask for pixel (1007, 1798): None
Segmentation mask for pixel (995, 1740): None
Segmentation mask for pixel (1032, 1454): None
Segmentation mask for pixel (1000, 1347): None
Segmentation mask for pixel (1051, 1263): None
Segmentation mask for pixel (905, 266): None
Segmentation mask for pixel (872, 301): None
Segmentation mask for pixel (1158, 1276): None
Segmentation mask for pixel (1057, 1514): None
Segmentation mask for pixel (879, 503): None


In [None]:
# Save segmentation mask to a GeoTIFF
output_segmentation_path = "/content/drive/MyDrive/Kaza2/Segmentation_Result.tif"
with rasterio.open(
    output_segmentation_path,
    "w",
    driver="GTiff",
    height=segmentation_mask.shape[0],
    width=segmentation_mask.shape[1],
    count=1,  # Single band for mask
    dtype=segmentation_mask.dtype,
    crs=crs,
    transform=transform,
) as dst:
    dst.write(segmentation_mask, 1)  # Write the mask

In [None]:
# # Set the image for segmentation
# sam.set_image(output_geotiff_path_rgb) #ChatGPT

error: OpenCV(4.10.0) /io/opencv/modules/imgproc/src/color.cpp:196: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'


In [None]:
# for point in gdf_points.geometry:
#     segmentation_mask = sam.predict(image=ROI_1_WET_Normalized_RGB.tif, point_coords=[(point.x, point.y)])

#     # Here, segmentation_mask is the predicted mask for the input point
#     # Do something with the mask, like displaying or saving it
#     print("Segmentation mask for point", (point.x, point.y), ":", segmentation_mask)

NameError: name 'ROI_1_WET_Normalized_RGB' is not defined

In [None]:
# # Perform segmentation for each point, ChatGPT
# segmentation_masks = []
# for point in coord_list:
#     segmentation_mask = sam.predict(point_coords=[point])
#     segmentation_masks.append(segmentation_mask)
#     print(f"Segmentation mask for point {point}: {segmentation_mask.shape}")

AttributeError: 'NoneType' object has no attribute 'shape'

In [None]:
segmentation_mask

In [None]:
# Visualize the first segmentation mask (optional), ChatGPT
import matplotlib.pyplot as plt

if segmentation_masks:
    plt.imshow(segmentation_masks[0], cmap="gray")
    plt.title(f"Segmentation Mask for Point {coord_list[0]}")
    plt.axis("off")
    plt.show()

In [None]:
# Save segmentation masks to disk (optional)
for i, mask in enumerate(segmentation_masks):
    output_mask_path = f"/content/drive/MyDrive/Kaza2/Segmentation_Mask_{i+1}.tif"
    with rasterio.open(
        output_mask_path,
        "w",
        driver="GTiff",
        height=mask.shape[0],
        width=mask.shape[1],
        count=1,
        dtype=mask.dtype,
        crs="EPSG:4326",  # Adjust to your dataset's CRS
        transform=sam.transform
    ) as dst:
        dst.write(mask, 1)
    print(f"Saved segmentation mask to {output_mask_path}")