In [2]:
%matplotlib inline

import datacube
import numpy as np
import xarray as xr 
import matplotlib.pyplot as plt
from scipy.ndimage import uniform_filter

import sys
sys.path.insert(1, '../Tools/')
from dea_tools.plotting import display_map

In [3]:
# Import extra libraries
import pystac_client
import planetary_computer

import odc.stac
import odc.geo.xr
from odc.geo.geom import BoundingBox

In [4]:
dc = datacube.Datacube(app="Radar_water_detection")

In [5]:
# this is the same as the top of the notebook, 
# but we can change to other areas of interest
latitude = (9.913, 9.764)
longitude = (126.077, 126.172)
time = ("2021-10", "2022-02") #Typhoon Odette 2021-12

In [6]:
# Open a client pointing to the Microsoft Planetary Computer data catalogue
catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
)

In [7]:
# Convert data-cube style queries into something readable by `pystac_client`
bbox = BoundingBox.from_xy(longitude, latitude)
time_range = "/".join(time)

# Search for STAC items from "esa-worldcover" product
search = catalog.search(
    collections="sentinel-1-rtc",
    bbox=bbox,
    datetime=time_range,
)

# Check how many items were returned
items = search.item_collection()
print(f"Found {len(items)} STAC items")

Found 22 STAC items


In [8]:
# Load sentinel-1 RTC data with odc-stac
ds_s1 = odc.stac.load(
    items,
    bbox=bbox,
    crs="EPSG:3327",
    resolution=20,
)

# Inspect outputs
# ds_s1

In [9]:
def dB_scale(data): 
    '''Scales a xarray.DataArray with linear DN to a dB scale.'''
    # Explicitly set negative data to nan to avoid log of negative number
    negative_free_data = data.where(data >= 0, np.nan)
    return 10 * np.log10(negative_free_data)

In [10]:
# Scale to plot data in decibels
ds_s1["vh_dB"] = dB_scale(ds_s1.vh)

# Plot all VH observations for the year
# ds_s1.vh_dB.plot(cmap="Greys_r", robust=True, col="time", col_wrap=5)
# plt.show()

In [11]:
# Plot the average of all VH observations
mean_vh_dB = ds_s1.vh_dB.mean(dim="time")

# fig = plt.figure(figsize=(7, 9))
# mean_vh_dB.plot(cmap="Greys_r", robust=True)
# plt.title("Average VH")
# plt.show()

In [12]:
# Scale to plot data in decibels
ds_s1["vv_dB"] = dB_scale(ds_s1.vv)

# Plot all VV observations for the year
# ds_s1.vv_dB.plot(cmap="Greys_r", robust=True, col="time", col_wrap=5)
# plt.show()

In [13]:
# Plot the average of all VV observations
mean_vv_dB = ds_s1.vv_dB.mean(dim="time")

# fig = plt.figure(figsize=(7, 9))
# mean_vv_dB.plot(cmap="Greys_r", robust=True)
# plt.title("Average VV")
# plt.show()

In [14]:
# Adapted from https://stackoverflow.com/questions/39785970/speckle-lee-filter-in-python
def lee_filter(img, size):
    """
    Applies the Lee filter to reduce speckle noise in an image.

    Parameters:
    img (ndarray): Input image to be filtered.
    size (int): Size of the uniform filter window.

    Returns:
    ndarray: The filtered image.
    """
    img_mean = uniform_filter(img, size)
    img_sqr_mean = uniform_filter(img**2, size)
    img_variance = img_sqr_mean - img_mean**2

    overall_variance = np.var(img)

    img_weights = img_variance / (img_variance + overall_variance)
    img_output = img_mean + img_weights * (img - img_mean)
    return img_output

In [15]:
print(ds_s1.vv.dims)

('time', 'y', 'x')


In [16]:
# Define a function to apply the Lee filter to a DataArray
def apply_lee_filter(data_array, size=7):
    """
    Applies the Lee filter to the provided DataArray.

    Parameters:
    data_array (xarray.DataArray): The data array to be filtered.
    size (int): Size of the uniform filter window. Default is 7.

    Returns:
    xarray.DataArray: The filtered data array.
    """
    data_array_filled = data_array.fillna(0)  # Use the DataArray's fillna method

    filtered_data = xr.apply_ufunc(
        lee_filter, data_array_filled,
        kwargs={"size": size},
        input_core_dims=[["x", "y"]],  # Referencing dimension names
        output_core_dims=[["x", "y"]],
        dask_gufunc_kwargs={"allow_rechunk": True},
        vectorize=True,
        dask="parallelized",
        output_dtypes=[data_array.dtype]
    )
    return filtered_data

In [17]:
# Apply the Lee filter to both VV and VH data
ds_s1["filtered_vv"] = apply_lee_filter(ds_s1.vv, size = 7)
ds_s1["filtered_vh"] = apply_lee_filter(ds_s1.vh, size = 7)

In [18]:
# Scale to plot data in decibels
ds_s1["filtered_vh_dB"] = dB_scale(ds_s1.filtered_vh)

# Plot all filtered VH observations for the year
# ds_s1.filtered_vh_dB.plot(cmap="Greys_r", robust=True, col="time", col_wrap=5)
# plt.show()

In [19]:
# Plot the average of all filtered VH observations
mean_filtered_vh_dB = ds_s1.filtered_vh_dB.mean(dim="time")

# fig = plt.figure(figsize=(7, 9))
# mean_filtered_vh_dB.plot(cmap="Greys_r", robust=True)
# plt.title("Average filtered VH")
# plt.show()

In [None]:
# Scale to plot data in decibels
ds_s1["filtered_vv_dB"] = dB_scale(ds_s1.filtered_vv)

# Plot all filtered VV observations for the year
# ds_s1.filtered_vv_dB.plot(cmap="Greys_r", robust=True, col="time", col_wrap=5)
# plt.show()

In [None]:
# Plot the average of all filtered VV observations
mean_filtered_vv_dB = ds_s1.filtered_vv_dB.mean(dim="time")

# fig = plt.figure(figsize=(7, 9))
# mean_filtered_vv_dB.plot(cmap="Greys_r", robust=True)
# plt.title("Average filtered VV")
# plt.show()

In [None]:
fig = plt.figure(figsize=(15, 3))
ds_s1.filtered_vh_dB.plot.hist(bins=1000, label="VH filtered")
ds_s1.vh_dB.plot.hist(bins=1000, label="VH", alpha=0.5)
plt.legend()
plt.xlabel("VH (dB)")
plt.title("Comparison of filtered VH bands to original")
plt.show()

In [None]:
fig = plt.figure(figsize=(15, 3))
ds_s1.filtered_vv_dB.plot.hist(bins=1000, label="VV filtered")
ds_s1.vv_dB.plot.hist(bins=1000, label="VV", alpha=0.5)
plt.legend()
plt.xlabel("VV (dB)")
plt.title("Comparison of filtered VV bands to original")
plt.show()

In [None]:
threshold = -20.0

In [None]:
fig = plt.figure(figsize=(15, 3))
plt.axvline(x=threshold, label=f"Threshold at {threshold}", color="red")
ds_s1.filtered_vh_dB.plot.hist(bins=1000, label="VH filtered")
ds_s1.vh_dB.plot.hist(bins=1000, label="VH", alpha=0.5)
plt.legend()
plt.xlabel("VH (dB)")
plt.title("Histogram Comparison of filtered VH bands to original")
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(15, 3))
ds_s1.filtered_vh_dB.plot.hist(bins=1000, label="VH filtered")
ax.axvspan(xmin=-40.0, xmax=threshold, alpha=0.25, color="green", label="Water")
ax.axvspan(xmin=threshold,
           xmax=-0.5,
           alpha=0.25,
           color="red",
           label="Not Water")
plt.legend()
plt.xlabel("VH (dB)")
plt.title("Effect of the classifier")
plt.show()

In [None]:
def s1_water_classifier(ds, threshold=-20.0):
    assert "vh" in ds.data_vars, "This classifier is expecting a variable named `vh` expressed in DN, not DB values"
    filtered = apply_lee_filter(ds_s1.vh, size = 7)
    water_data_array = dB_scale(filtered) < threshold
    return water_data_array.to_dataset(name="s1_water")

In [None]:
ds_s1["water"] = s1_water_classifier(ds_s1).s1_water

In [None]:
ds_s1.water

In [None]:
# Plot the mean of each classified pixel value
plt.figure(figsize=(15, 12))
ds_s1.water.mean(dim="time").plot(cmap="RdBu")
plt.title("Average classified pixel value")
plt.show()

In [None]:
# Plot the standard deviation of each classified pixel value
plt.figure(figsize=(15, 12))
ds_s1.water.std(dim="time").plot(cmap="viridis")
plt.title("Standard deviation of classified pixel values")
plt.show()

In [None]:
start_time_index = 0
end_time_index = ds_s1.water.sizes["time"] - 1

In [None]:
change = np.subtract(ds_s1.water.isel(time=start_time_index),
                     ds_s1.water.isel(time=end_time_index),
                     dtype=np.float32)

# Set all '0' entries to NaN, which prevents them from displaying in the plot.
change = change.where(change != 0)
ds_s1["change"] = change

In [None]:
plt.figure(figsize=(15, 12))
ds_s1.filtered_vh_dB.mean(dim="time").plot(cmap="Blues")
ds_s1.change.plot(cmap="RdBu", levels=2)
plt.title(f"Change in pixel value between time={start_time_index} and time={end_time_index}")
plt.show()