## TODOS:

- Dowscale WaPor data to 10m resolution


In [1]:
import sys
import os

# Add the parent directory of 'vegetation-period-NDVI' to the sys.path
sys.path.append(os.path.abspath(os.path.join("..")))

In [2]:
import ee
import geemap
from wapor_et_processing import load_wapor_et_data
from vegetation_period_NDVI.data_loading import load_sentinel2_data
from vegetation_period_NDVI.time_series import (
    extract_time_ranges,
    get_harmonic_ts,
    compute_harmonic_fit,
    calculate_phase_amplitude,
    add_time_data,
    get_regression_coefficients,
)
from utils.composites import harmonized_ts
from utils.export_image_collection import export_collection_to_assets

from typing import List

In [3]:
ee.Initialize(project="thurgau-irrigation")

## Define the AOI and the year to process

In [4]:
cantonal_borders_asset = (
    "projects/thurgau-irrigation/assets/Thurgau/thrugau_borders_2024"
)

aoi_feature_collection = ee.FeatureCollection(cantonal_borders_asset)
aoi_geometry = aoi_feature_collection.geometry()
aoi_geometry = aoi_geometry.simplify(500)
aoi_buffered = aoi_geometry.buffer(100)


# I need to restat :()
YEAR = 2023
BUFFER_DAYS = 5

# Load WAPOR ET data
first_year = YEAR
last_year = YEAR
wapor_et_data = load_wapor_et_data(
    first_year, last_year, frequency="dekadal"
).filterBounds(aoi_buffered)

In [5]:
def print_collection_dates(collection: ee.ImageCollection) -> None:
    """
    Print the dates of all images in an ImageCollection.

    Args:
        collection (ee.ImageCollection): The input image collection.

    Returns:
        None: This function prints the dates to the console.
    """
    # Get a list of all image dates
    dates = collection.aggregate_array('system:time_start')
    
    # Convert to ee.Date objects and format as strings
    formatted_dates = dates.map(lambda d: ee.Date(d).format('YYYY-MM-dd'))
    
    # Get the list of formatted dates
    date_list = formatted_dates.getInfo()
    
    print("Dates of images in the collection:")
    for date in date_list:
        print(date)

# print_collection_dates(wapor_et_data)

In [6]:
def create_centered_date_ranges(image_list: ee.List, buffer_days: int = 5) -> ee.List:
    """
    Creates date ranges centered around the timestamps of a list of Earth Engine images.

    Args:
        image_list (ee.List): A list of Earth Engine images.
        buffer_days (int): Number of days to buffer before and after the center date. Defaults to 5.

    Returns:
        ee.List: A list of lists, where each inner list contains two ee.Date objects
                 representing the start and end of a date range, centered around the image timestamp.
    """

    def create_centered_range(image, buffer_days):
        center_date = ee.Date(ee.Image(image).get("system:time_start"))
        start_date = center_date.advance(-buffer_days, "day")
        end_date = center_date.advance(buffer_days, "day")
        return ee.List([start_date, end_date])

    return image_list.map(lambda img: create_centered_range(img, buffer_days))

### Getting dekadal sentinel 2 data


In [7]:
s2collection = load_sentinel2_data(year=YEAR, aoi=aoi_buffered)

In [8]:
wapor_list = wapor_et_data.toList(36)

time_intervals = create_centered_date_ranges(wapor_list, buffer_days=BUFFER_DAYS)

bands = ["B3", "B4", "B8", "B11", "B12"]

options = {"agg_type": "mosaic", "mosaic_type": "least_cloudy", "band_name": "NDVI"}

s2_harmonized = harmonized_ts(
    masked_collection=s2collection,
    band_list=bands,
    time_intervals=time_intervals,
    options=options,
)

In [9]:
# print_collection_dates(s2_harmonized)

In [10]:
def compute_vegetation_indexes(image: ee.Image) -> ee.Image:
    """
    Compute vegetation indexes for a given image

    Args:
        image (ee.Image): The image to compute the vegetation indexes for

    Returns:
        ee.Image: The input image with the vegetation indexes

    """
    ndvi = image.normalizedDifference(["B8", "B4"]).rename("NDVI")
    ndwi = image.normalizedDifference(["B3", "B8"]).rename("NDWI")
    ndbi = image.normalizedDifference(["B11", "B8"]).rename("NDBI")
    return image.addBands(ndvi).addBands(ndwi).addBands(ndbi)


def fill_gaps(
    image_collection: ee.ImageCollection, vegetation_indexes: List[str]
) -> ee.ImageCollection:
    """
    Fill gaps in an image collection using harmonic regression for specified vegetation indexes.

    Args:
        image_collection (ee.ImageCollection): The image collection to fill gaps in.
        vegetation_indexes (List[str]): The vegetation indexes to fill gaps for.

    Returns:
        ee.ImageCollection: The image collection with gaps filled.
    """
    prep_for_harmonic = image_collection.map(add_time_data)

    def process_index(index: str) -> ee.ImageCollection:
        fitted_collection = compute_harmonic_fit(index, prep_for_harmonic, 2)
        return fitted_collection.map(
            lambda img: img.select(["fitted", "rmse"]).rename(
                [f"fitted_{index}", f"rmse_{index}"]
            )
        )

    fitted_collections = {index: process_index(index) for index in vegetation_indexes}

    def add_fitted_bands(img: ee.Image) -> ee.Image:
        for index, fitted_collection in fitted_collections.items():
            corresponding_fitted = fitted_collection.filter(
                ee.Filter.equals("system:time_start", img.get("system:time_start"))
            ).first()
            img = img.addBands(corresponding_fitted)
        return img

    return image_collection.map(add_fitted_bands)


s2_harmonized_w_vegetation_indexes = s2_harmonized.map(compute_vegetation_indexes)

s2_harmonized_gaps_filled = fill_gaps(
    s2_harmonized_w_vegetation_indexes, ["NDVI", "NDWI", "NDBI"]
)

In [11]:
from utils.downscale_anything_10m import downscale, perform_regression, apply_regression, apply_gaussian_smoothing, extract_coefficients

In [12]:
def resample_collection(
    collection: ee.ImageCollection, reference_collection: ee.ImageCollection
) -> ee.ImageCollection:
    """
    Resample an image collection to match the resolution and projection of a reference collection.
    This function is specifically designed to resample Sentinel-2 imagery to match WAPOR ET data.

    Args:
        collection (ee.ImageCollection): The input Sentinel-2 image collection to be resampled.
        reference_collection (ee.ImageCollection): The reference WAPOR ET image collection.

    Returns:
        ee.ImageCollection: The resampled Sentinel-2 image collection.
    """
    # Get the projection and scale from the first image of the reference collection
    reference_image = reference_collection.first()
    target_projection = reference_image.projection()
    target_scale = target_projection.nominalScale()

    def resample_image(image: ee.Image) -> ee.Image:
        # Reproject to match the reference projection and scale
        resampled = image.reproject(crs=target_projection, scale=target_scale)

        return resampled.set(
            {
                "resampled": True,
                "original_scale": image.projection().nominalScale(),
                "target_scale": target_scale,
                "original_projection": image.projection().wkt(),
                "target_projection": target_projection.wkt(),
            }
        )

    resampled_collection = collection.map(resample_image)

    return resampled_collection

In [13]:
independent_band = ["fitted_NDVI", "fitted_NDBI", "fitted_NDWI"]
dependent_band = ["ET"]

s2_indices = s2_harmonized_gaps_filled.select(independent_band)

independent_vars = resample_collection(s2_indices, wapor_et_data)

dependent_vars = wapor_et_data.select(dependent_band)

# print_collection_dates(independent_vars)
# print_collection_dates(dependent_vars)


In [14]:
def export_image_to_asset(
    image: ee.Image,
    asset_id: str,
    task_name: str,
    year: str,
    aoi: ee.Geometry,
    max_pixels: int = 1e13,
) -> ee.batch.Task:
    """
    Export an image to an Earth Engine asset.

    Args:
        image (ee.Image): The image to export.
        asset_id (str): The asset ID to export the image to.
        task_name (str): The name of the task.
        year (str): The year of the image.
        aoi (ee.Geometry): The area of interest to export.
        scale (float, optional): The scale in meters of the exported image. Defaults to 10.
        max_pixels (int, optional): The maximum number of pixels to export. Defaults to 1e13.

    Returns:
        ee.batch.Task: The export task object.
    """
    task = ee.batch.Export.image.toAsset(
        image=image,
        description=task_name,
        assetId=asset_id,
        region=aoi,
        scale=10,
        maxPixels=max_pixels,
    )

    print(f"Exporting {task_name} for {year} to {asset_id}")
    task.start()

    return task


# Usage example
def export_downscaled_images(
    s2_indices: ee.ImageCollection,
    independent_vars: ee.ImageCollection,
    dependent_vars: ee.ImageCollection,
    aoi: ee.Geometry,
    year: str,
    scale: float,
) -> None:
    """
    Export downscaled WaPOR ET images to Earth Engine assets.

    Args:
        s2_indices (ee.ImageCollection): Collection of Sentinel-2 indices.
        independent_vars (ee.ImageCollection): Collection of independent variables.
        dependent_vars (ee.ImageCollection): Collection of dependent variables.
        aoi (ee.Geometry): Area of interest for export.
        year (str): Year of the images.
        scale (float): Scale of the exported images.

    Returns:
        None
    """
    for i in range(36):
        j = i % 3 + 1
        m = i // 3 + 1
        s2_index = ee.Image(s2_indices.toList(36).get(i))
        ind_vars = ee.Image(independent_vars.toList(36).get(i))
        dep_vars = ee.Image(dependent_vars.toList(36).get(i))

        et_image_downscaled = downscale(ind_vars, dep_vars, scale, s2_index, aoi)

        def times_100_to_int(image):
            return image.select("downscaled").multiply(100).toInt()

        et_image_downscaled = times_100_to_int(et_image_downscaled)

        date = ee.Date.fromYMD(int(year), m, j * 10 - 9)
        et_image_downscaled = et_image_downscaled.set(
            "system:time_start", date.millis()
        )

        task_name = f"Thurgau_downscaled_WaPOR_dekadal{year}-{m:02d}_D{j}"
        asset_id = f"projects/thurgau-irrigation/assets/Thurgau/ET_WaPOR_10m_dekadal_{year}/WaPOR_ET_downscaled_{year}-{m:02d}_D{j}"

        export_image_to_asset(et_image_downscaled, asset_id, task_name, year, aoi)


scale = wapor_et_data.first().projection().nominalScale().getInfo()

export_downscaled_images(
    s2_indices, independent_vars, dependent_vars, aoi_buffered, YEAR, scale
)

Exporting Thurgau_downscaled_WaPOR_dekadal2023-01_D1 for 2023 to projects/thurgau-irrigation/assets/Thurgau/ET_WaPOR_10m_dekadal_2023/WaPOR_ET_downscaled_2023-01_D1
Exporting Thurgau_downscaled_WaPOR_dekadal2023-01_D2 for 2023 to projects/thurgau-irrigation/assets/Thurgau/ET_WaPOR_10m_dekadal_2023/WaPOR_ET_downscaled_2023-01_D2
Exporting Thurgau_downscaled_WaPOR_dekadal2023-01_D3 for 2023 to projects/thurgau-irrigation/assets/Thurgau/ET_WaPOR_10m_dekadal_2023/WaPOR_ET_downscaled_2023-01_D3
Exporting Thurgau_downscaled_WaPOR_dekadal2023-02_D1 for 2023 to projects/thurgau-irrigation/assets/Thurgau/ET_WaPOR_10m_dekadal_2023/WaPOR_ET_downscaled_2023-02_D1
Exporting Thurgau_downscaled_WaPOR_dekadal2023-02_D2 for 2023 to projects/thurgau-irrigation/assets/Thurgau/ET_WaPOR_10m_dekadal_2023/WaPOR_ET_downscaled_2023-02_D2
Exporting Thurgau_downscaled_WaPOR_dekadal2023-02_D3 for 2023 to projects/thurgau-irrigation/assets/Thurgau/ET_WaPOR_10m_dekadal_2023/WaPOR_ET_downscaled_2023-02_D3
Exporting 

## Visually validating the harmonic fit


In [15]:
# import matplotlib.pyplot as plt
# import ee
# import time

# # Bands: ['B3', 'B4', 'B8', 'B11', 'B12', 'NDVI', 'NDWI', 'NDBI', 'fitted_NDVI', 'rmse_NDVI', 'fitted_NDWI', 'rmse_NDWI', 'fitted_NDBI', 'rmse_NDBI']

# # Create a plot of the NDVI and fitted NDVI over all images. x axis: image index, y axis: NDVI value and fitted NDVI value
# image_list = ee.List(s2_harmonized_gaps_filled.toList(36))


# def get_NDVI_values(image):
#     NDVI = (
#         ee.Image(image)
#         .select("NDVI")
#         .reduceRegion(
#             reducer=ee.Reducer.first(), geometry=aoi_buffered, scale=10, maxPixels=1e8
#         )
#         .values()
#         .get(0)
#     )
#     fitted_NDVI = (
#         ee.Image(image)
#         .select("fitted_NDVI")
#         .reduceRegion(
#             reducer=ee.Reducer.first(), geometry=aoi_buffered, scale=10, maxPixels=1e8
#         )
#         .values()
#         .get(0)
#     )
#     return ee.Feature(None, {"NDVI": NDVI, "fitted_NDVI": fitted_NDVI})


# features = ee.FeatureCollection(image_list.map(get_NDVI_values))


# # Function to get values in batches
# def get_values_in_batches(collection, batch_size=10):
#     all_values = []
#     count = collection.size().getInfo()
#     for i in range(0, count, batch_size):
#         batch = collection.toList(batch_size, i)
#         batch_values = ee.FeatureCollection(batch).getInfo()
#         all_values.extend(batch_values["features"])
#         time.sleep(1)  # Add a small delay to avoid hitting rate limits
#     return all_values


# # Get values in batches
# all_values = get_values_in_batches(features)

In [16]:
# # Extract NDVI and fitted NDVI values
# NDVI_values = [feature.get("properties").get("NDVI") for feature in all_values]
# fitted_NDVI_values = [feature.get("properties").get("fitted_NDVI") for feature in all_values]

# # Create the plot
# plt.figure(figsize=(12, 6))
# plt.scatter(range(len(NDVI_values)), NDVI_values, label="NDVI", color="green")
# plt.plot(
#     range(len(fitted_NDVI_values)), fitted_NDVI_values, label="Fitted NDVI", color="red"
# )

# plt.title("NDVI and fitted NDVI over all images")
# plt.xlabel("Image index")
# plt.ylabel("NDVI")
# plt.legend()
# plt.grid(True)

# plt.show()

In [17]:
# Map = geemap.Map()


# # Add the layer to the map.
# image = ee.Image(s2_harmonized_gaps_filled.toList(36).get(20))
# Map.centerObject(aoi_buffered, 13)
# rmse_params = {
#     "bands": ["rmse_NDVI"],
#     "min": 0,
#     "max": 0.5,
#     "palette": ["blue", "white", "red"],
# }
# Map.addLayer(image, rmse_params, "rmse")
# # NDVI_params = {'bands': ['NDVI'], 'min': 0, 'max': 1, 'palette': ['blue', 'white', 'green']}
# # Map.addLayer(NDVI_image, NDVI_params, 'NDVI')
# # fitted_NDVI_params = {'bands': ['fitted_NDVI'], 'min': 0, 'max': 1, 'palette': ['blue', 'white', 'green']}
# # Map.addLayer(NDVI_image, fitted_NDVI_params, 'NDVI_fitted')


# # Display the map.
# Map