In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

import numpy as np
import pandas as pd
from glob import glob
import os
import statsmodels.api as sm
from typing import List
from sklearn.model_selection import train_test_split
import itertools
from tqdm import tqdm
from admix.data import quantile_normalize
import matplotlib.pyplot as plt
import calpgs
import subprocess

os.environ["R_HOME"] = "/u/project/pasaniuc/kangchen/software/miniconda3/envs/r/lib/R"

In [2]:
DATA_DIR = "../compile-data/out/per-trait-info/"
DATA_URL = "../r2-diff/data-table.xlsx"

df_trait_info = pd.read_excel(DATA_URL, sheet_name=0)
trait_map = {
    row.id: row.short if row.short is not np.nan else row.description
    for _, row in df_trait_info.iterrows()
}

df_covar_info = pd.read_excel(DATA_URL, sheet_name=1)
covar_map = {row.id: row.short for _, row in df_covar_info.iterrows()}

df_display = pd.read_excel(DATA_URL, sheet_name=2)

trait_list = df_display.id.values

COVAR_COLS = [
    "AGE",
    "SEX",
    "DEPRIVATION_INDEX",
    "log_BMI",
    "income",
    "ever_smoked",
    "drink_alcohol",
    "glasses",
    "years_of_edu",
] + [f"PC{i}" for i in range(1, 11)]
FIT_COLS = ["PGS"] + COVAR_COLS

In [3]:
df_params = pd.DataFrame(
    [params for params in itertools.product(trait_list, ["white", "other"])],
    columns=["trait", "group"],
)
df_params["out_prefix"] = df_params.apply(
    lambda r: f"out/joint-fit/{r.trait}-{r.group}", axis=1
)

In [4]:
def estimate(df, mean_cols, var_cols):
    df = df.copy()
    # format data set
    train_x = sm.add_constant(df[mean_cols])
    train_z = sm.add_constant(df[var_cols])
    train_y = df["pheno"].values
    test_x = sm.add_constant(df[mean_cols])
    test_z = sm.add_constant(df[var_cols])

    # adjust
    beta, gamma, beta_cov, gamma_cov = calpgs.fit_het_linear(
        y=train_y, mean_covar=train_x, var_covar=train_z, return_est_covar=True
    )

    beta_se = np.sqrt(np.diag(beta_cov))
    gamma_se = np.sqrt(np.diag(gamma_cov))

    df_params = pd.DataFrame(
        {
            "beta": beta,
            "beta_se": beta_se,
            "beta_z": beta / beta_se,
            "gamma": gamma,
            "gamma_se": gamma_se,
            "gamma_z": gamma / gamma_se,
        },
        index=train_x.columns,
    )
    pred_mean = test_x.dot(beta)
    pred_std = np.sqrt(np.exp(test_z.dot(gamma)))

    df["cal_pred"], df["cal_predstd"] = pred_mean, pred_std
    return df_params, df

In [5]:
def joint_fit(trait, indiv_group, out_prefix):
    df_trait = pd.read_csv(
        os.path.join(DATA_DIR, f"{trait}.tsv.gz"), index_col=0, sep="\t"
    )

    if indiv_group == "white":
        df_trait = df_trait[df_trait.group == "United Kingdom"]
    elif indiv_group == "other":
        df_trait = df_trait[~(df_trait.group == "United Kingdom")]
    else:
        raise NotImplementedError

    df_trait = df_trait.rename(columns={"MEAN": "PGS", "PHENO": "pheno"}).dropna(
        subset=["pheno", "PGS"]
    )

    # impute 0 and standardize covariates
    for col in ["PGS"] + COVAR_COLS:
        df_trait[col] = df_trait[col].fillna(0)
        df_trait[col] = (df_trait[col] - df_trait[col].mean()) / df_trait[col].std()

    df_trait = df_trait[["pheno", "PGS"] + COVAR_COLS]

    fit_cols = [col for col in FIT_COLS if col != trait]

    df_params, df_cal = estimate(df_trait, mean_cols=fit_cols, var_cols=fit_cols)
    df_params.to_csv(out_prefix + ".params.tsv", sep="\t", float_format="%.6g")
    df_cal.to_csv(out_prefix + ".predint.tsv", sep="\t", float_format="%.6g")

In [6]:
for _, params in tqdm(df_params.iterrows(), total=len(df_params)):
    joint_fit(params.trait, params.group, params.out_prefix)

100%|██████████| 22/22 [02:02<00:00,  5.55s/it]


In [7]:
# trait = "height"
# indiv_group = "other"

In [8]:
# path = "tmp.tsv"
# df_test.to_csv("tmp.tsv", sep="\t")

# cmds = [
#     "calpgs group-r2",
#     f"--df {path}",
#     "--y pheno --pred cal_pred --predstd cal_predstd",
#     "--group PC1,AGE,SEX",
#     f"--out tmp_out",
#     "--n-bootstrap 10 --n-subgroup 5",
# ]
# subprocess.check_call(" ".join(cmds), shell=True)

# df_predint = pd.read_csv(f"tmp_out.predint.tsv", sep="\t")
# df_r2 = pd.read_csv("tmp_out.r2.tsv", sep="\t")