<!-- -->


# Advanced Xarray

## Description
The last notebook [04_xarrayI_data_structure.ipynb](04_xarrayI_data_structure.ipynb) gave a first introduction to working with `xarray`.

In this notebook, we deepen the understanding of `xarray` as a container for remote sensing raster data and introduce additional `xarray` functions that are useful for analysis workflows.


## Setup
We will use `pystac-client` to search the Microsoft Planetary Computer STAC catalog and `odc-stac` (`stac_load`) to load the requested data into an `xarray.Dataset`. We use `NumPy` and `xarray` for the analysis steps.


In [None]:
import pandas as pd
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

from pystac_client import Client
import planetary_computer as pc
from odc.stac import stac_load

# Set config for displaying tables nicely
pd.set_option("display.max_colwidth", 200)
pd.set_option("display.max_rows", None)


### STAC search and load data
First, we search the Planetary Computer STAC catalog and load an example dataset using `odc-stac`.


In [None]:
# Load data from Planetary Computer (STAC)
STAC_URL = "https://planetarycomputer.microsoft.com/api/stac/v1"
COLLECTION = "sentinel-2-l2a"

# Area of interest: Würzburg (EPSG:4326)
bbox = (9.88, 49.75, 10.0, 49.82)

# Output grid
crs = "EPSG:32632"
resolution = 20

# Search STAC items
catalog = Client.open(STAC_URL)

datetime = "2021-03-01/2021-06-15"
query = {"eo:cloud_cover": {"lt": 40}}

search = catalog.search(collections=[COLLECTION], bbox=bbox, datetime=datetime, query=query)
items = list(search.get_items())
len(items)

# Load pixels with odc-stac
bands = ["B02", "B03", "B04", "B08"]
resampling = {"*": "bilinear"}

ds_raw = stac_load(
    items,
    bands=bands,
    crs=crs,
    resolution=resolution,
    groupby="solar_day",
    patch_url=pc.sign,
    dtype="uint16",
    nodata=0,
    resampling={"*": "bilinear"},
)

# Rename to match the variable names used throughout this notebook
rename_map = {"B02": "blue", "B03": "green", "B04": "red", "B08": "nir"}
ds = ds_raw.rename({k: v for k, v in rename_map.items() if k in ds_raw.data_vars})

# Scale reflectance (Sentinel-2 L2A) to ~0..1
for name in list(ds.data_vars):
    if name != "scl":
        ds[name] = ds[name].astype("float32") * 1e-4

ds


<a id='index_array3'></a>
## **Advanced Indexing**
### 1) Temporal Subset

In the earlier tutorial, we introduced `isel()`and `sel()` for indexing data. For both methods, a **slicing** operator exists. If the function `slice()` is passed onto the index function, the dataset is sliced. 
The first example uses the slicing by position method to select the first five scenes in `ds`. The start value is included (here, 0) and the stop value (here, 5) is excluded.

#### I. Using index number

In [14]:
ds.isel(time=slice(0,5))
#ds.isel(time = [0,1,2,3,4])

In [15]:
ds.isel(time=slice(0,5)).time

#### II. Using `datetime64` data

This example uses the slicing by label method to select the scenes between "2021-03-01" and "2021-03-10". Note, that when using the `slice()` function with the `sel()` method, both start and stop value are included.

In [8]:
print(ds.sel(time=slice("2021-03-01","2021-03-10"))) 

<xarray.Dataset>
Dimensions:      (time: 4, y: 905, x: 977)
Coordinates:
  * time         (time) datetime64[ns] 2022-03-02T10:19:41.024000 ... 2022-03...
  * y            (y) float64 1.558e+07 1.558e+07 ... 1.557e+07 1.557e+07
  * x            (x) float64 -3.002e+05 -3.002e+05 ... -2.905e+05 -2.904e+05
    spatial_ref  int32 32734
Data variables:
    blue         (time, y, x) uint16 9640 9624 9536 9552 ... 1350 1365 1336 1349
    green        (time, y, x) uint16 8992 8872 8896 8960 ... 1422 1440 1428 1430
    red          (time, y, x) uint16 8488 8448 8408 8416 ... 1516 1509 1479 1491
    nir          (time, y, x) uint16 8768 8720 8656 8616 ... 2476 2438 2342 2402
Attributes:
    crs:           EPSG:32734
    grid_mapping:  spatial_ref


In [16]:
ds.sel(time=slice("2021-03-01","2021-03-10")).time

#### III. Using other time dimensions

`xarray` also includes some useful features for the inspection of the time dimension. It helps extract additional information from a dataset efficiently. The following code automatically groups the time dimension in seasons ("DJF", "MAM", JJA", "SON"). There are a lot of other `time` dimensions arguments, e.g., `month`, `week`, `weekday`, `dayofyear`.

In [17]:
ds.time.dt.season

In [18]:
ds.time.dt.month

In [19]:
ds.time.dt.weekday

It is also possible to extract the "day of year" for a time step.

In [20]:
ds.time.dt.dayofyear

In [21]:
ds.groupby('time.season')

DatasetGroupBy, grouped over 'season'
2 groups with labels 'JJA', 'MAM'.

In [25]:
#ds.groupby('time.season').mean()

<bound method DatasetGroupByReductions.mean of DatasetGroupBy, grouped over 'season'
2 groups with labels 'DJF', 'SON'.>

### 2) Spatial Subset
It is possible to index and **slice within the x and y dimensions**. The following example selects the value for pixels of all bands in the second column and the fifth row of the raster (`x=2,y=5`).

In [22]:
ds.isel(x=2, y=5)
#ds.isel(x=[0,1,2], y=5)

### 3) Combining Temporal and Spatial Subset

We can subset temporally and spatially using `slice()` operator. If you know the actual coordinate (x,y) value (extent) of the spatial subset area, use the `sel()` function.

The following example subsets the `ds` by the temporal and spatial location of the pixels. Only the pixels from the first to the fifth columns and the pixels from the first to the fifth rows are included in the output. Also, the scenes are filtered in the time dimension between the first and fifth time step.

In [23]:
ds2 = ds.isel(time=slice(0,5), x= slice(0,5), y=slice(0,5))
ds2

#ds2.time
#plt.scatter(ds2.x.values, ds2.y.values)

## **Data Manipulation & Statistics**

This notebook presents some basic built-in functions of the `xarray` library to manipulate and transform data in a `xarray.Dataset`. Here, we show only a fraction of the available `xarray` functions. For a complete overview of all the available functions and tools of the `xarray` package, please visit the [documentation website](http://xarray.pydata.org/en/stable/). 

[Notebook 07](07_basic_analysis.ipynb) will cover this topic, focusing on an application-oriented remote sensing approach.
###  1) Statistical Operation

The simple built-in functions allow the user to do simple calculations with a `xarray.Dataset`.
The **basic math** built-in `xarray` functions are:
* `min()`, `max()`
* `mean()`, `median()`
* `sum()`
* `std()`

The following code demonstrates the easy use of the `max()` function to extract the maximum value of the red band in the `ds` dataset.

In [24]:
print(ds.red.max())

<xarray.DataArray 'red' ()>
array(19440, dtype=uint16)
Coordinates:
    spatial_ref  int32 32734


To apply a function to every value of a specified dimension (e.g., to calculate the mean of every time step), the `dim` argument in the basic math function must be defined with the dimension label.

This example calculates the mean of the `red` band for each pixel (defined by the unique `x`, `y` combination) over every time step. The result is a data array that can be used for further time series visualization and analysis.

In [25]:
print(ds.red.mean(dim=["x", "y"]))

#ds.red.mean(dim=["x", "y"]).values
#plt.plot(ds.red.mean(dim=["x", "y"]).values)

<xarray.DataArray 'red' (time: 42)>
array([ 1818.96105227,  8074.00635727,  1754.47787058,  5327.82408659,
        5052.20414958,  7466.13877413,  6652.25770399,  1943.12213507,
        5791.55179063,  2165.9254138 ,  6284.90501535,  1836.40194077,
        4318.13241799,  2913.1250926 , 10243.98031181,  2596.53914848,
        2306.90195604,  7080.55823498, 10008.62734835,  9668.67531795,
        3243.49281542,  1834.18206823,  1859.24376347,  8750.22720245,
        6678.98628002,  7958.82580229,  9407.54412708,  1739.37265052,
        6333.50257243,  6604.30961055,  3084.24980632,  8654.94368712,
        3030.41774063,  6430.65277289,  9552.01161069,  1698.20226649,
        1879.8146112 ,  9411.29344425,  8475.92199596,  1910.32194733,
        1629.76212444,  2132.29992705])
Coordinates:
  * time         (time) datetime64[ns] 2021-03-02T10:18:39.025000 ... 2021-06...
    spatial_ref  int32 32734


This examples works the other way around. It calculates the standard deviation of every pixel (`x`, `y`) over all timesteps of the dataset `ds`.

In [26]:
print(ds.red.std(dim="time"))

<xarray.DataArray 'red' (y: 905, x: 977)>
array([[2950.13795135, 3087.9106436 , 3190.4511719 , ..., 3660.84028099,
        3649.89349835, 3641.69284818],
       [2981.50697802, 3094.12890492, 3183.76733881, ..., 3678.63263811,
        3657.9010849 , 3639.44330671],
       [2897.76457591, 2874.01064669, 3061.91223283, ..., 3680.94274286,
        3674.93537613, 3657.93154202],
       ...,
       [3267.92346919, 3279.36118553, 3303.54414343, ..., 3657.32528726,
        3648.90419461, 3617.15589938],
       [3263.62266754, 3275.78591468, 3317.94272277, ..., 3667.18960309,
        3662.70165187, 3640.87214865],
       [3273.34765963, 3271.6850201 , 3323.72091445, ..., 3678.39480058,
        3666.36207741, 3660.46720002]])
Coordinates:
  * y            (y) float64 1.558e+07 1.558e+07 ... 1.557e+07 1.557e+07
  * x            (x) float64 -3.002e+05 -3.002e+05 ... -2.905e+05 -2.904e+05
    spatial_ref  int32 32734


Remember, to access the raw `numpy` array that stores the values of the resulting `xarray.DataArrays`, the suffix `.values` is needed. This allows you to work with the "actual" data values.

In [27]:
print(ds.blue.sum(dim=["x","y"]).values)
#plt.plot(ds.blue.sum(dim=["x","y"]).values)

[ 1378819884  7964658670  1272194029  4698605812  4541559804  6886255640
  6110182240  1482596331  5194190235  1712707186  6006349540  1342901595
  4120623844  2508944545 10428593244  2162029094  1887043287  6769489264
  8975718324  8915725986  2999159582  1392269227  1379476980  9055516168
  5978472916  7119185436  8538318172  1328425136  5701259768  6387376838
  2870126501  8182499220  2881210000  5764934137  9266444432  1380615879
  1568188829  9715135166  7933586832  1631484539  1335706615  1833185628]


### 2) Conditional Operation

Using conditional operation can be very helpful when we need to analyze satellite scenes or pixels that lie within our interests. The `where()` function provides the option to **mask** a `xarray.Dataset` based on a logical condition. By default, the function converts all values that match the condition to NaN values. This is extremely useful when applied with a binary mask to mask your data to the desired values. The argument `other` lets you define a subset value for all values that match the condition (default is `nan`). The argument `drop` drops all values which do not correspond with the condition.
The following example masks the dataset `ds` to only the values with a reflectance value greater than 700 in the `red` band.

In [28]:
print(ds.where(ds.red > 700))
#print(ds.where(ds.red < 700))

<xarray.Dataset>
Dimensions:      (time: 42, y: 905, x: 977)
Coordinates:
  * time         (time) datetime64[ns] 2021-03-02T10:18:39.025000 ... 2021-06...
  * y            (y) float64 1.558e+07 1.558e+07 ... 1.557e+07 1.557e+07
  * x            (x) float64 -3.002e+05 -3.002e+05 ... -2.905e+05 -2.904e+05
    spatial_ref  int32 32734
Data variables:
    blue         (time, y, x) float32 1.938e+03 1.691e+03 ... 1.816e+03
    green        (time, y, x) float32 2.118e+03 1.872e+03 ... 1.927e+03
    red          (time, y, x) float32 2.274e+03 2.032e+03 ... 1.767e+03
    nir          (time, y, x) float32 3.493e+03 3.307e+03 ... 4.284e+03
Attributes:
    crs:           EPSG:32734
    grid_mapping:  spatial_ref


This code subsets all zeros in the red band of the dataset `ds` in the first time stamp with the new value -9999.

In [29]:
replace = ds.red.isel(time=0).where(ds.red != 0, other = -9999)
#replace.values.min()

The implemented `xarray` function `isin()` allows us to **test each value** of `xarray.Dataset` or `xarray.DataArray` whether it is in the elements defined within the function. It returns a boolean array which can be used as a mask.
This example checks all the values of the `red` measurement if the value is in an array from 0 to 550.

In [42]:
mask_red = ds.red.isin(range(550))
print(mask_red)

#plt.imshow(mask_red) #error
#plt.imshow(mask_red.isel(time=3))

<xarray.DataArray 'red' (time: 21, y: 1031, x: 1010)>
array([[[False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False],
        ...,
        [False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False]],

       [[False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False],
        ...,
        [False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False]],

       [[False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False],
        ...,
...
        ...,
        [False, False, False, .

The created mask can easily be combined with the `where()` function to filter the dataset based on the predefined mask. In this case, the `ds` dataset is masked with previously defined mask `mask_red`, which is based on a logical test if values of the `red` band are within a specific range of values.

In [43]:
print(ds.where(mask_red)) #masking

<xarray.Dataset>
Dimensions:      (time: 21, y: 1031, x: 1010)
Coordinates:
  * time         (time) datetime64[ns] 2020-10-01T08:28:17 ... 2020-12-30T08:...
  * y            (y) float64 6.807e+06 6.807e+06 ... 6.797e+06 6.797e+06
  * x            (x) float64 8.687e+05 8.687e+05 ... 8.787e+05 8.788e+05
    spatial_ref  int32 32734
Data variables:
    blue         (time, y, x) float64 nan nan nan nan nan ... nan nan nan nan
    green        (time, y, x) float64 nan nan nan nan nan ... nan nan nan nan
    red          (time, y, x) float64 nan nan nan nan nan ... nan nan nan nan
    nir          (time, y, x) float64 nan nan nan nan nan ... nan nan nan nan
Attributes:
    crs:           EPSG:32734
    grid_mapping:  spatial_ref


### 3) Resampling
Resampling is necessary when working with time-series data if we want the data product to align with the temporal window.

 - **resample()**

The **`resample()` method** allows us to summarise the `xarray.Dataset` into bigger or smaller chunks based on a dimension. It handles both upsampling and downsampling. The argument `time` needs to be defined as a datetime-like coordinate. In the following example, we resample the `ds` dataset to a monthly time interval (`time = "m"`) and then calculate the median value for every resample chunk. _(this process takes a little while to run)_

In [30]:
print(ds.resample(time='m').median())

<xarray.Dataset>
Dimensions:      (y: 905, x: 977, time: 4)
Coordinates:
  * y            (y) float64 1.558e+07 1.558e+07 ... 1.557e+07 1.557e+07
  * x            (x) float64 -3.002e+05 -3.002e+05 ... -2.905e+05 -2.904e+05
    spatial_ref  int32 32734
  * time         (time) datetime64[ns] 2021-03-31 2021-04-30 ... 2021-06-30
Data variables:
    blue         (time, y, x) float64 2.828e+03 2.754e+03 ... 1.532e+03
    green        (time, y, x) float64 2.914e+03 2.85e+03 ... 1.723e+03 1.724e+03
    red          (time, y, x) float64 2.93e+03 2.882e+03 ... 1.475e+03 1.478e+03
    nir          (time, y, x) float64 3.782e+03 3.774e+03 ... 6.048e+03
Attributes:
    crs:           EPSG:32734
    grid_mapping:  spatial_ref


 - **groupby() method**

The **`groupby()` method** can also be used within the `xarray` library to *aggregate data over time*. Time aggregation arguments can be e.g. "time.year", "time.season", "time.month", "time.week", "time.day".
The code below groups the `ds` dataset into two groups by year. Therefore, a new "dimension" `year` is created. Then the median for each `year` is calculated. _(this process takes a little while to run)_

In [31]:
print(ds.groupby("time.year").median(dim="time"))

<xarray.Dataset>
Dimensions:      (y: 905, x: 977, year: 1)
Coordinates:
  * y            (y) float64 1.558e+07 1.558e+07 ... 1.557e+07 1.557e+07
  * x            (x) float64 -3.002e+05 -3.002e+05 ... -2.905e+05 -2.904e+05
    spatial_ref  int32 32734
  * year         (year) int64 2021
Data variables:
    blue         (year, y, x) float64 4.183e+03 3.942e+03 ... 3.946e+03
    green        (year, y, x) float64 4.252e+03 4.079e+03 ... 3.842e+03
    red          (year, y, x) float64 4.073e+03 3.961e+03 ... 3.747e+03
    nir          (year, y, x) float64 4.974e+03 5.038e+03 ... 5.954e+03
Attributes:
    crs:           EPSG:32734
    grid_mapping:  spatial_ref


### 4) Interpolation
Interpolation is a common solution dealing with missing remote sensing data, either caused by the coarse temporal resolution of the satellite, high cloud cover, or bad quality of the scenes. For example, a scene of a specific date is not available in the dataset. With the implemented `interp()`, it is possible to **interpolate data** for predefined time steps. The function takes the next usable scene before and after the specified date and interpolates their values (by default, interpolation method is "linear") to build a new `xarray.Dataset`.

In this example, the `ds` dataset has missing scenes on the "2020-12-25". The `interp()` function builds a "new" scene based on a linear interpolation from the two measurements before and after the new time step.

In [32]:
print(ds.time)

<xarray.DataArray 'time' (time: 42)>
array(['2021-03-02T10:18:39.025000000', '2021-03-05T10:28:09.024000000',
       '2021-03-07T10:20:21.024000000', '2021-03-10T10:30:21.024000000',
       '2021-03-12T10:17:29.024000000', '2021-03-15T10:27:09.024000000',
       '2021-03-17T10:20:21.024000000', '2021-03-20T10:30:21.024000000',
       '2021-03-22T10:16:49.024000000', '2021-03-25T10:26:39.024000000',
       '2021-03-27T10:20:21.024000000', '2021-03-30T10:30:21.024000000',
       '2021-04-01T10:15:59.024000000', '2021-04-04T10:25:59.024000000',
       '2021-04-06T10:20:21.025000000', '2021-04-09T10:30:21.024000000',
       '2021-04-11T10:15:59.024000000', '2021-04-14T10:25:59.024000000',
       '2021-04-16T10:20:21.024000000', '2021-04-19T10:30:21.024000000',
       '2021-04-21T10:15:49.024000000', '2021-04-24T10:25:49.024000000',
       '2021-04-26T10:20:21.024000000', '2021-04-29T10:30:21.024000000',
       '2021-05-01T10:15:59.024000000', '2021-05-04T10:25:59.025000000',
       '2021-0

In [34]:
ds_interp = ds.interp(time=["2021-06-10"])
print(ds_interp)

<xarray.Dataset>
Dimensions:      (y: 905, x: 977, time: 1)
Coordinates:
  * y            (y) float64 1.558e+07 1.558e+07 ... 1.557e+07 1.557e+07
  * x            (x) float64 -3.002e+05 -3.002e+05 ... -2.905e+05 -2.904e+05
    spatial_ref  int32 32734
  * time         (time) datetime64[ns] 2021-06-10
Data variables:
    blue         (time, y, x) float64 4.545e+03 4.421e+03 ... 2.347e+03
    green        (time, y, x) float64 4.611e+03 4.519e+03 ... 2.374e+03
    red          (time, y, x) float64 4.595e+03 4.406e+03 ... 2.222e+03
    nir          (time, y, x) float64 5.99e+03 6.351e+03 ... 3.856e+03 4.634e+03
Attributes:
    crs:           EPSG:32734
    grid_mapping:  spatial_ref


The `merge()` function allows us to **merge/join** `xarray.Datasets` or variables. By default, the `merge()` function uses an "inner" join as a merging operation. 
In our example, the interpolated `xarray.Dataset` created above is merged to the `ds` dataset using the `merge()` function.

In [35]:
print(ds.merge(ds_interp).time)

<xarray.DataArray 'time' (time: 43)>
array(['2021-03-02T10:18:39.025000000', '2021-03-05T10:28:09.024000000',
       '2021-03-07T10:20:21.024000000', '2021-03-10T10:30:21.024000000',
       '2021-03-12T10:17:29.024000000', '2021-03-15T10:27:09.024000000',
       '2021-03-17T10:20:21.024000000', '2021-03-20T10:30:21.024000000',
       '2021-03-22T10:16:49.024000000', '2021-03-25T10:26:39.024000000',
       '2021-03-27T10:20:21.024000000', '2021-03-30T10:30:21.024000000',
       '2021-04-01T10:15:59.024000000', '2021-04-04T10:25:59.024000000',
       '2021-04-06T10:20:21.025000000', '2021-04-09T10:30:21.024000000',
       '2021-04-11T10:15:59.024000000', '2021-04-14T10:25:59.024000000',
       '2021-04-16T10:20:21.024000000', '2021-04-19T10:30:21.024000000',
       '2021-04-21T10:15:49.024000000', '2021-04-24T10:25:49.024000000',
       '2021-04-26T10:20:21.024000000', '2021-04-29T10:30:21.024000000',
       '2021-05-01T10:15:59.024000000', '2021-05-04T10:25:59.025000000',
       '2021-0

The `xarray` package contains a variety of other useful functions besides those shown here. For more information about the `xarray` package, visit the [documentation website](http://xarray.pydata.org/en/stable/).

***

## Additional information

<font size="2">This notebook is provided for teaching by the [Department of Remote Sensing](http://remote-sensing.org/), [University of Wuerzburg](https://www.uni-wuerzburg.de/startseite/). It has been updated to use Planetary Computer STAC + `odc-stac`. </font>

**License:** The code in this notebook is licensed under the [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0).

**Data access:** Sentinel-2 L2A pixels are loaded from the Microsoft Planetary Computer via STAC using `odc-stac`.

**Data license:** See the dataset/collection metadata on Planetary Computer for license and attribution details.
