# Prediction Output in Raster or Vector Formats

## Preparations

Install requirements

In [None]:
%pip install -r ../requirements.txt

Import libraries

In [11]:
import geopandas as gpd
import pandas as pd
import rasterio
from catboost import CatBoostClassifier
from rasterio.features import rasterize
from rasterio.transform import from_origin
import matplotlib.pyplot as plt

Define crop classes with their corresponding IDs

In [12]:
class_names = {
    1: 'winter wheat',
    2: 'spring oats',
    3: 'spring barley',
    4: 'spring rye',
    5: 'corn',
    6: 'soybean',
    7: 'sunflower',
    8: 'sugar beet',
    9: 'rapeseed',
    10: 'sorghum',
    11: 'potato',
    12: 'cotton',
    13: 'spring wheat',
    14: 'winter oats',
    15: 'winter barley',
    16: 'winter rye'
}

We demonstrate how to generate a map using the fine-tuned CropGRM-small model. First, select the features required by the CropGRM-small model from the dataset.

In [13]:
cols_to_select = [
    'sum_t_4', 'sum_t_5', 'sum_t_6', 'sum_t_7', 'sum_t_8', 'sum_t_9', 'sum_t_10',
    'sum_prec_4', 'sum_prec_6', 'sum_prec_10', 'median_t_4', 'median_t_6', 'median_t_9',
    'median_t_10', 'ndre_S', 'median_red_fitted_8', 'median_nir_fitted_5', 'median_nir_fitted_8',
    'median_swir1_fitted_6', 'median_swir1_fitted_7', 'median_swir1_fitted_8', 'median_green_fitted_7',
    'median_green_fitted_8', 'median_swir2_fitted_5',
]

Load the model

In [14]:
model = CatBoostClassifier().load_model('../models/finetuned_model.cbm')

Make predictions to your data

In [15]:
df_classes = pd.read_parquet('../data/processed/input_data_for_model.parquet')
df_classes['class'] = model.predict(df_classes[cols_to_select])

Merge geometries with predicted classes by `field_id`

In [16]:
gdf = (
    gpd
    .read_file('../data/raw/fields.fgb')
    .merge(df_classes[['field_id', 'class']], on='field_id', how='left')
)
gdf['class'] = gdf['class'].fillna(0).astype(int)
gdf['class_name'] = gdf['class'].map(class_names)

gdf.explore(column='class_name', cmap='tab20')

Save the results as a vector file

In [17]:

gdf[['field_id', 'class_name', 'geometry']].to_file('../data/final/CropMap_fields.fgb', encoding='utf8')

Additionally save the results as raster

In [18]:
nodata = 0
pixel_size = 10
xmin, ymin, xmax, ymax = gdf.total_bounds
width = int((xmax - xmin) / pixel_size)
height = int((ymax - ymin) / pixel_size)

transform = from_origin(west=xmin, north=ymax, xsize=pixel_size, ysize=pixel_size)

shapes = ((geom, value) for geom, value in zip(gdf.geometry, gdf['class']))

raster = rasterize(
    shapes=shapes,
    out_shape=(height, width),
    fill=nodata,
    transform=transform,
    dtype='uint8',
    all_touched=True
)

cmap = plt.get_cmap('tab20')
colormap = {
    i: tuple(int(c * 255) for c in cmap(i)[:3])
    for i in range(cmap.N)
}

with rasterio.open(
    '../data/final/CropMap_fields.tif',
    'w',
    driver='GTiff',
    height=height,
    width=width,
    count=1,
    dtype=raster.dtype,
    nodata=nodata,
    crs=gdf.crs,
    transform=transform,
    photometric='Palette',
    COMPRESS='ZSTD',
) as dst:
    dst.write(raster, 1)
    dst.write_colormap(1, colormap)