In [1]:
%load_ext lab_black
%load_ext autoreload
%autoreload 2

import submitit
import admix
import numpy as np
import pandas as pd
import calpgs
import os
import statsmodels.api as sm
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
import glob
from tqdm import tqdm
import itertools
from typing import List
import subprocess
from admix.data import quantile_normalize

In [2]:
DATA_URL = "./data-table.xlsx"

df_trait_info = pd.read_excel(DATA_URL, sheet_name=0)
trait_map = {
    row.id: row.short if row.short is not np.nan else row.description
    for _, row in df_trait_info.iterrows()
}

df_covar_info = pd.read_excel(DATA_URL, sheet_name=1)
covar_map = {row.id: row.short for _, row in df_covar_info.iterrows()}

df_display = pd.read_excel(DATA_URL, sheet_name=2)

In [3]:
DATA_DIR = "../compile-data/out/per-trait-info/"

In [4]:
COVAR_COLS = ["AGE", "SEX", "DEPRIVATION_INDEX"] + [f"PC{i}" for i in range(1, 11)]

TEST_COLS = [
    "SEX",
    "glasses",
    "AGE",
    "years_of_edu",
    "income",
    "DEPRIVATION_INDEX",
    "PC1",
    "PC2",
    "drink_alcohol",
    "ever_smoked",
    "log_BMI",
]

print("Covariates:", ", ".join(COVAR_COLS))
print("Testing:", ", ".join(TEST_COLS))

Covariates: AGE, SEX, DEPRIVATION_INDEX, PC1, PC2, PC3, PC4, PC5, PC6, PC7, PC8, PC9, PC10
Testing: SEX, glasses, AGE, years_of_edu, income, DEPRIVATION_INDEX, PC1, PC2, drink_alcohol, ever_smoked, log_BMI


In [5]:
def compute_conditional_r2(
    trait: str,
    indiv_group: str,
    out_prefix: str,
    cond_col: str,
    test_cols: List[str],
    n_bootstrap: int = 1000,
):
    """
    Compute R2 across covariate for trait and group of individuals

    Parameters
    ----------
    trait: str
        trait to load
    indiv_group: str
        group of individuals
    out_prefix: str
        output prefix
        <out_prefix>.baseline.tsv and <out_prefix>.r2_diff.tsv will be produced
    """

    df_trait = pd.read_csv(
        os.path.join(DATA_DIR, f"{trait}.tsv.gz"), index_col=0, sep="\t"
    )
    if indiv_group == "white_british":
        df_trait = df_trait[df_trait.group == "United Kingdom"]
    elif indiv_group == "other":
        df_trait = df_trait[df_trait.group != "United Kingdom"]
    else:
        raise NotImplementedError

    # residual after regressing out COVAR_COLS
    df_trait["PHENO_RESID"] = (
        sm.OLS(
            quantile_normalize(df_trait["PHENO"].values),
            sm.add_constant(df_trait[COVAR_COLS]),
            missing="drop",
        )
        .fit()
        .resid
    )
    df_trait.dropna(subset=["PHENO_RESID", "MEAN", cond_col], inplace=True)
    n_unique = len(np.unique(df_trait[cond_col].values))
    if n_unique > 5:
        cond_var = pd.qcut(df_trait[cond_col], q=5, duplicates="drop")
    else:
        cond_var = df_trait[cond_col]
    # for each group stratify by the condition:
    for i, (cond_q, df_trait_q) in enumerate(df_trait.groupby(cond_var)):
        suffix = f"{cond_col}_{i + 1}"
        # baseline
        df_baseline = calpgs.summarize_pred(
            df_trait_q,
            y_col="PHENO_RESID",
            pred_col="MEAN",
        )
        df_baseline.to_csv(
            out_prefix + f".{suffix}.baseline.tsv", sep="\t", header=False
        )

        tmp_file = out_prefix + f".{suffix}.tmp.tsv"
        df_trait_q.to_csv(tmp_file, sep="\t")
        cmds = [
            "calpgs group-r2",
            f"--df {tmp_file}",
            "--y PHENO_RESID",
            "--pred MEAN",
            f"--group {','.join(test_cols)}",
            "--cor spearman",
            f"--out {out_prefix}.{suffix}",
        ]
        subprocess.check_call(" ".join(cmds), shell=True)
        os.remove(tmp_file)

In [6]:
trait = "LDL"
indiv_group = "white_british"


In [11]:
compute_conditional_r2(
    trait=trait,
    indiv_group=indiv_group,
    out_prefix=out_prefix,
    cond_col="AGE",
    test_cols=[col for col in TEST_COLS if col != "AGE"],
)

  x = pd.concat(x[::order], 1)


# TODO
Extend this computation to other traits / or a selected set of traits

In [8]:
trait_list = list(
    set(
        [
            t.split("/")[-1].rsplit(".", 2)[0]
            for t in glob.glob(os.path.join(DATA_DIR, "*.tsv.gz"))
        ]
    )
)
print(f"{len(trait_list)} traits in total.")

247 traits in total.


In [9]:
df_params = pd.DataFrame(
    [params for params in itertools.product(trait_list, ["white_british", "other"])],
    columns=["trait", "group"],
)
df_params["out_prefix"] = df_params.apply(
    lambda r: f"out/r2-diff/{r.trait}-{r.group}", axis=1
)
print(f"{len(df_params)} jobs in total")
os.makedirs("out/r2-diff/", exist_ok=True)

494 jobs in total


In [10]:
executor = submitit.SgeExecutor(folder="./submitit-logs")

executor.update_parameters(
    time_min=40,
    memory_g=12,
    setup=[
        "export PATH=~/project-pasaniuc/software/miniconda3/bin:$PATH",
        "export PYTHONNOUSERSITE=True",
    ],
)

In [11]:
df_todo_params = df_params[
    ~df_params.apply(lambda x: os.path.exists(x.out_prefix + ".r2diff.tsv"), axis=1)
]
print(f"{len(df_todo_params)} jobs remains")

0 jobs remains


In [12]:
jobs = executor.map_array(
    compute_r2,
    df_todo_params.trait,
    df_todo_params.group,
    df_todo_params.out_prefix,
    [TEST_COLS] * len(df_todo_params),
)



# Summarize the results

In [13]:
for group in ["white_british", "other"]:
    df_group_params = df_params[df_params.group == group]
    df_baseline_r2 = []
    df_r2_diff = []
    for _, row in tqdm(df_group_params.iterrows()):
        baseline_file = row.out_prefix + ".baseline.tsv"
        if not os.path.exists(baseline_file):
            print(f"{baseline_file} does not exist.")
            continue
        df_tmp = pd.read_csv(
            baseline_file, sep="\t", header=None, index_col=0
        ).squeeze()
        df_baseline_r2.append([row.trait, df_tmp["r2"]])

        df_tmp = pd.read_csv(row.out_prefix + ".r2diff.tsv", sep="\t")
        df_tmp.insert(0, "trait", row.trait)
        df_r2_diff.append(df_tmp)
    df_baseline_r2 = pd.DataFrame(df_baseline_r2, columns=["trait", "baseline_r2"])
    df_r2_diff = pd.concat(df_r2_diff)
    df_baseline_r2.to_csv(f"out/baseline_r2.{group}.tsv", sep="\t", index=False)
    df_r2_diff.to_csv(f"out/r2diff.{group}.tsv", sep="\t", index=False)

247it [00:05, 41.37it/s]
247it [00:05, 43.73it/s]
