## Preparing data

In [None]:
import xarray as xr
import geopandas as gpd
from dask.distributed import Client, LocalCluster
from datetime import datetime, timedelta
from functools import partial

In [None]:
start_time = "2014-1-31"
end_time = "2014-02-10"
year = 2014

parent_in_path = f"/gpfs/work2/0/ttse0619/qianqian/global_data_Qianqian/1input_data/{year}global"
data_paths = {"era5land": f"{parent_in_path}/era5land/*.nc",
            "lai": f"{parent_in_path}/lai_v2/*.nc",
            "ssm": f"{parent_in_path}/ssm/GlobalGSSM11km2014_20240214.tif",
            "co2": f"{parent_in_path}/co2/CAMS_CO2_2003-2020.nc",
            "landcover": f"{parent_in_path}/igbp/landcover10km_global.nc",
            "vcmax": f"{parent_in_path}/vcmax/TROPOMI_Vmax_Tg_mean10km_global.nc",
            "canopyheight": f"{parent_in_path}/canopy_height/canopy_height_11kmEurope20230921_10km.nc",
            }

parent_out_path = "/scratch-shared/falidoost"

In [None]:
# read shape file
eu_shape_file = "/gpfs/work2/0/ttse0619/qianqian/global_data_Qianqian/1input_data/EuropeBoundary.shp"
gdf = gpd.read_file(eu_shape_file)
bbox = gdf.total_bounds
bbox

In [None]:
def era5_preprocess(ds):    
    # Convert the longitude coordinates from [0, 360] to [-180, 180]
    ds = ds.assign_coords(longitude=(((ds.longitude + 180) % 360) - 180))
    return ds

def co2_preprocess(ds, start_time, end_time):    
    ds = ds.sel(time=slice(start_time, end_time))
    return ds

co2_partial_func = partial(co2_preprocess, start_time=start_time, end_time=end_time)

def fix_coords(ds):
    if 'band' in ds.dims:
        ds = ds.rename_dims({'band': 'time'})
        ds = ds.rename_vars({'band': 'time'})

    if 'x' in ds.dims and 'y' in ds.dims:
        ds = ds.rename_dims({'x': 'longitude', 'y': 'latitude'})
        ds = ds.rename_vars({'x': 'longitude', 'y': 'latitude'})
        
    elif 'lon' in ds.dims and 'lat' in ds.dims:
        ds = ds.rename_dims({'lon': 'longitude', 'lat': 'latitude'})
        ds = ds.rename_vars({'lon': 'longitude', 'lat': 'latitude'})
    return ds

In [None]:
cluster = LocalCluster(n_workers=28, threads_per_worker=1)
client = Client(cluster)

In [None]:
chunks = 'auto'

for data_path in data_paths:
    
    if data_path == "era5land":
        ds = xr.open_mfdataset(data_paths[data_path], preprocess=era5_preprocess, chunks=chunks)
    
    if data_path == "co2":
        ds = xr.open_mfdataset(data_paths[data_path], preprocess=co2_partial_func, chunks=chunks)
        ds = ds.assign_coords(longitude=(((ds.longitude + 180) % 360) - 180))       

    else:
        ds = xr.open_mfdataset(data_paths[data_path], preprocess=fix_coords, chunks=chunks)
        
    # convert day of year
    if ds.time.size == 1:
        ds['time'] = [datetime.strptime(start_time, "%Y-%m-%d")]
    elif ds.time.dtype == 'int64':
        # Convert day of year to datetime
        ds['time'] = [datetime(year, 1, 1) + timedelta(int(day) - 1) for day in ds.time.values]
        
    ds = ds.sortby(['longitude', 'latitude'])
    masked_ds = ds.sel(longitude=slice(bbox[0], bbox[2]), latitude=slice(bbox[1], bbox[3]), time=slice(start_time, end_time))
    masked_ds = masked_ds.chunk(chunks=chunks)
    
    # svae to zarr
    out_path = f"{parent_out_path}/{data_path}_{start_time}_{end_time}.zarr"
    masked_ds.to_zarr(out_path, mode='w')
    print(f"{out_path} is saved")
    print("=======================================")

In [None]:
client.shutdown()

## Interpolations

In [None]:
import xarray as xr
from dask.distributed import Client, LocalCluster

In [None]:
start_time = "2014-1-31"
end_time = "2014-02-10"
year = 2014
parent_in_path = "/scratch-shared/falidoost"
data_paths = {"era5land": f"{parent_in_path}/era5land_{start_time}_{end_time}.zarr",
              "lai": f"{parent_in_path}/lai_{start_time}_{end_time}.zarr",
              "ssm": f"{parent_in_path}/ssm_{start_time}_{end_time}.zarr",
              "co2": f"{parent_in_path}/co2_{start_time}_{end_time}.zarr",
              "landcover": f"{parent_in_path}/landcover_{start_time}_{end_time}.zarr",
              "vcmax": f"{parent_in_path}/vcmax_{start_time}_{end_time}.zarr",
              "canopyheight": f"{parent_in_path}/canopyheight_{start_time}_{end_time}.zarr",
            }
parent_out_path = "/scratch-shared/falidoost"

In [None]:
def interpolation(ds, other):
    # in time
    ds_interpolated = ds.interp(coords={"time": other.time}, method='nearest', kwargs={"fill_value": "extrapolate"})
    
    # in space
    ds_interpolated = ds_interpolated.interp(coords={"longitude": other.longitude, "latitude": other.latitude}, method='linear')
    
    return ds_interpolated

variable_names = {"lai": "LAI",
                  "ssm": "band_data",
                  "co2": "co2",
                  "canopyheight": "__xarray_dataarray_variable__",
                  "vcmax": "__xarray_dataarray_variable__",
                  "landcover": "lccs_class"}  

era5land = xr.open_zarr(data_paths["era5land"])
era5land = era5land.chunk(time=-1, latitude=100, longitude=100)

all_data = era5land.copy()

for name in variable_names:
    ds = xr.open_zarr(data_paths[name])
    ds = ds.chunk(time=-1, latitude=100, longitude=100)
    ds_interpolated = interpolation(ds, era5land)    
    all_data[name] = ds_interpolated[variable_names[name]]

all_data = all_data.chunk(time=-1, latitude=100, longitude=100)
all_data

In [None]:
cluster = LocalCluster(n_workers=25, threads_per_worker=1)
client = Client(cluster)

In [None]:
# svae to zarr
out_path = f"{parent_out_path}/all_data_{start_time}_{end_time}.zarr"
encoding = {var: {'chunks': (all_data.sizes["time"], 100, 100)} for var in all_data.data_vars}

all_data.to_zarr(out_path, mode='w', encoding=encoding)
print(f"{out_path} is saved")

In [None]:
client.shutdown()

## Variable derivation

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import dask.array as da
from dask.distributed import Client, LocalCluster
from PyStemmusScope import variable_conversion as vc
from dask_ml.preprocessing import OneHotEncoder

In [None]:
start_time = "2014-1-31"
end_time = "2014-02-10"

parent_in_path = "/scratch-shared/falidoost"
data_paths = {"all_data": f"{parent_in_path}/all_data_{start_time}_{end_time}.zarr",
    "igbp_table": f"{parent_in_path}/lccs_to_igbp_table.csv",
    "igbp_class": f"{parent_in_path}/IGBP11unique.csv",
            }

parent_out_path = "/scratch-shared/falidoost"

In [None]:
# era5_land variables
all_data = xr.open_zarr(data_paths["all_data"])
all_data

In [None]:
def era5land_accumulated_vars(ds, input_name, output_name, scale_factor):
    input_da = ds[input_name] / scale_factor
    output_da = input_da.diff("time")
    output_da[0::24] = input_da[1::24]  # accumulation starts at t01 instead of t00

    t00 = xr.DataArray(np.nan, coords=input_da.isel(time=0).coords) # assign first t00 to none
    output_da = xr.concat([output_da, t00], dim='time')
    ds[output_name] = output_da
    return ds

In [None]:
all_data = era5land_accumulated_vars(all_data, "ssrd", "Rin", 3600)
all_data = era5land_accumulated_vars(all_data, "strd", "Rli", 3600)
all_data = era5land_accumulated_vars(all_data, "tp", "Precip_msr", 0.001) # to mm
all_data["p"] = all_data["sp"] / 100  # Pa -> hPa
all_data["Ta"] = all_data["t2m"] - 273.15  # K -> degC
all_data["ea"] = vc.calculate_es(all_data["d2m"] - 273.15)*10 # *10 is for kPa -> hPa
all_data["u"] = (all_data["u10"] ** 2 + all_data["v10"] ** 2) ** 0.5
all_data["ssm"] = all_data["ssm"] / 1000

all_data = all_data.chunk(time=-1, latitude=100, longitude=100)
all_data

In [None]:
# we need to convert landcover to IGBP
# lookup tables
igbp_table = pd.read_csv(data_paths["igbp_table"])

igbp_class = pd.read_csv(data_paths["igbp_class"])['0'].unique()

# define one hot encoding for IGBP using dask-ml functions
encoder = OneHotEncoder(
    sparse_output=False,
)

# Unsorted categories are not yet supported by dask-ml
igbp_stemmus_scope = np.sort(igbp_table["IGBP_STEMMUS_SCOPE"].to_numpy().reshape(-1,1))

encoder = encoder.fit(igbp_stemmus_scope)  
        
lookup_table = igbp_table.set_index("lccs_class").T.to_dict('records')[0]

def map_landcover_to_igbp(landcover_block):
    # Create a new DataArray with "no data" to hold the mapped values 
    mapped_block = da.full_like(landcover_block, fill_value="No data", dtype="U7")

    # For each key-value pair in the lookup table
    for key, value in lookup_table.items():
        # Where the landcover_block equals the current key, assign the corresponding value
        mapped_block = da.where(landcover_block == key, value, mapped_block)
    
    return mapped_block
        

def landcover_to_igbp(ds, landcover_var_name, encoder):
    landcover = ds[landcover_var_name]
    
    # Replace NaN values with "No data" or 255 in the table
    landcover = da.where(da.isnan(landcover), 255, landcover)
    
    igbp = map_landcover_to_igbp(landcover)
    igbp_reshaped = igbp.reshape(-1, 1)

    transformed = encoder.transform(igbp_reshaped)
    
    # Select the columns that correspond to the categories in igbp_class
    indices = [np.where(encoder.categories_[0] == category)[0][0] for category in igbp_class]    
    transformed = transformed[:, indices]

    # Replace zeros with np.nan
    transformed = da.where(transformed == 0, np.nan, transformed)

    # Add each column of the transformed array as a new variable in the dataset
    for i in range(transformed.shape[1]):
        ds[f"IGBP_veg_long{i+1}"] = (("time", "latitude", "longitude"), transformed[:, i].reshape(igbp.shape))

    return ds

In [None]:
ds = landcover_to_igbp(all_data, "landcover", encoder)
ds

In [None]:
cluster = LocalCluster(n_workers=4, threads_per_worker=1)
client = Client(cluster)

In [None]:
# svae to zarr
out_path = f"{parent_out_path}/model_input_{start_time}_{end_time}.zarr"
ds.to_zarr(out_path, mode='w')
print(f"{out_path} is saved")

In [None]:
client.shutdown()

## Model prediction

In [None]:
import pickle
import xarray as xr
import numpy as np
import pandas as pd
import dask.array as da
from dask.distributed import Client, LocalCluster

In [None]:
start_time = "2014-1-31"
end_time = "2014-02-10"

parent_in_path = "/scratch-shared/falidoost"
parent_out_path = "/scratch-shared/falidoost"

In [None]:
# load model
path_model = f"{parent_in_path}/hourly_multi7_depth20_min1219.pkl"
with open(path_model, 'rb') as f:
    model = pickle.load(f)
model

In [None]:
# rename some variables
model_input = xr.open_zarr(f"{parent_in_path}/model_input_{start_time}_{end_time}.zarr")
rename_vars = {"co2": "CO2", "lai": "LAI", "canopyheight": "hc", "ssm": "SSM", "vcmax": "Vcmo"}
ds = model_input.rename(rename_vars)

input_vars = [
    'Rin', 'Rli', 'p', 'Ta', 'ea', 'u', 'CO2', 'LAI', 'Vcmo','hc', 'Precip_msr',  
    'SSM',  *[f'IGBP_veg_long{i}' for i in range(1, 12)]
]

# select input data 
input_ds = ds[input_vars]

# define output template
output_vars = ['LEtot','Htot','Rntot','Gtot', 'Actot','SIF685', 'SIF740']
output_temp = xr.Dataset()
ds_shape = (input_ds.sizes['time'], input_ds.sizes['latitude'], input_ds.sizes['longitude'])

for var in output_vars:
    output_temp[var] = xr.DataArray(
        name = var,
        data=da.zeros(ds_shape),
        dims=input_ds.dims,
        coords=input_ds.coords,
    )
output_temp = output_temp.chunk(input_ds.chunksizes) # the same cunkc as input

In [None]:
def predictFlux(input_ds, model, output_vars):

    df_features = input_ds.to_dataframe().reset_index().drop(columns=["time", "longitude", "latitude"])
    
    # Convert the nan value as 0 for the calculation
    invalid_index = df_features.isnull().any(axis=1)
    df_features.loc[invalid_index] = 0
    
    LEH = model.predict(df_features)
    LEH[invalid_index] = np.nan # convert the original nan values to nan back
    
    output_ds = xr.Dataset(coords=input_ds.coords)
    ds_shape = (output_ds.dims['time'], output_ds.dims['longitude'], output_ds.dims['latitude'])
    
    for i, name in enumerate(output_vars):
        output_ds[name] = (("time", "longitude", "latitude"), LEH[:, i].reshape(ds_shape))
    
    return output_ds

In [None]:
# result
LEH = xr.map_blocks(
    predictFlux,
    input_ds,
    kwargs={
        "model": model, 
        "output_vars": output_vars, 
    },
    template=output_temp,
)
LEH

In [None]:
cluster = LocalCluster(n_workers=4, threads_per_worker=1)
client = Client(cluster)

In [None]:
# svae to zarr
out_path = f"{parent_out_path}/predicted_{start_time}_{end_time}.zarr"
LEH.to_zarr(out_path, mode='w')
print(f"{out_path} is saved")

In [None]:
client.shutdown()