In [None]:
pip install numpy matplotlib tqdm shapely xarray rioxarray dask scikit-learn pystac-client planetary-computer requests rasterio

import os
import requests
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from shapely.geometry import box, mapping

import xarray as xr
import rioxarray as rxr
from rioxarray.merge import merge_datasets
from rasterio.enums import Resampling

import dask
import dask.array as da
from dask import delayed
from dask.diagnostics import ProgressBar
from dask.distributed import Client

from sklearn.cluster import MiniBatchKMeans

import pystac_client
import planetary_computer

In [None]:
# Start the Dask client
client = Client(memory_limit='8GB')  # Adjust the memory limit based on your system's capacity

# Print the dashboard link
print(client.dashboard_link)

In [None]:
def retrieve_stac_items_bbox(url, collection, bbox, datetime, max_items=10, pc_flag=False, max_cloud_cover=2): 
    """
    This function retrieves the latest STAC items from a STAC API, filtered by cloud cover.
    
    Parameters
    ----------
    url : str
        The STAC API URL
    collection : str
        Collection ID to search in the STAC API
    bbox : list or tuple
        A tuple or list with coordinates of the target area for API search in the format (min x, min y, max x, max y)
    datetime : str
        ISO-8601 formatted date range, e.g., "2020-01-01/2020-12-31".
    pc_flag : boolean
        A boolean flag to specify whether the STAC API is Microsoft's planetary computer API. Default is set to False.
    max_items : integer
        Maximum number of items to retrieve from the STAC API. Default is 10.
    max_cloud_cover : float
        Maximum allowed cloud cover percentage for filtering. Default is 2%.

    Returns
    -------
    items : Generator
        The STAC item collection matching the query parameters.
    """
    if pc_flag:
        modifier = planetary_computer.sign_inplace
    else:
        modifier = None
        
    catalog = pystac_client.Client.open(
        url=url,
        modifier=modifier,
    )


    # Query for items, including the cloud cover filter
    search_results = catalog.search(
        collections=[collection],
        bbox=bbox,
        datetime=datetime,  # ISO-8601 formatted date range
        query={"eo:cloud_cover": {"lt": max_cloud_cover}},  # Cloud cover filter
        sortby=["-properties.datetime"],
        max_items=max_items
    )
    items = search_results.item_collection()

    return items

# Define parameters for Charlottesville, Virginia (1-degree square around the city center)
url = "https://planetarycomputer.microsoft.com/api/stac/v1"
collection = "landsat-c2-l2"

bbox = [-79.5, 37.9, -78.5, 38.9]  # 1-degree square bounding box around Charlottesville
bbox_geometry = mapping(box(*bbox))
time_range = "2020-01-01/2020-12-31"

items = retrieve_stac_items_bbox(
    url=url,
    collection=collection,
    bbox=bbox,
    datetime=time_range,  # Pass the time range directly
    max_items=1000,  # Retrieve up to 10 images
    pc_flag=True  # Use Planetary Computer API
)

In [None]:
def gen_stac_asset_urls_all(items):
    """
    Generate a dictionary of asset URLs for all available bands in the STAC items.
    
    Parameters:
    items: list
        A collection of STAC items.

    Returns:
    dict:
        A dictionary where keys are band names and values are lists of URLs.
    """
    asset_urls = {}
    for item in items:
        for asset_name, asset in item.assets.items():
            if asset_name not in asset_urls:
                asset_urls[asset_name] = []
            asset_urls[asset_name].append(asset.href)
    return asset_urls

# Retrieve all band URLs
asset_urls = gen_stac_asset_urls_all(items)

In [None]:
def open_rasters_with_dask_all(asset_urls):
    """
    Open all band rasters using Dask.

    Parameters:
    asset_urls: dict
        Dictionary where keys are band names and values are lists of URLs.

    Returns:
    dict:
        Dictionary where keys are band names and values are lists of Dask datasets.
    """
    return {band: [rxr.open_rasterio(url, lock=False, chunks={'x': 512, 'y': 512}).to_dataset(name=f"{band}_{i}")
                   for i, url in enumerate(urls)]
            for band, urls in asset_urls.items()}

# Open datasets for all bands
with ProgressBar():
    band_datasets = open_rasters_with_dask_all(asset_urls)

with ProgressBar():
    red_datasets = open_rasters_with_dask(red_urls)
    green_datasets = open_rasters_with_dask(green_urls)
    blue_datasets = open_rasters_with_dask(blue_urls)

In [None]:
def merge_in_batches(datasets, batch_size=5):
    merged_result = None
    for i in range(0, len(datasets), batch_size):
        batch = datasets[i:i + batch_size]
        batch_merged = xr.merge(batch, compat = 'override')  # Merge the current batch
        if merged_result is None:
            merged_result = batch_merged
        else:
            merged_result = xr.merge([merged_result, batch_merged], compat='override')  # Merge with previous result
    return merged_result

def dask_merge_datasets(datasets):
    # Use delayed to merge datasets in batches
    delayed_batches = merge_in_batches(datasets)
    
    # Convert merged dataset to Dask array
    # We chunk the merged dataset after the delayed operations
    dask_array = delayed_batches.chunk({'x': 2048, 'y': 2048})  # Ensure chunking is efficient
    
    # Compute the result in parallel once all tasks are delayed
    result = dask.compute(dask_array)
    return result[0]

def dask_clip(dask_array, bbox_geometry, crs=4326):
    """
    Clip a Dask array lazily to a given bounding box using RioClip.
    
    Parameters
    ----------
    dask_array : xarray.DataArray or dask.array
        The Dask array to be clipped.
    bbox_geometry : dict
        The geometry for clipping (GeoJSON format).
    crs : int, optional
        The coordinate reference system (CRS) for the bounding box, default is 4326.
        
    Returns
    -------
    clipped : xarray.DataArray
        The clipped Dask array.
    """
    return dask_array.rio.clip([bbox_geometry], crs=crs)

def merge_all_bands(datasets):
    return {band: dask_merge_datasets(dsets) for band, dsets in datasets.items()}

def clip_all_bands(mosaics, bbox_geometry, crs=4326):
    return {band: dask_clip(mosaic, bbox_geometry, crs=crs) for band, mosaic in mosaics.items()}

In [None]:
with ProgressBar():
    merged_bands = merge_all_bands(band_datasets)
    clipped_bands = clip_all_bands(merged_bands, bbox_geometry)


with ProgressBar():
    mosaic_red[list(mosaic_red.data_vars)[0]].rio.to_raster("red_clipped.tif")
    mosaic_blue[list(mosaic_blue.data_vars)[0]].rio.to_raster("blue_clipped.tif")
    mosaic_green[list(mosaic_green.data_vars)[0]].rio.to_raster("green_clipped.tif")

# Combine all bands into a single xarray Dataset
final_mosaic = xr.merge([clipped_bands[band] for band in clipped_bands], compat='override')

# Save the final mosaic as a multi-band GeoTIFF
final_mosaic = final_mosaic.squeeze()
final_mosaic.rio.to_raster("final_mosaic_all_bands.tif")

print("Final mosaic bands:", list(final_mosaic.data_vars))

In [None]:
print("Available variables:", list(final_mosaic.data_vars))

In [None]:
# Bridge between the codes:
# Convert the final mosaic to a format suitable for clustering
stacked_array = final_mosaic['uVBrdMABc%3D'].values

# Reshape the array if needed for clustering
valid_mask = stacked_array > 0
valid_data = stacked_array[valid_mask]

# Apply MiniBatch KMeans Clustering
n_clusters = 10
kmeans = MiniBatchKMeans(n_clusters=n_clusters, random_state=42, batch_size=1024)
kmeans.fit(valid_data.reshape(-1, 1))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from shapely.geometry import box
import rioxarray as rxr
from sklearn.cluster import MiniBatchKMeans
import pystac_client
import planetary_computer as pc

# Constants
AOI = box(-105.0, 39.5, -104.5, 40.0)  # Bounding Box around Denver, CO
TIME_RANGE = "2023-01-01/2023-12-31"
BANDS = ["red", "green", "blue"]  # Target bands for clustering
N_CLUSTERS = 10


def search_and_load_data(catalog_url, collection, aoi, time_range, bands):
    """Searches for and loads satellite data for the given AOI, time range, and bands."""
    catalog = pystac_client.Client.open(catalog_url)
    search = catalog.search(collections=[collection], bbox=aoi.bounds, datetime=time_range, limit=1)
    item = next(search.get_items())
    print("Selected Item:", item.id)
    
    asset_hrefs = {key: pc.sign(asset.href) for key, asset in item.assets.items()}
    return [rxr.open_rasterio(asset_hrefs[band]).squeeze() for band in bands]


def stack_and_mask(arrays):
    """Stacks arrays along the last axis and masks invalid values."""
    stacked_array = np.stack([arr.values for arr in arrays], axis=-1)
    valid_mask = np.all(stacked_array > 0, axis=-1)
    return stacked_array, valid_mask


def apply_clustering(data, mask, n_clusters):
    """Applies MiniBatch KMeans clustering to the valid data."""
    valid_data = data[mask]
    kmeans = MiniBatchKMeans(n_clusters=n_clusters, random_state=42, batch_size=1024)
    kmeans.fit(valid_data)
    
    cluster_labels = -1 * np.ones(data.shape[:2], dtype=int)
    cluster_labels[mask] = kmeans.predict(valid_data)
    return cluster_labels, kmeans


def compute_cluster_means(data, mask, labels, n_clusters):
    """Computes the average spectral signatures for each cluster."""
    valid_data = data[mask]
    cluster_means = np.zeros((n_clusters, data.shape[-1]))
    for cluster in range(n_clusters):
        cluster_data = valid_data[labels[mask] == cluster]
        cluster_means[cluster] = cluster_data.mean(axis=0)
    return cluster_means


def visualize_results(cluster_labels, cluster_means, bands):
    """Visualizes the clustering results and average spectral signatures."""
    plt.figure(figsize=(10, 5))
    # Clustered Image
    plt.subplot(1, 2, 1)
    plt.title("Clustered Image")
    plt.imshow(cluster_labels, cmap="tab10")
    plt.colorbar(label="Cluster")

    # Spectral Signatures
    plt.subplot(1, 2, 2)
    plt.title("Average Spectral Signatures")
    for i, mean in enumerate(cluster_means):
        plt.plot(mean, label=f"Cluster {i}")
    plt.xticks(range(len(bands)), bands)
    plt.xlabel("Band")
    plt.ylabel("Reflectance")
    plt.legend()
    plt.tight_layout()
    plt.show()


# Workflow
catalog_url = "https://planetarycomputer.microsoft.com/api/stac/v1"
collection = "landsat-c2-l2"

# Step 1: Search and load data
arrays = search_and_load_data(catalog_url, collection, AOI, TIME_RANGE, BANDS)

# Step 2: Stack and mask data
stacked_array, valid_mask = stack_and_mask(arrays)

# Step 3: Apply clustering
cluster_labels, kmeans_model = apply_clustering(stacked_array, valid_mask, N_CLUSTERS)

# Step 4: Compute cluster means
cluster_means = compute_cluster_means(stacked_array, valid_mask, cluster_labels, N_CLUSTERS)

# Step 5: Visualize results
visualize_results(cluster_labels, cluster_means, BANDS)
