# Compare crop yields to observations

- Uses raw annual CTSM outputs (NOT timeseries files).

Notebook created by Sam Rabin (samrabin@ucar.edu).

In [None]:
import sys
import os
import glob
import numpy as np
import xarray as xr
import pandas as pd
from time import time
from dask.distributed import Client, wait
import convert_pft1d_to_sparse
import importlib
import earthstat
import caselist
import crop_timeseries_figs
import plotting_utils

# Plotting utils
import matplotlib.pyplot as plt
import clm_and_earthstat_maps as caem

# Start a local Dask cluster using all available cores
client = Client()
client

## 1. Settings

### 1.1 Parameters modifiable in config.yml

In [None]:
# Path to CUPiD externals. This method is supposedly unreliable, so it's best for this to be
# overridden by a value given in config.yml. See examples/crops/config.yml.
externals_path = os.path.join(os.getcwd(), os.pardir, os.pardir, "externals")

# Where land output is stored
CESM_output_dir = os.path.join(
    os.path.sep,
    "glade",
    "work",
    "samrabin",
    "clm6_crop_reparam_outputs",
)

# Full casenames that are present in CESM_output_dir and in individual filenames
case_name_list = [
    "ctsm53019_f09_BNF_hist",
    "clm6_crop_032",
    "clm6_crop_032_nomaxlaitrig",
    "clm6_crop_032_nmlt_phaseparams",
    # "alpha-ctsm5.4.CMIP7.09.ctsm5.3.068",
    # "alpha-ctsm5.4.CMIP7.09.ctsm5.3.068_nogddadapt",
    # "crujra_matreqs",
    # "crujra_matreqs_nogddadapt",
    "clm6_crop_032_nmlt_arooti",
]

# Names of cases to show in figure legends
case_legend_list = [
    "A: ctsm5.3.019 (GSWP3)",
    "B: ctsm5.3.032 (CRU-JRA)",
    "C: As B + w/o max LAI triggering grainfill",
    "D: As C + pre-CLM5 crop phase params",
    # "E: 5.4 branch",
    # "F: 5.4 branch (no GDD adapt)",
    # "G: CRU-JRA mat reqs",
    # "H: CRU-JRA mat reqs (no GDD adapt)",
    "I: As C + pre-CLM5 arooti",
]


# # Where land output is stored
# CESM_output_dir = os.path.join(
#     os.path.sep,
#     "glade",
#     "work",
#     "samrabin",
#     "wwieder_run_outputs",
# )
# # Full casenames that are present in CESM_output_dir and in individual filenames
# case_name_list = [
#     # "ctsm53n04ctsm52028_f09_hist",  # Doesn't have GRAINC_TO_FOOD_PERHARV
#     "ctsm53041_54surfdata_snowTherm_100_HIST",
#     "ctsm5.4_5.3.068_PPEcal115f09_118_HIST",
# ]
# # Names of cases to show in figure legends
# case_legend_list = [
#     # "ctsm53n04ctsm52028_f09_hist",  # Doesn't have GRAINC_TO_FOOD_PERHARV
#     "Run 100",
#     "Run 118",
# ]


# The actual netCDF timesteps, not the names of the files
start_year = 1961
end_year = 2024

cfts_to_include = [
    "temperate_corn",
    "tropical_corn",
    "cotton",
    "rice",
    "temperate_soybean",
    "tropical_soybean",
    "sugarcane",
    "spring_wheat",
    "irrigated_temperate_corn",
    "irrigated_tropical_corn",
    "irrigated_cotton",
    "irrigated_rice",
    "irrigated_temperate_soybean",
    "irrigated_tropical_soybean",
    "irrigated_sugarcane",
    "irrigated_spring_wheat",
]

crops_to_include = [
    "corn",
    "cotton",
    "rice",
    "soybean",
    "sugarcane",
    "wheat",
]
fao_to_clm_dict = {
    "Maize": "corn",
    "Rice": "rice",
    "Seed cotton, unginned": "cotton",
    "Soya beans": "soybean",
    "Sugar cane": "sugarcane",
    "Wheat": "wheat",
}

verbose = True

obs_data_dir = os.path.join(
    os.sep + "glade",
    "campaign",
    "cesm",
    "development",
    "cross-wg",
    "diagnostic_framework",
    "CUPiD_obs_data",
)

### 1.2 Other settings

In [None]:
# Set up directory for any scratch output
if "SCRATCH" in os.environ:
    cupid_temp = os.path.join(os.environ["SCRATCH"], "CUPiD_scratch")
    os.makedirs(cupid_temp, exist_ok=True)
else:
    cupid_temp = "."

N_PFTS = 78

short_names = [case.split(".")[-1] for case in case_name_list]

if start_year > end_year:
    raise RuntimeError(f"start_year ({start_year}) > end_year ({end_year})")

if case_legend_list:
    if len(case_name_list) != len(case_legend_list):
        raise RuntimeError("case_legend_list must be same length as case_name_list")
else:
    case_legend_list = case_name_list

In [None]:
# Move options to dict for easier passing among functions
opts = {}
opts["CESM_output_dir"] = CESM_output_dir
del CESM_output_dir
opts["case_name_list"] = case_name_list
del case_name_list
opts["case_legend_list"] = case_legend_list
del case_legend_list
opts["start_year"] = start_year
del start_year
opts["end_year"] = end_year
del end_year
opts["cfts_to_include"] = cfts_to_include
del cfts_to_include
opts["crops_to_include"] = crops_to_include
del crops_to_include
opts["fao_to_clm_dict"] = fao_to_clm_dict
del fao_to_clm_dict
opts["verbose"] = verbose
del verbose
opts["obs_data_dir"] = obs_data_dir
del obs_data_dir

### 1.3 Import stuff from externals

In [None]:
sys.path.append(externals_path)
import ctsm_postprocessing.utils as utils
from ctsm_postprocessing.crops import crop_secondary_variables as c2o
from ctsm_postprocessing.crops import cropcase
import ctsm_postprocessing.crops.faostat as faostat
from ctsm_postprocessing.resolutions import identify_resolution

## 2. Import case data

### 2.1 Import cases

In [None]:
importlib.reload(cropcase)
importlib.reload(caselist)

case_list = caselist.CaseList(
    CropCase=cropcase.CropCase,
    identify_resolution=identify_resolution,
    opts=opts,
)

### 2.3 Import FAOSTAT

In [None]:
fao_file = os.path.join(
    opts["obs_data_dir"],
    "lnd",
    "analysis_datasets",
    "ungridded",
    "timeseries",
    "FAOSTAT",
    "Production_Crops_Livestock_2025-02-25",
    "norm",
    "Production_Crops_Livestock_E_All_Data_(Normalized).csv",
)

fao = faostat.FaostatProductionCropsLivestock(
    fao_file,
    y1=opts["start_year"],
    yN=opts["end_year"],
)

# TODO: Move all the following to FaostatProductionCropsLivestock class

fao_prod = fao.get_element("Production", fao_to_clm_dict=opts["fao_to_clm_dict"])
fao_area = fao.get_element("Area harvested", fao_to_clm_dict=opts["fao_to_clm_dict"])

# Only include where both production and area data are present
def drop_a_where_not_in_b(a, b):
    return a.drop([i for i in a.index.difference(b.index)])


fao_prod = drop_a_where_not_in_b(fao_prod, fao_area)
fao_area = drop_a_where_not_in_b(fao_area, fao_prod)
if not fao_prod.index.equals(fao_area.index):
    raise RuntimeError("Mismatch of prod and area indices after trying to align them")

# Don't allow production where no area
is_bad = (fao_prod["Value"] > 0) & (fao_area["Value"] == 0)
where_bad = np.where(is_bad)[0]
bad_prod = fao_prod.iloc[where_bad]
bad_area = fao_area.iloc[where_bad]
fao_prod = fao_prod[~is_bad]
fao_area = fao_area[~is_bad]
if not fao_prod.index.equals(fao_area.index):
    raise RuntimeError(
        "Mismatch of prod and area indices after disallowing production where no area"
    )

# Get yield
fao_yield = fao_prod.copy()
fao_yield["Element"] = "Yield"
fao_yield["Unit"] = "/".join([fao_prod["Unit"].iloc[0], fao_area["Unit"].iloc[0]])
fao_yield["Value"] = fao_prod["Value"] / fao_area["Value"]

### 2.3 Import EarthStat (basically gridded FAOSTAT)

In [None]:
importlib.reload(earthstat)

earthstat_dir = os.path.join(
    opts["obs_data_dir"],
    "lnd",
    "analysis_datasets",
    "multi_grid",
    "annual",
    "FAO-EarthStatYields",
)

earthstat_data = earthstat.EarthStat(earthstat_dir, case_list.resolutions, opts)

In [None]:
importlib.reload(crop_timeseries_figs)
importlib.reload(earthstat)

# Get versions of CLM stats as if planted with EarthStat area
# TODO: Don't hard-code EARTHSTAT_RES_TO_PLOT here
EARTHSTAT_RES_TO_PLOT = "f09"
for case in case_list:
    case_ds = case.cft_ds

    for i, crop in enumerate(opts["crops_to_include"]):
        # Get EarthStat area
        crop_area_es = utils.ungrid(
            gridded_data=earthstat_data[EARTHSTAT_RES_TO_PLOT].get_data("area", crop),
            ungridded_ds=case_ds,
        )

        # Setup crop_*crop_area_es_expanded variable or append to it
        if i == 0:
            crop_area_es_expanded = crop_area_es.expand_dims(dim="crop", axis=0)
        else:
            # Append this crop's DataArray to existing one
            crop_area_es_expanded = xr.concat(
                [crop_area_es_expanded, crop_area_es],
                dim="crop",
            )

    # Convert area units
    clm_units = case_ds["crop_area"].attrs["units"]
    es_units = crop_area_es.attrs["units"]
    if clm_units == "m2" and es_units == "Mha":
        crop_area_es_expanded *= 1e4 * 1e6
        crop_area_es_expanded.attrs["units"] = "m2"
    else:
        raise NotImplementedError(
            f"Conversion assumes CLM area in m2 (got {clm_units}) and EarthStat area in Mha (got {es_units})"
        )

    # Before saving, check alignment of all dims
    crop_area_es_expanded = earthstat.check_dim_alignment(
        crop_area_es_expanded, case_ds
    )

    # Save to case_ds, filling with NaN as necessary (e.g., if there are CLM years not in EarthStat).
    case_ds["crop_area_es"] = crop_area_es_expanded

    # Calculate production as if planted with EarthStat area
    area_units = case_ds["crop_area_es"].attrs["units"]
    area_units_exp = "m2"
    yield_units = case_ds["crop_yield"].attrs["units"]
    yield_units_exp = "g/m2"
    if area_units != area_units_exp or yield_units != yield_units_exp:
        raise NotImplementedError(
            f"Yield calculation assumes area in {area_units_exp} (got {area_units}) and yield in {yield_units_exp} (got {yield_units})"
        )
    case_ds["crop_prod_es"] = case_ds["crop_area_es"] * case_ds["crop_yield"].rename(
        {"pft": "gridcell"}
    )
    case_ds["crop_prod_es"].attrs["units"] = "g"

    # Save EarthStat time axis to avoid plotting years with no EarthStat data
    earthstat_time = crop_area_es_expanded["time"]
    earthstat_time = earthstat_time.rename({"time": "earthstat_time_coord"})
    case_ds["earthstat_time"] = earthstat_time

## 3. Time series figures

### 3.1 With CLM areas (including all merged CFTs)

In [None]:
importlib.reload(crop_timeseries_figs)

crop_timeseries_figs.main("yield", earthstat_data, case_list, fao_yield, opts)
crop_timeseries_figs.main("prod", earthstat_data, case_list, fao_prod, opts)
crop_timeseries_figs.main("area", earthstat_data, case_list, fao_area, opts)

## 3.2 With EarthStat areas

In [None]:
importlib.reload(crop_timeseries_figs)

crop_timeseries_figs.main(
    "yield", earthstat_data, case_list, fao_yield, opts, use_earthstat_area=True
)

crop_timeseries_figs.main(
    "prod", earthstat_data, case_list, fao_prod, opts, use_earthstat_area=True
)

# Even when using EarthStat areas, you will not necessarily see perfect alignment between all the CLM runs with each other and with EarthStat. This can happen due to different land masks.
crop_timeseries_figs.main(
    "area", earthstat_data, case_list, fao_area, opts, use_earthstat_area=True
)

## 4. Yield maps

In [None]:
importlib.reload(caem)

caem.clm_and_earthstat_maps(
    which="yield",
    case_list=case_list,
    earthstat_data=earthstat_data,
    utils=utils,
    opts=opts,
)

In [None]:
importlib.reload(caem)

caem.clm_and_earthstat_maps(
    which="prod",
    case_list=case_list,
    earthstat_data=earthstat_data,
    utils=utils,
    opts=opts,
)

In [None]:
importlib.reload(caem)

caem.clm_and_earthstat_maps(
    which="area",
    case_list=case_list,
    earthstat_data=earthstat_data,
    utils=utils,
    opts=opts,
)

## 5. Immature and failed harvests

In [None]:
importlib.reload(plotting_utils)

for crop in opts["crops_to_include"]:
    results = plotting_utils.ResultsMaps(case_list.mapfig_layout, vrange=[0, 1])
    for case in case_list:
        tmp = case.cft_ds.sel(crop=crop)
        tmp["frac_immature_harv_timemean"] = tmp["crop_harv_area_immature"].sum(
            dim="time"
        ) / tmp["crop_harv_area"].sum(dim="time")
        map_clm = utils.grid_one_variable(tmp, "frac_immature_harv_timemean")
        map_clm.attrs["units"] = "unitless"
        map_clm.name = "Fraction immature harvests"
        results[case.name] = map_clm
    results.plot(case_name_list=case_list.names, crop=crop)

In [None]:
importlib.reload(plotting_utils)

for crop in opts["crops_to_include"]:
    results = plotting_utils.ResultsMaps(case_list.mapfig_layout, vrange=[0, 1])
    for case in case_list:
        tmp = case.cft_ds.sel(crop=crop)
        tmp["frac_failed_harv_timemean"] = tmp["crop_harv_area_failed"].sum(
            dim="time"
        ) / tmp["crop_harv_area"].sum(dim="time")
        map_clm = utils.grid_one_variable(tmp, "frac_failed_harv_timemean")
        map_clm.attrs["units"] = "unitless"
        map_clm.name = "Fraction failed harvests"
        results[case.name] = map_clm
    results.plot(case_name_list=case_list.names, crop=crop)

In [None]:
importlib.reload(crop_timeseries_figs)

# Get figure layout info
fig_opts, fig, axes = crop_timeseries_figs.setup_fig(opts)

for i, crop in enumerate(opts["crops_to_include"]):
    ax = axes.ravel()[i]
    plt.sca(ax)

    # Plot case data
    for c, case in enumerate(case_list):

        crop_data_ts = case.cft_ds.sel(crop=crop)["crop_harv_area_immature"].sum(
            dim=["pft"]
        ) / case.cft_ds.sel(crop=crop)["crop_harv_area"].sum(dim=["pft"])

        # Change line style for one line that overlaps another for some crops
        # TODO: Optionally define linestyle for each case in config.yml
        if "clm6_crop_032_nomaxlaitrig" in opts["case_name_list"] and opts[
            "case_name_list"
        ][c].endswith("clm6_crop_032_nmlt_phaseparams"):
            linestyle = "--"
        else:
            linestyle = "-"

        # Plot
        fig_opts["title"] = "Fraction immature crop area"
        crop_data_ts.plot(linestyle=linestyle)

    # Finish plot
    ax.set_title(crop)
    plt.xlabel("")

crop_timeseries_figs.finish_fig(opts, fig_opts, fig, incl_obs=False)

In [None]:
importlib.reload(crop_timeseries_figs)

# Get figure layout info
fig_opts, fig, axes = crop_timeseries_figs.setup_fig(opts)

for i, crop in enumerate(opts["crops_to_include"]):
    ax = axes.ravel()[i]
    plt.sca(ax)

    # Plot case data
    for c, case in enumerate(case_list):

        crop_data_ts = case.cft_ds.sel(crop=crop)["crop_harv_area_failed"].sum(
            dim=["pft"]
        ) / case.cft_ds.sel(crop=crop)["crop_harv_area"].sum(dim=["pft"])

        # Change line style for one line that overlaps another for some crops
        # TODO: Optionally define linestyle for each case in config.yml
        if "clm6_crop_032_nomaxlaitrig" in opts["case_name_list"] and opts[
            "case_name_list"
        ][c].endswith("clm6_crop_032_nmlt_phaseparams"):
            linestyle = "--"
        else:
            linestyle = "-"

        # Plot
        fig_opts["title"] = "Fraction failed crop area"
        crop_data_ts.plot(linestyle=linestyle)

    # Finish plot
    ax.set_title(crop)
    plt.xlabel("")

crop_timeseries_figs.finish_fig(opts, fig_opts, fig, incl_obs=False)