In [None]:
import math
import re
from collections import Counter, defaultdict
from datetime import datetime
from pathlib import Path

import iris
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [None]:
DATA_DIR = Path("/work/scratch-nopw/alexkr/multi_spinup6/jules_output")

In [None]:
def extract_date(path):
    return datetime.strptime(str(path).split(".")[-3], "%Y%m%d")


def extract_spinup_index(fname):
    if isinstance(fname, Path):
        fname = str(fname.name)
    return int(re.search(r"SPINUP(\d*)", fname).group(1))


def sort_key(path):
    return extract_spinup_index(path), extract_date(path)

In [None]:
dump_files = sorted(DATA_DIR.glob("*SPINUP*dump*.nc"), key=sort_key)
dump_files

In [None]:
dates = list(map(extract_date, dump_files))
dates

In [None]:
experiments = list(map(lambda f: str(f).split(".")[-5], dump_files))
experiments

In [None]:
land_indices = [100, 200, 500]
variable = "cs"
plot_data = defaultdict(lambda: defaultdict(list))

for data_file, experiment, date in zip(
    tqdm(dump_files, desc="Reading files"), experiments, dates
):
    cube = iris.load_cube(str(data_file), constraint=variable)
    for land_index in land_indices:
        for i in range(4):
            data_point = cube.data[i, ..., land_index]
            assert not data_point.mask
            plot_data[experiment][(land_index, i)].append((date, data_point.data))

In [None]:
nitems = len(plot_data[experiments[0]])
ncols = 4
nrows = math.ceil(nitems / ncols)

fig, axes = plt.subplots(nrows, ncols, figsize=(20, 15), constrained_layout=True)

for experiment, single_plot_data in plot_data.items():
    for ((key, values), ax) in zip(single_plot_data.items(), axes.ravel()):
        ax.plot(*list(zip(*(values))), marker="o", label=experiment)
        ax.set_title(key)
for ax in axes.ravel():
    ax.legend()

In [None]:
exp_counter = Counter(experiments)
complete = sorted(
    [exp for exp in exp_counter if exp_counter[exp] == max(exp_counter.values())],
    key=extract_spinup_index,
)
complete

In [None]:
shared_masks = [None] * 4

for data_file, experiment, date in zip(
    tqdm(dump_files, desc="Creating shared mask"), experiments, dates
):
    cube = iris.load_cube(str(data_file), constraint=variable)
    if shared_masks[0] is None:
        if isinstance(cube.data.mask, np.ndarray):
            shared_masks = [cube.data.mask[i] for i in range(4)]
        else:
            assert not cube.data.mask
            shared_masks = [
                np.zeros_like(cube.data.data[i], dtype=np.bool_) for i in range(4)
            ]
    else:
        for i in range(4):
            if isinstance(cube.data.mask, np.ndarray):
                shared_masks[i] |= cube.data.mask[i]
            else:
                assert not cube.data.mask
    print("Masked elements:", [np.sum(shared_masks[i]) for i in range(4)])

concats = [defaultdict(list) for i in range(4)]
for data_file, experiment, date in zip(
    tqdm(dump_files, desc="Concatenating arrays"), experiments, dates
):
    if experiment not in complete:
        # Only handle complete experiments.
        continue

    cube = iris.load_cube(str(data_file), constraint=variable)
    for (i, (shared_mask, concat)) in enumerate(zip(shared_masks, concats)):
        concat[experiment].append(cube.data.data[i][~shared_mask])

for concat in tqdm(concats, desc="Joining"):
    for experiment, arrs in concat.items():
        concat[experiment] = np.vstack([arr[None] for arr in arrs])

In [None]:
diff_data = defaultdict(list)

for index, concat in enumerate(concats):
    for exp, comp_exp in zip(complete[1:], complete[:-1]):
        diff = np.abs(concat[exp] - concat[comp_exp])
        diff_data[index].append(
            {
                "mean": np.mean(diff),
                "std": np.std(np.mean(diff, axis=1)),
                "max": np.max(np.mean(diff, axis=1)),
            }
        )

In [None]:
measures = list(diff_data[0][0])

fig, axes = plt.subplots(
    len(concats), len(measures), constrained_layout=True, figsize=(12, 8)
)

for ax_i, measure in enumerate(measures):
    for i, ax in enumerate(axes[:, ax_i]):
        ax.plot(complete[1:], [plot_dict[measure] for plot_dict in diff_data[i]])
        ax.set_title(f"$\mathrm{{cs}}_\mathrm{{{i + 1}}}$ {measure} of diffs")
        plt.setp(ax.get_xticklabels(), rotation=45, ha="right")