![](./resources/Custom_croptype_map.png)

**Study the Effect of local data**

In this notebook you will experiment the effect of local data on classification performance. 

- First we will experiment with the publicily available dataset in the region (Public datasets covering Zambia and Zimbabwe)
- Next you will add the local data provided by CIMMYT as part of the use-case exercise

**Download the prepared Public and Local data**

In [None]:
# add parent dirctory to sys.path
import sys
sys.path.append('..')

In [1]:
from pathlib import Path
training_fnames = ['public_zmb_zwe.parquet', 'local_zmb.parquet', 'local_zwe.parquet']

for fname in training_fnames:
    local_file_path = Path(f"./training_data/{fname}")
    local_file_path.parent.mkdir(exist_ok=True)
    local_deploy_dict = {}
    if not local_file_path.exists():
        print(f"Downloading demo preprocessed inputs to {local_file_path}...")
        remote_url = f"https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/demo/zambia/{fname}"
        import urllib.request
        urllib.request.urlretrieve(remote_url, local_file_path)

**Load public and local data**

The public data covering Zambia and Zimbabwe has already been queried and stored as a parquet file for your convience. You can load it by running the cell below. 

It also loads the local data covering the country you selected (CIMMYT)

In [6]:
from sklearn.model_selection import train_test_split
import pandas as pd

public_zmb_zwe_df = pd.read_parquet("./training_data/public_zmb_zwe.parquet")
local_country_df = pd.concat([pd.read_parquet("./training_data/local_zmb.parquet"), pd.read_parquet("./training_data/local_zwe.parquet")])

local_id_label_mapper = dict(zip(list(local_country_df['sample_id']), list(local_country_df['label_full'])))
public_id_label_mapper = dict(zip(list(public_zmb_zwe_df['sample_id']), list(public_zmb_zwe_df['label_full'])))

**Explore which datasets are included in the PUBLIC dataset**

In [2]:
public_zmb_zwe_df['ref_id'].unique()

**Check Valid time in public data**

In [3]:
from notebook_utils.seasons import valid_time_distribution
valid_time_distribution(public_zmb_zwe_df)

**Check Valid time in local  data**

In [4]:
valid_time_distribution(local_country_df)

Using the season slider below select a window that works best given the valid time of both the public and the local data

**Set the season slider**

**FOR THIS EXCERCISE IDEALLY SET IT TO 01 Oct - 30 Sep**

In [11]:
from notebook_utils.dateslider import season_slider
slider = season_slider()

VBox(children=(HTML(value="\n            <div style='text-align: center;'>\n                <div style='font-sâ€¦

In [5]:
from notebook_utils.classifier import align_extractions_to_season

season = slider.get_selected_dates()

# Align the extractions to the selected season
public_zmb_zwe_df = align_extractions_to_season(public_zmb_zwe_df, season, valid_time_buffer=2)
public_zmb_zwe_df.head()

# Align the extractions to the selected season
local_country_df = align_extractions_to_season(local_country_df, season, valid_time_buffer=2)
local_country_df.head()

**Compute training features**

Using a geospatial foundation model (Presto), we derive training features for each sample in the dataframe resulting from your query. Presto was pre-trained on millions of unlabeled samples around the world and finetuned on global labelled land cover and crop type data from the WorldCereal reference database. The resulting 128 *embeddings* (`presto_ft_0` -> `presto_ft_127`) nicely condense the Sentinel-1, Sentinel-2, meteo timeseries and ancillary data for your season of interest into a limited number of meaningful features which we will use for downstream model training.<br>

We provide some options aimed at increasing temporal robustness of your final crop model.<br>
This is controlled by the following arguments:
- `augment` parameter: when set to `True`, it introduces slight temporal jittering of the processing window, making the model more robust to slight variations in seasonality across different years. By default, this option is set to `True`, but especially when training a model for a specific region and year with good local data, disabling this option could be considered.
- `mask_on_training` parameter: when `True`, applies sensor masking augmentations (e.g. simulating S1/S2 dropouts, additional clouds, ancillary feature removals) only to the training split to improve robustness to real-world data gaps. The validation/test split is  kept untouched for fair evaluation.
- `repeats` parameter: number of times each training sample is (re)drawn with its augmentations. Higher values (>1) create more variants (with jitter/masking) and enlarge the effective training set, potentially improving generalization at the cost of longer embedding inference time.

We compute the embeddings for both the public and the local data by running the cell below

In [6]:
from notebook_utils.classifier import compute_presto_embeddings
import pandas as pd

public_zmb_zwe_df = compute_presto_embeddings(
    public_zmb_zwe_df,
    augment=True,
    mask_on_training=True,
    repeats=3
)

local_country_df = compute_presto_embeddings(
    local_country_df,
    augment=True,
    mask_on_training=True,
    repeats=3
)

**Select common crops in both public and local for fair evaluation**

In [7]:
import pandas as pd

def drop_small_classes(df, label_col="label_full", min_count=10):
    counts = df[label_col].value_counts()
    valid_labels = counts[counts >= min_count].index
    df_filtered = df[df[label_col].isin(valid_labels)].copy()
    return df_filtered

public_zmb_zwe_df['label_full'] = public_zmb_zwe_df['sample_id'].apply(lambda x: public_id_label_mapper[x])
local_country_df['label_full'] = local_country_df['sample_id'].apply(lambda x: local_id_label_mapper[x])

public_zmb_zwe_df = drop_small_classes(public_zmb_zwe_df)
local_country_df = drop_small_classes(local_country_df)

common_labels = set(public_zmb_zwe_df['label_full'].unique()) & set(local_country_df['label_full'].unique())

public_zmb_zwe_df = public_zmb_zwe_df[public_zmb_zwe_df['label_full'].isin(common_labels)].copy()
local_country_df = local_country_df[local_country_df['label_full'].isin(common_labels)].copy()

print(f"Number of common labels: {len(common_labels)}")
print("Common crops for which the model will be trained:", sorted(common_labels))

**Choose to select crops**
- crops added to crops_to_keep will be retained rest will be set to others
- if left empty, all crops will be retained

In [17]:
crops_to_keep = {'maize', 'unspecified_millet', 'unspecified_sorghum'}
crops_to_keep = {}  # to train for all above crops

if crops_to_keep:
    local_country_df.loc[~local_country_df['label_full'].isin(crops_to_keep), 'label_full'] = 'others'
    public_zmb_zwe_df.loc[~public_zmb_zwe_df['label_full'].isin(crops_to_keep), 'label_full'] = 'others'
else: 
    pass

**Create and indenpendent test set from the LOCAL data**

In [8]:
train_local_country, test_local_country = train_test_split(
    local_country_df,
    test_size=0.2,
    stratify=local_country_df["label_full"],
    random_state=42
)

print("Public dataset:", len(public_zmb_zwe_df), "train,")
print("Local Zambia dataset:", len(train_local_country), "train,", len(test_local_country), "test")

**Add a downstream class column**

In [19]:
public_zmb_zwe_df['downstream_class'] = public_zmb_zwe_df['label_full']
train_local_country['downstream_class'] = train_local_country['label_full']
test_local_country['downstream_class'] = test_local_country['label_full']

**Train a classifier on the PUBLIC dataset alone**

In [9]:
from notebook_utils.classifier import train_classifier

custom_model_pub, report, confusion_matrix = train_classifier(
    public_zmb_zwe_df, balance_classes=True, show_confusion_matrix='absolute',
)
print(report)

**Train a classifier on PUBLIC + LOCAL dataset alone**

In [10]:
from notebook_utils.classifier import train_classifier

custom_model_pub_loc, report, confusion_matrix = train_classifier(
    pd.concat([public_zmb_zwe_df, train_local_country]), balance_classes=True, show_confusion_matrix='absolute',
)
print(report)

**Evaluate PUBLIC Model on independent LOCAL test data**

In [11]:
from notebook_utils.classifier import train_classifier, apply_classifier
test_report_public, test_cm_public, _ = apply_classifier(
    test_local_country,
    custom_model_pub,
    show_confusion_matrix='absolute',
)

**Evaluate PUBLIC + LOCAL Model on independent LOCAL test data**

In [12]:
from notebook_utils.classifier import train_classifier, apply_classifier
test_report_pub_loc, test_cm_pub_loc, _ = apply_classifier(
    test_local_country,
    custom_model_pub_loc,
    show_confusion_matrix='absolute',
)

**Deploy the model**

In [13]:
from worldcereal.utils.upload import deploy_model
from openeo_gfmap.backend import cdse_connection
from notebook_utils.classifier import get_input

modelname = get_input("model")
model_url = deploy_model(cdse_connection(), custom_model_pub_loc, pattern=modelname)
print(f"Your model can be downloaded from: {model_url}")

**Select Inference Patches**

- '2022_zimbabwe_Region IV_02_major_grid01_2022-10-01_2023-09-30_0',
- '2023_zimbabwe_Region IV_02_major_grid01_2023-10-01_2024-09-30_0',
- '2022_zimbabwe_Region Va_03_major_grid01_2022-10-01_2023-09-30_0',
- '2023_zimbabwe_Region Va_03_major_grid01_2023-10-01_2024-09-30_0',
- '2022_zambia_IIB_01_major_grid01_2022-10-01_2023-09-30_0',
- '2023_zambia_IIB_01_major_grid01_2023-10-01_2024-09-30_0',
- '2024_zambia_IIB_01_major_grid01_2024-10-01_2025-09-30_0',
- '2022_zambia_IIB_01_major_grid02_2022-10-01_2023-09-30_0',
- '2023_zambia_IIB_01_major_grid02_2023-10-01_2024-09-30_0',
- '2024_zambia_IIB_01_major_grid02_2024-10-01_2025-09-30_0',
- '2022_zambia_IIB_01_major_grid03_2022-10-01_2023-09-30_0',
- '2023_zambia_IIB_01_major_grid03_2023-10-01_2024-09-30_0',
- '2024_zambia_IIB_01_major_grid03_2024-10-01_2025-09-30_0',
- '2022_zambia_III_02_major_grid02_2022-10-01_2023-09-30_0',
- '2023_zambia_III_02_major_grid02_2023-10-01_2024-09-30_0',
- '2024_zambia_III_02_major_grid02_2024-10-01_2025-09-30_0'}

**Recommended Inference Patches (Ideally Keep the number of patches to less that 4)**

In [67]:
deploy_fnames = {'2022_zimbabwe_Region IV_02_major_grid01_2022-10-01_2023-09-30_0',
                 '2023_zimbabwe_Region IV_02_major_grid01_2023-10-01_2024-09-30_0',
                 '2023_zambia_III_02_major_grid02_2023-10-01_2024-09-30_0',
                 '2024_zambia_III_02_major_grid02_2024-10-01_2025-09-30_0'}

**Download the selected inference patches to your local**

In [14]:
from pathlib import Path
import urllib.request
import urllib.parse
import xarray as xr
from pyproj import CRS

arr_dict = {}

for fname in deploy_fnames:
    local_file_path = Path(f"./local_inference/{fname}.nc")
    local_file_path.parent.mkdir(exist_ok=True)
    print(local_file_path)

    try:
        if not local_file_path.exists():
            encoded_fname = urllib.parse.quote(fname)
            remote_url = f"https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/demo/test_patches/{encoded_fname}.nc"
            urllib.request.urlretrieve(remote_url, local_file_path)
            print(remote_url)
    
        ds = xr.open_dataset(local_file_path)
    
        epsg = None
        if "crs" in ds:
            try:
                epsg = CRS.from_wkt(ds["crs"].attrs.get("spatial_ref", "")).to_epsg()
            except Exception:
                pass
    
        arr = (
            ds.drop_vars("crs", errors="ignore")
            .fillna(65535)
            .astype("uint16")
            .to_array(dim="bands")
        )
    
        arr_dict[fname] = {"array": arr, "epsg": epsg}
        ds.close()
    except Exception as e:
        print('Exception', e)
        pass

**Deploy the model on the inference patches**

In [16]:
from notebook_utils.local_inference import run_croptype_mapping
from notebook_utils.local_inference import run_cropland_mapping, classification_to_geotiff
import os

if not os.path.exists('./outputs'):
    os.makedirs('./outputs')

output_path_dict = {}

for fname, data in arr_dict.items():
    print('RUNNING FOR :', fname)
    landcover_embeddings, cropland_classification = run_cropland_mapping(data['array'], epsg=epsg)
    croptype_embeddings, croptype_classification = run_croptype_mapping(data['array'], epsg=epsg, classifier_url=model_url)
    # Set all croptype_classification pixel values to 254 where cropland_classification 'classification' band == 0
    mask = cropland_classification.sel(bands="classification") == 0
    croptype_classification = croptype_classification.where(~mask, 254)
    
    # save to GeoTIFF
    Path("outputs").parent.mkdir(exist_ok=True)
    croptype_path = Path("outputs") / f"{fname}.tif"
    classification_to_geotiff(
        classification=croptype_classification,
        epsg=epsg,
        out_path=croptype_path
    )
    output_path_dict[fname] = croptype_path


**Visualise your outputs**

In [17]:
from worldcereal.utils.models import load_model_lut
from notebook_utils.visualization import visualize_product

for fname, croptype_path in output_path_dict.items():
    lut = load_model_lut(model_url)
    print(fname)
    visualize_product(croptype_path, product='croptype', lut=lut, interactive_mode=False)

**You have reached the end of the notebook**