In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.distributions import Dirichlet, Bernoulli, Uniform
import pandas as pd
from tqdm import tqdm as tm

from src import Simulation as sim
from src import Dir_Reg
from src import Align
from src import visualize_latent_space as vls
from src import sim_to_df as std

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device = torch.device("mps") if torch.backends.mps.is_available() else device

print(device)

cuda


Generate data sets that illustrate how the parameters of the model influence the behavior of the model. 
<br>
Settings:
<br>
Length of Time: 20 or 200
<br>
Embedding Dimemsion: 2
<br>
Number of Nodes: 1200
<br>
Parameters:  (1, 1, 5, 5), (1, 1, 2, 5), (1, 1, -2, 5), (1, 1, -5, 5)
<br>
Initial Distribution: Dir(1, 1, 1)


In [2]:
torch.manual_seed(4)

T, n, alpha_0 = 20, 30, [[1,1,1], [1,1,1], [1,1,1]]

model_pos_2 = sim.ABC(time = T,
                    nodes = n,
                    beta = [1, 1, 2, 5],
                    alpha_0 = alpha_0)
model_pos_1 = sim.ABC(time = T*10,
                    nodes = n,
                    beta = [1, 1, 1 , 5],
                    alpha_0 = alpha_0)
model_neg_2 = sim.ABC(time = T*10,
                    nodes = n,
                    beta = [1, 1, -2, 5],
                    alpha_0 = alpha_0)
model_neg_5 = sim.ABC(time = T,
                    nodes = n,
                    beta = [1, 1, -5, 5],
                    alpha_0 = alpha_0)

In [3]:
std.lat_pos_to_df(model_pos_2.synth_data["lat_pos"], 3).to_csv(r"simulated_data/time_vs_lat_pos/pos_2_sample.csv")
std.lat_pos_to_df(model_pos_1.synth_data["lat_pos"], 3).to_csv(r"simulated_data/time_vs_lat_pos/pos_1_sample.csv")
std.lat_pos_to_df(model_neg_2.synth_data["lat_pos"], 3).to_csv(r"simulated_data/time_vs_lat_pos/neg_2_sample.csv")
std.lat_pos_to_df(model_neg_5.synth_data["lat_pos"], 3).to_csv(r"simulated_data/time_vs_lat_pos/neg_5_sample.csv")

Below we generate the synthetic data set that shows how the latent position in ABCDPRGM evolves through time under different settings.

Let $\widehat{B} \in \mathbb{R}^{q \times p}$ be the MLE that corresponds to the design matrix $X \otimes I_p$, and $\tilde{\beta} = (C^T C)^{-1} C^T \widehat{B}$. Let $\widehat{\beta}$ be the MLE that corresponds to the design matrix $(X \otimes I_p)C$.

We first do Monte Carlo simulations to verify the asymptotic behavior of $\widehat{B}$ and $\tilde{\beta}$. 

In [10]:
class ABC_MC_consistency_T2:
    """ keep no_oracle False for now... RGD fucks up the permutation bad, need to minimize wrt permuation in addition """
    def __init__(self, number_of_iterations, nodes_set, beta, alpha_0, constrained = False, oracle_lat_pos = True, oracle_align = False, no_oracle = False):
        self.settings = self.settings(number_of_iterations, nodes_set, beta, alpha_0, constrained, oracle_lat_pos, oracle_align, no_oracle)

    class settings:
        def __init__(self, number_of_iterations, nodes_set, beta, alpha_0, constrained, oracle_lat_pos, oracle_align, no_oracle):
            self.n_iter = number_of_iterations
            self.n_set = nodes_set
            self.beta = beta
            self.alpha_0 = alpha_0
            self.constrained = constrained
            self.OL = oracle_lat_pos
            self.OA = oracle_align
            self.NO = no_oracle

            self.init_model = sim.ABC(time = 2,
                                 nodes = 3,
                                 beta = self.beta,
                                 alpha_0 = self.alpha_0)
            p, q = self.init_model.settings.B.shape
            self.constraint = Dir_Reg.fit.gen_constraint(p , True)
            self.B_size = 4 if self.constrained else q*p

    def one_round(self, nodes, seed = None):

        if seed is not None: 
            torch.manual_seed(seed)
        
        n_type = self.settings.OL + self.settings.OA + self.settings.NO
        q, p = self.settings.init_model.settings.B.shape

        model = self.settings.init_model
        model.update_settings(nodes = nodes)
        constrained = self.settings.constrained

        Z0 = model.synth_data["lat_pos"][0,]
        Z1 = model.synth_data["lat_pos"][1,]
        Y0 = model.synth_data["obs_adj"][0,]
        Y1 = model.synth_data["obs_adj"][1,]

        est_OL= est_OA = est_NO = fish_OL = fish_OA = fish_NO = None
        method_OL= method_OA = method_NO = None
 
        if self.settings.OL:
            X0_ora_lat_pos = sim.ABC.gen_X(Y0, Z0, model.settings.K)
            est_ora_lat_pos = Dir_Reg.fit(predictor = X0_ora_lat_pos, 
                                          response = Z1, 
                                          constrained = constrained, 
                                          beta_guess = model.settings.beta)
            est_OL = est_ora_lat_pos.est_result["estimate"].reshape(-1)
            fish_OL = est_ora_lat_pos.est_result["fisher_info"].reshape(1, -1)
            method_OL = torch.tensor([[1,0,0]])

        if self.settings.OA:
            Z0_ora_align = Align.Oracle(Z0, Y0, (p-1)).embed_aligned
            Z1_ora_align = Align.Oracle(Z1, Y1, (p-1)).embed_aligned
            X0_ora_align = sim.ABC.gen_X(Y0, Z0_ora_align, model.settings.K)
            est_ora_align = Dir_Reg.fit(predictor = X0_ora_align, 
                                        response = Z1_ora_align, 
                                        constrained = constrained, 
                                        beta_guess = model.settings.beta)
            est_OA = est_ora_align.est_result["estimate"].reshape(-1)
            fish_OA = est_ora_align.est_result["fisher_info"].reshape(1, -1)
            method_OA = torch.tensor([[0,1,0]])

        if self.settings.NO:
            Z0_no_oracle = Align.No_Oracle(Y0, (p-1)).aligned
            init_guess = Align.Oracle(Z0_no_oracle, Y1, (p-1)).align_mat
            Z1_no_oracle = Align.No_Oracle(Y1, (p-1), init_guess).aligned
            X0_no_oracle = sim.ABC.gen_X(Y0, Z0_no_oracle, model.settings.K)

            est_no_oracle = Dir_Reg.fit(predictor = X0_no_oracle, 
                                        response = Z1_no_oracle, 
                                        constrained = constrained, 
                                        beta_guess = model.settings.beta)
            est_NO = est_no_oracle.est_result["estimate"].reshape(-1)
            fish_NO = est_no_oracle.est_result["fisher_info"].reshape(1, -11)
            method_NO = torch.tensor([[0,0,1]])
        
        est_list = [est_OL, est_OA, est_NO]
        est_to_concat = [item for item in est_list if item is not None]

        method_list = [method_OL, method_OA, method_NO]
        method_to_concat = [item for item in method_list if item is not None]

        est = torch.cat(est_to_concat, dim = 0).unsqueeze(dim = 1)
        method_core = torch.cat(method_to_concat, dim = 0)
        est_constrained = torch.tensor([constrained]).repeat(n_type * p * q).unsqueeze(dim = 1)
        est_nodes = torch.ones(n_type * q * p, 1) * nodes
        est_component = torch.arange(q*p).repeat(n_type).unsqueeze(dim = 1)
        est_method = torch.kron(method_core, torch.ones(q*p, 1))
        real = self.settings.init_model.settings.B.reshape(-1).repeat(n_type).unsqueeze(dim = 1)

        est_full = torch.cat([est_nodes, est_constrained, est_method, est_component, est, real], dim = 1)

        fish_list = [fish_OL, fish_OA, fish_NO]
        fish_to_concat = [item for item in fish_list if item is not None]
        fish = torch.cat(fish_to_concat, dim = 0)
        fish_nodes = torch.ones(n_type, 1) * nodes
        fish_full = torch.cat([fish_nodes, method_core, fish], dim = 1)

        result_dic = {"est": est_full, "fish": fish_full}

        return(result_dic)



In [11]:
temp = ABC_MC_consistency_T2(1, [3,6,9], [1,1,-4,5], [[1, 1, 10], [1, 10, 1], [10, 1, 1]], oracle_align= True)
result = temp.one_round(3000)
pd.DataFrame(result["est"], columns = ["nodes", "constrained", "method_OL", "method_OA", "method_NO", "component", "B_est", "B_real"])

torch.Size([2, 1]) torch.Size([2, 3]) torch.Size([2, 441])


Unnamed: 0,nodes,constrained,method_OL,method_OA,method_NO,component,B_est,B_real
0,3000.0,0.0,1.0,0.0,0.0,0.0,0.982366,1.0
1,3000.0,0.0,1.0,0.0,0.0,1.0,-0.052856,0.0
2,3000.0,0.0,1.0,0.0,0.0,2.0,-0.962245,-1.0
3,3000.0,0.0,1.0,0.0,0.0,3.0,0.161568,0.0
4,3000.0,0.0,1.0,0.0,0.0,4.0,1.13826,1.0
5,3000.0,0.0,1.0,0.0,0.0,5.0,-0.841529,-1.0
6,3000.0,0.0,1.0,0.0,0.0,6.0,1.127909,1.0
7,3000.0,0.0,1.0,0.0,0.0,7.0,0.213292,0.0
8,3000.0,0.0,1.0,0.0,0.0,8.0,-0.921877,-1.0
9,3000.0,0.0,1.0,0.0,0.0,9.0,-0.085244,0.0


In [14]:
result["fish"][:, :4]

tensor([[3.0000e+03, 1.0000e+00, 0.0000e+00, 0.0000e+00],
        [3.0000e+03, 0.0000e+00, 1.0000e+00, 0.0000e+00]])

In [None]:
def fix_nodes_n_iter(model, nodes, n_iter, constrained = False, oracle_lat_pos = True, oracle_align = False, no_oracle = False):

    """ run the model with n nodes n_iter number of times, record the estimated beta and fisher's information """

    model.update_settings(nodes = nodes)
    
    B = model.settings.B
    beta = model.settings.beta

    q, p = B.shape
    constraint = Dir_Reg.fit.gen_constraint(p, True)

    if constrained: 
        qp = 4
    else:
        qp = q*p

    B_hat_ora_lat_pos = torch.zeros(n_iter, q*p)
    B_fish_ora_lat_pos = torch.zeros(n_iter, (qp)**2)
    
    B_hat_ora_align = torch.zeros(n_iter, q*p)
    B_fish_ora_align = torch.zeros(n_iter, (qp)**2)

    B_hat_no_oracle = torch.zeros(n_iter, q*p)
    B_fish_no_oracle = torch.zeros(n_iter, (qp)**2)

    for i in tm(range(n_iter), desc = str(nodes)):

        torch.manual_seed(i)
        model.update_settings()
        
        Z0 = model.synth_data["lat_pos"][0,]
        Z1 = model.synth_data["lat_pos"][1,]
        Y0 = model.synth_data["obs_adj"][0,]
        Y1 = model.synth_data["obs_adj"][1,]

        

        if oracle_lat_pos:
            X0_ora_lat_pos = sim.ABC.gen_X(Y0, Z0, model.settings.K)
            est_ora_lat_pos = Dir_Reg.fit(predictor = X0_ora_lat_pos, response = Z1, constrained = constrained, beta_guess = model.settings.beta)
            B_hat_ora_lat_pos[i,] = est_ora_lat_pos.est_result["estimate"].reshape(-1)
            B_fish_ora_lat_pos[i,] = est_ora_lat_pos.est_result["fisher_info"].reshape(-1)

        if oracle_align:
            Z0_ora_align = Align.Oracle(Z0, Y0, (p-1)).embed_aligned
            Z1_ora_align = Align.Oracle(Z1, Y1, (p-1)).embed_aligned
            X0_ora_align = sim.ABC.gen_X(Y0, Z0_ora_align, model.settings.K)
            est_ora_align = Dir_Reg.fit(predictor = X0_ora_align, response = Z1_ora_align, constrained = constrained, beta_guess = model.settings.beta)
            B_hat_ora_align[i,] = est_ora_align.est_result["estimate"].reshape(-1)
            B_fish_ora_align[i,] = est_ora_align.est_result["fisher_info"].reshape(-1)


        if no_oracle:
            Z0_no_oracle = Align.Op_Riemannian_GD(Y0, "relu")
            Z1_no_oracle = Align.Op_Riemannian_GD(Y1, "relu", Z0_no_oracle)
            X0_no_oracle = sim.ABC.gen_X(Y0, Z0_no_oracle, model.settings.K)
            est_no_oracle = Dir_Reg.fit(predictor = X0_no_oracle, response = Z1_no_oracle, constrained = constrained, beta_guess = model.settings.beta)
            B_hat_no_oracle[i,] = est_no_oracle.est_result["estimate"].reshape(-1)
            B_fish_no_oracle[i,] = est_no_oracle.est_result["fisher_info"].reshape(-1)


    def to_pd_df(n, mat, b_real, name):
        """ so that the experiment result can live in the DLS data wonderland """
        n_iter, qp = mat.shape
        vec = mat.reshape(1, -1)
        comp = torch.arange(1, qp+1).repeat(n_iter).unsqueeze(dim = 0)

        constrained_ind = torch.tensor([constrained]).repeat(n_iter * qp).unsqueeze(dim = 0)
        seed_id = torch.arange(n_iter).repeat_interleave(qp).unsqueeze(dim = 0)
        node_id = n * torch.ones(qp*n_iter).unsqueeze(dim = 0)

        b_real_stack = torch.stack([b_real.reshape(-1)]* n_iter).reshape(1, -1)

        b_df = torch.cat([seed_id, constrained_ind, node_id, comp, vec, b_real_stack], dim = 0).T
        b_df = pd.DataFrame(b_df)
        b_df.columns = ["Seed", "Constrained", "Nodes", "Comp", name + "_hat", name + "_real"]

        for column in b_df.columns:
            if "_" not in column:
                try:
                    b_df[column] = b_df[column].astype(int)
                except ValueError:
                    pass

        return(b_df)

    B_df = to_pd_df(nodes, B_hat, B, "B")
    beta_tilde = torch.linalg.solve(constraint.T @ constraint, constraint.T) @ B_hat.T
    beta_df = to_pd_df(nodes, beta_tilde.T, beta, "beta")

    result_dict = {"B_hat": B_df, "beta_tilde": beta_df, "fish_est": fish_est.mean(dim = 0).reshape(qp, qp)}
    return(result_dict)

In [None]:
torch.manual_seed(0)
model = sim.ABC(time = 2,
                nodes = 1200,
                beta = [1,1,-4,5],
                alpha_0 = [[1, 1, 10], [1, 10, 1], [10, 1, 1]])
n_iter = 1
n_set = list(range(3000, 9000, 1500))
constraint = Dir_Reg.fit.gen_constraint(3, True)
constrained = False

oracle_constrained = std.fix_nodes_n_iter(model, 1500, n_iter, constrained)
df_B_hat = oracle_constrained["B_hat"]
df_beta_tilde = oracle_constrained["beta_tilde"]
df_B_hat_fish = oracle_constrained["fish_est"].unsqueeze(dim = 0)

for n in n_set:
    oracle_constrained = std.fix_nodes_n_iter(model, n, n_iter, constrained)
    df_B_hat = pd.concat([df_B_hat, oracle_constrained["B_hat"]], ignore_index= True)
    df_beta_tilde = pd.concat([df_beta_tilde, oracle_constrained["beta_tilde"]], ignore_index= True)
    df_B_hat_fish = torch.cat([df_B_hat_fish, oracle_constrained["fish_est"].unsqueeze(dim = 0)], dim = 0)

df_B_hat_fish_vec = pd.DataFrame(df_B_hat_fish.reshape(len(n_set) + 1, -1))
df_B_hat_fish_diag = torch.stack([df_B_hat_fish[i].diag() for i in range(len(n_set) + 1)])
df_B_hat_fish_diag = pd.DataFrame(df_B_hat_fish_diag)


In [None]:
df_B_hat.to_csv(r"simulated_data\emp_var_vs_obs_var\B_oracle.csv", index = False)
df_beta_tilde.to_csv(r"simulated_data\emp_var_vs_obs_var\B_oracle_proj.csv", index = False)
df_B_hat_fish_vec.to_csv(r"simulated_data\emp_var_vs_obs_var\B_fish.csv", index = False)

In [86]:
class ABC_MC_consistency_T2:
    """ keep no_oracle False for now... RGD fucks up the permutation bad, need to minimize wrt permuation in addition """
    def __init__(self, number_of_iterations, nodes_set, beta, alpha_0, constrained = False, oracle_lat_pos = True, oracle_align = False, no_oracle = False):
        self.settings = self.settings(number_of_iterations, nodes_set, beta, alpha_0, constrained, oracle_lat_pos, oracle_align, no_oracle)

    class settings:
        def __init__(self, number_of_iterations, nodes_set, beta, alpha_0, constrained, oracle_lat_pos, oracle_align, no_oracle):
            self.n_iter = number_of_iterations
            self.n_set = nodes_set
            self.beta = beta
            self.alpha_0 = alpha_0
            self.constrained = constrained
            self.OL = oracle_lat_pos
            self.OA = oracle_align
            self.NO = no_oracle

            self.init_model = sim.ABC(time = 2,
                                 nodes = 3,
                                 beta = self.beta,
                                 alpha_0 = self.alpha_0)
            p, q = self.init_model.settings.B.shape
            self.constraint = Dir_Reg.fit.gen_constraint(p , True)
            self.B_size = 4 if self.constrained else q*p

    def one_round(self, nodes, seed = None):

        if seed is not None: 
            torch.manual_seed(seed)
        
        n_type = self.settings.OL + self.settings.OA + self.settings.NO
        q, p = self.settings.init_model.settings.B.shape

        model = self.settings.init_model
        model.update_settings(nodes = nodes)
        constrained = self.settings.constrained

        Z0 = model.synth_data["lat_pos"][0,]
        Z1 = model.synth_data["lat_pos"][1,]
        Y0 = model.synth_data["obs_adj"][0,]
        Y1 = model.synth_data["obs_adj"][1,]

        est_OL= est_OA = est_NO = fish_OL = fish_OA = fish_NO = None
        method_OL= method_OA = method_NO = None
 
        if self.settings.OL:
            X0_ora_lat_pos = sim.ABC.gen_X(Y0, Z0, model.settings.K)
            est_ora_lat_pos = Dir_Reg.fit(predictor = X0_ora_lat_pos, 
                                          response = Z1, 
                                          constrained = constrained, 
                                          beta_guess = model.settings.beta)
            est_OL = est_ora_lat_pos.est_result["estimate"].reshape(-1)
            fish_OL = est_ora_lat_pos.est_result["fisher_info"].reshape(1, -1)
            method_OL = torch.tensor([[1,0,0]])

        if self.settings.OA:
            Z0_ora_align = Align.Oracle(Z0, Y0, (p-1)).embed_aligned
            Z1_ora_align = Align.Oracle(Z1, Y1, (p-1)).embed_aligned
            X0_ora_align = sim.ABC.gen_X(Y0, Z0_ora_align, model.settings.K)
            est_ora_align = Dir_Reg.fit(predictor = X0_ora_align, 
                                        response = Z1_ora_align, 
                                        constrained = constrained, 
                                        beta_guess = model.settings.beta)
            est_OA = est_ora_align.est_result["estimate"].reshape(-1)
            fish_OA = est_ora_align.est_result["fisher_info"].reshape(1, -1)
            method_OA = torch.tensor([[0,1,0]])

        if self.settings.NO:
            Z0_no_oracle = Align.Op_Riemannian_GD(Y0, "relu")
            Z1_no_oracle = Align.Op_Riemannian_GD(Y1, "relu", Z0_no_oracle)
            X0_no_oracle = sim.ABC.gen_X(Y0, Z0_no_oracle, model.settings.K)
            est_no_oracle = Dir_Reg.fit(predictor = X0_no_oracle, 
                                        response = Z1_no_oracle, 
                                        constrained = constrained, 
                                        beta_guess = model.settings.beta)
            est_NO = est_no_oracle.est_result["estimate"].reshape(-1)
            fish_NO = est_no_oracle.est_result["fisher_info"].reshape(1, -11)
            method_NO = torch.tensor([[0,0,1]])
        
        est_list = [est_OL, est_OA, est_NO]
        est_to_concat = [item for item in est_list if item is not None]

        method_list = [method_OL, method_OA, method_NO]
        method_to_concat = [item for item in method_list if item is not None]

        est = torch.cat(est_to_concat, dim = 0).unsqueeze(dim = 1)
        method_core = torch.cat(method_to_concat, dim = 0)

        est_constrained = torch.tensor([constrained]).repeat(n_type * p * q).unsqueeze(dim = 1)
        est_nodes = torch.ones(n_type * q * p, 1) * nodes
        est_component = torch.arange(q*p).repeat(n_type).unsqueeze(dim = 1)
        est_method = torch.kron(method_core, torch.ones(q*p, 1))
        real = self.settings.init_model.settings.B.reshape(-1).repeat(n_type).unsqueeze(dim = 1)

        est_full = torch.cat([est_nodes, est_constrained, est_method, est_component, est, real], dim = 1)

        fish_list = [fish_OL, fish_OA, fish_NO]
        fish_to_concat = [item for item in fish_list if item is not None]
        fish = torch.cat(fish_to_concat, dim = 1)
        fish_nodes = torch.ones(n_type, 1) * nodes
        fish_full = torch.cat([fish_nodes, method_core, fish], dim = 1)

        result_dic = {"est": est_full, "fish": fish_full}

        return(result_dic)



Unnamed: 0,nodes,constrained,method_OL,method_OA,method_NO,component,B_est,B_real
0,3000.0,0.0,1.0,0.0,0.0,0.0,1.219972,1.0
1,3000.0,0.0,1.0,0.0,0.0,1.0,0.232042,0.0
2,3000.0,0.0,1.0,0.0,0.0,2.0,-0.640052,-1.0
3,3000.0,0.0,1.0,0.0,0.0,3.0,-0.18047,0.0
4,3000.0,0.0,1.0,0.0,0.0,4.0,0.801624,1.0
5,3000.0,0.0,1.0,0.0,0.0,5.0,-1.136532,-1.0
6,3000.0,0.0,1.0,0.0,0.0,6.0,0.830992,1.0
7,3000.0,0.0,1.0,0.0,0.0,7.0,-0.14548,0.0
8,3000.0,0.0,1.0,0.0,0.0,8.0,-1.318146,-1.0
9,3000.0,0.0,1.0,0.0,0.0,9.0,0.239815,0.0


In [96]:
result["fish"].shape

torch.Size([1, 445])