# Compare crop yields to observations

- Uses raw annual CTSM outputs (NOT timeseries files).

Notebook created by Sam Rabin (samrabin@ucar.edu).

In [None]:
import sys
import os
import glob
import numpy as np
import xarray as xr
import pandas as pd
from time import time
from dask.distributed import Client, wait
import convert_pft1d_to_sparse
import importlib

# Plotting utils
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature

# Start a local Dask cluster using all available cores
client = Client()
client

## 1. Settings

### 1.1 Parameters modifiable in config.yml

In [None]:
# Path to CUPiD externals. This method is supposedly unreliable, so it's best for this to be
# overridden by a value given in config.yml. See examples/crops/config.yml.
externals_path = os.path.join(os.getcwd(), os.pardir, os.pardir, "externals")

# Where land output is stored
CESM_output_dir = os.path.join(
    os.path.sep,
    "glade",
    "work",
    "samrabin",
    "clm6_crop_reparam_outputs",
)

# Full casenames that are present in CESM_output_dir and in individual filenames
case_name_list = [
    "ctsm53019_f09_BNF_hist",
    "clm6_crop_032",
    "clm6_crop_032_nomaxlaitrig",
    "clm6_crop_032_nmlt_phaseparams",
]

# Names of cases to show in figure legends
case_legend_list = [
    "A: ctsm5.3.019 (GSWP3)",
    "B: ctsm5.3.032 (CRU-JRA)",
    "C: As B + w/o max LAI triggering grainfill",
    "D: As C + pre-CLM5 crop phase params",
]

clm_file_h = ".h0."

# The actual netCDF timesteps, not the names of the files
start_year = 1961
end_year = 2023

cfts_to_include = [
    "temperate_corn",
    "tropical_corn",
    "cotton",
    "rice",
    "temperate_soybean",
    "tropical_soybean",
    "sugarcane",
    "spring_wheat",
    "irrigated_temperate_corn",
    "irrigated_tropical_corn",
    "irrigated_cotton",
    "irrigated_rice",
    "irrigated_temperate_soybean",
    "irrigated_tropical_soybean",
    "irrigated_sugarcane",
    "irrigated_spring_wheat",
]

crops_to_include = [
    "corn",
    "cotton",
    "rice",
    "soybean",
    "sugarcane",
    "wheat",
]
fao_to_clm_dict = {
    "Maize": "corn",
    "Rice": "rice",
    "Seed cotton, unginned": "cotton",
    "Soya beans": "soybean",
    "Sugar cane": "sugarcane",
    "Wheat": "wheat",
}

dev_mode = True
verbose = True

obs_data_dir = os.path.join(
    os.sep + "glade",
    "campaign",
    "cesm",
    "development",
    "cross-wg",
    "diagnostic_framework",
    "CUPiD_obs_data",
)

### 1.2 Other settings

In [None]:
# Set up directory for any scratch output
if "SCRATCH" in os.environ:
    cupid_temp = os.path.join(os.environ["SCRATCH"], "CUPiD_scratch")
    os.makedirs(cupid_temp, exist_ok=True)
else:
    cupid_temp = "."

N_PFTS = 78

short_names = [case.split(".")[-1] for case in case_name_list]

if start_year > end_year:
    raise RuntimeError(f"start_year ({start_year}) > end_year ({end_year})")

if case_legend_list:
    if len(case_name_list) != len(case_legend_list):
        raise RuntimeError("case_legend_list must be same length as case_name_list")
else:
    case_legend_list = case_name_list

### 1.3 Import stuff from externals

In [None]:
sys.path.append(externals_path)
import ctsm_postprocessing.utils as utils
from ctsm_postprocessing.crops import crop_secondary_variables as c2o
from ctsm_postprocessing.crops import cropcase
import ctsm_postprocessing.crops.faostat as faostat
from ctsm_postprocessing.resolutions import identify_resolution

## 2. Import case data

### 2.1 Import cases

In [None]:
importlib.reload(cropcase)

start = time()
case_list = []
for i, case in enumerate(case_name_list):
    print(f"Importing {case}...")
    case_output_dir = os.path.join(
        CESM_output_dir,
        case,
        "lnd",
        "hist",
    )
    case_list.append(
        cropcase.CropCase(
            case,
            case_output_dir,
            clm_file_h,
            cfts_to_include,
            crops_to_include,
            start_year,
            end_year,
            verbose=verbose,
        )
    )

    if dev_mode:
        start_load = time()
        print("Loading...")
        case_list[-1].cft_ds.load()
        end_load = time()
        print(f"Loading took {int(end_load - start_load)} s")

    # Get gridcell area
    ds = case_list[-1].cft_ds
    area_g = []
    for i, lon in enumerate(ds["grid1d_lon"].values):
        lat = ds["grid1d_lat"].values[i]
        area_g.append(ds["area"].sel(lat=lat, lon=lon))
    area_g = np.array(area_g)
    area_p = []
    for i in ds["pfts1d_gi"].isel(cft=0).values:
        area_p.append(area_g[int(i) - 1])
    area_p = np.array(area_p)
    ds["pfts1d_gridcellarea"] = xr.DataArray(
        data=area_p,
        coords={"pft": ds["pft"].values},
        dims=["pft"],
    )

    # Get resolution
    ds.attrs["resolution"] = identify_resolution(ds)

print("Done.")
if verbose:
    end = time()
    print(f"Importing took {int(end - start)} s")

### 2.3 Import FAOSTAT

In [None]:
fao_file = os.path.join(
    obs_data_dir,
    "lnd",
    "analysis_datasets",
    "ungridded",
    "timeseries",
    "FAOSTAT",
    "Production_Crops_Livestock_2025-02-25",
    "norm",
    "Production_Crops_Livestock_E_All_Data_(Normalized).csv",
)

fao = faostat.FaostatProductionCropsLivestock(
    fao_file,
    y1=start_year,
    yN=end_year,
)

fao_prod = fao.get_element("Production", fao_to_clm_dict=fao_to_clm_dict)
fao_area = fao.get_element("Area harvested", fao_to_clm_dict=fao_to_clm_dict)

# Only include where both production and area data are present
def drop_a_where_not_in_b(a, b):
    return a.drop([i for i in a.index.difference(b.index)])


fao_prod = drop_a_where_not_in_b(fao_prod, fao_area)
fao_area = drop_a_where_not_in_b(fao_area, fao_prod)
if not fao_prod.index.equals(fao_area.index):
    raise RuntimeError("Mismatch of prod and area indices after trying to align them")

# Don't allow production where no area
is_bad = (fao_prod["Value"] > 0) & (fao_area["Value"] == 0)
where_bad = np.where(is_bad)[0]
bad_prod = fao_prod.iloc[where_bad]
bad_area = fao_area.iloc[where_bad]
fao_prod = fao_prod[~is_bad]
fao_area = fao_area[~is_bad]
if not fao_prod.index.equals(fao_area.index):
    raise RuntimeError(
        "Mismatch of prod and area indices after disallowing production where no area"
    )

# Get yield
fao_yield = fao_prod.copy()
fao_yield["Element"] = "Yield"
fao_yield["Unit"] = "/".join([fao_prod["Unit"].iloc[0], fao_area["Unit"].iloc[0]])
fao_yield["Value"] = fao_prod["Value"] / fao_area["Value"]

### 2.3 Import EarthStat (basically gridded FAOSTAT)

In [None]:
earthstat_dir = os.path.join(
    obs_data_dir,
    "lnd",
    "analysis_datasets",
    "multi_grid",
    "annual",
    "FAO-EarthStatYields",
)

# Import EarthStat crop list
earthstat_crop_list = []
earthstat_crop_list_file = os.path.join(
    earthstat_dir, "EARTHSTATMIRCAUNFAO_croplist.txt"
)
with open(earthstat_crop_list_file, "r", encoding="utf-8") as f:
    for line in f:
        # Strip leading/trailing whitespace
        line = line.strip()
        if not line:
            continue  # skip blank lines
        # Split into number and name(s)
        parts = line.split(maxsplit=1)
        if len(parts) != 2:
            raise RuntimeError(
                "Failed to parse this line in earthstat_crop_list_file: " + line
            )
        earthstat_crop_list.append(parts[1].lower())
# Replace some names to match those in CLM
for i, crop in enumerate(earthstat_crop_list):
    if crop == "maize":
        earthstat_crop_list[i] = "corn"
    elif crop == "soybeans":
        earthstat_crop_list[i] = "soybean"
    elif crop == "sugar cane":
        earthstat_crop_list[i] = "sugarcane"
# Check that all CLM crops are in earthstat_crop_list
for crop in crops_to_include:
    if crop not in earthstat_crop_list:
        print(f"WARNING: {crop} not found in earthstat_crop_list")

# Import EarthStat maps
resolutions = set([case.cft_ds.attrs["resolution"].name for case in case_list])
earthstat = {}
for res in resolutions:
    # For now, can only import f09 EarthStat. Will skip maps comparing EarthStat to output from other resolutions.
    if res != "f09":
        continue
    print(f"Importing EarthStat yield maps for resolution {res}...")

    # Open file
    earthstat[res] = xr.open_dataset(os.path.join(earthstat_dir, res + ".nc"))
print("Done.")

## 3. Yield time series

In [None]:
if verbose:
    start = time()

# Get figure layout info
if 5 <= len(crops_to_include) <= 6:
    nrows = 2
    ncols = 3
    height = 10.5
    width = 15
    hspace = 0.25
    wspace = 0.35
else:
    raise RuntimeError(f"Specify figure layout for Ncrops=={len(crops_to_include)}")
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(width, height))

fao_yield_world = fao_yield.query("Area == 'World'")

for i, crop in enumerate(crops_to_include):
    ax = axes.ravel()[i]
    plt.sca(ax)

    # Plot case data
    for c, case in enumerate(case_list):
        ds = case.cft_ds.sel(cft=case.crop_list[crop].pft_nums)

        cft_area = ds["pfts1d_gridcellarea"] * ds["pfts1d_wtgcell"]
        cft_prod = ds["YIELD_ANN"] * cft_area
        crop_prod_ts = cft_prod.sum(dim=["cft", "pft"])
        crop_area_ts = cft_area.sum(dim=["cft", "pft"])
        crop_yield_ts = crop_prod_ts / crop_area_ts

        # Plot data
        crop_yield_ts *= 1e-6 * 1e4  # Convert g/m2 to tons/ha
        crop_yield_ts.name = "Yield"
        ctsm_units = "t/ha"
        crop_yield_ts.attrs["units"] = ctsm_units

        # Change line style for one line that overlaps another for some crops
        if "clm6_crop_032_nomaxlaitrig" in case_name_list and case_name_list[
            c
        ].endswith("clm6_crop_032_nmlt_phaseparams"):
            linestyle = "--"
        else:
            linestyle = "-"

        # Plot
        crop_yield_ts.plot(linestyle=linestyle)

    # Plot FAOSTAT data
    faostat_units = fao_yield_world["Unit"].iloc[0]
    if faostat_units != ctsm_units:
        raise RuntimeError(
            f"CTSM units ({ctsm_units}) do not match FAOSTAT units ({faostat_units})"
        )
    fao_yield_world_thiscrop = fao_yield_world.query(f"Crop == '{crop}'")
    ax.plot(
        crop_yield_ts.time,
        fao_yield_world_thiscrop["Value"].values,
        "-k",
    )

    # Finish plot
    ax.set_title(crop)
    plt.xlabel("")

plt.subplots_adjust(wspace=wspace, hspace=hspace)
fig.legend(
    labels=case_legend_list + ["FAOSTAT"],
    loc="upper center",
    bbox_to_anchor=(0.5, 0.96),
    ncol=3,
    bbox_transform=fig.transFigure,
)
fig.suptitle("Global yield", fontsize="x-large", fontweight="bold")
plt.show()

if verbose:
    end = time()
    print(f"Time series plots took {int(end - start)} s")

## 4. Yield maps

In [None]:
importlib.reload(utils)

# Get figure layout info
N_cases = len(case_name_list)
if 3 <= N_cases <= 4:
    nrows = 2
    subplots_adjust_colorbar_top = 0.95
    subplots_adjust_colorbar_bottom = 0.2
    cbar_ax_rect = (0.2, 0.15, 0.6, 0.03)
else:
    raise RuntimeError(f"Specify figure layout for N_cases=={N_cases}")
height = 3.75 * nrows
ncols = 2
width = 15
hspace = 0
wspace = 0

if verbose:
    start_all = time()
for crop in crops_to_include:
    fig_yieldmaps_clm, axes_yieldmaps_clm = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        figsize=(width, height),
        subplot_kw={"projection": ccrs.PlateCarree()},
    )
    if verbose:
        start = time()
        print(crop)

    # Get maps and colorbar min/max (the latter should cover total range across ALL cases)
    vmin_yieldmaps_clm = np.inf
    vmax_yieldmaps_clm = -np.inf
    result_dict_yieldmaps_clm = {}
    for i, case in enumerate(case_list):
        ds = case.cft_ds.sel(cft=case.crop_list[crop].pft_nums)
        ds = ds.drop_vars(["date_written", "time_written"])
        # if verbose:
        #     print("getting mean across time")
        cft_yield = ds["YIELD_ANN"].mean(dim="time")
        # if verbose:
        #     print("getting CFT-weighted mean")
        ds = ds.mean(dim="time")
        ds["wtd_yield_across_cfts"] = cft_yield.weighted(ds["pfts1d_wtgcell"]).mean(
            dim="cft"
        )
        # if verbose:
        #     print("gridding yields")
        yieldmap_clm = utils.grid_one_variable(ds, "wtd_yield_across_cfts")
        yieldmap_clm = utils.lon_pm2idl(yieldmap_clm)
        yieldmap_clm *= 1e-6 * 1e4  # Convert g/m2 to t/ha
        yieldmap_clm.name = "Yield"
        yieldmap_clm.attrs["units"] = "tons / ha"

        # Cut off Antarctica
        antarctica_border = -60
        if np.isnan(yieldmap_clm.sel(lat=slice(-90, antarctica_border)).max()):
            result = yieldmap_clm.sel(lat=slice(antarctica_border, 90))

        # Save
        result_dict_yieldmaps_clm[case_name_list[i]] = yieldmap_clm
        vmin_yieldmaps_clm = min(vmin_yieldmaps_clm, np.nanmin(yieldmap_clm.values))
        vmax_yieldmaps_clm = max(vmax_yieldmaps_clm, np.nanmax(yieldmap_clm.values))

    # Plot
    for i, ax in enumerate(axes_yieldmaps_clm.ravel()):
        try:
            case = case_list[i]
        except IndexError:
            ax.set_visible(False)
            continue
        plt.sca(ax)
        # if verbose:
        #     print("plotting")
        case_name = case_name_list[i]
        result = result_dict_yieldmaps_clm[case_name]
        im = result.plot(
            ax=ax,
            transform=ccrs.PlateCarree(),
            vmin=vmin_yieldmaps_clm,
            vmax=vmax_yieldmaps_clm,
            add_colorbar=False,
        )
        ax.coastlines(linewidth=0.5)
        plt.title(case_name)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlabel("")
        ax.set_ylabel("")

    # Finish up
    fig_yieldmaps_clm.suptitle(crop, fontsize="x-large", fontweight="bold")
    end = time()
    fig_yieldmaps_clm.subplots_adjust(
        top=subplots_adjust_colorbar_top,
        bottom=subplots_adjust_colorbar_bottom,
    )
    cbar_ax = fig_yieldmaps_clm.add_axes(rect=cbar_ax_rect)
    fig_yieldmaps_clm.colorbar(
        im,
        cax=cbar_ax,
        orientation="horizontal",
        label=f"{result.name} ({result.attrs['units']})",
    )
    if verbose:
        print(f"{crop} took {end - start} s")
    plt.show()

if verbose:
    end_all = time()
    print(f"Maps took {int(end_all - start_all)} s.")