In [1]:
# %load_ext autoreload
# %autoreload 2
%load_ext lab_black

import os
import numpy as np
import pandas as pd
from glob import glob
import statsmodels.api as sm
from typing import List
import itertools
import matplotlib.pyplot as plt
from scipy import stats
import calpgs
import seaborn as sns
from tqdm import tqdm

np.random.seed(42)
plt.rcParams["font.family"] = "Arial"

In [2]:
# data
df_cov = pd.read_csv("data/cov.tsv", sep="\t", index_col=0).reset_index(drop=True)
n_indiv = df_cov.shape[0]
df_cov["GIA"] = np.sort(stats.expon.ppf(np.linspace(0.01, 0.4, n_indiv)))[
    stats.rankdata(df_cov["PC1"], method="ordinal") - 1
]
df_cov.drop(columns="PC1", inplace=True)

df_cov = (df_cov - df_cov.mean(axis=0)) / df_cov.std(axis=0)

df_cov = df_cov.sample(n=50000)

# parameter
baseline_r2 = 0.25
cov_effects = [0.25, 0.2, 0.15]

# simulate data
cov = df_cov.values
n_indiv = cov.shape[0]
pred = np.random.normal(size=n_indiv)
design = np.hstack([np.ones((n_indiv, 1)), pred.reshape(-1, 1), cov])

true_beta = np.array([0, 1] + [0] * cov.shape[1])
true_gamma = np.array([np.log(1 / baseline_r2 - 1), 0] + list(cov_effects))
slope = np.ones(n_indiv)
slope[cov[:, 1] > 0] += 5
y = np.random.normal(
    loc=(design @ true_beta) * slope, scale=np.sqrt(np.exp(design @ true_gamma))
)

# prepare data
mean_covar, var_covar, y = design, design, y
slope_covar = cov.copy()
slope_covar[:, 1] = (slope_covar[:, 1] > 0).astype(float)

In [3]:
# female = cov[:, 1] < 0
# fig, ax = plt.subplots(figsize=(3, 3), dpi=150)
# plt.scatter(pred[female], y[female], s=2, label="Female")
# plt.scatter(pred[~female], y[~female], s=2, label="Male")
# plt.legend()
# plt.show()

In [4]:
def fit_het_linear_vary_slope(
    y: np.ndarray,
    mean_covar: np.ndarray,
    var_covar: np.ndarray,
    slope_covar: np.ndarray,
):
    """Fit `fit_het_linear` with varying slope
    y ~ N((mean_covar * mean_beta) * (1 + slope_covar * slope_beta), exp(var_covar * var_beta))

    """
    slope_beta = np.zeros(slope_covar.shape[1])

    for i in range(500):

        slope_weight = 1 + slope_covar @ slope_beta

        print(f"## iter{i}")
        print("slope_beta: ", slope_beta[0:5])
        print("slope_weight: ", slope_weight[0:5])
        beta, gamma = calpgs.fit_het_linear(
            y=y,
            mean_covar=mean_covar * slope_weight[:, None],
            var_covar=var_covar,
            method="remlscore",
            return_est_covar=False,
        )
        fitted_mu = mean_covar @ beta
        fitted_var = np.exp(var_covar.dot(gamma))
        print("fitted_mu: ", fitted_mu)
        print("fitted_mu: ", fitted_mu)

        # fit slope beta
        #     y ~ N(mu * (1 + slope_covar * slope_beta), var)
        # ==> y - mu ~ N(mu * slope_covar * slope_beta, var)
        slope_beta = (
            sm.WLS(
                endog=y - fitted_mu,
                exog=fitted_mu[:, None] * slope_covar,
                weights=1.0 / fitted_var,
            )
            .fit()
            .params
        )

In [5]:
fit = calpgs.fit_het_linear(
    y=y,
    mean_covar=mean_covar,
    var_covar=var_covar,
    slope_covar=slope_covar,
    return_est_covar=True,
)

In [6]:
fit

(array([4.33064936e-03, 1.00438635e+00, 1.07600965e-03, 8.62841631e-04,
        1.34360285e-03]),
 array([1.08817347, 0.00612677, 0.24057558, 0.2002141 , 0.14462882]),
 array([ 5.64005247e-03,  4.98363051e+00, -3.16167652e-04]),
 array([[ 2.84512746e-05,  3.56233607e-08,  1.01708703e-06,
         -2.28958680e-05,  6.31236183e-07],
        [ 3.56233607e-08,  4.45187543e-06,  2.82148750e-08,
         -2.16321145e-08,  3.10284862e-08],
        [ 1.01708703e-06,  2.82148750e-08,  4.47400722e-06,
          2.66818954e-08,  7.72352439e-07],
        [-2.28958680e-05, -2.16321145e-08,  2.66818954e-08,
          2.20594106e-05, -1.30941200e-07],
        [ 6.31236183e-07,  3.10284862e-08,  7.72352439e-07,
         -1.30941200e-07,  4.64907896e-06]]),
 array([[ 4.00043369e-05,  4.37152352e-08,  6.47311303e-08,
         -2.53433687e-08, -7.02212210e-08],
        [ 4.37152352e-08,  3.98122909e-05,  2.14458518e-07,
         -1.65043100e-07,  1.00853421e-07],
        [ 6.47311303e-08,  2.14458518e-07