# Digitize 3 Training Classes

In [None]:
from ipyleaflet import Map, basemaps, basemap_to_tiles, DrawControl, TileLayer, GeoData, LayersControl
import pandas as pd
import geopandas as gpd

zxyURL = 'https://storage.googleapis.com/cloud-geo-efm-public/s2-composite-tiles/{z}/{x}/{y}'
center = (40.09351228982099, -74.07673459283767)
s2Layer = TileLayer(url=zxyURL, opacity=1, name="S2 CS+ 2022", max_zoom=16, max_native_zoom=14)

m = Map(center=center, zoom=15)
m.add(s2Layer)
draw_control = DrawControl(
    marker={"shapeOptions": {"color": "#0000FF"}},
            circlemarker={},
            polyline={},
            polygon={},
            rectangle={},
            )

m.add(draw_control)
print("Digitize Class 1")
m

In [None]:
# Once you're done digitizing Class 1, run this cell to capture the points.
classes = []
for feature in draw_control.data:
  classes.append({"class":"class1", "lon":feature['geometry']['coordinates'][0], "lat":feature['geometry']['coordinates'][1]})


In [None]:
m = Map(center=center, zoom=15)
m.add(s2Layer)

draw_control = DrawControl(
    marker={"shapeOptions": {"color": "#FF0000"}},
            circlemarker={},
            polyline={},
            polygon={},
            rectangle={},
            )

m.add(draw_control)
print("Digitize Class 2")
m

In [None]:
# Once you're done digitizing Class 2, run this cell to capture the points.
for feature in draw_control.data:
  classes.append({"class":"class2", "lon":feature['geometry']['coordinates'][0], "lat":feature['geometry']['coordinates'][1]})


In [None]:
m = Map(center=center, zoom=15)
m.add(s2Layer)

draw_control = DrawControl(
    marker={"shapeOptions": {"color": "#00FF00"}},
            circlemarker={},
            polyline={},
            polygon={},
            rectangle={},
            )

m.add(draw_control)
print("Digitize Class 3")
m

In [None]:
# Once you're done digitizing Class 3, run this cell to capture the points.
for feature in draw_control.data:
  classes.append({"class":"class3", "lon":feature['geometry']['coordinates'][0], "lat":feature['geometry']['coordinates'][1]})


In [None]:

# Create a pandas DataFrame from the data
df = pd.DataFrame(classes)

# Create the GeoDataFrame
gdf = gpd.GeoDataFrame(
    df,
    geometry=gpd.points_from_xy(df['lon'], df['lat']),
    crs="EPSG:4326"
)

gdf

In [None]:
#@title Show all 3 Classes

# Split the GeoDataFrame into three based on class
gdf_class1 = gdf[gdf['class'] == 'class1']
gdf_class2 = gdf[gdf['class'] == 'class2']
gdf_class3 = gdf[gdf['class'] == 'class3']

# Create the map (adjust center and zoom as needed)
m = Map(center=(gdf.geometry.y.mean(), gdf.geometry.x.mean()), zoom=13)
m.add(s2Layer)

layer1 = GeoData(geo_dataframe = gdf_class1,
    style={'color': 'black', 'radius':8, 'fillColor': 'red', 'opacity':0.5, 'weight':1.9, 'dashArray':'2', 'fillOpacity':0.6},
    hover_style={'fillColor': 'red' , 'fillOpacity': 0.2},
    point_style={'radius': 5, 'color': 'red', 'fillOpacity': 0.8, 'fillColor': 'red', 'weight': 3},
    name = 'Class 1')

layer2 = GeoData(geo_dataframe = gdf_class2,
    style={'color': 'black', 'radius':8, 'fillColor': 'blue', 'opacity':0.5, 'weight':1.9, 'dashArray':'2', 'fillOpacity':0.6},
    hover_style={'fillColor': 'blue' , 'fillOpacity': 0.2},
    point_style={'radius': 5, 'color': 'blue', 'fillOpacity': 0.8, 'fillColor': 'blue', 'weight': 3},
    name = 'Class 2')

layer3 = GeoData(geo_dataframe = gdf_class3,
    style={'color': 'black', 'radius':8, 'fillColor': 'green', 'opacity':0.5, 'weight':1.9, 'dashArray':'2', 'fillOpacity':0.6},
    hover_style={'fillColor': 'green' , 'fillOpacity': 0.2},
    point_style={'radius': 5, 'color': 'green', 'fillOpacity': 0.8, 'fillColor': 'green', 'weight': 3},
    name = 'Class 3')

# Add layers to the map
m.add_layer(layer1)
m.add_layer(layer2)
m.add_layer(layer3)

# Add layer control to the map
control = LayersControl(position='topright')
m.add_control(control)

# Display the map
m

In [None]:
#@title Optionally, save as a GeoJSON file
gdf.to_file("classes.geojson", driver='GeoJSON')

# Load an EE-Exported Zarr file and Classify it with KNN using Scikit Learn

In [None]:
#@title Install dependencies and authenticate to Cloud Storage
!pip install zarr rioxarray

import google.auth
from google.colab import auth
auth.authenticate_user()

import gcsfs
import xarray as xr
import zarr

# read the dataset from Zarr
ds = xr.open_zarr("gs://imax-conus/data-10m/")

In [None]:
#@title Buffer the GeoDataFrame's BBOX 500 meters and crop the Zarr to that BBOX

import geopandas as gpd
from shapely.geometry import Polygon
def buffer_bounding_box(gdf, buffer_distance_meters):
    """
    Calculates and buffers the bounding box of a GeoDataFrame in EPSG:4326.

    Args:
        gdf: The GeoDataFrame in EPSG:4326.
        buffer_distance_meters: The buffer distance in meters.

    Returns:
        The buffered bounding box as a GeoDataFrame.
    """

    # Ensure the GeoDataFrame is in EPSG:4326
    if gdf.crs != 'EPSG:4326':
        gdf = gdf.to_crs('EPSG:4326')

    # Get the bounding box
    bbox = gdf.total_bounds

    # Extract coordinates from bounds
    xmin, ymin, xmax, ymax = bbox

    # Create a Polygon from the coordinates
    polygon = Polygon([(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), (xmin, ymin)])

    # Create a GeoSeries from the Polygon
    polygon_gdf = gpd.GeoSeries([polygon], crs='EPSG:4326')

    # Project to a metric CRS for buffering
    bbox_poly_utm = polygon_gdf.to_crs(gdf.estimate_utm_crs())

    # Buffer the polygon
    buffered_bbox_utm = bbox_poly_utm.buffer(buffer_distance_meters)

    # Project back to EPSG:4326
    buffered_bbox = buffered_bbox_utm.to_crs('EPSG:4326')

    return buffered_bbox

# Assuming you have your GeoDataFrame 'gdf' defined

buffered_bbox_gdf = buffer_bounding_box(gdf, 500)

bbox = buffered_bbox_gdf.total_bounds

cropped_ds = ds.sel(
    lat=slice(bbox[1], bbox[3]),
    lon=slice(bbox[0], bbox[2])
)
cropped_ds

Create a copy of the clipped Xarray Dastaset to a GeoDataFrame so we can do a spatial join with the training point GeoDataFrame

In [None]:
import geopandas as gpd
from shapely.geometry import Point
import pandas as pd
import xarray as xr

# Convert the xarray Dataset to a GeoDataFrame

def ds_to_gdf(ds):
    """Converts an xarray Dataset to a GeoDataFrame."""
    # Create a DataFrame from the Dataset
    df = ds.to_dataframe().reset_index()

    # Create a geometry column
    df['geometry'] = df.apply(lambda row: Point(row['lon'], row['lat']), axis=1)

    # Create the GeoDataFrame
    gdf = gpd.GeoDataFrame(df, geometry='geometry')

    # Set the coordinate reference system (optional, but recommended)
    gdf.crs = 'EPSG:4326'  # Assuming WGS 84

    return gdf

# Convert the cropped_ds to a GeoDataFrame
cropped_gdf = ds_to_gdf(cropped_ds)

# Perform the spatial join
joined_gdf = gpd.sjoin_nearest(gdf, cropped_gdf, how="left")

joined_gdf


Train a KNN Classifier from the training points and their intersected Embedding Field Values and check the accuracy.

In [None]:
variable_names = [
       'embedding_B0', 'embedding_B1', 'embedding_B10', 'embedding_B11',
       'embedding_B12', 'embedding_B13', 'embedding_B14', 'embedding_B15',
       'embedding_B16', 'embedding_B17', 'embedding_B18', 'embedding_B19',
       'embedding_B2', 'embedding_B20', 'embedding_B21', 'embedding_B22',
       'embedding_B23', 'embedding_B24', 'embedding_B25', 'embedding_B26',
       'embedding_B27', 'embedding_B28', 'embedding_B29', 'embedding_B3',
       'embedding_B30', 'embedding_B31', 'embedding_B32', 'embedding_B33',
       'embedding_B34', 'embedding_B35', 'embedding_B36', 'embedding_B37',
       'embedding_B38', 'embedding_B39', 'embedding_B4', 'embedding_B40',
       'embedding_B41', 'embedding_B42', 'embedding_B43', 'embedding_B44',
       'embedding_B45', 'embedding_B46', 'embedding_B47', 'embedding_B48',
       'embedding_B49', 'embedding_B5', 'embedding_B50', 'embedding_B51',
       'embedding_B52', 'embedding_B53', 'embedding_B54', 'embedding_B55',
       'embedding_B56', 'embedding_B57', 'embedding_B58', 'embedding_B59',
       'embedding_B6', 'embedding_B60', 'embedding_B61', 'embedding_B62',
       'embedding_B63', 'embedding_B7', 'embedding_B8', 'embedding_B9'
]

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

# Select features and target variable
X = joined_gdf[variable_names]
y = joined_gdf['class']

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Create a KNN classifier
knn = KNeighborsClassifier(n_neighbors=5)

# Train the model
knn.fit(X_train, y_train)

# Make predictions on the test set
y_pred = knn.predict(X_test)

# Evaluate the model
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)

Now use that model to predict the classes of all of the points and plot the classified map

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches

# Prepare the data for prediction on cropped_ds
X_cropped = cropped_gdf[variable_names]

# Predict classes for cropped_ds
cropped_gdf['predicted_class'] = knn.predict(X_cropped)

# Convert class predictions to numeric values
cropped_gdf['predicted_class_numeric'] = cropped_gdf['predicted_class'].map({'class1': 1, 'class2': 2, 'class3': 3})

# Plot the predictions

# Define the color mapping
color_mapping = {1: 'red', 2: 'blue', 3: 'green'}

# Add legend for ax[1]
legend_patches = [
    mpatches.Patch(color=color_mapping[class_num], label=f'Class {class_num}')
    for class_num in color_mapping
]

# Plot the predictions with discrete colors
fig, ax = plt.subplots(figsize=(10, 8))
cropped_gdf.plot(column='predicted_class_numeric', categorical=True,
                 legend=True, ax=ax,
                 color=[color_mapping[x] for x in cropped_gdf['predicted_class_numeric']])

ax.legend(handles=legend_patches)

plt.show()


Since that was done on a GeoDataFrame, it is a point feature representation of the map being classified.

However, the same KNN Model can be used to make predictions on the Xarray DataSet.

In [None]:
# Convert the DataSet to an Array

da = cropped_ds.to_array()
da

In [None]:
# Stack and transpose the array based on the coordinates

da = da.stack(point=['lat', 'lon']).transpose()
da

In [None]:
# Slice the array by the time dimension

da = da[:, 0, :]
da

In [None]:
# Predict the classes using the trained KNN model and then convert to a numeric representation
predicted_classes = knn.predict(da)

predicted_classes[predicted_classes=='class1'] = 1
predicted_classes[predicted_classes=='class2'] = 2
predicted_classes[predicted_classes=='class3'] = 3

predicted_classes = predicted_classes.astype(float)

In [None]:
#@title Plot an RGB Composite from 3 bands of the Xarray Dataset next to the predicted classes
import numpy as np

# Functions to brighten and correct gamma for Xarray DataSet RGB composite
def brighten(band):
    alpha = 0.13
    beta = 0
    return np.clip(alpha * band + beta, 0, 255)

def gammacorr(band):
    gamma = 1
    return np.power(band, 1 / gamma)

# Select Bands for the RGB Composite (Defaults to 35,5,6)
red_band = cropped_ds["embedding_B0"]
green_band = cropped_ds["embedding_B14"]
blue_band = cropped_ds["embedding_B62"]

# Extract the values from da.point and store in NumPy array 'coords'
coords = np.array([*da.point.values])

# Get the dimension of the orriginal Dataset
reshape_lat = cropped_ds.sizes['lat']
reshape_lon = cropped_ds.sizes['lon']

# Define the color mapping for the predicted classes
color_mapping = {1: 'red', 2: 'blue', 3: 'green'}

# Create a discrete colormap the color mapping
cmap = mcolors.ListedColormap(list(color_mapping.values()))

# Stack the bands and normalize
rgb = np.stack([red_band, green_band, blue_band], axis=-1)

# Apply brightening and gamma correction to each band
rgb[..., 0] = brighten(rgb[..., 0])  # Apply to Red Band
rgb[..., 1] = brighten(rgb[..., 1])  # Apply to Green Band
rgb[..., 2] = brighten(rgb[..., 2])  # Apply to Blue Band

rgb[..., 0] = gammacorr(rgb[..., 0])
rgb[..., 1] = gammacorr(rgb[..., 1])
rgb[..., 2] = gammacorr(rgb[..., 2])

rgb = rgb / rgb.max()  # Normalize after adjustments

# Remove the extra dimension from the array
rgb = rgb.squeeze()

# Plot both plots
fig, ax = plt.subplots(1, 2, figsize=(16, 9))  # Initial figure size with 16:9 ratio

# Adjust the subplots to maintain the 16:9 aspect ratio
# Calculate the desired height based on the width and aspect ratio
desired_height = fig.get_figwidth() / 2 / (16/9)  # Divide width by 2 for each subplot
fig.set_figheight(desired_height)

# Plot the RGB composite
# Get the extent for the image
extent = [
    cropped_ds['lon'].min(), cropped_ds['lon'].max(),
    cropped_ds['lat'].max(), cropped_ds['lat'].min()  # Invert y-axis
]

# Plot the RGB composite with extent
ax[0].imshow(rgb.transpose(1, 0, 2), extent=extent, origin='lower')

# Set titles and axis labels
ax[0].set_title('Zarr RGB Composite')
ax[0].set_xlabel('Longitude')
ax[0].set_ylabel('Latitude')

ax[1].set_title('Predicted Classes')
ax[1].set_xlabel('Longitude')
ax[1].set_ylabel('Latitude')

# Add legend for ax[1]
legend_patches = [
    mpatches.Patch(color=color_mapping[class_num], label=f'Class {class_num}')
    for class_num in color_mapping
]
ax[1].legend(handles=legend_patches)

ax[1].pcolor(coords[:, 1].reshape([reshape_lat, reshape_lon]), coords[:, 0].reshape([reshape_lat, reshape_lon]),  predicted_classes.reshape([reshape_lat, reshape_lon]), cmap=cmap)


plt.show()

In [None]:
# If you want to create a GeoTiff export, reshape the array that holds the predicted classes back to the orriginal DataSet shape

# Create a new coordinate for the predicted classes
da['predicted_class'] = ('point', predicted_classes)

# Unstack the data to restore the original lat/lon dimensions
da = da.unstack('point')
da

In [None]:
# Create a GeoTIFF Output of the predictions

import rioxarray as rxr


# Rename the dimensions to 'x' and 'y' (rioxarray expectation)
da_renamed = da['predicted_class'].rename({'lon': 'x', 'lat': 'y'})

# Write the renamed DataArray to a GeoTIFF file
da_renamed.rio.to_raster("output_predictions.tif", driver="GTiff")