## This notebook is for downscaling the WaPOR ET data from 300m to 10m resolution. Subsequently the downscaled ImageCollection is exported to the Project Asset folder in Google Earth Engine.

In [None]:
import sys
import os

sys.path.append(os.path.abspath(os.path.join("..")))

In [None]:
import ee
import geemap

ee.Initialize(project="thurgau-irrigation")

In [None]:
from wapor_et_processing import load_wapor_et_data
from vegetation_period_NDVI.data_loading import load_sentinel2_data, add_time_data
from vegetation_period_NDVI.time_series import (
    extract_time_ranges,
    get_harmonic_ts,
    add_time_data,
)
from utils.composites import harmonized_ts
from utils.date_utils import print_collection_dates, create_centered_date_ranges
from utils.harmonic_regressor import HarmonicRegressor
from utils.ee_utils import back_to_float, back_to_int, export_image_to_asset


from typing import List

## Define the AOI and the year to process

In [None]:
# For the canton of Thurgau
cantonal_borders_asset = (
    "projects/thurgau-irrigation/assets/Thurgau/thrugau_borders_2024"
)

aoi_feature_collection = ee.FeatureCollection(cantonal_borders_asset)
aoi_geometry = aoi_feature_collection.geometry()
aoi_geometry = aoi_geometry.simplify(500)
aoi_buffered = aoi_geometry.buffer(100)

# Little square around Oensingen:
# oensingen_coordinates = [
#   [
#     [7.569238717890812, 47.35358169812031],
#     [7.569238717890812, 47.21413609771895],
#     [7.879806798836398, 47.21413609771895],
#     [7.879806798836398, 47.35358169812031],
#     [7.569238717890812, 47.35358169812031]
#   ]
# ];

# # Create an ee.Geometry object from the coordinates
# oensingen_polygon = ee.Geometry.Polygon(oensingen_coordinates)
# aoi_simple = oensingen_polygon.simplify(500)
# aoi_buffered = aoi_simple.buffer(1000)

In [None]:
YEAR = 2022
BUFFER_DAYS = 5

In [None]:
# Load WAPOR ET data
first_year = YEAR
last_year = YEAR
wapor_et_data = load_wapor_et_data(
    first_year, last_year, frequency="dekadal"
).filterBounds(aoi_buffered)

### Getting sentinel 2 data


In [None]:
s2collection = load_sentinel2_data(year=YEAR, aoi=aoi_buffered)

# img = s2collection.first()
# print("Band projections:")
# for band in img.bandNames().getInfo():
#     print(f"{band}: {img.select(band).projection().crs().getInfo()}")

In [None]:
wapor_list = wapor_et_data.toList(36)

time_intervals = create_centered_date_ranges(wapor_list, buffer_days=BUFFER_DAYS)

bands = ["B3", "B4", "B8", "B11", "B12"]

options = {"agg_type": "mosaic", "mosaic_type": "least_cloudy", "band_name": "NDVI"}

s2_harmonized = harmonized_ts(
    masked_collection=s2collection,
    band_list=bands,
    time_intervals=time_intervals,
    options=options,
)

In [None]:
# s2_harmonized.first().projection().getInfo()

In [None]:
def compute_vegetation_indexes(image: ee.Image) -> ee.Image:
    """
    Compute vegetation indexes for a given image

    Args:
        image (ee.Image): The image to compute the vegetation indexes for

    Returns:
        ee.Image: The input image with the vegetation indexes

    """
    ndvi = image.normalizedDifference(["B8", "B4"]).rename("NDVI")
    ndwi = image.normalizedDifference(["B3", "B8"]).rename("NDWI")
    ndbi = image.normalizedDifference(["B11", "B8"]).rename("NDBI")
    return image.addBands(ndvi).addBands(ndwi).addBands(ndbi)

s2_harmonized_w_vegetation_indexes = s2_harmonized.map(compute_vegetation_indexes)

In [None]:
# s2_harmonized_w_vegetation_indexes.first().projection().getInfo()

### Filling data gaps with harmonic regression

In [None]:
indexes = ["NDVI", "NDWI", "NDBI"]

s2_harmonized_w_vegetation_indexes = s2_harmonized_w_vegetation_indexes.map(
    add_time_data
)

s2_harmonized_gaps_filled = s2_harmonized_w_vegetation_indexes

for index in indexes:
    regressor = HarmonicRegressor(
        omega=1.5, max_harmonic_order=2, vegetation_index=index
    )

    regressor.fit(s2_harmonized_w_vegetation_indexes)
    fitted_collection = regressor.predict(s2_harmonized_w_vegetation_indexes)

    fitted_collection = fitted_collection.map(
        lambda img: img.select(["fitted"]).rename(f"fitted_{index}")
    )

    s2_harmonized_gaps_filled = s2_harmonized_gaps_filled.map(
        lambda img: img.addBands(
            fitted_collection.filterDate(img.date()).first().select([f"fitted_{index}"])
        )
    )

In [None]:
# s2_harmonized_w_vegetation_indexes.first().projection().nominalScale().getInfo()

In [None]:
# s2_harmonized_gaps_filled_list = s2_harmonized_gaps_filled.toList(36)

# first = ee.Image(s2_harmonized_gaps_filled_list.get(6)).select("fitted_NDVI").clip(aoi_buffered)
# second = ee.Image(s2_harmonized_gaps_filled_list.get(6)).select("NDVI").clip(aoi_buffered)

# Map = geemap.Map()

# vis_params = {
#     "bands": ["fitted_NDVI"],
#     "min": 0,
#     "max": 1,
#     "palette": ["red", "yellow", "green"],
# }

# vis_params_2 = {    
#     "bands": ["NDVI"],
#     "min": 0,
#     "max": 1,
#     "palette": ["red", "yellow", "green"],
# }

# Map.center_object(aoi_buffered, 12)
# Map.addLayer(second, vis_params_2, "NDVI")
# Map.addLayer(first, vis_params, "Fitted NDVI")

# Map

## Downscaling WAPOR ET data to Sentinel 2 resolution

In [None]:
from utils.downscale_anything_10m import Downscaler

In [None]:
def resample_collection(
    collection: ee.ImageCollection, reference_collection: ee.ImageCollection
) -> ee.ImageCollection:
    """
    Resample an image collection to match the resolution and projection of a reference collection.
    This function is specifically designed to resample Sentinel-2 imagery to match WAPOR ET data.

    Args:
        collection (ee.ImageCollection): The input Sentinel-2 image collection to be resampled.
        reference_collection (ee.ImageCollection): The reference WAPOR ET image collection.

    Returns:
        ee.ImageCollection: The resampled Sentinel-2 image collection.
    """
    # Get the projection and scale from the first image of the reference collection
    reference_image = reference_collection.first()
    target_projection = reference_image.projection()
    target_scale = target_projection.nominalScale()

    def resample_image(image: ee.Image) -> ee.Image:
        """
        Resample a single image to match the target projection and scale.

        Args:
            image (ee.Image): Input image to resample.

        Returns:
            ee.Image: Resampled image with consistent projection and scale.
        """
        # Store original metadata
        original_projection = image.projection()
        original_scale = original_projection.nominalScale()

        # Reproject each band separately to maintain band-specific properties
        band_names = image.bandNames()

        def resample_band(band_name: ee.String) -> ee.Image:
            band = image.select([band_name])
            return band.reproject(
                crs=target_projection, scale=target_scale, crsTransform=None
            ).setDefaultProjection(crs=target_projection, scale=target_scale)

        # Map over bands and resample each
        resampled_bands = band_names.map(lambda name: resample_band(ee.String(name)))

        # Combine resampled bands
        resampled = ee.ImageCollection(resampled_bands).toBands().rename(band_names)

        # Set metadata about the resampling operation
        return resampled.copyProperties(image).set(
            {
                "system:time_start": image.get("system:time_start"),
                "resampled": True,
                "original_scale": original_scale,
                "target_scale": target_scale,
                "original_projection": original_projection.wkt(),
                "target_projection": target_projection.wkt(),
            }
        )

    return collection.map(resample_image)

In [None]:
# def export_image_to_asset(
#     image: ee.Image,
#     asset_id: str,
#     task_name: str,
#     year: str,
#     aoi: ee.Geometry,
#     max_pixels: int = 1e13,
# ) -> ee.batch.Task:
#     """
#     Export an image to an Earth Engine asset.
#     """
#     task = ee.batch.Export.image.toAsset(
#         image=image,
#         description=task_name,
#         assetId=asset_id,
#         region=aoi,
#         scale=10,
#         maxPixels=max_pixels,
#     )
#     print(f"Exporting {task_name} for {year} to {asset_id}")
#     task.start()
#     return task


def process_and_export_downscaled_ET(
    downscaler: Downscaler,
    s2_indices: ee.ImageCollection,
    independent_vars: ee.ImageCollection,
    dependent_vars: ee.ImageCollection,
    aoi: ee.Geometry,
    year: str,
    scale_coarse: float,
    scale_fine: float = 10,
    time_steps: int = 36,
    time_step_type: str = "dekadal",
) -> List[ee.batch.Task]:
    """
    Process and export downscaled WaPOR ET images to Earth Engine assets.

    Args:
        downscaler (Downscaler): The Downscaler object used to downscale the images.
        s2_indices (ee.ImageCollection): The Sentinel-2 indices image collection.
        independent_vars (ee.ImageCollection): The resampled independent variables image collection.
        dependent_vars (ee.ImageCollection): The dependent variables image collection.
        aoi (ee.Geometry): The area of interest geometry.
        year (str): The year for which the images are processed.
        scale_coarse (float): The scale of the images before downscaling.
        scale_fine (float): The scale of the images after downscaling.
        time_steps (int): Number of time steps in the year (36 for dekadal, 12 for monthly).
        time_step_type (str): Type of time step ("dekadal" or "monthly").

    Returns:
        List[ee.batch.Task]: A list of export tasks for the downscaled images.
    """
    s2_indices_list = s2_indices.toList(s2_indices.size())
    independent_vars_list = independent_vars.toList(independent_vars.size())
    dependent_vars_list = dependent_vars.toList(dependent_vars.size())

    tasks = []
    for i in range(time_steps):
        if time_step_type == "dekadal":
            j = i % 3 + 1
            m = i // 3 + 1
            date = ee.Date.fromYMD(int(year), m, j * 10 - 9)
            time_step_name = f"{m:02d}_D{j}"
        elif time_step_type == "monthly":
            m = i + 1
            date = ee.Date.fromYMD(int(year), m, 1)
            time_step_name = f"{m:02d}"
        else:
            raise ValueError("time_step_type must be either 'dekadal' or 'monthly'")

        s2_index = ee.Image(s2_indices_list.get(i))
        ind_vars = ee.Image(independent_vars_list.get(i))
        dep_vars = ee.Image(dependent_vars_list.get(i))

        # Perform downscaling
        et_image_downscaled = downscaler.downscale(
            coarse_independent_vars=ind_vars,
            coarse_dependent_var=dep_vars,
            fine_independent_vars=s2_index,
            geometry=aoi,
            resolution=scale_coarse,
        )

        # Post-process the downscaled image
        et_image_downscaled = back_to_int(et_image_downscaled, 100)

        task_name = f"WaPOR_ET_downscaled_{year}_testin_reproject-{time_step_name}"
        asset_id = f"projects/thurgau-irrigation/assets/Thurgau/ET_WaPOR_10m__testin_reproject_{time_step_type}_{year}/{task_name}"

        task = export_image_to_asset(
            et_image_downscaled,
            asset_id,
            task_name,
            year,
            aoi,
            crs="EPSG:32632",
            scale=scale_fine,
        )
        tasks.append(task)

    return tasks

In [None]:
# independent_bands = ["fitted_NDVI", "fitted_NDBI", "fitted_NDWI"]
# dependent_band = ["ET"]

# s2_indices = s2_harmonized_gaps_filled.select(independent_bands)
# independent_vars = resample_collection(s2_indices, wapor_et_data)
# dependent_vars = wapor_et_data.select(dependent_band)

# scale = wapor_et_data.first().projection().nominalScale().getInfo()


# # Initialize the Downscaler
# downscaler = Downscaler(
#     independent_vars=independent_bands, dependent_var=dependent_band[0]
# )

# tasks = process_and_export_downscaled_ET(
#     downscaler,
#     s2_indices,
#     independent_vars,
#     dependent_vars,
#     aoi_buffered,
#     YEAR,
#     scale_coarse=scale,
#     scale_fine=10,
#     time_steps=36,
#     time_step_type="dekadal",
# )

# # You can add additional code here to monitor the tasks if needed
# print(f"Started {len(tasks)} export tasks.")

### Sanity check. Verify that the downscaling and exporting has worked correctly

In [58]:
wapot_collection_zh = ee.ImageCollection("projects/thurgau-irrigation/assets/Zuerich/ET_WaPOR_10m_dekadal_2022").map(lambda img: back_to_float(img, 100))

print_collection_dates(wapot_collection_zh)

Dates of images in the collection:
2022-01-01
2022-01-11
2022-01-21
2022-02-01
2022-02-11
2022-02-21
2022-03-01
2022-03-11
2022-03-21
2022-04-01
2022-04-11
2022-04-21
2022-05-01
2022-05-11
2022-05-21
2022-06-01
2022-06-11
2022-06-21
2022-07-01
2022-07-11
2022-07-21
2022-08-01
2022-08-11
2022-08-21
2022-09-01
2022-09-11
2022-09-21
2022-10-01
2022-10-11
2022-10-21
2022-11-01
2022-11-11
2022-11-21
2022-12-01
2022-12-11
2022-12-21


In [53]:
wapor_collection_zh_list = wapot_collection_zh.toList(36)

In [57]:
Map = geemap.Map()

image = wapor_collection_zh_list.get(20)
# image_2 = wapor_downscaled_tg_list.get(10)

vis_params = {
    "bands": ["downscaled"],
    "min": 0,
    "max": 5,
    "palette": ["blue", "lightblue", "green", "yellow", "red"],
}

Map.center_object(aoi_buffered, 10)
Map.addLayer(ee.Image(image), vis_params, "ET downscaled")
# Map.addLayer(ee.Image(image_2), vis_params, "ET downscaled old")


Map

Map(center=[47.56858787382066, 9.092720596553875], controls=(WidgetControl(options=['position', 'transparent_b…