# `XArray` Introduction

In this notebook, we are going to learn how to create and manipulate `xarray` datasets and data arrays for geospatial data analysis.

Design goals for xarray:

> "pandas for N-dimensional arrays"
- Built on pandas + NumPy + Dask.
- Copied the pandas API.
- Motivated by weather & climate use cases.

We are going to focus on **2 object types** of interest, the `DataArray` & `DataSet`:

- `DataArray`: Represents a single variable (example: `t2m`).
- `DataSet`: a collection of variables. It generalizes `DataArray` for multi-variate data science.
    - Similar to the difference between Pandas's `DataFrame` and `Series`.

In [None]:
from pathlib import Path
from datetime import datetime, timedelta

import numpy as np
import pandas as pd
import geopandas as gpd
import xarray as xr
import rioxarray
import matplotlib.pyplot as plt

from shapely.geometry import box
from rasterio.enums import Resampling
from geocube.api.core import make_geocube

# Set the seed
np.random.seed(0)

## `DataArray`

Let's start by creating a data array from synthetically-generated values:

In [None]:
# Create the NumPy array that has the data cube values
# description: 3D array filled with 1s
arr = np.ones((3, 4, 2))
arr

In [None]:
# We set the dimension names (x, y, and z)
dim_names = ("x", "y", "z")

In [None]:
# We also set a name for our data array (represents the variable name. Example: 2m-temperature)
var_name = "var"

In [None]:
# Set the coordinate values for `x` and `z` 
# `y` coordinates will be inferred by the library
coords = {
    "x": [0.1, 1.2, 2.3],
    "z": [-1, 1]
}

In [None]:
# We create metadata for the dataset
metadata = {"description": "This dataset has been created for demonstrative purposes."}

In [None]:
# Finally, we create the data array using all of the above information
da = xr.DataArray(
    data=arr,
    dims=dim_names,
    name=var_name,
    coords=coords,
    attrs=metadata
)

# Check the data array
da

We can directly get the underlying `data` (returns either a numpy or dask array):

In [None]:
da.data

We can also check the dimensions of the data array:

In [None]:
da.dims

.. and the coordinates:

In [None]:
da.coords

How about the metadata?

In [None]:
da.attrs

Let's create a data array named `alts` from random data with dimensions "latitude" and "longitude":

In [None]:
# Set the seed for the random generator
rng = np.random.default_rng(seed=0)

# Generate the altitude data
arr = rng.random((180, 360)) * 400
arr.shape

In [None]:
# Create the data array
da = xr.DataArray(
    data=arr,
    dims=("latitude", "longitude"),
    coords={
        "latitude": ("latitude", np.linspace(-90, 90, 180), {"type": "geodetic"}),
        "longitude": ("longitude", np.linspace(-180, 180, num=360), {"prime_meridian": "greenwich"})
    },
    name="height",
    attrs={
        "type": "Ellipsoid"
    }
)
da

Finally, let's visualize the data array using `XArray`'s matplotlib capabilities:

In [None]:
_ = da.plot(figsize=(7,5))

## `Dataset`

A `dataset` is a collection of `data arrays`:

In [None]:
ds = xr.Dataset(
    data_vars={
        "a": (("x", "y"), np.ones((3, 4))),                # 2D Array
        "b": ("t", np.full((8,), 3), {"attr": "value"})  # 1D array
    },
    coords={"x": [-1, 0, 1]},
    attrs={"attr": "value"}
)
ds

We can define multi-date-type coordinates:

In [None]:
ds = xr.Dataset(
    data_vars={
        "a": (("x", "y"), np.ones((3, 4))),
        "b": (("t", "x"), np.full((8, 3), 3))
    },
    coords={
        "x": ["a", "b", "c"],                                  # Categories
        "y": np.arange(4),                                     # Numbers
        "t": pd.date_range("2020-07-05", periods=8, freq="D")  # Dates
    },
    attrs={"attr": "value"}
)
ds

*Note*: `XArray` uses `nan` as its default missing value.

Let's create a dataset with two variables along `latitude` and `longitude`: `altitude` and `gravity_anomaly`:

In [None]:
# Generate the values of the two variables
alt = rng.random((180, 360)) * 400
gravity_anomaly = rng.random((180, 360)) * 400 - 200

# Create the dataset
ds = xr.Dataset(
    data_vars={
        "altitude": (("latitude", "longitude"), alt, {"ellipsoid": "wgs84"}),
        "gravity_anomaly": (("latitude", "longitude"), gravity_anomaly, {"ellipsoid": "grs80"})
    },
    coords={
        "latitude": ("latitude", np.linspace(-90, 90, num=180), {"type": "geodetic"}),
        "longitude": ("longitude", np.linspace(-180, 180, num=360), {"type": "greenwich"})
    }
)
ds

## Data Manipulation

Let's demonstrate how we can manipulate `XArray` objects:

In [None]:
# Create an array
arr = np.random.rand(3, 4)

# Use it to create a data array
da = xr.DataArray(arr, dims=("x", "y"))
da

### `isel`

Similar to Pandas, we can select by index:

In [None]:
da.isel(x=1, y=3)

The same applies to datasets:

In [None]:
# Create a dataset
ds = xr.Dataset(
    data_vars={
        "a": (("x", "y"), np.random.rand(3, 4)),
        "b": (("x", "y"), np.random.rand(3, 4))
    }
)

# Select the variable values at the second x/y
ds.isel(x=1, y=1)

Slicing a data array is also possible:

In [None]:
ds["a"][:2, :1]  # x & y slicing

.. same with datasets:

In [None]:
ds.isel(x=slice(None, 2), y=slice(None, 1))

### `sel`

We use the corrdinate values to directly slice the datasets/arrays:

Let's create a data array with coordinates:

In [None]:
da = xr.DataArray(
    np.random.rand(4, 6),
    dims=("x", "t"),
    coords={
        "x": [2, 9.9, 13, 14],
        "t": pd.date_range("2009-01-05", periods=6, freq="M")
    }
)
da

Select based on the actual grid values:

In [None]:
da.sel(x=9.9, t="2009-01-31")

We we are not sure of the values, we can select by nearest:

In [None]:
da.sel(x=9.8, t="2009-01-13", method="nearest")

We can also slice by values:

In [None]:
da.sel(x=[9.9, 13], t=slice("2009", "2010"))

We can drop a range (opposite of slicing):

In [None]:
da.drop_sel(x=[9.9, 13])

## Reading Data

Next, let's load a real dataset:

In [None]:
ds = xr.tutorial.load_dataset("air_temperature")
ds

Let us slice the data by latitude and longitude:

In [None]:
ds.isel(lat=slice(None, 30), lon=slice(20, 40))

One difference is that XArray's slicing is inclusive of the rightmost value:

In [None]:
ds.sel(lat=75, time=slice("2013-01-01", "2013-10-15"))

Masking is similar Pandas in the sense that we need to create a binary mask used to filter the data:

In [None]:
ds.where(ds.lat < 0.0)

*Note*: we don't want to assign to parts of a data arrays because **dask** (for HPC) arrays are immutable.

## Computation with `XArray`

In [None]:
ds = xr.open_dataset("./data/raster/ecmwf_forecasts.nc")
ds

Let's select the first forecast and visualize:

In [None]:
_ = ds["t2m"].isel(forecast_time=0).plot(robust=True, figsize=(9, 5))

Let's get the 2m temperature data array:

In [None]:
t2m = ds["t2m"]
t2m

In [None]:
# Kelvin to C
t2m_c = t2m - 273.15  # Automatic broadcasting
t2m_c

We recommend to keep using `XArray`'s `DataArray` or `Dataset` objects throughout our session. We would only export raw numpy arrays at the end of the "data preprocessing" stage:

In [None]:
# Compute a new data array
f = 0.5 * np.log(t2m_c ** 2)

### Reductions

Let's plot the temperature annual averages as a timeseries:

In [None]:
_ = (t2m - 273.15).mean(["latitude", "longitude"]).groupby("forecast_time.year").mean().plot(figsize=(7, 3))

Spatial pixel sizes in our grid are not the same, we need to properly weight before aggregating: 

In [None]:
# Create the weights
weights = np.cos(np.deg2rad(t2m.latitude))
weights

In [None]:
# Automatic broadcasting
(t2m * weights)

In [None]:
# ... however, this is better!
t2m_weighted = (t2m - 273.15).weighted(weights)
t2m_weighted

In [None]:
# We plot the weighted pixels
_ = t2m_weighted.mean(["latitude", "longitude"])\
    .groupby("forecast_time.year")\
    .mean().plot(figsize=(7, 3))

We can also plot the time-series for a specific location of interest:

In [None]:
_ = t2m.isel(latitude=30, longitude=40).plot(figsize=(7, 3))

### `GroupBy`

Let's calculate the monthly climatology:

In [None]:
# Calculate climatology
clim = t2m.groupby("forecast_time.month").mean("forecast_time")

# Visualize the seasonal cycle for a location of interest
_ = clim.isel(latitude=30, longitude=40).plot(figsize=(5, 2))

Let's take a look at the `tp` seasonal cycle by latitude:

In [None]:
_ = clim.mean("longitude").plot(x="month", y="latitude", levels=15)

Let's conduct a transformation to remove the seasonal climatology:

In [None]:
# Remove the monthly climatology
t2m_anoms = t2m.groupby("forecast_time.month").map(lambda grp: grp - grp.mean("forecast_time"))

# Visualize for a single location
_ = t2m_anoms.isel(latitude=30, longitude=40).plot(figsize=(5, 3))

We can also use `resample` (like Pandas):

In [None]:
t2m_anoms_1y = t2m_anoms.resample(forecast_time="1Y").mean("forecast_time")
_ = t2m_anoms_1y.plot(col="forecast_time", col_wrap=4)

## Visualization

We use histograms to visualize variable distributions:

In [None]:
_ = ds["t2m"].plot(bins=50, histtype="step", stacked=True, fill=False, figsize=(5, 3))

We can also visualize 2D arrays:

In [None]:
_ = ds["t2m"].isel(forecast_time=0).plot(
    robust=True,
    cbar_kwargs={"label": "2 meter temprature [Kelvin]", 
                 "orientation": "horizontal"},
    figsize=(7, 4)
)

We can change the styling through matplotlib:

In [None]:
_ = ds["t2m"].isel(forecast_time=0).plot.contour()

Let's visualize 2m-temperature longitude variation for 3 latitude values of interest:

In [None]:
_ = ds["t2m"].mean("forecast_time").isel(latitude=[25, 26, 27]).plot(hue="latitude", figsize=(5, 4))

### Workflow for Complex Plots

To produce complex plots following these steps:
1. Custom-create the `ax` using `plt.subplots()`.
2. Plot with `xarray` using `.plot(ax=ax)`.
3. Further customize the `ax`es directly using `matplotlib`. 

### Facets

Let's experiment with plotting 3D data:

In [None]:
# Get the data array
da = ds["t2m"]

# Group by month of year and calculate monthly means
da = da.groupby("forecast_time.month").mean()

# Plot
fg = da.plot(col="month", col_wrap=4, robust=True)
_ = fg.fig.suptitle("Seasonal evolution of global 2m temperature", y=1)
plt.show()

## `Rioxarray`

Why use `rioxarray` instead of `xarray`?

- It stores the CRS as a WKT, which is the recommended format (from *PROJ FAQ*).
- It loads in the CRS, transform, and nodata metadata in standard CF & GDAL locations.
- It supports masking and scaling data with the `masked` and `mask_and_scale` kwargs.
- It loads raster metadata into the attributes.

In [None]:
ds = rioxarray.open_rasterio("./data/raster/nairobi_elevation.tiff")
_ = ds.squeeze().plot(robust=True, cmap="terrain")

### Resampling

In [None]:
# Set the upscale factor
upscale_factor = 2

# Get the new width and height based on the resampling factors
new_width = ds.rio.width * upscale_factor
new_height = ds.rio.height * upscale_factor

# Resample the dataset
resampled_ds = ds.rio.reproject(
    ds.rio.crs,
    shape=(new_height, new_width),
    resampling=Resampling.bilinear
)
resampled_ds.shape

### Clipping

In [None]:
# Create the Nairobi box
xmin, ymin, xmax, ymax = 36.66, -1.37, 37.06, -1.07
nairobi_bbox = box(xmin, ymin, xmax, ymax)
nairobi_bbox

In [None]:
# Clip & visualize
clipped_ds = ds.rio.clip([nairobi_bbox])
_ = clipped_ds.squeeze().plot(robust=True, cmap="terrain")

### Reproject

In [None]:
# Current CRS
ds.rio.crs

In [None]:
# Estimate the destination UTM CRS
dest_crs = ds.rio.estimate_utm_crs()
dest_crs

In [None]:
# Let `rioxarray` estimate the destination UTM CRS
ds_utm = ds.rio.reproject(dest_crs)
ds_utm.rio.crs

### Cloud Optimized GeoTIFFs

The advantages of COG over GeoTiffs:

- **Tiling**: COGs are organized into small tiles that can be efficiently accessed and processed independently of the rest of the image. This allows for faster data access and processing, as only the relevant tiles need to be loaded into memory.
- **Compression**: COGs use lossless or lossy compression to reduce the file size without sacrificing image quality. This reduces storage costs and network bandwidth requirements, making it easier to transfer and access data in cloud-based environments.
- **Overviews**: COGs include pre-computed lower-resolution versions of the image, called overviews or pyramids, that can be used for rapid display and analysis at smaller scales. This reduces the need to access and process the full-resolution image, improving performance.

In [None]:
# Save
ds.rio.to_raster(raster_path="./data/raster/output_cog.tif", driver="COG")

### Mask Generation

In [None]:
# Create the mask
ds_mask = make_geocube(
    vector_data=gpd.GeoDataFrame(data={"val": [1]}, geometry=[nairobi_bbox]),
    like=ds
)

Let's merge the two:

In [None]:
# combine the DataArrays into a single Dataset
merged_ds = xr.Dataset(
    {"elevation": ds, "mask": ds_mask["val"]}
)

In [None]:
# create a figure with two subplots
fig, axs = plt.subplots(1, 2, figsize=(15, 5))

# plot the first variable in the left subplot
merged_ds["elevation"].plot(ax=axs[0], cmap="terrain", robust=True)
axs[0].set_title("Elevation")

# plot the second variable in the right subplot
merged_ds["mask"].plot(ax=axs[1], cmap="Greys")
axs[1].set_title("Mask")

# show the plot
plt.show()

## Resources

- [Official Documentation](https://docs.xarray.dev/en/stable/): The primary resource for XArray is its official documentation. It covers various topics, including installation, user guide, examples, and API reference.
- [XArray for multidimensional data](https://rabernat.github.io/research_computing_2018/xarray.html): This introductory tutorial by Ryan Abernat covers the basics of XArray and its role in handling multidimensional data.
- [Earth and Environmental Data Science with XArray](https://pangeo.io/): The Pangeo project provides a wealth of resources related to using XArray for Earth and environmental data science, including tutorials and examples.
- [XArray Tutorial for Geospatial Data](http://stephanhoyer.com/2015/06/11/xray-dask-out-of-core-labeled-arrays/): This tutorial by Stephan Hoyer demonstrates the use of XArray and Dask for processing large geospatial datasets.
- [XArray Exercises](https://github.com/xarray-contrib/xarray-tutorial): This GitHub repository by Tom Nicholas contains exercises and solutions for learning XArray.

---