In [1]:
%load_ext lab_black

import numpy as np
import pandas as pd
import calprs
import scipy
import seaborn as sns
import matplotlib.pyplot as plt

# Model
We have genetic component $g$, environment component $e$, two covariates (one continuous, mimicking age, and one binary, mimicking sex) $c_1, c_2$ and their corresponding effects $\alpha_1, \alpha_2$. The phenotype is defined as $y = g + e + c_1 \alpha_1 + c_2 \alpha_2$. 

We assume that some PRS is derived $\hat{g}$, and it is assumed that $E[\hat{g} - g] = 0$. But for some reason, $\hat{g} \sim \mathcal{N}(g, \tau_0 + \tau_1 c_1 + \tau_2 c_2)$. Therefore, some differntial performance between $\hat{g}$ and $g$ across groups of individuals.


In [2]:
# setup
n_indiv = 5_000
np.random.seed(1234)

g = np.random.normal(scale=np.sqrt(0.5), size=n_indiv)
e = np.random.normal(scale=np.sqrt(0.5), size=n_indiv)

# age and sex each explain 0.1 variance
age = (np.random.beta(a=3, b=3, size=n_indiv) * 100).astype(int)
sex = np.random.randint(2, size=n_indiv)
alpha_age = np.sqrt(0.1 / np.var(age))
alpha_sex = np.sqrt(0.1 / np.var(sex))

# y and y_cov
y = g + e
y_cov = y + alpha_age * age + alpha_sex * sex

# simulate predictor of g: g_hat
tau_0 = 0.2
tau_age = 0.4 / np.ptp(age)
tau_sex = 0.2 / np.ptp(sex)
g_hat = g + np.random.normal(scale=tau_0 + tau_age * age + tau_sex * sex)

# generate data frame
df = pd.DataFrame(
    {"g": g, "e": e, "y": y, "y_cov": y_cov, "age": age, "sex": sex, "prs": g_hat},
    index=np.arange(n_indiv).astype(int),
)
df.index.name = "indiv"
df.to_csv("toy.tsv", sep="\t", float_format="%.5g")

# Data visualization to understand the issue

TODO

# Some demonstration on this data

In [3]:
df["age_q"] = pd.qcut(df["age"], q=5).cat.codes
# baseline predstd
df["predstd0"] = 1.0

In [4]:
for col in ["age_q", "sex"]:
    df_sum = calprs.summarize_pred(
        df,
        y_col="y_cov",
        pred_col="prs",
        predstd_col="predstd0",
        group_col=col,
    )
    print(f"## {col}")
    display(df_sum)

## age_q


Unnamed: 0_level_0,r2,std(y),std(pred),coverage,length
age_q,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0.320251,1.019783,0.795608,0.860891,1.644854
1,0.308465,1.019648,0.848208,0.789266,1.644854
2,0.281069,1.054626,0.877853,0.694268,1.644854
3,0.277034,1.030312,0.916605,0.633166,1.644854
4,0.246716,1.05223,0.931935,0.520855,1.644854


## sex


Unnamed: 0_level_0,r2,std(y),std(pred),coverage,length
sex,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0.34158,1.046217,0.818939,0.821528,1.644854
1,0.23954,1.037303,0.927383,0.58082,1.644854


# Marginal calibration

In [5]:
calibrate_idx = np.random.choice(df.index, size=1000, replace=False)
df_calibrated = calprs.calibrate_pred(
    df,
    y_col="y_cov",
    pred_col="prs",
    predstd_col="predstd0",
    calibrate_idx=calibrate_idx,
    ci_method="scale",
)

df_calibrated[["age_q", "sex", "y_cov"]] = df[["age_q", "sex", "y_cov"]].set_index(
    df_calibrated.index
)

In [6]:
for col in ["age_q", "sex"]:
    df_sum = calprs.summarize_pred(
        df_calibrated,
        y_col="y_cov",
        pred_col="prs",
        predstd_col="predstd0",
        group_col=col,
    )
    print(f"## {col}")
    display(df_sum)

## age_q


Unnamed: 0_level_0,r2,std(y),std(pred),coverage,length
age_q,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0.320251,1.019783,0.549813,0.868812,1.43183
1,0.308465,1.019648,0.586163,0.908419,1.43183
2,0.281069,1.054626,0.60665,0.88976,1.43183
3,0.277034,1.030312,0.633429,0.886432,1.43183
4,0.246716,1.05223,0.644023,0.837852,1.43183


## sex


Unnamed: 0_level_0,r2,std(y),std(pred),coverage,length
sex,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0.34158,1.046217,0.565936,0.894747,1.43183
1,0.23954,1.037303,0.640877,0.862485,1.43183


# Conditional calibration

In [7]:
calibrate_idx = np.random.choice(df.index, size=1000, replace=False)
df_calibrated = calprs.calibrate_pred(
    df,
    y_col="y_cov",
    pred_col="prs",
    predstd_col="predstd0",
    calibrate_idx=calibrate_idx,
    ci_method="scale",
    ci_adjust_cols=["age", "sex"],
)

df_calibrated[["age_q", "sex", "y_cov"]] = df[["age_q", "sex", "y_cov"]].set_index(
    df_calibrated.index
)



In [8]:
for col in ["age_q", "sex"]:
    df_sum = calprs.summarize_pred(
        df_calibrated,
        y_col="y_cov",
        pred_col="prs",
        predstd_col="predstd0",
        group_col=col,
    )
    print(f"## {col}")
    display(df_sum)

## age_q


Unnamed: 0_level_0,r2,std(y),std(pred),coverage,length
age_q,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0.320251,1.019783,0.50546,0.860891,1.436667
1,0.308465,1.019648,0.538878,0.917282,1.491454
2,0.281069,1.054626,0.557711,0.916707,1.531537
3,0.277034,1.030312,0.582331,0.917085,1.568456
4,0.246716,1.05223,0.59207,0.88999,1.620317


## sex


Unnamed: 0_level_0,r2,std(y),std(pred),coverage,length
sex,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0.34158,1.046217,0.520282,0.898926,1.493098
1,0.23954,1.037303,0.589178,0.902091,1.564487
