# Mapping temporary crops in Zambia

In this exercise we will take you through the process of training a custom cropland classifier for a specific region.<br>
Cropland is defined in this exercise as "temporary crops", meaning crops harvested within 12 months after planting/sowing.

### Before you start

Make sure this notebook has access to all required functionalities:

In [None]:
# add parent dirctory to sys.path
import sys
sys.path.append('..')

### 1. Gather and prepare your training data

For the purpose of this exercise we already prepared a bounding box for you from which we will extract training data.<br>

The following cell demonstrates the location of all public datasets in our public extractions database and additionally plots our area of interest on top in red.

In [None]:
from shapely.geometry import box
import geopandas as gpd
from notebook_utils.extractions import retrieve_extractions_extent

extents = retrieve_extractions_extent()
print(f"Found {len(extents)} datasets with extractions.")
m = extents.explore(
            style_kwds={
                "fillOpacity": 0.05,  # Transparency of polygon fill (0 = fully transparent, 1 = opaque)
                "weight": 1,  # Border line width
            },
            highlight=False,
        )

# define area around Zambia, Zimbabwe, Tanzania
extent = [21.708984, -25.958045, 46.142578, 6.664608]
polygon = box(*extent)
gdf = gpd.GeoDataFrame({"name": ["zmb_zwe_tza_region"]}, geometry=[polygon], crs="EPSG:4326")
# Draw it on the same map
gdf.explore(
    m=m,
    name="Our area of interest",
    style_kwds=dict(color="darkred", weight=2, fillOpacity=0.1)
)

Retrieve the data for Zambia and Zimbabwe collected by CIMMYT:

In [None]:
from pathlib import Path

extractions_path = Path('./extractions_zmb')
if not extractions_path.exists():
    print("Downloading Zambian data...")
    remote_url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/demo/zambia/extractions_zmb.zip"
    import urllib.request
    urllib.request.urlretrieve(remote_url, './extractions_zmb.zip')
    import zipfile
    with zipfile.ZipFile('./extractions_zmb.zip', 'r') as zip_ref:
        zip_ref.extractall(extractions_path)
else:
    print("Zambian data already exists.")


In [None]:
from shapely.geometry import box
from notebook_utils.extractions import query_extractions

# Specify a buffer distance to expand your search perimeter
buffer = 0  # meters

# Specify the path to the private extractions data; 
private_extractions_path = Path('./extractions_zmb/worldcereal_merged_extractions.parquet')

# Specify whether you are only interested in temporary crops only (True) or all available classes (False)
filter_temporary_crops = False

# Query our public database of training data
extractions = query_extractions(bbox_poly=polygon, buffer=buffer, private_parquet_path=private_extractions_path, filter_cropland=filter_temporary_crops)
extractions.head()

**Perform a quick quality check**

In [None]:
from notebook_utils.extractions import get_band_statistics, visualize_timeseries

dataset_name = input('Enter the dataset name: ')
subset_data = extractions.loc[extractions['ref_id'] == dataset_name]

# Check band statistics
band_stats = get_band_statistics(subset_data)

# Visualize timeseries for a few samples (5 by default)
visualize_timeseries(subset_data, nsamples=5)

**Select your season of interest**

In [None]:
from openeo_gfmap import BoundingBoxExtent
from worldcereal.utils.map import _latlon_to_utm
from notebook_utils.seasons import retrieve_worldcereal_seasons

bbox, epsg = _latlon_to_utm(polygon.bounds)
spatial_extent = BoundingBoxExtent(*bbox, epsg)

seasons = retrieve_worldcereal_seasons(spatial_extent)

In [None]:
from notebook_utils.seasons import valid_time_distribution

valid_time_distribution(extractions)

In [None]:
from notebook_utils.dateslider import season_slider

slider = season_slider()

**Compute training features**

In [None]:
from notebook_utils.classifier import align_extractions_to_season

# Retrieve the date range you just selected
season = slider.get_selected_dates()

# Align the extractions to the selected season
training_df = align_extractions_to_season(extractions, season, valid_time_buffer=2)
training_df.head()

**Select your crops of interest**

Make sure to select all land cover classes, except for unspecified cropland (we do not know whether this is permanent crops or temporary crops) and mixed cropland.

In [None]:
from notebook_utils.croptypepicker import CropTypePicker

croptypepicker = CropTypePicker(sample_df=training_df, expand=False)

In [None]:
from notebook_utils.croptypepicker import apply_croptypepicker_to_df
from worldcereal.utils.legend import translate_ewoc_codes

training_df = apply_croptypepicker_to_df(training_df, croptypepicker)
other_count = training_df.loc[training_df['downstream_class'] == 'other']['ewoc_code'].value_counts()
other_labels = translate_ewoc_codes(other_count.index.tolist())
other_class = other_count.to_frame().merge(other_labels, left_index=True, right_index=True)
other_class

In [None]:
# drop other class
training_df = training_df.loc[training_df['downstream_class'] != 'other']
training_df['downstream_class'].value_counts()

Now we have two options:
1. we group all classes which do not represent temporary crops in one big class (non_crop)
2. we make a distinction between vegetated classes which are not temporary crops (non_crop_veg) and non-vegetated land cover classes (non_crop_non_veg)

Make sure you try a different option compared to your neighbour!

Maybe you can think of another class compilation?

In [None]:
# OPTION 1:
combine_classes = {
    'non_crop': ['non_cropland_herbaceous', 'shrubland', 'trees_unspecified_deciduous', 'non_cropland_incl_perennial', 'permanent_crops', 'trees_mixed',
    'bare_sparsely_vegetated', 'built_up', 'open_water'
],
}
# OPTION 2:
# combine_classes = {
#     'non_crop_veg': ['non_cropland_herbaceous', 'shrubland', 'trees_unspecified_deciduous', 'non_cropland_incl_perennial', 'permanent_crops', 'trees_mixed'],
#     'non_crop_non_veg': ['bare_sparsely_vegetated', 'built_up'],
# }
for new_class, old_classes in combine_classes.items():
    training_df.loc[training_df['downstream_class'].isin(old_classes), 'downstream_class'] = new_class

# Report on the contents of the data
training_df['downstream_class'].value_counts()

**Finally, compute Presto embeddings**

In [None]:
from notebook_utils.classifier import compute_presto_embeddings

embeddings_df = compute_presto_embeddings(training_df,
    task_type='cropland',
    augment=True,  # apply temporal jittering
    mask_on_training=True,  # apply sensor masking to training split
    repeats=3,  # number of times to augment each training sample
    custom_presto_url='https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-prometheo-landcover-MulticlassWithCroplandAuxBCELoss-labelsmoothing%3D0.05-month-LANDCOVER10-augment%3DTrue-balance%3DTrue-timeexplicit%3DFalse-masking%3Denabled-run%3D202510301004_encoder.pt'
)

### 2. Train cropland model

In [None]:
from notebook_utils.classifier import train_classifier

custom_model, report, confusion_matrix = train_classifier(
    embeddings_df, balance_classes=True, show_confusion_matrix='relative',
)
print(report)

### 3. Deploy your model

In [None]:
from worldcereal.utils.upload import deploy_model
from openeo_gfmap.backend import cdse_connection
from notebook_utils.classifier import get_input

modelname = get_input("model")
model_url = deploy_model(cdse_connection(), custom_model, pattern=modelname)
print(f"Your model can be downloaded from: {model_url}")

### 4. Apply model to test patches

In [None]:
from pathlib import Path
import xarray as xr
from pyproj import CRS
from notebook_utils.local_inference import run_cropland_mapping, classification_to_geotiff
from notebook_utils.visualization import visualize_product
from worldcereal.utils.models import load_model_lut

test_patches = ['lusaka-mar-25','zmb-north-mar-2025']

for patch in test_patches:
    target_path = Path(f"./preprocessed_inputs/{patch}")
    target_path.parent.mkdir(exist_ok=True)
    # Download and extract the data if not already present
    if not target_path.exists():
        url = f"https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/demo/zambia/{patch}.zip"
        # Download the file
        import urllib.request
        local_zip_file = f"./preprocessed_inputs/{patch}.zip"
        urllib.request.urlretrieve(url, local_zip_file)
        import zipfile
        with zipfile.ZipFile(local_zip_file, 'r') as zip_ref:
            zip_ref.extractall(target_path)
        
    # get preprocessed inputs
    local_file_path = Path(f"{target_path}/0/preprocessed-inputs_2024-09-01_2025-08-31_0.nc")
    # Open the preprocessed inputs file
    ds = xr.open_dataset(local_file_path)

    # Get the EPSG code and convert to xarray DataArray
    crs_attrs = ds["crs"].attrs
    epsg = CRS.from_wkt(ds.crs.attrs["spatial_ref"]).to_epsg()  
    arr = ds.drop_vars("crs").fillna(65535).astype("uint16").to_array(dim="bands")

    landcover_embeddings, cropland_classification = run_cropland_mapping(arr, epsg=epsg,
    custom_presto_url='https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-prometheo-landcover-MulticlassWithCroplandAuxBCELoss-labelsmoothing%3D0.05-month-LANDCOVER10-augment%3DTrue-balance%3DTrue-timeexplicit%3DFalse-masking%3Denabled-run%3D202510301004_encoder.pt',
    classifier_url=model_url,
)

    product_path = Path(f"./local_inference/{patch}/cropland_classification.tif")
    product_path.parent.mkdir(exist_ok=True)
    classification_to_geotiff(
        cropland_classification,
        epsg,
        product_path)
    lut = load_model_lut(model_url)
    visualize_product(product_path, product='croptype', lut=lut, interactive_mode=False)

Congratulations, you have reached the end of this demo!