In [697]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
import statsmodels.api as sm
import statsmodels.formula.api as smf
import pandas as pd
from scipy.optimize import minimize
from dataclasses import dataclass
import numpy as np
@dataclass
class Simulation:
    X: np.ndarray
    Y: np.ndarray
    Sigma: np.ndarray
    W: np.ndarray
    scale_re: float
    g: np.ndarray
    G: float
    re: np.ndarray
    D: np.ndarray
    

def correlation_from_covariance(covariance):
    v = np.sqrt(np.diag(covariance))
    outer_v = np.outer(v, v)
    correlation = covariance / outer_v
    correlation[covariance == 0] = 0
    return correlation

import numpy as np,numpy.linalg

def _getAplus(A):
    eigval, eigvec = np.linalg.eig(A)
    Q = np.matrix(eigvec)
    xdiag = np.matrix(np.diag(np.maximum(eigval, 0)))
    return Q*xdiag*Q.T

def _getPs(A, W=None):
    W05 = np.matrix(W**.5)
    return  W05.I * _getAplus(W05 * A * W05) * W05.I

def _getPu(A, W=None):
    Aret = np.array(A.copy())
    Aret[W > 0] = np.array(W)[W > 0]
    return np.matrix(Aret)

def nearPD(A, nit=50):
    n = A.shape[0]
    W = np.identity(n) 
# W is the matrix used for the norm (assumed to be Identity matrix here)
# the algorithm should work for any diagonal W
    deltaS = 0
    Yk = A.copy()
    for k in range(nit):
        Rk = Yk - deltaS
        Xk = _getPs(Rk, W=W)
        deltaS = Xk - Rk
        Yk = _getPu(Xk, W=W)
    return Yk

def simulate(N=100, E=2, SP=3, G = 20, scale_re = 0.5, CAR=False):
    if CAR is not True:
        re = np.random.normal(0, scale_re, [G,1])
        g = np.repeat(np.arange(0,G), np.round(N/G))
        D = np.NaN
    else:
        D, Sigma, re = simulate_CAR(G, scale_re)
        g = np.arange(0, G**2)
        re = re.reshape([-1, 1])
    
    X = np.random.uniform(-1, 1, size=[N,E] )
    W = np.random.normal(size=[E,SP])
    Y = X@W + np.reshape(re[g,:], newshape=[N,1])
    
    SS = np.random.uniform(-1,1,size=[SP, SP])
    SS = correlation_from_covariance(SS@np.transpose(SS))
    YY = np.concatenate([np.random.multivariate_normal(Y[i,:], cov = SS, size=[1,]) for i in range(N)], 0)
    YY = YY > 0
    YY = YY.astype(np.float64)
    return Simulation(X, YY, SS, W, scale_re, g, G, re, D)

def simulate_CAR(sides = 10, l = 0.7):
    row_coords = np.tile(np.arange(0, sides), sides).reshape([-1, 1])
    col_coords = np.tile(np.arange(0, sides), sides).reshape([-1, 1])
    d = np.concatenate([row_coords, col_coords], axis= 1)
    D = np.sum((d[:, np.newaxis, :] - d[np.newaxis, :, :]) ** 2, axis = -1)
    Sigma = nearPD(np.exp(-l*D))+np.eye(sides**2)*0.01
    re= np.random.multivariate_normal(np.zeros([sides**2]), cov=Sigma,size=[1])
    return D, Sigma, re

In [375]:
data = simulate(SP=4,N=1000,scale_re=0.01,G = 10)
E = data.X.shape[1]
N = data.X.shape[0]
SP = data.Y.shape[1]



In [716]:
def fit_model(data, 
              outer_epochs=12, 
              inner_epochs=3, 
              device = "cpu:0", 
              optim = "torch", 
              likelihood_type="mvp", 
              det=True,
              CAR=False
             ):
    E = data.X.shape[1]
    N = data.X.shape[0]
    SP = data.Y.shape[1]
    X, Y, G, indices, D = data.X, data.Y, data.G, data.g, data.D
    if CAR is True:
        G = G**2
    dev = torch.device(device)
    XT = torch.tensor(X, dtype=torch.float32, device=dev)
    YT = torch.tensor(Y, dtype=torch.float32, device=dev)
    torch.autograd.set_detect_anomaly(True)
    W = torch.tensor(np.random.normal(0., 0.001, size=(XT.shape[1],Y.shape[1])), dtype=torch.float32, requires_grad=True, device=dev)
    r_dim = Y.shape[1]
    df = int(np.rint(Y.shape[1]/2))
    low = -np.sqrt(6.0/(r_dim+df)) # type: ignore
    high = np.sqrt(6.0/(r_dim+df)) # type: ignore      
    sigma = torch.tensor(np.random.uniform(low, high, [r_dim, df]), requires_grad = True, dtype=torch.float32, device=dev) # type: ignore

    #@torch.jit.script
    def likelihood(mu: torch.Tensor, Ys: torch.Tensor, sigma: torch.Tensor, batch_size: int, sampling: int, df: int, alpha: float, device: str, dtype: torch.dtype):
        noise = torch.randn(size = [sampling, batch_size, df], device=torch.device(device), dtype=dtype)
        E = torch.sigmoid(   torch.einsum("ijk, lk -> ijl", [noise, sigma]).add(mu).mul(alpha)   ).mul(0.999999).add(0.0000005)
        logprob = E.log().mul(Ys).add((1.0 - E).log().mul(1.0 - Ys)).neg().sum(dim = 2).neg()
        maxlogprob = logprob.max(dim = 0).values
        Eprob = logprob.sub(maxlogprob).exp().mean(dim = 0)
        loss = Eprob.log().neg().sub(maxlogprob)
        return loss

    res = torch.tensor(np.random.normal(0, 0.01, size=(G, 1)), dtype=torch.float32, requires_grad=True, device=dev)
    scale_log = torch.tensor(np.random.normal(0.0,0.001, [1]), dtype=torch.float32, requires_grad=True, device=dev)
    zero_intercept = torch.zeros([1], dtype=torch.float32, device=dev)
    zero_CAR = torch.zeros([G], dtype=torch.float32, device=dev)

    opt1 = torch.optim.RMSprop([W, scale_log, sigma],lr=0.01)

    if CAR is True:
        D = torch.tensor(D, dtype=torch.float32, device=dev)
    opt2 = torch.optim.LBFGS([res], lr = 0.1)
    
    const_val = torch.tensor(0.0, dtype=torch.float32, device=dev)
    const_cov = torch.eye(G, dtype=torch.float32, device=dev)*0.01

    soft = lambda t: torch.nn.functional.softplus(t)+0.0001

    indices_T = torch.tensor(indices, dtype=torch.long)
    
    def re_loss():
        if CAR is not True:
            loss = -torch.distributions.Normal(zero_intercept, soft(scale_log)).log_prob(res).sum()
        else:
            loss = -torch.distributions.MultivariateNormal(zero_CAR, (-soft(scale_log)*D).exp()+const_cov).log_prob(res.reshape([1,-1])).sum()
        return loss

    
    def ll(res, W,sigma, XT, YT, indices_T):
        pred = XT@W+res[indices_T,:]#*scale_log.exp()
        #loss = -torch.distributions.Normal(loc=pred, scale=soft(scale_log_2) ).log_prob(YT).sum()
        if likelihood_type == "mvp":
            loss = likelihood(pred, YT, sigma, XT.shape[0], 100, df, 1.7012, device, torch.float32).sum()
        else:
            loss = -torch.distributions.Binomial(total_count=1, probs= torch.sigmoid(pred*1.7012) ).log_prob(YT).sum()
        loss += re_loss()
        return loss
    
    def torch_optim(epoch):
        if epoch % inner_epochs == 0:
            for i in range(20):
                opt2.zero_grad()
                loss = ll(res, W,sigma, XT, YT, indices_T)
                loss.backward()
                opt2.step(lambda: ll(res, W,sigma, XT, YT, indices_T))
            opt2.zero_grad()
    
    def minimize_func(res2):
        res2 = torch.tensor(res2, dtype=torch.float32, device=dev).reshape([G, 1])
        loss = ll(res2, W,sigma, XT, YT, indices_T)
        return loss.cpu().data.numpy()
    
    def hessian(res2):
        res2 = torch.tensor(res2, dtype=torch.float32, requires_grad=True, device=dev)
        loss = ll(res2.reshape([G, 1]), W, sigma, XT, YT, indices_T)
        loss.backward()
        return res2.grad.cpu().data.numpy()
    
    def scipy_optim(res, epoch, const_val):
        if epoch % inner_epochs == 0:
            res_new = minimize(
                        minimize_func, 
                        res.cpu().data.numpy().reshape([-1]), 
                        jac=hessian, 
                        method="BFGS"
                    )
            res = torch.tensor(res_new.x, dtype=torch.float32, requires_grad=True, device=dev).reshape([G, 1])
            #const_val = torch.tensor( 0.5*(np.log((2*np.pi)**(res_new.hess_inv.shape[0])) - np.log(np.linalg.det(res_new.hess_inv))), dtype=torch.float32, device=dev)
            #print(const_val)
        return res, const_val
    
    if det is not True:
        const_val = torch.tensor(0.0, dtype=torch.float32, device=dev)

    for epoch in range(outer_epochs):
        if optim == "torch": 
            torch_optim(epoch)
        else:
            res, const_val = scipy_optim(res, epoch, const_val)
        
        sample_indices = np.random.randint(0, XT.shape[0], size = 20)

        opt1.zero_grad()        
        pred = XT[sample_indices,:]@W+res[indices,:][sample_indices,:]#*scale_log.exp()
        loss = likelihood(pred, YT[sample_indices,:], sigma, pred.shape[0], 100, df, 1.7012, device, torch.float32).mean() + const_val
        loss += re_loss()
        loss.backward()
        opt1.step()
        opt1.zero_grad()
    avg_scale = soft(scale_log).mean().cpu().data.numpy()
    avg_acc = np.mean((np.sign(correlation_from_covariance((sigma @ sigma.t()).data.cpu().numpy())) == np.sign(data.Sigma) )[np.triu_indices(SP)])
    del opt1
    del W
    del XT
    del YT
    del sigma
    #torch.cuda.empty_cache()
    return avg_scale.tolist(), avg_acc

In [723]:
data = simulate(SP=2,N=400,scale_re=0.3,G = 20, CAR=True)

  re= np.random.multivariate_normal(np.zeros([sides**2]), cov=Sigma,size=[1])


In [725]:
fit_model(data, 550, 20, optim="scipy", likelihood_type="mvp", CAR=True)

(0.6581300497055054, 0.6666666666666666)

In [580]:
results_y = [fit_model(data, 200, 12, optim="scipy", likelihood_type="mvp", det=True)  for _ in range(15)]
#results_n = [fit_model(data, 60, 5, optim="scipy", likelihood_type="mvp", det=False) for _ in range(60)]

In [581]:
print( np.mean(np.asarray( results_y ), axis = 0) )
print( np.std(np.asarray( results_y ), axis = 0) )


[4.13961991 0.51632461]
[0.3287191 0.0027272]
