In [1]:
%load_ext lab_black
import pandas as pd
import numpy as np
import os
import statsmodels.api as sm
from admix.data import quantile_normalize
import calpgs
import subprocess
from tqdm import tqdm
import submitit
import itertools
import yaml

os.environ["R_HOME"] = "/u/project/pasaniuc/kangchen/software/miniconda3/envs/r/lib/R"

In [2]:
trait_list = np.loadtxt("data/traits.txt", dtype=str)

with open(f"data/meta.yaml", "r") as f:
    metadata = yaml.safe_load(f)

VAR_COLS, COVAR_COLS = metadata["VAR_COLS"], metadata["COVAR_COLS"]

In [3]:
def estimate_quantify(trait, group, out_prefix):
    path = f"out/format-data/{trait}.{group}.tsv"
    # remove BMI for BMI, edu for edu
    var_cols = [col for col in VAR_COLS if col != trait]
    calpgs.estimate_coef(
        df_path=path,
        y_col="QPHENO",
        mean_cols=["PGS"] + COVAR_COLS,
        var_cols=var_cols,
        out_prefix=out_prefix,
    )
    calpgs.quantify_r2(
        df_path=out_prefix + ".pred.tsv",
        y_col="QPHENO_RESID",
        pred_col="PGS",
        test_cols=VAR_COLS,
        out_prefix=out_prefix + ".resid",
    )
    calpgs.quantify_r2(
        df_path=out_prefix + ".pred.tsv",
        y_col="QPHENO",
        pred_col="pred_mean",
        test_cols=VAR_COLS,
        out_prefix=out_prefix + ".total",
    )

In [4]:
executor = submitit.SgeExecutor(folder="./submitit-logs")

executor.update_parameters(
    time_min=40,
    memory_g=12,
    setup=[
        "export PATH=~/project-pasaniuc/software/miniconda3/bin:$PATH",
        "export PYTHONNOUSERSITE=True",
        "export R_HOME=/u/project/pasaniuc/kangchen/software/miniconda3/envs/r/lib/R",
    ],
)
df_params = pd.DataFrame(
    [params for params in itertools.product(trait_list, ["white", "other", "all"])],
    columns=["trait", "group"],
)
df_params["out_prefix"] = df_params.apply(
    lambda r: f"out/estimate-quantify/{r.trait}.{r.group}", axis=1
)
print(f"{len(df_params)} jobs in total")
df_todo_params = df_params[
    ~df_params.apply(
        lambda x: os.path.exists(x.out_prefix + ".total.r2diff.tsv"), axis=1
    )
]
print(f"{len(df_todo_params)} jobs remains")

216 jobs in total
72 jobs remains


In [5]:
jobs = executor.map_array(
    estimate_quantify,
    df_todo_params.trait,
    df_todo_params.group,
    df_todo_params.out_prefix,
)