In [None]:
%load_ext autoreload
%autoreload 2
import catboost as cb
import os
import openeo
from loguru import logger
import geopandas as gpd
import geojson
from pathlib import Path
import numpy as np
from torch.utils.data import DataLoader
from scaleagdata_vito.presto.presto_utils import evaluate
from scaleagdata_vito.presto.datasets import ScaleAG10DDataset
from scaleagdata_vito.presto.presto_df import (add_labels, xr_to_df, filter_ts)
from scaleagdata_vito.presto.presto_utils import load_pretrained_model_from_url, get_encodings
from scaleagdata_vito.openeo.preprocessing import run_openeo_extraction_job
from openeo_gfmap import (
    Backend,
    BackendContext,
    TemporalContext,
    FetchType,
)
from openeo_gfmap.manager.job_splitters import split_job_hex


### Load Presto pretrained models

In [None]:
# Decadal and Monthly Presto models trained in self-supervised mode on WorldCereal data
presto_ss_10d_wc = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/scaleagdata/models/presto-ss-wc_10D.pt"
presto_ss_30d_wc = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/scaleagdata/models/presto-ss-wc_30D.pt"
model_wc_10d = load_pretrained_model_from_url(
    presto_ss_10d_wc, finetuned=False, ss_dekadal=True, strict=False, device="cpu"
)
model_wc_30d = load_pretrained_model_from_url(
    presto_ss_30d_wc, finetuned=False, ss_dekadal=False, strict=False, device="cpu"
)

### Few-shot learning with Presto on yield task

#### Fetch data from OpenEO

In [None]:
# load dataframe with labels and polygons we want to extract openeo data for
gdf = (
    gpd.read_file(
        "/projects/TAP/HEScaleAgData/timeseries_modelling/datasets/apr2024_AVR_subfields/data/AVR_fields_10000_100000_subfields_yield_bel_nl_roads_removed.geojson"
    ).iloc[:100]
    .drop(columns=["date"])
)

# setup OpenEO job parameters
job_params = dict(
    connection=openeo.connect("https://openeo.creo.vito.be/openeo/").authenticate_oidc(),
    backend_context=BackendContext(Backend.CDSE),
    temporal_extent=TemporalContext(
        start_date="2022-01-01",
        end_date="2022-12-31",
    ),
    fetch_type=FetchType.POINT,
    disable_meteo=False,
    out_format="NetCDF",
    title="ScaleAGData_demo",
    split_dataset=False,
    output_path="/home/vito/millig/gio/data/scaleag_demo/test1"
)

In [None]:
output_path = Path(job_params["output_path"])
if not os.path.exists(job_params["output_path"]):
    output_path.mkdir(parents=True, exist_ok=True)

if job_params["split_dataset"]:
    datasets = split_job_hex(gdf)
    for i, sub_gdf in enumerate(datasets):
        logger.info(f"Extracting OpenEO data for subset {i}")
        output_path_frame = output_path / f"cube_{i}"
        output_path_frame.mkdir(parents=True, exist_ok=True)
        run_openeo_extraction_job(sub_gdf, str(output_path_frame), job_params)
else:
    logger.info(f"Extracting OpenEO data for dataset")
    run_openeo_extraction_job(gdf, str(output_path), job_params)

#### Test Presto on yield task 

In [3]:
# dataset_file = f"{output_path}/timeseries.nc"
dataset_file = "/home/vito/millig/projects/TAP/HEScaleAgData/data/AVR_subfields/avr_subfields_10d.nc"
gdf_label_file = (
    "/projects/TAP/HEScaleAgData/" \
    "timeseries_modelling/datasets/apr2024_AVR_subfields/" \
    "data/AVR_fields_10000_100000_subfields_yield_bel_nl_roads_removed.geojson"
)

In [None]:
dataset = add_labels(xr_to_df(dataset_file), gdf_label_file)
meteo = [c for c in dataset.columns if "METEO" in c]

In [None]:
# split avoiding data leakage
np.random.seed(3)
df_p = dataset['parentname'].unique()
train_frac = 0.90
sample_idx = np.random.choice(range(0, len(df_p)), size=int(train_frac*len(df_p)), replace=False)
train_df = dataset[dataset['parentname'].isin(df_p[sample_idx])].reset_index(drop=True)
val_df = dataset[~dataset['parentname'].isin(df_p[sample_idx])].reset_index(drop=True)

print(f"Validation: number of field IDs: {len(val_df['parentname'].unique())}, number of samples: {len(val_df)}")
print(f"Training: number of field IDs: {len(train_df['parentname'].unique())}, number of samples: {len(train_df)}")

In [None]:
target_name = "median_yield"
train_ds = ScaleAG10DDataset(train_df, target_name=target_name, task="regression")
val_ds = ScaleAG10DDataset(val_df, target_name=target_name, task="regression")

dl_train = DataLoader(
    train_ds,
    batch_size=256,
    shuffle=False,
    num_workers=4,
)

dl_val = DataLoader(
    val_ds,
    batch_size=256,
    shuffle=False,
    num_workers=4,
)

In [None]:
model_name = "presto-ss-wc_10D"
encodings_np, targets = get_encodings(dl_train, model_wc_10d)

logger.info(f"Fitting Catboost model with {model_name} encodings")

cbm = cb.CatBoostRegressor(
    random_state=3,
    task_type="GPU",
    devices="0:1",
    logging_level="Silent",
    loss_function="RMSE",
)
train_dataset = cb.Pool(encodings_np, targets)
cbm.fit(train_dataset)

In [None]:
metrics, preds, targets = evaluate(
    model_wc_10d,
    cbm,
    dl_val,
    task="regression",
    up_val=120000,
    low_val=10000,
    )
metrics

### Few shot learning with Presto on Crop/no-Crop task

In [3]:
import pandas as pd
from presto.dataset import WorldCerealLabelled10DDataset
from sklearn.metrics import classification_report

# Load WorldCereal dataset from artifactory

wc_train_dataset = pd.read_parquet(
    "/home/vito/millig/gio/data/presto_ft/rawts-10d_train.parquet"
)

wc_val_dataset = pd.read_parquet(
    "/home/vito/millig/gio/data/presto_ft/rawts-10d_val.parquet"
)

In [None]:
wc_train_dataset = wc_train_dataset.sample(frac=0.005)
wc_val_dataset = wc_val_dataset.sample(frac=0.005)

print(len(wc_train_dataset), len(wc_val_dataset))

In [None]:
wc_train_ds = WorldCerealLabelled10DDataset(wc_train_dataset)
wc_val_ds = WorldCerealLabelled10DDataset(wc_val_dataset)
wc_dl_train = DataLoader(
    wc_train_ds,
    batch_size=256,
    shuffle=False,
    num_workers=4,
)

wc_dl_val = DataLoader(
    wc_val_ds,
    batch_size=256,
    shuffle=False,
    num_workers=4,
)

In [13]:
model_name = "presto-ss-wc_10D"
encodings_train, target_train = get_encodings(wc_dl_train, model_wc_10d)

In [None]:
logger.info(f"Fitting Catboost model with {model_name} encodings")

cbm = cb.CatBoostClassifier(
    random_state=3,
    task_type="GPU",
    devices="0:1",
    logging_level="Silent",
)
train_dataset = cb.Pool(encodings_train, target_train)
cbm.fit(train_dataset)

In [None]:
encodings_test, targets_test = get_encodings(wc_dl_val, model_wc_10d)
preds_test = cbm.predict(encodings_test)
print(classification_report(targets_test, preds_test))