# Compare crop yields to observations

- Uses raw annual CTSM outputs (NOT timeseries files).

Notebook created by Sam Rabin (samrabin@ucar.edu).

In [None]:
import sys
import os
import glob
import numpy as np
import xarray as xr
import pandas as pd
from time import time
from dask.distributed import Client, wait
import convert_pft1d_to_sparse
import importlib
import earthstat
import caselist

# Plotting utils
import matplotlib.pyplot as plt
import clm_and_earthstat_maps as caem

# Start a local Dask cluster using all available cores
client = Client()
client

## 1. Settings

### 1.1 Parameters modifiable in config.yml

In [None]:
# Path to CUPiD externals. This method is supposedly unreliable, so it's best for this to be
# overridden by a value given in config.yml. See examples/crops/config.yml.
externals_path = os.path.join(os.getcwd(), os.pardir, os.pardir, "externals")

# Where land output is stored
CESM_output_dir = os.path.join(
    os.path.sep,
    "glade",
    "work",
    "samrabin",
    "clm6_crop_reparam_outputs",
)

# Full casenames that are present in CESM_output_dir and in individual filenames
case_name_list = [
    "ctsm53019_f09_BNF_hist",
    "clm6_crop_032",
    "clm6_crop_032_nomaxlaitrig",
    "clm6_crop_032_nmlt_phaseparams",
]

# Names of cases to show in figure legends
case_legend_list = [
    "A: ctsm5.3.019 (GSWP3)",
    "B: ctsm5.3.032 (CRU-JRA)",
    "C: As B + w/o max LAI triggering grainfill",
    "D: As C + pre-CLM5 crop phase params",
]

clm_file_h = ".h0."

# The actual netCDF timesteps, not the names of the files
start_year = 1961
end_year = 1965

cfts_to_include = [
    "temperate_corn",
    "tropical_corn",
    "cotton",
    "rice",
    "temperate_soybean",
    "tropical_soybean",
    "sugarcane",
    "spring_wheat",
    "irrigated_temperate_corn",
    "irrigated_tropical_corn",
    "irrigated_cotton",
    "irrigated_rice",
    "irrigated_temperate_soybean",
    "irrigated_tropical_soybean",
    "irrigated_sugarcane",
    "irrigated_spring_wheat",
]

crops_to_include = [
    "corn",
    "cotton",
    "rice",
    "soybean",
    "sugarcane",
    "wheat",
]
fao_to_clm_dict = {
    "Maize": "corn",
    "Rice": "rice",
    "Seed cotton, unginned": "cotton",
    "Soya beans": "soybean",
    "Sugar cane": "sugarcane",
    "Wheat": "wheat",
}

dev_mode = True
verbose = True

obs_data_dir = os.path.join(
    os.sep + "glade",
    "campaign",
    "cesm",
    "development",
    "cross-wg",
    "diagnostic_framework",
    "CUPiD_obs_data",
)

### 1.2 Other settings

In [None]:
# Set up directory for any scratch output
if "SCRATCH" in os.environ:
    cupid_temp = os.path.join(os.environ["SCRATCH"], "CUPiD_scratch")
    os.makedirs(cupid_temp, exist_ok=True)
else:
    cupid_temp = "."

N_PFTS = 78

short_names = [case.split(".")[-1] for case in case_name_list]

if start_year > end_year:
    raise RuntimeError(f"start_year ({start_year}) > end_year ({end_year})")

if case_legend_list:
    if len(case_name_list) != len(case_legend_list):
        raise RuntimeError("case_legend_list must be same length as case_name_list")
else:
    case_legend_list = case_name_list

In [None]:
# Move options to dict for easier passing among functions
opts = {}
opts["CESM_output_dir"] = CESM_output_dir
del CESM_output_dir
opts["case_name_list"] = case_name_list
del case_name_list
opts["case_legend_list"] = case_legend_list
del case_legend_list
opts["clm_file_h"] = clm_file_h
del clm_file_h
opts["start_year"] = start_year
del start_year
opts["end_year"] = end_year
del end_year
opts["cfts_to_include"] = cfts_to_include
del cfts_to_include
opts["crops_to_include"] = crops_to_include
del crops_to_include
opts["fao_to_clm_dict"] = fao_to_clm_dict
del fao_to_clm_dict
opts["dev_mode"] = dev_mode
del dev_mode
opts["verbose"] = verbose
del verbose
opts["obs_data_dir"] = obs_data_dir
del obs_data_dir

### 1.3 Import stuff from externals

In [None]:
sys.path.append(externals_path)
import ctsm_postprocessing.utils as utils
from ctsm_postprocessing.crops import crop_secondary_variables as c2o
from ctsm_postprocessing.crops import cropcase
import ctsm_postprocessing.crops.faostat as faostat
from ctsm_postprocessing.resolutions import identify_resolution

## 2. Import case data

### 2.1 Import cases

In [None]:
importlib.reload(cropcase)
importlib.reload(caselist)

case_list = caselist.CaseList(
    CropCase=cropcase.CropCase,
    identify_resolution=identify_resolution,
    opts=opts,
)

In [None]:
# Calculate some extra variables
# TODO: Move these calculations to CropCase
for case in case_list:
    case_ds = case.cft_ds
    for i, crop in enumerate(opts["crops_to_include"]):
        cft_ds = case_ds.sel(cft=case.crop_list[crop].pft_nums)

        # Get area
        cft_area = cft_ds["pfts1d_gridcellarea"] * cft_ds["pfts1d_wtgcell"]
        cft_area *= 1e6  # Convert km2 to m2
        cft_area.attrs["units"] = "m2"

        # Get production
        cft_prod = cft_ds["YIELD_ANN"] * cft_area
        cft_prod.attrs["units"] = "g"

        # Setup crop_cft_* variables or append to them
        cft_area_expanded = cft_area.expand_dims(dim="crop", axis=0)
        cft_prod_expanded = cft_prod.expand_dims(dim="crop", axis=0)
        if i == 0:

            # Add crop (names) variable/dimension to case_ds, if needed
            if "crop" not in case_ds:
                crop_da = xr.DataArray(
                    data=opts["crops_to_include"],
                    dims=["crop"],
                )
                case_ds["crop"] = crop_da

            # Define crop_cft_* variables
            crop_cft_area_da = xr.DataArray(
                data=cft_area_expanded,
            )
            crop_cft_prod_da = xr.DataArray(
                data=cft_prod_expanded,
            )
        else:
            # Append this crop's DataArrays to existing ones
            crop_cft_area_da = xr.concat(
                [crop_cft_area_da, cft_area_expanded],
                dim="crop",
            )
            crop_cft_prod_da = xr.concat(
                [crop_cft_prod_da, cft_prod_expanded],
                dim="crop",
            )

    # Add crop_cft_* variables to case_ds
    case_ds["crop_cft_area"] = crop_cft_area_da
    case_ds["crop_cft_prod"] = crop_cft_prod_da

    # Calculate yield
    case_ds["crop_cft_yield"] = crop_cft_prod_da / crop_cft_area_da
    case_ds["crop_cft_yield"].attrs["units"] = (
        crop_cft_prod_da.attrs["units"] + "/" + crop_cft_area_da.attrs["units"]
    )

### 2.3 Import FAOSTAT

In [None]:
fao_file = os.path.join(
    opts["obs_data_dir"],
    "lnd",
    "analysis_datasets",
    "ungridded",
    "timeseries",
    "FAOSTAT",
    "Production_Crops_Livestock_2025-02-25",
    "norm",
    "Production_Crops_Livestock_E_All_Data_(Normalized).csv",
)

fao = faostat.FaostatProductionCropsLivestock(
    fao_file,
    y1=opts["start_year"],
    yN=opts["end_year"],
)

# TODO: Move all the following to FaostatProductionCropsLivestock class

fao_prod = fao.get_element("Production", fao_to_clm_dict=opts["fao_to_clm_dict"])
fao_area = fao.get_element("Area harvested", fao_to_clm_dict=opts["fao_to_clm_dict"])

# Only include where both production and area data are present
def drop_a_where_not_in_b(a, b):
    return a.drop([i for i in a.index.difference(b.index)])


fao_prod = drop_a_where_not_in_b(fao_prod, fao_area)
fao_area = drop_a_where_not_in_b(fao_area, fao_prod)
if not fao_prod.index.equals(fao_area.index):
    raise RuntimeError("Mismatch of prod and area indices after trying to align them")

# Don't allow production where no area
is_bad = (fao_prod["Value"] > 0) & (fao_area["Value"] == 0)
where_bad = np.where(is_bad)[0]
bad_prod = fao_prod.iloc[where_bad]
bad_area = fao_area.iloc[where_bad]
fao_prod = fao_prod[~is_bad]
fao_area = fao_area[~is_bad]
if not fao_prod.index.equals(fao_area.index):
    raise RuntimeError(
        "Mismatch of prod and area indices after disallowing production where no area"
    )

# Get yield
fao_yield = fao_prod.copy()
fao_yield["Element"] = "Yield"
fao_yield["Unit"] = "/".join([fao_prod["Unit"].iloc[0], fao_area["Unit"].iloc[0]])
fao_yield["Value"] = fao_prod["Value"] / fao_area["Value"]

### 2.3 Import EarthStat (basically gridded FAOSTAT)

In [None]:
importlib.reload(earthstat)

earthstat_dir = os.path.join(
    opts["obs_data_dir"],
    "lnd",
    "analysis_datasets",
    "multi_grid",
    "annual",
    "FAO-EarthStatYields",
)

earthstat_data = earthstat.EarthStat(earthstat_dir, case_list.resolutions, opts)

## 3. Yield time series

In [None]:
# TODO: Move this to its own module. Waiting for resolution of TODO in "Import
# FAOSTAT" cell so I don't have to pass so much stuff.

if opts["verbose"]:
    start = time()

# Get figure layout info
n_crops_to_include = len(opts["crops_to_include"])
if 5 <= n_crops_to_include <= 6:
    nrows = 2
    ncols = 3
    height = 10.5
    width = 15
    hspace = 0.25
    wspace = 0.35
else:
    raise RuntimeError(f"Specify figure layout for Ncrops=={n_crops_to_include}")
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(width, height))

fao_yield_world = fao_yield.query("Area == 'World'")

earthstat_res_to_plot = "f09"
earthstat_ds_to_plot = earthstat_data[earthstat_res_to_plot]

for i, crop in enumerate(opts["crops_to_include"]):
    ax = axes.ravel()[i]
    plt.sca(ax)

    # Plot case data
    for c, case in enumerate(case_list):

        # Do NOT use crop_cft_yield here, because you need to sum across cft and pft before doing the division
        crop_prod_ts = (
            case.cft_ds["crop_cft_prod"].sel(crop=crop).sum(dim=["cft", "pft"])
        )
        crop_area_ts = (
            case.cft_ds["crop_cft_area"].sel(crop=crop).sum(dim=["cft", "pft"])
        )
        crop_yield_ts = crop_prod_ts / crop_area_ts

        # Plot data
        # TODO: Increase robustness of unit conversion: Check that it really is
        # g/m2 to start with.
        crop_yield_ts *= 1e-6 * 1e4  # Convert g/m2 to tons/ha
        crop_yield_ts.name = "Yield"
        ctsm_units = "t/ha"
        crop_yield_ts.attrs["units"] = ctsm_units

        # Change line style for one line that overlaps another for some crops
        # TODO: Optionally define linestyle for each case in config.yml
        if "clm6_crop_032_nomaxlaitrig" in opts["case_name_list"] and opts[
            "case_name_list"
        ][c].endswith("clm6_crop_032_nmlt_phaseparams"):
            linestyle = "--"
        else:
            linestyle = "-"

        # Plot
        crop_yield_ts.plot(linestyle=linestyle)

    # Plot FAOSTAT data
    faostat_units = fao_yield_world["Unit"].iloc[0]
    if faostat_units != ctsm_units:
        raise RuntimeError(
            f"CTSM units ({ctsm_units}) do not match FAOSTAT units ({faostat_units})"
        )
    fao_yield_world_thiscrop = fao_yield_world.query(f"Crop == '{crop}'")
    ax.plot(
        crop_yield_ts.time,
        fao_yield_world_thiscrop["Value"].values,
        "-k",
    )

    # Plot EarthStat data
    earthstat_crop_idx = None
    try:
        earthstat_crop_idx = earthstat_data.crops.index(crop)
    except ValueError:
        print(f"{crop} not in EarthStat res {earthstat_res_to_plot}; skipping")
    if earthstat_crop_idx is not None:
        earthstat_crop_ds = earthstat_ds_to_plot.isel(crop=earthstat_crop_idx)
        earthstat_area_da_tyx = earthstat_crop_ds["HarvestArea"]
        earthstat_prod_da_tyx = earthstat_crop_ds["Production"]
        earthstat_prod_da_t = earthstat_prod_da_tyx.sum(dim=["lat", "lon"])
        earthstat_area_da_t = earthstat_area_da_tyx.sum(dim=["lat", "lon"])
        earthstat_yield_da_t = earthstat_prod_da_t / earthstat_area_da_t
        ax.plot(
            earthstat_yield_da_t["time"],
            earthstat_yield_da_t.values,
            "0.5",  # gray
        )

    # Finish plot
    ax.set_title(crop)
    plt.xlabel("")

plt.subplots_adjust(wspace=wspace, hspace=hspace)
fig.legend(
    labels=opts["case_legend_list"] + ["FAOSTAT", f"EarthStat {earthstat_res_to_plot}"],
    loc="upper center",
    bbox_to_anchor=(0.5, 0.96),
    ncol=3,
    bbox_transform=fig.transFigure,
)
fig.suptitle("Global yield", fontsize="x-large", fontweight="bold")
plt.show()

if opts["verbose"]:
    end = time()
    print(f"Time series plots took {int(end - start)} s")

## 4. Yield maps

In [None]:
importlib.reload(caem)

caem.clm_and_earthstat_maps(
    which="yield",
    case_list=case_list,
    earthstat_data=earthstat_data,
    utils=utils,
    opts=opts,
)

In [None]:
importlib.reload(caem)

caem.clm_and_earthstat_maps(
    which="prod",
    case_list=case_list,
    earthstat_data=earthstat_data,
    utils=utils,
    opts=opts,
)

In [None]:
importlib.reload(caem)

caem.clm_and_earthstat_maps(
    which="area",
    case_list=case_list,
    earthstat_data=earthstat_data,
    utils=utils,
    opts=opts,
)