In [2]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black
import admix_prs

import xarray as xr
import pandas as pd
import numpy as np

import admix
from scipy import stats
import matplotlib.pyplot as plt
from os.path import join
from tqdm import tqdm

import matplotlib.pyplot as plt

In [4]:
trait = "height"

In [25]:
df_pheno = []
for group in ["eur_test", "admix"]:
    df_tmp = pd.read_csv(
        join(f"../../data/pheno/{group}.{trait}.residual_pheno"), sep="\t"
    )
    df_tmp.index = df_tmp.FID.astype(str) + "_" + df_tmp.IID.astype(str)
    df_tmp = df_tmp.drop(columns=["FID", "IID"])
    df_pheno.append(df_tmp)
df_pheno = (
    pd.concat(df_pheno)
    .reset_index()
    .drop_duplicates(subset=["index"])
    .set_index("index")
)

df_prs_eur_test = pd.read_csv(f"out/prs/{trait}.eur_test.tsv.gz", sep="\t", index_col=0)
df_prs_eur_test["GROUP"] = "eur_test"
df_prs_admix = pd.read_csv(f"out/prs/{trait}.admix.tsv.gz", sep="\t", index_col=0)
df_prs_admix["GROUP"] = "admix"
df_prs = pd.concat([df_prs_eur_test, df_prs_admix])
df_prs["PHENO"] = df_pheno["PHENO"].reindex(df_prs.index)

QUANTILES = [0.1, 0.9]
pred_interval = np.quantile(
    df_prs[[f"SAMPLE_{i}" for i in range(1, 501)]], q=QUANTILES, axis=1
)
pred_sd = np.std(df_prs[[f"SAMPLE_{i}" for i in range(1, 501)]], axis=1)
df_plot = pd.DataFrame(
    {
        "PRS_MEAN": df_prs["MEAN"],
        "PRS_SD": pred_sd,
        "GROUP": df_prs["GROUP"],
        "PHENO": df_prs["PHENO"],
    }
)

for q_i, q in enumerate(QUANTILES):
    df_plot[f"PRS_Q_{q}"] = pred_interval[q_i, :]

# intercept = df_plot.GV.mean() - df_plot["PRS_MEAN"].mean()
# for col in [col for col in df_plot.columns if col.startswith("PRS_")]:
#     if col == "PRS_SD":
#         continue
#     df_plot[col] += intercept

In [26]:
df_lanc = pd.read_csv("../02_simulation/out/admix_lanc.tsv", sep="\t", index_col=0)
df_lanc["lanc"] = 1.0 - df_lanc["lanc"]
df_lanc["lanc_q"] = pd.qcut(df_lanc.lanc, q=5).cat.codes + 1

df_plot["lanc"] = df_lanc["lanc"].reindex(df_plot.index)
df_plot["lanc_q"] = df_lanc["lanc_q"].reindex(df_plot.index)

In [27]:
df_plot = df_plot[df_plot.GROUP == "admix"].dropna()

In [30]:
df_plot_group = {"group": [], "R2": [], "avg_prs_sd": []}

for i, group in df_plot.groupby("lanc_q"):
    df_plot_group["group"].append(i)
    df_plot_group["R2"].append(
        stats.pearsonr(group["PRS_MEAN"], group["PHENO"])[0] ** 2
    )
    df_plot_group["avg_prs_sd"].append(group["PRS_SD"].mean())
df_plot_group = pd.DataFrame(df_plot_group)

In [31]:
df_plot_group

Unnamed: 0,group,R2,avg_prs_sd
0,1.0,0.249725,0.473674
1,2.0,0.094964,0.514939
2,3.0,0.081431,0.565839
3,4.0,0.081011,0.592151
4,5.0,0.088995,0.602786


In [37]:
for i, group in df_plot.groupby("lanc_q"):
    print(i)
    print(group["PRS_Q_0.1"].mean())
    print(group["PRS_Q_0.9"].mean())
    print(group.PHENO.mean())

1.0
-1.2306598056171103
-0.0183229858081802
-0.05309132650775484
2.0
-2.476588733644685
-1.1607323558398845
0.04889969910020185
3.0
-3.648146141427444
-2.2010729652736583
0.059692455928803186
4.0
-4.304332573152047
-2.791236780692913
0.006996526008813383
5.0
-4.575945131792744
-3.0350750256643053
0.027022439455839514


In [None]:
# showcasing that 

In [23]:
df_plot_group

{'group': [1.0, 2.0, 3.0, 4.0, 5.0],
 'R2': [0.24972532034989192,
  0.09496410311452468,
  0.08143130137581694,
  0.08101095383327557,
  0.08899516105454677],
 'avg_prs_sd': [0.473673947332684,
  0.5149387043886997,
  0.5658389447067689,
  0.5921509233036603,
  0.6027856366178806]}

In [15]:
df_admix_plot

Unnamed: 0_level_0,PRS_MEAN,PRS_SD,GROUP,PHENO,lanc,lanc_q
indiv,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
1001418_1001418,-0.335150,0.497760,admix,0.068321,0.267383,2.0
1002721_1002721,-2.632120,0.577162,admix,0.195220,0.697249,3.0
1003248_1003248,-3.558477,0.577964,admix,0.657138,0.885233,5.0
1003263_1003263,-2.672451,0.535988,admix,-1.214965,0.496493,3.0
1005250_1005250,-3.696821,0.567640,admix,0.085189,0.905359,5.0
...,...,...,...,...,...,...
6022871_6022871,-2.951967,0.613482,admix,0.547353,0.911980,5.0
6023479_6023479,-2.798825,0.518653,admix,-0.916475,0.558200,3.0
6024954_6024954,-2.178861,0.527248,admix,0.153260,0.517732,3.0
6025478_6025478,-1.262301,0.482579,admix,-0.396475,0.391524,2.0
