# Compare crop yields to observations

- Uses raw annual CTSM outputs (NOT timeseries files).

Notebook created by Sam Rabin (samrabin@ucar.edu).

In [None]:
import sys
import os
import glob
import numpy as np
import xarray as xr
import pandas as pd
from time import time
from dask.distributed import Client, wait
import convert_pft1d_to_sparse
import importlib
import earthstat
import caselist
import crop_timeseries_figs
import plotting_utils

# Plotting utils
import matplotlib.pyplot as plt
import clm_and_earthstat_maps as caem

# Start a local Dask cluster using all available cores
client = Client()
client

## 1. Settings

### 1.1 Parameters modifiable in config.yml

In [None]:
# Path to CUPiD externals. This method is supposedly unreliable, so it's best for this to be
# overridden by a value given in config.yml. See examples/crops/config.yml.
externals_path = os.path.join(os.getcwd(), os.pardir, os.pardir, "externals")

# Where land output is stored
CESM_output_dir = os.path.join(
    os.path.sep,
    "glade",
    "work",
    "samrabin",
    "clm6_crop_reparam_outputs",
)

# Full casenames that are present in CESM_output_dir and in individual filenames
case_name_list = [
    # "ctsm53019_f09_BNF_hist",
    "clm6_crop_032",
    # "clm6_crop_032_nomaxlaitrig",
    # "clm6_crop_032_nmlt_phaseparams",
    "alpha-ctsm5.4.CMIP7.09.ctsm5.3.068",
    "alpha-ctsm5.4.CMIP7.09.ctsm5.3.068_nogddadapt",
]

# Names of cases to show in figure legends
case_legend_list = [
    # "A: ctsm5.3.019 (GSWP3)",
    "B: ctsm5.3.032 (CRU-JRA)",
    # "C: As B + w/o max LAI triggering grainfill",
    # "D: As C + pre-CLM5 crop phase params",
    "E: 5.4 branch",
    "F: 5.4 branch (no GDD adapt)",
]

# The actual netCDF timesteps, not the names of the files
start_year = 1961
end_year = 2023

cfts_to_include = [
    "temperate_corn",
    "tropical_corn",
    "cotton",
    "rice",
    "temperate_soybean",
    "tropical_soybean",
    "sugarcane",
    "spring_wheat",
    "irrigated_temperate_corn",
    "irrigated_tropical_corn",
    "irrigated_cotton",
    "irrigated_rice",
    "irrigated_temperate_soybean",
    "irrigated_tropical_soybean",
    "irrigated_sugarcane",
    "irrigated_spring_wheat",
]

crops_to_include = [
    "corn",
    "cotton",
    "rice",
    "soybean",
    "sugarcane",
    "wheat",
]
fao_to_clm_dict = {
    "Maize": "corn",
    "Rice": "rice",
    "Seed cotton, unginned": "cotton",
    "Soya beans": "soybean",
    "Sugar cane": "sugarcane",
    "Wheat": "wheat",
}

dev_mode = True
verbose = True

obs_data_dir = os.path.join(
    os.sep + "glade",
    "campaign",
    "cesm",
    "development",
    "cross-wg",
    "diagnostic_framework",
    "CUPiD_obs_data",
)

### 1.2 Other settings

In [None]:
# Set up directory for any scratch output
if "SCRATCH" in os.environ:
    cupid_temp = os.path.join(os.environ["SCRATCH"], "CUPiD_scratch")
    os.makedirs(cupid_temp, exist_ok=True)
else:
    cupid_temp = "."

N_PFTS = 78

short_names = [case.split(".")[-1] for case in case_name_list]

if start_year > end_year:
    raise RuntimeError(f"start_year ({start_year}) > end_year ({end_year})")

if case_legend_list:
    if len(case_name_list) != len(case_legend_list):
        raise RuntimeError("case_legend_list must be same length as case_name_list")
else:
    case_legend_list = case_name_list

In [None]:
# Move options to dict for easier passing among functions
opts = {}
opts["CESM_output_dir"] = CESM_output_dir
del CESM_output_dir
opts["case_name_list"] = case_name_list
del case_name_list
opts["case_legend_list"] = case_legend_list
del case_legend_list
opts["start_year"] = start_year
del start_year
opts["end_year"] = end_year
del end_year
opts["cfts_to_include"] = cfts_to_include
del cfts_to_include
opts["crops_to_include"] = crops_to_include
del crops_to_include
opts["fao_to_clm_dict"] = fao_to_clm_dict
del fao_to_clm_dict
opts["dev_mode"] = dev_mode
del dev_mode
opts["verbose"] = verbose
del verbose
opts["obs_data_dir"] = obs_data_dir
del obs_data_dir

### 1.3 Import stuff from externals

In [None]:
sys.path.append(externals_path)
import ctsm_postprocessing.utils as utils
from ctsm_postprocessing.crops import crop_secondary_variables as c2o
from ctsm_postprocessing.crops import cropcase
import ctsm_postprocessing.crops.faostat as faostat
from ctsm_postprocessing.resolutions import identify_resolution

## 2. Import case data

### 2.1 Import cases

In [None]:
importlib.reload(cropcase)
importlib.reload(caselist)

case_list = caselist.CaseList(
    CropCase=cropcase.CropCase,
    identify_resolution=identify_resolution,
    opts=opts,
)

In [None]:
# Calculate some extra variables
# TODO: Move these calculations to CropCase
for case in case_list:
    case_ds = case.cft_ds
    
    # Set up for adding cft_crop variable
    cft_crop_array = np.full(case_ds.sizes["cft"], "", dtype=object)

    for i, crop in enumerate(opts["crops_to_include"]):
        # Get data for CFTs of this crop
        pft_nums = case.crop_list[crop].pft_nums
        cft_ds = case_ds.sel(cft=pft_nums)

        # Save name of this crop for cft_crop variable
        for pft_num in pft_nums:
            cft_crop_array[np.where(case_ds["cft"].values == pft_num)] = crop

        # Get area
        cft_area = cft_ds["pfts1d_gridcellarea"] * cft_ds["pfts1d_wtgcell"]
        cft_area *= 1e6  # Convert km2 to m2
        cft_area.attrs["units"] = "m2"

        # Get production
        cft_prod = cft_ds["YIELD_ANN"] * cft_area
        cft_prod.attrs["units"] = "g"

        # Setup crop_cft_* variables or append to them
        cft_area_expanded = cft_area.expand_dims(dim="crop", axis=0)
        cft_prod_expanded = cft_prod.expand_dims(dim="crop", axis=0)
        if i == 0:

            # Add crop (names) variable/dimension to case_ds, if needed
            if "crop" not in case_ds:
                crop_da = xr.DataArray(
                    data=opts["crops_to_include"],
                    dims=["crop"],
                )
                case_ds["crop"] = crop_da

            # Define crop_cft_* variables
            crop_cft_area_da = xr.DataArray(
                data=cft_area_expanded,
            )
            crop_cft_prod_da = xr.DataArray(
                data=cft_prod_expanded,
            )
        else:
            # Append this crop's DataArrays to existing ones
            crop_cft_area_da = xr.concat(
                [crop_cft_area_da, cft_area_expanded],
                dim="crop",
            )
            crop_cft_prod_da = xr.concat(
                [crop_cft_prod_da, cft_prod_expanded],
                dim="crop",
            )

    # Add crop_cft_* variables to case_ds
    case_ds["crop_cft_area"] = crop_cft_area_da
    case_ds["crop_cft_prod"] = crop_cft_prod_da

    # Calculate CFT-level yield
    case_ds["crop_cft_yield"] = crop_cft_prod_da / crop_cft_area_da
    case_ds["crop_cft_yield"].attrs["units"] = (
        crop_cft_prod_da.attrs["units"] + "/" + crop_cft_area_da.attrs["units"]
    )

    # Collapse CFTs to individual crops
    case_ds["crop_area"] = crop_cft_area_da.sum(dim="cft", keep_attrs=True)
    case_ds["crop_prod"] = crop_cft_prod_da.sum(dim="cft", keep_attrs=True)

    # Calculate crop-level yield
    case_ds["crop_yield"] = case_ds["crop_prod"] / case_ds["crop_area"]
    case_ds["crop_yield"].attrs["units"] = (
        case_ds["crop_prod"].attrs["units"] + "/" + case_ds["crop_area"].attrs["units"]
    )

    # Save cft_crop variable
    case.cft_ds["cft_crop"] = xr.DataArray(
        data=cft_crop_array,
        dims=["cft"],
        coords={"cft": case_ds["cft"]},
    )
    
    # Area harvested
    hr = case_ds["HARVEST_REASON_PERHARV"]
    cft_planted_area = (case_ds["pfts1d_gridcellarea"] * case_ds["pfts1d_wtgcell"]).where(case_ds["pfts1d_wtgcell"]>0) * 1e6  # convert km2 to m2
    cft_planted_area.attrs["units"] = "m2"
    case.cft_ds["cft_harv_area"] = (cft_planted_area * (hr > 0)).sum(dim="mxharvests")
    case.cft_ds["cft_harv_area_immature"] = (cft_planted_area * (hr > 1)).sum(dim="mxharvests")
    case.cft_ds["cft_harv_area_failed"] = (cft_planted_area * (1 - case_ds["VALID_HARVEST"]).where(hr > 0)).sum(dim="mxharvests")
    case.cft_ds["crop_harv_area"] = case.cft_ds["cft_harv_area"].groupby(case_ds["cft_crop"]).sum(dim="cft").rename({"cft_crop": "crop"})
    case.cft_ds["crop_harv_area_immature"] = case.cft_ds["cft_harv_area_immature"].groupby(case_ds["cft_crop"]).sum(dim="cft").rename({"cft_crop": "crop"})
    case.cft_ds["crop_harv_area_failed"] = case.cft_ds["cft_harv_area_failed"].groupby(case_ds["cft_crop"]).sum(dim="cft").rename({"cft_crop": "crop"})


### 2.3 Import FAOSTAT

In [None]:
fao_file = os.path.join(
    opts["obs_data_dir"],
    "lnd",
    "analysis_datasets",
    "ungridded",
    "timeseries",
    "FAOSTAT",
    "Production_Crops_Livestock_2025-02-25",
    "norm",
    "Production_Crops_Livestock_E_All_Data_(Normalized).csv",
)

fao = faostat.FaostatProductionCropsLivestock(
    fao_file,
    y1=opts["start_year"],
    yN=opts["end_year"],
)

# TODO: Move all the following to FaostatProductionCropsLivestock class

fao_prod = fao.get_element("Production", fao_to_clm_dict=opts["fao_to_clm_dict"])
fao_area = fao.get_element("Area harvested", fao_to_clm_dict=opts["fao_to_clm_dict"])

# Only include where both production and area data are present
def drop_a_where_not_in_b(a, b):
    return a.drop([i for i in a.index.difference(b.index)])


fao_prod = drop_a_where_not_in_b(fao_prod, fao_area)
fao_area = drop_a_where_not_in_b(fao_area, fao_prod)
if not fao_prod.index.equals(fao_area.index):
    raise RuntimeError("Mismatch of prod and area indices after trying to align them")

# Don't allow production where no area
is_bad = (fao_prod["Value"] > 0) & (fao_area["Value"] == 0)
where_bad = np.where(is_bad)[0]
bad_prod = fao_prod.iloc[where_bad]
bad_area = fao_area.iloc[where_bad]
fao_prod = fao_prod[~is_bad]
fao_area = fao_area[~is_bad]
if not fao_prod.index.equals(fao_area.index):
    raise RuntimeError(
        "Mismatch of prod and area indices after disallowing production where no area"
    )

# Get yield
fao_yield = fao_prod.copy()
fao_yield["Element"] = "Yield"
fao_yield["Unit"] = "/".join([fao_prod["Unit"].iloc[0], fao_area["Unit"].iloc[0]])
fao_yield["Value"] = fao_prod["Value"] / fao_area["Value"]

### 2.3 Import EarthStat (basically gridded FAOSTAT)

In [None]:
importlib.reload(earthstat)

earthstat_dir = os.path.join(
    opts["obs_data_dir"],
    "lnd",
    "analysis_datasets",
    "multi_grid",
    "annual",
    "FAO-EarthStatYields",
)

earthstat_data = earthstat.EarthStat(earthstat_dir, case_list.resolutions, opts)

## 3. Time series figures

In [None]:
importlib.reload(crop_timeseries_figs)

crop_timeseries_figs.main("yield", earthstat_data, case_list, fao_yield, opts)
crop_timeseries_figs.main("prod", earthstat_data, case_list, fao_prod, opts)
crop_timeseries_figs.main("area", earthstat_data, case_list, fao_area, opts)

## 4. Yield maps

In [None]:
importlib.reload(caem)

caem.clm_and_earthstat_maps(
    which="yield",
    case_list=case_list,
    earthstat_data=earthstat_data,
    utils=utils,
    opts=opts,
)

In [None]:
importlib.reload(caem)

caem.clm_and_earthstat_maps(
    which="prod",
    case_list=case_list,
    earthstat_data=earthstat_data,
    utils=utils,
    opts=opts,
)

In [None]:
importlib.reload(caem)

caem.clm_and_earthstat_maps(
    which="area",
    case_list=case_list,
    earthstat_data=earthstat_data,
    utils=utils,
    opts=opts,
)

## 5. Immature and failed harvests

In [None]:
importlib.reload(plotting_utils)

for crop in opts["crops_to_include"]:
    results = plotting_utils.ResultsMaps(case_list.mapfig_layout)
    for case in case_list:
        tmp = case.cft_ds.sel(crop=crop)
        tmp["frac_immature_harv_timemean"] = tmp["crop_harv_area_immature"].sum(dim="time") / tmp["crop_harv_area"].sum(dim="time")
        map_clm = utils.grid_one_variable(tmp, "frac_immature_harv_timemean")
        map_clm.attrs["units"] = "unitless"
        map_clm.name = "Fraction immature harvests"
        results[case.name] = map_clm
    results.plot(case_name_list=case_list.names, crop=crop)

In [None]:
importlib.reload(plotting_utils)

for crop in opts["crops_to_include"]:
    results = plotting_utils.ResultsMaps(case_list.mapfig_layout)
    for case in case_list:
        tmp = case.cft_ds.sel(crop=crop)
        tmp["frac_failed_harv_timemean"] = tmp["crop_harv_area_failed"].sum(dim="time") / tmp["crop_harv_area"].sum(dim="time")
        map_clm = utils.grid_one_variable(tmp, "frac_failed_harv_timemean")
        map_clm.attrs["units"] = "unitless"
        map_clm.name = "Fraction failed harvests"
        results[case.name] = map_clm
    results.plot(case_name_list=case_list.names, crop=crop)