In [None]:
### I've added some dask operations and functionality on top of what the 
### original tutorial had but it's mostly Hamed's code from Raster Processing class 15

In [None]:
import xarray as xr
import rioxarray as rxr
from rioxarray.merge import merge_datasets
import numpy as np
import pyproj
from rasterio.enums import Resampling
from shapely.geometry import box, mapping
from dask.distributed import Client
from dask import delayed

import os
import requests
from tqdm import tqdm
import dask.array as da

import pystac_client
import planetary_computer
from dask.diagnostics import ProgressBar  # Useful for monitoring progress


In [None]:
client = Client(memory_limit='4GB')  # Adjust the memory limit based on your system's capacity

In [None]:
def retrieve_stac_items_bbox(url, collection, bbox, datetime, max_items=10, pc_flag=False, max_cloud_cover=2): 
    """
    This function retrieves the latest STAC items from a STAC API, filtered by cloud cover.
    
    Parameters
    ----------
    url : str
        The STAC API URL
    collection : str
        Collection ID to search in the STAC API
    bbox : list or tuple
        A tuple or list with coordinates of the target area for API search in the format (min x, min y, max x, max y)
    datetime : str
        ISO-8601 formatted date range, e.g., "2020-01-01/2020-12-31".
    pc_flag : boolean
        A boolean flag to specify whether the STAC API is Microsoft's planetary computer API. Default is set to False.
    max_items : integer
        Maximum number of items to retrieve from the STAC API. Default is 10.
    max_cloud_cover : float
        Maximum allowed cloud cover percentage for filtering. Default is 2%.

    Returns
    -------
    items : Generator
        The STAC item collection matching the query parameters.
    """
    if pc_flag:
        modifier = planetary_computer.sign_inplace
    else:
        modifier = None
        
    catalog = pystac_client.Client.open(
        url=url,
        modifier=modifier,
    )

    # Query for items, including the cloud cover filter
    search_results = catalog.search(
        collections=[collection],
        bbox=bbox,
        datetime=datetime,  # ISO-8601 formatted date range
        query={"eo:cloud_cover": {"lt": max_cloud_cover}},  # Cloud cover filter
        sortby=["-properties.datetime"],
        max_items=max_items
    )
    items = search_results.item_collection()

    return items

In [None]:
# Define parameters for Charlottesville, Virginia (1-degree square around the city center)
url = "https://planetarycomputer.microsoft.com/api/stac/v1"
collection = "landsat-c2-l2"

bbox = [-79.5, 37.9, -78.5, 38.9]  # 1-degree square bounding box around Charlottesville
bbox_geometry = mapping(box(*bbox))
time_range = "2020-01-01/2020-12-31"

In [None]:
# Query Landsat images for the bounding box and time range
items = retrieve_stac_items_bbox(
    url=url,
    collection=collection,
    bbox=bbox,
    datetime=time_range,  # Pass the time range directly
    max_items=1000,  # Retrieve up to 10 images
    pc_flag=True  # Use Planetary Computer API
)

# # Print the retrieved items
# for item in items:
#     print(item.id, item.properties["datetime"])


In [None]:
def gen_stac_asset_urls(items, asset):
    """
    This function receives an items collection returned by a STAC API, and returns
    the urls of the requested `asset` in a list. 

    Inputs:
        items : json collection
            A STAC items collection returned by STAC API
        asset : string
            Name of an asset present in the `items` collection

    Returns:
        urls : list
            List of all usls related to the `asset`
            
    """

    urls = []
    for item in items:
        urls.append(item.assets[asset].href)
    
    return urls

In [None]:
# Example of loading red, green, and blue band URLs using Dask
red_urls = gen_stac_asset_urls(items, "red")
green_urls = gen_stac_asset_urls(items, "green")
blue_urls = gen_stac_asset_urls(items, "blue")


In [None]:
with ProgressBar():
    red_datasets = open_rasters_with_dask(red_urls)
    green_datasets = open_rasters_with_dask(green_urls)
    blue_datasets = open_rasters_with_dask(blue_urls)


In [None]:
def merge_in_batches(datasets, batch_size=5):
    merged_result = None
    for i in range(0, len(datasets), batch_size):
        batch = datasets[i:i + batch_size]
        batch_merged = xr.merge(batch)  # Merge the current batch
        if merged_result is None:
            merged_result = batch_merged
        else:
            merged_result = xr.merge([merged_result, batch_merged])  # Merge with previous result
    return merged_result

In [None]:
def dask_merge_datasets(datasets):
    # Use delayed to merge datasets in batches
    delayed_batches = merge_in_batches(datasets)
    
    # Convert merged dataset to Dask array
    # We chunk the merged dataset after the delayed operations
    dask_array = delayed_batches.chunk({'x': 1024, 'y': 1024})  # Ensure chunking is efficient
    
    # Compute the result in parallel once all tasks are delayed
    result = dask.compute(dask_array)
    return result[0]  # Return the first element, which is the result of the delayed computation

In [None]:
# Mosaicking process with Dask
with ProgressBar():
    mosaic_red = dask_merge_datasets(red_datasets)
    mosaic_blue = dask_merge_datasets(blue_datasets)
    mosaic_green = dask_merge_datasets(green_datasets)

In [None]:
# Clip function for Dask
def dask_clip(dask_array, bbox_geometry, crs=4326):
    """
    Clip a Dask array lazily to a given bounding box using RioClip.
    
    Parameters
    ----------
    dask_array : xarray.DataArray or dask.array
        The Dask array to be clipped.
    bbox_geometry : dict
        The geometry for clipping (GeoJSON format).
    crs : int, optional
        The coordinate reference system (CRS) for the bounding box, default is 4326.
        
    Returns
    -------
    clipped : xarray.DataArray
        The clipped Dask array.
    """
    return dask_array.rio.clip([bbox_geometry], crs=crs)

In [None]:
with ProgressBar():
    mosaic_red = dask_clip(mosaic_red, bbox_geometry)
    mosaic_blue = dask_clip(mosaic_blue, bbox_geometry)
    mosaic_green = dask_clip(mosaic_green, bbox_geometry)

In [None]:
with ProgressBar():
    mosaic_red.red.rio.to_raster("red_clipped.tif")
    mosaic_blue.blue.rio.to_raster("blue_clipped.tif")
    mosaic_green.green.rio.to_raster("green_clipped.tif")

In [None]:
# mosaic_red.red.rio.to_raster("red_clipped.tif")
# mosaic_blue.red.rio.to_raster("blue_clipped.tif")
# mosaic_green.red.rio.to_raster("green_clipped.tif")

In [None]:
final_mosaic = xr.merge([mosaic_red, mosaic_green, mosaic_blue])

In [None]:
final_mosaic.rio.to_raster('final_mosaic.tif')