In [2]:
from ppp_prediction.model_v2.models import *

--------------------------------------------------------------------------------

  CuPy may not function correctly because multiple CuPy packages are installed
  in your environment:

    cupy, cupy-cuda12x

  Follow these steps to resolve this issue:

    1. For all packages listed above, run the following command to remove all
       existing CuPy installations:

         $ pip uninstall <package_name>

      If you previously installed CuPy via conda, also run the following:

         $ conda uninstall cupy

    2. Install the appropriate CuPy package.
       Refer to the Installation Guide for detailed instructions.

         https://docs.cupy.dev/en/stable/install.html

--------------------------------------------------------------------------------



In [3]:
from collections import defaultdict, OrderedDict
from ppp_prediction.model import run_glmnet
from ppp_prediction.cox import run_cox
from ppp_prediction.metrics import cal_binary_metrics
from sklearn.model_selection import train_test_split
from sklearn.isotonic import IsotonicRegression
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import numpy as np
import pandas as pd
import seaborn as sns
from plotnine import *
from sklearn.metrics import brier_score_loss, roc_curve, auc
from dcurves import dca
from functools import reduce, partial


import logging

logging.basicConfig(level=logging.INFO)

from scipy.stats import bootstrap


def config_dict_to_df(config_dict, index_name):
    """
    1) convert the config_dict to a dataframe
    Example1
            combination_dict = OrderedDict(
        {
            ("PANEL", "Lasso"): {
                "xvar": ["Age", "Sex"],
                "model": run_glmnet,
                "config": {"cv": 6},
            },
            ("PANEL", "xgboost"): {
                "xvar": ["Age", "Sex"],
            },
            ("AgeSex", "xgboost"): {
                "xvar": ["Age", "Sex"],
            },
        }
    )
    config_dict_to_df(combination_dict, ("combination", "model"))

                            xvar                                    model  \
    combination model                                                          
    PANEL       Lasso    [Age, Sex]  <function run_glmnet at 0x7ff7aa182f80>   
                xgboost  [Age, Sex]                                      NaN   
    AgeSex      xgboost  [Age, Sex]                                      NaN   

                            config  
    combination model               
    PANEL       Lasso    {'cv': 6}  
                xgboost        NaN  
    AgeSex      xgboost        NaN 



    """
    config_df = pd.DataFrame(config_dict).T
    config_df.index.set_names(index_name, inplace=True)
    config_df.columns = pd.MultiIndex.from_tuples(
        [("param", col) for col in config_df.columns]
    )

    return config_df


def update_concat_df(df1, df2, duplicate_replace=False, show_warning=True):
    """
    update the df1 with df2, if duplicate_replace is True, then replace the duplicate rows

    This will copy df1 and df2 to avoid modify the original df1 and df2

    Update the df1 with df2, if duplicate_replace is True, then replace the duplicate rows.
    This function copies df1 and df2 to avoid modifying the original dataframes.
    Parameters:
    - df1 (pandas.DataFrame): The first dataframe to be updated.
    - df2 (pandas.DataFrame): The second dataframe used for updating df1.
    - duplicate_replace (bool, optional): If True, replace duplicate rows in df1 with df2. Default is False.
    - show_warning (bool, optional): If True, show warning messages. Default is True.
    Returns:
    - df (pandas.DataFrame): The updated dataframe.
    Example:
    df1 = pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]})
    df2 = pd.DataFrame({'A': [7, 8, 9], 'B': [10, 11, 12]})
    updated_df = update_concat_df(df1, df2, duplicate_replace=True, show_warning=False)

    """
    # WARNING: this will copy the df1 and df2
    df1, df2 = df1.copy(), df2.copy()

    new_adds = df2.index.difference(df1.index)
    inter = df2.index.intersection(df1.index)

    if duplicate_replace:
        df1.drop(inter, inplace=True)
        warning_duplicated = (
            f"Duplicate_replace is True, will replace the model_config with {inter}"
        )
    else:
        warning_duplicated = (
            f"Duplicate_replace is False, will skip the model_config with {inter}"
        )

    if len(new_adds) > 0:
        warning_new_add = (
            f"new model_config {new_adds} not in original model_config, will add them"
        )
    else:
        warning_new_add = "No new model_config found"

    if show_warning:
        logging.warning(warning_new_add) if len(warning_new_add) > 0 else None
        logging.warning(warning_duplicated) if len(warning_duplicated) > 0 else None
    df = pd.concat([df1, df2])

    return df


def get_risk_strat_df(data=None, y_true=None, y_pred=None, k=10, n_resample=1000):
    """
    TODO: Add iris as an example
    """
    if data is not None:
        y_true = data[y_true]
        y_pred = data[y_pred]
    elif isinstance(y_true, pd.Series) and isinstance(y_pred, pd.Series):
        pass
    elif isinstance(y_true, np.ndarray) and isinstance(y_pred, np.ndarray):
        y_true = pd.Series(y_true)
        y_pred = pd.Series(y_pred)
    elif isinstance(y_true, list) and isinstance(y_pred, list):
        y_true = pd.Series(y_true)
        y_pred = pd.Series(y_pred)
    else:
        raise ValueError(
            "data should be a DataFrame or y_true and y_pred should be Series or list or numpy array"
        )

    plt_df = pd.DataFrame({"y_true": y_true, "y_pred": y_pred}).dropna()
    try:
        plt_df["y_pred_bins"] = pd.qcut(
            plt_df["y_pred"],
            k,
            labels=[f"{i:.0f}%" for i in (np.linspace(0, 1, k + 1) * 100)[1:]],
        )
    except ValueError:
        raise ValueError("input data have many values are same and cannot be binned")
    if not n_resample:
        plt_df_group = (
            plt_df.groupby("y_pred_bins")
            .apply(lambda x: pd.Series({"mean_true": x.y_true.mean()}))
            .reset_index(drop=False)
        )
    else:

        # 定义一个函数来计算均值
        def mean_bootstrap(data):
            # 使用bootstrap计算均值的置信区间
            res = bootstrap(data=(data,), statistic=np.mean, n_resamples=n_resample)

            return (
                np.mean(data),
                res.confidence_interval.low,
                res.confidence_interval.high,
            )

        # 对每个分位数进行bootstrap抽样

        plt_df_group = (
            plt_df.groupby("y_pred_bins")
            .apply(
                lambda x: pd.Series(
                    list(mean_bootstrap(x["y_true"])) + [x["y_pred"].mean()],
                    index=["mean_true", "ci_low", "ci_high", "mean_pred"],
                ).T
            )
            .reset_index(drop=False)
        )

    return plt_df_group


def get_calibration_df(
    data,
    obs,
    pred,
    followup=None,
    group=None,
    n_bins=10,
):
    """
    TODO: Add iris as an example
    """
    data = data.copy()

    if followup is None:
        followup = "followup"
        data[followup] = 1

    if group is not None:

        data = data.groupby(group).apply(
            lambda x: x.assign(decile=pd.qcut(x[pred], n_bins, labels=False))
        )
        data = (
            data.groupby([group, "decile"])
            .apply(
                lambda x: pd.Series(
                    {
                        "obsRate": (x[obs] / x[followup]).mean(),
                        "obsRate_SE": (x[obs] / x[followup]).std() / np.sqrt(len(x)),
                        "obsNo": x[obs].sum(),
                        "predMean": x[pred].mean(),
                    }
                )
            )
            .reset_index()
        )
    else:
        data = data.assign(decile=pd.qcut(data[pred], n_bins, labels=False))
        data = (
            data.groupby("decile")
            .apply(
                lambda x: pd.Series(
                    {
                        "obsRate": (x[obs] / x[followup]).mean(),
                        "obsRate_SE": (x[obs] / x[followup]).std() / np.sqrt(len(x)),
                        "obsNo": x[obs].sum(),
                        "predMean": x[pred].mean(),
                    }
                )
            )
            .reset_index()
        )
    data["obsRate_UCI"] = np.clip(
        data["obsRate"] + 1.96 * data["obsRate_SE"], a_max=1, a_min=None
    )
    data["obsRate_LCI"] = np.clip(
        data["obsRate"] - 1.96 * data["obsRate_SE"], a_min=0, a_max=None
    )
    return data


def calibration_score(
    raw_train_pred,
    raw_test_pred,
    train_y,
    method="isotonic",
    return_model=False,
    need_scale=True,
):
    """
    TODO: Add iris as an example
    """
    if method == "isotonic":
        model = IsotonicRegression(out_of_bounds="clip")

        if need_scale:
            model = Pipeline([("scaler", StandardScaler()), ("model", model)])
        else:
            model = Pipeline([("model", model)])

        model.fit(raw_train_pred, train_y)

        pred_train_calibrated = model.predict(raw_train_pred)
        pred_test_calibrated = model.predict(raw_test_pred)
    elif method == "logitstic":
        model = LogisticRegression(
            # class_weight="balanced",
            max_iter=5000,
            random_state=1,
        )
        if need_scale:
            model = Pipeline([("scaler", StandardScaler()), ("model", model)])
        else:
            model = Pipeline([("model", model)])

        raw_train_pred = (
            raw_train_pred.values
            if isinstance(raw_train_pred, pd.Series)
            else raw_train_pred
        )
        raw_test_pred = (
            raw_test_pred.values
            if isinstance(raw_test_pred, pd.Series)
            else raw_test_pred
        )
        model.fit(raw_train_pred.reshape(-1, 1), train_y)
        pred_train_calibrated = model.predict_proba(raw_train_pred.reshape(-1, 1))[:, 1]
        pred_test_calibrated = model.predict_proba(raw_test_pred.reshape(-1, 1))[:, 1]
    else:
        raise ValueError("method should be isotonic or logitstic")
    if return_model:
        return pred_train_calibrated, pred_test_calibrated, model
    else:
        return pred_train_calibrated, pred_test_calibrated


def get_predict_v2_from_df(
    model,
    data,
    x_var,
):
    """
    merge by idx
    TODO: Add iris as an example

    """

    no_na_data = data[x_var].dropna().copy()
    if hasattr(model, "predict_proba"):
        no_na_data["pred"] = model.predict_proba(no_na_data)[:, 1]
    else:
        no_na_data["pred"] = model.predict(no_na_data)

    return (
        data[[]]
        .merge(no_na_data[["pred"]], left_index=True, right_index=True, how="left")
        .values.flatten()
    )


class DiseaseScoreModel_V2:
    def __init__(
        self,
        disease_df,
        model_table,
        label,
        disease_name=None,
        train_eid=None,
        test_eid=None,
        eid="eid",
        other_keep_cols=None,
        E=None,
        T=None,
        test_size=0.2,
    ):
        """
            TODO: Add iris as an example

            meta_index_col is the index to record model summary information by a structure of DataFrame

                model_config:{
                    "AgeSex": {
                        "xvar":["age", "sex"]
                        }
                    "KidneyImage": {
                        "xvar":KidneyImage
                        "model":a function accept (train= train, test=test,xvar, y, **kwargs) and return (model, *others)
                        "config":{
                            "cv":5
                            ...
                        } # other config
                        }
        }
                }
                
        model_table:

                          param                                           \
                           xvar                                    model   
combination model                                                          
PANEL       Lasso    [Age, Sex]  <function run_glmnet at 0x7ff7aa182f80>   
            xgboost  [Age, Sex]                                      NaN   
AgeSex      xgboost  [Age, Sex]                                      NaN   

                                
                        config  
combination model               
PANEL       Lasso    {'cv': 6}  
            xgboost        NaN  
AgeSex      xgboost        NaN  


    
                other_keep_cols: other columns to keep in the final dataframe
                eid : the column name of the unique identifier

        """
        self.disease_df = disease_df

        # step1 split data; can be down by train_eid, test_eid or random split or user run train_test_split
        if train_eid is not None:
            self.train = disease_df.query(f"{eid} in @train_eid")
        if test_eid is not None:
            self.test = disease_df.query(f"{eid} in @test_eid")
        if train_eid is None and test_eid is None:
            logging.warning(f"Random split data with test_size: {test_size:.2f}")
            self.train, self.test = train_test_split(disease_df, test_size=test_size)

        self.train.reset_index(drop=True, inplace=True)
        self.test.reset_index(drop=True, inplace=True)

        self.label = label
        self.disease_name = disease_name or label
        self.eid = eid

        self.other_keep_cols = other_keep_cols if other_keep_cols else []

        logging.info(
            f"Loading data with train cases {self.train[label].sum()} and test cases {self.test[label].sum()} of {self.disease_name}, while {len(self.train.columns)} columns"
        )

        # E and T for cox model or C-index
        self.E = E
        self.T = T
        if self.E and (self.E != self.label):
            self.other_keep_cols.append(self.E)
        if self.T:
            self.other_keep_cols.append(self.T)
        ## drop na by label, E and T
        if self.E and self.T:
            self.train.dropna(subset=[self.label, self.E, self.T], inplace=True)
            self.test.dropna(subset=[self.label, self.E, self.T], inplace=True)
        else:
            self.train.dropna(subset=[self.label], inplace=True)
            self.test.dropna(subset=[self.label], inplace=True)

        logging.info(
            f"Drop NA by {self.label} and {self.E} and {self.T} in train and test and left {len(self.train)} and {len(self.test)} with train cases {self.train[self.label].sum()} and test cases {self.test[self.label].sum()}"
        )

        # step2 update information to score_dict
        # step2 save all infomation on a dataframe

        self.model_table = model_table.copy()

        self.train_score, self.test_score = (
            self.train[[self.eid, self.label, *self.other_keep_cols]].copy(),
            self.test[[self.eid, self.label, *self.other_keep_cols]].copy(),
        )

        # # keep the fitted model
        # self.fitted_model_dict = OrderedDict()

        # # keep the metrics
        # self.metrics_dict = {}

    # def get_metrics_df(self):
    #     c_index_df_list = []
    #     auc_df_list = []

    #     for score_name, metrics in self.metrics_dict.items():
    #         c_index = metrics.get("c_index", None)
    #         auc_metrics = metrics.get("auc_metrics", None)

    #         if c_index is not None:
    #             c_index_df_list.append(c_index)
    #         if auc_metrics is not None:
    #             auc_df_list.append(auc_metrics)
    #     c_index_df = pd.concat(c_index_df_list)
    #     auc_df = pd.DataFrame(auc_df_list)
    #     return c_index_df, auc_df

    # def re_cal_metrics(self):
    #     """
    #     re-calculate the metrics only AUC and C-Index
    #     """
    #     for combination_name in self.fitted_model_dict.keys():
    #         # cal metrics
    #         need_cols = [self.label, combination_name]

    #         ## E may equal to T
    #         if self.E and self.T:
    #             if self.E != self.label:
    #                 need_cols.append(self.E)
    #             need_cols.append(self.T)

    #         to_cal_df = self.test_score[need_cols].copy().dropna()

    #         c_index = run_cox(
    #             to_cal_df, var=combination_name, E=self.E, T=self.T, ci=True
    #         )
    #         auc_metrics = cal_binary_metrics(
    #             to_cal_df[self.label], to_cal_df[combination_name], ci=True
    #         )

    #         self.metrics_dict[combination_name] = {
    #             "c_index": c_index,
    #             "auc_metrics": auc_metrics,
    #         }

    def update_model(self, new_model_table=None, duplicate_replace=False):
        """
        fit the model with the new model_config, or
        """
        # update the model_config
        if new_model_table is not None:
            self.model_table = update_concat_df(
                self.model_table,
                new_model_table,
                duplicate_replace=duplicate_replace,
            )

        # fit model by model_table
        # fitted model will add a status to show whether the model is fitted

        for name, model_table_row in self.model_table.iterrows():
            # unpack the params
            params = model_table_row["param"].to_dict()

            # get xvar
            xvar = params["xvar"]

            # get model
            if "model" in params:
                model_fn = params["model"]
            else:
                logging.warning(
                    f"model function not found in {name}, use default glmnet to run lasso"
                )
                model_fn = run_glmnet

            # get model config
            model_fn_config = params.get("config", {})
            if pd.isna(model_fn_config):
                model_fn_config = {}

            # get score name alias or
            score_name = params.get(
                "score_name", name if isinstance(name, str) else "_".join(name)
            )

            # fit the model
            self.fit(
                xvar=xvar,
                name=name,
                score_name=score_name,
                model_fn=model_fn,
                **model_fn_config,
            )

    def fit(self, xvar, name, score_name, model_fn=run_glmnet, **model_fn_config):
        """
        fit the model with the combination

        """

        # step1 check whether the model is already fitted by name have ("status", "fitted") value
        try:
            status_fitted = self.model_table.loc[name, ("status", "fitted")]
        except KeyError:
            status_fitted = False

        if status_fitted == 1:
            logging.warning(f"{name} already fitted, will skip it")
            return

        # step2 fit the model

        model, *_ = model_fn(
            train=self.train,
            test=self.test,
            xvar=xvar,
            label=self.label,
            **model_fn_config,
        )

        # TODO: use model.predict(model=model, data=self.train, xvar = combination) to replace the following
        self.train_score[score_name] = get_predict_v2_from_df(model, self.train, xvar)
        self.test_score[score_name] = get_predict_v2_from_df(model, self.test, xvar)
        ## add the score into train_score
        self.train[score_name] = get_predict_v2_from_df(model, self.train, xvar)
        self.test[score_name] = get_predict_v2_from_df(model, self.test, xvar)

        # cal metrics
        need_cols = [self.label, score_name]

        ## E may equal to T
        if self.E and self.T:
            if self.E != self.label:
                need_cols.append(self.E)
            need_cols.append(self.T)

        to_cal_df = self.test_score[need_cols].copy().dropna()

        # zscore for correct OR and HR
        to_cal_df_train = self.train_score[need_cols].copy().dropna()
        train_mean = to_cal_df_train[score_name].mean()
        train_std = to_cal_df_train[score_name].std()

        to_cal_df[score_name] = (to_cal_df[score_name] - train_mean) / train_std

        # cal c
        if self.E and self.T:
            c_index_df = run_cox(to_cal_df, var=score_name, E=self.E, T=self.T, ci=True)
            c_index_dict = c_index_df.iloc[0].T.to_dict()
            for metric_name, metric_value in c_index_dict.items():
                self.model_table.loc[name, ("c_index", metric_name)] = metric_value

        # cal auc
        auc_metrics_dict = cal_binary_metrics(
            to_cal_df[self.label], to_cal_df[score_name], ci=True
        )
        for metric_name, metric_value in auc_metrics_dict.items():
            self.model_table.loc[name, ("auc", metric_name)] = metric_value

        # update model into model_table
        self.model_table.loc[name, ("status", "fitted")] = 1

        self.model_table.loc[name, ("model", "model")] = model
        self.model_table.loc[name, ("basic", "score_name")] = score_name

    def calibrate(self, method="logitstic"):
        """
        calibrate the score

        """
        # check fitted status
        self._check_status()

        self.train_score_calibrated, self.test_score_calibrated = (
            self.train[[self.eid, self.label, *self.other_keep_cols]].copy(),
            self.test[[self.eid, self.label, *self.other_keep_cols]].copy(),
        )

        # for score_name, score_model_config in self.model_config.items():
        for name, score_name in self.model_table[("basic", "score_name")].items():

            from ppp_prediction.calibration import calibrate

            raw_train_score = self.train_score[[self.label, score_name]].dropna()
            raw_test_score = self.test_score[[self.label, score_name]].dropna()

            calibrated_object = calibrate(
                X_train=raw_train_score[score_name],
                X_test=raw_test_score[score_name],
                y_train=raw_train_score[self.label],
                y_test=raw_test_score[self.label],
                n_bins=10,
                need_scale=True,
            )

            calibration_model = calibrated_object["best_clf"]

            # TODO: use model.predict(model=model, data=self.train, xvar = combination) to replace the following
            self.train_score_calibrated[score_name] = get_predict_v2_from_df(
                calibration_model, self.train_score, [score_name]
            )
            self.test_score_calibrated[score_name] = get_predict_v2_from_df(
                calibration_model, self.test_score, [score_name]
            )

            self.model_table.loc[name, ("model", "calibrated_model")] = (
                calibration_model
            )

    def get_score_names(self):
        # return list(self.fitted_model_dict.keys())
        return self.model_table[("basic", "score_name")].values.tolist()

    def get_score_names_df(self):

        # TODO: if level is more than 2, may have problem
        return (
            self.model_table[[("basic", "score_name")]]
            .copy()
            .droplevel(0, axis=1)
            .reset_index()
        )  #

    def set_color_set(self, colorset=None):
        # self.color
        if colorset is None:
            colorset = list(sns.color_palette("tab20").as_hex())

        self.method_colorset = {k: v for k, v in zip(self.get_score_names(), colorset)}

    @property
    def color_set(self):
        if not hasattr(self, "method_colorset"):
            self.set_color_set()
        return self.method_colorset

    def get_metrics_by_user_multi(
        self, metrics_dict=None, use_calibrate=False, **kwargs
    ):
        """
        metrics_dict: a dict with key as the metrics_name and value as the metrics_fn
        """
        metrics_list = []
        for metrics_name, metrics_fn in metrics_dict.items():
            metrics_df = self.get_metrics_by_user(
                metrics_fn,
                metrics_name=metrics_name,
                use_calibrate=use_calibrate,
                **kwargs,
            )
            metrics_list.append(metrics_df)

        return reduce(lambda x, y: pd.merge(x, y), metrics_list)

    def get_metrics_by_user(
        self, metrics_fn, metrics_name=None, use_calibrate=False, **kwargs
    ):
        """
        metrics_fn: a function accept (y_true, y_prob, other_kwargs) and return a dict; note the first pos will be the label and the second pos will be the score
        """
        metrics_name = metrics_name or metrics_fn.__name__

        metrics_list = []
        for row_idx, row in self.get_score_names_df().iterrows():
            row_dict = row.to_dict()
            score_name = row_dict["score_name"]

            if use_calibrate:

                to_cal_df = self.test_score_calibrated[
                    [self.label, score_name]
                ].dropna()

            else:
                to_cal_df = self.test_score[[self.label, score_name]].dropna()

            metrics_score = metrics_fn(
                to_cal_df[self.label],
                to_cal_df[score_name],
                **kwargs,
            )
            if isinstance(metrics_score, dict):
                metrics_list.append(
                    {
                        **row_dict,
                        **metrics_score,
                    }
                )
                logging.info(
                    f"metrics {metrics_name} return a dict, will unpack it to the dataframe"
                )
            else:
                metrics_list.append(
                    {
                        **row_dict,
                        metrics_name: metrics_score,
                    }
                )

        return pd.DataFrame(metrics_list)

    def _check_status(
        self,
    ):
        if "status" in self.model_table.columns.get_level_values(0):

            if self.model_table[("status", "fitted")].isna().any():
                error_flag = True
            else:
                error_flag = False

        else:
            error_flag = False

        if error_flag:
            raise ValueError(f"model not fitted, run update_model first")
        else:
            return

    @property
    def brier_score(self):
        if not hasattr(self, "train_score_calibrated"):
            logging.warning("No calibrated model fitted, run calibrate first")
            return
        return self.get_metrics_by_user(
            brier_score_loss, use_calibrate=True, metrics_name="brier_score"
        )

    # TODO：画图呈现逻辑，现在默认不用分面，全部都以score_name做

    def calibration_plot(self, n_bins=10, return_df=False, by="test", facet_fn=None):

        # check fitted status
        self._check_status()

        if not hasattr(self, "train_score_calibrated"):
            logging.warning("No calibrated model fitted, run calibrate first")
            return
        if by == "test":
            by_data = self.test_score_calibrated
        elif by == "train":
            by_data = self.train_score_calibrated
        elif by == "all":
            by_data = pd.concat(
                [self.test_score_calibrated, self.train_score_calibrated]
            )
        else:
            raise ValueError("by should be test, train or all")
        # get calibration plot raw df
        calibration_df_list = []

        for row_idx, row in self.get_score_names_df().iterrows():
            row_dict = row.to_dict()
            score_name = row_dict["score_name"]

            score_calibrated = by_data[[self.label, score_name]].dropna().copy()

            c_calibration_df = get_calibration_df(
                data=score_calibrated,  # use train to test
                obs=self.label,
                pred=score_name,
                n_bins=n_bins,
            )

            # assign others
            for k, v in row_dict.items():
                c_calibration_df[k] = v

            c_calibration_df = c_calibration_df.set_index(
                list(row_dict.keys())
            ).reset_index()

            calibration_df_list.append(c_calibration_df)

        calibration_df = pd.concat(calibration_df_list)

        lim_bound = max(
            calibration_df["obsRate"].max(), calibration_df["predMean"].max()
        )

        # TODO: 统一绘图风格 theme
        p = ggplot(
            data=calibration_df,
            mapping=aes(x="predMean", y="obsRate", color="score_name"),
        )
        if facet_fn:
            p = p + facet_fn

        p = (
            p
            + geom_point(alpha=0.8, size=3)
            + geom_line(alpha=0.8)
            # + geom_line()
            + geom_abline(intercept=0, slope=1, linetype="dashed")
            + theme_classic(base_family="Calibri", base_size=12)  # 使用Tufte主题
            + theme(axis_line=element_line())
            + theme(
                figure_size=(12, 12),
                legend_position="top",
                axis_text_x=element_text(angle=90),
                strip_background=element_blank(),
                axis_text=element_text(size=12),  # 调整轴文字大小
                axis_title=element_text(size=14),  # 调整轴标题大小和样式
                legend_title=element_text(size=14),  # 调整图例标题大小和样式
                legend_text=element_text(),  # 调整图例文字大小
                strip_text=element_text(size=14),  # 调整分面标签的大小和样式
                plot_title=element_text(size=16, hjust=0.5),  # 添加图表标题并居中
                # plot_margin = margin(10, 10, 10, 10)  # 设置图表边距
            )
            + scale_color_manual(values=self.color_set)
            + labs(
                x="Predicted risk",
                y="Observed risk",
                title="Calibration plot",
                color="Score",
            )
            + coord_cartesian(xlim=(0, lim_bound), ylim=(0, lim_bound))
        )
        if return_df:
            return p, calibration_df
        else:
            return p

    # TODO:接口统一
    def plot_dca(self, return_df=False, by="test"):

        self._check_status()

        if not hasattr(self, "train_score_calibrated"):
            logging.warning("No calibrated model fitted, run calibrate first")
            return

        if by == "test":
            by_data = self.test_score_calibrated
        elif by == "train":
            by_data = self.train_score_calibrated
        elif by == "all":
            by_data = pd.concat(
                [self.test_score_calibrated, self.train_score_calibrated]
            )
        else:
            raise ValueError("by should be test, train or all")
        # TODO: update to new code of get_dca_df
        test = by_data[[self.label, *self.get_score_names()]].dropna().copy()
        event_rate = test[self.label].sum() / len(test)
        dca_df = dca(
            data=test,
            outcome=self.label,
            modelnames=self.get_score_names(),
            thresholds=np.linspace(0, event_rate, 1000),
        )
        dca_df["st_net_benefit"] = dca_df["net_benefit"] / event_rate
        dca_df["disease"] = self.disease_name

        # TODO: 统一绘图风格 theme; by another function
        # from dca_df
        p = (
            ggplot(
                data=dca_df,
                mapping=aes(x="threshold", y="st_net_benefit", color="score_name"),
            )
            # + facet_wrap("disease", scales="free")
            + geom_line()
            + ylim(0, 1)
            + theme_classic(base_family="Calibri", base_size=12)  # 使用Tufte主题
            + theme(axis_line=element_line())
            + theme(
                figure_size=(12, 12),
                legend_position="top",
                axis_text_x=element_text(angle=90),
                strip_background=element_blank(),
                axis_text=element_text(size=12),  # 调整轴文字大小
                axis_title=element_text(size=14),  # 调整轴标题大小和样式
                legend_title=element_text(size=14),  # 调整图例标题大小和样式
                legend_text=element_text(),  # 调整图例文字大小
                strip_text=element_text(size=14),  # 调整分面标签的大小和样式
                plot_title=element_text(size=16, hjust=0.5),  # 添加图表标题并居中
                # plot_margin = margin(10, 10, 10, 10)  # 设置图表边距
            )
            # + scale_color_manual(values=c_color_dict)
        )

        if return_df:
            return p, dca_df
        else:
            return p

    def plot_auc(
        self,
        return_df=False,
        by="test",
    ):
        self._check_status()

        if by == "test":
            by_data = self.test_score
        elif by == "train":
            by_data = self.train_score
        elif by == "all":
            by_data = pd.concat([self.test_score, self.train_score])
        else:
            raise ValueError("by should be test, train or all")

        # get auc famhistory_df_list
        auc_df_list = []

        for row_idx, row in self.get_score_names_df().iterrows():
            row_dict = row.to_dict()
            score_name = row_dict["score_name"]

            to_cal_df = by_data[[self.label, score_name]].dropna()
            fpr, tpr, _ = roc_curve(to_cal_df[self.label], to_cal_df[score_name])
            roc_current_df = pd.DataFrame(
                [
                    {
                        "fpr": fpr_,
                        "tpr": tpr_,
                    }
                    for fpr_, tpr_ in zip(fpr, tpr)
                ]
            )
            for k, v in row_dict.items():
                roc_current_df[k] = v
            roc_current_df = roc_current_df.set_index(
                list(row_dict.keys())
            ).reset_index()

            auc_df_list.append(roc_current_df)
        auc_df = pd.concat(auc_df_list)

        # TODO: 统一绘图风格 theme
        # from auc_df

        p = (
            ggplot(
                data=auc_df,
                mapping=aes(x="fpr", y="tpr", color="score_name"),
            )
            + geom_line()
            + geom_abline(intercept=0, slope=1, linetype="dashed")
            + theme_classic(base_family="Calibri", base_size=12)  # 使用Tufte主题
            + theme(axis_line=element_line())
            + theme(
                figure_size=(12, 12),
                legend_position="top",
                axis_text_x=element_text(angle=90),
                strip_background=element_blank(),
                axis_text=element_text(size=12),  # 调整轴文字大小
                axis_title=element_text(size=14),  # 调整轴标题大小和样式
                legend_title=element_text(size=14),  # 调整图例标题大小和样式
                legend_text=element_text(),  # 调整图例文字大小
                strip_text=element_text(size=14),  # 调整分面标签的大小和样式
                plot_title=element_text(size=16, hjust=0.5),  # 添加图表标题并居中
                # plot_margin = margin(10, 10, 10, 10)  # 设置图表边距
            )
            + scale_color_manual(values=self.color_set)
            + labs(
                x="1 - Specificity",
                y="Sensitivity",
                title="ROC curve",
                color="score_name",
            )
        )
        if return_df:
            return p, auc_df
        else:
            return p

    def plot_risk_strat(
        self,
        return_df=False,
        by="test",
        facet=False,
        k=10,
        show_ci=True,
        n_resample=100,
    ):
        self._check_status()

        if by == "test":
            by_data = self.test_score
        elif by == "train":
            by_data = self.train_score
        elif by == "all":
            by_data = pd.concat([self.test_score, self.train_score])
        else:
            raise ValueError("by should be test, train or all")

        # get risk_strat_df
        risk_strat_df_list = []
        # for score_name in self.get_score_names():
        for row_idx, row in self.get_score_names_df().iterrows():
            row_dict = row.to_dict()
            score_name = row_dict["score_name"]

            risk_strat_df = get_risk_strat_df(
                data=by_data.copy(),
                y_true=self.label,
                y_pred=score_name,
                k=k,
                n_resample=n_resample,
            )

            for k, v in row_dict.items():
                risk_strat_df[k] = v
            risk_strat_df = risk_strat_df.set_index(list(row_dict.keys())).reset_index()

            # risk_strat_df["model"] = score_name
            # risk_strat_df["disease"] = self.disease_name
            risk_strat_df_list.append(risk_strat_df)
        risk_strat_df = pd.concat(risk_strat_df_list)

        # TODO: 统一绘图风格 theme
        # from risk_strat_df

        dodge_width = 0.6
        p = ggplot(
            data=risk_strat_df,
            mapping=aes(x="y_pred_bins", y="mean_true", color="score_name"),
        )
        if facet:
            p = p + facet_wrap("model", scales="free_y")

        p = p + geom_point(
            alpha=0.8,
            size=2,
            position=position_dodge(width=dodge_width),
            na_rm=True,
        )

        if show_ci:
            p = p + geom_linerange(
                mapping=aes(ymin="ci_low", ymax="ci_high"),
                size=1,
                alpha=0.8,
                position=position_dodge(width=dodge_width),
                na_rm=True,
            )

        p = (
            p
            + theme_classic(base_family="Calibri", base_size=12)  # 使用Tufte主题
            + theme(axis_line=element_line())
            + theme(
                figure_size=(10, 5),
                legend_position="top",
                axis_text_x=element_text(angle=90),
                strip_background=element_blank(),
                axis_text=element_text(size=12),  # 调整轴文字大小
                axis_title=element_text(size=14),  # 调整轴标题大小和样式
                legend_title=element_text(size=14),  # 调整图例标题大小和样式
                legend_text=element_text(),  # 调整图例文字大小
                strip_text=element_text(size=14),  # 调整分面标签的大小和样式
                plot_title=element_text(size=16, hjust=0.5),  # 添加图表标题并居中
                # plot_margin = margin(10, 10, 10, 10)  # 设置图表边距
            )
            + guides(color=guide_legend(nrow=1, title=""))
            + scale_color_manual(values=self.color_set)
            + labs(
                x="Risk Decile",  # 设置X轴标签
                y="Observed Events Rate",  # 设置Y轴标签
                # color="group",  # 设置图例标题
                # title="",  # 添加图表标题
            )
            # + coord_flip()
        )

        if return_df:
            return p, risk_strat_df
        else:
            return p

    def compare_model(self, compare_list, by="test", ci=True, n_resample=100):
        """
        [
        (ref1, new1)
        (ref2, new2)
        ]
        """
        if by == "test":
            by_data = self.test_score
        elif by == "train":
            by_data = self.train_score
        elif by == "all":
            by_data = pd.concat([self.test_score, self.train_score])
        else:
            raise ValueError("by should be test, train or all")

        compare_result_list = []
        for ref, new in compare_list:
            to_cal_df = by_data[[self.label, ref, new]].dropna().copy()

            total = {}

            total["ref"] = ref
            total["new"] = new
            total["disease"] = self.disease_name

            # NRI
            NRI_res = NRI(
                to_cal_df[self.label],
                to_cal_df[ref],
                to_cal_df[new],
                ci=ci,
                n_resamples=n_resample,
            )
            total.update(NRI_res)

            # IDI
            IDI_res = IDI(
                to_cal_df[self.label],
                to_cal_df[ref],
                to_cal_df[new],
                ci=ci,
                n_resamples=n_resample,
            )
            total.update(IDI_res)

            # AUC diff
            auc_diff_res = roc_test(
                to_cal_df[self.label], to_cal_df[ref], to_cal_df[new]
            )
            total.update(auc_diff_res)

            # C diff
            if self.E and self.T:
                c_diff_res = compareC(
                    to_cal_df[self.T],
                    to_cal_df[self.label],
                    to_cal_df[ref],
                    to_cal_df[new],
                )
                total.update(c_diff_res)

            compare_result_list.append(total)
        return pd.DataFrame(compare_result_list)

    def plot_performance(
        self,
        metric="auc",
        # or
        metrics_fn=None,
        metrics_name=None,
        # return
        return_df=False,
        **kwargs,
    ):
        """
        if metric is a function, then use it to calculate the metrics; works like `get_metrics_by_user`
        """
        # get metrics_df

        if metric == "c_index":
            plt_data = (
                self.model_table[["basic", metric]]
                .copy()
                .droplevel(0, axis=1)
                .reset_index()
            )
            y = "c_index"
            y_LCI = "c_index_LCI"
            y_UCI = "c_index_UCI"
            y_name = "C-index"

        elif metric == "auc":
            plt_data = (
                self.model_table[["basic", metric]]
                .copy()
                .droplevel(0, axis=1)
                .reset_index()
            )
            y = "AUC"
            y_LCI = "AUC_LCI"
            y_UCI = "AUC_UCI"
            y_name = "AUC"

        elif metric == "brier_score":
            plt_data = (
                self.model_table[["basic", metric]]
                .copy()
                .droplevel(0, axis=1)
                .reset_index()
            )
            y = "brier_score"
            y_LCI = None
            y_UCI = None

            y_name = "Brier Score"
        elif metric is None and metrics_fn is not None:
            if metrics_name is None:
                if isinstance(metrics_fn, partial):
                    raise ValueError(
                        "metrics_name should be provided when metrics_fn is a functools.partial"
                    )
                metrics_name = metrics_fn.__name__

            use_calibrate = kwargs.pop("use_calibrate", False)
            plt_data = self.get_metrics_by_user(
                metrics_fn, metrics_name=metrics_name, use_calibrate=use_calibrate
            )
            y = metrics_name
            if y not in plt_data.columns:
                raise ValueError(
                    f"metrics_name {metrics_name} not found in the metrics_df, there are {plt_data.columns}"
                )
            if f"{y}_LCI" in plt_data.columns:
                y_LCI = f"{y}_LCI"
                y_UCI = f"{y}_UCI"
            else:
                y_LCI = y_UCI = None
            y_name = metrics_name
        else:
            raise ValueError("metric should be c_index or auc")
        p = (
            ggplot(
                data=plt_data,
                mapping=aes(x="score_name", y=y, color="score_name"),
            )
            # + facet_wrap("disease", scales="free_y")
            + geom_point(alpha=0.8, size=3, position=position_dodge(width=0.5))
        )
        if y_LCI is not None:
            p = p + geom_linerange(
                mapping=aes(ymin=y_LCI, ymax=y_UCI),
                size=1,
                alpha=0.8,
                position=position_dodge(width=0.5),
            )
        p = (
            p
            + theme_classic(base_family="Calibri", base_size=12)  # 使用Tufte主题
            + theme(axis_line=element_line())
            + theme(
                figure_size=(12, 6),
                legend_position="none",
                axis_text_x=element_text(angle=90),
                strip_background=element_blank(),
                axis_text=element_text(size=12),  # 调整轴文字大小
                axis_title=element_text(size=14),  # 调整轴标题大小和样式
                legend_title=element_text(size=14),  # 调整图例标题大小和样式
                legend_text=element_text(),  # 调整图例文字大小
                strip_text=element_text(size=14),  # 调整分面标签的大小和样式
                plot_title=element_text(size=16, hjust=0.5),  # 添加图表标题并居中
                # plot_margin = margin(10, 10, 10, 10)  # 设置图表边距
            )
            # + guides(color=False)
            # + scale_color_manual(values=colorset)
            + scale_color_manual(values=self.color_set)
            + labs(
                x="Method",  # 设置X轴标签
                # y="C-index",  # 设置Y轴标签
                y=y_name,
                # color="Method",  # 设置图例标题
                title="Comparison of Methods",  # 添加图表标题
            )
            # + coord_flip()
        )
        if return_df:
            return p, plt_data
        else:
            return p


# save_fig(

In [4]:
# load test dataset

import sklearn

# return dataframe
X_df = sklearn.datasets.load_breast_cancer(as_frame=True)["data"]
y_df = sklearn.datasets.load_breast_cancer(as_frame=True)["target"]

df = X_df.join(y_df).reset_index(drop=False, names=["eid"])
df

Unnamed: 0,eid,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,...,worst texture,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension,target
0,0,17.99,10.38,122.80,1001.0,0.11840,0.27760,0.30010,0.14710,0.2419,...,17.33,184.60,2019.0,0.16220,0.66560,0.7119,0.2654,0.4601,0.11890,0
1,1,20.57,17.77,132.90,1326.0,0.08474,0.07864,0.08690,0.07017,0.1812,...,23.41,158.80,1956.0,0.12380,0.18660,0.2416,0.1860,0.2750,0.08902,0
2,2,19.69,21.25,130.00,1203.0,0.10960,0.15990,0.19740,0.12790,0.2069,...,25.53,152.50,1709.0,0.14440,0.42450,0.4504,0.2430,0.3613,0.08758,0
3,3,11.42,20.38,77.58,386.1,0.14250,0.28390,0.24140,0.10520,0.2597,...,26.50,98.87,567.7,0.20980,0.86630,0.6869,0.2575,0.6638,0.17300,0
4,4,20.29,14.34,135.10,1297.0,0.10030,0.13280,0.19800,0.10430,0.1809,...,16.67,152.20,1575.0,0.13740,0.20500,0.4000,0.1625,0.2364,0.07678,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
564,564,21.56,22.39,142.00,1479.0,0.11100,0.11590,0.24390,0.13890,0.1726,...,26.40,166.10,2027.0,0.14100,0.21130,0.4107,0.2216,0.2060,0.07115,0
565,565,20.13,28.25,131.20,1261.0,0.09780,0.10340,0.14400,0.09791,0.1752,...,38.25,155.00,1731.0,0.11660,0.19220,0.3215,0.1628,0.2572,0.06637,0
566,566,16.60,28.08,108.30,858.1,0.08455,0.10230,0.09251,0.05302,0.1590,...,34.12,126.70,1124.0,0.11390,0.30940,0.3403,0.1418,0.2218,0.07820,0
567,567,20.60,29.33,140.10,1265.0,0.11780,0.27700,0.35140,0.15200,0.2397,...,39.42,184.60,1821.0,0.16500,0.86810,0.9387,0.2650,0.4087,0.12400,0


In [None]:
features = df.columns[1:-1].tolist()
combination_dict = OrderedDict(
    {
        ("Lasso"): {
            "xvar": features,
            "model": fit_best_model_v2,
            "config": {"cv": 10, "engine": "sklearn"},
        },
        ("xgboost"): {
            "xvar": features,
            "model": fit_xgboost,
        },
        ("lightGBM"): {
            "xvar": features,
            "model": fit_lightgbm,
        },
        ("TabPFN"): {
            "xvar": features,
            "model": fit_tabpfn,
        },
        ("TabNet"): {
            "xvar": features,
            "model": fit_tabnet,
        },
    }
)

model_table = config_dict_to_df(combination_dict, ("model"))
model_table

Unnamed: 0_level_0,param,param,param
Unnamed: 0_level_1,xvar,model,config
model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2
Lasso,"[mean radius, mean texture, mean perimeter, me...",<function fit_best_model_v2 at 0x7f76c04b0430>,"{'cv': 10, 'engine': 'sklearn'}"
xgboost,"[mean radius, mean texture, mean perimeter, me...",<function fit_xgboost at 0x7f76b04a0a60>,
lightGBM,"[mean radius, mean texture, mean perimeter, me...",<function fit_lightgbm at 0x7f76b04a0b80>,
TabPFN,"[mean radius, mean texture, mean perimeter, me...",<function fit_tabpfn at 0x7f76b04a0ee0>,


In [6]:
train_eid, test_eid = train_test_split(df, test_size=0.2, random_state=42)

train_eid = train_eid[["eid"]]
test_eid = test_eid[["eid"]]
test_eid

Unnamed: 0,eid
204,204
70,70
131,131
431,431
540,540
...,...
486,486
75,75
249,249
238,238


In [37]:
targetModel = DiseaseScoreModel_V2(
    disease_df=df,
    model_table=model_table,
    label="target",
    disease_name="target",
    # test_size=0.5,
    train_eid=train_eid.eid,
    test_eid=test_eid.eid,
    other_keep_cols=[col for col in df.columns if col not in ["target", "eid"]],
)

INFO:root:Loading data with train cases 286 and test cases 71 of target, while 32 columns
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
INFO:root:Drop NA by target and None and None in train and test and left 455 and 114 with train cases 286 and test cases 71


In [248]:
targetModel.update_model()
targetModel.model_table

NameError: name 'targetModel' is not defined

In [249]:
targetModel.model_table

NameError: name 'targetModel' is not defined

In [7]:
from torch import Tensor
from torch.nn import Linear, Module, ModuleList

from torch_frame import TensorFrame, stype
from torch_frame.nn.conv import TabTransformerConv
from torch_frame.nn.encoder import (
    EmbeddingEncoder,
    LinearEncoder,
    StypeWiseFeatureEncoder,
)


class ExampleTransformer(Module):
    def __init__(
        self,
        channels,
        out_channels,
        num_layers,
        num_heads,
        col_stats,
        col_names_dict,
    ):
        super().__init__()
        self.encoder = StypeWiseFeatureEncoder(
            out_channels=channels,
            col_stats=col_stats,
            col_names_dict=col_names_dict,
            stype_encoder_dict={
                stype.categorical: EmbeddingEncoder(),
                stype.numerical: LinearEncoder(),
            },
        )
        self.convs = ModuleList(
            [
                TabTransformerConv(
                    channels=channels,
                    num_heads=num_heads,
                )
                for _ in range(num_layers)
            ]
        )
        self.decoder = Linear(channels, out_channels)

    def forward(self, tf: TensorFrame) -> Tensor:
        x, _ = self.encoder(tf)
        for conv in self.convs:
            x = conv(x)
        out = self.decoder(x.mean(dim=1))
        return out

In [176]:
total_data = pd.read_feather(
    "/mnt/d/桌面/work/AAA_lifeStyle/V3/output/raw_data/score_df.feather"
)
lifeStyle_cols = [
    "PhysicalActivity",
    "HealthyDiet",
    "Alcohol_consumption",
    "Sedentary_behaviour",
    "BMI",
    "SleepPattern",
    "SmokingStatus",
]
to_dummpy_cols = []
for col in total_data.columns:
    if total_data[col].dtype == "category":
        to_dummpy_cols.append(col)

to_dummpy_cols
# dummpy used the first category as the base category
# so we need to drop the first category to avoid multicollinearity
tmp = pd.DataFrame()
dummpy_cols = []
# for col in [
#     "Sex(M)",
#     "Ethnicity",
#     # "Educational_attainment",
#     # "Family_income",
#     # "Empolyment",
#     "antihypertensives",
#     "antihyperglycemic",
#     "lipid_lowering",
#     "History_of_Hypertension",
#     "History_of_Diabetes",
#     "familiy_history_heart_disease",
# ]:
lifeStyleDummyCols = []

for col in to_dummpy_cols:
    # dummy variable
    dummy = pd.get_dummies(
        total_data[col], prefix=col, drop_first=True
    )  # default version is True, False for test at 20250305
    for i in dummy.columns:
        dummy[i] = dummy[i].astype("int")
    dummpy_cols.extend(dummy.columns)
    if col in lifeStyle_cols:
        lifeStyleDummyCols.extend(dummy.columns)
    tmp = pd.concat([tmp, dummy], axis=1)
total_data = pd.concat([total_data, tmp], axis=1)
total_data
for col in dummpy_cols:
    total_data[col] = total_data[col].astype(float)
total_data["PRS"] = (total_data["PRS"] - total_data["PRS"].mean()) / total_data[
    "PRS"
].std()

In [278]:
RF = [
    "Age_at_recruitment",
    "Sex_M",
    "HbA1C",
    "LDL_cholesterol",
    "HDL_cholesterol",
    "Cholesterol",
    "Triglycerides",
    "SBP",
    "eGFR",
    "antihypertensives",
    "antihyperglycemic",
    "lipid_lowering",
]
lifeStyle_cols
features = lifeStyle_cols + RF + ["PRS"]

cat_cols = [
    *lifeStyle_cols,
    "Sex_M",
    "antihypertensives",
    "antihyperglycemic",
    "lipid_lowering",
    "incident",
]
qt_cols = [
    "Age_at_recruitment",
    "HbA1C",
    "LDL_cholesterol",
    "HDL_cholesterol",
    "Cholesterol",
    "Triglycerides",
    "SBP",
    "eGFR",
    "PRS",
]

In [279]:
total_data["AAA"]

0         Control
1         Control
2         Control
3         Control
4         Control
           ...   
435991    Control
435992    Control
435993        AAA
435994    Control
435995    Control
Name: AAA, Length: 435996, dtype: category
Categories (2, object): ['Control', 'AAA']

In [288]:
from torch_frame.data import Dataset
from torch_frame import TensorFrame, stype

# train_df = df.query("eid in @train_eid.eid")
# test_df = df.query("eid in @test_eid.eid")
# train_df

# split train, val and tes
# train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
# train_df, val_df = train_test_split(train_df, test_size=0.3, random_state=42)
train_df = total_data[total_data["Type"] == "Train"]
test_df = total_data[total_data["Type"] == "Test"]
train_df, val_df = train_test_split(train_df, test_size=0.3, random_state=42)

In [289]:
# train_df = train_df.groupby("AAA").sample(n=3000, replace=True)

In [282]:
# import argparse
# import math
# import os
# import os.path as osp
# import time
# from typing import Any, Optional

# import numpy as np
# import optuna
# import torch
# from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss
# from torch.optim.lr_scheduler import ExponentialLR
# from torchmetrics import AUROC, Accuracy, MeanSquaredError
# from tqdm import tqdm

# from torch_frame import stype
# from torch_frame.data import DataLoader
# from torch_frame.datasets import DataFrameBenchmark
# from torch_frame.gbdt import CatBoost, LightGBM, XGBoost
# from torch_frame.nn.encoder import EmbeddingEncoder, LinearBucketEncoder
# from torch_frame.nn.models import (
#     MLP,
#     ExcelFormer,
#     FTTransformer,
#     ResNet,
#     TabNet,
#     TabTransformer,
#     Trompt,
# )
# from torch_frame.typing import TaskType

# # class Args:
# #     def __init__(self, **kwargs):
# #         # 使用字典存储键值对
# #         self.__dict__.update(kwargs)

# # # 创建 Args 对象

# # args = Args(
# #     model_type = "TabNet", # TabNet, TabTransformer, ExcelFormer, MLP, ResNet, Trompt, LightGBM, CatBoost, XGBoost
# #     task_type = "binary_classification", # binary_classification, multiclass_classification, regression
# #     scale = "small",
# #     idx = 0,
# # )

# model_type = "TabNet"
# train_dataset = TrainDataSet
# test_dataset = TestDataSet
# val_dataset = ValDataSet
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"]
# task_type = "binary_classification"  # binary_classification, multiclass_classification, regression
# sacle = "small"  # small, medium, large
# epochs = 50
# num_trials = 20  # Number of Optuna-based hyper-parameter tuning.
# num_repeats = 5  # Number of repeated training and eval on the best config
# seed = 42
# result_path = "./test"


# def fit_tabular_dl(
#     model_type="TabNet",
#     train_dataset=TrainDataSet,
#     test_dataset=TestDataSet,
#     val_dataset=ValDataSet,
#     device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
#     TRAIN_CONFIG_KEYS=["batch_size", "gamma_rate", "base_lr"],
#     # task_type = "binary_classification",  # binary_classification, multiclass_classification, regression
#     # sacle = "small" , # small, medium, large
#     epochs=50,
#     num_trials=20,  # Number of Optuna-based hyper-parameter tuning.
#     num_repeats=5,  # Number of repeated training and eval on the best config
#     seed=42,
#     result_path="./test",
# ):

#     torch.manual_seed(seed)

#     train_tensor_frame = train_dataset.tensor_frame
#     val_tensor_frame = val_dataset.tensor_frame
#     test_tensor_frame = test_dataset.tensor_frame

#     if train_dataset.task_type == TaskType.BINARY_CLASSIFICATION:
#         out_channels = 1
#         loss_fun = BCEWithLogitsLoss()
#         metric_computer = AUROC(task="binary").to(device)
#         higher_is_better = True
#     elif train_dataset.task_type == TaskType.MULTICLASS_CLASSIFICATION:
#         out_channels = train_dataset.num_classes
#         loss_fun = CrossEntropyLoss()
#         metric_computer = Accuracy(
#             task="multiclass", num_classes=train_dataset.num_classes
#         ).to(device)
#         higher_is_better = True
#     elif train_dataset.task_type == TaskType.REGRESSION:
#         out_channels = 1
#         loss_fun = MSELoss()
#         metric_computer = MeanSquaredError(squared=False).to(device)
#         higher_is_better = False

#     # To be set for each model
#     model_cls = None
#     col_stats = None

#     # Set up model specific search space
#     if model_type == "TabNet":
#         model_search_space = {
#             "split_attn_channels": [64, 128, 256],
#             "split_feat_channels": [64, 128, 256],
#             "gamma": [1.0, 1.2, 1.5],
#             "num_layers": [4, 6, 8],
#         }
#         train_search_space = {
#             "batch_size": [2048, 4096],
#             "base_lr": [0.001, 0.01],
#             "gamma_rate": [0.9, 0.95, 1.0],
#         }
#         model_cls = TabNet
#         col_stats = train_dataset.col_stats
#     elif model_type == "FTTransformer":
#         model_search_space = {
#             "channels": [64, 128, 256],
#             "num_layers": [4, 6, 8],
#         }
#         train_search_space = {
#             "batch_size": [256, 512],
#             "base_lr": [0.0001, 0.001],
#             "gamma_rate": [0.9, 0.95, 1.0],
#         }
#         model_cls = FTTransformer
#         col_stats = train_dataset.col_stats
#     elif model_type == "FTTransformerBucket":
#         model_search_space = {
#             "channels": [64, 128, 256],
#             "num_layers": [4, 6, 8],
#         }
#         train_search_space = {
#             "batch_size": [256, 512],
#             "base_lr": [0.0001, 0.001],
#             "gamma_rate": [0.9, 0.95, 1.0],
#         }
#         model_cls = FTTransformer

#         col_stats = train_dataset.col_stats
#     elif model_type == "ResNet":
#         model_search_space = {
#             "channels": [64, 128, 256],
#             "num_layers": [4, 6, 8],
#         }
#         train_search_space = {
#             "batch_size": [256, 512],
#             "base_lr": [0.0001, 0.001],
#             "gamma_rate": [0.9, 0.95, 1.0],
#         }
#         model_cls = ResNet
#         col_stats = train_dataset.col_stats
#     elif model_type == "MLP":
#         model_search_space = {
#             "channels": [64, 128, 256],
#             "num_layers": [1, 2, 4],
#         }
#         train_search_space = {
#             "batch_size": [256, 512],
#             "base_lr": [0.0001, 0.001],
#             "gamma_rate": [0.9, 0.95, 1.0],
#         }
#         model_cls = MLP
#         col_stats = train_dataset.col_stats
#     elif model_type == "TabTransformer":
#         model_search_space = {
#             "channels": [16, 32, 64, 128],
#             "num_layers": [4, 6, 8],
#             "num_heads": [4, 8],
#             "encoder_pad_size": [2, 4],
#             "attn_dropout": [0, 0.2],
#             "ffn_dropout": [0, 0.2],
#         }
#         train_search_space = {
#             "batch_size": [128, 256],
#             "base_lr": [0.0001, 0.001],
#             "gamma_rate": [0.9, 0.95, 1.0],
#         }
#         model_cls = TabTransformer
#         col_stats = train_dataset.col_stats
#     elif model_type == "Trompt":
#         model_search_space = {
#             "channels": [64, 128, 192],
#             "num_layers": [4, 6, 8],
#             "num_prompts": [64, 128, 192],
#         }
#         train_search_space = {
#             "batch_size": [128, 256],
#             "base_lr": [0.01, 0.001],
#             "gamma_rate": [0.9, 0.95, 1.0],
#         }
#         if train_tensor_frame.num_cols > 20:
#             # Reducing the model size to avoid GPU OOM
#             model_search_space["channels"] = [64, 128]
#             model_search_space["num_prompts"] = [64, 128]
#         elif train_tensor_frame.num_cols > 50:
#             model_search_space["channels"] = [64]
#             model_search_space["num_prompts"] = [64]
#         model_cls = Trompt
#         col_stats = train_dataset.col_stats
#     elif model_type == "ExcelFormer":
#         from torch_frame.transforms import (
#             CatToNumTransform,
#             MutualInformationSort,
#         )

#         categorical_transform = CatToNumTransform()
#         categorical_transform.fit(train_dataset.tensor_frame, train_dataset.col_stats)
#         train_tensor_frame = categorical_transform(train_tensor_frame)
#         # val_tensor_frame = categorical_transform(val_tensor_frame)
#         # test_tensor_frame = categorical_transform(test_tensor_frame)
#         col_stats = categorical_transform.transformed_stats

#         mutual_info_sort = MutualInformationSort(task_type=train_dataset.task_type)
#         mutual_info_sort.fit(train_tensor_frame, col_stats)
#         train_tensor_frame = mutual_info_sort(train_tensor_frame)
#         # val_tensor_frame = mutual_info_sort(val_tensor_frame)
#         # test_tensor_frame = mutual_info_sort(test_tensor_frame)

#         model_search_space = {
#             "in_channels": [128, 256],
#             "num_heads": [8, 16, 32],
#             "num_layers": [4, 6, 8],
#             "diam_dropout": [0, 0.2],
#             "residual_dropout": [0, 0.2],
#             "aium_dropout": [0, 0.2],
#             "mixup": [None, "feature", "hidden"],
#             "beta": [0.5],
#             "num_cols": [train_tensor_frame.num_cols],
#         }
#         train_search_space = {
#             "batch_size": [256, 512],
#             "base_lr": [0.001],
#             "gamma_rate": [0.9, 0.95, 1.0],
#         }
#         model_cls = ExcelFormer

#     assert model_cls is not None
#     assert col_stats is not None
#     assert set(train_search_space.keys()) == set(TRAIN_CONFIG_KEYS)
#     col_names_dict = train_tensor_frame.col_names_dict

#     def train(
#         model: Module,
#         loader: DataLoader,
#         optimizer: torch.optim.Optimizer,
#         epoch: int,
#     ) -> float:
#         model.train()
#         loss_accum = total_count = 0

#         for tf in tqdm(loader, desc=f"Epoch: {epoch}"):
#             tf = tf.to(device)
#             y = tf.y
#             if isinstance(model, ExcelFormer):
#                 # Train with FEAT-MIX or HIDDEN-MIX
#                 pred, y = model(tf, mixup_encoded=True)
#             elif isinstance(model, Trompt):
#                 # Trompt uses the layer-wise loss
#                 pred = model(tf)
#                 num_layers = pred.size(1)
#                 # [batch_size * num_layers, num_classes]
#                 pred = pred.view(-1, out_channels)
#                 y = tf.y.repeat_interleave(num_layers)
#             else:
#                 pred = model(tf)

#             if pred.size(1) == 1:
#                 pred = pred.view(
#                     -1,
#                 )
#             if train_dataset.task_type == TaskType.BINARY_CLASSIFICATION:
#                 y = y.to(torch.float)
#             loss = loss_fun(pred, y)
#             optimizer.zero_grad()
#             loss.backward()
#             loss_accum += float(loss) * len(tf.y)
#             print(tf.y)
#             total_count += len(tf.y)
#             optimizer.step()
#         return loss_accum / total_count

#     @torch.no_grad()
#     def test(
#         model: Module,
#         loader: DataLoader,
#     ) -> float:
#         model.eval()
#         metric_computer.reset()
#         for tf in loader:
#             tf = tf.to(device)
#             pred = model(tf)
#             if isinstance(model, Trompt):
#                 pred = pred.mean(dim=1)
#             if train_dataset.task_type == TaskType.MULTICLASS_CLASSIFICATION:
#                 pred = pred.argmax(dim=-1)
#             elif train_dataset.task_type == TaskType.REGRESSION:
#                 pred = pred.view(
#                     -1,
#                 )
#             metric_computer.update(pred, tf.y)
#         return metric_computer.compute().item()

#     def train_and_eval_with_cfg(
#         model_cfg: dict[str, Any],
#         train_cfg: dict[str, Any],
#         trial: Optional[optuna.trial.Trial] = None,
#     ) -> tuple[float, float]:
#         # Use model_cfg to set up training procedure
#         if model_type == "FTTransformerBucket":
#             # Use LinearBucketEncoder instead
#             stype_encoder_dict = {
#                 stype.categorical: EmbeddingEncoder(),
#                 stype.numerical: LinearBucketEncoder(),
#             }
#             model_cfg["stype_encoder_dict"] = stype_encoder_dict
#         model = model_cls(
#             **model_cfg,
#             out_channels=out_channels,
#             col_stats=col_stats,
#             col_names_dict=col_names_dict,
#         ).to(device)
#         model.reset_parameters()
#         # Use train_cfg to set up training procedure
#         optimizer = torch.optim.Adam(model.parameters(), lr=train_cfg["base_lr"])
#         lr_scheduler = ExponentialLR(optimizer, gamma=train_cfg["gamma_rate"])
#         train_loader = DataLoader(
#             train_tensor_frame,
#             batch_size=train_cfg["batch_size"],
#             shuffle=True,
#             drop_last=True,
#         )
#         val_loader = DataLoader(val_tensor_frame, batch_size=train_cfg["batch_size"])
#         test_loader = DataLoader(test_tensor_frame, batch_size=train_cfg["batch_size"])

#         if higher_is_better:
#             best_val_metric = 0
#         else:
#             best_val_metric = math.inf

#         for epoch in range(1, epochs + 1):
#             train_loss = train(model, train_loader, optimizer, epoch)
#             val_metric = test(model, val_loader)

#             if higher_is_better:
#                 if val_metric > best_val_metric:
#                     best_val_metric = val_metric
#                     best_test_metric = test(model, test_loader)
#             else:
#                 if val_metric < best_val_metric:
#                     best_val_metric = val_metric
#                     best_test_metric = test(model, test_loader)
#             lr_scheduler.step()
#             print(f"Train Loss: {train_loss:.4f}, Val: {val_metric:.4f}")

#             if trial is not None:
#                 trial.report(val_metric, epoch)
#                 if trial.should_prune():
#                     raise optuna.TrialPruned()

#         print(f"Best val: {best_val_metric:.4f}, Best test: {best_test_metric:.4f}")
#         return best_val_metric, best_test_metric

#     def objective(trial: optuna.trial.Trial) -> float:
#         model_cfg = {}
#         for name, search_list in model_search_space.items():
#             model_cfg[name] = trial.suggest_categorical(name, search_list)
#         train_cfg = {}
#         for name, search_list in train_search_space.items():
#             train_cfg[name] = trial.suggest_categorical(name, search_list)

#         best_val_metric, _ = train_and_eval_with_cfg(
#             model_cfg=model_cfg, train_cfg=train_cfg, trial=trial
#         )
#         return best_val_metric


#     # Hyper-parameter optimization with Optuna
#     print("Hyper-parameter search via Optuna")
#     start_time = time.time()
#     study = optuna.create_study(
#         pruner=optuna.pruners.MedianPruner(),
#         direction="maximize" if higher_is_better else "minimize",
#     )
#     study.optimize(objective, n_trials=num_trials)
#     end_time = time.time()
#     search_time = end_time - start_time
#     print("Hyper-parameter search done. Found the best config.")
#     params = study.best_params
#     best_train_cfg = {}
#     for train_cfg_key in TRAIN_CONFIG_KEYS:
#         best_train_cfg[train_cfg_key] = params.pop(train_cfg_key)
#     best_model_cfg = params

#     print(
#         f"Repeat experiments {num_repeats} times with the best train "
#         f"config {best_train_cfg} and model config {best_model_cfg}."
#     )
#     start_time = time.time()
#     best_val_metrics = []
#     best_test_metrics = []
#     for _ in range(num_repeats):
#         best_val_metric, best_test_metric = train_and_eval_with_cfg(
#             best_model_cfg, best_train_cfg
#         )
#         best_val_metrics.append(best_val_metric)
#         best_test_metrics.append(best_test_metric)
#     end_time = time.time()
#     final_model_time = (end_time - start_time) / num_repeats
#     best_val_metrics = np.array(best_val_metrics)
#     best_test_metrics = np.array(best_test_metrics)

#     result_dict = {
#         # 'args': __dict__,
#         "best_val_metrics": best_val_metrics,
#         "best_test_metrics": best_test_metrics,
#         "best_val_metric": best_val_metrics.mean(),
#         "best_test_metric": best_test_metrics.mean(),
#         "best_train_cfg": best_train_cfg,
#         "best_model_cfg": best_model_cfg,
#         "search_time": search_time,
#         "final_model_time": final_model_time,
#         "total_time": search_time + final_model_time,
#     }
#     print(result_dict)
#     # Save results
#     if result_path != "":
#         os.makedirs(os.path.dirname(result_path), exist_ok=True)
#         torch.save(result_dict, result_path)
#     return result_dict

In [290]:
# col_to_stype = {
#     **{k: stype.numerical for k in train_df.columns if k not in ["eid", "target"]},
#     **{"target": stype.categorical},
# }

col_to_stype = {
    **{k: stype.numerical for k in qt_cols},
    **{k: stype.categorical for k in cat_cols},
}

TrainDataSet = Dataset(
    df=train_df,
    col_to_stype=col_to_stype,
    target_col="incident",
)
TrainDataSet.materialize()

ValDataSet = Dataset(
    df=val_df,
    col_to_stype=col_to_stype,
    target_col="incident",
)
ValDataSet.materialize()

TestDataSet = Dataset(
    df=test_df,
    col_to_stype=col_to_stype,
    target_col="incident",
)
TestDataSet.materialize()
from torch_frame.data import DataLoader

train_loader = DataLoader(TrainDataSet.tensor_frame, batch_size=128, shuffle=True)
val_loader = DataLoader(ValDataSet.tensor_frame, batch_size=128, shuffle=False)
test_loader = DataLoader(TestDataSet.tensor_frame, batch_size=128, shuffle=False)

In [291]:
train_dataset.task_type

<TaskType.BINARY_CLASSIFICATION: 'binary_classification'>

In [None]:
import argparse
import math
import os
import os.path as osp
import time
from typing import Any, Optional

import numpy as np
import optuna
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss
from torch.optim.lr_scheduler import ExponentialLR
from torchmetrics import AUROC, Accuracy, MeanSquaredError
from tqdm import tqdm

from torch_frame import stype
from torch_frame.data import DataLoader
from torch_frame.datasets import DataFrameBenchmark
from torch_frame.gbdt import CatBoost, LightGBM, XGBoost
from torch_frame.nn.encoder import EmbeddingEncoder, LinearBucketEncoder
from torch_frame.nn.models import (
    MLP,
    ExcelFormer,
    FTTransformer,
    ResNet,
    TabNet,
    TabTransformer,
    Trompt,
)
from torch_frame.typing import TaskType

# class Args:
#     def __init__(self, **kwargs):
#         # 使用字典存储键值对
#         self.__dict__.update(kwargs)

# # 创建 Args 对象

# args = Args(
#     model_type = "TabNet", # TabNet, TabTransformer, ExcelFormer, MLP, ResNet, Trompt, LightGBM, CatBoost, XGBoost
#     task_type = "binary_classification", # binary_classification, multiclass_classification, regression
#     scale = "small",
#     idx = 0,
# )

model_type = "MLP"
train_dataset = TrainDataSet
test_dataset = TestDataSet
val_dataset = ValDataSet
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"]
task_type = "binary_classification"  # binary_classification, multiclass_classification, regression
sacle = "small"  # small, medium, large
epochs = 10
num_trials = 3  # Number of Optuna-based hyper-parameter tuning.
num_repeats = 5  # Number of repeated training and eval on the best config
seed = 42


torch.manual_seed(seed)


train_tensor_frame = train_dataset.tensor_frame
val_tensor_frame = val_dataset.tensor_frame
test_tensor_frame = test_dataset.tensor_frame

if train_dataset.task_type == TaskType.BINARY_CLASSIFICATION:
    out_channels = 1
    loss_fun = BCEWithLogitsLoss()
    metric_computer = AUROC(task="binary").to(device)
    higher_is_better = True
elif train_dataset.task_type == TaskType.MULTICLASS_CLASSIFICATION:
    out_channels = train_dataset.num_classes
    loss_fun = CrossEntropyLoss()
    metric_computer = Accuracy(
        task="multiclass", num_classes=train_dataset.num_classes
    ).to(device)
    higher_is_better = True
elif train_dataset.task_type == TaskType.REGRESSION:
    out_channels = 1
    loss_fun = MSELoss()
    metric_computer = MeanSquaredError(squared=False).to(device)
    higher_is_better = False

# To be set for each model
model_cls = None
col_stats = None

# Set up model specific search space
if model_type == "TabNet":
    model_search_space = {
        "split_attn_channels": [64, 128, 256],
        "split_feat_channels": [64, 128, 256],
        "gamma": [1.0, 1.2, 1.5],
        "num_layers": [4, 6, 8],
    }
    train_search_space = {
        "batch_size": [
            2048,
            4096,
        ],  # Note if you have a small data, you may want to reduce it, also low gpu memory
        # "batch_size": [128, 256],
        "base_lr": [0.001, 0.01],
        "gamma_rate": [0.9, 0.95, 1.0],
    }
    model_cls = TabNet
    col_stats = train_dataset.col_stats
elif model_type == "FTTransformer":
    model_search_space = {
        "channels": [64, 128, 256],
        "num_layers": [4, 6, 8],
    }
    train_search_space = {
        "batch_size": [256, 512],
        "base_lr": [0.0001, 0.001],
        "gamma_rate": [0.9, 0.95, 1.0],
    }
    model_cls = FTTransformer
    col_stats = train_dataset.col_stats
elif model_type == "FTTransformerBucket":
    model_search_space = {
        "channels": [64, 128, 256],
        "num_layers": [4, 6, 8],
    }
    train_search_space = {
        "batch_size": [256, 512],
        "base_lr": [0.0001, 0.001],
        "gamma_rate": [0.9, 0.95, 1.0],
    }
    model_cls = FTTransformer

    col_stats = train_dataset.col_stats
elif model_type == "ResNet":
    model_search_space = {
        "channels": [64, 128, 256],
        "num_layers": [4, 6, 8],
    }
    train_search_space = {
        "batch_size": [256, 512],
        "base_lr": [0.0001, 0.001],
        "gamma_rate": [0.9, 0.95, 1.0],
    }
    model_cls = ResNet
    col_stats = train_dataset.col_stats
elif model_type == "MLP":
    model_search_space = {
        "channels": [64, 128, 256],
        "num_layers": [1, 2, 4],
    }
    train_search_space = {
        "batch_size": [256, 512],
        "base_lr": [0.0001, 0.001],
        "gamma_rate": [0.9, 0.95, 1.0],
    }
    model_cls = MLP
    col_stats = train_dataset.col_stats
elif model_type == "TabTransformer":
    model_search_space = {
        "channels": [16, 32, 64, 128],
        "num_layers": [4, 6, 8],
        "num_heads": [4, 8],
        "encoder_pad_size": [2, 4],
        "attn_dropout": [0, 0.2],
        "ffn_dropout": [0, 0.2],
    }
    train_search_space = {
        "batch_size": [128, 256],
        "base_lr": [0.0001, 0.001],
        "gamma_rate": [0.9, 0.95, 1.0],
    }
    model_cls = TabTransformer
    col_stats = train_dataset.col_stats
elif model_type == "Trompt":
    model_search_space = {
        "channels": [64, 128, 192],
        "num_layers": [4, 6, 8],
        "num_prompts": [64, 128, 192],
    }
    train_search_space = {
        "batch_size": [128, 256],
        "base_lr": [0.01, 0.001],
        "gamma_rate": [0.9, 0.95, 1.0],
    }
    if train_tensor_frame.num_cols > 20:
        # Reducing the model size to avoid GPU OOM
        model_search_space["channels"] = [64, 128]
        model_search_space["num_prompts"] = [64, 128]
    elif train_tensor_frame.num_cols > 50:
        model_search_space["channels"] = [64]
        model_search_space["num_prompts"] = [64]
    model_cls = Trompt
    col_stats = train_dataset.col_stats
elif model_type == "ExcelFormer":
    from torch_frame.transforms import (
        CatToNumTransform,
        MutualInformationSort,
    )

    categorical_transform = CatToNumTransform()
    categorical_transform.fit(train_dataset.tensor_frame, train_dataset.col_stats)
    train_tensor_frame = categorical_transform(train_tensor_frame)
    # val_tensor_frame = categorical_transform(val_tensor_frame)
    # test_tensor_frame = categorical_transform(test_tensor_frame)
    col_stats = categorical_transform.transformed_stats

    mutual_info_sort = MutualInformationSort(task_type=train_dataset.task_type)
    mutual_info_sort.fit(train_tensor_frame, col_stats)
    train_tensor_frame = mutual_info_sort(train_tensor_frame)
    # val_tensor_frame = mutual_info_sort(val_tensor_frame)
    # test_tensor_frame = mutual_info_sort(test_tensor_frame)

    model_search_space = {
        "in_channels": [128, 256],
        "num_heads": [8, 16, 32],
        "num_layers": [4, 6, 8],
        "diam_dropout": [0, 0.2],
        "residual_dropout": [0, 0.2],
        "aium_dropout": [0, 0.2],
        "mixup": [None, "feature", "hidden"],
        "beta": [0.5],
        "num_cols": [train_tensor_frame.num_cols],
    }
    train_search_space = {
        "batch_size": [256, 512],
        "base_lr": [0.001],
        "gamma_rate": [0.9, 0.95, 1.0],
    }
    model_cls = ExcelFormer

assert model_cls is not None
assert col_stats is not None
assert set(train_search_space.keys()) == set(TRAIN_CONFIG_KEYS)
col_names_dict = train_tensor_frame.col_names_dict


def train(
    model: Module,
    loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    epoch: int,
) -> float:
    model.train()
    loss_accum = total_count = 0

    for tf in tqdm(loader, desc=f"Epoch: {epoch}"):
        tf = tf.to(device)
        y = tf.y
        if isinstance(model, ExcelFormer):
            # Train with FEAT-MIX or HIDDEN-MIX
            pred, y = model(tf, mixup_encoded=True)
        elif isinstance(model, Trompt):
            # Trompt uses the layer-wise loss
            pred = model(tf)
            num_layers = pred.size(1)
            # [batch_size * num_layers, num_classes]
            pred = pred.view(-1, out_channels)
            y = tf.y.repeat_interleave(num_layers)
        else:
            pred = model(tf)

        if pred.size(1) == 1:
            pred = pred.view(
                -1,
            )
        if train_dataset.task_type == TaskType.BINARY_CLASSIFICATION:
            y = y.to(torch.float)
        loss = loss_fun(pred, y)
        optimizer.zero_grad()
        loss.backward()
        loss_accum += float(loss) * len(tf.y)
        total_count += len(tf.y)
        optimizer.step()
    return loss_accum / total_count


@torch.no_grad()
def test(
    model: Module,
    loader: DataLoader,
) -> float:
    model.eval()
    metric_computer.reset()
    for tf in loader:
        tf = tf.to(device)
        pred = model(tf)
        if isinstance(model, Trompt):
            pred = pred.mean(dim=1)
        if train_dataset.task_type == TaskType.MULTICLASS_CLASSIFICATION:
            pred = pred.argmax(dim=-1)
        elif train_dataset.task_type == TaskType.REGRESSION:
            pred = pred.view(
                -1,
            )
        metric_computer.update(pred, tf.y)
    return metric_computer.compute().item()


def train_and_eval_with_cfg(
    model_cfg: dict[str, Any],
    train_cfg: dict[str, Any],
    trial: Optional[optuna.trial.Trial] = None,
) -> tuple[float, float]:
    # Use model_cfg to set up training procedure
    if model_type == "FTTransformerBucket":
        # Use LinearBucketEncoder instead
        stype_encoder_dict = {
            stype.categorical: EmbeddingEncoder(),
            stype.numerical: LinearBucketEncoder(),
        }
        model_cfg["stype_encoder_dict"] = stype_encoder_dict
    model = model_cls(
        **model_cfg,
        out_channels=out_channels,
        col_stats=col_stats,
        col_names_dict=col_names_dict,
    ).to(device)
    model.reset_parameters()
    # Use train_cfg to set up training procedure
    optimizer = torch.optim.Adam(model.parameters(), lr=train_cfg["base_lr"])
    lr_scheduler = ExponentialLR(optimizer, gamma=train_cfg["gamma_rate"])
    train_loader = DataLoader(
        train_tensor_frame,
        batch_size=train_cfg["batch_size"],
        shuffle=True,
        drop_last=True,
    )
    val_loader = DataLoader(val_tensor_frame, batch_size=train_cfg["batch_size"])
    test_loader = DataLoader(test_tensor_frame, batch_size=train_cfg["batch_size"])

    if higher_is_better:
        best_val_metric = 0
    else:
        best_val_metric = math.inf

    for epoch in range(1, epochs + 1):
        train_loss = train(model, train_loader, optimizer, epoch)
        val_metric = test(model, val_loader)

        if higher_is_better:
            if val_metric > best_val_metric:
                best_val_metric = val_metric
                best_test_metric = test(model, test_loader)
        else:
            if val_metric < best_val_metric:
                best_val_metric = val_metric
                best_test_metric = test(model, test_loader)
        lr_scheduler.step()
        print(f"Train Loss: {train_loss:.4f}, Val: {val_metric:.4f}")

        if trial is not None:
            trial.report(val_metric, epoch)
            if trial.should_prune():
                raise optuna.TrialPruned()

    print(f"Best val: {best_val_metric:.4f}, Best test: {best_test_metric:.4f}")
    return best_val_metric, best_test_metric


def objective(trial: optuna.trial.Trial) -> float:
    model_cfg = {}
    for name, search_list in model_search_space.items():
        model_cfg[name] = trial.suggest_categorical(name, search_list)
    train_cfg = {}
    for name, search_list in train_search_space.items():
        train_cfg[name] = trial.suggest_categorical(name, search_list)

    best_val_metric, _ = train_and_eval_with_cfg(
        model_cfg=model_cfg, train_cfg=train_cfg, trial=trial
    )
    return best_val_metric


def main_deep_models():
    # Hyper-parameter optimization with Optuna
    print("Hyper-parameter search via Optuna")
    start_time = time.time()
    study = optuna.create_study(
        pruner=optuna.pruners.MedianPruner(),
        direction="maximize" if higher_is_better else "minimize",
    )
    study.optimize(objective, n_trials=num_trials)
    end_time = time.time()
    search_time = end_time - start_time
    print("Hyper-parameter search done. Found the best config.")
    params = study.best_params
    best_train_cfg = {}
    for train_cfg_key in TRAIN_CONFIG_KEYS:
        best_train_cfg[train_cfg_key] = params.pop(train_cfg_key)
    best_model_cfg = params

    print(
        f"Repeat experiments {num_repeats} times with the best train "
        f"config {best_train_cfg} and model config {best_model_cfg}."
    )

    # retrain model
    if model_type == "FTTransformerBucket":
        # Use LinearBucketEncoder instead
        stype_encoder_dict = {
            stype.categorical: EmbeddingEncoder(),
            stype.numerical: LinearBucketEncoder(),
        }
        best_model_cfg["stype_encoder_dict"] = stype_encoder_dict

    model = model_cls(
        **best_model_cfg,
        out_channels=out_channels,
        col_stats=col_stats,
        col_names_dict=col_names_dict,
    ).to(device)
    model.reset_parameters()
    # Use train_cfg to set up training procedure
    optimizer = torch.optim.Adam(model.parameters(), lr=best_train_cfg["base_lr"])
    lr_scheduler = ExponentialLR(optimizer, gamma=best_train_cfg["gamma_rate"])
    train_loader = DataLoader(
        train_tensor_frame,
        batch_size=best_train_cfg["batch_size"],
        shuffle=True,
        drop_last=True,
    )
    val_loader = DataLoader(val_tensor_frame, batch_size=best_train_cfg["batch_size"])
    test_loader = DataLoader(test_tensor_frame, batch_size=best_train_cfg["batch_size"])

    if higher_is_better:
        best_val_metric = 0
    else:
        best_val_metric = math.inf

    for epoch in range(1, epochs + 1):
        train_loss = train(model, train_loader, optimizer, epoch)
        val_metric = test(model, val_loader)

        if higher_is_better:
            if val_metric > best_val_metric:
                best_val_metric = val_metric
                best_test_metric = test(model, test_loader)
        else:
            if val_metric < best_val_metric:
                best_val_metric = val_metric
                best_test_metric = test(model, test_loader)
        lr_scheduler.step()
        print(f"Train Loss: {train_loss:.4f}, Val: {val_metric:.4f}")

    result_dict = {
        # 'args': __dict__,
        "model": model,
        "best_val_metric": best_val_metric,
        "best_test_metric": best_test_metric,
        "best_train_cfg": best_train_cfg,
        "best_model_cfg": best_model_cfg,
        "search_time": search_time,
    }
    return result_dict

    start_time = time.time()
    best_val_metrics = []
    best_test_metrics = []
    for _ in range(num_repeats):
        best_val_metric, best_test_metric = train_and_eval_with_cfg(
            best_model_cfg, best_train_cfg
        )
        best_val_metrics.append(best_val_metric)
        best_test_metrics.append(best_test_metric)
    # end_time = time.time()
    # final_model_time = (end_time - start_time) / num_repeats
    # best_val_metrics = np.array(best_val_metrics)
    # best_test_metrics = np.array(best_test_metrics)

    # result_dict = {
    #     # 'args': __dict__,
    #     "best_val_metrics": best_val_metrics,
    #     "best_test_metrics": best_test_metrics,
    #     "best_val_metric": best_val_metrics.mean(),
    #     "best_test_metric": best_test_metrics.mean(),
    #     "best_train_cfg": best_train_cfg,
    #     "best_model_cfg": best_model_cfg,
    #     "search_time": search_time,
    #     "final_model_time": final_model_time,
    #     "total_time": search_time + final_model_time,
    # }
    # print(result_dict)
    # # Save results
    # if result_path != "":
    #     os.makedirs(os.path.dirname(result_path), exist_ok=True)
    #     torch.save(result_dict, result_path)

In [310]:
# model = model_cls(
#     **model_cfg,
#     out_channels=out_channels,
#     col_stats=col_stats,
#     col_names_dict=col_names_dict,
# )
# model.to(device)
# train(
#     model,
#     train_loader,
#     optimizer,
#     epoch,
# )

In [None]:
res = main_deep_models()

[I 2025-03-24 13:54:23,596] A new study created in memory with name: no-name-20fbe0fa-3088-409e-a763-1f3f3d3c0cd7


Hyper-parameter search via Optuna


Epoch: 1: 100%|██████████| 198/198 [00:02<00:00, 95.89it/s] 


Train Loss: 0.0695, Val: 0.7005


Epoch: 2: 100%|██████████| 198/198 [00:01<00:00, 147.08it/s]


Train Loss: 0.0370, Val: 0.7563


Epoch: 3: 100%|██████████| 198/198 [00:01<00:00, 147.87it/s]


Train Loss: 0.0347, Val: 0.8210


Epoch: 4: 100%|██████████| 198/198 [00:01<00:00, 140.42it/s]


Train Loss: 0.0324, Val: 0.8389


Epoch: 5: 100%|██████████| 198/198 [00:01<00:00, 125.17it/s]


Train Loss: 0.0314, Val: 0.8445


Epoch: 6: 100%|██████████| 198/198 [00:00<00:00, 202.78it/s]


Train Loss: 0.0313, Val: 0.8489


Epoch: 7: 100%|██████████| 198/198 [00:00<00:00, 240.04it/s]


Train Loss: 0.0307, Val: 0.8492


Epoch: 8: 100%|██████████| 198/198 [00:01<00:00, 147.14it/s]


Train Loss: 0.0307, Val: 0.8511


Epoch: 9: 100%|██████████| 198/198 [00:00<00:00, 259.53it/s]


Train Loss: 0.0303, Val: 0.8526


Epoch: 10: 100%|██████████| 198/198 [00:01<00:00, 175.67it/s]
[I 2025-03-24 13:54:48,946] Trial 0 finished with value: 0.8525776267051697 and parameters: {'channels': 64, 'num_layers': 4, 'batch_size': 512, 'base_lr': 0.001, 'gamma_rate': 0.9}. Best is trial 0 with value: 0.8525776267051697.


Train Loss: 0.0302, Val: 0.8525
Best val: 0.8526, Best test: 0.8585


Epoch: 1: 100%|██████████| 397/397 [00:01<00:00, 372.97it/s]


Train Loss: 0.0423, Val: 0.7774


Epoch: 2: 100%|██████████| 397/397 [00:01<00:00, 394.29it/s]


Train Loss: 0.0325, Val: 0.7924


Epoch: 3: 100%|██████████| 397/397 [00:01<00:00, 353.14it/s]


Train Loss: 0.0319, Val: 0.7951


Epoch: 4: 100%|██████████| 397/397 [00:01<00:00, 355.56it/s]


Train Loss: 0.0316, Val: 0.8024


Epoch: 5: 100%|██████████| 397/397 [00:01<00:00, 349.43it/s]


Train Loss: 0.0311, Val: 0.8102


Epoch: 6: 100%|██████████| 397/397 [00:01<00:00, 354.04it/s]


Train Loss: 0.0311, Val: 0.8117


Epoch: 7: 100%|██████████| 397/397 [00:01<00:00, 359.70it/s]


Train Loss: 0.0305, Val: 0.8194


Epoch: 8: 100%|██████████| 397/397 [00:01<00:00, 367.75it/s]


Train Loss: 0.0304, Val: 0.8220


Epoch: 9: 100%|██████████| 397/397 [00:01<00:00, 370.79it/s]


Train Loss: 0.0302, Val: 0.8248


Epoch: 10: 100%|██████████| 397/397 [00:01<00:00, 356.14it/s]
[I 2025-03-24 13:55:11,225] Trial 1 finished with value: 0.8301591873168945 and parameters: {'channels': 256, 'num_layers': 4, 'batch_size': 256, 'base_lr': 0.0001, 'gamma_rate': 0.9}. Best is trial 0 with value: 0.8525776267051697.


Train Loss: 0.0298, Val: 0.8302
Best val: 0.8302, Best test: 0.8372


Epoch: 1: 100%|██████████| 198/198 [00:00<00:00, 411.10it/s]


Train Loss: 0.3212, Val: 0.6582


Epoch: 2: 100%|██████████| 198/198 [00:01<00:00, 182.81it/s]


Train Loss: 0.0498, Val: 0.7611


Epoch: 3: 100%|██████████| 198/198 [00:01<00:00, 169.00it/s]


Train Loss: 0.0353, Val: 0.8012


Epoch: 4: 100%|██████████| 198/198 [00:01<00:00, 172.02it/s]


Train Loss: 0.0331, Val: 0.8143


Epoch: 5: 100%|██████████| 198/198 [00:01<00:00, 162.84it/s]


Train Loss: 0.0323, Val: 0.8223


Epoch: 6: 100%|██████████| 198/198 [00:01<00:00, 169.11it/s]


In [306]:
# get predict
model = res["model"]
pred_array = []
for tf in test_loader:
    tf = tf.to(device)
    pred = model(tf).cpu().detach().flatten().tolist()

    # proba = torch.softmax(pred, dim=1)[:, 1].cpu().detach().numpy().tolist()
    pred_array.extend(pred)
    # break

In [308]:
TestDataSet.df["pred"] = pred_array
from ppp_prediction.metrics import cal_binary_metrics

cal_binary_metrics(TestDataSet.df["incident"], TestDataSet.df["pred"])

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy


{'AUC': 0.8304221064354074,
 'ACC': 0.7796975201607355,
 'Macro_F1': 0.45684945991264686,
 'Sensitivity': 0.7434094903339191,
 'Specificity': 0.7799153507269247,
 'APR': 0.046981357128616584}

In [172]:
import torch
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ExampleTransformer(
    channels=32,
    out_channels=TrainDataSet.num_classes,
    num_layers=2,
    num_heads=8,
    col_stats=TrainDataSet.col_stats,
    col_names_dict=TrainDataSet.tensor_frame.col_names_dict,
).to(device)

optimizer = torch.optim.Adam(model.parameters())

for epoch in range(10):
    print(f"Epoch {epoch}")
    for tf in train_loader:
        tf = tf.to(device)
        pred = model.forward(tf)
        loss = F.cross_entropy(pred, tf.y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    print(loss.item())

Epoch 0
0.1027328372001648
Epoch 1
0.0993896871805191
Epoch 2
0.06829185783863068
Epoch 3
0.1154552698135376
Epoch 4
0.08367930352687836
Epoch 5
0.02228795550763607
Epoch 6
0.052397169172763824
Epoch 7
0.18704631924629211
Epoch 8
0.011632421053946018
Epoch 9
0.0033353206235915422


In [58]:
# get predict
pred_array = []
for tf in test_loader:
    tf = tf.to(device)
    pred = model(tf)

    proba = torch.softmax(pred, dim=1)[:, 1].cpu().detach().numpy().tolist()
    pred_array.extend(proba)
    break

In [61]:
TestDataSet.df["pred"] = pred_array

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy


In [None]:
TestDataSet.df["pred"] = pred_array
from ppp_prediction.metrics import cal_binary_metrics

cal_binary_metrics(TestDataSet.df["target"], TestDataSet.df["pred"])

{'AUC': 0.99737962659679,
 'ACC': 0.9736842105263158,
 'Macro_F1': 0.9721203228173148,
 'Sensitivity': 0.9859154929577465,
 'Specificity': 0.9767441860465116,
 'APR': 0.9983678405804648}

In [None]:
from torch_frame.datasets import Yandex
from torch_frame.data import DataLoader

train_dataset = Yandex(root="/tmp/adult", name="adult")
train_dataset.materialize()
train_dataset = train_dataset[:0.8]
train_loader = DataLoader(train_dataset.tensor_frame, batch_size=128, shuffle=True)

In [43]:
train_dataset.target_col

'target_col'

In [None]:
import argparse
import math
import os
import os.path as osp
import time
from typing import Any, Optional

import numpy as np
import optuna
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss
from torch.optim.lr_scheduler import ExponentialLR
from torchmetrics import AUROC, Accuracy, MeanSquaredError
from tqdm import tqdm

from torch_frame import stype
from torch_frame.data import DataLoader
from torch_frame.datasets import DataFrameBenchmark
from torch_frame.gbdt import CatBoost, LightGBM, XGBoost
from torch_frame.nn.encoder import EmbeddingEncoder, LinearBucketEncoder
from torch_frame.nn.models import (
    MLP,
    ExcelFormer,
    FTTransformer,
    ResNet,
    TabNet,
    TabTransformer,
    Trompt,
)
from torch_frame.typing import TaskType

TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"]
GBDT_MODELS = ["XGBoost", "CatBoost", "LightGBM"]

parser = argparse.ArgumentParser()
parser.add_argument(
    '--task_type', type=str, choices=[
        'binary_classification',
        'multiclass_classification',
        'regression',
    ], default='binary_classification')
parser.add_argument('--scale', type=str, choices=['small', 'medium', 'large'],
                    default='small')
parser.add_argument('--idx', type=int, default=0,
                    help='The index of the dataset within DataFrameBenchmark')
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--num_trials', type=int, default=20,
                    help='Number of Optuna-based hyper-parameter tuning.')
parser.add_argument(
    '--num_repeats', type=int, default=5,
    help='Number of repeated training and eval on the best config.')
parser.add_argument(
    '--model_type', type=str, default='TabNet', choices=[
        'TabNet', 'FTTransformer', 'ResNet', 'MLP', 'TabTransformer', 'Trompt',
        'ExcelFormer', 'FTTransformerBucket', 'XGBoost', 'CatBoost', 'LightGBM'
    ])
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--result_path', type=str, default='')
args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(args.seed)

# Prepare datasets
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data')
train_dataset = DataFrameBenchmark(root=path, task_type=TaskType(args.task_type),
                             scale=args.scale, idx=args.idx)
train_dataset.materialize()
train_dataset = train_dataset.shuffle()
train_dataset, val_dataset, test_dataset = train_dataset.split()

train_tensor_frame = train_dataset.tensor_frame
val_tensor_frame = val_dataset.tensor_frame
test_tensor_frame = test_dataset.tensor_frame

if args.model_type in GBDT_MODELS:
    gbdt_cls_dict = {
        'XGBoost': XGBoost,
        'CatBoost': CatBoost,
        'LightGBM': LightGBM
    }
    model_cls = gbdt_cls_dict[args.model_type]
else:
    if train_dataset.task_type == TaskType.BINARY_CLASSIFICATION:
        out_channels = 1
        loss_fun = BCEWithLogitsLoss()
        metric_computer = AUROC(task='binary').to(device)
        higher_is_better = True
    elif train_dataset.task_type == TaskType.MULTICLASS_CLASSIFICATION:
        out_channels = train_dataset.num_classes
        loss_fun = CrossEntropyLoss()
        metric_computer = Accuracy(task='multiclass',
                                   num_classes=train_dataset.num_classes).to(device)
        higher_is_better = True
    elif train_dataset.task_type == TaskType.REGRESSION:
        out_channels = 1
        loss_fun = MSELoss()
        metric_computer = MeanSquaredError(squared=False).to(device)
        higher_is_better = False

    # To be set for each model
    model_cls = None
    col_stats = None

    # Set up model specific search space
    if args.model_type == 'TabNet':
        model_search_space = {
            'split_attn_channels': [64, 128, 256],
            'split_feat_channels': [64, 128, 256],
            'gamma': [1., 1.2, 1.5],
            'num_layers': [4, 6, 8],
        }
        train_search_space = {
            'batch_size': [2048, 4096],
            'base_lr': [0.001, 0.01],
            'gamma_rate': [0.9, 0.95, 1.],
        }
        model_cls = TabNet
        col_stats = train_dataset.col_stats
    elif args.model_type == 'FTTransformer':
        model_search_space = {
            'channels': [64, 128, 256],
            'num_layers': [4, 6, 8],
        }
        train_search_space = {
            'batch_size': [256, 512],
            'base_lr': [0.0001, 0.001],
            'gamma_rate': [0.9, 0.95, 1.],
        }
        model_cls = FTTransformer
        col_stats = train_dataset.col_stats
    elif args.model_type == 'FTTransformerBucket':
        model_search_space = {
            'channels': [64, 128, 256],
            'num_layers': [4, 6, 8],
        }
        train_search_space = {
            'batch_size': [256, 512],
            'base_lr': [0.0001, 0.001],
            'gamma_rate': [0.9, 0.95, 1.],
        }
        model_cls = FTTransformer

        col_stats = train_dataset.col_stats
    elif args.model_type == 'ResNet':
        model_search_space = {
            'channels': [64, 128, 256],
            'num_layers': [4, 6, 8],
        }
        train_search_space = {
            'batch_size': [256, 512],
            'base_lr': [0.0001, 0.001],
            'gamma_rate': [0.9, 0.95, 1.],
        }
        model_cls = ResNet
        col_stats = train_dataset.col_stats
    elif args.model_type == 'MLP':
        model_search_space = {
            'channels': [64, 128, 256],
            'num_layers': [1, 2, 4],
        }
        train_search_space = {
            'batch_size': [256, 512],
            'base_lr': [0.0001, 0.001],
            'gamma_rate': [0.9, 0.95, 1.],
        }
        model_cls = MLP
        col_stats = train_dataset.col_stats
    elif args.model_type == 'TabTransformer':
        model_search_space = {
            'channels': [16, 32, 64, 128],
            'num_layers': [4, 6, 8],
            'num_heads': [4, 8],
            'encoder_pad_size': [2, 4],
            'attn_dropout': [0, 0.2],
            'ffn_dropout': [0, 0.2],
        }
        train_search_space = {
            'batch_size': [128, 256],
            'base_lr': [0.0001, 0.001],
            'gamma_rate': [0.9, 0.95, 1.],
        }
        model_cls = TabTransformer
        col_stats = train_dataset.col_stats
    elif args.model_type == 'Trompt':
        model_search_space = {
            'channels': [64, 128, 192],
            'num_layers': [4, 6, 8],
            'num_prompts': [64, 128, 192],
        }
        train_search_space = {
            'batch_size': [128, 256],
            'base_lr': [0.01, 0.001],
            'gamma_rate': [0.9, 0.95, 1.],
        }
        if train_tensor_frame.num_cols > 20:
            # Reducing the model size to avoid GPU OOM
            model_search_space['channels'] = [64, 128]
            model_search_space['num_prompts'] = [64, 128]
        elif train_tensor_frame.num_cols > 50:
            model_search_space['channels'] = [64]
            model_search_space['num_prompts'] = [64]
        model_cls = Trompt
        col_stats = train_dataset.col_stats
    elif args.model_type == 'ExcelFormer':
        from torch_frame.transforms import (
            CatToNumTransform,
            MutualInformationSort,
        )

        categorical_transform = CatToNumTransform()
        categorical_transform.fit(train_dataset.tensor_frame,
                                  train_dataset.col_stats)
        train_tensor_frame = categorical_transform(train_tensor_frame)
        val_tensor_frame = categorical_transform(val_tensor_frame)
        test_tensor_frame = categorical_transform(test_tensor_frame)
        col_stats = categorical_transform.transformed_stats

        mutual_info_sort = MutualInformationSort(task_type=train_dataset.task_type)
        mutual_info_sort.fit(train_tensor_frame, col_stats)
        train_tensor_frame = mutual_info_sort(train_tensor_frame)
        val_tensor_frame = mutual_info_sort(val_tensor_frame)
        test_tensor_frame = mutual_info_sort(test_tensor_frame)

        model_search_space = {
            'in_channels': [128, 256],
            'num_heads': [8, 16, 32],
            'num_layers': [4, 6, 8],
            'diam_dropout': [0, 0.2],
            'residual_dropout': [0, 0.2],
            'aium_dropout': [0, 0.2],
            'mixup': [None, 'feature', 'hidden'],
            'beta': [0.5],
            'num_cols': [train_tensor_frame.num_cols],
        }
        train_search_space = {
            'batch_size': [256, 512],
            'base_lr': [0.001],
            'gamma_rate': [0.9, 0.95, 1.],
        }
        model_cls = ExcelFormer

    assert model_cls is not None
    assert col_stats is not None
    assert set(train_search_space.keys()) == set(TRAIN_CONFIG_KEYS)
    col_names_dict = train_tensor_frame.col_names_dict


def train(
    model: Module,
    loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    epoch: int,
) -> float:
    model.train()
    loss_accum = total_count = 0

    for tf in tqdm(loader, desc=f'Epoch: {epoch}'):
        tf = tf.to(device)
        y = tf.y
        if isinstance(model, ExcelFormer):
            # Train with FEAT-MIX or HIDDEN-MIX
            pred, y = model(tf, mixup_encoded=True)
        elif isinstance(model, Trompt):
            # Trompt uses the layer-wise loss
            pred = model(tf)
            num_layers = pred.size(1)
            # [batch_size * num_layers, num_classes]
            pred = pred.view(-1, out_channels)
            y = tf.y.repeat_interleave(num_layers)
        else:
            pred = model(tf)

        if pred.size(1) == 1:
            pred = pred.view(-1, )
        if train_dataset.task_type == TaskType.BINARY_CLASSIFICATION:
            y = y.to(torch.float)
        loss = loss_fun(pred, y)
        optimizer.zero_grad()
        loss.backward()
        loss_accum += float(loss) * len(tf.y)
        total_count += len(tf.y)
        optimizer.step()
    return loss_accum / total_count


@torch.no_grad()
def test(
    model: Module,
    loader: DataLoader,
) -> float:
    model.eval()
    metric_computer.reset()
    for tf in loader:
        tf = tf.to(device)
        pred = model(tf)
        if isinstance(model, Trompt):
            pred = pred.mean(dim=1)
        if train_dataset.task_type == TaskType.MULTICLASS_CLASSIFICATION:
            pred = pred.argmax(dim=-1)
        elif train_dataset.task_type == TaskType.REGRESSION:
            pred = pred.view(-1, )
        metric_computer.update(pred, tf.y)
    return metric_computer.compute().item()


def train_and_eval_with_cfg(
    model_cfg: dict[str, Any],
    train_cfg: dict[str, Any],
    trial: Optional[optuna.trial.Trial] = None,
) -> tuple[float, float]:
    # Use model_cfg to set up training procedure
    if args.model_type == 'FTTransformerBucket':
        # Use LinearBucketEncoder instead
        stype_encoder_dict = {
            stype.categorical: EmbeddingEncoder(),
            stype.numerical: LinearBucketEncoder(),
        }
        model_cfg['stype_encoder_dict'] = stype_encoder_dict
    model = model_cls(
        **model_cfg,
        out_channels=out_channels,
        col_stats=col_stats,
        col_names_dict=col_names_dict,
    ).to(device)
    model.reset_parameters()
    # Use train_cfg to set up training procedure
    optimizer = torch.optim.Adam(model.parameters(), lr=train_cfg['base_lr'])
    lr_scheduler = ExponentialLR(optimizer, gamma=train_cfg['gamma_rate'])
    train_loader = DataLoader(train_tensor_frame,
                              batch_size=train_cfg['batch_size'], shuffle=True,
                              drop_last=True)
    val_loader = DataLoader(val_tensor_frame,
                            batch_size=train_cfg['batch_size'])
    test_loader = DataLoader(test_tensor_frame,
                             batch_size=train_cfg['batch_size'])

    if higher_is_better:
        best_val_metric = 0
    else:
        best_val_metric = math.inf

    for epoch in range(1, args.epochs + 1):
        train_loss = train(model, train_loader, optimizer, epoch)
        val_metric = test(model, val_loader)

        if higher_is_better:
            if val_metric > best_val_metric:
                best_val_metric = val_metric
                best_test_metric = test(model, test_loader)
        else:
            if val_metric < best_val_metric:
                best_val_metric = val_metric
                best_test_metric = test(model, test_loader)
        lr_scheduler.step()
        print(f'Train Loss: {train_loss:.4f}, Val: {val_metric:.4f}')

        if trial is not None:
            trial.report(val_metric, epoch)
            if trial.should_prune():
                raise optuna.TrialPruned()

    print(
        f'Best val: {best_val_metric:.4f}, Best test: {best_test_metric:.4f}')
    return best_val_metric, best_test_metric


def objective(trial: optuna.trial.Trial) -> float:
    model_cfg = {}
    for name, search_list in model_search_space.items():
        model_cfg[name] = trial.suggest_categorical(name, search_list)
    train_cfg = {}
    for name, search_list in train_search_space.items():
        train_cfg[name] = trial.suggest_categorical(name, search_list)

    best_val_metric, _ = train_and_eval_with_cfg(model_cfg=model_cfg,
                                                 train_cfg=train_cfg,
                                                 trial=trial)
    return best_val_metric


def main_deep_models():
    # Hyper-parameter optimization with Optuna
    print("Hyper-parameter search via Optuna")
    start_time = time.time()
    study = optuna.create_study(
        pruner=optuna.pruners.MedianPruner(),
        direction="maximize" if higher_is_better else "minimize",
    )
    study.optimize(objective, n_trials=args.num_trials)
    end_time = time.time()
    search_time = end_time - start_time
    print("Hyper-parameter search done. Found the best config.")
    params = study.best_params
    best_train_cfg = {}
    for train_cfg_key in TRAIN_CONFIG_KEYS:
        best_train_cfg[train_cfg_key] = params.pop(train_cfg_key)
    best_model_cfg = params

    print(f"Repeat experiments {args.num_repeats} times with the best train "
          f"config {best_train_cfg} and model config {best_model_cfg}.")
    start_time = time.time()
    best_val_metrics = []
    best_test_metrics = []
    for _ in range(args.num_repeats):
        best_val_metric, best_test_metric = train_and_eval_with_cfg(
            best_model_cfg, best_train_cfg)
        best_val_metrics.append(best_val_metric)
        best_test_metrics.append(best_test_metric)
    end_time = time.time()
    final_model_time = (end_time - start_time) / args.num_repeats
    best_val_metrics = np.array(best_val_metrics)
    best_test_metrics = np.array(best_test_metrics)

    result_dict = {
        'args': args.__dict__,
        'best_val_metrics': best_val_metrics,
        'best_test_metrics': best_test_metrics,
        'best_val_metric': best_val_metrics.mean(),
        'best_test_metric': best_test_metrics.mean(),
        'best_train_cfg': best_train_cfg,
        'best_model_cfg': best_model_cfg,
        'search_time': search_time,
        'final_model_time': final_model_time,
        'total_time': search_time + final_model_time,
    }
    print(result_dict)
    # Save results
    if args.result_path != '':
        os.makedirs(os.path.dirname(args.result_path), exist_ok=True)
        torch.save(result_dict, args.result_path)


def main_gbdt():
    if train_dataset.task_type.is_classification:
        num_classes = train_dataset.num_classes
    else:
        num_classes = None
    model = model_cls(task_type=train_dataset.task_type, num_classes=num_classes)

    import time
    start_time = time.time()
    model.tune(tf_train=train_dataset.tensor_frame,
               tf_val=val_dataset.tensor_frame, num_trials=args.num_trials)
    val_pred = model.predict(tf_test=val_dataset.tensor_frame)
    val_metric = model.compute_metric(val_dataset.tensor_frame.y, val_pred)
    test_pred = model.predict(tf_test=test_dataset.tensor_frame)
    test_metric = model.compute_metric(test_dataset.tensor_frame.y, test_pred)
    end_time = time.time()
    result_dict = {
        'args': args.__dict__,
        'best_val_metric': val_metric,
        'best_test_metric': test_metric,
        'best_cfg': model.params,
        'total_time': end_time - start_time,
    }
    print(result_dict)
    # Save results
    if args.result_path != '':
        os.makedirs(os.path.dirname(args.result_path), exist_ok=True)
        torch.save(result_dict, args.result_path)


if __name__ == '__main__':
    print(args)
    if os.path.exists(args.result_path):
        exit(-1)
    if args.model_type in ["XGBoost", "CatBoost", "LightGBM"]:
        main_gbdt()
    else:
        main_deep_models()

In [None]:
import torch
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ExampleTransformer(
    channels=32,
    out_channels=train_dataset.num_classes,
    num_layers=2,
    num_heads=8,
    col_stats=train_dataset.col_stats,
    col_names_dict=train_dataset.tensor_frame.col_names_dict,
).to(device)

optimizer = torch.optim.Adam(model.parameters())

for epoch in range(5):
    print(f"Epoch {epoch}")
    for tf in train_loader:
        tf = tf.to(device)
        pred = model.forward(tf)
        loss = F.cross_entropy(pred, tf.y)
        optimizer.zero_grad()
        loss.backward()

Epoch 0
Epoch 1
Epoch 2
Epoch 3
Epoch 4


In [45]:
loss

tensor(0.6940, device='cuda:0', grad_fn=<NllLossBackward0>)