In [1]:
from lib_unified import get_unique_bases, process_dataset, calculate_and_save_metrics, extract_info_from_model_catalogue, get_ref_catalogue, get_model_catalogue
from lib_unified_dict import multi_level_dict, print_dict, write_to_json

import os

from pcmdi_metrics.utils import create_target_grid

In [2]:
variables = [
    "pr", "ua-200", "ua-850", "va-200", 
    "rsdt", "rsut", "rsutcs", "rlut", "rlutcs", 
    "rstcre", "rltcre", "rt", "rst"
]  # optional. If given, prioritized over the model_catalogue.json. If not given, use all variables commonly in ref_catalogue.json and model_catalogue.json

model_data_path_template = "/home/data/%(model)/%(var)/%(model)_%(run)_%(var)_blabla.nc"  # optional. If given, prioritized over model_catalogue.json

models = ["model-a", "model-b"]  # optional. If given, prioritized over the model_catalogue.json. If not given, use all models in model_catalogue.json

runs_dict = {
    "model-a": ["r1", "r2"],
    "model-b": ["r1", "r2"],
    "model-c": ["r1", "r2"],
}
# optional. If given, prioritized over the model_catalogue.json. If not given, use all runs in model_catalogue.json

In [3]:
interim_output_path_dict = {
    "ref": "clims_ref/%(var)",
    "ref_interp": "clims_ref_interp/%(var)",
    "model": "clims_mod/%(var)",
    "model_interp": "clims_mod_interp/%(var)"
}

output_path = "./output"

regions = ["NHEX", "SHEX"]

target_grid = "2.5x2.5"

ref_catalogue_file_path = 'ref_catalogue.json'
model_catalogue_file_path = 'model_catalogue.json'

In [4]:
rad_diagnostic_variables = ["rt", "rst", "rstcre", "rltcre"]

default_regions = ["global", "NHEX", "SHEX", "TROPICS"]

In [5]:
if not regions:
    regions = default_regions

In [6]:
refs_dict = get_ref_catalogue(ref_catalogue_file_path)
models_dict = get_model_catalogue(model_catalogue_file_path, variables, models, runs_dict, model_data_path_template)

In [7]:
if any(var is None for var in (variables, models, runs_dict)): 
    variables, models, runs_dict = extract_info_from_model_catalogue(variables, models, runs_dict, refs_dict, models_dict)

In [8]:
common_grid = create_target_grid(target_grid_resolution=target_grid)

### implement grid creation here

In [9]:
encountered_variables = set()
ac_ref_dict = multi_level_dict()
ac_model_run_dict = multi_level_dict()
metrics_dict = multi_level_dict()

variables_dict = get_unique_bases(variables)
variables_unique = list(variables_dict.keys())

print('variables_unique:', variables_unique)
print('variables_dict:', variables_dict)


def process_references(var, refs, rad_diagnostic_variables, levels, common_grid):
    for ref in refs:
        try:
            process_dataset(var, ref, refs_dict, ac_ref_dict, rad_diagnostic_variables, encountered_variables, levels, common_grid, interim_output_path_dict, data_type="ref")
        except Exception as e:
            print(e)
            

def process_models(var, models, runs_dict, rad_diagnostic_variables, levels, common_grid, refs):
    for model in models:
        for run in runs_dict[model]:
            try:
                process_dataset(var, (model, run), models_dict, ac_model_run_dict, rad_diagnostic_variables, encountered_variables, levels, common_grid, interim_output_path_dict, data_type="model")
                for level in levels:
                    ac_model_run_level_interp = ac_model_run_dict[var][model][run][level]
                    calculate_and_save_metrics(var, model, run, level, regions, refs, ac_ref_dict, ac_model_run_level_interp, output_path, metrics_dict)
            except Exception as e:
                print(e)
                
    write_to_json(metrics_dict[var], os.path.join(output_path, f"output_{var}.json"))


def main():  
    for var in variables_unique:
        print("var:", var)
        encountered_variables.add(var)
        levels = variables_dict[var]  
        refs = refs_dict[var].keys()

        process_references(var, refs, rad_diagnostic_variables, levels, common_grid)
        process_models(var, models, runs_dict, rad_diagnostic_variables, levels, common_grid, refs)


if __name__ == "__main__":
    main()


variables_unique: ['pr', 'ua', 'va', 'rsdt', 'rsut', 'rsutcs', 'rlut', 'rlutcs', 'rstcre', 'rltcre', 'rt', 'rst']
variables_dict: {'pr': [None], 'ua': [200, 850], 'va': [200], 'rsdt': [None], 'rsut': [None], 'rsutcs': [None], 'rlut': [None], 'rlutcs': [None], 'rstcre': [None], 'rltcre': [None], 'rt': [None], 'rst': [None]}
var: pr
Processing data for: GPCP1
No path, filename, or template found for ref: GPCP1
Processing data for: GPCP2
No path, filename, or template found for ref: GPCP2
Processing data for: model-a, r1
Processing model dataset - varname: pr, data: ('model-a', 'r1'), path: /home/data/model-a/pr/model-a_r1_pr_blabla.nc
refs: dict_keys(['GPCP1', 'GPCP2'])
ac_ref_dict[var].keys(): dict_keys([])
'Reference data GPCP1 is not available for pr.'
Processing data for: model-a, r2
Processing model dataset - varname: pr, data: ('model-a', 'r2'), path: /home/data/model-a/pr/model-a_r2_pr_blabla.nc
refs: dict_keys(['GPCP1', 'GPCP2'])
ac_ref_dict[var].keys(): dict_keys([])
'Reference 