# MultiOmicsDiseasePrediction

数据维度

- Meta
- Prot
- RF | PANNEL | AS 
- PRS
- LTL + CHIP 

流程：
1. Meta 和 Prot 交集 作为 held-out 数据 
2. 排除held-out数据作为所有维度数据的单独模型的训练集
2. 定义研究的疾病或者表型数据（输出所有数据集中的case|control的比例；如有时间则列出随访信息等）
3. 对每个数据定义机器学习方法（xxx.py）
4. 整合全部组学的预测分数
5. 下游分析（calibration，netbenifit，重分类，分数的百分比风险图，疾病生存分析（如有随访））


In [None]:
import pandas as pd

from pathlib import Path

pd.set_option("display.max_columns", None)


# %config InlineBackend.figure_format = "svg"
# %config InlineBackend.print_figure_kwargs = {"dpi" : 300}
import seaborn as sns
import matplotlib.pyplot as plt


# from cadFace.vis import percentiles_plot
import sci_palettes

try:
    sci_palettes.register_cmap()
except:
    pass
import scienceplots

plt.style.use(["nature", "no-latex"])
sns.set_context("paper", font_scale=1.5)
sns.set_palette("nejm")

In [None]:
def load_data(x):
    # if isinstance(x, Path.)
    x = str(x)
    if ".csv" in x:
        return pd.read_csv(x)
    elif x.endswith(".feather"):
        return pd.read_feather(x)
    elif x.endswith(".pkl"):
        return pd.read_pickle(x)
    elif ".tsv" in x:
        return pd.read_csv(x, sep="\t")
    else:
        raise ValueError(f"File format: {x} not supported")


class DataConfig(object):
    def __init__(self, path, name=None, **kwargs):
        self.name = name if name else Path(path).stem
        self.path = path
        # self.kwargs = kwargs
        for key, value in kwargs.items():
            setattr(self, key, value)

    def __load_data__(self):
        print(f"Loading data: {self.name}")
        self.data = load_data(self.path)

    def __str__(self) -> str:
        return self.name

    def __repr__(self) -> str:
        return self.name


# class ModelConfig(object):
#     def __init__(self, name, model, feature, cov, **kwargs):
#         self.name = name
#         self.model = model
#         self.feature = feature
#         self.cov = cov
#         for key, value in kwargs.items():
#             setattr(self, key, value)


class ModelConfig(dict):
    def __init__(self, name=None, model=None, feature=None, cov=None, **kwargs):
        kwargs["name"] = name
        kwargs["model"] = model
        kwargs["feature"] = feature
        kwargs["cov"] = cov
        super().__init__(**kwargs)

    def __getattr__(self, name):
        return self[name]

    def __setattr__(self, name, value):
        self[name] = value

def check_disease_dist(Config):
    # to_check_list = Config
    diseaseData = Config["diseaseData"].data
    disease = Config["diseaseData"].name
    disease_event_col = Config["diseaseData"].label

    dist = {}
    for dataconfig in Config["omicsData"].values():
        data_eids = dataconfig.data[["eid"]].copy()
        dataname = dataconfig.name

        if data_eids.eid.dtype != diseaseData.eid.dtype:
            data_eids.eid = data_eids.eid.astype(diseaseData.eid.dtype)

        inner_data = diseaseData.merge(data_eids, on="eid", how="inner")[
            [disease_event_col]
        ]

        inner_data = inner_data.value_counts().to_dict()

        dist[dataname] = inner_data

    dataconfig = Config["heldOutData"]
    data_eids = dataconfig.data[["eid"]].copy()
    dataname = dataconfig.name

    if data_eids.eid.dtype != diseaseData.eid.dtype:
        data_eids.eid = data_eids.eid.astype(diseaseData.eid.dtype)

    inner_data = diseaseData.merge(data_eids, on="eid", how="inner")[
        [disease_event_col]
    ]

    inner_data = inner_data.value_counts().to_dict()

    dist[dataname] = inner_data

    dist_df = pd.DataFrame(dist)

    return dist_df

class LassoConfig(object):
    def __init__(
        self,
        feature,
        label,
        cov,
        name="lasso",
        family="binomial",
        lambda_=None,
        type_measure="auc",
        cv=10,
        **kwargs,
    ):
        assert isinstance(label, str), "label should be a string"
        if cov is not None:
            if isinstance(cov, str):
                cov = [cov]
            elif isinstance(cov, list):
                if len(cov) == 0:
                    cov = None

        assert isinstance(feature, str) or isinstance(
            feature, list
        ), "feature should be a string or a list"
        self.config = {
            name: {
                "feature": feature if isinstance(feature, list) else [feature],
                "label": label,
                "time": None,
                "cov": cov,
                "family": family,
                "lambda": lambda_,
                "type_measure": type_measure,
                "cv": cv,
            }
        }
        if kwargs:
            # export them and warning not used
            print(f"Warning: {kwargs} not used")

    def to_json(self):
        return self.config


def plot_coef_scatter(data, coef, feature, k=6, ax=None, cmap="nejm"):
    data = (
        data[[coef, feature]].copy().rename(columns={coef: "coef", feature: "feature"})
    )
    plt_data = (
        data.query("coef !=0")
        .sort_values("coef", ascending=False)
        .reset_index(drop=True)
        .reset_index(drop=False, names=["idx"])
    )

    from adjustText import adjust_text

    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 4))

    hue = "coef"
    y = "coef"
    name = "feature"
    plt_data = plt_data.query("feature != 'sex' & feature != 'age' ")

    # 计算每个点的颜色
    colors = plt_data[hue]
    min_value = max(abs(colors.min()), abs(colors.max()))
    norm = plt.Normalize(-min_value, min_value)
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])

    sc = ax.scatter(
        plt_data.index,
        plt_data[y],
        c=colors,
        cmap=cmap,
        s=30,
        # edgecolor="k",
        zorder=3,
    )
    # cb = plt.colorbar(sm, ax=ax)

    # 设置标题和轴标签
    ax.set_title(
        f"Mean Coefficient of {name} bootstrap model", fontsize=16, fontweight="bold"
    )
    ax.set_xlabel(
        "",
    )
    ax.set_ylabel("Mean Coefficient", fontsize=14)
    ax.set_yticks([-min_value / 2, 0, min_value / 2])
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    # 增加网格线
    ax.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.5)

    texts = [
        ax.text(
            idx,
            row[y],
            f"{row[name]}",
            ha="center",
            va="bottom",
            fontsize=8,
        )
        for idx, row in plt_data.head(k).iterrows()
    ] + [
        ax.text(
            idx,
            row[y],
            f"{row[name]}",
            ha="center",
            va="top",
            fontsize=8,
        )
        for idx, row in plt_data.tail(k).iterrows()
    ]
    adjust_text(texts, ax=ax, arrowprops=dict(arrowstyle="->", lw=0.5))


class GLMNETBootsrapResult(object):
    def __init__(self, bootstrap_coef_df):
        self.coef = bootstrap_coef_df
        self.features = self.coef.index.tolist()
        self._init_weights_dist()

    def _init_weights_dist(self):

        res = self.coef

        percent_of_nonZero_coefficients = (
            (res != 0).sum(axis=1) * 100 / len(res.columns)
        )
        mean_coefficients = res.mean(axis=1)
        weights_dist_df = pd.DataFrame(
            [percent_of_nonZero_coefficients, mean_coefficients],
            index=["percent_of_nonZero_coefficients", "mean_coefficients"],
        ).T
        weights_dist_df["abs_mean_coefficients"] = weights_dist_df[
            "mean_coefficients"
        ].abs()
        self.weights_dist_df = weights_dist_df

    def _plot_top_k_features(self, k=10, pallete="viridis", ax=None, exclude=None):
        """
        plot top k features
        """

        if ax is None:
            fig, ax = plt.subplots(1, 1, figsize=(5, k))

        if isinstance(exclude, str):
            exclude = [exclude]

        if exclude is not None:
            plt_data = self.coef.loc[self.coef.index.difference(exclude), :]

            top_k_features = self.weights_dist_df
            top_k_features = top_k_features.loc[
                top_k_features.index.difference(exclude), :
            ].sort_values(
                by=["mean_coefficients"],
                ascending=False,
            )
        else:
            plt_data = self.coef
            top_k_features = self.weights_dist_df.sort_values(
                by=["mean_coefficients"],
                ascending=False,
            )

        plt_data = plt_data.loc[
            [*top_k_features.index[:k], *top_k_features.index[-k:]], :
        ]
        idx_name = plt_data.index.name
        plt_data = plt_data.reset_index(drop=False).melt(id_vars=idx_name)

        sns.boxplot(
            data=plt_data,
            y=idx_name,
            x="value",
            showfliers=False,
            ax=ax,
            palette=pallete,
        )
        ax.set_xticks([0.0])
        ax.set_xticklabels(ax.get_xticklabels(), rotation=90, fontsize=10)

        ax.grid(axis="x", linestyle="--", alpha=1, linewidth=2, color="red")
        ax.set_xlabel("Mean of Coefficients")
        ax.set_ylabel("Features")
        ax.set_title(f"Top {k} Features")
        return ax

    def _show_models_coeffients(self, axes=None, color="#d67b7f", top=5, exclude=None):
        """
        res:
            model1 model2
        SOST xx yy
        BGN xx yy


        """
        if self.coef is None:
            self.coef = self._init_coeffeients_df()
        res = self.coef

        # exclude = self.cov if exclude is None else exclude + self.cov
        if exclude:
            if isinstance(exclude, str):
                exclude = [exclude]
            res = res.loc[res.index.difference(exclude), :]

        if axes is None:
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
        else:
            assert len(axes) == 2, "axes should be a list of length 2"
            ax1, ax2 = axes

        percent_of_nonZero_coefficients = (
            (res != 0).sum(axis=1) * 100 / len(res.columns)
        )
        mean_coefficients = res.mean(axis=1)
        plt_data = pd.DataFrame(
            [percent_of_nonZero_coefficients, mean_coefficients],
            index=["percent_of_nonZero_coefficients", "mean_coefficients"],
        ).T
        plt_data["abs_mean_coefficients"] = plt_data["mean_coefficients"].abs()

        # ax1
        sns.scatterplot(
            x=percent_of_nonZero_coefficients,
            y=mean_coefficients,
            size=mean_coefficients,
            sizes=(20, 400),
            legend=False,
            edgecolor="black",
            ax=ax1,
            color=color,
        )
        ax1.plot([0, 100], [0, 0], "k--", lw=3, color="grey")
        ax1.set_xlim(-1, 105)
        ax1.set_xlabel("percent of non-zero coefficients")
        ax1.set_ylabel("mean nonzero coefficients")
        sorted_plt_data = (
            plt_data.sort_values(
                by=["percent_of_nonZero_coefficients", "abs_mean_coefficients"],
                ascending=False,
            )
            .iloc[:top, :]
            .index
        )
        for i, txt in enumerate(sorted_plt_data):
            # ax1.annotate(txt, (sorted_plt_data.iloc[i, 0], sorted_plt_data.iloc[i, 1]))
            ax1.text(
                plt_data.loc[txt, "percent_of_nonZero_coefficients"],
                plt_data.loc[txt, "mean_coefficients"],
                txt,
                ha="right",
                fontsize=8,
                color="black",
            )

        # ax2
        absolute_mean_coefficients = mean_coefficients.abs().sort_values(ascending=True)
        sns.barplot(
            y=absolute_mean_coefficients,
            x=absolute_mean_coefficients.index,
            ax=ax2,
            color=color,
        )
        ax2.set_ylabel("absolute mean coefficients")
        ax2.set_xlabel("")
        xticks = ax2.get_xticklabels()
        if len(xticks) > 100:
            ax2.set_xticks([""] * len(xticks))
        else:
            ax2.set_xticklabels(ax2.get_xticklabels(), rotation=90)
        if axes is None:
            # fig.tight_layout()
            return ax1, ax2

    def coef_barplot(
        self,
        cmap="RdBu_r",
        k=10,
        ax=None,
        errorbar_kwargs=None,
        scatter_kwargs=None,
        exclude=["age", "sex"],
    ):
        from adjustText import adjust_text
        from scipy import stats

        if errorbar_kwargs is None:
            errorbar_kwargs = {}
        if scatter_kwargs is None:
            scatter_kwargs = {}

        plt_data = self.coef.copy()

        def cal_ci(x):
            mean_x = x.mean()
            scale = stats.sem(x)
            ci_low, ci_high = stats.t.interval(
                0.95, len(x) - 1, loc=mean_x, scale=scale
            )
            return {"mean": mean_x, "ci_low": ci_low, "ci_high": ci_high}

        # drop age sex

        # exclude = self.cov if exclude is None else exclude + self.cov
        if isinstance(exclude, str):
            exclude = [exclude]

        plt_data = plt_data.loc[[i not in exclude for i in plt_data.index.tolist()]]

        plt_data = plt_data.apply(
            lambda x: pd.Series(cal_ci(x)),
            axis=1,
        )
        plt_data = plt_data.sort_values("mean", ascending=False).reset_index(
            drop=False, names="feature"
        )
        if ax is None:
            fig, ax = plt.subplots(figsize=(12, 4))
        plt_data["error_low"] = plt_data["mean"] - plt_data["ci_low"]
        plt_data["error_high"] = plt_data["ci_high"] - plt_data["mean"]

        # 绘制误差线

        ax.errorbar(
            x=plt_data.index,
            y=plt_data["mean"],
            yerr=[
                plt_data["mean"] - plt_data.ci_low,
                plt_data.ci_high - plt_data["mean"],
            ],
            fmt=errorbar_kwargs.pop("fmt", "none"),  # 不使用标记
            lw=errorbar_kwargs.pop("lw", 1),
            capsize=errorbar_kwargs.pop("capsize", 2),
            ecolor=errorbar_kwargs.pop("ecolor", "lightgrey"),  # 将误差线设置为浅灰色
            **errorbar_kwargs,
        )

        # 使用scatter添加颜色渐变的散点
        # 计算每个点的颜色
        colors = plt_data["mean"]
        min_value = max(abs(colors.min()), abs(colors.max()))
        norm = plt.Normalize(-min_value, min_value)
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])

        sc = ax.scatter(
            plt_data.index,
            plt_data["mean"],
            c=colors,
            cmap=cmap,
            s=scatter_kwargs.pop("s", 5),
            # edgecolor="k",
            zorder=scatter_kwargs.pop("zorder", 3),
            **scatter_kwargs,
        )
        cb = plt.colorbar(sm, ax=ax)

        # 设置标题和轴标签
        ax.set_title(
            "Mean Coefficient of bootstrap model", fontsize=16, fontweight="bold"
        )
        ax.set_xlabel(
            "",
        )
        ax.set_ylabel("Mean Coefficient", fontsize=14)
        ax.set_yticks([-min_value / 2, 0, min_value / 2])
        ax.set_xticks([])
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        # 增加网格线
        ax.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.5)

        texts = [
            ax.text(
                idx,
                row["mean"] + row["error_high"],
                f"{row['feature']}",
                ha="center",
                va="bottom",
                fontsize=8,
            )
            for idx, row in plt_data.head(k).iterrows()
        ] + [
            ax.text(
                idx,
                row["mean"] - row["error_low"],
                f"{row['feature']}",
                ha="center",
                va="top",
                fontsize=8,
            )
            for idx, row in plt_data.tail(k).iterrows()
        ]
        adjust_text(texts, ax=ax, arrowprops=dict(arrowstyle="->", lw=0.5))

        return ax


# from tqdm.notebook import tqdm


def load_glmnet_bootstrap(model_dir):
    """
    dir/
        1/Meta/
            coef_df.csv
            train_score.csv
            test_score.csv
        1/Prot/
            coef_df.csv
            train_score.csv
            test_score.csv
        2/Meta/
        ...
    """

    model_dir = Path(model_dir)

    coef_df_name = "coef_df.csv"
    train_score = "train_score.csv"
    test_score = "test_score.csv"

    from collections import defaultdict

    res = defaultdict(lambda: defaultdict(list))

    found_csvs = list(model_dir.rglob("*.csv"))
    for file_dir in found_csvs:
        if file_dir.parent == model_dir:
            continue
        filename = file_dir.stem
        submodelname = file_dir.parent.name
        seed = file_dir.parent.parent.name
        file = pd.read_csv(file_dir)
        if filename == "coef_df":
            file.columns = ["feature", f"coef_{seed}"]
            file.set_index("feature", inplace=True)
        elif filename == "train_score":
            continue
        else:

            file.rename(columns={"pred": f"pred_{seed}"}, inplace=True)
            file.set_index("eid", inplace=True)

        res[submodelname][filename].append(file)

    for submodelname in res.keys():
        for subcsv in res[submodelname].keys():
            # first = res[submodelname][subcsv][0].iloc[:, [0]]
            merged = pd.concat(res[submodelname][subcsv], axis=1)
            res[submodelname][subcsv] = merged

    return res


import json
import subprocess
import shutil
import matplotlib.gridspec as gridspec


def run_glmnet(
    json_dir,
    train_dir,
    out_dir,
    test_dir=None,
    seed=None,
):

    if shutil.which("run_glmnet.R") is None:
        raise ValueError("run_glmnet.R is not in the PATH")

    cmd = f"run_glmnet.R --json {json_dir} --train {train_dir} --out {out_dir}"
    if test_dir is not None:
        cmd += f" --test {test_dir}"
    if seed is not None:
        cmd += f" --seed {seed}"
    print(cmd)
    subprocess.run(cmd, shell=True)
    return subprocess


class LassoTrainTFPipline(object):
    def __init__(
        self, mmconfig, dataconfig, tgtconfig, phenoconfig, testdataconfig=None
    ):
        """ """
        self.mmconfig = mmconfig
        self.dataconfig = dataconfig
        self.tgtconfig = tgtconfig
        self.phenoconfig = phenoconfig
        self.testdataconfig = testdataconfig

    def run(self, n_bootstrap=200, n_jobs=4, outputFolder="./out"):
        # simple lasso
        mmconfig = self.mmconfig
        dataconfig = self.dataconfig
        tgtconfig = self.tgtconfig
        phenoconfig = self.phenoconfig
        outputFolder = Path(outputFolder)

        label = tgtconfig.label
        diseaseData = tgtconfig.data

        phenosData = phenoconfig.data

        model_list = mmconfig["model"]
        modelname = mmconfig["name"]
        feature = mmconfig["feature"]
        cov = mmconfig["cov"] if mmconfig["cov"] is not None else []

        # copy data
        used_pheno_data = phenosData[["eid"] + cov].copy()
        used_dis_data = diseaseData[["eid", label]].copy()

        # check eid dtype

        if used_pheno_data.eid.dtype != dataconfig.data.eid.dtype:
            used_pheno_data.eid = used_pheno_data.eid.astype(dataconfig.data.eid.dtype)
        if used_dis_data.eid.dtype != dataconfig.data.eid.dtype:
            used_dis_data.eid = used_dis_data.eid.astype(dataconfig.data.eid.dtype)

        # check output
        model_output_folder = outputFolder / modelname
        model_output_folder.mkdir(parents=True, exist_ok=True)
        # lasso
        lasso_config = LassoConfig(
            feature=feature,
            label=label,
            cov=cov,
            name=modelname,
            type_measure=mmconfig.get("type_measure", "auc"),
            cv=mmconfig.get("cv", 10),
        ).to_json()
        json_dir = model_output_folder / "train_config.json"
        json.dump(lasso_config, open(json_dir, "w"))

        model_save_dir = model_output_folder / "model"

        # data save to
        train_feather = (
            dataconfig.data.merge(diseaseData[["eid", label]], on="eid", how="inner")
            .merge(phenosData[["eid"] + cov], on="eid", how="inner")
            .dropna(subset=[label])
        ).reset_index(drop=True)

        tmp_train_feather_dir = model_output_folder / "train.feather"
        ##################### rm ##################
        train_feather = train_feather.head(10000)
        ##################### rm ##################
        print(f"Train data shape: {train_feather.shape}")

        train_feather.to_feather(tmp_train_feather_dir)
        ##################### rm ##################
        if self.testdataconfig is not None:
            if self.testdataconfig.data.eid.dtype != dataconfig.data.eid.dtype:
                self.testdataconfig.data.eid = self.testdataconfig.data.eid.astype(
                    dataconfig.data.eid.dtype
                )

            # merge disease data
            test_feather = (
                self.testdataconfig.data.merge(
                    diseaseData[["eid", label]], on="eid", how="inner"
                ).dropna(subset=[label])
            ).reset_index(drop=True)

            # check cov in test data
            # if not in, merge from phenos
            to_merge_cols = []
            for c in cov:
                if c not in self.testdataconfig.data.columns:
                    to_merge_cols.append(c)
                    print(f"Missing cov in test data: {c}")

            if len(to_merge_cols) > 0:
                test_feather = test_feather.merge(
                    phenosData[["eid"] + to_merge_cols], on="eid", how="inner"
                ).reset_index(drop=True)

            tmp_test_feather_dir = model_output_folder / "test.feather"
            test_feather = test_feather[train_feather.columns.tolist()]
            print(f"Test data shape: {test_feather.shape}")
            test_feather.to_feather(tmp_test_feather_dir)
        else:
            raise ValueError("Test data is not provided")

        # run single without random seed
        single_lasso_output_folder = model_output_folder / "single"
        run_glmnet(
            json_dir=json_dir,
            train_dir=tmp_train_feather_dir,
            out_dir=single_lasso_output_folder,
            test_dir=tmp_test_feather_dir if self.testdataconfig is not None else None,
        )
        if isinstance(n_bootstrap, int) and n_bootstrap > 1:
            if self.testdataconfig is None:
                raise ValueError(
                    "Test data is not provided, cannot run bootstrap to select best"
                )
            # run bootstrap
            bootstrap_output_folder = model_output_folder / "bootstrap"
            from joblib import Parallel, delayed

            res = Parallel(n_jobs=n_jobs)(
                delayed(run_glmnet)(
                    json_dir=json_dir,
                    train_dir=tmp_train_feather_dir,
                    out_dir=bootstrap_output_folder / f"{i}",
                    test_dir=(
                        tmp_test_feather_dir
                        if self.testdataconfig is not None
                        else None
                    ),
                    seed=i,
                )
                for i in range(1, n_bootstrap + 1)
            )

            # plot bootstrap
            res = load_glmnet_bootstrap(bootstrap_output_folder)
            coef = res[modelname]["coef_df"]
            test_score = res[modelname]["test_score"]
            ## save
            coef.to_csv(bootstrap_output_folder / "bootstrap_coef_df.csv", index=True)
            test_score.reset_index(drop=False).to_feather(
                bootstrap_output_folder / "test_score.feather"
            )

            ## plot
            fig = plt.figure(figsize=(15, 10))
            gs = gridspec.GridSpec(2, 5, hspace=0.5, wspace=0.5, figure=fig)

            ax1 = fig.add_subplot(gs[0, 0:2])
            ax2 = fig.add_subplot(gs[0, 2:4])
            ax3 = fig.add_subplot(gs[:, 4:])
            ax4 = fig.add_subplot(gs[1, :4])

            glmnet_bootsrap_result = GLMNETBootsrapResult(coef)
            glmnet_bootsrap_result._show_models_coeffients(axes=[ax1, ax2])
            glmnet_bootsrap_result._plot_top_k_features(ax=ax3)
            ax3.yaxis.set_label_position("right")
            ax3.yaxis.tick_right()
            glmnet_bootsrap_result.coef_barplot(ax=ax4)
            fig.savefig(model_output_folder / "bootstrap_coef_plot.png")

            # fit the passed
            coef_mean = coef.mean(axis=1)
            non_zero_features = coef_mean[coef_mean != 0].index.tolist()

            # this time no need for cov
            non_zero_features_lasso_config = LassoConfig(
                feature=non_zero_features, label=label, cov=None, name=modelname
            ).to_json()
            non_zero_features_json_dir = (
                model_output_folder / "non_zero_features_train_config.json"
            )
            json.dump(
                non_zero_features_lasso_config, open(non_zero_features_json_dir, "w")
            )

            # run glmnet
            non_zero_features_output_folder = model_output_folder / "non_zero_features"
            run_glmnet(
                json_dir=non_zero_features_json_dir,
                train_dir=tmp_train_feather_dir,
                out_dir=non_zero_features_output_folder,
                test_dir=(
                    tmp_test_feather_dir if self.testdataconfig is not None else None
                ),
            )

            # compare them
            score_dict = {}
            single_test_score = load_data(
                single_lasso_output_folder / modelname / "test_score.csv"
            )
            single_test_score.columns = ["eid", "single"]
            score_dict["single"] = single_test_score

            bootstrap_test_score = load_data(
                bootstrap_output_folder / "test_score.feather"
            )
            bootstrap_test_score["mean"] = bootstrap_test_score.mean(axis=1)
            bootstrap_test_score = bootstrap_test_score[["eid", "mean"]]
            score_dict["mean"] = bootstrap_test_score

            non_zero_features_test_score = load_data(
                non_zero_features_output_folder / modelname / "test_score.csv"
            )
            non_zero_features_test_score.columns = ["eid", "non_zero_features"]
            score_dict["non_zero_features"] = non_zero_features_test_score

            to_compare_df = (
                test_feather[["eid", label]]
                .merge(single_test_score, on="eid", how="inner")
                .merge(bootstrap_test_score, on="eid", how="inner")
                .merge(non_zero_features_test_score, on="eid", how="inner")
            )

            to_compare_metrics = {}
            from ppp_prediction.corr import cal_binary_metrics_bootstrap

            for col in ["single", "mean", "non_zero_features"]:
                to_cal = to_compare_df[[label, col]].dropna()
                to_compare_metrics[col] = cal_binary_metrics_bootstrap(
                    to_cal[label], to_cal[col], ci_kwargs={"n_resamples": 100}
                )
            to_compare_metrics = pd.DataFrame(to_compare_metrics).T.sort_values(
                "AUC", ascending=False
            )

            to_compare_metrics.to_csv(
                model_output_folder / "compare_metrics.csv", index=True
            )

            # extract best

            best_model = to_compare_metrics.index[0]
            best_model_score = score_dict[best_model]
            best_model_score.to_csv(
                model_output_folder / "best_model_score.csv", index=False
            )

            print(f"Finished!")
        else:
            shutil.copy(
                single_lasso_output_folder / modelname / "test_score.csv",
                model_output_folder / "best_model_score.csv",
            )
            return

define json 

In [None]:
# Necessary Params to pass by cmd
outputFolder = "test"

# defined in the json
ukbData = "MulitOmicsDisease/"
ProtTrainDir = f"{ukbData}/traindata/Prot.feather"
MetaTrainDir = f"{ukbData}/traindata/Meta.feather"
RFTrainDir = f"{ukbData}/traindata/RF.feather"
heldOutDataDir = f"{ukbData}/traindata/heldout.feather"

phenoDataDir = f"{ukbData}/omicsData/phenos.feather"
diseaseDataDir = f"{ukbData}/disease/T2D_Coding_Amit_NG2018.feather"

cov = ["age", "sex"]

params_json = {
    "omicsData": {
        "Prot": {
            "name": "Prot",
            "path": MetaTrainDir,
            "feature": None,
        },
        "Meta": {
            "name": "Meta",
            "path": MetaTrainDir,
            "feature": None,
        },
        "RF": {
            "name": "RF",
            "path": RFTrainDir,
            "feature": None,
        },
    },
    "heldOutData": {
        "name": "heldOut",
        "path": heldOutDataDir,
    },
    "diseaseData": {
        # "name": "T2D_Coding_Amit_NG2018",
        "path": diseaseDataDir,
        "label": "event",
        "date": "date",
    },
    "phenoData": {
        "name": "phenos",
        "path": phenoDataDir,
    },
    "modelConfig": {
        "Prot": {
            "name": "Prot",  # name of the model for save
            "model": ["lasso"],  # not work now
            "feature": None,  # feature to use, None is all
            "cov": cov,  # covariate to use
            "cv": 10,
            "n_bootstrap": 8,
        },
        "Meta": {
            "name": "Meta",
            "model": ["lasso"],
            "feature": None,
            "cov": cov,
            "cv": 10,
            "n_bootstrap": 8,
        },
        "RF": {
            "name": "RF",
            "model": ["lasso"],
            "feature": None,
            "cov": None,
            "cv": 10,
            "n_bootstrap": None,
        },
    },
}

In [None]:
OmicsDataDirDict = {k: DataConfig(**v) for k, v in params_json["omicsData"].items()}
heldOutDataDict = DataConfig(**params_json["heldOutData"])
diseaseDict = DataConfig(**params_json["diseaseData"])
phenosDataDict = DataConfig(**params_json["phenoData"])
modelconfig = {k: ModelConfig(**v) for k, v in params_json["modelConfig"].items()}

Config = {
    "omicsData": OmicsDataDirDict,
    "heldOutData": heldOutDataDict,
    "diseaseData": diseaseDict,
    "phenosData": phenosDataDict,
    "modelConfig": modelconfig,
}

In [None]:
diseaseDict

In [None]:
# OmicsDataDirDict = {
#     "Prot": DataConfig(
#         **{
#             "name": "Prot",
#             "path": ProtTrainDir,
#             "feature": None,
#         }
#     ),
#     "Meta": DataConfig(
#         **{
#             "name": "Meta",
#             "path": MetaTrainDir,
#             "feature": None,
#         }
#     ),
# }

# heldOutDataDict = DataConfig(
#     **{
#         "name": "heldOut",
#         "path": heldOutDataDir,
#     }
# )


# ## disease phenotype
# ### eid should be the first column, event is the second column and time is the third column (optional)
# diseaseDict = DataConfig(
#     **{
#         "name": "CAD_xhv2_exclude_more",
#         "path": "/home/xutingfeng/ukb/ukbData/phenotypes/ukb_ph/disease_extracted/AAA_from_renji/CAD_xhv2_exclude_more.csv",
#         "label": "CAD_xhv2_exclude_more_event",
#         "date": "CAD_xhv2_exclude_more_date",
#     }
# )
# ## covariates
# ### eid should be the first column, and the rest are the covariates may need
# phenosDataDict = DataConfig(
#     **{
#         "name": "phenos",
#         "path": "/home/xutingfeng/ukb/project/ppp_prediction/results/Meta_Prot/dataset/RF_training_df.feather",
#     }
# )

# ## Model Config
# ProtModelConfig = ModelConfig(
#     **{
#         "name": "Prot",
#         "model": ["lasso"],
#         "feature": None,
#         "cov": ["age", "sex"],
#     }
# )
# MetaModelConfig = ModelConfig(
#     **{
#         "name": "Meta",
#         "model": ["lasso"],
#         "feature": None,
#         "cov": ["age", "sex"],
#     }
# )

# modelconfig = {"Prot": ProtModelConfig, "Meta": MetaModelConfig}
# # Other Params

# ## Other Omics Data
# ### first col should be eid and left is the feature with no missing values (will drop them if missing, so impute first or if u really want to keep them)

# # OtherOmicsDataDirDict.update()

# # Config

# Config = {
#     "omicsData": OmicsDataDirDict,
#     "heldOutData": heldOutDataDict,
#     "diseaseData": diseaseDict,
#     "phenosData": phenosDataDict,
#     "modelConfig": modelconfig,
# }

In [None]:
# update output folder to diseasse name
# outputFolder = Path(outputFolder) / diseaseDict.name

In [None]:
# load_data

for k in Config.keys():
    if isinstance(Config[k], DataConfig):
        Config[k].__load_data__()
    elif k == "omicsData":
        for omics in Config["omicsData"]:
            Config["omicsData"][omics].__load_data__()
    else:
        print(f"Skipping {k}")

In [None]:
Config["diseaseData"].data

In [None]:
# def check_disease_dist(Config):
#     # to_check_list = Config
#     diseaseData = Config["diseaseData"].data
#     disease = Config["diseaseData"].name
#     disease_event_col = Config["diseaseData"].label

#     dist = {}
#     for dataconfig in Config["omicsData"].values():
#         data_eids = dataconfig.data[["eid"]].copy()
#         dataname = dataconfig.name

#         if data_eids.eid.dtype != diseaseData.eid.dtype:
#             data_eids.eid = data_eids.eid.astype(diseaseData.eid.dtype)

#         inner_data = diseaseData.merge(data_eids, on="eid", how="inner")[
#             [disease_event_col]
#         ]

#         inner_data = inner_data.value_counts().to_dict()

#         dist[dataname] = inner_data

#     dataconfig = Config["heldOutData"]
#     data_eids = dataconfig.data[["eid"]].copy()
#     dataname = dataconfig.name

#     if data_eids.eid.dtype != diseaseData.eid.dtype:
#         data_eids.eid = data_eids.eid.astype(diseaseData.eid.dtype)

#     inner_data = diseaseData.merge(data_eids, on="eid", how="inner")[
#         [disease_event_col]
#     ]

#     inner_data = inner_data.value_counts().to_dict()

#     dist[dataname] = inner_data

#     dist_df = pd.DataFrame(dist)

#     return dist_df

In [None]:
dist_df = check_disease_dist(Config)
dist_df

In [None]:
# check features
for mconfig in Config["modelConfig"].values():
    if mconfig["feature"] is None:
        # if mconfig["name"] in ["Prot", "Meta"]:
        mconfig["feature"] = (
            Config["omicsData"][mconfig["name"]].data.columns[1:].tolist()
        )
        print(f"Set feature for {mconfig['name']}")
        # else:
        #     raise ValueError(f"Feature for {mconfig['name']} is not set")

In [None]:
# check cov
for mconfig in Config["modelConfig"].values():
    cov = mconfig["cov"]
    if cov is not None:
        if Config["phenosData"] is None:
            raise ValueError(
                f"PhenosData is not set, while covariates are set for {mconfig['name']}"
            )
        else:
            for c in cov:
                if c not in Config["phenosData"].data.columns:
                    raise ValueError(
                        f"cov of {mconfig['name']}, {c} not in phenosData columns"
                    )
        for c in cov:
            if c not in Config["heldOutData"].data.columns:
                raise ValueError(
                    f"cov of {mconfig['name']}, {c} not in heldOutData columns"
                )

## 模型训练

In [50]:
# class LassoConfig(object):
#     def __init__(
#         self,
#         feature,
#         label,
#         cov,
#         name="lasso",
#         family="binomial",
#         lambda_=None,
#         type_measure="auc",
#         cv=10,
#         **kwargs,
#     ):
#         assert isinstance(label, str), "label should be a string"
#         if cov is not None:
#             if isinstance(cov, str):
#                 cov = [cov]
#             elif isinstance(cov, list):
#                 if len(cov) == 0:
#                     cov = None

#         assert isinstance(feature, str) or isinstance(
#             feature, list
#         ), "feature should be a string or a list"
#         self.config = {
#             name: {
#                 "feature": feature if isinstance(feature, list) else [feature],
#                 "label": label,
#                 "time": None,
#                 "cov": cov,
#                 "family": family,
#                 "lambda": lambda_,
#                 "type_measure": type_measure,
#                 "cv": cv,
#             }
#         }
#         if kwargs:
#             # export them and warning not used
#             print(f"Warning: {kwargs} not used")

#     def to_json(self):
#         return self.config

In [48]:
# def plot_coef_scatter(data, coef, feature, k=6, ax=None, cmap="nejm"):
#     data = (
#         data[[coef, feature]].copy().rename(columns={coef: "coef", feature: "feature"})
#     )
#     plt_data = (
#         data.query("coef !=0")
#         .sort_values("coef", ascending=False)
#         .reset_index(drop=True)
#         .reset_index(drop=False, names=["idx"])
#     )

#     from adjustText import adjust_text

#     if ax is None:
#         fig, ax = plt.subplots(figsize=(12, 4))

#     hue = "coef"
#     y = "coef"
#     name = "feature"
#     plt_data = plt_data.query("feature != 'sex' & feature != 'age' ")

#     # 计算每个点的颜色
#     colors = plt_data[hue]
#     min_value = max(abs(colors.min()), abs(colors.max()))
#     norm = plt.Normalize(-min_value, min_value)
#     sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
#     sm.set_array([])

#     sc = ax.scatter(
#         plt_data.index,
#         plt_data[y],
#         c=colors,
#         cmap=cmap,
#         s=30,
#         # edgecolor="k",
#         zorder=3,
#     )
#     # cb = plt.colorbar(sm, ax=ax)

#     # 设置标题和轴标签
#     ax.set_title(
#         f"Mean Coefficient of {name} bootstrap model", fontsize=16, fontweight="bold"
#     )
#     ax.set_xlabel(
#         "",
#     )
#     ax.set_ylabel("Mean Coefficient", fontsize=14)
#     ax.set_yticks([-min_value / 2, 0, min_value / 2])
#     ax.spines["top"].set_visible(False)
#     ax.spines["right"].set_visible(False)
#     # 增加网格线
#     ax.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.5)

#     texts = [
#         ax.text(
#             idx,
#             row[y],
#             f"{row[name]}",
#             ha="center",
#             va="bottom",
#             fontsize=8,
#         )
#         for idx, row in plt_data.head(k).iterrows()
#     ] + [
#         ax.text(
#             idx,
#             row[y],
#             f"{row[name]}",
#             ha="center",
#             va="top",
#             fontsize=8,
#         )
#         for idx, row in plt_data.tail(k).iterrows()
#     ]
#     adjust_text(texts, ax=ax, arrowprops=dict(arrowstyle="->", lw=0.5))


# class GLMNETBootsrapResult(object):
#     def __init__(self, bootstrap_coef_df):
#         self.coef = bootstrap_coef_df
#         self.features = self.coef.index.tolist()
#         self._init_weights_dist()

#     def _init_weights_dist(self):

#         res = self.coef

#         percent_of_nonZero_coefficients = (
#             (res != 0).sum(axis=1) * 100 / len(res.columns)
#         )
#         mean_coefficients = res.mean(axis=1)
#         weights_dist_df = pd.DataFrame(
#             [percent_of_nonZero_coefficients, mean_coefficients],
#             index=["percent_of_nonZero_coefficients", "mean_coefficients"],
#         ).T
#         weights_dist_df["abs_mean_coefficients"] = weights_dist_df[
#             "mean_coefficients"
#         ].abs()
#         self.weights_dist_df = weights_dist_df

#     def _plot_top_k_features(self, k=10, pallete="viridis", ax=None, exclude=None):
#         """
#         plot top k features
#         """

#         if ax is None:
#             fig, ax = plt.subplots(1, 1, figsize=(5, k))

#         if isinstance(exclude, str):
#             exclude = [exclude]

#         if exclude is not None:
#             plt_data = self.coef.loc[self.coef.index.difference(exclude), :]

#             top_k_features = self.weights_dist_df
#             top_k_features = top_k_features.loc[
#                 top_k_features.index.difference(exclude), :
#             ].sort_values(
#                 by=["mean_coefficients"],
#                 ascending=False,
#             )
#         else:
#             plt_data = self.coef
#             top_k_features = self.weights_dist_df.sort_values(
#                 by=["mean_coefficients"],
#                 ascending=False,
#             )

#         plt_data = plt_data.loc[
#             [*top_k_features.index[:k], *top_k_features.index[-k:]], :
#         ]
#         idx_name = plt_data.index.name
#         plt_data = plt_data.reset_index(drop=False).melt(id_vars=idx_name)

#         sns.boxplot(
#             data=plt_data,
#             y=idx_name,
#             x="value",
#             showfliers=False,
#             ax=ax,
#             palette=pallete,
#         )
#         ax.set_xticks([0.0])
#         ax.set_xticklabels(ax.get_xticklabels(), rotation=90, fontsize=10)

#         ax.grid(axis="x", linestyle="--", alpha=1, linewidth=2, color="red")
#         ax.set_xlabel("Mean of Coefficients")
#         ax.set_ylabel("Features")
#         ax.set_title(f"Top {k} Features")
#         return ax

#     def _show_models_coeffients(self, axes=None, color="#d67b7f", top=5, exclude=None):
#         """
#         res:
#             model1 model2
#         SOST xx yy
#         BGN xx yy


#         """
#         if self.coef is None:
#             self.coef = self._init_coeffeients_df()
#         res = self.coef

#         # exclude = self.cov if exclude is None else exclude + self.cov
#         if exclude:
#             if isinstance(exclude, str):
#                 exclude = [exclude]
#             res = res.loc[res.index.difference(exclude), :]

#         if axes is None:
#             fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
#         else:
#             assert len(axes) == 2, "axes should be a list of length 2"
#             ax1, ax2 = axes

#         percent_of_nonZero_coefficients = (
#             (res != 0).sum(axis=1) * 100 / len(res.columns)
#         )
#         mean_coefficients = res.mean(axis=1)
#         plt_data = pd.DataFrame(
#             [percent_of_nonZero_coefficients, mean_coefficients],
#             index=["percent_of_nonZero_coefficients", "mean_coefficients"],
#         ).T
#         plt_data["abs_mean_coefficients"] = plt_data["mean_coefficients"].abs()

#         # ax1
#         sns.scatterplot(
#             x=percent_of_nonZero_coefficients,
#             y=mean_coefficients,
#             size=mean_coefficients,
#             sizes=(20, 400),
#             legend=False,
#             edgecolor="black",
#             ax=ax1,
#             color=color,
#         )
#         ax1.plot([0, 100], [0, 0], "k--", lw=3, color="grey")
#         ax1.set_xlim(-1, 105)
#         ax1.set_xlabel("percent of non-zero coefficients")
#         ax1.set_ylabel("mean nonzero coefficients")
#         sorted_plt_data = (
#             plt_data.sort_values(
#                 by=["percent_of_nonZero_coefficients", "abs_mean_coefficients"],
#                 ascending=False,
#             )
#             .iloc[:top, :]
#             .index
#         )
#         for i, txt in enumerate(sorted_plt_data):
#             # ax1.annotate(txt, (sorted_plt_data.iloc[i, 0], sorted_plt_data.iloc[i, 1]))
#             ax1.text(
#                 plt_data.loc[txt, "percent_of_nonZero_coefficients"],
#                 plt_data.loc[txt, "mean_coefficients"],
#                 txt,
#                 ha="right",
#                 fontsize=8,
#                 color="black",
#             )

#         # ax2
#         absolute_mean_coefficients = mean_coefficients.abs().sort_values(ascending=True)
#         sns.barplot(
#             y=absolute_mean_coefficients,
#             x=absolute_mean_coefficients.index,
#             ax=ax2,
#             color=color,
#         )
#         ax2.set_ylabel("absolute mean coefficients")
#         ax2.set_xlabel("")
#         xticks = ax2.get_xticklabels()
#         if len(xticks) > 100:
#             ax2.set_xticks([""] * len(xticks))
#         else:
#             ax2.set_xticklabels(ax2.get_xticklabels(), rotation=90)
#         if axes is None:
#             # fig.tight_layout()
#             return ax1, ax2

#     def coef_barplot(
#         self,
#         cmap="RdBu_r",
#         k=10,
#         ax=None,
#         errorbar_kwargs=None,
#         scatter_kwargs=None,
#         exclude=["age", "sex"],
#     ):
#         from adjustText import adjust_text
#         from scipy import stats

#         if errorbar_kwargs is None:
#             errorbar_kwargs = {}
#         if scatter_kwargs is None:
#             scatter_kwargs = {}

#         plt_data = self.coef.copy()

#         def cal_ci(x):
#             mean_x = x.mean()
#             scale = stats.sem(x)
#             ci_low, ci_high = stats.t.interval(
#                 0.95, len(x) - 1, loc=mean_x, scale=scale
#             )
#             return {"mean": mean_x, "ci_low": ci_low, "ci_high": ci_high}

#         # drop age sex

#         # exclude = self.cov if exclude is None else exclude + self.cov
#         if isinstance(exclude, str):
#             exclude = [exclude]

#         plt_data = plt_data.loc[[i not in exclude for i in plt_data.index.tolist()]]

#         plt_data = plt_data.apply(
#             lambda x: pd.Series(cal_ci(x)),
#             axis=1,
#         )
#         plt_data = plt_data.sort_values("mean", ascending=False).reset_index(
#             drop=False, names="feature"
#         )
#         if ax is None:
#             fig, ax = plt.subplots(figsize=(12, 4))
#         plt_data["error_low"] = plt_data["mean"] - plt_data["ci_low"]
#         plt_data["error_high"] = plt_data["ci_high"] - plt_data["mean"]

#         # 绘制误差线

#         ax.errorbar(
#             x=plt_data.index,
#             y=plt_data["mean"],
#             yerr=[
#                 plt_data["mean"] - plt_data.ci_low,
#                 plt_data.ci_high - plt_data["mean"],
#             ],
#             fmt=errorbar_kwargs.pop("fmt", "none"),  # 不使用标记
#             lw=errorbar_kwargs.pop("lw", 1),
#             capsize=errorbar_kwargs.pop("capsize", 2),
#             ecolor=errorbar_kwargs.pop("ecolor", "lightgrey"),  # 将误差线设置为浅灰色
#             **errorbar_kwargs,
#         )

#         # 使用scatter添加颜色渐变的散点
#         # 计算每个点的颜色
#         colors = plt_data["mean"]
#         min_value = max(abs(colors.min()), abs(colors.max()))
#         norm = plt.Normalize(-min_value, min_value)
#         sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
#         sm.set_array([])

#         sc = ax.scatter(
#             plt_data.index,
#             plt_data["mean"],
#             c=colors,
#             cmap=cmap,
#             s=scatter_kwargs.pop("s", 5),
#             # edgecolor="k",
#             zorder=scatter_kwargs.pop("zorder", 3),
#             **scatter_kwargs,
#         )
#         cb = plt.colorbar(sm, ax=ax)

#         # 设置标题和轴标签
#         ax.set_title(
#             "Mean Coefficient of bootstrap model", fontsize=16, fontweight="bold"
#         )
#         ax.set_xlabel(
#             "",
#         )
#         ax.set_ylabel("Mean Coefficient", fontsize=14)
#         ax.set_yticks([-min_value / 2, 0, min_value / 2])
#         ax.set_xticks([])
#         ax.spines["top"].set_visible(False)
#         ax.spines["right"].set_visible(False)
#         # 增加网格线
#         ax.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.5)

#         texts = [
#             ax.text(
#                 idx,
#                 row["mean"] + row["error_high"],
#                 f"{row['feature']}",
#                 ha="center",
#                 va="bottom",
#                 fontsize=8,
#             )
#             for idx, row in plt_data.head(k).iterrows()
#         ] + [
#             ax.text(
#                 idx,
#                 row["mean"] - row["error_low"],
#                 f"{row['feature']}",
#                 ha="center",
#                 va="top",
#                 fontsize=8,
#             )
#             for idx, row in plt_data.tail(k).iterrows()
#         ]
#         adjust_text(texts, ax=ax, arrowprops=dict(arrowstyle="->", lw=0.5))

#         return ax


# # from tqdm.notebook import tqdm


# def load_glmnet_bootstrap(model_dir):
#     """
#     dir/
#         1/Meta/
#             coef_df.csv
#             train_score.csv
#             test_score.csv
#         1/Prot/
#             coef_df.csv
#             train_score.csv
#             test_score.csv
#         2/Meta/
#         ...
#     """

#     model_dir = Path(model_dir)

#     coef_df_name = "coef_df.csv"
#     train_score = "train_score.csv"
#     test_score = "test_score.csv"

#     from collections import defaultdict

#     res = defaultdict(lambda: defaultdict(list))

#     found_csvs = list(model_dir.rglob("*.csv"))
#     for file_dir in found_csvs:
#         if file_dir.parent == model_dir:
#             continue
#         filename = file_dir.stem
#         submodelname = file_dir.parent.name
#         seed = file_dir.parent.parent.name
#         file = pd.read_csv(file_dir)
#         if filename == "coef_df":
#             file.columns = ["feature", f"coef_{seed}"]
#             file.set_index("feature", inplace=True)
#         elif filename == "train_score":
#             continue
#         else:

#             file.rename(columns={"pred": f"pred_{seed}"}, inplace=True)
#             file.set_index("eid", inplace=True)

#         res[submodelname][filename].append(file)

#     for submodelname in res.keys():
#         for subcsv in res[submodelname].keys():
#             # first = res[submodelname][subcsv][0].iloc[:, [0]]
#             merged = pd.concat(res[submodelname][subcsv], axis=1)
#             res[submodelname][subcsv] = merged

#     return res

In [49]:
# import json
# import subprocess
# import shutil
# import matplotlib.gridspec as gridspec


# def run_glmnet(
#     json_dir,
#     train_dir,
#     out_dir,
#     test_dir=None,
#     seed=None,
# ):

#     if shutil.which("run_glmnet.R") is None:
#         raise ValueError("run_glmnet.R is not in the PATH")

#     cmd = f"run_glmnet.R --json {json_dir} --train {train_dir} --out {out_dir}"
#     if test_dir is not None:
#         cmd += f" --test {test_dir}"
#     if seed is not None:
#         cmd += f" --seed {seed}"
#     print(cmd)
#     subprocess.run(cmd, shell=True)
#     return subprocess


# class LassoTrainTFPipline(object):
#     def __init__(
#         self, mmconfig, dataconfig, tgtconfig, phenoconfig, testdataconfig=None
#     ):
#         """ """
#         self.mmconfig = mmconfig
#         self.dataconfig = dataconfig
#         self.tgtconfig = tgtconfig
#         self.phenoconfig = phenoconfig
#         self.testdataconfig = testdataconfig

#     def run(self, n_bootstrap=200, n_jobs=4, outputFolder="./out"):
#         # simple lasso
#         mmconfig = self.mmconfig
#         dataconfig = self.dataconfig
#         tgtconfig = self.tgtconfig
#         phenoconfig = self.phenoconfig
#         outputFolder = Path(outputFolder)

#         label = tgtconfig.label
#         diseaseData = tgtconfig.data

#         phenosData = phenoconfig.data

#         model_list = mmconfig["model"]
#         modelname = mmconfig["name"]
#         feature = mmconfig["feature"]
#         cov = mmconfig["cov"] if mmconfig["cov"] is not None else []

#         # copy data
#         used_pheno_data = phenosData[["eid"] + cov].copy()
#         used_dis_data = diseaseData[["eid", label]].copy()

#         # check eid dtype

#         if used_pheno_data.eid.dtype != dataconfig.data.eid.dtype:
#             used_pheno_data.eid = used_pheno_data.eid.astype(dataconfig.data.eid.dtype)
#         if used_dis_data.eid.dtype != dataconfig.data.eid.dtype:
#             used_dis_data.eid = used_dis_data.eid.astype(dataconfig.data.eid.dtype)

#         # check output
#         model_output_folder = outputFolder / modelname
#         model_output_folder.mkdir(parents=True, exist_ok=True)
#         # lasso
#         lasso_config = LassoConfig(
#             feature=feature,
#             label=label,
#             cov=cov,
#             name=modelname,
#             type_measure=mmconfig.get("type_measure", "auc"),
#             cv=mmconfig.get("cv", 10),
#         ).to_json()
#         json_dir = model_output_folder / "train_config.json"
#         json.dump(lasso_config, open(json_dir, "w"))

#         model_save_dir = model_output_folder / "model"

#         # data save to
#         train_feather = (
#             dataconfig.data.merge(diseaseData[["eid", label]], on="eid", how="inner")
#             .merge(phenosData[["eid"] + cov], on="eid", how="inner")
#             .dropna(subset=[label])
#         ).reset_index(drop=True)

#         tmp_train_feather_dir = model_output_folder / "train.feather"
#         ##################### rm ##################
#         train_feather = train_feather.head(10000)
#         ##################### rm ##################
#         print(f"Train data shape: {train_feather.shape}")

#         train_feather.to_feather(tmp_train_feather_dir)
#         ##################### rm ##################
#         if self.testdataconfig is not None:
#             if self.testdataconfig.data.eid.dtype != dataconfig.data.eid.dtype:
#                 self.testdataconfig.data.eid = self.testdataconfig.data.eid.astype(
#                     dataconfig.data.eid.dtype
#                 )

#             # merge disease data
#             test_feather = (
#                 self.testdataconfig.data.merge(
#                     diseaseData[["eid", label]], on="eid", how="inner"
#                 ).dropna(subset=[label])
#             ).reset_index(drop=True)

#             # check cov in test data
#             # if not in, merge from phenos
#             to_merge_cols = []
#             for c in cov:
#                 if c not in self.testdataconfig.data.columns:
#                     to_merge_cols.append(c)
#                     print(f"Missing cov in test data: {c}")

#             if len(to_merge_cols) > 0:
#                 test_feather = test_feather.merge(
#                     phenosData[["eid"] + to_merge_cols], on="eid", how="inner"
#                 ).reset_index(drop=True)

#             tmp_test_feather_dir = model_output_folder / "test.feather"
#             test_feather = test_feather[train_feather.columns.tolist()]
#             print(f"Test data shape: {test_feather.shape}")
#             test_feather.to_feather(tmp_test_feather_dir)
#         else:
#             raise ValueError("Test data is not provided")

#         # run single without random seed
#         single_lasso_output_folder = model_output_folder / "single"
#         run_glmnet(
#             json_dir=json_dir,
#             train_dir=tmp_train_feather_dir,
#             out_dir=single_lasso_output_folder,
#             test_dir=tmp_test_feather_dir if self.testdataconfig is not None else None,
#         )
#         if isinstance(n_bootstrap, int) and n_bootstrap > 1:
#             if self.testdataconfig is None:
#                 raise ValueError(
#                     "Test data is not provided, cannot run bootstrap to select best"
#                 )
#             # run bootstrap
#             bootstrap_output_folder = model_output_folder / "bootstrap"
#             from joblib import Parallel, delayed

#             res = Parallel(n_jobs=n_jobs)(
#                 delayed(run_glmnet)(
#                     json_dir=json_dir,
#                     train_dir=tmp_train_feather_dir,
#                     out_dir=bootstrap_output_folder / f"{i}",
#                     test_dir=(
#                         tmp_test_feather_dir
#                         if self.testdataconfig is not None
#                         else None
#                     ),
#                     seed=i,
#                 )
#                 for i in range(1, n_bootstrap + 1)
#             )

#             # plot bootstrap
#             res = load_glmnet_bootstrap(bootstrap_output_folder)
#             coef = res[modelname]["coef_df"]
#             test_score = res[modelname]["test_score"]
#             ## save
#             coef.to_csv(bootstrap_output_folder / "bootstrap_coef_df.csv", index=True)
#             test_score.reset_index(drop=False).to_feather(
#                 bootstrap_output_folder / "test_score.feather"
#             )

#             ## plot
#             fig = plt.figure(figsize=(15, 10))
#             gs = gridspec.GridSpec(2, 5, hspace=0.5, wspace=0.5, figure=fig)

#             ax1 = fig.add_subplot(gs[0, 0:2])
#             ax2 = fig.add_subplot(gs[0, 2:4])
#             ax3 = fig.add_subplot(gs[:, 4:])
#             ax4 = fig.add_subplot(gs[1, :4])

#             glmnet_bootsrap_result = GLMNETBootsrapResult(coef)
#             glmnet_bootsrap_result._show_models_coeffients(axes=[ax1, ax2])
#             glmnet_bootsrap_result._plot_top_k_features(ax=ax3)
#             ax3.yaxis.set_label_position("right")
#             ax3.yaxis.tick_right()
#             glmnet_bootsrap_result.coef_barplot(ax=ax4)
#             fig.savefig(model_output_folder / "bootstrap_coef_plot.png")

#             # fit the passed
#             coef_mean = coef.mean(axis=1)
#             non_zero_features = coef_mean[coef_mean != 0].index.tolist()

#             # this time no need for cov
#             non_zero_features_lasso_config = LassoConfig(
#                 feature=non_zero_features, label=label, cov=None, name=modelname
#             ).to_json()
#             non_zero_features_json_dir = (
#                 model_output_folder / "non_zero_features_train_config.json"
#             )
#             json.dump(
#                 non_zero_features_lasso_config, open(non_zero_features_json_dir, "w")
#             )

#             # run glmnet
#             non_zero_features_output_folder = model_output_folder / "non_zero_features"
#             run_glmnet(
#                 json_dir=non_zero_features_json_dir,
#                 train_dir=tmp_train_feather_dir,
#                 out_dir=non_zero_features_output_folder,
#                 test_dir=(
#                     tmp_test_feather_dir if self.testdataconfig is not None else None
#                 ),
#             )

#             # compare them
#             score_dict = {}
#             single_test_score = load_data(
#                 single_lasso_output_folder / modelname / "test_score.csv"
#             )
#             single_test_score.columns = ["eid", "single"]
#             score_dict["single"] = single_test_score

#             bootstrap_test_score = load_data(
#                 bootstrap_output_folder / "test_score.feather"
#             )
#             bootstrap_test_score["mean"] = bootstrap_test_score.mean(axis=1)
#             bootstrap_test_score = bootstrap_test_score[["eid", "mean"]]
#             score_dict["mean"] = bootstrap_test_score

#             non_zero_features_test_score = load_data(
#                 non_zero_features_output_folder / modelname / "test_score.csv"
#             )
#             non_zero_features_test_score.columns = ["eid", "non_zero_features"]
#             score_dict["non_zero_features"] = non_zero_features_test_score

#             to_compare_df = (
#                 test_feather[["eid", label]]
#                 .merge(single_test_score, on="eid", how="inner")
#                 .merge(bootstrap_test_score, on="eid", how="inner")
#                 .merge(non_zero_features_test_score, on="eid", how="inner")
#             )

#             to_compare_metrics = {}
#             from ppp_prediction.corr import cal_binary_metrics_bootstrap

#             for col in ["single", "mean", "non_zero_features"]:
#                 to_cal = to_compare_df[[label, col]].dropna()
#                 to_compare_metrics[col] = cal_binary_metrics_bootstrap(
#                     to_cal[label], to_cal[col], ci_kwargs={"n_resamples": 100}
#                 )
#             to_compare_metrics = pd.DataFrame(to_compare_metrics).T.sort_values(
#                 "AUC", ascending=False
#             )

#             to_compare_metrics.to_csv(
#                 model_output_folder / "compare_metrics.csv", index=True
#             )

#             # extract best

#             best_model = to_compare_metrics.index[0]
#             best_model_score = score_dict[best_model]
#             best_model_score.to_csv(
#                 model_output_folder / "best_model_score.csv", index=False
#             )

#             print(f"Finished!")
#         else:
#             shutil.copy(
#                 single_lasso_output_folder / modelname / "test_score.csv",
#                 model_output_folder / "best_model_score.csv",
#             )
#             return

In [45]:
mmconfig = Config["modelConfig"]["Meta"]
dataconfig = Config["omicsData"]["Meta"]
tgtconfig = Config["diseaseData"]
phenoconfig = Config["phenosData"]
testconfig = Config["heldOutData"]

In [46]:
tgtconfig = Config["diseaseData"]
phenoconfig = Config["phenosData"]
testconfig = Config["heldOutData"]
for omics in Config["omicsData"].keys():
    if omics != "RF":
        continue
    assert omics in Config["modelConfig"].keys(), f"{omics} not in model config"
    mmconfig = Config["modelConfig"][omics]
    dataconfig = Config["omicsData"][omics]
    print(f"Running {omics}")
    LassoTrainTFPipline(
        mmconfig=mmconfig,
        dataconfig=dataconfig,
        tgtconfig=tgtconfig,
        phenoconfig=phenoconfig,
        testdataconfig=testconfig,
    ).run(
        outputFolder=f"./test/{tgtconfig.name}",
        n_bootstrap=mmconfig.get("n_bootstrap", None),
    )

Running RF
Train data shape: (10000, 12)
Test data shape: (28000, 12)
run_glmnet.R --json test/T2D_Coding_Amit_NG2018/RF/train_config.json --train test/T2D_Coding_Amit_NG2018/RF/train.feather --out test/T2D_Coding_Amit_NG2018/RF/single --test test/T2D_Coding_Amit_NG2018/RF/test.feather


Loaded glmnet 4.1-8

Attaching package: ‘arrow’

The following object is masked from ‘package:utils’:

    timestamp


Attaching package: ‘dplyr’

The following objects are masked from ‘package:stats’:

    filter, lag

The following objects are masked from ‘package:base’:

    intersect, setdiff, setequal, union



$train
[1] "test/T2D_Coding_Amit_NG2018/RF/train.feather"

$test
[1] "test/T2D_Coding_Amit_NG2018/RF/test.feather"

$json
[1] "test/T2D_Coding_Amit_NG2018/RF/train_config.json"

$output
[1] "test/T2D_Coding_Amit_NG2018/RF/single"

$help
[1] FALSE

[1] "train data size: 10000"
[1] "json_file have keys: 1"
[1] "Processing RF"
[1] TRUE
[1] "train data size: 9998 with featuers 10"
Training


executing %dopar% sequentially: no parallel backend registered 


In [None]:
pd.read_feather(
    "/home/xutingfeng/ukb/project/ppp_prediction/test/T2D_Coding_Amit_NG2018/Meta/test.feather"
)

In [None]:
diseaseDict.label

In [None]:
import json


mmconfig = Config["modelConfig"]["Meta"]
dataconfig = Config["omicsData"]["Meta"]
label = Config["diseaseData"].label
diseaseData = Config["diseaseData"].data
phenoData = Config["phenosData"].data
n_bootstrap = 100

model_list = mmconfig.model
modelname = mmconfig.name
feature = mmconfig.feature
cov = mmconfig.cov

# copy data
used_pheno_data = phenoData[["eid"] + cov].copy()
used_dis_data = diseaseData[["eid", label]].copy()

# check eid dtype
if used_pheno_data.eid.dtype != dataconfig.data.eid.dtype:
    used_pheno_data.eid = used_pheno_data.eid.astype(dataconfig.data.eid.dtype)
if used_dis_data.eid.dtype != dataconfig.data.eid.dtype:
    used_dis_data.eid = used_dis_data.eid.astype(dataconfig.data.eid.dtype)

# check output
model_output_folder = outputFolder / modelname
model_output_folder.mkdir(parents=True, exist_ok=True)
# lasso
lasso_config = LassoConfig(
    feature=feature, label=Config["diseaseData"].label, cov=cov, name=modelname
).to_json()
json_dir = model_output_folder / "train_config.json"
json.dump(lasso_config, open(json_dir, "w"))

model_save_dir = model_output_folder / "model"

# data save to
train_feather = (
    dataconfig.data.merge(diseaseData[["eid", label]], on="eid", how="inner")
    .merge(phenoData[["eid"] + cov], on="eid", how="inner")
    .dropna(subset=[label])
).reset_index(drop=True)
tmp_train_feather_dir = model_output_folder / "train.feather"
train_feather.to_feather(tmp_train_feather_dir)


# run single without random seed
print(
    f"run_glmnet.R --json {json_dir} --train {tmp_train_feather_dir} --test {heldOutDataDict.path} --out {model_save_dir}"
)
# run bootstrap with random seed to get the confidence interval

for i in range(1, n_bootstrap+1):
    print(
        f"run_glmnet.R --json {json_dir} --train {tmp_train_feather_dir} --test {heldOutDataDict.path} --out {model_save_dir}/{i} --seed {i}"
    )

# TODO: run bootstrap and optim by python to control

In [None]:
!run_glmnet.R --json CAD_xhv2_exclude_more/Meta/train_config.json --train CAD_xhv2_exclude_more/Meta/train.feather --test /home/xutingfeng/ukb/ukbData/MultiOmicsDiseasePrediction/data/held_out_df.feather --out CAD_xhv2_exclude_more/Meta/model
