# Few-Shot Learning with Presto

### Notebook Overview 

1) Short introduction on Foundation Models and Presto
2) Definition of Few-Shot learning
3) Apply Presto to perfrom Few-Shot learning on a regression and a classification task

### 1) Foundation Models

A Foundation Model is a model trained on large and diverse unlabeled datasets to learn general patterns and features of the data. Thanks to its strong generalization capabilities, such a model can be adapted for a wide range of applications that use similar types of input data.

**Presto** (**P**retrained **Re**mote **S**ensing **T**ransf**o**rmer) is a foundation model trained on a large, unlabeled dataset of Sentinel-2, Sentinel-1, Meteorological and Topography pixel-timeseries data. It is able to capture long-range relationships across time and sensor dimensions, improving the signal-to-noise ratio and providing a concise, informative representation of the inputs. 
In this project, We made use of the Presto version developed in collaboration with [WorldCereal](https://github.com/WorldCereal/presto-worldcereal/)

Originally trained on monthly composites, Presto has been refined to be able to ingest dekadal data and to be fine-tuned for regression and classification tasks.

### 2) Few-Shot Learning

Few-shot learning aims to develop models that can learn from a small number of labeled instances while enhancing generalization and performance on new, unseen examples.

Given a dataset with only a few annotated examples, we can fine-tune a pretrained foundation model to either directly handle the downstream task or generate compressed representations of the inputs, which can then be used to train a machine learning model for the downstream task.
The figure below provides an overview of the latter scenario

<div style="text-align: center;">
    <img src="../images/ScaleAG_pipeline_overview_presto_ml.jpg" alt="Overview of a Foundation Model used to produce embeddings which can be fed as training examples to downstream models for different tasks and applications." width="700" />
    <p><em>Overview of a Foundation Model used to produce embeddings which can be fed as training examples to downstream models for different tasks and applications.</em></p>
</div>

### 3) Implementing Few-Shot learning with Presto

In [2]:
# %load_ext autoreload
# %autoreload 2
import catboost as cb
from catboost import Pool
import pandas as pd
from torch.utils.data import DataLoader
from scaleagdata_vito.openeo.extract_sample_scaleag import generate_extraction_job_command
from pathlib import Path
import sys
sys.path.append("/home/vito/millig/gio/prometheo/")
from prometheo.datasets.scaleag import ScaleAgDataset # fix installation
from prometheo import finetune
from prometheo.finetune import Hyperparams
from prometheo.models import Presto
from torch import nn

from scaleagdata_vito.presto.presto_df import load_dataset
from scaleagdata_vito.presto.presto_utils import (
    load_pretrained_model_from_url, 
    evaluate,
    evaluate_catboost,
    train_catboost_on_encodings,
)

  from .autonotebook import tqdm as notebook_tqdm


#### Fetch data from OpenEO

To set up the job, we adapt the job parameters to our needs. The user has to indicate the following fields in order to generate the command to be run in the terminal for starting the extraction 

```python
job_params = dict(
    output_folder=..., 
    input_df=...,
    start_date=...,
    end_date=...,
    unique_id_column=...,
    composite_window=..., # "month" or "dekad" are supported. Default is "dekad"
)

```

In [3]:
job_params = dict(
    output_folder="/projects/TAP/HEScaleAgData/data/AVR_subfields/extractions_16012025/",
    input_df="/home/vito/millig/gio/data/scaleag_extractions/AVR_fields_10000_100000_subfields_yield_bel_nl_roads_removed.geojson",
    start_date="2022-01-01",
    end_date="2022-12-31",
    unique_id_column="fieldname",
    composite_window="month",
)
generate_extraction_job_command(job_params)

python scaleag-vito/scripts/extractions/extract.py -output_folder /projects/TAP/HEScaleAgData/data/AVR_subfields/extractions_16012025/ -input_df /home/vito/millig/gio/data/scaleag_extractions/AVR_fields_10000_100000_subfields_yield_bel_nl_roads_removed.geojson --start_date 2022-01-01 --end_date 2022-12-31 --unique_id_column fieldname --composite_window month


#### Regression task: yield estimation 

Potato yield estimation. The data cover fields in Belgium and The Netherlands during the growing season. 
In order to test the generalization capabilities of the different models and combinations, we limit data correlation by using data from Belgium as training set and those from The Netherlands as validation set.

In [8]:
# load extracted dataset 
df = load_dataset(job_params["output_folder"], num_timesteps=36)

100%|██████████| 284/284 [00:20<00:00, 13.79it/s]


In [9]:
# split in train and val
df_sample = df.sample(frac=0.1, random_state=42)
df_train = df_sample.sample(frac=0.8, random_state=42)
df_val = df_sample[~df_sample.sample_id.isin(df_train.sample_id)]

print(f"Train size: {len(df_train)}")
print(f"Val size: {len(df_val)}")

Train size: 1109
Val size: 277


In [10]:
# initialize datasets 
train_ds = ScaleAgDataset(
    df_train,
    num_timesteps=36,
    task_type="regression",
    target_name="median_yield",
    upper_bound=120000,
    lower_bound=10000,
)
val_ds = ScaleAgDataset(
    df_val,
    num_timesteps=36,
    task_type="regression",
    target_name="median_yield",
    upper_bound=120000,
    lower_bound=10000,
)

#### Finetuning

In [12]:
# Construct the model with finetuning head
model = Presto(num_outputs=train_ds.num_outputs, regression=True)

# Reduce epochs for testing purposes
hyperparams = Hyperparams(max_epochs=10, batch_size=256, patience=1, num_workers=2)
output_dir = Path("/home/vito/millig/gio/presto_exp/prometheo_exp")

# set loss
if train_ds.task_type == "regression":
    loss_fn = nn.MSELoss()
elif train_ds.task_type == "binary":
    loss_fn = nn.BCEWithLogitsLoss()
else:
    loss_fn = nn.CrossEntropyLoss()

finetune.run_finetuning(
    model,
    train_ds,
    val_ds,
    experiment_name="test",
    output_dir=output_dir,
    loss_fn=loss_fn,
    hyperparams=hyperparams,
)

22-01-2025 19:44:29 - INFO - Initialized logging to /home/vito/millig/gio/presto_exp/prometheo_exp/logs/console-output.log
22-01-2025 19:44:29 - INFO - Initialized logging to /home/vito/millig/gio/presto_exp/prometheo_exp/logs/console-output.log


INFO:__main__:Initialized logging to /home/vito/millig/gio/presto_exp/prometheo_exp/logs/console-output.log


22-01-2025 19:44:29 - INFO - Using output dir: /home/vito/millig/gio/presto_exp/prometheo_exp
22-01-2025 19:44:29 - INFO - Using output dir: /home/vito/millig/gio/presto_exp/prometheo_exp


INFO:__main__:Using output dir: /home/vito/millig/gio/presto_exp/prometheo_exp


22-01-2025 19:46:21 - INFO - Early stopping!
22-01-2025 19:46:21 - INFO - Early stopping!


INFO:__main__:Early stopping!
Train metric: 0.046, Val metric: 0.010, Best Val Loss: 0.010 (no improvement for 0 epochs):  20%|██        | 2/10 [01:52<07:29, 56.18s/it]

22-01-2025 19:46:22 - INFO - Finetuning done
22-01-2025 19:46:22 - INFO - Finetuning done



INFO:__main__:Finetuning done


PretrainedPrestoWrapper(
  (encoder): Encoder(
    (eo_patch_embed): ModuleDict(
      (S1): Linear(in_features=2, out_features=128, bias=True)
      (S2_RGB): Linear(in_features=3, out_features=128, bias=True)
      (S2_Red_Edge): Linear(in_features=3, out_features=128, bias=True)
      (S2_NIR_10m): Linear(in_features=1, out_features=128, bias=True)
      (S2_NIR_20m): Linear(in_features=1, out_features=128, bias=True)
      (S2_SWIR): Linear(in_features=2, out_features=128, bias=True)
      (ERA5): Linear(in_features=2, out_features=128, bias=True)
      (SRTM): Linear(in_features=2, out_features=128, bias=True)
      (NDVI): Linear(in_features=1, out_features=128, bias=True)
    )
    (dw_embed): Embedding(10, 128)
    (latlon_embed): Linear(in_features=3, out_features=128, bias=True)
    (blocks): ModuleList(
      (0-1): 2 x Block(
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=128, out_feature