# Few-Shot Learning with Presto

## Notebook Overview 

1) Short introduction on Foundation Models and Presto
2) Definition of Few-Shot learning
3) Apply Presto to perfrom Few-Shot learning on a regression and a classification task

### 1) Foundation Models

A Foundation Model is a model trained on large and diverse unlabeled datasets to learn general patterns and features of the data. Thanks to its strong generalization capabilities, such a model can be adapted for a wide range of applications that use similar types of input data.

**Presto** (**P**retrained **Re**mote **S**ensing **T**ransf**o**rmer) is a foundation model trained on a large, unlabeled dataset of Sentinel-2, Sentinel-1, Meteorological and Topography pixel-timeseries data. It is able to capture long-range relationships across time and sensor dimensions, improving the signal-to-noise ratio and providing a concise, informative representation of the inputs. 
In this project, We made use of the [Presto](https://github.com/WorldCereal/prometheo.git) version developed in collaboration with WorldCereal

Originally trained on monthly composites, Presto has been refined to be able to ingest dekadal data and to be fine-tuned for regression and classification tasks.

### 2) Few-Shot Learning

Few-shot learning aims to develop models that can learn from a small number of labeled instances while enhancing generalization and performance on new, unseen examples.

Given a dataset with only a few annotated examples, we can fine-tune a pretrained foundation model to either directly handle the downstream task or generate compressed representations of the inputs (embeddings), which can then be used to train a machine learning model for the downstream task. We hereby show-case the former scenario, whose overview is depicted in the figure below.

<div style="text-align: center;">
    <img src="../images/ScaleAG_pipeline_overview_presto.jpg" alt="Overview of a Foundation Model fine tuned for different downstream tasks and applications." width="700" />
    <p><em>Overview of a Foundation Model fine tuned for different downstream tasks and applications.</em></p>
</div>

### 3) Implementing Few-Shot learning with Presto

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
import geopandas as gpd
from scaleagdata_vito.presto.datasets_prometheo import ScaleAgDataset
from scaleagdata_vito.openeo.extract_sample_scaleag import generate_input_for_extractions, extract
from scaleagdata_vito.presto.utils import evaluate_finetuned_model
from scaleagdata_vito.presto.presto_df import load_dataset
from scaleagdata_vito.presto.utils import train_test_val_split, finetune_on_task, load_finetuned_model, get_pretrained_model_url, get_resources_dir
from scaleagdata_vito.presto.inference import PrestoPredictor, reshape_result, plot_results
from scaleagdata_vito.utils.map import ui_map
from scaleagdata_vito.utils.dateslider import date_slider
from scaleagdata_vito.openeo.extract_sample_scaleag import collect_inputs_for_inference

### Before we start...

**Check your data!** Investigate validity of geometries uniqueness of sample IDs, presence of outliers and so on before starting the extraction. Achieving good performance making use of a limited amount of data is a challening task per se. Therefore, **the quality of your data will greatly impact your final results.**

Data requirements:
- Points or Polygons (will be aggregated in points)
- Lat-Lon (crs:4326) 
- Format: parquet, GeoJSON, shapefile, GPKG
For each geometry:
- Date (if available) 
- Unique ID
- Annotations

Good practice:

Remove polygons close to borders (e.g. apply buffer) to ensure data are contained in the field
If the annotations are accurate, point geometries should be preferred. However, especially in regression tasks (i.e., continuous output values) such us yield estimation the target values might be noisy. In that case, we recommend subdividing the polygons in subfields of 20m x 20m (to cover more measurements) and computing the median yield for a smoother and more reliable target


#### Requirements for running the extractions
- Account in [Copernicus Data Space Ecosystem (CDSE)](https://dataspace.copernicus.eu/). You can sign up for free and have a monthly availability of 10000 credits.
- A dataset with valid geometries (Points or Polygons) in lat-lon projection.
- Preferably a dataset with unique IDs per sample 
- A labelled dataset. Not required for the extraction process, but for the following fine-tuning steps.

#### EO data extractions
In this first step, we extract for each sample in your dataset the required EO time series from CDSE using OpenEO.
For running the job, the user should indicate the following job_dictionary fields:

```python
    job_params = dict(
        output_folder=..., # where to save the extracted dataset
        input_df=..., # input georeferenced dataset to run the extractions for 
        start_date=..., # string indicating from which date to extract data  
        end_date=..., # string indicating until which date to extract the data 
        unique_id_column=..., # name of the column in the input_df containing the unique ID of the samples  
        composite_window=..., # "month" or "dekad" are supported. Default is "dekad"
    )
```
in particular:
- If the `date` information associated with the label is provided, the `start_date` of the time-series is automatically set to 9 months before the date, whereas the `end_date` is set to 9 months after. If `date` is not available, the user needs to manually indicate the desired `start_date` and `end_date` for the extractions.
- `composite_window` indicates the time-series granularity, which can be dekadal or monthly. 
  - `dekad`: each time step in the extracted time series corresponds to a mean-compositing operation on 10-days acquisitions. Accordingly with the start and end date, each month will be covered by 3 time steps which, by default, correspond to the 1st, 11th and 21th of the month. 
  - `month`: each time step in the extracted time series corresponds to a mean-compositing operation on 30-days acquisitions. Each month will be covered by 1 time step which, by default, correspond to the 1st of the month.

The following decadal/monthly time series will be extracted for the indicated time range:

- Sentinel-2 L2A data (all bands)
- Sentinel-1 VH and VV
- Average air temperature and precipitation sum derived from AgERA5
- Slope and elevation from Copernicus DEM

Presto accepts 1D time-series. Therefore, if Polygons are provided for the extractions, the latter are spatially aggregated in points which will correspond to the centroid lat lon geolocation.

### Regression task: potato yield estimation 

The data covers fields in Belgium during the growing season of 2022. Each field polygon was partitioned in subfields of 20m x 20m. The latter are partitioned into training, validation and test sets. 

**NOTE:** This is a very small dummy dataset with randomized yield values. No meaningful results are expected from using such data for training.

##### Data Extraction

In [None]:
output_folder = Path("/home/giorgia/Private/data/scaleag/demo/regression")
input_df = get_resources_dir() / "dummy_yield.geojson"
start_date = "2022-01-01"
end_date = "2022-12-31"
unique_id_column = "fieldname"
composite_window = "dekad"

In [None]:
# check input data structure 
gpd.read_file(input_df) 

In [None]:
job_params = dict(
    output_folder=output_folder,
    input_df=input_df,
    unique_id_column=unique_id_column,
    composite_window=composite_window,
)
extract(generate_input_for_extractions(job_params))

Once the dataset will be extracted, it can be loaded with the `load_dataset` function by specifying the path where the `.parquet` files have been downloaded. Moreover, the following manipulations of the dataset are also possible:

- `window_of_interest`: the user can specify a time window of interest out of the whole available time-series. `start_date` and `end_date` should be provided as strings in a list.
- `use_valid_time`: the user might want to define the window of interest based on the `date` the label is associated with. If so, also `required_min_timesteps` should be provided
- `buffer_window`: buffers the `start_date` and `end_date` by the number of time steps here specified  

In the following cell, we load the extracted dataset for 1 year of data.

**NOTE:** this code currently assumes that we are dealing with 1 year of data falling in the same time period

##### Presto datasets initialization

In [None]:
df = load_dataset(
    files_root_dir=output_folder,
    window_of_interest=[start_date, end_date],
    composite_window=composite_window,
)

The following step splits the data into train, test and val datasets. the split can be performed by uniform sampling or by group sampling. The former is usually more suitable for the binary and multiclass classification tasks, to ensure the data distribution is represented in all the 3 sets. The latter is more specific for cases where we want to avoid data autocorrelation and so data leakage between training and val/test sets.
In the case of yield estimation, for instance, we often have samples coming from the same field. So we might want to separate the data based on the field they belong to to better test the model generalization capabilities.
Therefore:
- `uniform_sample_by`: pass the name of the column in the dataframe to perform uniform sampling on
- `group_sample_by`: pass the name of the column in the dataframe to perform the group sampling on

`sampling_fraction` indicates the proportion of the training set out of the whole dataset. the remaining percentage will be equally devided into validation and test sets.

In [None]:
df_train, df_val, df_test = train_test_val_split(df=df, group_sample_by="parentname", sampling_frac=0.8)

We now set up the parameters needed for initializing presto datasets for the specific task:
- `num_timesteps`: can be inferred by the max number of the `available_timesteps` 
- `target_name`: name of the column containing the target data
- `upper_bound` and `lower_bound`: these should be set to the min and max of the distribution. Therefore, it is important to get rid of potential outlaiers beforehand.

**NOTE:** upper and lower bounds are also used to normalize the targets during the training process. Therefore it is important to keep track of such values to convert the predictions to the original units in the inference step!  

In [None]:
# visualize distribution to check for outliers to exclude if needed
num_timesteps = df.available_timesteps.max()
task_type = "regression"
target_name = "median_yield"
upper_bound = df[target_name].max() 
lower_bound = 0 

We Initialize the training, validation and test datasets objects to be used for training Presto.

In [None]:
# initialize datasets
train_ds = ScaleAgDataset(
    df_train,
    num_timesteps=num_timesteps,
    task_type=task_type,
    target_name=target_name,
    composite_window=composite_window,
    upper_bound=upper_bound,
    lower_bound=lower_bound,
)
val_ds = ScaleAgDataset(
    df_val,
    num_timesteps=num_timesteps,
    task_type=task_type,
    target_name=target_name,
    composite_window=composite_window,
    upper_bound=upper_bound,
    lower_bound=lower_bound,
)
test_ds = ScaleAgDataset(
    df_test,
    num_timesteps=num_timesteps,
    task_type=task_type,
    target_name=target_name,
    composite_window=composite_window,
    upper_bound=upper_bound,
    lower_bound=lower_bound,
)

##### Presto Finetuning

In this section Presto will be Fine-Tuned in a supervised way for the target downstream task. first we set up the following experiment parameters:

- `output_dir` : where to dave the model 
- `experiment_name` : the model name
- `pretrained_model_path` : pretrained presto model to start the fine tuning from. Can be a string indicating the path to the model or a url 

In [None]:
# set models hyperparameters
model_output_dir = Path("/home/giorgia/Private/data/scaleag/demo/regression/")
experiment_name = "presto-ss-wc-10D-ft-dek"

In [None]:
# Construct the model with finetuning head starting from the pretrained model
finetuned_model = finetune_on_task(
    train_ds=train_ds,
    val_ds=val_ds,
    pretrained_model_path=get_pretrained_model_url(composite_window=composite_window),
    output_dir=model_output_dir, 
    experiment_name=experiment_name,
    num_workers=0,
    )
evaluate_finetuned_model(finetuned_model, test_ds, num_workers=0, batch_size=32)

##### Inference using Fine-Tuned end-to-end Presto

In this section, we apply the fine tuned model to generate a yield map on an unseen area. 
We need to indicate the spatial and temporal extent. The 2 cells below, offer a simple way for the user to provide these information and perform once again the extraction from CDSE of the EO time-series required by Presto. 
We also need to indicate the `output_dir` of where to save the datacube of the extraction, its `output_filename` and the `composite_window` which will be the same as used for finetuning the model.

In [None]:
map = ui_map(area_limit=7)

In [None]:
# select 1 year of data
slider = date_slider()

In [None]:
output_dir = Path("/home/giorgia/Private/data/scaleag/demo/regression")
output_filename = "inference_area"
inference_file = output_dir / f"{output_filename}.nc"

In [None]:
collect_inputs_for_inference(
    spatial_extent=map.get_extent(),
    temporal_extent=slider.get_processing_period(),
    output_path=output_dir,
    output_filename=f"{output_filename}.nc",
    composite_window=composite_window,
)

Once the datacube has been extracted, we can perform the inference task using the finetuned model and visualize the predicted map. 

In [None]:
inference_file = get_resources_dir() / "inference_area_tevuren.nc"
mask_path = get_resources_dir() / "LPIS_flanders_potatoes_2022.tif"

In [None]:
finetuned_model = load_finetuned_model(model_output_dir / experiment_name, task_type=task_type)
presto_model = PrestoPredictor(
    model=finetuned_model,
    batch_size=50,
    task_type=task_type,
    composite_window=composite_window,
)

predictions = presto_model.predict(inference_file, upper_bound=upper_bound, lower_bound=lower_bound, mask_path=mask_path)
predictions_map = reshape_result(predictions, path_to_input_file=inference_file)

In [None]:
plot_results(prob_map=predictions_map, path_to_input_file=inference_file, task=task_type, ts_index=5)

### Binary task: crop/no-crop

Now we test the few-shot learning on a binary task. We Fine-Tune presto on datapoints sampled from Flanders on 2021. This time, the dataset is the result of a monthly compositing. We initialize the parameters for the dataset preparation accordingly  

In [None]:
output_folder = Path("/home/giorgia/Private/data/scaleag/demo/worldcereal/")
input_df = get_resources_dir() / "dummy_cropland.geojson"
task_type = "binary"
target_name = "LANDCOVER_LABEL"
composite_window = "month"
unique_id_column = "sample_id"

In [None]:
job_params = dict(
    output_folder=output_folder,
    input_df=input_df,
    unique_id_column=unique_id_column,
    composite_window=composite_window,
)
extract(generate_input_for_extractions(job_params))

##### Presto datasets initialization

In [None]:
df = load_dataset(
    files_root_dir=output_folder,
    composite_window=composite_window,
)
train_df, val_df, test_df = train_test_val_split(df=df, uniform_sample_by=target_name, sampling_frac=0.6)

In the context of a binary classification task, specifying the upper and lower bounds won't be necessary anymore. In case we are dealing with multiclass labels, we can convert the problem into a binary classification task by providing the `positive_labels` argument to the `ScaleAgDatset` class. The list of labels passed as value to `positive_labels` indicates which subset of classes should be interpreted as positive class (here "crop"). All the other labels will therefore be interpreted as negative class (here "no-crop").

In [None]:
num_timesteps = df.available_timesteps.max()
positive_labels = [10, 11, 12, 13]

In [None]:
train_ds = ScaleAgDataset(
    train_df,
    num_timesteps=num_timesteps,
    task_type=task_type,
    target_name=target_name,
    positive_labels=positive_labels,
    composite_window=composite_window,
)

val_ds = ScaleAgDataset(
    val_df,
    num_timesteps=num_timesteps,
    task_type=task_type,
    target_name=target_name,
    positive_labels=positive_labels,
    composite_window=composite_window,
)

test_ds = ScaleAgDataset(
    test_df,
    num_timesteps=num_timesteps,
    task_type=task_type,
    target_name=target_name,
    positive_labels=positive_labels,
    composite_window=composite_window,
)

##### Presto Finetuning

In this section Presto will be Fine-Tuned in a supervised way for the target downstream task. Once again, we set up the experiment parameters.

In [None]:
experiment_name = "presto_wc_ft_crop"
model_output_dir = Path("/home/giorgia/Private/data/scaleag/demo/worldcereal/")

In [None]:
finetuned_model = finetune_on_task(
                train_ds=train_ds,
                val_ds=val_ds,
                pretrained_model_path=get_pretrained_model_url(composite_window=composite_window),
                output_dir=model_output_dir, 
                experiment_name=experiment_name,
                num_workers=0,
            )
evaluate_finetuned_model(finetuned_model, test_ds, num_workers=0, batch_size=32)

##### Inference using Fine-Tuned end-to-end Presto

We now apply the finetuned model to an unseen area to perform the classification task.

In [None]:
finetuned_model = load_finetuned_model(model_output_dir / experiment_name, task_type=task_type)
inference_file = get_resources_dir() / "worldcereal_preprocessed_inputs.nc"

In [None]:
presto_model = PrestoPredictor(
    model=finetuned_model,
    batch_size=50,
    task_type=task_type,
    composite_window=composite_window,
)

predictions = presto_model.predict(inference_file)
prob_map = reshape_result(predictions, path_to_input_file=inference_file)
pred_map = presto_model.get_predictions(prob_map, threshold=0.75)
plot_results(prob_map=prob_map, pred_map=pred_map, path_to_input_file=inference_file, task=task_type, ts_index=4)