In [1]:
%load_ext autoreload
%autoreload 2
import catboost as cb
import openeo
from loguru import logger
import geopandas as gpd
import geojson
from pathlib import Path
import numpy as np
from torch.utils.data import DataLoader
from scaleagdata_vito.presto.datasets import ScaleAG10DDataset
from scaleagdata_vito.presto.presto_df import (add_labels, xr_to_df, filter_ts)
from scaleagdata_vito.presto.presto_utils import load_pretrained_model_from_url, get_encodings
from scaleagdata_vito.openeo.preprocessing import scaleag_preprocessed_inputs_gfmap
from openeo_gfmap import (
    Backend,
    BackendContext,
    TemporalContext,
    FetchType,
)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Load Presto pretrained models

In [2]:
# Decadal and Monthly Presto models trained in self-supervised mode on WorldCereal data
presto_ss_10d_wc = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/scaleagdata/models/presto-ss-wc_10D.pt"
presto_ss_30d_wc = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/scaleagdata/models/presto-ss-wc_30D.pt"
model_wc_10d = load_pretrained_model_from_url(
    presto_ss_10d_wc, finetuned=False, ss_dekadal=True, strict=False, device="cpu"
)
model_wc_30d = load_pretrained_model_from_url(
    presto_ss_30d_wc, finetuned=False, ss_dekadal=False, strict=False, device="cpu"
)

[32m2024-08-22 15:46:41.791[0m | [1mINFO    [0m | [36mscaleagdata_vito.presto.presto_utils[0m:[36mload_pretrained_model_from_url[0m:[36m49[0m - [1m Initialize Presto dekadal architecture with 10d ss trained WorldCereal Presto weights...[0m
[32m2024-08-22 15:46:42.027[0m | [1mINFO    [0m | [36mscaleagdata_vito.presto.presto_utils[0m:[36mload_pretrained_model_from_url[0m:[36m61[0m - [1m Initialize Presto dekadal architecture with 30d ss trained WorldCereal Presto weights...[0m


### Few-shot learning with Presto on yield task

In [3]:
gdf = (
    gpd.read_file(
        "/projects/TAP/HEScaleAgData/timeseries_modelling/datasets/apr2024_AVR_subfields/data/AVR_fields_10000_100000_subfields_yield_bel_nl_roads_removed.geojson"
    )
    .iloc[:100]
    .drop(columns=["date"])
)
output_path = "/home/vito/millig/gio/data/scaleag_demo/test"
Path(output_path).mkdir(parents=True, exist_ok=True)

# Set temporal range to generate product
temporal_extent = TemporalContext(
    start_date="2022-01-01",
    end_date="2022-12-31",
)

connection = openeo.connect("https://openeo.creo.vito.be/openeo/").authenticate_oidc()
backend_context = BackendContext(Backend.CDSE)

Authenticated using refresh token.


In [4]:
geometry_latlon = geojson.loads(gdf.to_json())
inputs = scaleag_preprocessed_inputs_gfmap(
    connection=connection,
    backend_context=backend_context,
    spatial_extent=geometry_latlon,
    temporal_extent=temporal_extent,
    fetch_type=FetchType.POINT,
    disable_meteo=True,  # precompute meteo and upload in bucket
)
cube = inputs.aggregate_spatial(geometries=geometry_latlon, reducer="mean")

Selected orbit direction: ASCENDING from max accumulated area overlap between bounds and products.


In [5]:
job = cube.create_job(
    outputfile=output_path,
    out_format="NetCDF",
    title="Test_ScaleAgData",
    job_options={
        "driver-memory": "4g",
        "executor-memoryOverhead": "4g",
        "soft-error": True,
    },
    # sample_by_feature=True,
)
job.start_and_wait()
job.download_result(output_path)

0:00:00 Job 'j-2408226c9d6442a9b9cecf54160428aa': send 'start'
0:00:15 Job 'j-2408226c9d6442a9b9cecf54160428aa': created (progress 0%)
0:00:32 Job 'j-2408226c9d6442a9b9cecf54160428aa': created (progress 0%)
0:00:42 Job 'j-2408226c9d6442a9b9cecf54160428aa': created (progress 0%)
0:00:57 Job 'j-2408226c9d6442a9b9cecf54160428aa': created (progress 0%)
0:01:07 Job 'j-2408226c9d6442a9b9cecf54160428aa': running (progress N/A)
0:01:22 Job 'j-2408226c9d6442a9b9cecf54160428aa': running (progress N/A)
0:01:38 Job 'j-2408226c9d6442a9b9cecf54160428aa': running (progress N/A)
0:01:57 Job 'j-2408226c9d6442a9b9cecf54160428aa': running (progress N/A)
0:02:22 Job 'j-2408226c9d6442a9b9cecf54160428aa': running (progress N/A)
0:02:55 Job 'j-2408226c9d6442a9b9cecf54160428aa': running (progress N/A)
0:03:33 Job 'j-2408226c9d6442a9b9cecf54160428aa': running (progress N/A)
0:04:20 Job 'j-2408226c9d6442a9b9cecf54160428aa': running (progress N/A)
0:05:18 Job 'j-2408226c9d6442a9b9cecf54160428aa': running (progre

PosixPath('/home/vito/millig/gio/data/scaleag_demo/test/timeseries.nc')

In [6]:
dataset_file = f"{output_path}/timeseries.nc"
gdf_label_file = (
    "/projects/TAP/HEScaleAgData/" \
    "timeseries_modelling/datasets/apr2024_AVR_subfields/" \
    "data/AVR_fields_10000_100000_subfields_yield_bel_nl_roads_removed.geojson"
)

In [30]:
dataset = add_labels(xr_to_df(dataset_file), gdf_label_file)
# dataset

  exec(code_obj, self.user_global_ns, self.user_ns)


In [8]:
# split avoiding data leakage
np.random.seed(3)
df_p = dataset['parentname'].unique()
train_frac = 0.90
sample_idx = np.random.choice(range(0, len(df_p)), size=int(train_frac*len(df_p)), replace=False)
train_df = dataset[dataset['parentname'].isin(df_p[sample_idx])].reset_index(drop=True)
val_df = dataset[~dataset['parentname'].isin(df_p[sample_idx])].reset_index(drop=True)

print(f"Validation: number of field IDs: {len(val_df['parentname'].unique())}, number of samples: {len(val_df)}")
print(f"Training: number of field IDs: {len(train_df['parentname'].unique())}, number of samples: {len(train_df)}")

Validation: number of field IDs: 1, number of samples: 16
Training: number of field IDs: 9, number of samples: 84


In [9]:
target_name = "median_yield"
train_ds = ScaleAG10DDataset(train_df, target_name=target_name, task="regression")
val_ds = ScaleAG10DDataset(val_df, target_name=target_name, task="regression")

dl_train = DataLoader(
    train_ds,
    batch_size=256,
    shuffle=False,
    num_workers=4,
)

dl_val = DataLoader(
    val_ds,
    batch_size=256,
    shuffle=False,
    num_workers=4,
)



In [11]:
model_name = "presto-ss-wc_10D"
encodings_np, targets = get_encodings(dl_train, model_wc_10d)

logger.info(f"Fitting Catboost model with {model_name} encodings")

cbm = cb.CatBoostRegressor(
    random_state=3,
    task_type="GPU",
    devices="0:1",
    logging_level="Silent",
    loss_function="RMSE",
)
train_dataset = cb.Pool(encodings_np, targets)
cbm.fit(train_dataset)

KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/vito/millig/miniconda3/envs/sadenv/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/vito/millig/miniconda3/envs/sadenv/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/vito/millig/miniconda3/envs/sadenv/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/vito/millig/gio/scaleag-vito/src/scaleagdata_vito/presto/datasets.py", line 200, in __getitem__
    eo, mask_per_token, latlon, month, target = self.row_to_arrays(
  File "/home/vito/millig/gio/scaleag-vito/src/scaleagdata_vito/presto/datasets.py", line 86, in row_to_arrays
    [float(row_d[df_val.format(t)]) for t in range(cls.NUM_TIMESTEPS)]
  File "/home/vito/millig/gio/scaleag-vito/src/scaleagdata_vito/presto/datasets.py", line 86, in <listcomp>
    [float(row_d[df_val.format(t)]) for t in range(cls.NUM_TIMESTEPS)]
KeyError: 'METEO-precipitation_flux-ts0-100m'


### Few shot learning with Presto on Crop/no-Crop task

In [3]:
import pandas as pd
from presto.dataset import WorldCerealLabelled10DDataset

# Load WorldCereal dataset from artifactory

wc_train_dataset = pd.read_parquet(
    "/home/vito/millig/gio/data/presto_ft/rawts-10d_train.parquet"
)

wc_val_dataset = pd.read_parquet("/home/vito/millig/gio/data/presto_ft/rawts-10d_val.parquet"
)

In [4]:
wc_train_dataset = wc_train_dataset.sample(frac=0.005)
wc_val_dataset = wc_val_dataset.sample(frac=0.005)

print(len(wc_train_dataset), len(wc_val_dataset))

3732 933


In [5]:
wc_train_ds = WorldCerealLabelled10DDataset(wc_train_dataset)
wc_val_ds = WorldCerealLabelled10DDataset(wc_val_dataset)
wc_dl_train = DataLoader(
    wc_train_ds,
    batch_size=256,
    shuffle=False,
    num_workers=4,
)

wc_dl_val = DataLoader(
    wc_val_ds,
    batch_size=256,
    shuffle=False,
    num_workers=4,
)



In [6]:
model_name = "presto-ss-wc_10D"
encodings_np, targets = get_encodings(wc_dl_train, model_wc_10d)

In [8]:
logger.info(f"Fitting Catboost model with {model_name} encodings")

cbm = cb.CatBoostClassifier(
    random_state=3,
    task_type="GPU",
    devices="0:1",
    logging_level="Silent",
)
train_dataset = cb.Pool(encodings_np, targets)
cbm.fit(train_dataset)

[32m2024-08-22 15:48:10.930[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mFitting Catboost model with presto-ss-wc_10D encodings[0m



<catboost.core.CatBoostClassifier at 0x7f523bfa1a50>