In [1]:
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
import warnings
from sklearn.exceptions import ConvergenceWarning

In [2]:
pheno = pd.read_table('bladder-pheno.txt', index_col=0)
data = pd.read_table('bladder-expr.txt', index_col=0).T
batch = pheno['batch']
covars = pheno[['age', 'cancer']]

In [3]:
design = pd.get_dummies(batch.loc[data.index], prefix='_batch')
n_array, n_batch = design.shape

design = design.join(pd.get_dummies(covars, drop_first=True))

* Y (n, p): data
* X (n, c): design
* B_hat (c, p): `X B_hat ~ Y`
* var_pooled (p,): `(Y - Y_hat)^2` mean
* stand_mean (n, p): `B_hat.batch` weighted mean + `X.cov B_hat.cov`
* Z (n, p): standardized data

In [4]:
Y = data.to_numpy()
X = design.to_numpy()
X_batch, X_cov = X[:, :n_batch], X[:, n_batch:]
n_batches = X_batch.sum(axis=0)
batches = [b.nonzero() for b in X_batch.T]

In [5]:
B_hat = np.linalg.solve(X.T @ X, X.T @ Y)
var_pooled = np.full(n_array, 1/n_array) @ (Y - X @ B_hat)**2
stand_mean = (n_batches / n_array) @ B_hat[:n_batch, :] + X_cov @ B_hat[n_batch:, :]
Z = (Y - stand_mean) / np.sqrt(var_pooled)

* gamma_hat (b, p): `X(batch) gamma_hat ~ Z`
* delta_hat (b, p): `Z` var for each batch
* gamma_bar, t2 (b,): `gamma_hat ~ N(gamma_bar,  t2)`
* lambda_bar, theta_bar (b,): `delta_hat ~ InvGamma(lambda_bar, theta_bar)`

In [6]:
gamma_hat = np.linalg.solve(X_batch.T @ X_batch, X_batch.T @ Z)
gamma_bar = gamma_hat.mean(axis=1)
tau_bar2 = gamma_hat.var(axis=1, ddof=1)

In [7]:
def _lambda_invgamma(mean, var):
    return mean**2 / var + 2

def _theta_invgamma(mean, var):
    return mean**3 / var + mean

delta_hat = np.vstack([Z[batch].var(axis=0, ddof=1) for batch in batches])
V = delta_hat.mean(axis=1)
S2 = delta_hat.var(axis=1, ddof=1)
lambda_bar = _lambda_invgamma(V, S2)
theta_bar = _theta_invgamma(V, S2)

In [8]:
def _postmean(n, g_hat, g_bar, t_bar2, d_star):
    return (n*t_bar2*g_hat + d_star*g_bar) / (n*t_bar2 + d_star)

def _postvar(n, sum_sq, l_bar, th_bar):
    return (0.5*sum_sq + th_bar) / (0.5*n + l_bar - 1)

def _em_fit(z, g_hat, g_bar, t_bar2, d_hat, l_bar, th_bar,
            tol=0.0001, max_iter=100):
    n = z.shape[0]
    g_old, d_old = g_hat, d_hat
    
    for n_iter in range(1, max_iter+1):
        g_new = _postmean(n, g_hat, g_bar, t_bar2, d_old)
        d_new = _postvar(n, ((z - g_new)**2).sum(axis=0), l_bar, th_bar)
        
        change = max((np.abs(g_new - g_old) / g_old).max(),
                     (np.abs(d_new - d_old) / d_old).max())
        if change < tol:
            converged = True
            break
        g_old = g_new
        d_old = d_new
    
    if not converged:
        warnings.warn('Batch did not converge!', ConvergenceWarning)
    return {'gamma': g_new, 'delta': d_new, 'n_iter': n_iter}

In [9]:
batch_fits = Parallel(n_jobs=4)(delayed(_em_fit)(
    Z[batches[i]], gamma_hat[i], gamma_bar[i], tau_bar2[i],
    delta_hat[i], lambda_bar[i], theta_bar[i],
    tol=0.0001, max_iter=100
) for i in range(n_batch))
gamma_star = np.array([result['gamma'] for result in batch_fits])
delta_star = np.array([result['delta'] for result in batch_fits])

In [10]:
adjusted = np.sqrt(var_pooled / (X_batch @ delta_star)) * (Z - X_batch @ gamma_star) + stand_mean

In [11]:
r = pd.DataFrame(adjusted, index=data.index, columns=data.columns)