# NB GLM test

In [323]:
import multiprocessing
import warnings
from math import floor
from pathlib import Path
from typing import List
from typing import Literal
from typing import Optional
from typing import Tuple
from typing import Union
from typing import cast

In [324]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from scipy.linalg import solve  # type: ignore
from scipy.optimize import minimize  # type: ignore
from scipy.special import gammaln  # type: ignore
from scipy.special import polygamma  # type: ignore
from scipy.stats import norm  # type: ignore
from sklearn.linear_model import LinearRegression  # type: ignore

In [325]:
#generate count data for a single gene (vector)
counts = [5071, 5010, 1623, 1851, 1769, 2065, 1412, 1188, 1668, 1674, 949, 1104, 841, 1061, 852, 1108]
counts = np.array(counts)
counts 

array([5071, 5010, 1623, 1851, 1769, 2065, 1412, 1188, 1668, 1674,  949,
       1104,  841, 1061,  852, 1108])

In [326]:
#generate CNV data
cnv = [4, 5, 3, 4, 4, 3, 4, 4, 2, 2, 2, 2, 2, 2, 2, 2]
cnv = np.array(cnv)
cnv = cnv/2
cnv = cnv + 10e-9
cnv

array([2.00000001, 2.50000001, 1.50000001, 2.00000001, 2.00000001,
       1.50000001, 2.00000001, 2.00000001, 1.00000001, 1.00000001,
       1.00000001, 1.00000001, 1.00000001, 1.00000001, 1.00000001,
       1.00000001])

In [327]:
#Generate dispersion data
#dispersions = np.array([0.1989177, 1.3452317, 0.2803645, 0.2750281, 0.1786459, 0.3633964, 0.4704590, 0.3415177, 0.2745592, 0.4313849])
disp = 1.3452317
alpha = disp
alpha

1.3452317

In [328]:
#Generate real calculated sf data
size_factors = np.array([1.16950695, 1.08637265, 1.15012341, 1.19435402, 1.08791923, 1.17677969, 1.16581881,
       0.77746277, 1.17488973, 0.91116436, 0.8980075 , 1.06022032, 0.90552191, 0.9371684 , 0.78716591, 0.99839825])
size_factors

array([1.16950695, 1.08637265, 1.15012341, 1.19435402, 1.08791923,
       1.17677969, 1.16581881, 0.77746277, 1.17488973, 0.91116436,
       0.8980075 , 1.06022032, 0.90552191, 0.9371684 , 0.78716591,
       0.99839825])

In [137]:
#Generate design matrix
#X = np.array([1, 0])
#X = np.repeat(X, [8, 8], axis=0)
#X.T

array([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0])

In [329]:
X = {'condition': [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]}
X

{'condition': [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]}

In [330]:
X = pd.DataFrame(X, index = ['sample1', 'sample2', 'sample3', 'sample4', 'sample5', 'sample6', 'sample7', 'sample8', 'sample9', 'sample10',
                            'sample11', 'sample12', 'sample13', 'sample14', 'sample15', 'sample16'])
X.insert(0, "intercept", 1)
X

Unnamed: 0,intercept,condition
sample1,1,0
sample2,1,0
sample3,1,0
sample4,1,0
sample5,1,0
sample6,1,0
sample7,1,0
sample8,1,0
sample9,1,1
sample10,1,1


In [331]:
X = np.array(X)
X

array([[1, 0],
       [1, 0],
       [1, 0],
       [1, 0],
       [1, 0],
       [1, 0],
       [1, 0],
       [1, 0],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1]])

In [332]:
design_matrix = X

In [333]:
num_vars = design_matrix.shape[1]
num_vars

2

In [205]:
Q, R = np.linalg.qr(X)
Q, R

(array([[-0.25, -0.25],
        [-0.25, -0.25],
        [-0.25, -0.25],
        [-0.25, -0.25],
        [-0.25, -0.25],
        [-0.25, -0.25],
        [-0.25, -0.25],
        [-0.25, -0.25],
        [-0.25,  0.25],
        [-0.25,  0.25],
        [-0.25,  0.25],
        [-0.25,  0.25],
        [-0.25,  0.25],
        [-0.25,  0.25],
        [-0.25,  0.25],
        [-0.25,  0.25]]),
 array([[-4., -2.],
        [ 0.,  2.]]))

In [274]:
y = np.log(counts / size_factors + 0.1)
y

array([8.37473413, 8.43636858, 7.25223318, 7.34594037, 7.39396428,
       7.47016086, 7.0994213 , 7.33181146, 7.25827672, 7.51605766,
       6.96308028, 6.94831452, 6.83394313, 7.03194775, 6.98699515,
       7.01200501])

In [275]:
beta_init = solve(R, Q.T @ y)
beta = beta_init
beta

array([ 7.58807927, -0.51925174])

In [276]:
ridge_factor = np.diag(np.repeat(1e-6, num_vars))
ridge_factor

array([[1.e-06, 0.e+00],
       [0.e+00, 1.e-06]])

In [277]:
# Classical model
mu = np.maximum(size_factors * np.exp(X @ beta), min_mu)
mu

array([2309.21176602, 2145.06164812, 2270.93862995, 2358.27273688,
       2148.11539717, 2323.57191734, 2301.92946959, 1535.11372987,
       1380.22502252, 1070.40841127, 1054.95212892, 1245.51485785,
       1063.77983117, 1100.95717323,  924.73877175, 1172.88815444])

In [278]:
# Analitical IRLS algorithm calculations check

W = mu / (1.0 + mu * disp)
W

array([0.74312714, 0.74310884, 0.74312311, 0.74313212, 0.74310921,
       0.74312862, 0.74312639, 0.74300657, 0.74296622, 0.74285048,
       0.74284293, 0.74292296, 0.74284727, 0.74286478, 0.74276928,
       0.74289553])

In [280]:
z = np.log(mu / size_factors) + (counts - mu) / mu
z

array([8.78406669, 8.92367648, 7.30276173, 7.37297576, 7.41159182,
       7.47679719, 7.20147773, 7.3619633 , 7.27732612, 7.63271658,
       6.96839441, 6.95520796, 6.85940465, 7.03253441, 6.9901688 ,
       7.01350413])

In [281]:
H = (X.T * W) @ X + ridge_factor
H

array([[11.88782244,  5.94295944],
       [ 5.94295944,  5.94296044]])

In [282]:
beta_hat = solve(H, X.T @ (W * z), assume_a="pos")
beta_hat

array([ 7.72941716, -0.63825702])

### CN normalized model

In [301]:
#cn = np.log(cnv)
cnv

array([2.00000001, 2.50000001, 1.50000001, 2.00000001, 2.00000001,
       1.50000001, 2.00000001, 2.00000001, 1.00000001, 1.00000001,
       1.00000001, 1.00000001, 1.00000001, 1.00000001, 1.00000001,
       1.00000001])

In [334]:
X_t = X.T
X_t

array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]])

In [286]:
min_mu = 0.5
mu = np.maximum(size_factors * np.exp(X_t[0] * beta[0] * cnv + X_t[1] * beta[1]), min_mu)
mu # controllare se devo passare log(cn)

array([4.55957905e+06, 1.88205037e+08, 1.00910395e+05, 4.65645079e+06,
       4.24149145e+06, 1.03249184e+05, 4.54520003e+06, 3.03110893e+06,
       1.38022513e+03, 1.07040849e+03, 1.05495221e+03, 1.24551495e+03,
       1.06377991e+03, 1.10095726e+03, 9.24738842e+02, 1.17288824e+03])

In [287]:
W = mu / (1.0 + mu * disp)
W

array([0.74336624, 0.74336636, 0.74336089, 0.74336625, 0.74336624,
       0.74336101, 0.74336624, 0.74336618, 0.74296622, 0.74285048,
       0.74284293, 0.74292296, 0.74284727, 0.74286478, 0.74276928,
       0.74289553])

In [288]:
z = np.log(mu / size_factors) + (counts - mu) / mu
z

array([14.17727078, 17.97022487, 10.39820255, 14.17655613, 14.17657568,
       10.40211914, 14.17646927, 14.17655055,  7.27732611,  7.63271654,
        6.96839442,  6.95520797,  6.85940467,  7.03253441,  6.99016881,
        7.01350414])

In [289]:
H = (X.T * W) @ X + ridge_factor
H

array([[11.88987986,  5.94295944],
       [ 5.94295944,  5.94296044]])

In [290]:
beta_hat = solve(H, X.T @ (W * z), assume_a="pos")
beta_hat

array([13.70674867, -6.61558752])

### Test classical GLM

In [362]:
def nb_nll(
    counts: np.ndarray, mu: np.ndarray, alpha: Union[float, np.ndarray]
) -> Union[float, np.ndarray]:
    n = len(counts)
    alpha_neg1 = 1 / alpha
    logbinom = gammaln(counts + alpha_neg1) - gammaln(counts + 1) - gammaln(alpha_neg1)
    if hasattr(alpha, "__len__") and len(alpha) > 1:
        return (
            alpha_neg1 * np.log(alpha)
            - logbinom
            + (counts + alpha_neg1) * np.log(mu + alpha_neg1)
            - (counts * np.log(mu))
        ).sum(0)
    else:
        return (
            n * alpha_neg1 * np.log(alpha)
            + (
                -logbinom
                + (counts + alpha_neg1) * np.log(alpha_neg1 + mu)
                - counts * np.log(mu)
            ).sum()
        )

In [363]:
def vec_nb_nll(counts: np.ndarray, mu: np.ndarray, alpha: np.ndarray) -> np.ndarray:
    n = len(counts)
    alpha_neg1 = 1 / alpha
    logbinom = (
        gammaln(counts[:, None] + alpha_neg1)
        - gammaln(counts + 1)[:, None]
        - gammaln(alpha_neg1)
    )

    if len(mu.shape) == 1:
        return n * alpha_neg1 * np.log(alpha) + (
            -logbinom
            + (counts[:, None] + alpha_neg1) * np.log(mu[:, None] + alpha_neg1)
            - (counts * np.log(mu))[:, None]
        ).sum(0)
    else:
        return n * alpha_neg1 * np.log(alpha) + (
            -logbinom
            + (counts[:, None] + alpha_neg1) * np.log(mu + alpha_neg1)
            - (counts[:, None] * np.log(mu))
        ).sum(0)

In [364]:
def grid_fit_beta(
    counts: np.ndarray,
    size_factors: np.ndarray,
    cnv: np.ndarray,
    design_matrix: np.ndarray,
    disp: float,
    min_mu: float = 0.5,
    grid_length: int = 60,
    min_beta: float = -30,
    max_beta: float = 30,
) -> np.ndarray:
    x_grid = np.linspace(min_beta, max_beta, grid_length)
    y_grid = np.linspace(min_beta, max_beta, grid_length)
    ll_grid = np.zeros((grid_length, grid_length))

    def loss(beta: np.ndarray) -> np.ndarray:
        # closure to minimize
        design_matrix_t = design_matrix.T
        mu = np.maximum(size_factors[:, None] * np.exp(design_matrix_t[0] * beta[0] + design_matrix_t[1] * beta[1]), min_mu)
        return vec_nb_nll(counts, mu, disp) + 0.5 * (1e-6 * beta**2).sum(1)

    for i, x in enumerate(x_grid):
        ll_grid[i, :] = loss(np.array([[x, y] for y in y_grid]))

    min_idxs = np.unravel_index(np.argmin(ll_grid, axis=None), ll_grid.shape)
    delta = x_grid[1] - x_grid[0]

    fine_x_grid = np.linspace(
        x_grid[min_idxs[0]] - delta, x_grid[min_idxs[0]] + delta, grid_length
    )

    fine_y_grid = np.linspace(
        y_grid[min_idxs[1]] - delta,
        y_grid[min_idxs[1]] + delta,
        grid_length,
    )

    for i, x in enumerate(fine_x_grid):
        ll_grid[i, :] = loss(np.array([[x, y] for y in fine_y_grid]))

    min_idxs = np.unravel_index(np.argmin(ll_grid, axis=None), ll_grid.shape)
    beta = np.array([fine_x_grid[min_idxs[0]], fine_y_grid[min_idxs[1]]])
    return beta

In [376]:
def irls_solver(
    counts: np.ndarray,
    size_factors: np.ndarray,
    design_matrix: np.ndarray,
    cnv: np.ndarray,
    disp: float,
    min_mu: float = 0.5,
    beta_tol: float = 1e-8,
    min_beta: float = -30,
    max_beta: float = 30,
    optimizer: Literal["BFGS", "L-BFGS-B"] = "L-BFGS-B",
    maxiter: int = 250,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, bool]:

    assert optimizer in ["BFGS", "L-BFGS-B"]
    
    num_vars = design_matrix.shape[1]
    X = design_matrix
    
    # if full rank, estimate initial betas for IRLS below
    if np.linalg.matrix_rank(X) == num_vars:
        Q, R = np.linalg.qr(X)
        y = np.log(counts / size_factors + 0.1)
        beta_init = solve(R, Q.T @ y)
        beta = beta_init
    else:  # Initialise intercept with log base mean
        beta_init = np.zeros(num_vars)
        beta_init[0] = np.log(counts / size_factors).mean()
        beta = beta_init
        
    dev = 1000.0
    dev_ratio = 1.0

    ridge_factor = np.diag(np.repeat(1e-6, num_vars))
    #mu = np.maximum(size_factors * np.exp(X @ beta), min_mu)

    X_t = X.T
    mu = np.maximum(size_factors * np.exp(X_t[0] * beta[0] * cnv + X_t[1] * beta[1]), min_mu)
    
    converged = True
    i = 0
    while dev_ratio > beta_tol:
        W = mu / (1.0 + mu * disp)
        z = np.log(mu / size_factors) + (counts - mu) / mu
        H = (X.T * W) @ X + ridge_factor
        beta_hat = solve(H, X.T @ (W * z), assume_a="pos")
        i += 1

        if sum(np.abs(beta_hat) > max_beta) > 0 or i >= maxiter:
            # If IRLS starts diverging, use L-BFGS-B
            def f(beta: np.ndarray) -> float:
                # closure to minimize
                #mu_ = np.maximum(size_factors * np.exp(X @ beta), min_mu)
                mu_ = np.maximum(size_factors * np.exp(X_t[0] * beta[0] * cnv + X_t[1] * beta[1]), min_mu)
                
                return nb_nll(counts, mu_, disp) + 0.5 * (ridge_factor @ beta**2).sum()

            def df(beta: np.ndarray) -> np.ndarray:
                #mu_ = np.maximum(size_factors * np.exp(X @ beta), min_mu)
                mu_ = np.maximum(size_factors * np.exp(X_t[0] * beta[0] * cnv + X_t[1] * beta[1]), min_mu)
                return (
                    -X.T @ counts
                    + ((1 / disp + counts) * mu_ / (1 / disp + mu_)) @ X
                    + ridge_factor @ beta
                )

            res = minimize(
                f,
                beta_init,
                jac=df,
                method=optimizer,
                bounds=(
                    [(min_beta, max_beta)] * num_vars
                    if optimizer == "L-BFGS-B"
                    else None
                ),
            )

            beta = res.x
            #mu = np.maximum(size_factors * np.exp(X @ beta), min_mu)
            mu = np.maximum(size_factors * np.exp(X_t[0] * beta[0] * cnv + X_t[1] * beta[1]), min_mu)
            converged = res.success

            if not res.success and num_vars <= 2:
                beta = grid_fit_beta(
                    counts,
                    size_factors,
                    X,
                    disp,
                )
                #mu = np.maximum(size_factors * np.exp(X @ beta), min_mu
                mu = np.maximum(size_factors * np.exp(X_t[0] * beta[0] * cnv + X_t[1] * beta[1]), min_mu) 
            break

        beta = beta_hat
        #mu = np.maximum(size_factors * np.exp(X @ beta), min_mu)
        mu = np.maximum(size_factors * np.exp(X_t[0] * beta[0] * cnv + X_t[1] * beta[1]), min_mu)
        # Compute deviation
        old_dev = dev
        # Replaced deviation with -2 * nll, as in the R code
        dev = -2 * nb_nll(counts, mu, disp)
        dev_ratio = np.abs(dev - old_dev) / (np.abs(dev) + 0.1)

    # Compute H diagonal (useful for Cook distance outlier filtering)
    W = mu / (1.0 + mu * disp)
    W_sq = np.sqrt(W)
    XtWX = (X.T * W) @ X + ridge_factor
    H = W_sq * np.diag(X @ np.linalg.inv(XtWX) @ X.T) * W_sq
    # Return an UNthresholded mu (as in the R code)
    # Previous quantities are estimated with a threshold though
    #mu = size_factors * np.exp(X @ beta)
    mu = size_factors * np.exp(X_t[0] * beta[0] * cnv + X_t[1] * beta[1])

    print("Beta parameters:", beta), 
    print("Estimated mean:", np.array(mu)), 
    print("H:", np.array(H)),
    print("Convergence:", converged)
    
    return beta, mu, H, converged

In [300]:
# Classical GLM
irls_solver(counts, size_factors, design_matrix, cnv, disp)

Beta parameters: [ 7.72027975 -0.62936539]
Estimated mean: [2635.58908755 2448.2384661  2591.90653697 2691.58419434 2451.72382321
 2651.97886119 2627.27753238 1752.08227073 1411.04896746 1094.3133611
 1078.51190055 1273.33038124 1087.53674791 1125.54435483  945.39054701
 1199.08173831]
H: [0.12500282 0.12500012 0.12500222 0.12500355 0.12500017 0.12500303
 0.1250027  0.12498505 0.12501583 0.12499678 0.12499553 0.12500871
 0.12499625 0.12499913 0.12498341 0.12500419]
Convergence: True


(array([ 7.72027975, -0.62936539]),
 array([2635.58908755, 2448.2384661 , 2591.90653697, 2691.58419434,
        2451.72382321, 2651.97886119, 2627.27753238, 1752.08227073,
        1411.04896746, 1094.3133611 , 1078.51190055, 1273.33038124,
        1087.53674791, 1125.54435483,  945.39054701, 1199.08173831]),
 array([0.12500282, 0.12500012, 0.12500222, 0.12500355, 0.12500017,
        0.12500303, 0.1250027 , 0.12498505, 0.12501583, 0.12499678,
        0.12499553, 0.12500871, 0.12499625, 0.12499913, 0.12498341,
        0.12500419]),
 True)

In [377]:
# CN normalized GLM
irls_solver(counts, size_factors, design_matrix, cnv, disp)

Beta parameters: [4.20093517 2.88637762]
Estimated mean: [ 5210.61215201 39544.51424337   627.2041216   5321.31559408
  4847.10686007   641.74076048  5194.18004179  3463.90156732
  1405.97617152  1090.37924648  1074.63459302  1268.75269093
  1083.62699557  1121.49796313   941.99181781  1194.77097582]
H: [0.1250314  0.12504689 0.12490121 0.12503178 0.12503007 0.12490456
 0.12503135 0.12502241 0.12501588 0.12499677 0.12499552 0.12500874
 0.12499624 0.12499913 0.12498335 0.12500421]
Convergence: True


(array([4.20093517, 2.88637762]),
 array([ 5210.61215201, 39544.51424337,   627.2041216 ,  5321.31559408,
         4847.10686007,   641.74076048,  5194.18004179,  3463.90156732,
         1405.97617152,  1090.37924648,  1074.63459302,  1268.75269093,
         1083.62699557,  1121.49796313,   941.99181781,  1194.77097582]),
 array([0.1250314 , 0.12504689, 0.12490121, 0.12503178, 0.12503007,
        0.12490456, 0.12503135, 0.12502241, 0.12501588, 0.12499677,
        0.12499552, 0.12500874, 0.12499624, 0.12499913, 0.12498335,
        0.12500421]),
 True)