In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.distributions import Dirichlet, Bernoulli, Uniform
import pandas as pd
import tqdm as tm

from src import Simulation as sim
from src import Dir_Reg
from src import Align
from src import visualize_latent_space as vls

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

Let $\widehat{B} \in \mathbb{R}^{q \times p}$ be the MLE that corresponds to the design matrix $X \otimes I_p$, and $\tilde{\beta} = (C^T C)^{-1} C^T \widehat{B}$. Let $\widehat{\beta}$ be the MLE that corresponds to the design matrix $(X \otimes I_p)C$.

We first do Monte Carlo simulations to verify the asymptotic behavior of $\widehat{B}$ and $\tilde{\beta}$. 

In [120]:
torch.manual_seed(0)
model = sim.ABC(time = 2,
                nodes = 1200,
                beta = [1,1,-4,5],
                alpha_0 = [[1, 1, 10], [1, 10, 1], [10, 1, 1]])

In [304]:
def fix_nodes_n_iter(nodes, n_iter):

    
    model = sim.ABC(time = 2,
                    nodes = nodes,
                    beta = [1,1,-4,5],
                    alpha_0 = [[1, 1, 10], [1, 10, 1], [10, 1, 1]])
    q, p = model.settings.B.shape
    constraint = Dir_Reg.fit.gen_constraint(p, True)
    B = model.settings.B
    beta = model.settings.beta

    B_hat = torch.zeros(n_iter, q*p)
    beta_tilde = torch.zeros(n_iter, 4)
    fish_est = torch.zeros(n_iter, (q*p)**2)

    for i in range(n_iter):
        torch.manual_seed(i)
        model.update_settings()
        Z0 = model.synth_data["lat_pos"][0,]
        Z1 = model.synth_data["lat_pos"][1,]
        Y0 = model.synth_data["obs_adj"][0,]
        X0 = sim.ABC.gen_X(Y0, Z0, model.settings.K)

        est = Dir_Reg.fit(predictor = X0, response = Z1, constrained = False, beta_guess = model.settings.beta)

        B_hat[i,] = est.est_result["estimate"].reshape(-1)
        beta_tilde[i,] = Dir_Reg.fit.proj_beta(est.est_result["estimate"], constraint).reshape(-1)
        fish_est[i,] = est.est_result["fisher_info"].reshape(-1)
    
    
    B_hat = B_hat.reshape(-1).unsqueeze(dim = 0)
    beta_tilde = beta_tilde.reshape(-1).unsqueeze(dim = 0)

    comp = torch.as_tensor(list(range(1, q*p+1)) * n_iter).unsqueeze(dim = 0)
    comp_tilde = torch.as_tensor(list(range(1, 4+1)) * n_iter).unsqueeze(dim = 0)

    seed_id = torch.as_tensor(list(range(1, n_iter+1)) * q*p).reshape(q*p, n_iter).T.reshape(-1).unsqueeze(dim = 0)
    seed_id_tilde = torch.as_tensor(list(range(1, n_iter+1)) * 4).reshape(4, n_iter).T.reshape(-1).unsqueeze(dim = 0)

    node_id = nodes * torch.ones(q*p*n_iter).unsqueeze(dim = 0)
    node_id_tilde = nodes * torch.ones(4*n_iter).unsqueeze(dim = 0)

    B_real = torch.stack([B.reshape(-1)]* n_iter).reshape(-1).unsqueeze(dim = 0)
    beta_real = torch.stack([beta.reshape(-1)]* n_iter).reshape(-1).unsqueeze(dim = 0)
    

    B_hat = torch.cat([node_id, seed_id, comp, B_hat, B_real], dim = 0).T
    B_hat = pd.DataFrame(B_hat)

    beta_tilde = torch.cat([node_id_tilde, seed_id_tilde, comp_tilde, beta_tilde, beta_real], dim = 0).T
    beta_tilde = pd.DataFrame(beta_tilde)

    column_names = ['Node', 'Seed', 'Comp', 'B_hat', 'B_real']
    B_hat.columns = column_names
    B_hat = B_hat.astype({'Node': 'int32', 'Seed': 'int32', 'Comp': 'int32', 'B_hat': 'float64', 'B_real': 'float64'})

    column_names = ['Node', 'Seed', 'Comp', 'beta_tilde', 'beta_real']
    beta_tilde.columns = column_names
    beta_tilde = beta_tilde.astype({'Node': 'int32', 'Seed': 'int32', 'Comp': 'int32', 'beta_tilde': 'float64', 'beta_real': 'float64'})

    


    result_dict = {"B_hat": B_hat, "beta_tilde": beta_tilde, "fish_est": fish_est.mean(dim = 0).reshape(q*p, q*p)}
    return(result_dict)


In [326]:
n_iter = 1000
n_set = list(range(3000, 13500, 1500))
constraint = Dir_Reg.fit.gen_constraint(3, True)
H = torch.linalg.solve(constraint.T @ constraint, constraint.T)

temp = fix_nodes_n_iter(1500, n_iter)
df_B = temp["B_hat"]
df_beta = temp["beta_tilde"]
df_fish = temp["fish_est"].unsqueeze(dim = 0)
df_fish_C = (H @ torch.linalg.solve(temp["fish_est"], H.T)).unsqueeze(dim = 0)

for n in n_set:
    temp = fix_nodes_n_iter(n, n_iter)
    df_B = pd.concat([df_B, temp["B_hat"]], ignore_index= True)
    df_beta = pd.concat([df_beta, temp["beta_tilde"]], ignore_index= True)
    df_fish = torch.cat([df_fish, temp["fish_est"].unsqueeze(dim = 0)], dim = 0)
    df_fish_C = torch.cat([df_fish_C, (H @ torch.linalg.solve(temp["fish_est"], H.T)).unsqueeze(dim = 0)], dim = 0)

df_fish_C_diag = torch.stack([df_fish_C[i].diag() for i in range(len(n_set) + 1)])
df_fish_diag = torch.stack([df_fish[i].diag() for i in range(len(n_set) + 1)])

df_fish_C_diag = pd.DataFrame(df_fish_C_diag)
df_fish_diag = pd.DataFrame(df_fish_diag)


In [327]:
df_B.to_csv(r"C:\Users\yangs\Desktop\df_B.csv", index = False)
df_beta.to_csv(r"C:\Users\yangs\Desktop\df_beta.csv", index = False)
df_fish_C_diag.to_csv(r"C:\Users\yangs\Desktop\df_fish_C_diag.csv", index = False)


In [328]:
df_beta

Unnamed: 0,Node,Seed,Comp,beta_tilde,beta_real
0,1500,1,1,0.987787,1.0
1,1500,1,2,1.005503,1.0
2,1500,1,3,-4.017076,-4.0
3,1500,1,4,5.690616,5.0
4,1500,2,1,0.998375,1.0
...,...,...,...,...,...
31995,12000,999,4,5.120615,5.0
31996,12000,1000,1,1.015697,1.0
31997,12000,1000,2,0.989503,1.0
31998,12000,1000,3,-3.988922,-4.0
