# Compare crop yields to observations

- Uses raw annual CTSM outputs (NOT timeseries files).

Notebook created by Sam Rabin (samrabin@ucar.edu).

In [None]:
import sys
import os
import glob
import numpy as np
import xarray as xr
import pandas as pd
from time import time
from dask.distributed import Client, wait
import convert_pft1d_to_sparse
import importlib

# Plotting utils
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature

ctsm_python_gallery_path = (
    "/glade/work/samrabin/cupid_crops/externals/ctsm_python_gallery"
)
sys.path.append(ctsm_python_gallery_path)
import ctsm_py.utils
import ctsm_py.crop_secondary_variables as c2o
import ctsm_py.mark_crops_invalid as mci

# Start a local Dask cluster using all available cores
client = Client()
client

## 1. Settings

### 1.1 Parameters modifiable in config.yml

In [None]:
# Where land output is stored
CESM_output_dir = os.path.join(
    os.path.sep,
    "glade",
    "work",
    "samrabin",
    "clm6_crop_reparam_outputs",
)

# Full casenames that are present in CESM_output_dir and in individual filenames
case_name_list = [
    "ctsm53019_f09_BNF_hist",
    "clm6_crop_032",
    "clm6_crop_032_nomaxlaitrig",
]

clm_file_h = ".h0."

# The actual netCDF timesteps, not the names of the files
start_year = 1961
end_year = 2023

cfts_to_include = [
    "temperate_corn",
    "tropical_corn",
    "cotton",
    "rice",
    "temperate_soybean",
    "tropical_soybean",
    "sugarcane",
    "spring_wheat",
    "irrigated_temperate_corn",
    "irrigated_tropical_corn",
    "irrigated_cotton",
    "irrigated_rice",
    "irrigated_temperate_soybean",
    "irrigated_tropical_soybean",
    "irrigated_sugarcane",
    "irrigated_spring_wheat",
]

crops_to_include = [
    "corn",
    "cotton",
    "rice",
    "soybean",
    "sugarcane",
    "wheat",
]
fao_to_clm_dict = {
    "Maize": "corn",
    "Rice": "rice",
    "Seed cotton, unginned": "cotton",
    "Soya beans": "soybean",
    "Sugar cane": "sugarcane",
    "Wheat": "wheat",
}

dev_mode = True
verbose = True

obs_data_dir = os.path.join(
    os.sep + "glade",
    "campaign",
    "cesm",
    "development",
    "cross-wg",
    "diagnostic_framework",
    "CUPiD_obs_data",
)

### 1.2 Other settings

In [None]:
# Set up directory for any scratch output
if "SCRATCH" in os.environ:
    cupid_temp = os.path.join(os.environ["SCRATCH"], "CUPiD_scratch")
    os.makedirs(cupid_temp, exist_ok=True)
else:
    cupid_temp = "."

N_PFTS = 78

short_names = [case.split(".")[-1] for case in case_name_list]

if start_year > end_year:
    raise RuntimeError(f"start_year ({start_year}) > end_year ({end_year})")

## 2. Import case data

### 2.1 Set up classes etc.

In [None]:
importlib.reload(c2o)


class Cft:
    def __init__(self, name, cft_num):
        self.name = name

        # 1-indexed in the FORTRAN style
        self.cft_num = cft_num
        self.pft_num = None  # Need to know max cft_num

        # 0-indexed in the Python style
        self.pft_ind = None  # Need to know pft_num
        self.where = None

    def __str__(self):
        return "\n".join(
            [
                self.name + ":",
                f"   cft_num: {self.cft_num}",
                f"   pft_num: {self.pft_num}",
                f"   pft_ind: {self.pft_ind}",
                f"   N cells: {len(self.where)}",
            ]
        )

    def update_pft(self, n_non_crop_pfts):
        """
        You don't know n_non_crop_pfts until after reading in all CFTs, so
        this function gets called once that's done in CftList.__init__().
        """
        self.pft_num = n_non_crop_pfts + self.cft_num - 1
        self.pft_ind = self.pft_num - 1

    def get_where(self, ds):
        """
        Get the indices on the pft dimension corresponding to this CFT
        """
        if self.pft_num is None:
            raise RuntimeError(
                "get_where() can't be run until after calling Crop.update_pft()"
            )
        pfts1d_itype_veg = ds["pfts1d_itype_veg"]
        if "time" in pfts1d_itype_veg.dims:
            pfts1d_itype_veg = pfts1d_itype_veg.isel(time=0)
        self.where = np.where(pfts1d_itype_veg.values == self.pft_num)[0].astype(int)
        return self.where


class CftList:
    def __init__(self, ds, n_pfts, cfts_to_include):
        # Get list of all possible CFTs
        self.cft_list = []
        for i, (key, value) in enumerate(ds.attrs.items()):
            if not key.startswith("cft_"):
                continue
            cft_name = key[4:]
            self.cft_list.append(Cft(cft_name, value))

        # Ensure that all CFTs in cfts_to_include are present
        cfts_in_file = [x.name for x in self.cft_list]
        missing_cfts = [x for x in cfts_to_include if x not in cfts_in_file]
        if missing_cfts:
            msg = (
                "The following are in cfts_to_include but not the dataset: "
                + ", ".join(missing_cfts)
            )
            raise KeyError(msg)

        # Figure out PFT indices
        max_cft_num = max([x.cft_num for x in self.cft_list])
        n_non_crop_pfts = n_pfts - max_cft_num + 1  # Incl. unvegetated
        for cft in self.cft_list:
            cft = cft.update_pft(n_non_crop_pfts)

        # Only include CFTs we care about
        if len(cfts_to_include) != len(np.unique(cfts_to_include)):
            raise ValueError("Duplicate CFT(s) in cfts_to_include")
        self.cft_list = [x for x in self.cft_list if x.name in cfts_to_include]

        # Figure out where the pft index is each CFT
        for cft in self.cft_list:
            cft.get_where(ds)
            if len(cft.where) == 0:
                print("Warning: No occurrences found of " + cft.name)

    def __getitem__(self, index):
        return self.cft_list[index]

    def __str__(self):
        results = []
        for cft in self.cft_list:
            results.append(str(cft))
        return "\n".join(results)


def get_cft_ds(ds, cft):
    ds = ds.isel(pft=cft.where)
    ds["cft"] = cft.pft_num
    ds = ds.set_coords("cft")
    for var in ds:
        if "pft" in ds[var].dims:
            ds[var] = ds[var].assign_coords(cft=cft.pft_num)
            ds[var] = ds[var].expand_dims("cft")
    return ds


class Crop:
    def __init__(self, name, cft_list, ds):
        self.name = name

        # Get CFTs included in this crop
        self.cft_list = []
        for cft in cft_list:
            if self.name not in cft.name:
                continue
            self.cft_list.append(cft)

        # Get information for all CFTs in this crop
        self.cft_names = []
        self.pft_nums = []
        self.pft_inds = []
        self.where = np.array([], dtype=np.int64)
        for cft in self.cft_list:
            self.cft_names.append(cft.name)
            self.pft_nums.append(cft.pft_num)
            self.pft_inds.append(cft.pft_ind)
            self.where = np.append(self.where, cft.get_where(ds))
        self.where = np.sort(self.where)

    def __str__(self):
        return f"{self.name}: {', '.join(f'{x.name} ({x.pft_num})' for x in self.cft_list)}"


class CropList:
    def __init__(self, crops_to_include, cft_list, ds):
        if len(crops_to_include) != len(np.unique(crops_to_include)):
            raise ValueError("Duplicate crop(s) found in crops_to_include")
        self.crop_list = [Crop(x, cft_list, ds) for x in crops_to_include]

    def __getitem__(self, index):
        if isinstance(index, str):
            found = False
            for i, crop in enumerate(self.crop_list):
                found = crop.name == index
                if found:
                    break
            if not found:
                raise KeyError(f"No crop found matching '{index}'")
            return self.crop_list[i]
        return self.crop_list[index]

    def __str__(self):
        results = []
        for crop in self.crop_list:
            results.append(str(crop))
        return "\n".join(results)


def mf_preproc(ds):
    ds["pfts1d_wtgcell"] = ds["pfts1d_wtgcell"].expand_dims(
        dim="time", axis=0
    )  # .assign_coords({"time": ds["time"]})
    vars_to_drop = [x for x in ds if any("lev" in d for d in ds[x].dims)]
    ds = ds.drop_vars(vars_to_drop)
    return ds


class Case:
    def __init__(
        self,
        name,
        CESM_output_dir,
        clm_file_h,
        cfts_to_include,
        crops_to_include,
        start_year,
        end_year,
    ):
        # Get list of all time series files
        file_dir = os.path.join(
            CESM_output_dir,
            name,
            "lnd",
            "hist",
        )
        file_pattern = os.path.join(file_dir, name + ".clm2" + clm_file_h + "*.nc")
        file_list = np.sort(glob.glob(file_pattern))
        if len(file_list) == 0:
            raise FileNotFoundError("No files found matching pattern: " + file_pattern)

        # Get list of files to actually include
        self.file_list = []
        for filename in file_list:
            ds = xr.open_dataset(filename)
            if (
                ds.time.values[0].year <= end_year
                and start_year <= ds.time.values[-1].year
            ):
                self.file_list.append(filename)
        if not self.file_list:
            raise FileNotFoundError(
                f"No files found with timestamps in {start_year}-{end_year}"
            )

        # Read files
        # Adding join="override", compat="override", coords="minimal", doesn't fix the graph size
        # Adding combine="nested", concat_dim="time" doesn't give time axis to only variables we want
        ds = xr.open_mfdataset(
            self.file_list,
            decode_times=True,
            chunks={},
            join="override",
            compat="override",
            coords="minimal",
            # combine="nested", concat_dim="time",
            data_vars="minimal",
            preprocess=mf_preproc,
        )

        # Get CFT info
        self.cft_list = CftList(ds, N_PFTS, cfts_to_include)

        # Get crop list
        self.crop_list = CropList(crops_to_include, self.cft_list, ds)

        # Save CFT dataset
        for i, cft in enumerate(self.cft_list):
            this_cft_ds = get_cft_ds(ds, cft)

            if i == 0:
                self.cft_ds = this_cft_ds.copy()
                n_expected = self.cft_ds.sizes["pft"]
            else:
                # Check # of gridcells with this PFT
                n_this = this_cft_ds.sizes["pft"]
                if n_this != n_expected:
                    raise RuntimeError(
                        f"Expected {n_expected} gridcells with {cft.name}; found {n_this}"
                    )
                self.cft_ds = xr.concat(
                    [self.cft_ds, this_cft_ds],
                    dim="cft",
                    data_vars="minimal",
                    compat="override",
                    join="override",
                    coords="minimal",
                )

        # Load
        if dev_mode:
            if verbose:
                start = time()
                print("Loading...")
            self.cft_ds.load()
            if verbose:
                end = time()
                print(f"Loading took {int(end - start)} s")

        # Get secondary variables
        if verbose:
            start = time()
            print("Getting secondary variables")
        for var in ["HDATES", "SDATES_PERHARV"]:
            self.cft_ds[var] = self.cft_ds[var].where(self.cft_ds[var] >= 0)
        self.cft_ds["HUIFRAC_PERHARV"] = c2o.get_huifrac(self.cft_ds)
        self.cft_ds["GSLEN_PERHARV"] = c2o.get_gslen(self.cft_ds)
        if verbose:
            end = time()
            print(f"Secondary variables took {int(end - start)} s")

### 2.2 Import cases

In [None]:
start = time()
case_list = []
for i, case in enumerate(case_name_list):
    print(f"Importing {case}...")
    case_list.append(
        Case(
            case,
            CESM_output_dir,
            clm_file_h,
            cfts_to_include,
            crops_to_include,
            start_year,
            end_year,
        )
    )

    # Get gridcell area
    ds = case_list[-1].cft_ds
    area_g = []
    for i, lon in enumerate(ds["grid1d_lon"].values):
        lat = ds["grid1d_lat"].values[i]
        area_g.append(ds["area"].sel(lat=lat, lon=lon))
    area_g = np.array(area_g)
    area_p = []
    for i in ds["pfts1d_gi"].isel(cft=0).values:
        area_p.append(area_g[int(i) - 1])
    area_p = np.array(area_p)
    ds["pfts1d_gridcellarea"] = xr.DataArray(
        data=area_p,
        coords={"pft": ds["pft"].values},
        dims=["pft"],
    )

print("Done.")
if verbose:
    end = time()
    print(f"Importing took {int(end - start)} s")

### 2.3 Import FAOSTAT

In [None]:
this_dir = os.path.join(
    obs_data_dir,
    "lnd",
    "analysis_datasets",
    "ungridded",
    "timeseries",
    "FAOSTAT",
    "Production_Crops_Livestock_2025-02-25",
    "norm",
)

fao = pd.read_csv(
    os.path.join(this_dir, "Production_Crops_Livestock_E_All_Data_(Normalized).csv"),
    low_memory=False,
)

fao2 = fao.copy()

# Because it's easy to confuse Item vs. Element
fao2 = fao2.rename(columns={"Item": "Crop"}, errors="raise")

# Combine "Maize" and "Maize, green"
fao2.Crop = fao2.Crop.replace("Maize.*", "Maize", regex=True)
fao2 = fao2.groupby(by=["Crop", "Year", "Element", "Area", "Unit"], as_index=False).agg(
    "sum"
)

# Filter out "China," which includes all Chinas
if "China" in fao2.Area.values:
    fao2 = fao2.query('Area != "China"')

## 3. Yield time series

In [None]:
def fao_data_get(fao_all, element, y1, yN, fao_to_clm_dict, cropList_combined_clm):

    # Extract element of interest
    fao_this = fao_all.copy().query(f"Element == '{element}'")
    if fao_this.empty:
        raise KeyError(f"No FAOSTAT element found matching {element}")

    # Drop all but columns of interest
    fao_this = fao_this[["Crop", "Year", "Element", "Area", "Unit", "Value"]]

    # Extract crops of interest
    fao_crops = fao_this.Crop.unique()
    fao_crops_from_dict = [crop for crop in fao_to_clm_dict]
    missing_crops = [crop for crop in fao_crops_from_dict if crop not in fao_crops]
    if missing_crops:
        raise KeyError(
            f"Crops missing from FAOSTAT {element}: {'; '.join(missing_crops)}"
        )
    fao_this = fao_this[fao_this["Crop"].isin(fao_crops_from_dict)]
    if len(fao_this.Crop.unique()) != len(cropList_combined_clm):
        raise RuntimeError(
            "Unexpected # crops in FAOSTAT after extracting crops of interest"
        )

    # Remove unneeded years
    fao_this = fao_this.query(f"Year >= {y1} & Year <= {yN}")
    if fao_this.empty:
        raise KeyError(f"No FAOSTAT years found in {y1}-{yN}")

    # Rename to match CLM
    fao_this["Crop"] = fao_this["Crop"].replace(fao_to_clm_dict)

    # Set index
    fao_this = fao_this.set_index(["Crop", "Year", "Area"])

    return fao_this


fao_prod = fao_data_get(
    fao2, "Production", start_year, end_year, fao_to_clm_dict, crops_to_include
)
fao_area = fao_data_get(
    fao2, "Area harvested", start_year, end_year, fao_to_clm_dict, crops_to_include
)

# Only include where both production and area data are present
def drop_a_where_not_in_b(a, b):
    return a.drop([i for i in a.index.difference(b.index)])


fao_prod = drop_a_where_not_in_b(fao_prod, fao_area)
fao_area = drop_a_where_not_in_b(fao_area, fao_prod)
if not fao_prod.index.equals(fao_area.index):
    raise RuntimeError("Mismatch of prod and area indices after trying to align them")

# Don't allow production where no area
is_bad = (fao_prod["Value"] > 0) & (fao_area["Value"] == 0)
where_bad = np.where(is_bad)[0]
bad_prod = fao_prod.iloc[where_bad]
bad_area = fao_area.iloc[where_bad]
fao_prod = fao_prod[~is_bad]
fao_area = fao_area[~is_bad]
if not fao_prod.index.equals(fao_area.index):
    raise RuntimeError(
        "Mismatch of prod and area indices after disallowing production where no area"
    )

# Get yield
fao_yield = fao_prod.copy()
fao_yield["Element"] = "Yield"
fao_yield["Unit"] = "/".join([fao_prod["Unit"].iloc[0], fao_area["Unit"].iloc[0]])
fao_yield["Value"] = fao_prod["Value"] / fao_area["Value"]

In [None]:
if verbose:
    start = time()

# Get figure layout info
if 5 <= len(crops_to_include) <= 6:
    nrows = 2
    ncols = 3
    height = 10.5
    width = 15
    hspace = 0.25
    wspace = 0.35
else:
    raise RuntimeError(f"Specify figure layout for Ncrops=={len(crops_to_include)}")
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(width, height))

fao_yield_world = fao_yield.query("Area == 'World'")

for i, crop in enumerate(crops_to_include):
    ax = axes.ravel()[i]
    plt.sca(ax)

    # Plot case data
    for case in case_list:
        ds = case.cft_ds.sel(cft=case.crop_list[crop].pft_nums)

        cft_area = ds["pfts1d_gridcellarea"] * ds["pfts1d_wtgcell"]
        if np.any(ds["GRAINC_TO_FOOD_PERHARV"] < 0):
            raise ValueError("Unexpected negative value(s) in GRAINC_TO_FOOD_PERHARV")
        cft_yield = mci.mark_crops_invalid(
            ds,
            "GRAINC_TO_FOOD_PERHARV",
            min_viable_hui="isimip3",
            this_pft=crop,
            invalid_value=0,
        )
        cft_prod = cft_yield.sum(dim="mxharvests", skipna=True) * cft_area
        cft_prod = ctsm_py.utils.food_grainc_to_harvested_tons_onecrop(cft_prod, crop)
        crop_prod_ts = cft_prod.sum(dim=["cft", "pft"])
        crop_area_ts = cft_area.sum(dim=["cft", "pft"])
        crop_yield_ts = crop_prod_ts / crop_area_ts

        # Plot data
        crop_yield_ts *= 1e-6 * 1e4  # Convert g/m2 to tons/ha
        crop_yield_ts.name = "Yield"
        ctsm_units = "t/ha"
        crop_yield_ts.attrs["units"] = ctsm_units
        crop_yield_ts.plot()

    # Plot FAOSTAT data
    faostat_units = fao_yield_world["Unit"].iloc[0]
    if faostat_units != ctsm_units:
        raise RuntimeError(
            f"CTSM units ({ctsm_units}) do not match FAOSTAT units ({faostat_units})"
        )
    fao_yield_world_thiscrop = fao_yield_world.query(f"Crop == '{crop}'")
    ax.plot(
        crop_yield_ts.time,
        fao_yield_world_thiscrop["Value"].values,
        "-k",
    )

    # Finish plot
    ax.set_title(crop)
    ax.legend(case_name_list + ["FAOSTAT"])
    plt.xlabel("")

plt.subplots_adjust(wspace=wspace, hspace=hspace)
plt.show()

if verbose:
    end = time()
    print(f"Time series plots took {int(end - start)} s")

## 4. Yield maps

In [None]:
importlib.reload(ctsm_py.utils)
importlib.reload(ctsm_py.crop_secondary_variables)
importlib.reload(ctsm_py.mark_crops_invalid)

# Get figure layout info
N_cases = len(case_name_list)
if 2 <= N_cases <= 3:
    nrows = 2
    ncols = 2
    height = 8
    width = 15
    hspace = 2
    wspace = 0
else:
    raise RuntimeError(f"Specify figure layout for N_cases=={N_cases}")

if verbose:
    start_all = time()
for crop in crops_to_include:
    fig, axes = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        figsize=(width, height),
        subplot_kw={"projection": ccrs.PlateCarree()},
    )
    if verbose:
        start = time()
        print(crop)
    for i, case in enumerate(case_list):
        ax = axes.ravel()[i]
        plt.sca(ax)
        ds = case.cft_ds.sel(cft=case.crop_list[crop].pft_nums)
        ds = ds.drop_vars(["date_written", "time_written"])
        # if verbose:
        #     print("getting mean across time")
        if np.any(ds["GRAINC_TO_FOOD_PERHARV"] < 0):
            raise ValueError("Unexpected negative value(s) in GRAINC_TO_FOOD_PERHARV")
        cft_yield = ctsm_py.mark_crops_invalid.mark_crops_invalid(
            ds,
            "GRAINC_TO_FOOD_PERHARV",
            min_viable_hui="isimip3",
            this_pft=crop,
            invalid_value=0,
        )
        cft_yield = cft_yield.sum(dim="mxharvests", skipna=True).mean(dim="time")
        # if verbose:
        #     print("getting CFT-weighted mean")
        ds = ds.mean(dim="time")
        ds["wtd_yield_across_cfts"] = cft_yield.weighted(ds["pfts1d_wtgcell"]).mean(
            dim="cft"
        )
        # if verbose:
        #     print("gridding yields")
        result = ctsm_py.utils.grid_one_variable(ds, "wtd_yield_across_cfts")
        result = ctsm_py.utils.lon_pm2idl(result)
        result = ctsm_py.utils.food_grainc_to_harvested_tons_onecrop(result, crop)

        # Cut off Antarctica
        antarctica_border = -60
        if np.isnan(result.sel(lat=slice(-90, antarctica_border)).max()):
            result = result.sel(lat=slice(antarctica_border, 90))

        # Plot data
        result *= 1e-6 * 1e4  # Convert g/m2 to t/ha
        result.name = "Yield"
        result.attrs["units"] = "tons / ha"
        # if verbose:
        #     print("plotting")
        result.plot(
            ax=ax,
            transform=ccrs.PlateCarree(),
            cbar_kwargs={"location": "bottom", "pad": 0},
        )
        ax.coastlines(linewidth=0.5)
        plt.title(case_name_list[i])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlabel("")
        ax.set_ylabel("")
    fig.suptitle(crop, fontsize="x-large", fontweight="bold")
    end = time()
    if verbose:
        print(f"{crop} took {end - start} s")

plt.show()

if verbose:
    end_all = time()
    print(f"Maps took {int(end_all - start_all)} s.")