In [None]:
import os
import glob
import numpy as np
import xarray as xr

# Plotting utils
import matplotlib.pyplot as plt

In [None]:
# Where land output is stored
cesm_output_dir = os.path.join(
    os.path.sep,
    "glade",
    "work",
    "samrabin",
)

# Full casenames that are present in CESM_output_dir and in individual filenames
# caseNames = [
#             'b.e23_alpha16b.BLT1850.ne30_t232.054',
# 'b.e30_beta02.BLT1850.ne30_t232.104',
#             ]
case_name_list = None
case_name = "ctsm53019_f09_BNF_hist"

clm_file_h = ".h5."

cfts_to_include = [
    "spring_wheat",
    "irrigated_spring_wheat",
]

## NOT WORKING YET
crops_to_include = [
    # "corn",
    "cotton",
    "rice",
    # "soybean",
    # "sugarcane",
    # "wheat",
]

In [None]:
# Set up directory for any scratch output
if "SCRATCH" in os.environ:
    cupid_temp = os.path.join(os.environ["SCRATCH"], "CUPiD_scratch")
    os.makedirs(cupid_temp, exist_ok=True)
else:
    cupid_temp = "."

if (case_name_list and case_name) or (not case_name_list and not case_name):
    raise RuntimeError("Specify one of caseNames or case_name")
if not case_name_list:
    case_name_list = [
        case_name,
    ]

short_names = [case.split(".")[-1] for case in case_name_list]

n_pfts = 78

In [None]:
ds_list = []
for i, case in enumerate(case_name_list):
    print(f"Importing {case}...")

    # Get list of necessary time series files
    file_pattern = os.path.join(
        cesm_output_dir,
        case,
        "lnd",
        "hist",
        case + ".clm2" + clm_file_h + "*.nc",
    )
    file_list = np.sort(glob.glob(file_pattern))
    if len(file_list) == 0:
        raise FileNotFoundError("No files found matching pattern: " + file_pattern)

    # Open files
    ds_list.append(xr.open_mfdataset(file_list, decode_times=True))
print("Done.")

### Get CFT info

In [None]:
class Cft:
    def __init__(self, name, cft_num):
        self.name = name

        # 1-indexed in the FORTRAN style
        self.cft_num = cft_num
        self.pft_num = None  # Need to know max cft_num

        # 0-indexed in the Python style
        self.pft_ind = None  # Need to know pft_num
        self.where = None

    def __str__(self):
        return "\n".join(
            [
                self.name + ":",
                f"   cft_num: {self.cft_num}",
                f"   pft_num: {self.pft_num}",
                f"   pft_ind: {self.pft_ind}",
            ]
        )

    def update_pft(self, n_non_crop_pfts):
        """
        You don't know n_non_crop_pfts until after reading in all CFTs, so
        this function gets called once that's done in CftList.__init__().
        """
        self.pft_num = n_non_crop_pfts + self.cft_num
        self.pft_ind = self.pft_num - 1

    def get_where(self, ds):
        """
        Get the indices on the pft dimension corresponding to this CFT
        """
        if self.pft_num is None:
            raise RuntimeError(
                "get_where() can't be run until after calling Crop.update_pft()"
            )
        pfts1d_itype_veg = ds["pfts1d_itype_veg"]
        if "time" in pfts1d_itype_veg.dims:
            pfts1d_itype_veg = pfts1d_itype_veg.isel(time=0)
        self.where = np.where(pfts1d_itype_veg.values == self.pft_num)[0].astype(int)
        return self.where


class CftList:
    def __init__(self, ds, n_pfts, cfts_to_include):
        # Get list of all possible CFTs
        self.cft_list = []
        for i, (key, value) in enumerate(ds.attrs.items()):
            if not key.startswith("cft_"):
                continue
            cft_name = key[4:]
            self.cft_list.append(Cft(cft_name, value))

        # Figure out PFT indices
        max_cft_num = max([x.cft_num for x in self.cft_list])
        n_non_crop_pfts = n_pfts - max_cft_num + 1  # Incl. unvegetated
        for cft in self.cft_list:
            cft = cft.update_pft(n_non_crop_pfts)

        # Only include CFTs we care about
        self.cft_list = [x for x in self.cft_list if x.name in cfts_to_include]

        # Figure out where the pft index is each CFT
        for cft in self.cft_list:
            cft.get_where(ds)

    def __getitem__(self, index):
        return self.cft_list[index]

    def __str__(self):
        results = []
        for cft in self.cft_list:
            results.append(str(cft))
        return "\n".join(results)

In [None]:
for i, case in enumerate(case_name_list):
    ds = ds_list[i]
    this_cftlist = CftList(ds, n_pfts, cfts_to_include)

    # Save or check
    if i == 0:
        cft_list = this_cftlist
    elif this_cftlist != cft_list:
        raise NotImplementedError(
            "This code can't handle cases with different CFT lists"
        )
print(cft_list)

## Make time series

In [None]:
# ↓ Incorrectly assumes that all gridcells have the same area
gridcell_area = ds["area"].isel(time=0).mean() * 1e6  # *1e6 to convert km2 to m2

for cft in cft_list:
    for ds in ds_list:
        ds = ds_list[i].isel(pft=cft.where)

        cft_area = gridcell_area * ds["pfts1d_wtgcell"]
        cft_production = cft_area * ds["GRAINC_TO_FOOD_ANN"]

        # Plot data
        cft_production_ts = cft_production.sum(dim="pft")
        cft_production_ts *= 1e-6 * 1e-6  # Convert gC to MtC
        cft_production_ts.attrs["units"] = "Mt C"
        cft_production_ts.plot()

    # Finish plot
    plt.title(cft.name.replace("_", " "))
    plt.show()

## ↓ COMBINED-CFT CROPS; NOT WORKING

In [None]:
class Crop:
    def __init__(self, name, cft_list, ds):
        self.name = name

        # Get CFTs included in this crop
        self.cft_list = []
        for cft in cft_list:
            if self.name not in cft.name:
                continue
            self.cft_list.append(cft)

        # Get information for all CFTs in this crop
        self.cft_names = []
        self.pft_inds = []
        self.where = np.array([], dtype=np.int64)
        for cft in self.cft_list:
            self.cft_names.append(cft.name)
            self.pft_inds.append(cft.pft_ind)
            self.where = np.append(self.where, cft.get_where(ds))
        self.where = np.sort(self.where)

        # Placeholders
        self.ds = None

    def __str__(self):
        return f"{self.name}: {', '.join(f'{x.name} ({x.pft_ind})' for x in self.cft_list)}"


crop_list = [Crop(x, cft_list, ds) for x in crops_to_include]
for crop in crop_list:
    print(crop)