<a target="_blank" href="https://colab.research.google.com/github/Prindle19/efcoa/blob/main/notebooks/EFM_Classification_Dask_ML_KNN_Local.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [None]:
# install dependencies: Zarr, Dask ML
!pip install -q zarr "dask-ml[complete]"

In [None]:
#  Import dependencies and authenticate to Cloud Storage to load the EFM Zarr
import gcsfs
import xarray as xr
import zarr
from dask.distributed import Client, progress
import dask_ml.cluster
import matplotlib.pyplot as plt

import google.auth
from google.colab import auth
auth.authenticate_user()

import gcsfs

# read the dataset from Zarr
ds = xr.open_zarr("gs://imax-conus/data-10m/")
ds

In [None]:
# Crop the Dataset to the Manasquan, NJ Inlet area

bbox = [-74.09972442,  40.0838885 , -74.02481765,  40.12451048] # 500m buffer

ds_small = ds.sel(
    lat=slice(bbox[1], bbox[3]),
    lon=slice(bbox[0], bbox[2]),
    time='2022-01-01'
)

ds_small

In [None]:
# Convert the DataSet to an Array

da = ds_small.to_array()

# stack lat and lon dimensions and transpose as KMeans expects input with a shape: (example, feature)
da = da.stack(point=['lat', 'lon']).transpose()
da

In [None]:
# Start local Dask cluster with 1 worker and 4 threads per worker
client = Client(processes=False, threads_per_worker=4, n_workers=1, memory_limit='12GB')
client

In [None]:
# Create 10 classes using Dask ML KNN on the local cluster using unsupervised classification

%%time
km = dask_ml.cluster.KMeans(n_clusters=10, init_max_iter=2, oversampling_factor=10)
km.fit(da)

In [None]:
# Assign the predictions for each cell back to the DataSet

da['predicted_class'] = ('point', km.labels_)

# Unstack the data to restore the original lat/lon dimensions
da = da.unstack('point')
da

In [None]:
# For reference, visualize embedding_B35 of the Sentinel 2 Composite

fig, ax = plt.subplots(figsize=(10, 5))
ds_small.embedding_B35.plot(x='lon',y='lat')

In [None]:
# Plot the predicted classes

fig, ax = plt.subplots(figsize=(10, 5))
da.predicted_class.plot(ax=ax, add_colorbar=True, x='lon', y='lat', cmap='tab10_r')
ax.set_title("Unsupervised K-Means clustering on Zarr with Dask running locally")
plt.show()

In [None]:
# Shutdown the Dask Client and Cluster
client.shutdown()