In [None]:
import logging
import sys

logging.basicConfig(filename='error_log.txt', level=logging.ERROR, filemode='w',
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')

"""logging.basicConfig(
    filemode='w',
    level=logging.ERROR,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('error_log.txt'),
        logging.StreamHandler(sys.stdout)
    ]
)

logger = logging.getLogger(__name__)"""

In [None]:
from pcmdi_metrics.mean_climate.lib_unified import (
    get_unique_bases, process_dataset, calculate_and_save_metrics, extract_info_from_model_catalogue, 
    get_ref_catalogue, get_model_catalogue, 
    multi_level_dict, print_dict, write_to_json
)
import os
import datetime


from pcmdi_metrics.utils import create_target_grid

In [None]:
variables = [
    #"pr", "ua-200", "ua-850", "va-200", "ta-850",
    "rsdt", "rsut", "rsutcs", "rlut", "rlutcs", 
    "rstcre", "rltcre", "rt", "rst"
]  # optional. If given, prioritized over the model_catalogue.json. If not given, use all variables commonly in ref_catalogue.json and model_catalogue.json

#variables = [
#    "ua-200", "ua-850", "va-200"
#]

model_data_path_template = "/home/data/%(model)/%(var)/%(model)_%(run)_%(var)_blabla.nc"  # optional. If given, prioritized over model_catalogue.json

models = ["model-a", "model-b"]  # optional. If given, prioritized over the model_catalogue.json. If not given, use all models in model_catalogue.json

models_runs_dict = {
    "model-a": ["r1", "r2"],
    "model-b": ["r1", "r2"],
    "model-c": ["r1", "r2"],
}
# optional. If given, prioritized over the model_catalogue.json. If not given, use all runs in model_catalogue.json

In [None]:
interim_output_path_dict = {
    "ref": {
        "path_ac": "/p/user_pub/climate_work/lee1043/temporary/mean_climate_workflow_refactorization/output/clims_ref/%(var)",
        "path_ac_interp": "/p/user_pub/climate_work/lee1043/temporary/mean_climate_workflow_refactorization/output/clims_ref_interp/%(var)"
    },
    "model": {
        "path_ac": "/p/user_pub/climate_work/lee1043/temporary/mean_climate_workflow_refactorization/output/clims_model/%(var)",
        "path_ac_interp": "/p/user_pub/climate_work/lee1043/temporary/mean_climate_workflow_refactorization/output/clims_model_interp/%(var)"
    }
}

output_path = "/p/user_pub/climate_work/lee1043/temporary/mean_climate_workflow_refactorization/output/json"

regions = ["NHEX", "SHEX"]

target_grid = "2.5x2.5"

ref_catalogue_file_path = '/p/user_pub/PCMDIobs/catalogue/obs4MIPs_PCMDI_monthly_byVar_catalogue_v20240716.json'
model_catalogue_file_path = 'model_catalogue.json'

ref_data_head = "/p/user_pub/PCMDIobs"  # optional, if ref_catalogue file does not include entire directory path

In [None]:
rad_diagnostic_variables = ["rt", "rst", "rstcre", "rltcre"]

default_regions = ["global", "NHEX", "SHEX", "TROPICS"]

In [None]:
if not regions:
    regions = default_regions

In [None]:
refs_dict = get_ref_catalogue(ref_catalogue_file_path, ref_data_head)
models_dict = get_model_catalogue(model_catalogue_file_path, variables, models, models_runs_dict, model_data_path_template)

In [None]:
if any(var is None for var in (variables, models, models_runs_dict)): 
    variables, models, models_runs_dict = extract_info_from_model_catalogue(variables, models, models_runs_dict, refs_dict, models_dict)

In [None]:
common_grid = create_target_grid(target_grid_resolution=target_grid)

In [None]:
encountered_variables = set()
ac_ref_dict = multi_level_dict()
ac_model_run_dict = multi_level_dict()
metrics_dict = multi_level_dict()

variables_level_dict = get_unique_bases(variables)
variables_unique = list(variables_level_dict.keys())

print('variables_unique:', variables_unique)
print('variables_level_dict:', variables_level_dict)




def process_references(var, refs, rad_diagnostic_variables, levels, common_grid, start, end, version):

    #refs = [ref for ref in refs if ref in ["GPCP-2-3", "ERA-5", "CERES-EBAF-4-2"]]

    for ref in refs:
        print(f"=== var, ref: {var}, {ref}")
        #try:
        if 1:
            process_dataset(
                var, ref, refs_dict, ac_ref_dict, rad_diagnostic_variables, 
                encountered_variables, levels, common_grid, interim_output_path_dict["ref"], data_type="ref",
                start=start, end=end,     
                repair_time_axis=True,
                overwrite_output_ac=True,
                version=version
            )
        """
        except Exception as e:
            # Log the error to a file
            logging.error(f"Error for {var} {ref}: {str(e)}")
            print(f"Error logged for {var} {ref}")
            print(f'Error from process_references for {var} {ref}:', e)
        """
            

def process_models(var, models, models_runs_dict, rad_diagnostic_variables, levels, common_grid, refs, start, end, version):
    for model in models:
        for run in models_runs_dict[model]:
            try:
                process_dataset(
                    var, (model, run), models_dict, ac_model_run_dict, 
                    rad_diagnostic_variables, encountered_variables, levels, common_grid, 
                    interim_output_path_dict["model"], data_type="model",
                    start=start, end=end,
                    version=version
                )
                for level in levels:
                    ac_model_run_level_interp = ac_model_run_dict[var][model][run][level]
                    calculate_and_save_metrics(var, model, run, level, regions, refs, ac_ref_dict, ac_model_run_level_interp, output_path, refs_dict, metrics_dict)
            except Exception as e:
                print(f'Error from process_models for {var} {model} {run}:', e)

    for level in levels:
        if level is None:
             var_key = var
        else:
             var_key = f"{var}-{level}"
        write_to_json(metrics_dict[var_key], os.path.join(output_path, f"output_{var_key}.json"))


def main():  

    # Set version identifier using the current date if not provided
    version = datetime.datetime.now().strftime("v%Y%m%d")
    print("version:", version)

    all_ref_variables = True

    #variables_unique = ["pr", "ua", "va", "ta"]
    if all_ref_variables:
        variables_unique = sorted(list(refs_dict.keys()))

    start = "1981-01"
    end = "2013-12"

    for var in variables_unique:
        try:
            print("var:", var)
            encountered_variables.add(var)
            levels = variables_level_dict[var]

            if var in refs_dict:
                refs = refs_dict[var].keys()
                process_references(var, refs, rad_diagnostic_variables, levels, common_grid, start, end, version)

            # process_models(var, models, models_runs_dict, rad_diagnostic_variables, levels, common_grid, refs, start, end, version)
        except Exception as e:
            print(f'Error from main for {var}:', e)            

if __name__ == "__main__":
    main()


In [None]:
print_dict(refs_dict)

In [None]:
print_dict(models_dict)

In [None]:
f = "/p/user_pub/PCMDIobs/obs4MIPs/NASA-LaRC/CERES-EBAF-4-2/mon/rsdt/gn/v20240513/rsdt_mon_CERES-EBAF-4-2_RSS_gn_200003-202310.nc"

import xcdat as xc
ds = xc.open_dataset(f)