In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.distributions import Dirichlet, Bernoulli, Uniform
import pandas as pd
from tqdm import tqdm as tm

from src import Simulation as sim
from src import Dir_Reg
from src import Align
from src import visualize_latent_space as vls

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device = torch.device("mps") if torch.backends.mps.is_available() else device

print(device)

cuda


Generate data sets that illustrate how the parameters of the model influence the behavior of the model. 
<br>
Settings:
<br>
Length of Time: 20 or 200
<br>
Embedding Dimemsion: 2
<br>
Number of Nodes: 1200
<br>
Parameters:  (1, 1, 5, 5), (1, 1, 2, 5), (1, 1, -2, 5), (1, 1, -5, 5)
<br>
Initial Distribution: Dir(1, 1, 1)


In [None]:
torch.manual_seed(4)

T, n, alpha_0 = 20, 30, [[1,1,1], [1,1,1], [1,1,1]]

model_pos_2 = sim.ABC(time = T,
                    nodes = n,
                    beta = [1, 1, 2, 5],
                    alpha_0 = alpha_0)
model_pos_1 = sim.ABC(time = T*10,
                    nodes = n,
                    beta = [1, 1, 1 , 5],
                    alpha_0 = alpha_0)
model_neg_2 = sim.ABC(time = T*10,
                    nodes = n,
                    beta = [1, 1, -2, 5],
                    alpha_0 = alpha_0)
model_neg_5 = sim.ABC(time = T,
                    nodes = n,
                    beta = [1, 1, -5, 5],
                    alpha_0 = alpha_0)

In [None]:
sim.ABC_Monte_Carlo.lat_pos(model_pos_2.synth_data["lat_pos"], 3).to_csv(r"simulated_data/time_vs_lat_pos/pos_2_sample.csv")
sim.ABC_Monte_Carlo.lat_pos(model_pos_1.synth_data["lat_pos"], 3).to_csv(r"simulated_data/time_vs_lat_pos/pos_1_sample.csv")
sim.ABC_Monte_Carlo.lat_pos(model_neg_2.synth_data["lat_pos"], 3).to_csv(r"simulated_data/time_vs_lat_pos/neg_2_sample.csv")
sim.ABC_Monte_Carlo.lat_pos(model_neg_5.synth_data["lat_pos"], 3).to_csv(r"simulated_data/time_vs_lat_pos/neg_5_sample.csv")

Below we generate the synthetic data set that shows how the latent position in ABCDPRGM evolves through time under different settings.

Let $\widehat{B} \in \mathbb{R}^{q \times p}$ be the MLE that corresponds to the design matrix $X \otimes I_p$, and $\tilde{\beta} = (C^T C)^{-1} C^T \widehat{B}$. Let $\widehat{\beta}$ be the MLE that corresponds to the design matrix $(X \otimes I_p)C$.

We first do Monte Carlo simulations to verify the asymptotic behavior of $\widehat{B}$ and $\tilde{\beta}$. 

In [2]:
""" no_oracle option is for the RGD stuff, which is still somewhat problematic. It shouldn't be turned on. """
temp = sim.ABC_Monte_Carlo.consistency_T2(number_of_iterations = 100, 
                                          nodes_set = [1500],
                                          beta = [1,1,-4, 5],
                                          alpha_0 = [[10, 1, 1], [1, 10, 1], [1, 1, 10]],
                                          seeded = True,
                                          constrained = False,
                                          oracle_lat_pos = True,
                                          oracle_align = False,
                                          no_oracle = False)

1500: 100%|██████████| 100/100 [00:07<00:00, 12.94it/s]


In [3]:
temp.MC_result.est

Unnamed: 0,seed,nodes,constrained,method_OL,method_OA,method_NO,component,B_est,B_real
0,0,1500,0,1,0,0,0,1.647439,1.0
1,0,1500,0,1,0,0,1,0.595847,0.0
2,0,1500,0,1,0,0,2,-0.359172,-1.0
3,0,1500,0,1,0,0,3,-0.509387,0.0
4,0,1500,0,1,0,0,4,0.376627,1.0
...,...,...,...,...,...,...,...,...,...
2095,99,1500,0,1,0,0,16,-4.491355,-4.0
2096,99,1500,0,1,0,0,17,3.505875,4.0
2097,99,1500,0,1,0,0,18,5.343741,5.0
2098,99,1500,0,1,0,0,19,5.382929,5.0


In [4]:
temp.MC_result.fish

Unnamed: 0,nodes,method_1,method_2,method_3,fisher_info_1,fisher_info_2,fisher_info_3,fisher_info_4,fisher_info_5,fisher_info_6,...,fisher_info_432,fisher_info_433,fisher_info_434,fisher_info_435,fisher_info_436,fisher_info_437,fisher_info_438,fisher_info_439,fisher_info_440,fisher_info_441
0,1500,1,0,0,21919.652344,-7309.260742,-14581.612305,2755.815186,-1051.811768,-1684.058960,...,24212.728516,-8989.612305,-21355.687500,30511.404297,-22119.394531,-8676.549805,30962.054688,-39572.789062,-38221.687500,78283.054688
1,1500,1,0,0,21883.421875,-7081.566895,-14773.452148,2935.404785,-1034.082031,-1881.941040,...,24633.880859,-8597.886719,-20988.968750,29754.144531,-21835.228516,-8458.582031,30461.056641,-38877.062500,-37848.304688,77213.367188
2,1500,1,0,0,22023.812500,-7473.166992,-14522.678711,2864.290039,-1100.341675,-1744.867920,...,25586.054688,-9088.357422,-23473.201172,32727.468750,-22000.802734,-9195.208984,31360.187500,-39709.574219,-41736.398438,81935.523438
3,1500,1,0,0,21899.660156,-7562.814941,-14308.984375,2729.316406,-1080.119873,-1629.836914,...,23929.910156,-8553.638672,-19571.714844,28287.287109,-21147.718750,-8494.708984,29810.818359,-38051.664062,-36140.085938,74682.843750
4,1500,1,0,0,21181.562500,-7182.962891,-13970.496094,2610.392578,-966.390015,-1624.575684,...,22645.390625,-9129.926758,-22152.082031,31450.386719,-21360.531250,-8938.352539,30460.810547,-38882.050781,-39161.433594,78533.679688
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,1500,1,0,0,20308.250000,-7115.350586,-13162.728516,2666.958984,-1058.211914,-1588.710327,...,23899.367188,-8627.062500,-20789.917969,29584.115234,-19740.146484,-8616.038086,28517.863281,-36283.898438,-37871.726562,74645.109375
96,1500,1,0,0,20496.062500,-7240.099121,-13229.683594,2509.753418,-937.935791,-1553.359619,...,21711.855469,-8108.008301,-19962.972656,28236.501953,-19871.871094,-8188.785156,28224.259766,-36034.851562,-35809.656250,72336.187500
97,1500,1,0,0,22814.437500,-7793.013672,-14993.136719,2826.757812,-1052.794189,-1754.021973,...,23321.519531,-8858.515625,-20280.845703,29306.378906,-21615.373047,-8589.445312,30369.148438,-39162.714844,-36804.859375,76456.804688
98,1500,1,0,0,19903.439453,-6733.811523,-13141.502930,2570.027344,-1001.920410,-1548.564453,...,22951.332031,-8154.565430,-21060.488281,29381.597656,-19893.890625,-8347.847656,28406.972656,-35653.226562,-37399.820312,73544.250000
