<p style="text-align: center">
<img src="./images/landsat_8_rend-sm1.png" width=250 alt="Landsat 8"></img>
</p>

# Spectral Clustering

---

## Overview

With the [data ingestion](01_Data_Ingestion) and [preprocessing](02_Preprocessing) under our belts, the current notebook will demonstrate a simple machine learning workflow to identify water in our satellite images. For this particular approach, we will utilize spectral clustering to assign labels to each x,y point in our data space based on the similarity of the combined set of pixels across wavelength-bands in our image stack. Our example approach uses a version of spectral clustering from [dask_ml](http://ml.dask.org/clustering.html#spectral-clustering) that is a scalable equivalent of what is available in [scikit-learn](https://scikit-learn.org/stable/modules/clustering.html#spectral-clustering). To focus on the analysis, we will begin by performing this analysis on a single image and then conclude by comparing across images by combining our regridding steps from the previous notebook with spectral clustering.

Our present approach is just one example of an analysis, but any library, algorithm, or simulator could be used at this stage if it can accept our processed array data.

## Prerequisites

| Concepts | Importance | Notes |
| --- | --- | --- |
| [Xarray](https://foundations.projectpythia.org/core/xarray.html) | Necessary |  |
|  |  |  |
|  |  |  |

- **Time to learn**: 20 minutes.


---

## Imports


In [76]:
import intake
import numpy as np
import xarray as xr
xr.set_options(keep_attrs=True)
from dask_ml.cluster import SpectralClustering
from dask.distributed import Client
import cartopy.crs as ccrs
import geoviews as gv
import hvplot.xarray

import warnings 
# Ignore a warning about the format of epsg codes
warnings.simplefilter('ignore', FutureWarning)

## Loading data

Let's start by loading the small version of the landsat data. This should be familiar from the previous notebooks.

In [77]:
cat = intake.open_catalog('./data/catalog.yml')
landsat_5_da = cat.landsat_5_small.to_dask()
landsat_5_da

Unnamed: 0,Array,Chunk
Bytes,4.12 MiB,19.53 kiB
Shape,"(6, 300, 300)","(1, 50, 50)"
Dask graph,216 chunks in 13 graph layers,216 chunks in 13 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.12 MiB 19.53 kiB Shape (6, 300, 300) (1, 50, 50) Dask graph 216 chunks in 13 graph layers Data type float64 numpy.ndarray",300  300  6,

Unnamed: 0,Array,Chunk
Bytes,4.12 MiB,19.53 kiB
Shape,"(6, 300, 300)","(1, 50, 50)"
Dask graph,216 chunks in 13 graph layers,216 chunks in 13 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


## Reshaping Data

The shape of our data is currently `n_bands`, `n_y`, `n_x`. In order for dask-ml / scikit-learn to consume our data, we'll need to reshape our image stacks into `n_samples, n_features`, where `n_features` is the number of wavelength-bands and `n_samples` is the total number of pixels in each wavelength-band image. Essentially, we'll be creating a vector of pixels out of each image, where each pixel has multiple features (bands), but the ordering of the pixels is no longer relevant to the computation. We'll first look at using NumPy, then Xarray.

### Numpy

Data can be reshaped at the lowest level using NumPy, by getting the underlying values from the `xarray.DataArray`, and using flatten and transpose to get the right shape. 

In [78]:
arr = landsat_5_da.values
arr.shape

(6, 300, 300)

In [79]:
flattened_npa = np.array([arr[i].flatten() for i in range(arr.shape[0])])
flattened_npa

array([[ 640.,  842.,  864., ..., 1309., 1636., 1199.],
       [ 810., 1096., 1191., ..., 1736., 2250., 1736.],
       [1007., 1345., 1471., ..., 2202., 2783., 1994.],
       [1221., 1662., 1809., ..., 2755., 3431., 2223.],
       [1819., 2596., 2495., ..., 3067., 3802., 2665.],
       [1682., 2215., 2070., ..., 2860., 3724., 2333.]])

In [80]:
flattened_npa.shape

(6, 90000)

In [81]:
flattened_t_npa = flattened_npa.transpose()
flattened_t_npa.shape

(90000, 6)

Now we have the data in `n_samples, n_features`, but since these are bare NumPy arrays without any coordinates or labeled dimensions, it will be harder to recreate the images after the analysis.

### Xarray

Let's consider a better way to reshape the data that preserves the metadata. By using xarray methods to flatten the data, we can keep track of the coordinate labels 'x' and 'y' along the way. This means that we have the ability to reshape back to our original array at any time with no information loss!

In [82]:
flattened_xda = landsat_5_da.stack(z=('x','y'))
flattened_xda

Unnamed: 0,Array,Chunk
Bytes,4.12 MiB,23.44 kiB
Shape,"(6, 90000)","(1, 3000)"
Dask graph,216 chunks in 16 graph layers,216 chunks in 16 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.12 MiB 23.44 kiB Shape (6, 90000) (1, 3000) Dask graph 216 chunks in 16 graph layers Data type float64 numpy.ndarray",90000  6,

Unnamed: 0,Array,Chunk
Bytes,4.12 MiB,23.44 kiB
Shape,"(6, 90000)","(1, 3000)"
Dask graph,216 chunks in 16 graph layers,216 chunks in 16 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


We can reorder the dimensions using `DataArray.transpose`:

In [83]:
flattened_t_xda = flattened_xda.transpose('z', 'band')
flattened_t_xda

Unnamed: 0,Array,Chunk
Bytes,4.12 MiB,23.44 kiB
Shape,"(90000, 6)","(3000, 1)"
Dask graph,216 chunks in 17 graph layers,216 chunks in 17 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.12 MiB 23.44 kiB Shape (90000, 6) (3000, 1) Dask graph 216 chunks in 17 graph layers Data type float64 numpy.ndarray",6  90000,

Unnamed: 0,Array,Chunk
Bytes,4.12 MiB,23.44 kiB
Shape,"(90000, 6)","(3000, 1)"
Dask graph,216 chunks in 17 graph layers,216 chunks in 17 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


## Standardize Data

Now that we have the data in the correct shape, let's standardize (or rescale) the values of the data. We do this to get all the flattened image vectors onto a common scale while preserving the differences in the ranges of values. Again, we'll demonstrate doing this first in NumPy and then xarray.

In [84]:
# TODO: introduce standardization equation

In [85]:
rescaled_npa = (flattened_t_npa - flattened_t_npa.mean()) / flattened_t_npa.std()
rescaled_npa

array([[-1.29960701, -1.10062865, -0.87004784, -0.6195692 ,  0.08036645,
        -0.0799867 ],
       [-1.0631739 , -0.76587681, -0.47443204, -0.10339592,  0.98981461,
         0.54386898],
       [-1.03742375, -0.65468302, -0.32695396,  0.06866184,  0.87159805,
         0.37415215],
       ...,
       [-0.51656863, -0.01678181,  0.52865299,  1.1759179 ,  1.54110171,
         1.2988163 ],
       [-0.1338279 ,  0.58483512,  1.2086908 ,  1.9671495 ,  2.40139051,
         2.31009455],
       [-0.64531934, -0.01678181,  0.28519712,  0.55323267,  1.07057641,
         0.68198338]])

In [86]:
rescaled_xda = (flattened_t_xda - flattened_t_xda.mean()) / flattened_t_xda.std()
rescaled_xda

Unnamed: 0,Array,Chunk
Bytes,4.12 MiB,23.44 kiB
Shape,"(90000, 6)","(3000, 1)"
Dask graph,216 chunks in 36 graph layers,216 chunks in 36 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.12 MiB 23.44 kiB Shape (90000, 6) (3000, 1) Dask graph 216 chunks in 36 graph layers Data type float64 numpy.ndarray",6  90000,

Unnamed: 0,Array,Chunk
Bytes,4.12 MiB,23.44 kiB
Shape,"(90000, 6)","(3000, 1)"
Dask graph,216 chunks in 36 graph layers,216 chunks in 36 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


As `rescaled_xda` is still a Dask object, if you wanted to actually run the rescaling at this point (provided that all the data can fit into memory), use `.compute()`

In [87]:
rescaled_xda.compute()


## ML pipeline
Now that our data is in the propor shape and value range, we are ready to conduct spectral clustering. Here we will use a version of [spectral clustering from dask_ml](https://ml.dask.org/modules/generated/dask_ml.cluster.SpectralClustering.html) that is a scalable equivalent to operations from Scikit-learn that cluster pixels based on similarity (across all bands, which makes it spectral clustering by spectra!)

The Machine Learning pipeline shown below is just for demonstration purposes, including the shaping/reshaping of data. In practice you will likely be using a more sophisticated pipeline. 

In [88]:
client = Client(processes=False)
client

Perhaps you already have a cluster running?
Hosting the HTTP server on port 65366 instead


0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://192.168.178.39:65366/status,

0,1
Dashboard: http://192.168.178.39:65366/status,Workers: 1
Total threads: 10,Total memory: 32.00 GiB
Status: running,Using processes: False

0,1
Comm: inproc://192.168.178.39/48717/18,Workers: 1
Dashboard: http://192.168.178.39:65366/status,Total threads: 10
Started: Just now,Total memory: 32.00 GiB

0,1
Comm: inproc://192.168.178.39/48717/21,Total threads: 10
Dashboard: http://192.168.178.39:65367/status,Memory: 32.00 GiB
Nanny: None,
Local directory: /var/folders/pp/zp63v9q50m79py9t866mvg3h0000gp/T/dask-worker-space/worker-9byc9mwe,Local directory: /var/folders/pp/zp63v9q50m79py9t866mvg3h0000gp/T/dask-worker-space/worker-9byc9mwe


Now we will compute and persist the rescaled data to feed into the ML pipeline. Notice that our `X` matrix below has the shape: `n_samples, n_features` as discussed earlier. 

In [89]:
X = client.persist(rescaled_xda)
X.shape

(90000, 6)

First we will set up the model with the number of clusters, and other options.

In [90]:
clf = SpectralClustering(n_clusters=4, random_state=0, gamma=None,
                         kmeans_params={'init_max_iter': 5},
                         persist_embedding=True)

**This is the slow-ish part.** Then we'll fit the model to our matrix `X`. This is the part that will take a noticeable amount of time. Depending on your setup, it could take about 30 seconds to run the small version of the data (on a relatively beefy laptop) or around 10 minutes for a full size landsat image.

In [91]:
%time clf.fit(X)



CPU times: user 21.6 s, sys: 20.7 s, total: 42.3 s
Wall time: 31.3 s


In [92]:
labels = clf.assign_labels_.labels_.compute()
labels.shape

(90000,)

The result is a vector of cluster labels! OK, I know this doesn't seem all that exciting yet, but we're getting there. Next we will reshape the results into human-friendly image form.

In [93]:
labels

array([0, 0, 3, ..., 0, 0, 3], dtype=int32)

## Un-flattening

Once the computation is done, the output can be used to create a new array with the same structure as the input array. This new output array will have the coordinates needed to be unstacked similarly to how they were stacked. One of the main benefits of using `xarray` for this stacking and unstacking is that allows `xarray` to keep track of the coordinate information for us. 

Since the original array is n_samples by n_features (90000, 6) and the output only contains one feature (90000,), the template structure for this data needs to have the shape (n_samples). We achieve this by just taking one of the bands.

In [94]:
template = flattened_t_xda[:, 0]
output_array = template.copy(data=labels)
output_array

With this new output array in hand, we can unstack back to the original dimensions:

In [95]:
unstacked = output_array.unstack()
unstacked

And finally, bring the results to life! 

In [96]:
landsat_5_da.sel(band=4).hvplot.image(x='x', y='y', geo=True, datashade=True, cmap='greys', title='Raw Image') + \
               unstacked.hvplot(x='x', y='y', cmap='Set3', geo=True, colorbar=False, title='Spectral Clustering Labels')

## Spectral Clustering over time

Now that we have conducted the spectral clustering for one time, let's bring it together with what we learned about regridding in the previous [Preprocessing]('02_Preprocesing') notebook to compare the results of this analysis from two different time points. The important conceptual goal here is to get the images from different acquisitions onto the same spatial grid so that we can have a chance to run computations that directly compare the images.

We already have Landsat 5 data (from 1988), so let's just load Landsat 8 (from 2017).

In [97]:
landsat_8_da = cat.landsat_8_small.read_chunked()

See the previous preprocessing notebook for a detailed walkthrough on the following steps, but in summary, we are creating a bounding box and grid around our region of interest and then interpolating our data onto this new grid.

In [98]:
crs = ccrs.epsg(32611)
x_center, y_center = crs.transform_point(-118.7081, 38.6942, ccrs.PlateCarree())

buffer = 1.5e4

xmin = x_center - buffer
xmax = x_center + buffer
ymin = y_center - buffer
ymax = y_center + buffer

bounding_box = [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)]

res = 200
x = np.arange(xmin, xmax, res)
y = np.arange(ymin, ymax, res)

landsat_8_da_regridded = landsat_8_da.interp(x=x, y=y)
landsat_5_da_regridded = landsat_5_da.interp(x=x, y=y)

Let's take a look at our regridded data. Notice that hvPlot understands that the two arrays have a common dimension `band`, and automatically link them to the same widget.

In [99]:
landsat_8_da_regridded.hvplot.image(x='x', y='y', geo=True, title='Landsat 8 2017', colorbar=False, rasterize=True, cmap='viridis') +\
    landsat_5_da_regridded.hvplot.image(x='x', y='y', geo=True, title='Landsat 5 1988', colorbar=False, rasterize=True, cmap='viridis')

Now let's run the same spectral clustering steps that we saw earlier, but on this new regridded data. Again, we will start with reshaping and rescaling the data.

In [100]:
l5_rg_flat_xda = landsat_5_da_regridded.stack(z=('x','y')).transpose('z', 'band')
l8_rg_flat_xda = landsat_8_da_regridded.stack(z=('x','y')).transpose('z', 'band')

l5_rg_rescale_xda = (l5_rg_flat_xda - l5_rg_flat_xda.mean()) / l5_rg_flat_xda.std()
l8_rg_rescale_xda = (l8_rg_flat_xda - l8_rg_flat_xda.mean()) / l8_rg_flat_xda.std()

l5_X = client.persist(l5_rg_rescale_xda)
l8_X = client.persist(l8_rg_rescale_xda)


And now we fit the data to our model.

In [101]:
l5_clf = SpectralClustering(n_clusters=4, random_state=0, gamma=None,
                         kmeans_params={'init_max_iter': 5},
                         persist_embedding=True)
%time l5_clf.fit(l5_X)



CPU times: user 21.1 s, sys: 14.2 s, total: 35.3 s
Wall time: 31.1 s


In [102]:
l8_clf = SpectralClustering(n_clusters=4, random_state=0, gamma=None,
                         kmeans_params={'init_max_iter': 5},
                         persist_embedding=True)
%time l8_clf.fit(l8_X)



CPU times: user 17.1 s, sys: 10.6 s, total: 27.7 s
Wall time: 24.5 s


In [103]:
l5_labels = l5_clf.assign_labels_.labels_.compute()
l8_labels = l8_clf.assign_labels_.labels_.compute()

And the last step before the big reveal is to reshape the results back into image form:

In [104]:
l5_template = l5_rg_flat_xda[:, 0]
l5_output_array = l5_template.copy(data=l5_labels)

l8_template = l8_rg_flat_xda[:, 0]
l8_output_array = l8_template.copy(data=l8_labels)

l5_labels_unstacked = l5_output_array.unstack()
l8_labels_unstacked = l8_output_array.unstack()

Ta da!

In [105]:
l5_labels_unstacked.hvplot(x='x', y='y', width=400, height=400, cmap='Set3', geo=True, colorbar=False, title='1988 Labels') +\
l8_labels_unstacked.hvplot(x='x', y='y', width=400, height=400, cmap='Set3', geo=True, colorbar=False, title='2017 Labels')

But wait, the spectral clustering labels of water are clearly different between the two years. If we want to directly compare the amount of water across these images, we'll have to create a mask using the appropriate label from each image that is indicative of water. Since we are using interactive plotting, we can just hover over the lake in these images to discover that we are interested in cluster label 1 (blue) for the 1988 data and cluster label 3 (yellow) for the 2017 data. Great, now let's create those water masks.

In [106]:
l5_labels_mask = l5_labels_unstacked.where(l5_labels_unstacked == 1, 0) # set non-1 to 0
l8_labels_mask = l8_labels_unstacked.where(l8_labels_unstacked == 3, 0) # set non-3 to 0
l8_labels_mask = l8_labels_mask.where(l8_labels_mask != 3, 1) # set 3 -> 1

In [107]:
l5_labels_mask.hvplot(x='x', y='y', cmap='greys', geo=True, colorbar=False, title='1988 Water Mask') +\
l8_labels_mask.hvplot(x='x', y='y', cmap='greys', geo=True, colorbar=False, title='2017 Water Mask')

Now we can take the difference of these water label masks to see exactly where the water levels has changed.

In [108]:
l8_l5_specdiff = l8_labels_mask - l5_labels_mask

<div class="admonition alert alert-warning">
    <p class="admonition-title" style="font-weight:bold">Warning</p>
    By default, this last operation between two xarray arrays will strip the attributes (like crs) from the result unless you have told xarray to hang on to them, as we did in our import cell at the top with xr.set_options(keep_attrs=True).
</div>

In [109]:
l8_l5_specdiff.hvplot(x='x', y='y', width=400, height=400, cmap='blues', geo=True, alpha=.7, colorbar=False, title='2017-1988 Labels', tiles='ESRI')

Congratulations, you did it! Above, the white pixels are regions where there was water in 1988 but not 2017 around the lake.

---

## Summary
Nice work. In this notebook we covered reshaping and rescaling the data to get it into a format ready for machine learning. Then we conducted spectral clustering to get label-images of spots where there was likely water, and finally used our regridding approach to compared the water regions from different time points. 

### What's next?
Now that we have conducted a simple machine learning workflow, it's time for you to adapt and extend these methods to your own projects.


## Resources and references
- Authored/adapted by Demetris Roumis circa Dec, 2022
- This cookbook was inspired by the [EarthML](https://github.com/pyviz-topics/EarthML) tutorial. See a list of the EarthML contributors [here](https://github.com/pyviz-topics/EarthML/graphs/contributors).
<a href="https://github.com/pyviz-topics/EarthML/graphs/contributors">
  <img src="https://contrib.rocks/image?repo=pyviz-topics/EarthML" />
</a>
- The landsat 8 banner image is from [NASA](https://svs.gsfc.nasa.gov/10812)
