In [8]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.distributions import Dirichlet, Bernoulli, Uniform
import pandas as pd
import tqdm as tm

from src import Simulation as sim
from src import Dir_Reg
from src import Align
from src import visualize_latent_space as vls

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [9]:
#default settings
T, n = 2, 3000
# mat = torch.tensor([[36,2,27], [99,86,14], [72,18,34]], dtype= torch.float).reshape(K, p)/10
mat = [[1, 1, 10], [1, 10, 1], [10, 1, 1]]
model = sim.ABC(time = T,
                nodes = n,
                beta = [1,1,-5,5.0],
                alpha_0 = mat)


In [10]:
Z0 = model.synth_data["lat_pos"][0,]
Z1 = model.synth_data["lat_pos"][1,]

Y0 = model.synth_data["obs_adj"][0,]
Y1 = model.synth_data["obs_adj"][1,]

X0 = sim.ABC.gen_X(Y0, Z0, 3)

In [11]:
temp_est = Dir_Reg.fit(predictor = X0,
                       response = Z1, 
                       constrained = False,
                       beta_guess = model.settings.beta)
temp_est.est_result

{'final_estimate': tensor([[ 1.0262,  0.0713, -1.0622],
         [ 0.1378,  1.2264, -0.8604],
         [ 0.9429, -0.1120, -0.9406],
         [-0.1009,  0.7961, -1.0597],
         [-5.1584, -0.2554,  4.8828],
         [-0.2432, -5.3347,  4.8304],
         [ 5.1436,  5.2169,  2.0664]], device='cuda:0'),
 'final_fisher_info': tensor([[ 28445.1211,  -8002.1997, -20394.6777,   3734.2437,  -1174.7316,
           -2519.6396,  28731.2539,  -7784.8218, -20902.2520,   3530.7866,
           -1377.1213,  -2109.4307,   5486.8325,  -1729.9487,  -3702.1221,
           18988.4746,  -4807.9766, -14138.0381,  35857.6758,  -9971.4297,
          -25759.7773],
         [ -8002.1997,   9986.5459,  -1625.8208,  -1174.7316,   3873.5977,
           -2659.5273,  -7784.8218,   9314.8818,  -1172.8185,  -1377.1213,
            4535.5610,  -3117.8027,  -1729.9490,   5554.1553,  -3744.7734,
           -4807.9766,   6476.5386,  -1430.4650,  -9971.4277,  16938.0586,
           -6506.3184],
         [-20394.6777,  -162

In [12]:
ora_est_0 = Align.Oracle(Z0, Y0, 2)
ora_est_1 = Align.Oracle(Z1, Y1, 2)

In [13]:
X0_tilde = sim.ABC.gen_X(Y0, ora_est_0.embed_aligned, 3)
Z1_tilde = ora_est_1.embed_aligned
ora_reg = Dir_Reg.fit(predictor = X0_tilde, 
                      response = Z1_tilde,
                      constrained = True,
                      beta_guess = model.settings.beta)

In [14]:
ora_reg.est_result

{'final_estimate': tensor([[ 1.0090,  0.0000, -1.0090],
         [ 0.0000,  1.0090, -1.0090],
         [ 0.9898,  0.0000, -0.9898],
         [ 0.0000,  0.9898, -0.9898],
         [-5.0150,  0.0000,  5.0150],
         [ 0.0000, -5.0150,  5.0150],
         [ 4.1221,  4.1221,  1.1059]], device='cuda:0'),
 'final_fisher_info': tensor([[30275.0098, 28318.9277, -9136.5166,   363.8755],
         [28318.9277, 27993.8848, -8941.5303,   375.8338],
         [-9136.5166, -8941.5303,  4513.4062,  1323.6959],
         [  363.8755,   375.8338,  1323.6960,  3028.5823]]),
 'ASE_info_lost': 0.04133333333333333}