# Normalise abrupt-4xCO2

In [None]:
%load_ext nb_black

In [None]:
import glob
import os.path
import re
import traceback
import warnings
from concurrent.futures import as_completed, ProcessPoolExecutor
from multiprocessing import Pool

import netcdf_scm.retractions
import netcdf_scm.stitching
import tqdm.autonotebook as tqdman

import config

In [None]:
ID = config.ID

In [None]:
RUN_CHECK = False

In [None]:
CRUNCH_DIR = "./{}-country-crunch-popn-weighted".format(ID)
STITCHED_NORMALISED_DIR = "./{}-irf-calibration-crunch-stitched-normalised".format(ID)

MAX_WORKERS = 60

In [None]:
!mkdir -p {STITCHED_NORMALISED_DIR}

In [None]:
CRUNCH_DIR

In [None]:
display(CRUNCH_DIR)
abrupt4xco2_files = [
    f
    for f in glob.glob(os.path.join(CRUNCH_DIR, "**", "*.nc"), recursive=True)
    if "_abrupt-4xCO2_" in f
]
# ssp_files = [f for f in glob.glob(os.path.join(CRUNCH_DIR, "**", "*.nc"), recursive=True) if "ssp" in f]
display(len(abrupt4xco2_files))
abrupt4xco2_files[:21]

In [None]:
cms = set([f.split(os.sep)[6] for f in abrupt4xco2_files])
display(len(cms))
print("\n".join(sorted(cms)))

In [None]:
# TODO: move this into netcdf_scm
retracted_ids = netcdf_scm.retractions.check_retractions(
    [".".join(f.split(os.sep)[3:-1]) for f in abrupt4xco2_files],
    esgf_query_batch_size=20,
)
retracted_files = []
for i in retracted_ids:
    retracted_dir = os.path.join(
        CRUNCH_DIR, "netcdf-scm-crunched", i.replace(".", os.sep)
    )
    retracted_files_dir = os.listdir(retracted_dir)
    assert len(retracted_files_dir) == 1
    retracted_files.append(os.path.join(retracted_dir, retracted_files_dir[0]))

sorted(retracted_files)

In [None]:
abrupt4xco2_files = [f for f in abrupt4xco2_files if f not in retracted_files]
display(len(abrupt4xco2_files))

In [None]:
# TODO: put useful bits of this in netCDF-SCM


def stitch_and_normalise(
    f, catch=True, norm_years=21, normalise=True, verbose=False, force=False
):
    def get_result():
        if verbose:
            print(f"Loading and stitching {f}")
        (
            scmrun,
            picontrol_branching_time,
            picontrol_file,
        ) = netcdf_scm.stitching.get_continuous_timeseries_with_meta(
            f, drs="CMIP6Output", return_picontrol_info=normalise
        )

        variable = scmrun.get_unique_meta("variable", True)
        climate_model = scmrun.get_unique_meta("climate_model", True)
        scenario = scmrun.get_unique_meta("scenario", True)
        member_id = scmrun.get_unique_meta("member_id", True)

        min_time = scmrun["time"].min()
        start_year = min_time.year
        start_month = min_time.month

        max_time = scmrun["time"].max()
        end_year = max_time.year
        end_month = max_time.month

        table = os.path.basename(f).split("_")[2]
        grid = os.path.basename(f).split("_")[-2]
        out_name = f"netcdf-scm_{variable}_Amon_{climate_model}_{scenario}_{member_id}_{grid}_{start_year}{start_month:02d}-{end_year}{end_month:02d}.nc"

        if normalise:
            out_file = os.path.join(STITCHED_NORMALISED_DIR, out_name)
        else:
            out_file = os.path.join(STITCHED_DIR, out_name)

        if os.path.isfile(out_file):
            if verbose:
                print(f"Out file already exists: {out_file}")

            if force:
                if verbose:
                    print("Force over-writing")
            else:
                return None

        if normalise:
            if verbose:
                print(f"Loading {picontrol_file}")

            picontrol_scmrun = netcdf_scm.io.load_scmrun(picontrol_file)
            picontrol_scmrun.metadata["netcdf-scm crunched file"] = picontrol_file

            if verbose:
                print(f"Normalising using {norm_years} years")

            normaliser = netcdf_scm.normalisation.NormaliserRunningMean(
                nyears=norm_years
            )

            out = normaliser.normalise_against_picontrol(
                scmrun, picontrol_scmrun, picontrol_branching_time
            )
        else:
            out = scmrun

        out["grid"] = grid

        out_to_disk = out.copy()
        out_to_disk.metadata = {
            k.replace("(", "").replace(")", ""): v
            for k, v in out_to_disk.metadata.items()
        }

        if verbose:
            print(f"Saving to {out_file}")

        out_to_disk.to_nc(out_file)

        return None

    if catch:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            try:
                return get_result()
            except Exception as exc:
                raise ValueError("File failed: {}".format(f)) from exc
    else:
        return get_result()

In [None]:
checker = "./20210416-irf-calibration-crunch/netcdf-scm-crunched/CMIP6/CMIP/MOHC/UKESM1-0-LL/abrupt-4xCO2/r1i1p1f2/Amon/tas/gn/v20190406/netcdf-scm_tas_Amon_UKESM1-0-LL_abrupt-4xCO2_r1i1p1f2_gn_185001-199912.nc"
checker

In [None]:
# RUN_CHECK = True

In [None]:
if RUN_CHECK:
    import xarray as xr
    from netcdf_scm.iris_cube_wrappers import ScmCube

    def _load_helper_and_scm_cubes(path):
        scm_cubes = {}

        data = xr.open_dataset(path)
        data.load()  # get everything in memory

        # Must be kept until https://github.com/pandas-dev/pandas/issues/37071
        # is solved
        if data["time"].encoding["units"] == "days since 1-01-01 00:00:00":
            data["time"].encoding["units"] = "days since 0001-01-01 00:00:00"

        for _, darray in data.data_vars.items():
            try:
                region = darray.attrs["region"]
            except KeyError:
                # bnds or some other unclassified variable
                continue

            if region != "World":
                continue

            scm_cubes[region] = ScmCube()

            scm_cubes[region].cube = darray.to_iris()
            scm_cubes[region].cube.attributes = {
                **scm_cubes[region].cube.attributes,
                **data.attrs,
            }

        # take any cube as base for now, not sure how to really handle this so will
        # leave like this for now and only make this method public when I work it
        # out...
        loaded = list(scm_cubes.values())[0]

        return loaded, scm_cubes

    netcdf_scm.io._load_helper_and_scm_cubes = _load_helper_and_scm_cubes

In [None]:
if RUN_CHECK:
    tmp = stitch_and_normalise(checker, catch=False, verbose=True)
    display(tmp)

In [None]:
if RUN_CHECK:
    source = netcdf_scm.io.load_scmrun(checker)
    display(source)

In [None]:
if RUN_CHECK:
    parent_replacements = netcdf_scm.stitching.get_parent_replacements(source)
    display(parent_replacements)

In [None]:
if RUN_CHECK:
    parent_file = netcdf_scm.stitching.get_parent_file_path(
        checker, parent_replacements, "CMIP6Output"
    )
    display(parent_file)

In [None]:
if RUN_CHECK:
    parent = netcdf_scm.io.load_scmrun(parent_file)
#     parent.metadata["parent_time_units"] = "days since 0001-01-01"

In [None]:
if RUN_CHECK:
    display(netcdf_scm.stitching.get_branch_time(parent, parent=True))
    display(netcdf_scm.stitching.get_branch_time(parent))

In [None]:
if RUN_CHECK:
    !ncdump -h {parent_file} | grep parent

In [None]:
# abrupt4xco2_files = [f for f in abrupt4xco2_files if "UKESM" in f]
# abrupt4xco2_files

In [None]:
normalise = False
normalise = True

force = True
force = False

verbose = True
verbose = False

pool = ProcessPoolExecutor(max_workers=MAX_WORKERS)

futures = []
for f in tqdman.tqdm(abrupt4xco2_files):
    futures.append(
        pool.submit(
            stitch_and_normalise, f, normalise=normalise, verbose=verbose, force=force
        )
    )

all_errors = []
errors = []
for i, future in tqdman.tqdm(
    enumerate(as_completed(futures, timeout=None)), total=len(futures)
):
    try:
        future.result()
    except Exception as exc:
        errors.append(traceback.format_exc())

    if i % 50 == 10 or i == len(futures) - 1:
        print("\n\n".join(errors))
        all_errors += list(
            set([v for e in errors for v in re.findall(".*File failed: (.*.nc)", e)])
        )
        #         if errors:
        #             break
        errors = []

In [None]:
len(all_errors)

In [None]:
all_errors