# BayesRVAT Tutorial â€” single-gene, end-to-end demo

In [None]:
# Reproducibility and core libs
import numpy as np, pandas as pd, random
from pathlib import Path

# Utils from this repo
from bayesrvat.utils.simulations import simulate_genetics, simulate_phenotype
from bayesrvat.utils.acat import acat_test
from bayesrvat import BayesRVAT

# Preprocessing
from sklearn.preprocessing import MinMaxScaler

# Plotting
import matplotlib.pyplot as plt

# Autoreload for iterative work
%load_ext autoreload
%autoreload 2

# Seeds
seed = 0
random.seed(seed); np.random.seed(seed)

In [None]:
DATA = Path("data")               
GENE = "APOB"                     
FILENAME_ANN = DATA / f"{GENE}_annotations.csv"
FILENAME_FREQ = DATA / f"{GENE}_freq.csv"
MODEL_FILE = DATA / "model_choices" / "all.csv"

N = 100_000                       # cohort size for simulation
MAF_MAX = 0.001                   # RVAT rarity cutoff
VG = 0.01                         # variance explained by genetics for the simulation

In [None]:
bim = pd.read_csv(FILENAME_FREQ)          # variant list + allele freq columns
df_annots = pd.read_csv(FILENAME_ANN)           # annotation features per variant

In [None]:
# Simulate sparse genotypes for N individuals given variant frequency table
X = simulate_genetics(bim, N).values   # shape (N, M)

# Keep non-empty variants and apply rare MAF threshold computed from X
maf = 0.5 * X.mean(axis=0)
Ikeep = (X.sum(axis=0) > 0) & (maf <= MAF_MAX)

X = X[:, Ikeep]
df_annots = df_annots.loc[Ikeep].reset_index(drop=True)
maf = 0.5 * X.mean(axis=0)             # recompute after filtering
nonzero_rows = np.unique(X.nonzero()[0])  # used by BayesRVAT for sparse indices

In [None]:
# Example MAF-derived weight and its normalized version
df_annots["beta_maf"] = 1.0 / np.sqrt(2.0 * (1.0 - maf) * maf)

df_annots["beta_maf_norm"] = MinMaxScaler().fit_transform(df_annots[["beta_maf"]].values).ravel()

In [None]:
model = pd.read_csv(MODEL_FILE)

# Ensure the annotation columns required by the model exist and are ordered
req = model["annots"].tolist()
df_annots = df_annots[req]
A = df_annots.astype(np.float32).values     # (M, K)
XA = X.dot(A).astype(np.float32)           # (N, K) gene-level aggregated features

# Priors from model table
w_mean = model["mean"].astype(np.float32).values
w_std = model["std"].astype(np.float32).values
positive_w = model["positive"].astype(int).values

In [None]:
# Simulate phenotype given aggregated features XA
# You used plof/missense/other means in the PDF; carry those explicitly
Y = simulate_phenotype(
    XA, 
    vg=VG, 
    plof_mean=8.0, 
    missense_mean=2.0, 
    other_mean=2.0, 
    annots_mean=1,
    plof_std=1,
    missense_std=2,
    other_std=2,
    annots_std=2
).astype(np.float32)

# Intercept-only covariate
F = np.ones((XA.shape[0], 1), dtype=np.float32)

In [None]:
brvat = BayesRVAT(
    Y=Y, 
    F=F, 
    X=XA, 
    idxs=nonzero_rows, 
    prior_mean=w_mean,
    prior_std=w_std, 
    positive=positive_w
)

# Optimization schedules: null then alternative
brvat.optimize_null(factr=1e-3)
brvat.optimize(factr=1e3)

pv_brvat = float(brvat.getPv())
pv_brvat

In [None]:
# Run simple burdens per-model and collect per-annotation-set p-values
results = brvat.run_simple_burden(
    Y=Y,
    F=F,
    Xt=XA,
    prior_mean=w_mean,
    annots =df_annots.columns,
    prior_std=w_std,
    positive=positive_w
)
dfres = pd.DataFrame(results)

# ACAT over consequence buckets and over all annotation sets
conseq = ["plof", "missense", "other1"] 
pv_conseq = float(acat_test(dfres.loc["pv", conseq].values)[0])
pv_annots = float(acat_test(dfres.loc["pv"].values)[0])

pv_brvat,pv_conseq,pv_annots

In [None]:
post_mean = np.asarray(brvat.post_mean[:len(conseq)], dtype=float)
post_std  = np.asarray(brvat.post_std[:len(conseq)], dtype=float)

n_samples = 10_000

rng = np.random.default_rng(seed)

samples = [rng.normal(loc=post_mean[i], scale=max(post_std[i], 1e-12), size=n_samples) 
           for i in range(len(conseq))]

plt.figure(figsize=(7, 4),dpi=100)
parts = plt.violinplot(samples, showmeans=True, showmedians=False, showextrema=False)

plt.xticks(range(1, len(conseq) + 1), conseq)

plt.ylabel("Posterior weight")
plt.title(f"{GENE}: posterior Consequence annotation weights")
plt.tight_layout()
plt.show()
