## Setup

In [None]:
from specific import *

### Retrieve previous results from the 'model' notebook

In [None]:
X_train, X_test, y_train, y_test = data_split_cache.load()
results, rf = cross_val_cache.load()

#### Get Dask Client

In [None]:
# client = Client(n_workers=1, threads_per_worker=8, resources={'threads': 8})
client = get_client()
client

In [None]:
rf_fut = client.scatter(rf, broadcast=True)
X_fut = client.scatter(X_train, broadcast=True)

In [None]:
columns_list = list(combinations(X_train.columns, 2))
print("Total nr. of columns:", len(columns_list))

## 2D ALE Plotting

In [None]:
# XXX: Local trial
# save_ale_2d_and_get_importance(rf, X_train[:20000], X_train.columns[:2])

### Test Parallelisation capacity - around 8 cores per worker looks good

Note that it is impractical to spawn 32 workers with 1 process each since they will all have their own memory copies. 
Only around 60GB of memory is available.

In [None]:
times = []
n_jobs_list = [1, 4, 7, 10, 12, 15, 18, 22, 26, 30][::-1]

for n_jobs in tqdm(n_jobs_list):
    start = time()
    ale_f = client.submit(
        add_common_path_deco(save_ale_2d_and_get_importance),
        model=rf_fut,
        train_set=X_fut,
        features=columns_list[0],
        n_jobs=n_jobs,
        resources={"threads": 1},
        pure=False,
    )
    dask.distributed.wait(ale_f)

    if ale_f.status == "error":
        print(ale_f.result())
        times.append(0)
    else:
        times.append(time() - start)

In [None]:
print("Time for 1 cpu core:", times[-1])
print("nr of tasks:", len(columns_list))

n_workers = 320
total_time = len(columns_list) * times[-1]
print(f"Time with {n_workers} workers: {(total_time / n_workers) / (60 * 60)} hours.")

In [None]:
times = np.array(times)
n_jobs_list = np.array(n_jobs_list)
plt.plot(n_jobs_list, (1 / times) / n_jobs_list)

#### Local 2D ALE plotting

In [None]:
# Plot top interactions first.
interact_data_cache = SimpleCache("SHAP_interact_data", cache_dir=CACHE_DIR)
interact_data = interact_data_cache.load()

In [None]:
interact_data[:10].index.to_list()[0]

In [None]:
columns_list[:10]

In [None]:
[shorten_features(cs) for cs in columns_list if "Max Temp" in cs]

In [None]:
[
    set(shorten_features(cs)) == set(["Max Temp", "VOD 3 M"])
    for cs in columns_list
    if "Max Temp" in cs
]

In [None]:
# for columns in columns_list:
# for columns in [('Max Temp', 'lightning')]:
for interact_columns in tqdm(interact_data.index.to_list(), desc="2D ALE plots"):
    matching_cols = [
        cs for cs in columns_list if set(shorten_features(cs)) == set(interact_columns)
    ]
    assert len(matching_cols) == 1, matching_cols
    columns = matching_cols[0]

    save_ale_2d_and_get_importance(
        model=rf,
        train_set=X_train,
        features=columns,
        n_jobs=get_ncpus(),
        include_first_order=True,
    )

## Run 2D ALE plotting in parallel

In [None]:
world_ale_2d_cache = SimpleCache("world_ale_2d", cache_dir=CACHE_DIR)

n_threads = common_worker_threads(client)  # The number of threads per worker.

# XXX:
n_threads //= 2


@world_ale_2d_cache
def get_world_ale_2d():
    ale_fs = [
        client.submit(
            add_common_path_deco(save_ale_2d_and_get_importance),
            model=rf_fut,
            train_set=X_fut,
            features=columns,
            n_jobs=n_threads,
            resources={"threads": n_threads},
        )
        for columns in columns_list
    ]

    for ale_f in tqdm(
        dask.distributed.as_completed(ale_fs),
        total=len(ale_fs),
        unit="plot",
        desc="Calculating 2D ALE plots",
        smoothing=0,
    ):
        if ale_f.status == "error":
            print(ale_f.result())

    ptp_values = {}

    for columns, ale_f in zip(columns_list, ale_fs):
        ptp_values[columns] = ale_f.result()
    return ptp_values


ptp_values = get_world_ale_2d()

### Repeat with first order effects just for the plots.

In [None]:
n_threads = common_worker_threads(client)  # The number of threads per worker.

# XXX:
n_threads //= 2

ale_fs = []
for columns in columns_list:
    if not os.path.isfile(
        os.path.join(
            figure_saver.directories[0],
            "2d_ale_first_order",
            "__".join(columns) + ".png",
        )
    ):
        ale_fs.append(
            client.submit(
                add_common_path_deco(save_ale_2d_and_get_importance),
                model=rf_fut,
                train_set=X_fut,
                features=columns,
                n_jobs=n_threads,
                resources={"threads": n_threads},
                include_first_order=True,
            )
        )

for ale_f in tqdm(
    dask.distributed.as_completed(ale_fs),
    total=len(ale_fs),
    unit="plot",
    desc="Calculating 2D ALE plots",
    smoothing=0,
):
    if ale_f.status == "error":
        print(ale_f.result())

### Ignore and count None values, then plot a histogram of the ptp values.

In [None]:
filtered_columns_list = []
filtered_ptp_values = []
for columns, ptp in ptp_values.items():
    if ptp is not None:
        filtered_columns_list.append(columns)
        filtered_ptp_values.append(ptp)
    else:
        print("Error for columns:", columns)

np.asarray([ptp for ptp in ptp_values if ptp is not None])
_ = plt.hist(filtered_ptp_values, bins=20)

pdp_results = pd.Series(filtered_ptp_values, index=filtered_columns_list)
pdp_results.sort_values(inplace=True, ascending=False)
print(pdp_results.head(20))
plt.figure()
plt.plot(pdp_results.values, marker="o")
plt.yscale("log")
plt.xscale("log")
plt.ylabel("2D Importance")

## PDP Plotting - see array job folder 'pdp_2d'

Not everything may have been computed.

### Worldwide

In [None]:
%%time
save_pdp_plot_2d(rf, X_train[:1000], columns_list[0], 8)

In [None]:
%%time
save_pdp_plot_2d(rf, X_train[:10000], columns_list[0], 8)

In [None]:
%%time
save_pdp_plot_2d(rf, X_train[:100000], columns_list[0], 8)

In [None]:
# Not really practical since it takes too long - use array jobs instead!

%%time
n_threads = 8
pdp_fs = [
    client.submit(
        add_common_path_deco(save_pdp_plot_2d),
        model=rf_fut,
        X_train=X_fut,
        features=columns,
        n_jobs=n_threads,
        resources={"threads": n_threads},
    )
    for columns in columns_list
]

for pdp_f in tqdm(
    dask.distributed.as_completed(pdp_fs),
    total=len(pdp_fs),
    unit="plot",
    desc="Calculating 2D PDP plots",
    smoothing=0,
):
    if pdp_f.status == "error":
        print(pdp_f.result())