In [1]:
import ee
from IPython.core.display_functions import display

try:
    ee.Authenticate()
except Exception as e:
    print(f"Error authenticating Earth Engine: {e}. Please ensure you have Earth Engine access.")

# try:
#     ee.Initialize(project="rwanda-climate-alerts")
# except Exception as e:
#     print(f"Error initializing Earth Engine: {e}. Please ensure you are authenticated.")

import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import pandas as pd

from src.fetch_datasets import fetch_all

chirps, era5_temp, soil_moist, ndvi, dem, slope = fetch_all()

dataset_dict = {
    "chirps": {
        "dataset": chirps,
        "list of bands": ["precipitation"],
        "title": "Precipitation in ",
        "xlabel": "Date",
        "ylabel": "Precipitation [mm]",
        "ylim_min": -0,
        "ylim_max": 100
    },
    "era5_temp": {
        "dataset": era5_temp,
        "list of bands": ["temperature_2m"],
        "title": "Temperature in ",
        "xlabel": "Date",
        "ylabel": "Temperature [C]",
        "ylim_min": 10,
        "ylim_max": 30
    },
    "soil_moist": {
        "dataset": soil_moist,
        "list of bands": ["volumetric_soil_water_layer_1"],
        "title": "Soil moisture in ",
        "xlabel": "Date",
        "ylabel": "Moisture [?]",
        "ylim_min": -0,
        "ylim_max": 1
    },
    "ndvi": {
        "dataset": ndvi,
        "list of bands": ["NDVI"],
        "title": "NDVI in ",
        "xlabel": "Date",
        "ylabel": "NDVI [?]",
        "ylim_min": -0,
        "ylim_max": 10000
    }
}

# Fetch Time series
def get_time_series(image_collection, district_name, start_date, end_date, scale):
    district = ee.FeatureCollection("FAO/GAUL/2015/level2") \
                    .filter(ee.Filter.eq("ADM0_NAME", "Rwanda")) \
                    .filter(ee.Filter.eq("ADM2_NAME", district_name)) \
                    .geometry()

    district_time_series = image_collection \
                            .filterDate(start_date, end_date) \
                            .getRegion(district, scale=scale) \
                            .getInfo()

    return district_time_series


# Convert to pandas DataFrame
def ee_array_to_df(arr, list_of_bands):
    """Transforms client-side ee.Image.getRegion array to pandas.DataFrame."""
    df = pd.DataFrame(arr)

    # Rearrange the header.
    headers = df.iloc[0]
    df = pd.DataFrame(df.values[1:], columns=headers)

    # Remove rows without data inside.
    df = df[['longitude', 'latitude', 'time', *list_of_bands]].dropna()

    # Convert the data to numeric values.
    for band in list_of_bands:
        df[band] = pd.to_numeric(df[band], errors='coerce')

    # Convert the time field into a datetime.
    df['datetime'] = pd.to_datetime(df['time'], unit='ms')

    # Keep the columns of interest.
    df = df[['time','datetime',  *list_of_bands]]

    return df


def t_kelvin_to_celsius(t_kelvin):
    """Converts Kelvin units to degrees Celsius."""
    t_celsius =  t_kelvin - 273.15
    return t_celsius


def plot_dataset(dataframe, district, dataset_name, dataset_info):
    fig, ax = plt.subplots(figsize=(14, 6))

    dataset = dataset_info[dataset_name]

    list_of_bands = dataset["list of bands"]
    title = dataset["title"] + district
    xlabel = dataset["xlabel"]
    ylabel = dataset["ylabel"]
    ylim_min = dataset["ylim_min"]
    ylim_max = dataset["ylim_max"]

    if dataset_name == "era5_temp":
        dataframe["temperature_2m"] = dataframe["temperature_2m"].apply(t_kelvin_to_celsius)

    ax.scatter(dataframe["datetime"], dataframe[list_of_bands[0]],
               color='gray', linewidth=1, alpha=0.7, label=f"{district} (trend)")

    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %Y'))  # e.g., "Jun 2025"
    ax.xaxis.set_major_locator(mdates.MonthLocator(interval=1))
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right")

    ax.set_title(title, fontsize=16)
    ax.set_xlabel(xlabel, fontsize=14)
    ax.set_ylabel(ylabel, fontsize=14)
    ax.set_ylim(ylim_min, ylim_max)
    ax.grid(lw=0.5, ls='--', alpha=0.7)
    ax.legend(fontsize=14, loc='lower right')

    return fig, ax


def get_dataset_info(district, dataset_name, dataset_info):
    dataset = dataset_info[dataset_name]

    plot_params = {
        "district": district,
        "list_of_bands": dataset["list of bands"],
        "title": dataset["title"] + district,
        "xlabel": dataset["xlabel"],
        "ylabel": dataset["ylabel"],
        "ylim_min": dataset["ylim_min"],
        "ylim_max": dataset["ylim_max"]
    }

    return plot_params

def plot_dataset_test(dataframe, dataset_name, ax, dataset_parameters=None):
    district = None
    if dataset_name == "era5_temp":
        dataframe["temperature_2m"] = dataframe["temperature_2m"].apply(t_kelvin_to_celsius)

    if dataset_parameters:
        district = dataset_parameters["district"]
        ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %Y'))  # e.g., "Jun 2025"
        ax.xaxis.set_major_locator(mdates.MonthLocator(interval=1))
        plt.setp(ax.get_xticklabels(), rotation=45, ha="right")

        ax.set_title(dataset_parameters["title"], fontsize=16)
        ax.set_xlabel(dataset_parameters["xlabel"], fontsize=14)
        ax.set_ylabel(dataset_parameters["ylabel"], fontsize=14)
        ax.set_ylim(dataset_parameters["ylim_min"], dataset_parameters["ylim_max"])
        ax.grid(lw=0.5, ls='--', alpha=0.7)
        ax.legend(fontsize=14, loc='lower right')

    ax.scatter(dataframe["datetime"], dataframe[dataset_parameters["list_of_bands"]],
               color='gray', linewidth=1, alpha=0.7, label=district)


def get_daily_average(dataframe):
    dataframe = dataframe.drop("time", axis=1)
    df_combined = dataframe.groupby("datetime").agg(
        precipitation = ("precipitation", "mean"),
    )
    return df_combined

In [2]:
district_time_series = get_time_series(
                        dataset_dict["chirps"]["dataset"],
                        "Bugesera",
                        "2024-01-01",
                        "2024-12-31",
                        1000)

df = ee_array_to_df(district_time_series, dataset_dict["chirps"]["list of bands"])
daily_average_df = get_daily_average(df)

# print(df[:10])
# print(daily_average_df[:10])

# print(len(daily_average_df))

# Plot with matplotlib
# fig, ax = plot_dataset(df, "Bugesera", "chirps", dataset_dict)
# plt.show()

In [3]:
# print(f"{len(df)}\n")
display(df[:10])
print(df.dtypes)

Unnamed: 0,time,datetime,precipitation
0,1704067200000,2024-01-01,1.779287
1,1704153600000,2024-01-02,3.558574
2,1704240000000,2024-01-03,1.779287
3,1704326400000,2024-01-04,1.779287
4,1704412800000,2024-01-05,0.0
5,1704499200000,2024-01-06,0.0
6,1704585600000,2024-01-07,8.745286
7,1704672000000,2024-01-08,0.0
8,1704758400000,2024-01-09,2.915095
9,1704844800000,2024-01-10,0.0


0
time                     object
datetime         datetime64[ns]
precipitation           float64
dtype: object


In [4]:
display(daily_average_df[:10])
print(df.dtypes)

Unnamed: 0_level_0,precipitation
datetime,Unnamed: 1_level_1
2024-01-01,0.129794
2024-01-02,4.207765
2024-01-03,2.760971
2024-01-04,0.175958
2024-01-05,2.228546
2024-01-06,0.0
2024-01-07,6.434671
2024-01-08,0.0
2024-01-09,3.800104
2024-01-10,0.498575


0
time                     object
datetime         datetime64[ns]
precipitation           float64
dtype: object
