In [2]:
import torch
import numpy as np
import numpy as np, numpy.linalg
import matplotlib.pyplot as plt
from dataclasses import dataclass
import statsmodels.api as sm
import statsmodels.formula.api as smf
import pandas as pd
from scipy.optimize import minimize
from dataclasses import dataclass
import numpy as np

@dataclass
class Simulation:
    X: np.ndarray
    Y: np.ndarray
    W: np.ndarray
    g: np.ndarray
    G: np.ndarray
    re: np.ndarray
    sd_re: float
    
def simulate(N=100,E=2,G=10, sd_re=1.0):
    X = np.random.uniform(-1, 1, [N, E])
    W = np.random.normal(0, 1.0, [E, 1])
    Y = np.matmul(X, W) + np.random.normal(0.0, 0.8, [N,1])
    re = np.random.normal(0, sd_re, [G,1])
    g = np.repeat(np.arange(0,G), np.round(N/G))
    Y = Y+re[g,:]
    return Simulation(X, Y,W, g, G, re, sd_re)

In [147]:
#@ray.remote(num_gpus=1,num_cpus=2)
def fit_model(data: Simulation, det=True, CAR=False, device = "cuda:0", batch_size = 100, epochs = 40) -> list:
     N, E = data.X.shape
     SP = data.Y.shape[1]
     X, Y, G, indices = data.X, data.Y, data.G, data.g
     dev = torch.device(device)
    
     XT = torch.tensor(X, dtype=torch.float32, device=torch.device("cpu:0"))
     YT = torch.tensor(Y, dtype=torch.float32, device=torch.device("cpu:0"))
     indices_T = torch.tensor(indices, dtype=torch.long, device=torch.device("cpu:0"))

     # Variables
     W = torch.tensor(np.random.normal(0.0,0.001, [XT.shape[1], YT.shape[1]]), dtype=torch.float32, device=dev, requires_grad=True)
     scale_log = torch.tensor(1.0, dtype=torch.float32, requires_grad=True, device=dev)
     scale_log_normal = torch.tensor(1.0, dtype=torch.float32, requires_grad=True, device=dev)
     res = torch.tensor(np.random.normal(0.0,0.001, [G, 1]), dtype=torch.float32, requires_grad=True, device=dev)
     soft = lambda t: torch.nn.functional.softplus(t)+0.0001
     zero_intercept = torch.zeros([1], dtype=torch.float32, device=dev)
     loss2 = torch.zeros([1], dtype=torch.float32, device=dev)

     
     adapt = torch.tensor(np.rint(XT.shape[0]/batch_size).tolist(), dtype=torch.float32, device=dev)
     const = torch.tensor(np.log(2*(np.pi)), dtype=torch.float32, device=dev)
     
     def ll(res, W, XT, YT, indices_T, scale_log, scale_log_normal):
        pred = XT@W+res[indices_T,:]
        loss = -torch.distributions.Normal(pred, torch.clamp(scale_log_normal, 0.00001, 100.)).log_prob(YT).sum()/XT.shape[0]
        loss += -torch.distributions.Normal(zero_intercept, torch.clamp(scale_log,0.0001, 100.)).log_prob(res[indices_T.unique()]).sum()/adapt/XT.shape[0]#/G/indices_T.unique().shape[0] #/XT.shape[0]
        #loss += ((res[indices_T.unique()].pow(2.0))*(sigma_res)*0.5).sum()/res[indices_T.unique()].shape[0]#/crit_factor
        return loss


     optimizer = torch.optim.Adamax([W, scale_log,scale_log_normal, res], lr = 0.1)

     dataset = torch.utils.data.TensorDataset(XT, YT, indices_T)
     dataLoader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
     

     for i in range(epochs):

        for x, y, inds in dataLoader:
             optimizer.zero_grad()
             loss = ll(res, W, x.to(dev), y.to(dev), inds.to(dev),scale_log, scale_log_normal)#/x.shape[0]
             if det is True:

                  hess = torch.autograd.functional.hessian(lambda res: ll(res, W, x.to(dev), y.to(dev), inds.to(dev),scale_log, scale_log_normal), res, create_graph=True).squeeze()
                  ind2 = inds.unique()
                  
                  D_tmp = hess.index_select(0, ind2).index_select(1, ind2)
                  const_val = torch.eye(ind2.shape[0], device=dev, dtype=torch.float32)*0.01
                  logDA=(D_tmp+const_val).inverse().logdet()/G/indices_T.unique().shape[0]
                  loss2 = -(0.5*logDA)/adapt/x.shape[0]
                  loss+=loss2 + const*G/2.
             loss = loss
             loss.backward()
             optimizer.step()          


        if i % 2 == 0:
            print([loss.item(), loss2, scale_log.item()])
               
     return [(scale_log).cpu().data.numpy().tolist(), 
             (scale_log_normal).cpu().data.numpy().tolist(), 
             W.cpu().data.numpy()]


In [138]:
data = simulate(N=100, G = 20, sd_re=0.0001)
print(2.3**2)
dataset = pd.DataFrame(np.concatenate([data.X, data.Y, np.reshape(data.g, [data.g.shape[0],1])], axis = 1),
                       columns = ["X1", "X2", "Y", "g"]
                      )
md = smf.mixedlm("Y~0+X1+X2", dataset, groups=dataset["g"], )
mdf = md.fit(reml=False)
print(mdf.summary())

5.289999999999999
         Mixed Linear Model Regression Results
Model:            MixedLM Dependent Variable: Y        
No. Observations: 100     Method:             ML       
No. Groups:       20      Scale:              0.6395   
Min. group size:  5       Log-Likelihood:     -119.5411
Max. group size:  5       Converged:          Yes      
Mean group size:  5.0                                  
-------------------------------------------------------
             Coef.  Std.Err.   z    P>|z| [0.025 0.975]
-------------------------------------------------------
X1            0.460    0.131  3.520 0.000  0.204  0.716
X2           -0.495    0.142 -3.494 0.000 -0.773 -0.217
Group Var     0.000    0.068                           





In [150]:
fit_model(data, device="cpu:0", det=True, batch_size=10, epochs=100)

[19.211347579956055, tensor(-0.0407, grad_fn=<DivBackward0>), 0.16554474830627441]
[19.595626831054688, tensor(-0.0392, grad_fn=<DivBackward0>), 0.2396908551454544]
[19.7680721282959, tensor(-0.0452, grad_fn=<DivBackward0>), 0.2492096722126007]
[19.425113677978516, tensor(-0.0500, grad_fn=<DivBackward0>), 0.264719158411026]
[19.407373428344727, tensor(-0.0325, grad_fn=<DivBackward0>), 0.24883286654949188]
[19.419963836669922, tensor(-0.0313, grad_fn=<DivBackward0>), 0.2635261118412018]
[19.501163482666016, tensor(-0.0428, grad_fn=<DivBackward0>), 0.2796635925769806]
[19.622589111328125, tensor(-0.0396, grad_fn=<DivBackward0>), 0.24843156337738037]
[19.495563507080078, tensor(-0.0342, grad_fn=<DivBackward0>), 0.2557966113090515]
[19.30864715576172, tensor(-0.0261, grad_fn=<DivBackward0>), 0.2783925235271454]
[19.863971710205078, tensor(-0.0416, grad_fn=<DivBackward0>), 0.27472516894340515]
[19.411075592041016, tensor(-0.0586, grad_fn=<DivBackward0>), 0.2736388146877289]
[19.387697219848

[0.2669656574726105,
 0.7639532089233398,
 array([[ 0.4142138 ],
        [-0.46700388]], dtype=float32)]