In [None]:
from collections import defaultdict
from operator import itemgetter

import matplotlib.pyplot as plt
import numpy as np
from buffered_leave_one_out import fit_buffered_loo_sample
from sklearn.metrics import r2_score
from wildfires.analysis import cube_plotting
from wildfires.qstat import get_ncpus

from empirical_fire_modelling.configuration import Experiment
from empirical_fire_modelling.cx1.core import check_local, run_local
from empirical_fire_modelling.utils import tqdm

max_rad = 50
exp = Experiment["15VEG_FAPAR"]


def args_iter():
    # Batches of 1000s (x8 rads) submitted as separate CX1 array jobs due to job size limitations.
    for seeds in [range(1000), range(1000, 2000), range(2000, 3000), range(3000, 4000)]:
        for radius in np.linspace(0, max_rad, 8):
            for seed in seeds:
                yield (exp, radius, max_rad, seed)


checked, uncached_args = check_local(
    fit_buffered_loo_sample,
    zip(*args_iter()),
    kwargs={},
    backend="processes",
    n_cores=get_ncpus(),
    verbose=True,
)

results = run_local(
    fit_buffered_loo_sample,
    zip(
        *(
            single_args
            for single_args in args_iter()
            if single_args not in uncached_args
        )
    ),
    kwargs={},
    backend="processes",
    n_cores=get_ncpus(),
    verbose=True,
)

In [None]:
n_plotted = np.zeros((720, 1440), dtype=np.int64)

for (
    (test_indices, n_ignored, n_train, n_hold_out, total_samples),
    test_y,
    pred_y,
) in tqdm(results, desc="Realising data / plotting"):
    n_plotted[test_indices[0], test_indices[1]] += 1

fig = cube_plotting(n_plotted, fig=plt.figure(figsize=(10, 5), dpi=1000))

In [None]:
data = defaultdict(list)
# assert len(results) == len(args[0])

n_seed_bins = 10
seed_bins = np.linspace(
    0,
    4000,  # The maximum (+1) seed possible.
    n_seed_bins + 1,
    dtype=np.int64,
)

for (
    (
        experiment,
        radius,
        max_rad,
        seed,
    ),
    (
        (test_indices, n_ignored, n_train, n_hold_out, total_samples),
        test_y,
        pred_y,
    ),
) in zip(args_iter(), tqdm(results, desc="Realising / collating results")):
    seed_index = np.digitize(seed, seed_bins)

    data[(radius, "n_ignored")].append(n_ignored)
    data[(radius, "n_train")].append(n_train)
    data[(radius, "n_hold_out")].append(n_hold_out)
    data[(radius, "total_samples")].append(total_samples)
    data[(seed_index, radius, "test_y")].append(np.array(test_y).ravel())
    data[(seed_index, radius, "pred_y")].append(np.array(pred_y).ravel())

In [None]:
keys = tuple(k for k in data if len(k) == 3)
seed_indices = tuple(sorted(set(map(itemgetter(0), keys))))
radii = tuple(sorted(set(map(itemgetter(1), keys))))
seed_indices, radii

In [None]:
seed_r2s = [[] for _ in seed_indices]
for seed_i, seed_index in enumerate(seed_indices):
    for i, radius in enumerate(radii):
        test_y = np.array(data[(seed_index, radii[i], "test_y")]).ravel()
        pred_y = np.array(data[(seed_index, radii[i], "pred_y")]).ravel()
        seed_r2s[seed_i].append(r2_score(y_true=test_y, y_pred=pred_y))

plt.figure(dpi=200)
plt.title(experiment.name)
for seed_index, r2s in zip(seed_indices, seed_r2s):
    plt.plot(np.array(radii) * 30, r2s, label=f"Seed bin: {seed_index}")

plt.legend()
plt.xlabel("radius (approx. km)")
_ = plt.ylabel(r"CV $\mathrm{R}^2$")

In [None]:
seed_r2s = [[] for _ in seed_indices]
for seed_i, seed_index in enumerate(seed_indices):
    for i, radius in enumerate(radii):
        test_y = np.array(data[(seed_index, radii[i], "test_y")]).ravel()
        pred_y = np.array(data[(seed_index, radii[i], "pred_y")]).ravel()
        seed_r2s[seed_i].append(r2_score(y_true=test_y, y_pred=pred_y))

plt.figure(dpi=200)
plt.title(experiment.name)
for seed_index, r2s in zip(seed_indices, seed_r2s):
    plt.plot(np.array(radii), r2s, label=f"Seed bin: {seed_index}")

plt.legend()
plt.xlabel("radius (pixels)")
_ = plt.ylabel(r"CV $\mathrm{R}^2$")