In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

import numpy as np
import pandas as pd
import calprs
import scipy
import seaborn as sns
import matplotlib.pyplot as plt

ModuleNotFoundError: No module named 'calprs'

# Model
We have genetic component $g$, environment component $e$, two covariates (one continuous, mimicking age, and one binary, mimicking sex) $c_1, c_2$ and their corresponding effects $\alpha_1, \alpha_2$. The phenotype is defined as $y = g + e + c_1 \alpha_1 + c_2 \alpha_2$. 

We assume that some PRS is derived $\hat{g}$, and it is assumed that $E[\hat{g} - g] = 0$. But for some reason, $\hat{g} \sim \mathcal{N}(g, \tau_0 + \tau_1 c_1 + \tau_2 c_2)$. Therefore, some differntial performance between $\hat{g}$ and $g$ across groups of individuals.


In [3]:
# setup
n_indiv = 5_000
np.random.seed(1234)

g = np.random.normal(scale=np.sqrt(0.5), size=n_indiv)
e = np.random.normal(scale=np.sqrt(0.5), size=n_indiv)

# age and sex each explain 0.1 variance
age = (np.random.beta(a=3, b=3, size=n_indiv) * 100).astype(int)
sex = np.random.randint(2, size=n_indiv)
alpha_age = np.sqrt(0.1 / np.var(age))
alpha_sex = np.sqrt(0.1 / np.var(sex))

# y and y_cov
y = g + e
y_cov = y + alpha_age * age + alpha_sex * sex

# simulate predictor of g: g_hat
tau_0 = 0.2
tau_age = 0.4 / np.ptp(age)
tau_sex = 0.2 / np.ptp(sex)
g_hat = g + np.random.normal(scale=tau_0 + tau_age * age + tau_sex * sex)

# generate data frame
df = pd.DataFrame(
    {"g": g, "e": e, "y": y, "y_cov": y_cov, "age": age, "sex": sex, "prs": g_hat},
    index=np.arange(n_indiv).astype(int),
)
df.index.name = "indiv"
df["predstd0"] = 1.0
df.to_csv("toy.tsv", sep="\t")

# Data visualization to understand the issue

TODO

# Some demonstration on this data

In [3]:
df["age_q"] = pd.qcut(df["age"], q=5).cat.codes
# baseline predstd
df["predstd0"] = 1.0

In [4]:
for col in ["age_q", "sex"]:
    df_sum = calprs.summarize_pred(
        df,
        y_col="y_cov",
        pred_col="prs",
        predstd_col="predstd0",
        group_col=col,
    )
    print(f"## {col}")
    display(df_sum)

## age_q


Unnamed: 0_level_0,r2,std(y),std(pred),coverage,length
age_q,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0.33838,1.054985,0.830924,0.861611,1.644854
1,0.346413,1.081095,0.872822,0.778234,1.644854
2,0.275992,1.02337,0.84649,0.707267,1.644854
3,0.278252,1.034956,0.879951,0.609562,1.644854
4,0.264367,1.088182,0.95044,0.482828,1.644854


## sex


Unnamed: 0_level_0,r2,std(y),std(pred),coverage,length
sex,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0.339652,1.056277,0.815605,0.809264,1.644854
1,0.265643,1.066777,0.936796,0.563143,1.644854


In [5]:
np.random.seed(1234)
calibrate_idx = np.random.choice(df.index, size=1000, replace=False)

# Marginal calibration

In [6]:
df_calibrated = calprs.calibrate_pred(
    df,
    y_col="y_cov",
    pred_col="prs",
    predstd_col="predstd0",
    calibrate_idx=calibrate_idx,
    ci_method="scale",
)

df_calibrated[["age_q", "sex", "y_cov"]] = df[["age_q", "sex", "y_cov"]].set_index(
    df_calibrated.index
)
df_calibrated = df_calibrated[~df_calibrated.index.isin(calibrate_idx)]

for col in ["age_q", "sex"]:
    df_sum = calprs.summarize_pred(
        df_calibrated,
        y_col="y_cov",
        pred_col="prs",
        predstd_col="predstd0",
        group_col=col,
    )
    print(f"## {col}")
    display(df_sum)

## age_q


Unnamed: 0_level_0,r2,std(y),std(pred),coverage,length
age_q,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0.321725,1.038469,0.579632,0.896714,1.561212
1,0.341215,1.078559,0.611054,0.920157,1.561212
2,0.2869,1.040115,0.597844,0.932741,1.561212
3,0.276894,1.021092,0.616918,0.916877,1.561212
4,0.265388,1.099216,0.671637,0.837905,1.561212


## sex


Unnamed: 0_level_0,r2,std(y),std(pred),coverage,length
sex,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0.33935,1.056446,0.573388,0.912536,1.561212
1,0.25867,1.072175,0.657864,0.887745,1.561212


# Conditional calibration

In [7]:
df_calibrated = calprs.calibrate_pred(
    df,
    y_col="y_cov",
    pred_col="prs",
    predstd_col="predstd0",
    calibrate_idx=calibrate_idx,
    ci_method="scale",
    ci_adjust_cols=["age", "sex"],
)

df_calibrated[["age_q", "sex", "y_cov"]] = df[["age_q", "sex", "y_cov"]].set_index(
    df_calibrated.index
)
df_calibrated = df_calibrated[~df_calibrated.index.isin(calibrate_idx)]

for col in ["age_q", "sex"]:
    df_sum = calprs.summarize_pred(
        df_calibrated,
        y_col="y_cov",
        pred_col="prs",
        predstd_col="predstd0",
        group_col=col,
    )
    print(f"## {col}")
    display(df_sum)

## age_q


Unnamed: 0_level_0,r2,std(y),std(pred),coverage,length
age_q,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0.321725,1.038469,0.579632,0.873239,1.469641
1,0.341215,1.078559,0.611054,0.91623,1.521964
2,0.2869,1.040115,0.597844,0.928934,1.557151
3,0.276894,1.021092,0.616918,0.921914,1.595419
4,0.265388,1.099216,0.671637,0.860349,1.648943


## sex


Unnamed: 0_level_0,r2,std(y),std(pred),coverage,length
sex,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0.33935,1.056446,0.573388,0.910593,1.569589
1,0.25867,1.072175,0.657864,0.887745,1.545289


# New API

In [8]:
# train model
np.random.seed(1234)
calibrate_idx = np.random.choice(df.index, size=1000, replace=False)
df_train = df.loc[calibrate_idx, :].copy()
df_test = df.loc[~df.index.isin(calibrate_idx), :].copy()

In [9]:
model = calprs.calibrate_model(
    y=df_train["y_cov"].values,
    pred=df_train["prs"].values,
    predstd=df_train["predstd0"].values,
    ci_method="scale",
)

# adjust
df_test["cal_prs"], df_test["cal_predstd"] = calprs.calibrate_adjust(
    model=model,
    pred=df_test["prs"].values,
    predstd=df_test["predstd0"].values,
)

for col in ["age_q", "sex"]:
    df_sum = calprs.summarize_pred(
        df_test,
        y_col="y_cov",
        pred_col="cal_prs",
        predstd_col="cal_predstd",
        group_col=col,
    )
    print(f"## {col}")
    display(df_sum)

## age_q


Unnamed: 0_level_0,r2,std(y),std(pred),coverage,length
age_q,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0.321725,1.038469,0.579632,0.896714,1.561212
1,0.341215,1.078559,0.611054,0.920157,1.561212
2,0.2869,1.040115,0.597844,0.932741,1.561212
3,0.276894,1.021092,0.616918,0.916877,1.561212
4,0.265388,1.099216,0.671637,0.837905,1.561212


## sex


Unnamed: 0_level_0,r2,std(y),std(pred),coverage,length
sex,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0.33935,1.056446,0.573388,0.912536,1.561212
1,0.25867,1.072175,0.657864,0.887745,1.561212


In [10]:
model = calprs.calibrate_model(
    y=df_train["y_cov"].values,
    pred=df_train["prs"].values,
    predstd=df_train["predstd0"].values,
    ci_method="scale",
    ci_adjust_vars=df_train[["age", "sex"]].values,
)


# adjust
df_test["cal_prs"], df_test["cal_predstd"] = calprs.calibrate_adjust(
    model=model,
    pred=df_test["prs"].values,
    predstd=df_test["predstd0"].values,
    ci_adjust_vars=df_test[["age", "sex"]].values,
)

for col in ["age_q", "sex"]:
    df_sum = calprs.summarize_pred(
        df_test,
        y_col="y_cov",
        pred_col="cal_prs",
        predstd_col="cal_predstd",
        group_col=col,
    )
    print(f"## {col}")
    display(df_sum)

## age_q


Unnamed: 0_level_0,r2,std(y),std(pred),coverage,length
age_q,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0.321725,1.038469,0.579632,0.873239,1.469641
1,0.341215,1.078559,0.611054,0.91623,1.521964
2,0.2869,1.040115,0.597844,0.928934,1.557151
3,0.276894,1.021092,0.616918,0.921914,1.595419
4,0.265388,1.099216,0.671637,0.860349,1.648944


## sex


Unnamed: 0_level_0,r2,std(y),std(pred),coverage,length
sex,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0.33935,1.056446,0.573388,0.910593,1.569589
1,0.25867,1.072175,0.657864,0.887745,1.545289
