## Playground for random effect integration in sjSDM
Dependencies + sjSDM simulation including:

* random intercepts per sites
* random intercepts per sites and species
* random intercepts following CAR model

In [1]:
import torch
import numpy as np
import numpy as np, numpy.linalg
import matplotlib.pyplot as plt
from dataclasses import dataclass
import statsmodels.api as sm
import statsmodels.formula.api as smf
import pandas as pd
from scipy.optimize import minimize
from dataclasses import dataclass
import numpy as np

@dataclass
class Simulation:
    X: np.ndarray
    Y: np.ndarray
    Sigma: np.ndarray
    W: np.ndarray
    scale_re: float
    g: np.ndarray
    G: float
    re: np.ndarray
    D: np.ndarray
    

def correlation_from_covariance(covariance):
    v = np.sqrt(np.diag(covariance))
    outer_v = np.outer(v, v)
    correlation = covariance / outer_v
    correlation[covariance == 0] = 0
    return correlation

def _getAplus(A):
    eigval, eigvec = np.linalg.eig(A)
    Q = np.matrix(eigvec)
    xdiag = np.matrix(np.diag(np.maximum(eigval, 0)))
    return Q*xdiag*Q.T

def _getPs(A, W=None):
    W05 = np.matrix(W**.5)
    return  W05.I * _getAplus(W05 * A * W05) * W05.I

def _getPu(A, W=None):
    Aret = np.array(A.copy())
    Aret[W > 0] = np.array(W)[W > 0]
    return np.matrix(Aret)

def nearPD(A, nit=50):
    n = A.shape[0]
    W = np.identity(n) 
# W is the matrix used for the norm (assumed to be Identity matrix here)
# the algorithm should work for any diagonal W
    deltaS = 0
    Yk = A.copy()
    for k in range(nit):
        Rk = Yk - deltaS
        Xk = _getPs(Rk, W=W)
        deltaS = Xk - Rk
        Yk = _getPu(Xk, W=W)
    return Yk

def simulate_CAR(sides = 10, l = 0.7):
    row_coords = np.tile(np.arange(0, sides), sides).reshape([-1, 1])
    col_coords = np.tile(np.arange(0, sides), sides).reshape([-1, 1])
    d = np.concatenate([row_coords, col_coords], axis= 1)
    D = np.sum((d[:, np.newaxis, :] - d[np.newaxis, :, :]) ** 2, axis = -1)
    Sigma = nearPD(np.exp(-l*D))+np.eye(sides**2)*0.01
    re= np.random.multivariate_normal(np.zeros([sides**2]), cov=Sigma,size=[1])
    return D, Sigma, re

def simulate(N=100, E=2, SP=3, G = 20, scale_re = 0.5, CAR=False):
    if CAR is not True:
        re = np.random.normal(0, scale_re, [G,1])
        g = np.repeat(np.arange(0,G), np.round(N/G))
        D = np.NaN
    else:
        D, Sigma, re = simulate_CAR(G, scale_re)
        g = np.arange(0, G**2)
        re = re.reshape([-1, 1])
    
    X = np.random.uniform(-1, 1, size=[N,E] )
    W = np.random.normal(size=[E,SP])
    Y = X@W + np.reshape(re[g,:], newshape=[N,1])
    
    SS = np.random.uniform(-1,1,size=[SP, SP])
    SS = correlation_from_covariance(SS@np.transpose(SS))
    YY = np.concatenate([np.random.multivariate_normal(Y[i,:], cov = SS, size=[1,]) for i in range(N)], 0)
    YY = YY > 0
    YY = YY.astype(np.float64)
    return Simulation(X, YY, SS, W, scale_re, g, G, re, D)

`fit_model` estimates a MVP jSDM based on the sjSDM approach and returns the standard deviation of the random effects and the covariance accuracy

In [12]:
def fit_model(data, 
              outer_epochs=12, 
              inner_epochs=3, 
              device = "cpu:0", 
              optim = "torch", 
              likelihood_type="mvp", 
              det=True,
              CAR=False
             ):
    E = data.X.shape[1]
    N = data.X.shape[0]
    SP = data.Y.shape[1]
    X, Y, G, indices, D = data.X, data.Y, data.G, data.g, data.D
    if CAR is True:
        G = G**2
    dev = torch.device(device)
    XT = torch.tensor(X, dtype=torch.float32, device=dev)
    YT = torch.tensor(Y, dtype=torch.float32, device=dev)
    torch.autograd.set_detect_anomaly(True)
    W = torch.tensor(np.random.normal(0., 0.001, size=(XT.shape[1],Y.shape[1])), dtype=torch.float32, requires_grad=True, device=dev)
    r_dim = Y.shape[1]
    df = int(np.rint(Y.shape[1]/2))
    low = -np.sqrt(6.0/(r_dim+df)) # type: ignore
    high = np.sqrt(6.0/(r_dim+df)) # type: ignore      
    sigma = torch.tensor(np.random.uniform(low, high, [r_dim, df]), requires_grad = True, dtype=torch.float32, device=dev) # type: ignore

    @torch.jit.script
    def likelihood(mu: torch.Tensor, Ys: torch.Tensor, sigma: torch.Tensor, batch_size: int, sampling: int, df: int, alpha: float, device: str, dtype: torch.dtype):
        noise = torch.randn(size = [sampling, batch_size, df], device=torch.device(device), dtype=dtype)
        E = torch.sigmoid(   torch.einsum("ijk, lk -> ijl", [noise, sigma]).add(mu).mul(alpha)   ).mul(0.999999).add(0.0000005)
        logprob = E.log().mul(Ys).add((1.0 - E).log().mul(1.0 - Ys)).neg().sum(dim = 2).neg()
        maxlogprob = logprob.max(dim = 0).values
        Eprob = logprob.sub(maxlogprob).exp().mean(dim = 0)
        return Eprob.log().neg().sub(maxlogprob)

    scale_log = torch.tensor(np.random.normal(0.0,0.001, [1]), dtype=torch.float32, requires_grad=True, device=dev)
    res = torch.tensor(np.random.normal(0.0,0.001, [G, 1]), dtype=torch.float32, requires_grad=True, device=dev)
    zero_intercept = torch.zeros([1], dtype=torch.float32, device=dev)
    zero_CAR = torch.zeros([G], dtype=torch.float32, device=dev)

    opt1 = torch.optim.RMSprop([W, scale_log, sigma],lr=0.01)

    if CAR is True:
        D = torch.tensor(D, dtype=torch.float32, device=dev)
    opt2 = torch.optim.LBFGS([res], lr = 0.1)
    
    const_val = torch.tensor(0.0, dtype=torch.float32, device=dev)
    const_cov = torch.eye(G, dtype=torch.float32, device=dev)*0.01

    soft = lambda t: torch.nn.functional.softplus(t)+0.0001

    indices_T = torch.tensor(indices, dtype=torch.long)
    
    def re_loss():
        if CAR is not True:
            return -torch.distributions.Normal(zero_intercept, soft(scale_log)).log_prob(res).sum()
        else:
            return -torch.distributions.MultivariateNormal(zero_CAR, (-soft(scale_log)*D).exp()+const_cov).log_prob(res.reshape([1,-1])).sum()

    
    def ll(res, W,sigma, XT, YT, indices_T):
        pred = XT@W+res[indices_T,:]#*scale_log.exp()
        #loss = -torch.distributions.Normal(loc=pred, scale=soft(scale_log_2) ).log_prob(YT).sum()
        if likelihood_type == "mvp":
            loss = likelihood(pred, YT, sigma, XT.shape[0], 100, df, 1.7012, device, torch.float32).sum()
        else:
            loss = -torch.distributions.Binomial(total_count=1, probs= torch.sigmoid(pred*1.7012) ).log_prob(YT).sum()
        loss += re_loss()
        return loss
    
    def torch_optim(epoch):
        if epoch % inner_epochs == 0:
            for _ in range(20):
                opt2.zero_grad()
                loss = ll(res, W,sigma, XT, YT, indices_T)
                loss.backward()
                opt2.step(lambda: ll(res, W,sigma, XT, YT, indices_T))
            opt2.zero_grad()
    
    def minimize_func(res2):
        res2 = torch.tensor(res2, dtype=torch.float32, device=dev).reshape([G, 1])
        loss = ll(res2, W,sigma, XT, YT, indices_T)
        return loss.cpu().data.numpy()
    
    def hessian(res2):
        res2 = torch.tensor(res2, dtype=torch.float32, requires_grad=True, device=dev)
        loss = ll(res2.reshape([G, 1]), W, sigma, XT, YT, indices_T)
        loss.backward()
        return res2.grad.cpu().data.numpy()
    
    def scipy_optim(res, epoch, const_val):
        if epoch % inner_epochs == 0:
            res_new = minimize(
                        minimize_func, 
                        res.cpu().data.numpy().reshape([-1]), 
                        jac=hessian, 
                        method="BFGS"
                    )
            res = torch.tensor(res_new.x, dtype=torch.float32, requires_grad=True, device=dev).reshape([G, 1])
            #const_val = torch.tensor( 0.5*(np.log((2*np.pi)**(res_new.hess_inv.shape[0])) - np.log(np.linalg.det(res_new.hess_inv))), dtype=torch.float32, device=dev)
            #print(const_val)
        return res, const_val
    
    if det is not True:
        const_val = torch.tensor(0.0, dtype=torch.float32, device=dev)
    for epoch in range(outer_epochs):
        if optim == "torch": 
            torch_optim(epoch)
        else:
            res, const_val = scipy_optim(res, epoch, const_val)
        
        sample_indices = np.random.randint(0, XT.shape[0], size = 20)

        opt1.zero_grad()        
        pred = XT[sample_indices,:]@W+res[indices,:][sample_indices,:]#*scale_log.exp()
        loss = likelihood(pred, YT[sample_indices,:], sigma, pred.shape[0], 100, df, 1.7012, device, torch.float32).mean() + const_val
        loss += re_loss()
        loss.backward()
        opt1.step()
        opt1.zero_grad()
    avg_scale = soft(scale_log).mean().cpu().data.numpy()
    avg_acc = np.mean((np.sign(correlation_from_covariance((sigma @ sigma.t()).data.cpu().numpy())) == np.sign(data.Sigma) )[np.triu_indices(SP)])
    del opt1
    del W
    del XT
    del YT
    del sigma
    return avg_scale.tolist(), avg_acc

In [11]:
data = simulate(SP=2,N=5000,scale_re=0.5,G = 100, CAR=False)
#fit_model(data, CAR=True)
data.re

array([[-0.40456842],
       [-0.47768778],
       [-0.57108772],
       [-0.21552874],
       [ 0.71881768],
       [ 0.31674885],
       [ 1.42935109],
       [-0.69193213],
       [ 0.18393983],
       [ 0.34073885],
       [ 0.20831082],
       [ 0.05193391],
       [-0.06783768],
       [-0.24047788],
       [ 0.3399139 ],
       [-0.52989648],
       [-0.02116909],
       [ 0.23199976],
       [ 0.02096044],
       [ 0.3455103 ],
       [ 0.06003734],
       [-0.93551233],
       [-0.39061279],
       [-0.34171041],
       [ 0.65986416],
       [ 0.7608267 ],
       [-0.36777635],
       [-1.2574748 ],
       [-0.25170756],
       [-0.3295159 ],
       [-0.31556222],
       [-0.77594736],
       [-0.14177741],
       [-0.52532314],
       [-0.09569401],
       [ 1.15591509],
       [-0.08016586],
       [ 0.82441512],
       [ 0.17998155],
       [ 1.08525601],
       [ 0.14132331],
       [ 0.65178794],
       [-0.41840465],
       [-0.45944114],
       [ 0.04853268],
       [ 0

In [12]:
fit_model(data, CAR=False, det=True)

[55.80103302001953, 21.666128158569336]
[59.56851577758789, 26.515663146972656]
[53.23011016845703, 21.068098068237305]
[48.70780944824219, 18.318222045898438]
[49.03778076171875, 20.532358169555664]
[47.15538787841797, 19.405576705932617]
[47.09665298461914, 19.003772735595703]
[48.80914306640625, 20.8721923828125]
[51.730995178222656, 24.477840423583984]
[38.06648635864258, 12.463699340820312]


[0.20226380921304, 0.458190164843836, 0.4199224385975236, 1.0]

In [75]:
def fit_model(data: Simulation, det=True, CAR=False, device = "cpu:0", batch_size=50, ll="MVP", intercept="sites", df=None) -> list:
     N, E = data.X.shape
     SP = data.Y.shape[1]
     X, Y, G, indices, D = data.X, data.Y, data.G, data.g, data.D
     dev = torch.device(device)
     r_dim = Y.shape[1]
     if df is None:
          df = int(np.rint(Y.shape[1]/2))
     low = -np.sqrt(6.0/(r_dim+df)) 
     high = np.sqrt(6.0/(r_dim+df))     
     XT = torch.tensor(X, dtype=torch.float32, device=torch.device("cpu:0"))
     YT = torch.tensor(Y, dtype=torch.float32, device=torch.device("cpu:0"))
     indices_T = torch.tensor(indices, dtype=torch.long, device=torch.device("cpu:0"))
     init_scale = 10.0
     if CAR is True:
          G = G**2
          D = torch.tensor(D, dtype=torch.float32, device=dev)
          init_scale = 0.0
     # Variables
     W = torch.tensor(np.random.normal(0.0,0.001, [XT.shape[1], YT.shape[1]]), dtype=torch.float32, device=dev, requires_grad=True)
     
     if intercept is not "species":
          scale_log = torch.tensor(1.0, dtype=torch.float32, requires_grad=True, device=dev)
          res = torch.tensor(np.random.normal(0.0,0.001, [G, 1]), dtype=torch.float32, requires_grad=True, device=dev)
     if intercept is "species":
          res = torch.tensor(np.random.normal(0.0,0.001, [1, SP]), dtype=torch.float32, requires_grad=True, device=dev)
          scale_log = torch.tensor(np.random.normal(0.0,0.001, [SP]), dtype=torch.float32, requires_grad=True, device=dev)
          init_scale = torch.ones_like(scale_log)
     sigma = torch.tensor(np.random.uniform(low, high, [r_dim, df]), requires_grad = True, dtype=torch.float32, device=dev) # type: ignore

     soft = lambda t: torch.nn.functional.softplus(t)+0.0001
     zero_intercept = torch.zeros([1], dtype=torch.float32, device=dev)
     zero_CAR = torch.zeros([G], dtype=torch.float32, device=dev)
     loss2 = torch.zeros([1], dtype=torch.float32, device=dev)
     adapt = torch.tensor(np.rint(XT.shape[0]/batch_size).tolist(), dtype=torch.float32, device=dev)


     @torch.jit.script
     def MVP(mu: torch.Tensor, Ys: torch.Tensor, sigma: torch.Tensor, batch_size: int, sampling: int, df: int, alpha: float, device: str, dtype: torch.dtype):
         noise = torch.randn(size = [sampling, batch_size, df], device=torch.device(device), dtype=dtype)
         E = torch.sigmoid(   torch.einsum("ijk, lk -> ijl", [noise, sigma]).add(mu).mul(alpha)   ).mul(0.999999).add(0.0000005)
         logprob = E.log().mul(Ys).add((1.0 - E).log().mul(1.0 - Ys)).neg().sum(dim = 2).neg()
         maxlogprob = logprob.max(dim = 0).values
         Eprob = logprob.sub(maxlogprob).exp().mean(dim = 0)
         return Eprob.log().neg().sub(maxlogprob)
    
     def Binomial(mu: torch.Tensor, Ys: torch.Tensor, sigma: torch.Tensor, batch_size: int, sampling: int, df: int, alpha: float, device: str, dtype: torch.dtype):
          return - torch.distributions.Binomial(1, mu.sigmoid()).log_prob(Ys)
     
     if ll is "MVP":
          likelihood = MVP
     else:
          likelihood = Binomial

     def ll(res, W, sigma, XT, YT, indices_T, scale_log):
          if intercept is not "species":
               pred = XT@W+res[indices_T,:]
          else:
               pred = XT@W+res
          loss = likelihood(pred, YT, sigma, XT.shape[0], 100, df, 1.7012, device, torch.float32).sum()/XT.shape[0]
          if intercept is "sites":
               loss += -torch.distributions.Normal(zero_intercept, (scale_log.exp())).log_prob(res[indices_T.unique()]).sum()/adapt/XT.shape[0]
          if intercept is "species":
               loss += -torch.distributions.LowRankMultivariateNormal(torch.zeros([YT.shape[1]], dtype=torch.float32, device=dev), sigma, scale_log.exp() ).log_prob(res).sum()/adapt/XT.shape[0]
          if CAR is True:
               ind2 = indices_T.unique()
               D_tmp = D.index_select(0, ind2).index_select(1, ind2)
               const_val = torch.eye(ind2.shape[0], device=dev, dtype=torch.float32)*0.001
               loss += -torch.distributions.MultivariateNormal(zero_CAR[ind2], (-(scale_log.exp())*D_tmp).exp()+const_val).log_prob(res[indices_T.unique()].reshape([1,-1])).sum()/adapt/XT.shape[0]
          return loss

     
     optimizer = torch.optim.Adamax([W, scale_log, sigma], lr = 0.03)
     optimizer_re = torch.optim.Adamax([res], lr = 0.03)

     dataset = torch.utils.data.TensorDataset(XT, YT, indices_T)
     dataLoader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
     for _ in range(50):
          for x, y, inds in dataLoader:
               optimizer_re.zero_grad()
               loss = ll(res, W.detach(),sigma.detach(), x.to(dev), y.to(dev), inds.to(dev),torch.tensor(init_scale, dtype=torch.float32, device=dev) )
               loss.backward(  )
               optimizer_re.step()
     optimizer_re.zero_grad()
     print(torch.std(res))

     for i in range(100):
          if i > 0:
               #print(torch.std(res))
               for x, y, inds in dataLoader:
                    optimizer_re.zero_grad()
                    loss = ll(res, W.detach(),sigma, x.to(dev), y.to(dev), inds.to(dev), scale_log.detach())#/x.shape[0]
                    loss.backward()
                    optimizer_re.step()
          optimizer_re.zero_grad()

          for x, y, inds in dataLoader:
               optimizer.zero_grad()
               loss = ll(res, W, sigma, x.to(dev), y.to(dev), inds.to(dev), scale_log)#/x.shape[0]
               if det is True:
                    loss.backward(  retain_graph=True )
                    gg = torch.autograd.grad(loss, res, retain_graph=True, create_graph=True)[0]
                    gg=gg[gg.nonzero(as_tuple=True)].reshape([-1,1])
                    logDA = torch.reciprocal(gg**2*gg.shape[0]).sqrt().reshape([-1]).diag().inverse().logdet()
                    #hess = torch.autograd.functional.hessian(lambda res: ll(res, W, sigma, x.to(dev), y.to(dev), inds.to(dev),scale_log), res, create_graph=True).squeeze()
                    #ind2 = inds.to(dev).unique()
                    #D_tmp = hess.index_select(0, ind2).index_select(1, ind2)
                    #const_val = torch.eye(ind2.shape[0], device=dev, dtype=torch.float32)*0.01
                    #logDA=(D_tmp+const_val).inverse().logdet()
                    loss2 = ((gg.shape[0]*0.5*torch.log((2*torch.tensor(3.14, dtype=torch.float32, device=dev))) - 0.5*logDA))/adapt/x.shape[0]
                    #loss2 = gg.shape[0]*0.5*torch.log((2*torch.tensor(3.14, dtype=torch.float32))) - 0.5*logDA
                    loss+=loss2
               loss = loss
               loss.backward()
               optimizer.step()

          optimizer.zero_grad()

          #if i % 10 == 0:
          #     print([loss.item(), loss2.item()])
               
     return [(scale_log.exp().mean()).cpu().data.numpy().tolist(), 
             np.mean((np.sign(correlation_from_covariance((sigma @ sigma.t()).data.cpu().numpy())) == np.sign(data.Sigma) )[np.triu_indices(SP)]), 
             correlation_from_covariance((sigma @ sigma.t() + scale_log.exp()).data.cpu().numpy()) ]

In [69]:
data = simulate(SP = 100, N = 1000  )

In [79]:
fit_model(data, ll="binomial", intercept="species", df = 40, batch_size = 200, det = False)



tensor(0.0823, grad_fn=<StdBackward0>)


[8.509558028890751e-06,
 0.5134653465346535,
 array([[1.        , 0.27305794, 0.3846035 , ..., 0.5822861 , 0.5183367 ,
         0.31661004],
        [0.2729554 , 1.        , 0.69973063, ..., 0.5865577 , 0.51317865,
         0.41813707],
        [0.38451353, 0.6997175 , 0.9999999 , ..., 0.7276242 , 0.624743  ,
         0.5473529 ],
        ...,
        [0.58230746, 0.58664626, 0.7276993 , ..., 1.0000001 , 0.58745205,
         0.64994854],
        [0.5183621 , 0.5132928 , 0.62483996, ..., 0.58745027, 1.        ,
         0.6510153 ],
        [0.316561  , 0.41817302, 0.5473907 , ..., 0.64989865, 0.65095186,
         1.0000001 ]], dtype=float32)]

In [39]:
torch.distributions.LowRankMultivariateNormal(torch.zeros(5), torch.ones([5,2]), torch.ones([5])).sample([1])

tensor([[ 0.1462, -0.2108,  1.5039, -0.3509,  0.9086]])