In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

import numpy as np
import pandas as pd
from glob import glob
import os
import statsmodels.api as sm
from sklearn.model_selection import train_test_split
import itertools
from tqdm import tqdm
import calpgs
import pickle

In [2]:
DATA_DIR = "../compile-data/out/per-trait-info/"

COVAR_COLS = ["AGE", "SEX", "DEPRIVATION_INDEX"] + [f"PC{i}" for i in range(1, 11)]

trait_list = [
    os.path.basename(f)[:-7] for f in glob(os.path.join(DATA_DIR, "*.tsv.gz"))
]

# Build calibration model

In [23]:
def build_model(data_prefix: str, ci_adjust: str, out_prefix: str):
    assert ci_adjust in ["none", "all"]

    df_train = pd.read_csv(data_prefix + ".train.tsv", sep="\t", index_col=0)
    df_test = pd.read_csv(data_prefix + ".test.tsv", sep="\t", index_col=0)

    # train model
    if ci_adjust == "none":
        ci_adjust_vars = None
    elif ci_adjust == "all":
        ci_adjust_vars = df_train.iloc[:, 3:]
    else:
        raise NotImplementedError

    model = calpgs.calibrate_model(
        y=df_train["pheno"].values,
        pred=df_train["pred"].values,
        predstd=df_train["predstd"].values,
        ci_method="scale",
        ci_adjust_vars=ci_adjust_vars,
    )

    # adjust model
    if ci_adjust == "none":
        ci_adjust_vars = None
    elif ci_adjust == "all":
        ci_adjust_vars = df_test.iloc[:, 3:]
    else:
        raise NotImplementedError

    df_test["cal_pred"], df_test["cal_predstd"] = calpgs.calibrate_adjust(
        model=model,
        pred=df_test["pred"].values,
        predstd=df_test["predstd"].values,
        ci_adjust_vars=ci_adjust_vars,
    )

    out_dir = os.path.dirname(out_prefix)
    os.makedirs(out_dir, exist_ok=True)

    with open(out_prefix + ".model", "wb") as f:
        pickle.dump(model, f)
    df_test.to_csv(out_prefix + ".test_info.tsv", sep="\t")

In [24]:
data_prefix_list = np.unique([p.split(".")[0] for p in glob("out/data/*/*")])
df_params = pd.DataFrame(
    [params for params in itertools.product(data_prefix_list, ["all", "none"])],
    columns=["data_prefix", "ci_adjust"],
)
df_params["out_prefix"] = (
    df_params.data_prefix.str.replace("/data/", "/model/")
    + "-"
    + df_params["ci_adjust"]
)

In [25]:
for _, param in tqdm(df_params.iterrows(), total=len(df_params)):
    build_model(
        data_prefix=param.data_prefix,
        ci_adjust=param.ci_adjust,
        out_prefix=param.out_prefix,
    )

100%|██████████| 360/360 [02:58<00:00,  2.02it/s]
