# Few-Shot Learning with Presto

### Notebook Overview 

1) Short introduction on Foundation Models and Presto
2) Definition of Few-Shot learning
3) Apply Presto to perfrom Few-Shot learning on a regression and a classification task

### 1) Foundation Models

A Foundation Model is a model trained on large and diverse unlabeled datasets to learn general patterns and features of the data. Thanks to its strong generalization capabilities, such a model can be adapted for a wide range of applications that use similar types of input data.

**Presto** (**P**retrained **Re**mote **S**ensing **T**ransf**o**rmer) is a foundation model trained on a large, unlabeled dataset of Sentinel-2, Sentinel-1, Meteorological and Topography pixel-timeseries data. It is able to capture long-range relationships across time and sensor dimensions, improving the signal-to-noise ratio and providing a concise, informative representation of the inputs. 
In this project, We made use of the Presto version developed in collaboration with [WorldCereal](https://github.com/WorldCereal/presto-worldcereal/)

Originally trained on monthly composites, Presto has been refined to be able to ingest dekadal data and to be fine-tuned for regression and classification tasks.

### 2) Few-Shot Learning

Few-shot learning aims to develop models that can learn from a small number of labeled instances while enhancing generalization and performance on new, unseen examples.

Given a dataset with only a few annotated examples, we can fine-tune a pretrained foundation model to either directly handle the downstream task or generate compressed representations of the inputs, which can then be used to train a machine learning model for the downstream task.
The figure below provides an overview of the latter scenario

<div style="text-align: center;">
    <img src="../images/ScaleAG_pipeline_overview_presto_ml.jpg" alt="Overview of a Foundation Model used to produce embeddings which can be fed as training examples to downstream models for different tasks and applications." width="700" />
    <p><em>Overview of a Foundation Model used to produce embeddings which can be fed as training examples to downstream models for different tasks and applications.</em></p>
</div>

### 3) Implementing Few-Shot learning with Presto

In [1]:
%load_ext autoreload
%autoreload 2
import catboost as cb
from loguru import logger
from pathlib import Path
import sys
sys.path.append("/home/vito/millig/gio/prometheo/")
from prometheo.datasets.scaleag import ScaleAgDataset # fix installation
from prometheo import finetune
from prometheo.finetune import Hyperparams
from prometheo.models.presto.wrapper import PretrainedPrestoWrapper, load_pretrained
import torch
from torch import nn
from torch.utils.data import DataLoader
from scaleagdata_vito.openeo.extract_sample_scaleag import generate_extraction_job_command
from scaleagdata_vito.presto.utils import evaluate_finetuned_model, evaluate_downstream_model, get_encodings
from scaleagdata_vito.presto.presto_df import load_dataset

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  from .autonotebook import tqdm as notebook_tqdm


#### Fetch data from OpenEO

To set up the job, we adapt the job parameters to our needs. The user has to indicate the following fields in order to generate the command to be run in the terminal for starting the extraction 

```python
job_params = dict(
    output_folder=..., 
    input_df=...,
    start_date=...,
    end_date=...,
    unique_id_column=...,
    composite_window=..., # "month" or "dekad" are supported. Default is "dekad"
)

```

In [3]:
job_params = dict(
    output_folder="/projects/TAP/HEScaleAgData/data/AVR_subfields/extractions_31012025/",
    input_df="/home/vito/millig/gio/data/scaleag_extractions/AVR_fields_10000_100000_subfields_yield_bel_nl_roads_removed.geojson",
    start_date="2022-01-01",
    end_date="2022-12-31",
    unique_id_column="fieldname",
    composite_window="dekad",
)
generate_extraction_job_command(job_params)

python scaleag-vito/scripts/extractions/extract.py -output_folder /projects/TAP/HEScaleAgData/data/AVR_subfields/extractions_31012025/ -input_df /home/vito/millig/gio/data/scaleag_extractions/AVR_fields_10000_100000_subfields_yield_bel_nl_roads_removed.geojson --start_date 2022-01-01 --end_date 2022-12-31 --unique_id_column fieldname --composite_window dekad


#### Regression task: yield estimation 

Potato yield estimation. The data cover fields in Belgium and The Netherlands during the growing season. 
In order to test the generalization capabilities of the different models and combinations, we limit data correlation by using data from Belgium as training set and those from The Netherlands as validation set.

In [2]:
# load extracted dataset
window_of_interest = ["2022-04-01", "2022-10-31"]
df = load_dataset(
    files_root_dir="/projects/TAP/HEScaleAgData/data/AVR_subfields/extractions_31012025/",
    window_of_interest=window_of_interest,
    use_valid_time=False,
    required_min_timesteps=36,
    buffer_window=8,
    no_data_value=65535,
    composite_window="dekad",
)

 20%|██        | 58/284 [00:03<00:13, 16.93it/s]


KeyboardInterrupt: 

In [5]:
import random
# split in train and val
# df_sample = df.sample(frac=0.1, random_state=42)
sampling_frac = 0.8
random.seed(3)
parentname = df.parentname.unique()
parentname_train = random.sample(list(parentname), int(len(parentname)*sampling_frac))
df_sample = df.copy()
df_train = df_sample[df_sample.parentname.isin(parentname_train)]
df_val = df_sample[~df_sample.parentname.isin(parentname_train)]

print(f"Train size: {len(df_train)}")
print(f"Val size: {len(df_val)}")

Train size: 11180
Val size: 2685


In [None]:
# initialize datasets
num_timesteps = df.available_timesteps.max()

train_ds = ScaleAgDataset(
    df_train,
    num_timesteps=num_timesteps,
    task_type="regression",
    target_name="median_yield",
    compositing_window="dekad",
    upper_bound=120000,
    lower_bound=10000,
)
val_ds = ScaleAgDataset(
    df_val,
    num_timesteps=num_timesteps,
    task_type="regression",
    target_name="median_yield",
    compositing_window="dekad",
    upper_bound=120000,
    lower_bound=10000,
)

[32m2025-02-07 10:58:10.843[0m | [1mINFO    [0m | [36mprometheo.datasets.scaleag[0m:[36mset_num_outputs[0m:[36m133[0m - [1mSetting number of outputs to 1 for regression task.[0m
[32m2025-02-07 10:58:10.852[0m | [1mINFO    [0m | [36mprometheo.datasets.scaleag[0m:[36mset_num_outputs[0m:[36m133[0m - [1mSetting number of outputs to 1 for regression task.[0m


#### Finetuning

In [None]:
# Construct the model with finetuning head
pretrained_model_path = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/scaleagdata/models/presto-ss-wc_10D.pt"
model = PretrainedPrestoWrapper(
    num_outputs=1,
    regression=True,
)
model = load_pretrained(model, pretrained_model_path, strict=False)

# Reduce epochs for testing purposes
hyperparams = Hyperparams(max_epochs=50, batch_size=256, patience=1, num_workers=2)
output_dir = Path("/home/vito/millig/gio/presto_exp/prometheo_exp")

# set loss depending on the task type
if train_ds.task_type == "regression":
    loss_fn = nn.MSELoss()
elif train_ds.task_type == "binary":
    loss_fn = nn.BCEWithLogitsLoss()
else:
    loss_fn = nn.CrossEntropyLoss()

finetuned_model = finetune.run_finetuning(
            model,
            train_ds,
            val_ds,
            experiment_name="presto-ss-wc-10D-ft-dek",
            output_dir=output_dir,
            loss_fn=loss_fn,
            hyperparams=hyperparams,
        )

[32m2025-02-07 10:58:18[0m | [1mINFO    [0m | [36mprometheo.utils[0m - [1mLogging setup complete. Logging to: /home/vito/millig/gio/presto_exp/prometheo_exp/logs/presto-ss-wc-10D-ft-dek.log and console.[0m
[32m2025-02-07 10:58:18[0m | [1mINFO    [0m | [36mprometheo.finetune[0m - [1mUsing output dir: /home/vito/millig/gio/presto_exp/prometheo_exp[0m


Train metric: 0.006, Val metric: 0.005, Best Val Loss: 0.005 (improved):  62%|██████▏   | 31/50 [2:05:00<1:22:02, 259.07s/it]

#### Evaluate using end-to-end finetuned Presto

In [6]:
finetuned_model = PretrainedPrestoWrapper(num_outputs=1, regression=True)
finetuned_model = load_pretrained(
    finetuned_model,
    "/home/vito/millig/gio/presto_exp/prometheo_exp/presto-ss-wc-10D-ft-dek.pt",
)

evaluate_finetuned_model(finetuned_model, val_ds, num_workers=2, batch_size=32)

[32m2025-02-06 13:11:50.096[0m | [1mINFO    [0m | [36mscaleagdata_vito.presto.utils[0m:[36mevaluate_finetuned_model[0m:[36m91[0m - [1mEvaluating the finetuned model on regression task[0m


{'RMSE': 9361.403876379496,
 'R2_score': -0.011486582941066636,
 'explained_var_score': -0.011409980070501424,
 'MAPE': 0.12428021738426595}

#### Train downstream model on Presto encodings and evaluate

In [11]:
notebook_device = "GPU" if torch.cuda.is_available() else None
cbm = cb.CatBoostRegressor(
    random_state=3,
    task_type=notebook_device,
    logging_level="Silent",
    loss_function="RMSE",
)
logger.info("Computing Presto encodings")
train_dl = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=2)
train_encodings, train_targets = get_encodings(train_dl, finetuned_model)
logger.info("Fitting Catboost model on Presto encodings")
train_dataset = cb.Pool(train_encodings, train_targets)
cbm.fit(train_dataset)

[32m2025-02-04 10:37:21[0m | [1mINFO    [0m | [36m__main__[0m - [1mComputing Presto encodings[0m
[32m2025-02-04 10:39:14[0m | [1mINFO    [0m | [36m__main__[0m - [1mFitting Catboost model on Presto encodings[0m





<catboost.core.CatBoostRegressor at 0x7f142883a7d0>

In [12]:
evaluate_downstream_model(finetuned_model, cbm, val_ds, num_workers=2, batch_size=32)

[32m2025-02-04 10:39:46[0m | [1mINFO    [0m | [36mscaleagdata_vito.presto.utils[0m - [1mEvaluating the finetuned model on regression task[0m


{'RMSE': 9857.511666240649,
 'R2_score': 0.10692120073691003,
 'explained_var_score': 0.10769062053643519,
 'MAPE': 0.0780827248908144}

In [12]:
from scaleagdata_vito.presto.presto_utils_demo import revert_to_original_units
from scaleagdata_vito.demo.utils import prepare_data_for_cb
notebook_device = "GPU" if torch.cuda.is_available() else None
raw_cbm = cb.CatBoostRegressor(
    random_state=3,
    task_type=notebook_device,
    logging_level="Silent",
    loss_function="RMSE",
)

train_x, train_y = prepare_data_for_cb(
    df_train,
    "median_yield",
    lower_bound=10000,
    upper_bound=120000,
    num_time_steps=num_timesteps,
)
val_x, val_y = prepare_data_for_cb(
    df_val, 
    "median_yield",
    lower_bound=10000, 
    upper_bound=120000,
    num_time_steps=num_timesteps)

train_pool = cb.Pool(train_x, train_y)
raw_cbm.fit(train_pool)




<catboost.core.CatBoostRegressor at 0x7f47fbc69de0>

In [13]:
import numpy as np
from sklearn.metrics import mean_squared_error, r2_score, explained_variance_score, mean_absolute_percentage_error

preds = raw_cbm.predict(val_x)
targets = revert_to_original_units(
    val_y, lower_bound=10000, upper_bound=120000
)
preds = revert_to_original_units(
    preds, lower_bound=10000, upper_bound=120000
)
metrics = {
    "RMSE": float(np.sqrt(mean_squared_error(targets, preds))),
    "R2_score": float(r2_score(targets, preds)),
    "explained_var_score": float(explained_variance_score(targets, preds)),
    "MAPE": float(mean_absolute_percentage_error(targets, preds)),
}

In [14]:
metrics

{'RMSE': 13811.234190851072,
 'R2_score': 0.6525257053276933,
 'explained_var_score': 0.652657666821981,
 'MAPE': 0.15791034200182627}

### Fine tune on crop/no-crop

In [2]:
from presto.utils import prep_dataframe, process_parquet
import pandas as pd
from typing import Optional, List, Tuple, Union
from glob import glob
from tqdm import tqdm
import numpy as np


def split_df(
    df: pd.DataFrame,
    val_sample_ids: Optional[List[str]] = None,
    val_countries_iso3: Optional[List[str]] = None,
    val_years: Optional[List[int]] = None,
    val_size: Optional[float] = None,
    train_only_samples: Optional[List[str]] = None,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    if val_size is not None:
        assert (
            (val_countries_iso3 is None)
            and (val_years is None)
            and (val_sample_ids is None)
        )
        val, train = np.split(
            df.sample(frac=1, random_state=42), [int(val_size * len(df))]
        )
        logger.info(f"Using {len(train)} train and {len(val)} val samples")
        return pd.DataFrame(train), pd.DataFrame(val)
    if val_sample_ids is not None:
        assert (val_countries_iso3 is None) and (val_years is None)
        is_val = df.sample_id.isin(val_sample_ids)
        is_train = ~df.sample_id.isin(val_sample_ids)
    elif val_countries_iso3 is not None:
        assert (val_sample_ids is None) and (val_years is None)
        df = join_with_world_df(df)
        for country in val_countries_iso3:
            assert df.iso3.str.contains(
                country
            ).any(), f"Tried removing {country} but it is not in the dataframe"
        if train_only_samples is not None:
            is_val = df.iso3.isin(val_countries_iso3) & ~df.sample_id.isin(
                train_only_samples
            )
        else:
            is_val = df.iso3.isin(val_countries_iso3)
        is_train = ~df.iso3.isin(val_countries_iso3)
    elif val_years is not None:
        df["end_date_ts"] = pd.to_datetime(df.end_date)
        if train_only_samples is not None:
            is_val = df.end_date_ts.dt.year.isin(val_years) & ~df.sample_id.isin(
                train_only_samples
            )
        else:
            is_val = df.end_date_ts.dt.year.isin(val_years)
        is_train = ~df.end_date_ts.dt.year.isin(val_years)

    logger.info(
        f"Using {len(is_val) - sum(is_val)} train and {sum(is_val)} val samples"
    )

    return df[is_train], df[is_val]


def load_df(
    parquet_file: Union[Path, str],
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    logger.info("Reading dataset")
    files = sorted(glob(f"{parquet_file}/**/*.parquet"))

    df_list = []
    for f in tqdm(files[:50]):
        _data = pd.read_parquet(f, engine="fastparquet")
        _ref_id = f.split("/")[-2].split("=")[-1]
        _data["ref_id"] = _ref_id
        _data_pivot = process_parquet(_data)
        _data_pivot.reset_index(inplace=True)
        df_list.append(_data_pivot)
    df = pd.concat(df_list)
    df = df.fillna(65535)
    del df_list
    return df

def prepare_training_df(df, val_samples_file=None):
    df = prep_dataframe(df, filter_function=None, dekadal=False).reset_index()

    if val_samples_file is not None:
        logger.info(f"Controlled train/test split based on: {val_samples_file}")
        val_samples_df = pd.read_csv(val_samples_file)
        train_df, test_df = split_df(
            df, val_sample_ids=val_samples_df.sample_id.tolist()
        )
    else:
        logger.info("Random train/test split ...")
        train_df, test_df = split_df(df, val_size=0.2)
    train_df, val_df = split_df(train_df, val_size=0.2)

    return train_df, val_df, test_df

In [3]:
parquet_file = "/home/vito/millig/projects/worldcereal/data/worldcereal_training_data_monthly.parquet/worldcereal_training_data.parquet"
files = sorted(glob(f"{parquet_file}/**/*.parquet"))
len(files)

121

In [4]:
# from scaleagdata_vito.presto.presto_df import process_parquet

# Training parameters
pretrained_model_path = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc_longparquet_random-window-cut_no-time-token_epoch96_corrected-mask.pt"
parquet_file = "/home/vito/millig/projects/worldcereal/data/worldcereal_training_data_monthly.parquet/worldcereal_training_data.parquet"
val_samples_file = "/home/vito/millig/gio/worldcereal-classification/scripts/training/finetuning/cropland_random_generalization_test_split_samples.csv"

epochs = 100
batch_size = 512
patience = 5
num_workers = 16

# ------------------------------------------
# Get the train/val/test dataframes
df = load_df(parquet_file)

# val_df.columns

[32m2025-02-10 10:33:38.660[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_df[0m:[36m66[0m - [1mReading dataset[0m


 14%|█▍        | 7/50 [00:04<00:28,  1.51it/s]Dropping 5 faulty samples.
100%|██████████| 50/50 [00:33<00:00,  1.48it/s]


In [8]:
df_sample = df.sample(frac=0.01,random_state=42)
df_sample.size

1982928

In [9]:
train_df, val_df, test_df = prepare_training_df(df_sample, val_samples_file)
train_df.available_timesteps.max(), test_df.available_timesteps.max(), val_df.available_timesteps.max()

[32m2025-02-10 10:37:38.648[0m | [1mINFO    [0m | [36m__main__[0m:[36mprepare_training_df[0m:[36m86[0m - [1mControlled train/test split based on: /home/vito/millig/gio/worldcereal-classification/scripts/training/finetuning/cropland_random_generalization_test_split_samples.csv[0m
[32m2025-02-10 10:37:39.031[0m | [1mINFO    [0m | [36m__main__[0m:[36msplit_df[0m:[36m56[0m - [1mUsing 3895 train and 653 val samples[0m
  return bound(*args, **kwds)
[32m2025-02-10 10:37:39.070[0m | [1mINFO    [0m | [36m__main__[0m:[36msplit_df[0m:[36m26[0m - [1mUsing 3116 train and 779 val samples[0m


(29, 25, 29)

In [10]:
# initialize datasets
num_timesteps = train_df.available_timesteps.max()

train_ds = ScaleAgDataset(
    train_df,
    num_timesteps=num_timesteps,
    task_type="binary",
    target_name="LANDCOVER_LABEL",
    positive_labels=[10, 11, 12, 13],
    compositing_window="month",
)
val_ds = ScaleAgDataset(
    val_df,
    num_timesteps=num_timesteps,
    task_type="binary",
    target_name="LANDCOVER_LABEL",
    positive_labels=[10, 11, 12, 13],
    compositing_window="month",
)

[32m2025-02-10 10:37:44.869[0m | [1mINFO    [0m | [36mprometheo.datasets.scaleag[0m:[36mset_num_outputs[0m:[36m133[0m - [1mSetting number of outputs to 1 for binary task.[0m
[32m2025-02-10 10:37:44.876[0m | [1mINFO    [0m | [36mprometheo.datasets.scaleag[0m:[36mset_num_outputs[0m:[36m133[0m - [1mSetting number of outputs to 1 for binary task.[0m


In [11]:
from torch import nn
from torch.optim import AdamW, lr_scheduler
from torch.utils.data import DataLoader
from prometheo.finetune import Hyperparams, run_finetuning
from prometheo.models.presto import param_groups_lrd

# set loss depending on the task type
if train_ds.task_type == "regression":
    loss_fn = nn.MSELoss()
elif train_ds.task_type == "binary":
    loss_fn = nn.BCEWithLogitsLoss()
else:
    loss_fn = nn.CrossEntropyLoss()

model = PretrainedPrestoWrapper(
    num_outputs=1,
    regression=False,
    pretrained_model_path=pretrained_model_path,
)

# Set the parameters
hyperparams = Hyperparams(
    max_epochs=epochs,
    batch_size=batch_size,
    patience=patience,
    num_workers=num_workers,
)
parameters = param_groups_lrd(model)
optimizer = AdamW(parameters, lr=hyperparams.lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

# Run the finetuning
logger.info("Starting finetuning...")
finetuned_model = run_finetuning(
    model=model,
    train_ds=train_ds,
    val_ds=val_ds,
    experiment_name="presto_wc_ft_crop",
    output_dir=output_dir,
    loss_fn=loss_fn,
    optimizer=optimizer,
    scheduler=scheduler,
    hyperparams=hyperparams,
    setup_logging=False,  # Already setup logging
)

ConnectTimeout: HTTPSConnectionPool(host='artifactory.vgt.vito.be', port=443): Max retries exceeded with url: /artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc_longparquet_random-window-cut_no-time-token_epoch96_corrected-mask.pt (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7f050a71e2f0>, 'Connection to artifactory.vgt.vito.be timed out. (connect timeout=None)'))