In [15]:
import lightgbm as lgb
import numpy as np
import pandas as pd
from pyro.contrib.cevae import CEVAE
from sklearn.model_selection import StratifiedKFold
from tqdm import tqdm

import torch
from cate.infra.mlflow import MlflowClient
from cate.model.dataset import Dataset
from cate.model.evaluate import (
    Auuc,
    Outputs,
    QiniByPercentile,
    UpliftByPercentile,
    UpliftCurve,
)
from torch import Tensor
from cate.model.metrics import Artifacts, Metrics
from cate.utils import Timer, get_logger, path_linker

In [2]:
dataset_name = "test"
logger = get_logger("causalml")
pathlinker = path_linker(dataset_name)
client = MlflowClient("base_pattern")
timer = Timer()

In [3]:
logger.info("load dataset")
ds = Dataset.load(pathlinker.base)
base_classifier = lgb.LGBMClassifier(
    importance_type="gain", random_state=42, force_col_wise=True, n_jobs=-1
)
base_regressor = lgb.LGBMRegressor(
    importance_type="gain", random_state=42, force_col_wise=True, n_jobs=-1
)

models = {
    "cevae": CEVAE(ds.X.shape[1]).to("cuda"),
}

INFO  2024-11-07 15:33:25 [causalml] load dataset


In [None]:
torch.set_default_device("cuda")
skf = StratifiedKFold(5, shuffle=True, random_state=42)
for name, model in models.items():
    logger.info(f"start {name}")
    client.start_run(
        run_name=f"{dataset_name}_{name}",
        tags={"model": name, "dataset": dataset_name, "package": "causalml"},
        description=f"base_pattern: {name} training and evaluation using {dataset_name} dataset with causalml package and lightgbm model with 5-fold cross validation and stratified sampling.",
    )
    client.log_params(
        {
            "importance_type": "gain",
            "random_state": 42,
            "n_jobs": -1,
            "force_col_wise": True,
        }
    )
    _pred_dfs = []
    for i, (train_idx, valid_idx) in tqdm(
        enumerate(skf.split(np.zeros(len(ds)), ds.y))
    ):
        logger.info(f"epoch {i}")
        train_X = ds.X.iloc[train_idx].to_numpy()
        train_y = ds.y.iloc[train_idx].to_numpy().reshape(-1)
        train_w = ds.w.iloc[train_idx].to_numpy().reshape(-1)
        valid_X = ds.X.iloc[valid_idx].to_numpy()
        valid_y = ds.y.iloc[valid_idx].to_numpy().reshape(-1)
        valid_w = ds.w.iloc[valid_idx].to_numpy().reshape(-1)

        timer.start(name, "train", i)
        model.fit(
            Tensor(train_X).to("cuda"),
            Tensor(train_w).to("cuda"),
            Tensor(train_y).to("cuda"),
        )
        timer.stop(name, "train", i)

        timer.start(name, "predict", i)
        pred = model.ite(Tensor(valid_X).to("cuda"))
        timer.stop(name, "predict", i)

        metrics = Metrics(
            list(
                [Auuc()]
                + [UpliftByPercentile(k) for k in np.arange(0, 1, 0.1)]
                + [QiniByPercentile(k) for k in np.arange(0, 1, 0.1)]
            )
        )
        metrics(pred.to("cpu").numpy().reshape(-1), valid_y, valid_w)
        client.log_metrics(metrics, i)

        _pred_dfs.append(
            pd.DataFrame({"index": ds.y.index[valid_idx], "pred": pred.reshape(-1)})
        )

    pred_df = pd.concat(_pred_dfs, axis=0)
    base_df = pd.merge(
        ds.y.rename(columns={ds.y_columns[0]: "y"}),
        ds.w.rename(columns={ds.w_columns[0]: "w"}),
        left_index=True,
        right_index=True,
    )
    output_df = pd.merge(base_df, pred_df, left_index=True, right_index=True)

    artifacts = Artifacts([UpliftCurve(), Outputs()])
    artifacts(output_df.pred.to_numpy(), output_df.y.to_numpy(), output_df.w.to_numpy())
    client.log_artifacts(artifacts)
    client.end_run()

INFO  2024-11-07 15:42:07 [causalml] start cevae


2024/11/07 15:42:08 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
0it [00:00, ?it/s]

INFO  2024-11-07 15:42:08 [causalml] epoch 0


INFO 	 Training with 800 minibatches per epoch


In [17]:
client.end_run()

2024/11/07 15:42:05 INFO mlflow.tracking._tracking_service.client: 🏃 View run test_cevae at: http://ec2-44-217-145-52.compute-1.amazonaws.com:5000/#/experiments/6/runs/e45226b179234e6db8cde62042e2ccc2.
2024/11/07 15:42:05 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://ec2-44-217-145-52.compute-1.amazonaws.com:5000/#/experiments/6.
2024/11/07 15:42:05 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2024/11/07 15:42:05 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!
