![](../resources/Custom_croptype_map.png)

### Content

- [Introduction](###-Introduction)
- [How to run this notebook?](###-How-to-run-this-notebook?)
- [Before you start](###-Before-you-start)
- [1. Choose finetuned Presto model for your experiments](###-1.-Choose-finetuned-Presto-model-for-your-experiments)
- [2. Train your model](###-2.-Train-your-model)
- [3. Local inference](###-3.-Local-inference)

### Introduction

This notebook guides you through the process of training a custom crop type classification model for your area, season and crop types of interest.

For training the model, you can use a combination of:
- publicly available reference data harmonized by the WorldCereal consortium;
- your own private reference data.

<div class="alert alert-block alert-warning">
In case you would like to use private reference data to train your model, make sure to first complete our <a href='https://github.com/WorldCereal/worldcereal-classification/blob/main/notebooks/worldcereal_private_extractions_app.ipynb' target='_blank' rel='noopener'>private extractions workflow.</a>

For the sake of simplicity here, we will only make use of public reference data.
</div>

After model training, we deploy your custom model to the cloud, from where it can be accessed by OpenEO, allowing you to apply your model on your area and season of interest and generate your custom crop type map.

We also provide an alternative pathway to run inference locally, on a series of prepared patches.

### How to run this notebook?

#### Option 1: Run on Terrascope

You can use a preconfigured environment on [**Terrascope**](https://terrascope.be/en) to run the workflows in a Jupyter notebook environment. Just register as a new user on Terrascope or use one of the supported EGI eduGAIN login methods to get started.

Once you have a Terrascope account, you can run this notebook by clicking the button shown below.

<div class="alert alert-block alert-warning">When you click the button, you will be prompted with "Server Options".<br>
Make sure to select the "Worldcereal" image here. Did you choose "Terrascope" by accident?<br>
Then go to File > Hub Control Panel > Stop my server, and click the link below once again.</div>


<a href="https://notebooks.terrascope.be/hub/user-redirect/git-pull?repo=https%3A%2F%2Fgithub.com%2FWorldCereal%2Fworldcereal-classification&urlpath=lab%2Ftree%2Fworldcereal-classification%2Fnotebooks%2Ftrainings%2F20260212_Kenya_workshop_croptype.ipynb&branch=main"><img src="https://img.shields.io/badge/Run%20notebook%20on-Terrascope-brightgreen" alt="Run notebook" valign="middle"></a>


<div class="alert alert-block alert-warning">
<b>WARNING:</b> <br>
Every time you click the above link, the latest version of the notebook will be fetched, potentially leading to conflicts with changes you have made yourself.<br>
To avoid such code conflicts, we recommend you to make a copy of the notebook and make changes only in your copied version.
</div>


#### Option 2: Install Locally

If you prefer to install the package locally, you can create the WorldCereal environment using **Conda** or **pip**.

- **Conda**<br>
First clone the repository:
```bash
git clone https://github.com/WorldCereal/worldcereal-classification.git
cd worldcereal-classification
```
Next, install the package locally:
`conda env create -f environment.yml`


- **Pip**<br>
`pip install "worldcereal-classification[train,notebooks] @ git+https://github.com/worldcereal/worldcereal-classification.git"`

### Download test dataset

Before starting the below exercise, you are encouraged to try out the [RDM](https://rdm.esa-worldcereal.org) data upload procedure using the dataset automatically downloaded to your "download" folder.

Execute the next cell to start the download.

In [1]:
from pathlib import Path

target_path = Path("./download/zmb_north_test_dataset.gpkg")
target_path.parent.mkdir(exist_ok=True)
# Download and extract the data if not already present
if not target_path.exists():
    url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/demo/zambia/zmb_north_test_dataset.gpkg"
    # Download the file
    import urllib.request
    urllib.request.urlretrieve(url, target_path)

### Before you start

In order to run WorldCereal crop mapping jobs from this notebook, you need to create an account on the [Copernicus Data Space Ecosystem](https://dataspace.copernicus.eu/).<br>
This is free of charge and will grant you a number of free openEO processing credits to continue this demo.


#### CDSE authentication

Run the cell below to make sure you are connected to your CDSE account before starting the application.

In [None]:
from openeo_gfmap.backend import cdse_connection
cdse_connection()

#### Ensure proper access to functionality

Execute the next block of code to ensure this notebook has access to all functionalities needed in this exercise.

In [None]:
# add parent dirctory to sys.path
import sys
sys.path.append('..')

### 1. Choose finetuned Presto model for your experiments

In the cell below, you have the choice to either use our default presto model that was trained for global requirements, or a more specialised model that was finetuned for Kenya.

Simply set `model_bundle` parameter to either "kenya" or "default".

We recommend to use the kenya model.

Execute the next cell before continuing!

In [None]:
from pathlib import Path
import urllib.request
from worldcereal.train.backbone import checkpoint_fingerprint
from worldcereal.openeo.parameters import DEFAULT_SEASONAL_MODEL_URL

model_bundle = "kenya"
# model_bundle = "default"

if model_bundle == "kenya":
    presto_model_path = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/demo/kenya/presto-prometheo-kenya-finetuned-noaugment-month-augment=False-balance=True-timeexplicit=True-masking=enabled-run=202602101037_encoder.pt"
    seasonal_model_zip_path = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/demo/kenya/presto-prometheo-kenya-finetuned-noaugment-month-augment=False-balance=True-timeexplicit=True-masking=enabled-run=202602101037.zip"
    # Download the kenya presto model to your machine to run training and local inference
    local_pt = Path("./local_inference/kenya_local_presto_encoder.pt")
    local_pt.parent.mkdir(parents=True, exist_ok=True)
    if not local_pt.exists():
        urllib.request.urlretrieve(presto_model_path, local_pt)
    # We'll need to explicitly check model's fingerprint here
    presto_fp = checkpoint_fingerprint(local_pt)
    
    presto_model_package = {"presto_remote_path": presto_model_path,
                            "presto_local_path": str(local_pt),
                            "presto_fingerprint": presto_fp,
                            "seasonal_model_path": seasonal_model_zip_path}

    print("Finetuned Presto model for Kenya selected.")
    
elif model_bundle == "default":
    seasonal_model_zip_path = DEFAULT_SEASONAL_MODEL_URL
    presto_model_package = None
    
    print("Globally finetuned Presto model selected.")
    
else:
    raise ValueError(f"Unknown model bundle: {model_bundle}. Supported options are 'kenya' and 'default'.")

### 2. Train your model

Launch the application in the next cell to train your crop type model for Kenya!

Make sure to choose the option "**Full workflow**" in the welcome screen of the application.<br>

Then complete steps 1 --> 6 (stop before deploying your model in step 7).<br>

Brief outline of steps:
1. Retrieve existing reference data (make sure to select Kenya as AOI!)
2. Inspect the reference data
3. Choose your season of interest and align your reference data to that season
4. Compute Presto embeddings on your reference data
5. Select data and compile the list of classes you want to include in your model
6. Train your model

Once finalized, **continue with the next cell in this notebook**.

Later, you can still run steps 7 -> 9 (inference in the cloud for your area of interest).

In [None]:
from notebook_utils.training_app import WorldCerealTrainingApp

app = WorldCerealTrainingApp().run(presto_model_package=presto_model_package)

### 3. Local Inference

As an alternative to the model deployment and inference steps as exposed in the application above (steps 7 -> 9), we offer here the possibility to quickly run your newly trained model on some prepared patches across Kenya. For these patches, we have already fetched the required pre-processed time series inputs from CDSE using openEO workflows. For more information on this procedure, you can consult the [worldcereal_preprocessed_inputs notebook](https://github.com/WorldCereal/worldcereal-classification/blob/main/notebooks/worldcereal_preprocessed_inputs.ipynb).

The next cells load one of these prepared patches, show a random NDVI slice for context, and call `run_seasonal_inference()` with the seasonal backbone plus your custom croptype head to produce a croptype raster for the chosen season.

This step is meant for fast QA and tuning of inference knobs and post-processing; it does not replace the full openEO production run. The output is written as a GeoTIFF under `./local_inference/` with a filename that includes the patch and head package, so you can inspect it in your GIS tools or compare with reference data.


**Download and visualize one patch**

In [None]:
import random
import numpy as np
from pathlib import Path
import urllib.request

from pyproj import CRS
import xarray as xr

from notebook_utils.preprocessed_inputs import get_band_statistics_netcdf, visualize_timeseries_netcdf

# randomly select an option from the following list:
options = ["01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12"]
selected = random.choice(options)

patch_id = f"patch{selected}"
print(f"Selected patch: {patch_id}")

# Download the demo patch from artifactory if the requested patch is missing locally.
local_file_path = Path(f"./local_inference/input_patches/preprocessed-inputs_{patch_id}.nc")
if not local_file_path.exists():
    local_file_path.parent.mkdir(parents=True, exist_ok=True)
    if not local_file_path.exists():
        print(f"Downloading demo preprocessed inputs to {local_file_path}...")
        remote_url = (
            f"https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/demo/kenya/preprocessed-inputs_{patch_id}.nc"
        )
        urllib.request.urlretrieve(remote_url, local_file_path)

ds = xr.open_dataset(local_file_path, engine="netcdf4")
epsg = CRS.from_wkt(ds.crs.attrs["spatial_ref"]).to_epsg()

# Visualize a random timestamp from the patch
ndvi = (ds["S2-L2A-B08"] - ds["S2-L2A-B04"]) / (ds["S2-L2A-B08"] + ds["S2-L2A-B04"])
timestamp_ind = np.random.randint(0, ndvi.shape[0])
ndvi.isel(t=timestamp_ind).plot(cmap="RdYlGn", vmin=-0.8, vmax=0.8)

# Show band statistics
stats = get_band_statistics_netcdf(ds)

# Visualize random time series
visualize_timeseries_netcdf(ds, band="NDVI", npixels=6, random_seed=42)

**Set processing parameters**

Configure the inference and post-processing parameters:

- `mask_cropland`: Apply cropland masking during croptype predictions
- `enable_cropland_head`: Also export cropland probability rasters for quality assurance
- `export_class_probs`: Emit per-class probability layers for each crop type
- `croptype_postprocess_enabled`: Apply spatial post-processing to croptype classifications
- `croptype_postprocess_method`: Post-processing algorithm (e.g., "majority_vote")
- `croptype_postprocess_kernel`: Kernel size for post-processing filter
- `cropland_postprocess_enabled`: Apply spatial post-processing to cropland mask
- `cropland_postprocess_method`: Post-processing algorithm for cropland
- `cropland_postprocess_kernel`: Kernel size for cropland post-processing

In [None]:
mask_cropland = True  # cropland masking using default cropland model
enable_cropland_head = True  # also export cropland rasters for QA
enable_croptype_head = True  # whether to run the croptype head at all (if False, only cropland predictions will be made)
export_class_probs = True  # export per-class probabilities in the crop type product

croptype_postprocess_enabled = True
croptype_postprocess_method = "majority_vote"
croptype_postprocess_kernel = 3

cropland_postprocess_enabled = True
cropland_postprocess_method = "majority_vote"
cropland_postprocess_kernel = 3

**Specify year and season**

For your information, we first display for which season_id and growing season period you have trained your model:

In [None]:
season_id = app.season_id
print(f"Model trained for season id: {season_id}")
season_window = app.season_window
print(f"Model trained for season: {season_window.start_date} - {season_window.end_date}")

In the next cell, you can now manually set a start and end date + season id.

**NOTE**: for the preprocessed input patches you are working with here, we only have data available between 2025-01-01 and 2026-01-01 !!!

In [None]:
season_id = "xxx"                 # e.g. LongRains 
season_start_date = "YYYY-mm-dd"  # e.g. 2021-03-01
season_end_date = "YYYY-mm-dd"    # e.g. 2021-08-31

**Run local inference**

In [None]:
import json

from notebook_utils.local_inference import (
    run_seasonal_inference,
    build_postprocess_spec,
    classification_to_geotiff,
)

season_windows = {
    season_id: (str(season_start_date), str(season_end_date))
}

croptype_postprocess = build_postprocess_spec(
    enabled=croptype_postprocess_enabled,
    method=croptype_postprocess_method,
    kernel_size=croptype_postprocess_kernel,
)
cropland_postprocess = build_postprocess_spec(
    enabled=cropland_postprocess_enabled,
    method=cropland_postprocess_method,
    kernel_size=cropland_postprocess_kernel,
)

classification_result = run_seasonal_inference(
    ds,  # or local_file_path
    seasonal_model_zip=seasonal_model_zip_path,
    croptype_head_zip=app.head_package_path,  # the crop type model you trained in the app
    season_ids=[season_id],
    season_windows=season_windows,
    enforce_cropland_gate=mask_cropland,
    export_class_probabilities=export_class_probs,
    enable_cropland_head=enable_cropland_head,
    enable_croptype_head=enable_croptype_head,
    croptype_postprocess=croptype_postprocess,
    cropland_postprocess=cropland_postprocess,
    as_dataset=False,  # DataArray for GeoTIFF
)

# Specify output directory and name for the classification result
output_dir = Path("./local_inference")
output_dir.mkdir(parents=True, exist_ok=True)
# retrieve model name
head_tag = Path(app.head_package_path).stem
classification_result_path = output_dir / f"croptype_{season_id}_{patch_id}_{head_tag}.tif"

# get model class names from the model config file
head_output_dir = app.head_output_path
head_config_path = head_output_dir / "config.json"
if not head_config_path.exists():
    raise FileNotFoundError(
        f"Torch head config not found at {head_config_path}. Check the training logs above."
    )
with head_config_path.open() as fp:
    head_config = json.load(fp)
class_map = {i: name for i, name in enumerate(head_config["classes_list"])}

# Finally, save to geotiff
classification_to_geotiff(classification_result, epsg, classification_result_path, class_map)

The following cell helps you to quickly visualize your maps.

For more detailed inspection (for instance of probabilities), we advise to load the product(s) in QGIS.

In [None]:
from notebook_utils.local_inference import isolate_cropland_croptype_products
from notebook_utils.visualization import visualize_products

# split one raster into separate cropland and crop type products
products = isolate_cropland_croptype_products(
    classification_result_path,
    enable_cropland_head)

# Get class names from your model and convert to look-up table compatible with visualization function
lut_croptype = {v: k for k, v in sorted(class_map.items(), key=lambda kv: kv[0])}
luts = {'croptype': lut_croptype}

# Run simple visualization of the classification bands
visualize_products(products, luts=luts)

Congratulations, you have reached the end of this demo!

You can now go back to the app and run through steps 7 - 9 to launch a map production job in your area of choice!