In [2]:
import torch
import numpy as np
import numpy as np, numpy.linalg
import matplotlib.pyplot as plt
from dataclasses import dataclass
import statsmodels.api as sm
import statsmodels.formula.api as smf
import pandas as pd
from scipy.optimize import minimize
from dataclasses import dataclass
import numpy as np

@dataclass
class Simulation:
    X: np.ndarray
    Y: np.ndarray
    W: np.ndarray
    g: np.ndarray
    G: np.ndarray
    re: np.ndarray
    sd_re: float
    
def simulate(N=100,E=2,G=10, sd_re=1.0):
    X = np.random.uniform(-1, 1, [N, E])
    W = np.random.normal(0, 1.0, [E, 1])
    Y = np.matmul(X, W) + np.random.normal(0.0, 0.8, [N,1])
    re = np.random.normal(0, sd_re, [G,1])
    g = np.repeat(np.arange(0,G), np.round(N/G))
    Y = Y+re[g,:]
    return Simulation(X, Y,W, g, G, re, sd_re)

In [429]:
#@ray.remote(num_gpus=1,num_cpus=2)
def fit_model(data: Simulation, det=True, CAR=False, device = "cuda:0", batch_size = 100) -> list:
     N, E = data.X.shape
     SP = data.Y.shape[1]
     X, Y, G, indices = data.X, data.Y, data.G, data.g
     dev = torch.device(device)
    
     XT = torch.tensor(X, dtype=torch.float32, device=torch.device("cpu:0"))
     YT = torch.tensor(Y, dtype=torch.float32, device=torch.device("cpu:0"))
     indices_T = torch.tensor(indices, dtype=torch.long, device=torch.device("cpu:0"))

     # Variables
     W = torch.tensor(np.random.normal(0.0,0.001, [XT.shape[1], YT.shape[1]]), dtype=torch.float32, device=dev, requires_grad=True)
     scale_log = torch.tensor(1.0, dtype=torch.float32, requires_grad=True, device=dev)
     scale_log_normal = torch.tensor(0.0, dtype=torch.float32, requires_grad=True, device=dev)
     res = torch.tensor(np.random.normal(0.0,0.001, [G, 1]), dtype=torch.float32, requires_grad=True, device=dev)
     soft = lambda t: torch.nn.functional.softplus(t)+0.0001
     zero_intercept = torch.zeros([1], dtype=torch.float32, device=dev)
     loss2 = torch.zeros([1], dtype=torch.float32, device=dev)

     mse = torch.nn.MSELoss(reduction="mean")
     
     adapt = torch.tensor(np.rint(XT.shape[0]/batch_size).tolist(), dtype=torch.float32, device=dev)
     
     def ll(res, W, XT, YT, indices_T, scale_log, scale_log_normal):
        pred = XT@W+res[indices_T,:]
        loss = -torch.distributions.Normal(pred, torch.exp(scale_log_normal)).log_prob(YT).sum()/XT.shape[0]
        #loss = mse(YT, pred).mean()
        #sigma_noise = torch.exp(scale_log_normal)
        #sigma_res = 1./torch.square(torch.exp(scale_log))
        # crit_factor = 1.0 / (2 * sigma_noise.square())
        loss += -torch.distributions.Normal(zero_intercept, torch.exp(scale_log)).log_prob(res[indices_T.unique()]).sum()/adapt/XT.shape[0]#/XT.shape[0]
        #loss += ((res[indices_T.unique()].pow(2.0))*(sigma_res)*0.5).sum()/res[indices_T.unique()].shape[0]#/crit_factor
        return loss


     optimizer = torch.optim.Adamax([W, scale_log,scale_log_normal], lr = 0.1)
     optimizer_re = torch.optim.Adamax([res], lr = 0.1)

     dataset = torch.utils.data.TensorDataset(XT, YT, indices_T)
     dataLoader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
     
     for _ in range(20):
          for x, y, inds in dataLoader:
               optimizer_re.zero_grad()
               loss = ll(res, W.detach(), x.to(dev), y.to(dev), inds.to(dev),torch.tensor(100.0, dtype=torch.float32), scale_log_normal.detach() )
               loss.backward(  )
               optimizer_re.step()
     optimizer_re.zero_grad()
     print(torch.std(res))

     for i in range(40):
          if i > 1:
               #print(torch.std(res))
               for x, y, inds in dataLoader:
                    optimizer_re.zero_grad()
                    loss = ll(res, W, x.to(dev), y.to(dev), inds.to(dev), scale_log.detach(), scale_log_normal.detach())#/x.shape[0]
                    loss.backward()
                    optimizer_re.step()
               optimizer_re.zero_grad()
          
          if i > 0:
               for _ in range(3):
                    for x, y, inds in dataLoader:
                         optimizer.zero_grad()
                         loss = ll(res, W, x.to(dev), y.to(dev), inds.to(dev),scale_log, scale_log_normal)#/x.shape[0]
                         if det is True:
                              #loss.backward(  retain_graph=True )
                              #gg = torch.autograd.grad(loss, res, retain_graph=True, create_graph=True)[0]
                              #gg=gg[gg.nonzero(as_tuple=True)].reshape([-1,1])
                              #logDA = torch.reciprocal(gg**2).sqrt().reshape([-1]).diag().inverse().logdet()
                              #loss2 = (-0.5*logDA)#/x.shape[0]
                              hess = torch.autograd.functional.hessian(lambda res: ll(res, W, x.to(dev), y.to(dev), inds.to(dev),scale_log, scale_log_normal), res, create_graph=True).squeeze()
                              ind2 = inds.unique()
                              D_tmp = hess.index_select(0, ind2).index_select(1, ind2)
                              const_val = torch.eye(ind2.shape[0], device=dev, dtype=torch.float32)*0.01
                              logDA=(D_tmp+const_val).inverse().logdet()
                              loss2 = ((ind2.shape[0]*0.5*torch.log((2*torch.tensor(3.14, dtype=torch.float32))) - 0.5*logDA))/adapt/x.shape[0]
                              loss+=loss2
                         loss = loss
                         loss.backward()
                         optimizer.step()
          
          optimizer.zero_grad()
          


          if i % 20 == 0:
               print([loss.item(), loss2])
               
     return [(torch.exp(scale_log)).cpu().data.numpy().tolist(), 
             (torch.exp(scale_log_normal)).cpu().data.numpy().tolist(), 
             W.cpu().data.numpy()]


In [423]:
data = simulate(N=1000, G = 50, sd_re=2.3)
print(2.3**2)
dataset = pd.DataFrame(np.concatenate([data.X, data.Y, np.reshape(data.g, [data.g.shape[0],1])], axis = 1),
                       columns = ["X1", "X2", "Y", "g"]
                      )
md = smf.mixedlm("Y~0+X1+X2", dataset, groups=dataset["g"], )
mdf = md.fit(reml=False)
print(mdf.summary())

5.289999999999999
         Mixed Linear Model Regression Results
Model:            MixedLM Dependent Variable: Y         
No. Observations: 1000    Method:             ML        
No. Groups:       50      Scale:              0.6364    
Min. group size:  20      Log-Likelihood:     -1323.5319
Max. group size:  20      Converged:          Yes       
Mean group size:  20.0                                  
--------------------------------------------------------
             Coef.  Std.Err.    z    P>|z| [0.025 0.975]
--------------------------------------------------------
X1           -1.395    0.045 -30.859 0.000 -1.484 -1.306
X2            0.223    0.046   4.891 0.000  0.134  0.313
Group Var     5.869    1.517                            



In [428]:
fit_model(data, device="cpu:0", det=True, batch_size=1000)

tensor(1.2997, grad_fn=<StdBackward0>)
[inf, tensor([0.])]
tensor(175.1028, grad_fn=<LogdetBackward0>)
tensor(181.4552, grad_fn=<LogdetBackward0>)
tensor(186.3749, grad_fn=<LogdetBackward0>)
tensor(190.2208, grad_fn=<LogdetBackward0>)
tensor(193.2012, grad_fn=<LogdetBackward0>)
tensor(195.5276, grad_fn=<LogdetBackward0>)
tensor(197.3473, grad_fn=<LogdetBackward0>)
tensor(198.7539, grad_fn=<LogdetBackward0>)
tensor(199.8382, grad_fn=<LogdetBackward0>)
tensor(200.6704, grad_fn=<LogdetBackward0>)
tensor(201.2995, grad_fn=<LogdetBackward0>)
tensor(201.7719, grad_fn=<LogdetBackward0>)
tensor(202.1229, grad_fn=<LogdetBackward0>)
tensor(202.3760, grad_fn=<LogdetBackward0>)
tensor(202.5502, grad_fn=<LogdetBackward0>)
tensor(202.6577, grad_fn=<LogdetBackward0>)
tensor(202.7038, grad_fn=<LogdetBackward0>)
tensor(202.6926, grad_fn=<LogdetBackward0>)
tensor(202.6261, grad_fn=<LogdetBackward0>)
tensor(202.5042, grad_fn=<LogdetBackward0>)
tensor(202.3284, grad_fn=<LogdetBackward0>)
tensor(202.1013, 

[1.9286086559295654,
 1.0407851934432983,
 array([[-1.3586684 ],
        [ 0.22159573]], dtype=float32)]

In [379]:
data.W


array([[ 0.17998728],
       [-1.49548422]])