In [None]:
import os

import papermill as pm

In [None]:
def get_data_path(task):
    data_path = f"../../../../data/papermill/training_alphas/neural/{task}"
    if not os.path.exists(data_path):
        os.makedirs(data_path, exist_ok=True)
    return data_path

In [None]:
def train_alpha(
    name, model, content, task, residual_alphas, input_alphas, tune_hyperparams
):
    pm.execute_notebook(
        "Papermill.ipynb",
        os.path.join(get_data_path(task), f"{name}.ipynb"),
        parameters=dict(
            model=model,
            content=content,
            task=task,
            input_alphas=input_alphas,
            residual_alphas=residual_alphas,
            outdir=name,
            tune_hyperparams=tune_hyperparams,
        ),
    )

In [None]:
def get_baseline_predictor(content, task):
    blp = {
        "explicit": [f"{task}/ExplicitUserItemBiases"],
        "implicit": [],
    }
    return blp[content]

In [None]:
os.chdir("../Helpers")

In [None]:
all_tasks = ["random", "temporal"]

In [None]:
for task in all_tasks:
    for content in ["explicit", "implicit"]:
        train_alpha(
            f"Neural{content.capitalize()}AutoencoderUntuned",
            "autoencoder",
            content,
            task,
            get_baseline_predictor(content, task),
            get_baseline_predictor("explicit", task),
            False,
        )

In [None]:
for task in all_tasks:
    for content in ["explicit"]:
        train_alpha(
            f"Neural{content.capitalize()}ItemCFUntuned",
            "item_based_collaborative_filtering",
            content,
            task,
            get_baseline_predictor(content, task),
            get_baseline_predictor("explicit", task),
            False,
         )

In [None]:
for task in all_tasks:
    for content in ["implicit"]:
        train_alpha(
        f"Neural{content.capitalize()}EaseUntuned",
        "ease",
        content,
        task,
        get_baseline_predictor(content, task),
        get_baseline_predictor("explicit", task),
        False,
    )