In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

# Several observations
Under this simple simulation,
1. Because of the large regularization, weight = 1 / (alpha + eigval) does not change much across different eigen values.
2. weight can be approximated as a constant across different eigenvalues.
3. then np.square(pc) * 1 can be accurately used to approximate the predictive variance.

Questions:
1. What would happen in real data? Plot weight, etc.
2. Or what happen if one use (geno ** 2).sum(axis=1)?
3. Can one truncate at some top PCs, and approximate the rest with constant.
    - Concretely $$\sum_{i \in \text{top}} (v_i^\top g)^2 \cdot w_i + \sum_{i \in \text{rest}}(v_i^\top g)^2  \cdot w_i$$
    - And we know $$||g||^2_2 = \sum_i (v_i^\top  g)^2 = \sum_{i \in \text{top}} (v_i^\top  g)^2 + \sum_{i \in \text{rest}} (v_i^\top g)^2$$
    - Therefore, one can approximate the predictive variance under this model with $$ \sum_{i \in \text{top}} (v_i^\top g)^2 \cdot w_i + \left[||g||^2_2 - (\sum_{i \in \text{top}} (v_i^\top g)^2) \right] w_c$$

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from sklearn.linear_model import Ridge
from sklearn.decomposition import PCA
import admix
from scipy import linalg

np.random.seed(0)

# Genotype simulation


In [3]:
CHROM = 1
KG_PATH = f"/u/project/pasaniuc/kangchen/DATA/plink2-1kg/out/build38.chr{CHROM}"

In [4]:
n_admix = 2000

# Load data sets
1. European and African individuals with >1% MAF

In [5]:
dset_all = admix.io.read_dataset(KG_PATH)[0:200000:200]
dset_all.persist()
dset_eur = dset_all[:, (dset_all.indiv.SuperPop == "EUR").values]
dset_afr = dset_all[:, (dset_all.indiv.SuperPop == "AFR").values]
freq_eur = dset_eur.geno.mean(axis=[1, 2]).compute()
freq_afr = dset_afr.geno.mean(axis=[1, 2]).compute()
snp_mask = (0.01 < freq_eur) & (freq_eur < 0.99) & (0.01 < freq_eur) & (freq_eur < 0.99)
dset_eur, dset_afr = dset_eur[snp_mask], dset_afr[snp_mask]

mosaic_size = admix.simulate.calculate_mosaic_size(
    df_snp=dset_eur.snp, genetic_map="hg38", chrom=1, n_gen=7
)

np.random.seed(1)

dset_admix = admix.simulate.admix_geno(
    geno_list=[dset_eur.geno, dset_afr.geno],
    df_snp=dset_eur.snp,
    anc_props=[0.2, 0.8],
    mosaic_size=mosaic_size,
    n_indiv=n_admix,
)

100%|██████████| 4000/4000 [00:00<00:00, 35040.50it/s]


In [6]:
dset_eur

admix.Dataset object with n_snp x n_indiv = 605 x 503, no local ancestry
	snp: 'CHROM', 'POS', 'REF', 'ALT', 'QUAL', 'FILTER'
	indiv: 'PAT', 'MAT', 'SEX', 'SuperPop', 'Population'

In [7]:
dset_afr

admix.Dataset object with n_snp x n_indiv = 605 x 661, no local ancestry
	snp: 'CHROM', 'POS', 'REF', 'ALT', 'QUAL', 'FILTER'
	indiv: 'PAT', 'MAT', 'SEX', 'SuperPop', 'Population'

In [8]:
def simulate_pheno(geno_eur, geno_admix, var_g: float, var_e: float, n_sim: int = 10):
    geno = np.vstack([geno_eur, geno_admix])

    n_indiv, n_snp = geno.shape[0:2]
    beta = np.random.normal(size=(n_snp, n_sim))
    pheno_g = geno @ beta
    pheno_g *= np.sqrt(var_g / np.var(pheno_g, axis=0))
    pheno_e = np.random.normal(scale=np.sqrt(var_e), size=pheno_g.shape)
    pheno = pheno_g + pheno_e
    return pheno[0 : dset_eur.n_indiv, :], pheno[dset_eur.n_indiv :, :]

In [9]:
var_g = 1.0
var_e = 2.0

In [10]:
assert dset_eur.snp.equals(dset_admix.snp)
n_snp = dset_eur.n_snp
geno_eur, geno_admix = (
    dset_eur.geno.sum(axis=2).compute().T.astype(float),
    dset_admix.geno.sum(axis=2).compute().T.astype(float),
)
# center around geno_eur
freq_eur = geno_eur.mean(axis=0)
geno_eur -= freq_eur
geno_admix -= freq_eur
geno_all = np.vstack([geno_eur, geno_admix])

pheno_eur, pheno_admix = simulate_pheno(
    geno_eur=geno_eur, geno_admix=geno_admix, var_g=var_g, var_e=var_e
)

# Bayesian linear regression with analytical solution

In [11]:
sim_i = 0
alpha = var_e * n_snp / var_g

In [12]:
def ridge(X_train, y_train, X_test, alpha, var_e):
    """
    ridge using scikit-learn
    """
    n_snp = X_train.shape[1]
    model = Ridge(alpha=alpha, fit_intercept=True)
    model.fit(X_train, y_train)

    pred = model.predict(X_test)
    w = model.coef_.flatten()
    wcov = var_e * np.linalg.inv(np.diag([alpha] * n_snp) + X_train.T.dot(X_train))
    X_all = np.vstack([X_train, X_test])
    postvar = np.array(
        [
            var_e + X_all[i, :].T.dot(wcov.dot(X_all[i, :]))
            for i in range(X_all.shape[0])
        ]
    )
    return postvar


def ridge(X_train, y_train, X_test, alpha, var_e):
    """
    ridge using only numpy
    """
    # centering
    X_train, y_train, X_test = X_train.copy(), y_train.copy(), X_test.copy()
    center_train = X_train.mean(axis=0)
    X_train -= center_train
    X_test -= center_train
    intercept = y_train.mean()
    y_train -= intercept

    n_snp = X_train.shape[1]

    # train
    inv_XtX_train = np.linalg.inv(np.diag([alpha] * n_snp) + X_train.T @ X_train)
    w = inv_XtX_train @ X_train.T @ y_train
    wcov = var_e * inv_XtX_train

    X_all = np.vstack([X_train, X_test])
    pred = X_all @ w + intercept
    predvar = np.array(
        [
            var_e + X_all[i, :].T.dot(wcov.dot(X_all[i, :]))
            for i in range(X_all.shape[0])
        ]
    )

    return pred[len(y_train) :], predvar[len(y_train) :]


def analytical_predvar(X_train, X_test, alpha, n_components=None):
    # centering
    X_train, X_test = X_train.copy(), X_test.copy()
    center_train = X_train.mean(axis=0)
    X_train -= center_train
    X_test -= center_train

    eigval, eigvec = linalg.eigh(X_train.T @ X_train)
    weight = 1 / (eigval + alpha)

    if n_components is not None:
        truncated_weight = np.ones_like(weight) * (alpha)
        truncated_weight[-n_components:] = weight[-n_components:]
    else:
        truncated_weight = weight

    # note that X_test @ eigvec corresponds to the PCs
    pc = X_test @ eigvec
    predvar = np.square(pc) @ truncated_weight * var_e + var_e
    return predvar


def pca(X_train, X_test):
    # centering
    X_train, X_test = X_train.copy(), X_test.copy()
    center_train = X_train.mean(axis=0)
    X_train -= center_train
    X_test -= center_train

    pca = PCA(n_components=10)
    pca.fit(X_train)

    pc_train = pca.transform(X_train)
    pc_test = pca.transform(X_test)
    return pc_train, pc_test

In [13]:
predvar_gt = analytical_predvar(
    X_train=geno_eur, X_test=geno_admix, alpha=alpha, n_components=None
)

In [14]:
predvar = analytical_predvar(
    X_train=geno_eur, X_test=geno_admix, alpha=alpha, n_components=10
)

In [17]:
# plt.scatter(predvar_gt, np.square(geno_admix).sum(axis=1))

In [16]:
predvar, pc, weight = analytical_predvar(
    X_train=geno_eur, X_test=geno_admix, alpha=alpha, n_components=5
)

ValueError: too many values to unpack (expected 3)

In [None]:
plt.scatter(predvar_gt, predvar)

In [None]:
predvar

In [None]:
plt.plot(np.cumsum(np.flip(np.square(pc).var(axis=0) * weight)))

In [None]:
plt.plot(pc)

In [None]:
X_train = geno_eur
X_test = geno_admix
X_train, X_test = X_train.copy(), X_test.copy()
center_train = X_train.mean(axis=0)
X_train -= center_train
X_test -= center_train

pca = PCA(n_components=10)
pca.fit(X_train)

pc_train = pca.transform(X_train)
pc_test = pca.transform(X_test)

In [None]:
X_train = geno_eur
X_test = geno_admix

X_train, X_test = X_train.copy(), X_test.copy()
center_train = X_train.mean(axis=0)
X_train -= center_train
X_test -= center_train

eigval, eigvec = linalg.eigh(X_train.T @ X_train)
weight = 1 / (eigval + alpha)

In [None]:
plt.plot(np.square(X_test @ eigvec).var(axis=0))

In [None]:
plt.plot(weight)

In [None]:
predvar = analytical_predvar(X_train=geno_eur, X_test=geno_admix, alpha=alpha)

In [None]:
pred_admix, predvar_admix = ridge(
    X_train=geno_eur,
    y_train=pheno_eur[:, sim_i],
    X_test=geno_admix,
    alpha=alpha,
    var_e=var_e,
)

In [None]:
import admix_prs
import pandas as pd

In [None]:
df_plot = pd.DataFrame(
    {
        "y": pheno_admix[:, sim_i],
        "lower": pred_admix - np.sqrt(predvar_admix) * 1.645,
        "upper": pred_admix + np.sqrt(predvar_admix) * 1.645,
    }
)

In [None]:
fig, ax = plt.subplots(figsize=(5, 3), dpi=150)
admix_prs.plot_calibration(
    df_plot, y_col="y", lower_col="lower", upper_col="upper", n=30
)

In [None]:
admix_prs.eval_calibration(df_plot, x_col="y", lower_col="lower", upper_col="upper")