# Few-Shot Learning with Presto

### Notebook Overview 

1) Short introduction on Foundation Models and Presto
2) Definition of Few-Shot learning
3) Apply Presto to perfrom Few-Shot learning on a regression and a classification task

### 1) Foundation Models

A Foundation Model is a model trained on large and diverse unlabeled datasets to learn general patterns and features of the data. Thanks to its strong generalization capabilities, such a model can be adapted for a wide range of applications that use similar types of input data.

**Presto** (**P**retrained **Re**mote **S**ensing **T**ransf**o**rmer) is a foundation model trained on a large, unlabeled dataset of Sentinel-2, Sentinel-1, Meteorological and Topography pixel-timeseries data. It is able to capture long-range relationships across time and sensor dimensions, improving the signal-to-noise ratio and providing a concise, informative representation of the inputs. 
In this project, We made use of the Presto version developed in collaboration with [WorldCereal](https://github.com/WorldCereal/presto-worldcereal/)

Originally trained on monthly composites, Presto has been refined to be able to ingest dekadal data and to be fine-tuned for regression and classification tasks.

### 2) Few-Shot Learning

Few-shot learning aims to develop models that can learn from a small number of labeled instances while enhancing generalization and performance on new, unseen examples.

Given a dataset with only a few annotated examples, we can fine-tune a pretrained foundation model to either directly handle the downstream task or generate compressed representations of the inputs, which can then be used to train a machine learning model for the downstream task.
The figure below provides an overview of the latter scenario

<div style="text-align: center;">
    <img src="../images/ScaleAG_pipeline_overview_presto_ml.jpg" alt="Overview of a Foundation Model used to produce embeddings which can be fed as training examples to downstream models for different tasks and applications." width="700" />
    <p><em>Overview of a Foundation Model used to produce embeddings which can be fed as training examples to downstream models for different tasks and applications.</em></p>
</div>

### 3) Implementing Few-Shot learning with Presto

In [20]:
%load_ext autoreload
%autoreload 2
import catboost as cb
from loguru import logger
from pathlib import Path
import sys
sys.path.append("/home/vito/millig/gio/prometheo/")
from prometheo.datasets.scaleag import ScaleAgDataset # fix installation
from prometheo import finetune
from prometheo.finetune import Hyperparams
from prometheo.models.presto.wrapper import PretrainedPrestoWrapper, load_pretrained
import torch
from torch import nn
from torch.utils.data import DataLoader
from scaleagdata_vito.openeo.extract_sample_scaleag import generate_extraction_job_command
from scaleagdata_vito.presto.finetuned_eval import evaluate_finetuned_model, evaluate_downstream_model, get_encodings
from scaleagdata_vito.presto.presto_df import load_dataset

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### Fetch data from OpenEO

To set up the job, we adapt the job parameters to our needs. The user has to indicate the following fields in order to generate the command to be run in the terminal for starting the extraction 

```python
job_params = dict(
    output_folder=..., 
    input_df=...,
    start_date=...,
    end_date=...,
    unique_id_column=...,
    composite_window=..., # "month" or "dekad" are supported. Default is "dekad"
)

```

In [33]:
job_params = dict(
    output_folder="/projects/TAP/HEScaleAgData/data/AVR_subfields/extractions_16012025/",
    input_df="/home/vito/millig/gio/data/scaleag_extractions/AVR_fields_10000_100000_subfields_yield_bel_nl_roads_removed.geojson",
    start_date="2022-01-01",
    end_date="2022-12-31",
    unique_id_column="fieldname",
    composite_window="month",
)
generate_extraction_job_command(job_params)

python scaleag-vito/scripts/extractions/extract.py -output_folder /projects/TAP/HEScaleAgData/data/AVR_subfields/extractions_16012025/ -input_df /home/vito/millig/gio/data/scaleag_extractions/AVR_fields_10000_100000_subfields_yield_bel_nl_roads_removed.geojson --start_date 2022-01-01 --end_date 2022-12-31 --unique_id_column fieldname --composite_window month


#### Regression task: yield estimation 

Potato yield estimation. The data cover fields in Belgium and The Netherlands during the growing season. 
In order to test the generalization capabilities of the different models and combinations, we limit data correlation by using data from Belgium as training set and those from The Netherlands as validation set.

In [29]:
# load extracted dataset 
df = load_dataset(job_params["output_folder"], num_timesteps=36)

100%|██████████| 293/293 [00:16<00:00, 17.44it/s]


In [6]:
# split in train and val
df_sample = df.sample(frac=0.1, random_state=42)
df_train = df_sample.sample(frac=0.8, random_state=42)
df_val = df_sample[~df_sample.sample_id.isin(df_train.sample_id)]

print(f"Train size: {len(df_train)}")
print(f"Val size: {len(df_val)}")

Train size: 1109
Val size: 277


In [7]:
# initialize datasets 
train_ds = ScaleAgDataset(
    df_train,
    num_timesteps=36,
    task_type="regression",
    target_name="median_yield",
    upper_bound=120000,
    lower_bound=10000,
)
val_ds = ScaleAgDataset(
    df_val,
    num_timesteps=36,
    task_type="regression",
    target_name="median_yield",
    upper_bound=120000,
    lower_bound=10000,
)

[32m2025-01-27 14:31:02.791[0m | [1mINFO    [0m | [36mprometheo.datasets.scaleag[0m:[36mset_num_outputs[0m:[36m133[0m - [1mSetting number of outputs to 1 for regression task.[0m
[32m2025-01-27 14:31:02.827[0m | [1mINFO    [0m | [36mprometheo.datasets.scaleag[0m:[36mset_num_outputs[0m:[36m133[0m - [1mSetting number of outputs to 1 for regression task.[0m


#### Finetuning

In [7]:
# Construct the model with finetuning head
pretrained_model_path = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/scaleagdata/models/presto-ss-wc_30D.pt"
model = PretrainedPrestoWrapper(
    num_outputs=1,
    regression=True,
    pretrained_model_path=pretrained_model_path,
)

# Reduce epochs for testing purposes
hyperparams = Hyperparams(max_epochs=10, batch_size=256, patience=1, num_workers=2)
output_dir = Path("/home/vito/millig/gio/presto_exp/prometheo_exp")

# set loss depending on the task type
if train_ds.task_type == "regression":
    loss_fn = nn.MSELoss()
elif train_ds.task_type == "binary":
    loss_fn = nn.BCEWithLogitsLoss()
else:
    loss_fn = nn.CrossEntropyLoss()

finetuned_model = finetune.run_finetuning(
            model,
            train_ds,
            val_ds,
            experiment_name="presto-ss-wc-30D-ft-dek",
            output_dir=output_dir,
            loss_fn=loss_fn,
            hyperparams=hyperparams,
        )

[32m2025-01-27 10:48:12[0m | [1mINFO    [0m | [36mprometheo.utils[0m - [1mLogging setup complete. Logging to: /home/vito/millig/gio/presto_exp/prometheo_exp/logs/test30d.log and console.[0m
[32m2025-01-27 10:48:12[0m | [1mINFO    [0m | [36mprometheo.finetune[0m - [1mUsing output dir: /home/vito/millig/gio/presto_exp/prometheo_exp[0m


Finetuning:   0%|          | 0/10 [00:00<?, ?it/s]







Train metric: 0.253, Val metric: 0.146, Best Val Loss: 0.146 (improved):  10%|█         | 1/10 [00:33<04:59, 33.27s/it]







Train metric: 0.108, Val metric: 0.049, Best Val Loss: 0.049 (improved):  20%|██        | 2/10 [01:08<04:34, 34.33s/it]







Train metric: 0.045, Val metric: 0.024, Best Val Loss: 0.024 (improved):  30%|███       | 3/10 [01:42<03:59, 34.20s/it]





[32m2025-01-27 10:50:32[0m | [1mINFO    [0m | [36mprometheo.finetune[0m - [1mEarly stopping![0m


Train metric: 0.045, Val metric: 0.024, Best Val Loss: 0.024 (improved):  30%|███       | 3/10 [02:19<05:26, 46.58s/it]

[32m2025-01-27 10:50:32[0m | [1mINFO    [0m | [36mprometheo.finetune[0m - [1mFinetuning done[0m





#### Evaluate using end-to-end finetuned Presto

In [14]:
finetuned_model = PretrainedPrestoWrapper(num_outputs=1, regression=True)
finetuned_model = load_pretrained(finetuned_model, "/home/vito/millig/gio/presto_exp/prometheo_exp/test30d.pt")

evaluate_finetuned_model(finetuned_model, val_ds, num_workers=2, batch_size=32)

[32m2025-01-27 14:45:16.250[0m | [1mINFO    [0m | [36mscaleagdata_vito.presto.finetuned_eval[0m:[36mevaluate_finetuned_model[0m:[36m90[0m - [1mEvaluating the finetuned model on regression task[0m


{'RMSE': 15456.115880524178,
 'R2_score': -4.488653519352723,
 'explained_var_score': -4.485200183298538,
 'MAPE': 1.0393098899097182}

#### Train downstream model on Presto encodings and evaluate

In [21]:
notebook_device = "GPU" if torch.cuda.is_available() else None
cbm = cb.CatBoostRegressor(
    random_state=3,
    task_type=notebook_device,
    logging_level="Silent",
    loss_function="RMSE",
)
logger.info("Computing Presto encodings")
train_dl = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=2)
train_encodings, train_targets = get_encodings(train_dl, finetuned_model)
logger.info("Fitting Catboost model on Presto encodings")
train_dataset = cb.Pool(train_encodings, train_targets)
cbm.fit(train_dataset)

[32m2025-01-27 14:48:49.994[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m8[0m - [1mComputing Presto encodings[0m
[32m2025-01-27 14:49:05.484[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mFitting Catboost model on Presto encodings[0m


<catboost.core.CatBoostRegressor at 0x7fe770805540>

In [27]:
evaluate_downstream_model( finetuned_model, cbm, val_ds, num_workers=2, batch_size=32)

[32m2025-01-27 14:50:58.408[0m | [1mINFO    [0m | [36mscaleagdata_vito.presto.finetuned_eval[0m:[36mevaluate_downstream_model[0m:[36m129[0m - [1mEvaluating the finetuned model on regression task[0m


{'RMSE': 7222.96229491022,
 'R2_score': -0.1986593607634093,
 'explained_var_score': -0.19777567369644822,
 'MAPE': 0.07087597481094239}