# Training XGBoost to Emulate ec-land

In this notebook we take some example Zarr data (similar to that created by this projects other functionality) and train an ML emulator of the ec-land land surface model. Here we are training on features of climatological, meteorological and previous model state values to predict the 6-hourly model state update 

In [None]:
import xgboost as xgb
import numpy as np
from sklearn.metrics import mean_squared_error

from dataset import EcDataset

## Settings

In [None]:
path = "/data/ecland_i6aj_o400_2010_2022_6h_euro.zarr"
n_estimators = 50
spatial_encoding = True

## Datasets

In [None]:
ds_train = EcDataset(path, "2018", "2019", spatial_encoding, temporal_encoding)
ds_val = EcDataset(path, "2020", "2020", spatial_encoding, temporal_encoding)

train_feats, train_targets = ds_train.feats, ds_train.targets
val_feats, val_targets = ds_val.feats, ds_val.targets

# Normalization
train_feats = EcDataset.transform(train_feats, ds_train.feat_means[None,:], ds_train.feat_stdevs[None,:])
train_targets = EcDataset.transform(train_targets, ds_train.target_means[None,:], ds_train.target_stdevs[None,:])
val_feats = EcDataset.transform(val_feats, ds_val.feat_means[None,:], ds_val.feat_stdevs[None,:])
val_targets = EcDataset.transform(val_targets, ds_val.target_means[None,:], ds_val.target_stdevs[None,:])

print(train_feats.shape, train_targets.shape)
print(val_feats.shape, val_targets.shape)

## Model training with XGBoost

Now we have our "features" and "targets" we can train xgboost to predict our model increments.

In [None]:
def mse(y_pred: np.ndarray, y_true: np.ndarray) -> float:
    return mean_squared_error(y_pred.flatten(), y_true.flatten())

model = xgb.XGBRegressor(
    n_estimators=n_estimators,
    tree_method="hist",
    device="cuda",
    objevtive=mse,
)
fname = "./test.json"

print("Fitting XGB model...")

# At once
model.fit(train_feats, train_targets, eval_set=[(val_feats, val_targets)])
model.save_model(fname)
y_val_pred = model.predict(val_feats)
val_mse = mse(y_val_pred, val_targets)
print(f"Validation MSE = {val_mse}")

# # Incremental (doesnt work!)
# batch_size = 20000
# i = 0
# val_mse_curr = float('inf')
# while True:
#     idxs = np.random.choice(n_train, batch_size, replace=False)
#     X_batch = X_train[idxs]
#     y_batch = y_train[idxs]
#     model.fit(X_batch, y_batch, eval_set=[(X_batch, y_batch)], xgb_model=fname if i>0 else None, verbose=False)
#     model.save_model(fname)
#     y_val_pred = model.predict(X_val)
#     if i%5 == 0:
#         val_mse = mse(y_val_pred, y_val)
#         print(f"Epoch {i}: Validation MSE = {val_mse}")
#         if val_mse < val_mse_curr:
#             val_mse_curr = val_mse
#             i+=1
#         else:
#             break

print("Finished training")