In [1]:
from utils import *
from model import *
import numpy as np
import torch
import sys
import torch.nn.functional as F
import torch.optim as optim
from utils_for_experiments import *
# from transform_wrappers_multiprocessing import *
from transform_wrappers import *
from matplotlib import pyplot as plt
from tqdm import tqdm
from discretize import *
from visualization import *


np.set_printoptions(linewidth=np.inf)

if __name__ == "__main__":

    # device = "cuda" if torch.cuda.is_available() else "cpu"
    device = "cpu"
    print(f"using device {device}")

    ####################     Generation parameters     #######################################################
    dataArgs = dict()

    maximum_number_of_nodes_n = "20" #@param [12, 24, 30, 48]
    dataArgs["max_n_node"] = int(maximum_number_of_nodes_n)

    range_of_linkage_probability_p = "0.0, 1.0" #@param [[0.0,1.0], [0.2,0.8], [0.5,0.5]]
    dataArgs["p_range"] = [float(range_of_linkage_probability_p.split(",")[0]), float(range_of_linkage_probability_p.split(",")[1])]

    node_attributes = "degree" #@param ["uniform", "degree", "random"]
    dataArgs["node_attr"] = node_attributes

    number_of_graph_instances = "5000" #@param [1, 100, 1000, 10000, 25000, 50000, 100000, 200000, 500000, 1000000]
    dataArgs["n_graph"] = int(number_of_graph_instances)

    dataArgs["upper_triangular"] = False
    A, Attr, Param, Topol = generate_data_v2(dataArgs)
    # g, a, attr = unpad_data(A[0], Attr[0])

    ####################     Model parameters     #######################################################
    modelArgs = {"gnn_filters": 2, "conv_filters": 16, "kernel_size": 3}

    number_of_latent_variables= "20" #@param [1, 2, 3, 4, 5]
    modelArgs["latent_dim"] = int(number_of_latent_variables)

    trainArgs = dict()

    weight_graph_reconstruction_loss = "30" #@param [0, 1, 2, 3, 5, 10, 20]
    weight_attribute_reconstruction_loss = "5" #@param [0, 1, 2, 3, 5, 10, 20]
    beta_value = "20" #@param [0, 1, 2, 3, 5, 10, 20]
    trainArgs["loss_weights"] = [int(weight_graph_reconstruction_loss), int(weight_attribute_reconstruction_loss), int(beta_value)]

    epochs = "35" #@param [10, 20, 50]
    trainArgs["epochs"] = int(epochs)
    batch_size = "512" #@param [2, 4, 8, 16, 32, 128, 512, 1024]
    trainArgs["batch_size"] = int(batch_size)
    early_stop = "2" #@param [1, 2, 3, 4, 10]
    trainArgs["early_stop"] = int(early_stop)
    train_test_split = "0.2" #@param [0.1, 0.2, 0.3, 0.5]
    train_validation_split = "0.1" #@param [0.1, 0.2, 0.3, 0.5]
    trainArgs["data_split"] = float(train_test_split)
    trainArgs["validation_split"] = float(train_validation_split)
    lr = "0.001"  #@param [0.1, 0.01, 0.001, 0.0001, 0.00001]
    trainArgs["lr"] = float(lr)



    ## Train and Test Split _______________________________________________

    A_train = torch.from_numpy(A[:int((1-trainArgs["data_split"]-trainArgs["validation_split"])*A.shape[0])])
    Attr_train = generate_batch(torch.from_numpy(Attr[:int((1-trainArgs["data_split"]-trainArgs["validation_split"])*Attr.shape[0])]), trainArgs["batch_size"])
    Param_train = generate_batch(torch.from_numpy(Param[:int((1-trainArgs["data_split"]-trainArgs["validation_split"])*Param.shape[0])]), trainArgs["batch_size"])
    Topol_train = generate_batch(torch.from_numpy(Topol[:int((1-trainArgs["data_split"]-trainArgs["validation_split"])*Topol.shape[0])]), trainArgs["batch_size"])

    A_validate = torch.from_numpy(A[int((1-trainArgs["data_split"]-trainArgs["validation_split"])*A.shape[0]):int((1-trainArgs["data_split"])*Attr.shape[0])])
    Attr_validate = generate_batch(torch.from_numpy(Attr[int((1-trainArgs["data_split"]-trainArgs["validation_split"])*Attr.shape[0]):int((1-trainArgs["data_split"])*Attr.shape[0])]), trainArgs["batch_size"])
    Param_validate = generate_batch(torch.from_numpy(Param[int((1-trainArgs["data_split"]-trainArgs["validation_split"])*Param.shape[0]):int((1-trainArgs["data_split"])*Attr.shape[0])]), trainArgs["batch_size"])
    Topol_validate = generate_batch(torch.from_numpy(Topol[int((1-trainArgs["data_split"]-trainArgs["validation_split"])*Topol.shape[0]):int((1-trainArgs["data_split"])*Attr.shape[0])]), trainArgs["batch_size"])

    A_test = torch.from_numpy(A[int((1-trainArgs["data_split"])*A.shape[0]):])
    Attr_test = generate_batch(torch.from_numpy(Attr[int((1-trainArgs["data_split"])*Attr.shape[0]):]), trainArgs["batch_size"])
    Param_test = generate_batch(torch.from_numpy(Param[int((1-trainArgs["data_split"])*Param.shape[0]):]), trainArgs["batch_size"])
    Topol_test = generate_batch(torch.from_numpy(Topol[int((1-trainArgs["data_split"])*Topol.shape[0]):]), trainArgs["batch_size"])


    # print(A_train.shape)
    # print(len(Attr_train), Attr_train[0].shape)

    ## build graph_conv_filters
    SYM_NORM = True
    A_train_mod = generate_batch(preprocess_adj_tensor_with_identity(torch.squeeze(A_train, -1), SYM_NORM), trainArgs["batch_size"])
    A_validate_mod = generate_batch(preprocess_adj_tensor_with_identity(torch.squeeze(A_validate, -1), SYM_NORM), trainArgs["batch_size"])
    A_test_mod = generate_batch(preprocess_adj_tensor_with_identity(torch.squeeze(A_test, -1), SYM_NORM), trainArgs["batch_size"])

    A_train = generate_batch(A_train, trainArgs["batch_size"])
    A_validate = generate_batch(A_validate, trainArgs["batch_size"])
    A_test = generate_batch(A_test, trainArgs["batch_size"])

    train_data = (Attr_train, A_train_mod, Param_train, Topol_train)
    validate_data = (Attr_validate, A_validate_mod, Param_validate, Topol_validate)
    test_data = (Attr_test, A_test_mod, Param_test, Topol_test)

    # attribute first -> (n, n), adjacency second -> (n, n, 1)
    modelArgs["input_shape"], modelArgs["output_shape"] = ((Attr_train[0].shape[1], Attr_train[0].shape[2]), (int(A_train_mod[0].shape[1] / modelArgs["gnn_filters"]), A_train_mod[0].shape[2], 1)),\
                                                          ((Attr_test[0].shape[1], Attr_test[0].shape[1]), (int(A_test_mod[0].shape[1] / modelArgs["gnn_filters"]), A_test_mod[0].shape[2], 1))
    # print(modelArgs["input_shape"], modelArgs["output_shape"])
    # print(A_train[0].shape)



  2%|█▎                                                                             | 84/5000 [00:00<00:12, 398.58it/s]

using device cpu



100%|█████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:13<00:00, 382.73it/s]


done


In [7]:
np.set_printoptions(linewidth=150)
np.set_printoptions(formatter={'float': lambda x: "{0:0.2f}".format(x)})

In [2]:
# attr = Attr_train[0][2].float()
# f = A_train_mod[0][2].float()
# # f = preprocess_adj_tensor_with_identity(torch.squeeze(f, -1), symmetric = False)
# z, z_mean, z_log_var, A_hat, attr_hat, A_hat_raw, max_score_per_node, min_score_per_node  = vae(attr.unsqueeze(0), f.unsqueeze(0))
# z

In [33]:
######################      VAE        ##########################################
operation_name = "density"  ## ["transitivity", "density", "forest fire ..."]
# param_path = operation_name + "_pretrained" + "_" + maximum_number_of_nodes_n
param_path = "."
vae = torch.load(param_path + "/vae.model")

train_losses = []
validation_losses = []
batched_z = []
batched_A_hat = []
batched_Attr_hat = []
batched_A_hat_discretized = []
batched_A_hat_discretized_test = []
batched_gcn_filters_from_A_hat = []
batched_z_test = []
batched_A_hat_test = []
batched_Attr_hat_test = []
batched_gcn_filters_from_A_hat_test = []
batched_A_hat_raw_train = []
batched_A_hat_raw_test = []
batched_A_hat_max_train = []
batched_A_hat_max_test = []
batched_A_hat_min_train = []
batched_A_hat_min_test = []
print("\n\n =================Extracting useful information=====================")
vae.eval()

def index_of(my_list, target):
    try: return my_list.index(target)
    except: return dataArgs["max_n_node"]

for e in range(1):
    loss_cum = 0
    for i in range(len(Attr_train)):
        attr = Attr_train[i].float().to(device)
        A = A_train[i].float().to(device)
        graph_conv_filters = A_train_mod[i].float().to(device)

        z, z_mean, z_log_var, A_hat, attr_hat, A_hat_raw, max_score_per_node, min_score_per_node = vae(attr, graph_conv_filters)

        if e + 1 == 1:
            batched_z.append(z.detach())
            batched_Attr_hat.append(attr_hat.detach())
            batched_A_hat.append(A_hat.detach())
            temp = A_hat.detach().cpu()
            batched_gcn_filters_from_A_hat.append(preprocess_adj_tensor_with_identity(torch.squeeze(temp, -1), symmetric = False))

            A_discretize = A.cpu().squeeze().numpy()
            A_hat_discretize = A_hat.detach().cpu().squeeze().numpy()
            discretizer = Discretizer(A_discretize, A_hat_discretize)
            A_hat_discretize = discretizer.discretize(method='hard_threshold', threshold=0.4)
            A_hat_discretize = torch.unsqueeze(torch.from_numpy(A_hat_discretize), -1)

            batched_A_hat_discretized.append(A_hat_discretize)
            batched_A_hat_raw_train.append(A_hat_raw.detach())
            batched_A_hat_max_train.append(max_score_per_node.detach())
            batched_A_hat_min_train.append(min_score_per_node.detach())

            # count = 0
            for j in range(len(batched_A_hat_discretized[i])):
                temp = list(torch.diag(batched_A_hat_discretized[i][j].detach().reshape(dataArgs["max_n_node"], -1)))[::-1]
                pred_node_num = dataArgs["max_n_node"] - index_of(list(temp), 1)
                Param_train[i][j][-1] = pred_node_num  # predicted node num have ~96% acc
                true_node_num = int(Param_train[i][j][0])
                # print(pred_node_num)
                # print(true_node_num)

                # count += pred_node_num == true_node_num
            # print(f"node prediction accuracy : {count / len(batched_A_hat_discretized[i])}")

        loss = loss_func((A, attr), (A_hat, attr_hat), z_mean, z_log_var, trainArgs, modelArgs)
        loss_cum += loss.item()

    print("Model loss {} ".format(loss_cum / len(Attr_train)))


a,p,r,f = compute_score_batched(A_train, batched_A_hat_discretized)
print(f"VAE performance:  \n Accuracy: {a},  \n Precision: {p},  \n Recall: {r},  \n F1 Score: {f}\n")

density_ori, diameter_ori, cluster_coef_ori, edges_ori, avg_degree_ori = topological_measure(A_train)
density_hat, diameter_hat, cluster_coef_hat, edges_hat, avg_degree_hat = topological_measure(batched_A_hat_discretized)
print(f"--- Truth topology (averaged) ---\n density: {density_ori} \n diameter: {diameter_ori} "
      f"\n clustering coefficient: {cluster_coef_ori} \n edges: {edges_ori} \n avgerage degree {avg_degree_ori}\n")
print(f"--- Reconstructed topology (averaged) ---\n density: {density_hat} \n diameter: {diameter_hat} "
      f"\n clustering coefficient: {cluster_coef_hat} \n edges: {edges_hat} \n avgerage degree {avg_degree_hat}\n")






Model loss 2430.152518136161 
VAE performance:  
 Accuracy: 0.9069385714285654,  
 Precision: 0.7047642453256331,  
 Recall: 0.8931318175829427,  
 F1 Score: 0.7878451997051672

--- Truth topology (averaged) ---
 density: 0.25517142857142794 
 diameter: -1.0 
 clustering coefficient: 0.2777762870649428 
 edges: 48.482571428571426 
 avgerage degree 0.25517142857142794

--- Reconstructed topology (averaged) ---
 density: 0.3268285714285696 
 diameter: -1.0 
 clustering coefficient: 0.4490918726394994 
 edges: 62.09742857142857 
 avgerage degree 0.3268285714285696



In [3]:

# ############################# Steering GAN   ####################################

# ## training tip: same batch, same alpha!

# # w = torch.randn_like(batched_z[0][0], requires_grad=True).unsqueeze(0).to(device)
# w = torch.load(param_path + "/w_density.pt")
# a_w1 = torch.load(param_path + "/a_w1_density.pt")
# a_w2 = torch.load(param_path + "/a_w2_density.pt")
# a_b1 = torch.load(param_path + "/a_b1_density.pt")
# a_b2 = torch.load(param_path + "/a_b2_density.pt")




# ### Initialize generator
# generator = Decoder_v2(modelArgs, trainArgs, device).to(device)

# decoder_weight = dict(vae.decoder.named_parameters())
# generator_weight = dict(generator.named_parameters())
# for k in generator_weight.keys():
#     assert k in decoder_weight
#     generator_weight[k] = decoder_weight[k]
# generator.eval()


# discriminator.eval()
## operation = "transitivity", "density", "node_count"

generator = torch.load("generator.model")

# transform = GraphTransform(dataArgs["max_n_node"], operation = operation_name, sigmoid = False)
transform = GraphTransform(dataArgs["max_n_node"], operation = operation_name, sigmoid = False)
w_epochs = 1  ### adjust epoch here!!!


In [4]:

loss_train = []
w_A_train = []
w_A_hat_train = []
w_edit_A_hat_train = []
w_gen_A_hat_train = []
gen_A_raw_train = []
gen_A_max_train = []
gen_A_min_train = []
masked_norm_A_hats = []

for e in range(w_epochs):
    for i in tqdm(range(len(batched_A_hat_discretized))):

        fil = batched_gcn_filters_from_A_hat[i].float().to(device)
        attr_hat = batched_Attr_hat[i].float().to(device)
        # A_hat = batched_A_hat[i].to(device)
        A_hat = batched_A_hat_discretized[i].to(device)
        A = A_train[i]
        z = batched_z[i].to(device)

        ## discretize
        # A = A_train[i].cpu().numpy().squeeze(-1)
        # A_hat = A_hat.cpu().numpy().squeeze(-1)
        # discretizer = Discretizer(A, A_hat)
        # A_hat = discretizer.discretize('hard_threshold')
        # A = torch.unsqueeze(torch.from_numpy(A), -1)
        # A_hat = torch.unsqueeze(torch.from_numpy(A_hat), -1)

#         _, alpha_edit = transform.get_train_alpha(A_hat)
        _, alpha_edit = 0, 0.35
        # alpha_gen = a_w2 * F.relu(a_w1 * alpha_edit + a_b1) + a_b2
        sign = 1 if alpha_edit < 0 else -1
        alpha_gen = sign * torch.log(torch.abs(torch.tensor(alpha_edit)))

        ## first get edit and D(edit(G(z)))
        
        
        edit_attr = attr_hat
        edit_A = transform.get_target_graph(alpha_edit, A_hat, list(Param_train[i][:,-1].type(torch.LongTensor)))  # replace this with the edit(G(z)) attr & filter! Expect do all graphs in batch in one step!!
        # print(alpha_edit, alpha_gen)

#         temp = edit_A.detach().cpu()
#         edit_fil = preprocess_adj_tensor_with_identity(torch.squeeze(temp, -1), symmetric = False).to(device)
#         feature_edit, _ = discriminator(edit_attr.float(), edit_fil.float())


        # Then get G(z + aw) and D(G(z + aw))
    
#         gen_A, gen_attr, gen_A_raw, gen_A_max, gen_A_min = generator(z + alpha_gen * w)
        gen_A, gen_attr, gen_A_raw, gen_A_max, gen_A_min = generator(alpha_gen, z)
#         temp = gen_A.detach().cpu()
#         gen_fil = preprocess_adj_tensor_with_identity(torch.squeeze(temp, -1), symmetric = False).to(device)
#         feature_gen, preds = discriminator(gen_attr.float(), gen_fil.float())
#         labels = torch.ones(edit_attr.shape[0]).to(device)



        if e + 1 == w_epochs:
            w_A_train.append(A)
            w_A_hat_train.append(A_hat)
            w_edit_A_hat_train.append(edit_A)
            w_gen_A_hat_train.append(gen_A.detach())
            gen_A_raw_train.append(gen_A_raw.detach())
            masked_norm_A = masked_normalization(gen_A_raw.detach(), Param_train[i])
            masked_norm_A_hats.append(masked_norm_A)
            gen_A_max_train.append(gen_A_max.detach())
            gen_A_min_train.append(gen_A_min.detach())

print("====================== G(z + aw) v.s. edit(G(z))  results =============================")
gen_A_discretized = debugDiscretizer(w_A_hat_train, w_edit_A_hat_train, gen_A_raw_train, gen_A_max_train, gen_A_min_train, w_gen_A_hat_train, masked_norm_A_hats, discretize_method="hard_threshold", printMatrix=False, abortPickle=True)


#debugDecoder(w_edit_A_hat_train, [], w_gen_A_hat_train, [], discretize_method="hard_threshold", printMatrix=True)
# drawGraph(w_A_train, w_A_hat_train, w_edit_A_hat_train, w_gen_A_hat_train)
# drawGraphSaveFigure(w_A_train, w_A_hat_train, w_edit_A_hat_train, w_gen_A_hat_train, clearImage=True)


  0%|                                                                                            | 0/7 [00:00<?, ?it/s]


TypeError: forward() missing 1 required positional argument: 'z'

In [5]:
density_ori, diameter_ori, cluster_coef_ori, edges_ori, avg_degree_ori = topological_measure(batched_A_hat_discretized) # G(z)
density_ed, diameter_ed, cluster_coef_ed, edges_ed, avg_degree_ed = topological_measure(w_edit_A_hat_train) # edit(G(z), edit_alpha)
density_hat, diameter_hat, cluster_coef_hat, edges_hat, avg_degree_hat = topological_measure([torch.from_numpy(gen_A_discretized)]) # G(z + alpaha * w)
print(f"--- Original topology (averaged) ---\n density: {density_ori} \n diameter: {diameter_ori} "
      f"\n clustering coefficient: {cluster_coef_ori} \n edges: {edges_ori} \n avgerage degree {avg_degree_ori}\n")

print(f"--- Edit topology (averaged) ---\n density: {density_ed} \n diameter: {diameter_ed} "
      f"\n clustering coefficient: {cluster_coef_ed} \n edges: {edges_ed} \n avgerage degree {avg_degree_ed}\n")

print(f"--- Reconstructed topology (averaged) ---\n density: {density_hat} \n diameter: {diameter_hat} "
      f"\n clustering coefficient: {cluster_coef_hat} \n edges: {edges_hat} \n avgerage degree {avg_degree_hat}\n")

--- Original topology (averaged) ---
 density: 0.23133233082706686 
 diameter: -1.0 
 clustering coefficient: 0.40446849661076467 
 edges: 43.95314285714286 
 avgerage degree 0.23133233082706683

--- Edit topology (averaged) ---
 density: 0.3704706766917278 
 diameter: -0.9988571428571429 
 clustering coefficient: 0.5145780875193894 
 edges: 70.38942857142857 
 avgerage degree 0.3704706766917278

--- Reconstructed topology (averaged) ---
 density: 0.2622511278195479 
 diameter: -1.0 
 clustering coefficient: 0.3949161427834552 
 edges: 49.827714285714286 
 avgerage degree 0.2622511278195479



In [34]:
global_norm_A = []
global_edit_A = []
for idx,alpha in enumerate([0.01, 0.1, 0.2, 0.3, 0.4, 0.5]):
    print(f"!!!!!!!!!!!!!!==========================edit_alpha = {alpha} ===================================!!!!!!!!!!!!!!")
    loss_train = []
    w_A_train = []
    w_A_hat_train = []
    w_edit_A_hat_train = []
    w_gen_A_hat_train = []
    gen_A_raw_train = []
    gen_A_max_train = []
    gen_A_min_train = []
    masked_norm_A_hats = []
#     threshold = 0.5 - 0.05 * (idx + 1)
    
    for e in range(w_epochs):
        for i in tqdm(range(len(batched_A_hat_discretized))):

            fil = batched_gcn_filters_from_A_hat[i].float().to(device)
            attr_hat = batched_Attr_hat[i].float().to(device)
            # A_hat = batched_A_hat[i].to(device)
            A_hat = batched_A_hat_discretized[i].to(device)
            A = A_train[i]
            z = batched_z[i].to(device)
            hard_mask = (1 - A_hat).squeeze(-1)

            _, alpha_edit = 0, alpha
        
#             sign = 1 if alpha_edit < 0 else -1
            alpha_gen = -1 * torch.log(torch.abs(torch.tensor(alpha_edit)))
            print(alpha_gen)

            edit_attr = attr_hat
            edit_A = transform.get_target_graph(alpha_edit, A_hat, list(Param_train[i][:,-1].type(torch.LongTensor)))  # replace this with the edit(G(z)) attr & filter! Expect do all graphs in batch in one step!!

#             gen_A, gen_attr, gen_A_raw, gen_A_max, gen_A_min = generator(z + alpha_gen * w)
            gen_A, gen_attr, gen_A_raw, gen_A_max, gen_A_min = generator(alpha_gen, hard_mask, z)

            if e + 1 == w_epochs:
                w_A_train.append(A)
                w_A_hat_train.append(A_hat)
                w_edit_A_hat_train.append(edit_A)
                w_gen_A_hat_train.append(gen_A.detach())
                gen_A_raw_train.append(gen_A_raw.detach())
                masked_norm_A = masked_normalization(gen_A_raw.detach(), Param_train[i])
                masked_norm_A_hats.append(masked_norm_A)
                print(f"Avaerge probability: {torch.mean(masked_norm_A[masked_norm_A > 1e-3])}")
                print(f"Avaerge probability variance: {torch.var(masked_norm_A[masked_norm_A > 1e-3])}")
                gen_A_max_train.append(gen_A_max.detach())
                gen_A_min_train.append(gen_A_min.detach())
                
        global_norm_A.append(masked_norm_A_hats)
        global_edit_A.append(w_edit_A_hat_train)


    print("====================== G(z + aw) v.s. edit(G(z))  results =============================")
    gen_A_discretized = debugDiscretizer(w_A_hat_train, w_edit_A_hat_train, gen_A_raw_train, gen_A_max_train, gen_A_min_train, w_gen_A_hat_train, masked_norm_A_hats, discretize_method="hard_threshold", printMatrix=False, abortPickle=True, threshold=0.45)

    density_ori, diameter_ori, cluster_coef_ori, edges_ori, avg_degree_ori = topological_measure(batched_A_hat_discretized) # G(z)
    density_ed, diameter_ed, cluster_coef_ed, edges_ed, avg_degree_ed = topological_measure(w_edit_A_hat_train) # edit(G(z), edit_alpha)
    density_hat, diameter_hat, cluster_coef_hat, edges_hat, avg_degree_hat = topological_measure([torch.from_numpy(gen_A_discretized)]) # G(z + alpaha * w)
    print(f"--- Original topology (averaged) ---\n density: {density_ori} \n diameter: {diameter_ori} "
          f"\n clustering coefficient: {cluster_coef_ori} \n edges: {edges_ori} \n avgerage degree {avg_degree_ori}\n")

    print(f"--- Edit topology (averaged) ---\n density: {density_ed} \n diameter: {diameter_ed} "
          f"\n clustering coefficient: {cluster_coef_ed} \n edges: {edges_ed} \n avgerage degree {avg_degree_ed}\n")

    print(f"--- Reconstructed topology (averaged) ---\n density: {density_hat} \n diameter: {diameter_hat} "
          f"\n clustering coefficient: {cluster_coef_hat} \n edges: {edges_hat} \n avgerage degree {avg_degree_hat}\n")
    print("\n\n")


  0%|                                                                                            | 0/7 [00:00<?, ?it/s]

tensor(4.6052)


 14%|████████████                                                                        | 1/7 [00:00<00:03,  1.73it/s]

Avaerge probability: 0.5991808176040649
Avaerge probability variance: 0.07136798650026321
tensor(4.6052)


 29%|████████████████████████                                                            | 2/7 [00:01<00:02,  1.77it/s]

Avaerge probability: 0.6157324314117432
Avaerge probability variance: 0.07353327423334122
tensor(4.6052)


 43%|████████████████████████████████████                                                | 3/7 [00:01<00:02,  1.79it/s]

Avaerge probability: 0.6025481224060059
Avaerge probability variance: 0.06679567694664001
tensor(4.6052)


 57%|████████████████████████████████████████████████                                    | 4/7 [00:02<00:01,  1.80it/s]

Avaerge probability: 0.6228896975517273
Avaerge probability variance: 0.06766042113304138
tensor(4.6052)


 71%|████████████████████████████████████████████████████████████                        | 5/7 [00:02<00:01,  1.83it/s]

Avaerge probability: 0.619803786277771
Avaerge probability variance: 0.07104763388633728
tensor(4.6052)


 86%|████████████████████████████████████████████████████████████████████████            | 6/7 [00:03<00:00,  1.89it/s]

Avaerge probability: 0.6192256808280945
Avaerge probability variance: 0.07571703940629959
tensor(4.6052)


100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:03<00:00,  1.91it/s]

Avaerge probability: 0.6169819235801697
Avaerge probability variance: 0.07275968044996262





accuracy: 0.9609164285714216
precision: 0.8719941481555349
recall: 0.8670081318237295
f1 score: 0.8694939921097394


  0%|                                                                                            | 0/7 [00:00<?, ?it/s]

--- Original topology (averaged) ---
 density: 0.2665082706766896 
 diameter: -1.0 
 clustering coefficient: 0.4490918726394994 
 edges: 50.63657142857143 
 avgerage degree 0.2665082706766896

--- Edit topology (averaged) ---
 density: 0.3327473684210505 
 diameter: -0.9991428571428571 
 clustering coefficient: 0.4588762045238918 
 edges: 63.222 
 avgerage degree 0.3327473684210505

--- Reconstructed topology (averaged) ---
 density: 0.3352706766917273 
 diameter: -1.0 
 clustering coefficient: 0.4673617824682325 
 edges: 63.70142857142857 
 avgerage degree 0.3352706766917273




tensor(2.3026)


 14%|████████████                                                                        | 1/7 [00:00<00:03,  1.94it/s]

Avaerge probability: 0.6047631502151489
Avaerge probability variance: 0.0698051005601883
tensor(2.3026)


 29%|████████████████████████                                                            | 2/7 [00:01<00:02,  1.92it/s]

Avaerge probability: 0.6212549209594727
Avaerge probability variance: 0.07037972658872604
tensor(2.3026)


 43%|████████████████████████████████████                                                | 3/7 [00:01<00:02,  1.89it/s]

Avaerge probability: 0.6081829071044922
Avaerge probability variance: 0.06415673345327377
tensor(2.3026)


 57%|████████████████████████████████████████████████                                    | 4/7 [00:02<00:01,  1.85it/s]

Avaerge probability: 0.627957820892334
Avaerge probability variance: 0.0653679296374321
tensor(2.3026)


 71%|████████████████████████████████████████████████████████████                        | 5/7 [00:02<00:01,  1.83it/s]

Avaerge probability: 0.6247804760932922
Avaerge probability variance: 0.06949345767498016
tensor(2.3026)


 86%|████████████████████████████████████████████████████████████████████████            | 6/7 [00:03<00:00,  1.86it/s]

Avaerge probability: 0.6239209771156311
Avaerge probability variance: 0.07304003834724426
tensor(2.3026)


100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:03<00:00,  1.89it/s]

Avaerge probability: 0.6222915649414062
Avaerge probability variance: 0.07006431370973587





accuracy: 0.9574957142857082
precision: 0.8974745462874881
recall: 0.8405118017032981
f1 score: 0.8680596930522559


  0%|                                                                                            | 0/7 [00:00<?, ?it/s]

--- Original topology (averaged) ---
 density: 0.2665082706766896 
 diameter: -1.0 
 clustering coefficient: 0.4490918726394994 
 edges: 50.63657142857143 
 avgerage degree 0.2665082706766896

--- Edit topology (averaged) ---
 density: 0.35350827067669005 
 diameter: -0.9985714285714286 
 clustering coefficient: 0.488731782625577 
 edges: 67.16657142857143 
 avgerage degree 0.35350827067669

--- Reconstructed topology (averaged) ---
 density: 0.33625864661653965 
 diameter: -1.0 
 clustering coefficient: 0.4728694258729137 
 edges: 63.88914285714286 
 avgerage degree 0.33625864661653965




tensor(1.6094)


 14%|████████████                                                                        | 1/7 [00:00<00:03,  1.68it/s]

Avaerge probability: 0.6056696772575378
Avaerge probability variance: 0.06926167756319046
tensor(1.6094)


 29%|████████████████████████                                                            | 2/7 [00:01<00:03,  1.66it/s]

Avaerge probability: 0.6227304339408875
Avaerge probability variance: 0.06939959526062012
tensor(1.6094)


 43%|████████████████████████████████████                                                | 3/7 [00:01<00:02,  1.72it/s]

Avaerge probability: 0.6085306406021118
Avaerge probability variance: 0.06334429234266281
tensor(1.6094)


 57%|████████████████████████████████████████████████                                    | 4/7 [00:02<00:01,  1.72it/s]

Avaerge probability: 0.6286075115203857
Avaerge probability variance: 0.06479281187057495
tensor(1.6094)


 71%|████████████████████████████████████████████████████████████                        | 5/7 [00:02<00:01,  1.74it/s]

Avaerge probability: 0.6250643730163574
Avaerge probability variance: 0.06909868866205215
tensor(1.6094)


 86%|████████████████████████████████████████████████████████████████████████            | 6/7 [00:03<00:00,  1.78it/s]

Avaerge probability: 0.6251127123832703
Avaerge probability variance: 0.0722828358411789
tensor(1.6094)


100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:03<00:00,  1.80it/s]

Avaerge probability: 0.6233156323432922
Avaerge probability variance: 0.06925950944423676





accuracy: 0.948222857142849
precision: 0.9236415813634271
recall: 0.8130273267544136
f1 score: 0.8648117580327768


  0%|                                                                                            | 0/7 [00:00<?, ?it/s]

--- Original topology (averaged) ---
 density: 0.2665082706766896 
 diameter: -1.0 
 clustering coefficient: 0.4490918726394994 
 edges: 50.63657142857143 
 avgerage degree 0.2665082706766896

--- Edit topology (averaged) ---
 density: 0.3739428571428541 
 diameter: -0.9974285714285714 
 clustering coefficient: 0.5133804087437946 
 edges: 71.04914285714285 
 avgerage degree 0.3739428571428541

--- Reconstructed topology (averaged) ---
 density: 0.33666466165413295 
 diameter: -1.0 
 clustering coefficient: 0.47730550020241286 
 edges: 63.96628571428572 
 avgerage degree 0.33666466165413295




tensor(1.2040)


 14%|████████████                                                                        | 1/7 [00:00<00:03,  1.64it/s]

Avaerge probability: 0.6061468720436096
Avaerge probability variance: 0.06891769170761108
tensor(1.2040)


 29%|████████████████████████                                                            | 2/7 [00:01<00:03,  1.60it/s]

Avaerge probability: 0.6237773895263672
Avaerge probability variance: 0.06875244528055191
tensor(1.2040)


 43%|████████████████████████████████████                                                | 3/7 [00:01<00:02,  1.57it/s]

Avaerge probability: 0.6090161204338074
Avaerge probability variance: 0.06283501535654068
tensor(1.2040)


 57%|████████████████████████████████████████████████                                    | 4/7 [00:02<00:01,  1.57it/s]

Avaerge probability: 0.6290555000305176
Avaerge probability variance: 0.06446446478366852
tensor(1.2040)


 71%|████████████████████████████████████████████████████████████                        | 5/7 [00:03<00:01,  1.54it/s]

Avaerge probability: 0.6252957582473755
Avaerge probability variance: 0.06886924803256989
tensor(1.2040)


 86%|████████████████████████████████████████████████████████████████████████            | 6/7 [00:03<00:00,  1.57it/s]

Avaerge probability: 0.625827968120575
Avaerge probability variance: 0.07181653380393982
tensor(1.2040)


100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:04<00:00,  1.59it/s]

Avaerge probability: 0.6241669058799744
Avaerge probability variance: 0.06873751431703568





accuracy: 0.9374292857142807
precision: 0.9399929607239249
recall: 0.7926671565396577
f1 score: 0.8600665992371064


  0%|                                                                                            | 0/7 [00:00<?, ?it/s]

--- Original topology (averaged) ---
 density: 0.2665082706766896 
 diameter: -1.0 
 clustering coefficient: 0.4490918726394994 
 edges: 50.63657142857143 
 avgerage degree 0.2665082706766896

--- Edit topology (averaged) ---
 density: 0.39142706766916996 
 diameter: -0.9977142857142857 
 clustering coefficient: 0.5311423191801565 
 edges: 74.37114285714286 
 avgerage degree 0.39142706766916996

--- Reconstructed topology (averaged) ---
 density: 0.33684210526315556 
 diameter: -1.0 
 clustering coefficient: 0.4798516042827392 
 edges: 64.0 
 avgerage degree 0.33684210526315556




tensor(0.9163)


 14%|████████████                                                                        | 1/7 [00:00<00:03,  1.63it/s]

Avaerge probability: 0.6064980030059814
Avaerge probability variance: 0.06870473176240921
tensor(0.9163)


 29%|████████████████████████                                                            | 2/7 [00:01<00:03,  1.63it/s]

Avaerge probability: 0.624734103679657
Avaerge probability variance: 0.06827861070632935
tensor(0.9163)


 43%|████████████████████████████████████                                                | 3/7 [00:01<00:02,  1.62it/s]

Avaerge probability: 0.6095496416091919
Avaerge probability variance: 0.062475547194480896
tensor(0.9163)


 57%|████████████████████████████████████████████████                                    | 4/7 [00:02<00:01,  1.57it/s]

Avaerge probability: 0.6294769644737244
Avaerge probability variance: 0.06424878537654877
tensor(0.9163)


 71%|████████████████████████████████████████████████████████████                        | 5/7 [00:03<00:01,  1.59it/s]

Avaerge probability: 0.6254751682281494
Avaerge probability variance: 0.0686982199549675
tensor(0.9163)


 86%|████████████████████████████████████████████████████████████████████████            | 6/7 [00:03<00:00,  1.55it/s]

Avaerge probability: 0.6264995336532593
Avaerge probability variance: 0.07147957384586334
tensor(0.9163)


100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:04<00:00,  1.58it/s]

Avaerge probability: 0.6249455213546753
Avaerge probability variance: 0.06835662573575974





accuracy: 0.9265264285714221
precision: 0.9549406577320975
recall: 0.7809836953396978
f1 score: 0.8592460637884607


  0%|                                                                                            | 0/7 [00:00<?, ?it/s]

--- Original topology (averaged) ---
 density: 0.2665082706766896 
 diameter: -1.0 
 clustering coefficient: 0.4490918726394994 
 edges: 50.63657142857143 
 avgerage degree 0.2665082706766896

--- Edit topology (averaged) ---
 density: 0.407442105263155 
 diameter: -0.9977142857142857 
 clustering coefficient: 0.5486660316544731 
 edges: 77.414 
 avgerage degree 0.407442105263155

--- Reconstructed topology (averaged) ---
 density: 0.3388992481202986 
 diameter: -1.0 
 clustering coefficient: 0.48340318274629185 
 edges: 64.39085714285714 
 avgerage degree 0.3388992481202986




tensor(0.6931)


 14%|████████████                                                                        | 1/7 [00:00<00:03,  1.50it/s]

Avaerge probability: 0.606891930103302
Avaerge probability variance: 0.06855131685733795
tensor(0.6931)


 29%|████████████████████████                                                            | 2/7 [00:01<00:03,  1.53it/s]

Avaerge probability: 0.6256409883499146
Avaerge probability variance: 0.06792289763689041
tensor(0.6931)


 43%|████████████████████████████████████                                                | 3/7 [00:01<00:02,  1.51it/s]

Avaerge probability: 0.6101027131080627
Avaerge probability variance: 0.0622207410633564
tensor(0.6931)


 57%|████████████████████████████████████████████████                                    | 4/7 [00:02<00:02,  1.46it/s]

Avaerge probability: 0.6298903226852417
Avaerge probability variance: 0.0641031265258789
tensor(0.6931)


 71%|████████████████████████████████████████████████████████████                        | 5/7 [00:03<00:01,  1.51it/s]

Avaerge probability: 0.6256153583526611
Avaerge probability variance: 0.06856788694858551
tensor(0.6931)


 86%|████████████████████████████████████████████████████████████████████████            | 6/7 [00:03<00:00,  1.54it/s]

Avaerge probability: 0.6271616220474243
Avaerge probability variance: 0.07122446596622467
tensor(0.6931)


100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:04<00:00,  1.56it/s]

Avaerge probability: 0.625690221786499
Avaerge probability variance: 0.06806372851133347





accuracy: 0.9188185714285652
precision: 0.9656926779194589
recall: 0.7794108661044867
f1 score: 0.8626094068348745
--- Original topology (averaged) ---
 density: 0.2665082706766896 
 diameter: -1.0 
 clustering coefficient: 0.4490918726394994 
 edges: 50.63657142857143 
 avgerage degree 0.2665082706766896

--- Edit topology (averaged) ---
 density: 0.42026165413533556 
 diameter: -0.9968571428571429 
 clustering coefficient: 0.560851759804647 
 edges: 79.84971428571428 
 avgerage degree 0.42026165413533556

--- Reconstructed topology (averaged) ---
 density: 0.34321353383458403 
 diameter: -1.0 
 clustering coefficient: 0.4869254868520835 
 edges: 65.21057142857143 
 avgerage degree 0.34321353383458403






In [12]:
torch.nonzero(masked_norm_A)

	nonzero(Tensor input, *, Tensor out)
Consider using one of the following signatures instead:
	nonzero(Tensor input, *, bool as_tuple)


tensor([[  0,   0,   0],
        [  0,   0,   1],
        [  0,   0,   2],
        ...,
        [427,   5,   3],
        [427,   5,   4],
        [427,   5,   5]])

In [18]:
torch.mean(masked_norm_A[masked_norm_A > 1e-3])

tensor(0.4855)

In [11]:
def compare(i1,i2,i3):
    print(global_norm_A[i1][i2][i3].numpy())
    print(global_edit_A[i1][i2][i3].squeeze(-1).numpy().astype(int))

In [17]:
compare(0,0,124)

[[0.89 0.67 0.61 0.58 0.66 0.63 0.32 0.56 0.55 0.49 0.40 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
 [0.67 0.86 0.44 0.27 0.76 0.33 0.56 0.47 0.32 0.54 0.14 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
 [0.64 0.41 0.56 0.39 0.36 0.43 0.13 0.51 0.39 0.41 0.27 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
 [0.64 0.27 0.44 0.92 0.36 0.86 0.28 0.31 0.64 0.14 0.72 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
 [0.68 0.76 0.38 0.36 0.83 0.41 0.54 0.42 0.39 0.41 0.24 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
 [0.69 0.31 0.49 0.92 0.40 0.87 0.28 0.35 0.66 0.17 0.72 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
 [0.32 0.57 0.13 0.24 0.54 0.24 0.60 0.16 0.22 0.29 0.20 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
 [0.59 0.40 0.51 0.31 0.36 0.35 0.13 0.47 0.33 0.40 0.20 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
 [0.60 0.31 0.43 0.68 0.37 0.66 0.21 0.33 0.53 0.19 0.51 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
 [0.48 0.56 0.42 0.14 0.41 0.20 0.26 0.43 0.22 0.51 0.05 0.00 0.00 0.00 0

In [130]:
alpha_gen

tensor(-0.6931)

In [27]:
print(global_norm_A[1][0][124].numpy())

[[0.90 0.69 0.63 0.60 0.68 0.65 0.33 0.58 0.57 0.50 0.42 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
 [0.69 0.88 0.44 0.28 0.79 0.33 0.58 0.45 0.33 0.56 0.15 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
 [0.65 0.42 0.58 0.40 0.37 0.44 0.14 0.52 0.40 0.41 0.27 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
 [0.63 0.28 0.43 0.91 0.38 0.88 0.27 0.32 0.66 0.14 0.75 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
 [0.69 0.79 0.38 0.38 0.85 0.42 0.56 0.40 0.39 0.42 0.25 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
 [0.68 0.32 0.47 0.92 0.41 0.90 0.27 0.36 0.68 0.17 0.74 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
 [0.33 0.59 0.14 0.24 0.56 0.25 0.62 0.15 0.21 0.28 0.20 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
 [0.60 0.42 0.52 0.32 0.37 0.36 0.13 0.48 0.34 0.40 0.21 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
 [0.60 0.32 0.43 0.68 0.38 0.68 0.21 0.34 0.55 0.20 0.53 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
 [0.49 0.57 0.42 0.14 0.42 0.19 0.27 0.42 0.21 0.53 0.05 0.00 0.00 0.00 0

In [8]:
i,j = 0,77
print(global_edit_A[0][i][j].squeeze(-1).numpy().astype(int))
print("========================================================================")
print(global_norm_A[0][i][j].squeeze(-1).numpy())

[[1 0 0 1 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0]
 [0 1 1 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 1 0 1 1 0 1 1 0 0 0 0 0 0 0 0 0 0 0]
 [1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]
 [1 1 1 1 1 1 0 1 1 1 0 1 0 1 0 1 0 0 0 0]
 [0 0 1 0 1 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 1 0 1 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0]
 [0 1 1 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 1 1 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 1 0 0 0 0]
 [0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 0 0 0]
 [1 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0]
 [0 0 0 0 1 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0]
 [0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
[[0.67 0.40 0.49 0.60 0.67 0.46 0.56 0.43 0.56 0.32 0.54 0.40 0.56 0.44 0.18 0.54 0.21 0.05 0.00 0.00]
 [0.40 0.61 0.67 0.36 0.65 0.49 0.37

In [20]:
mask = 1 - global_edit_A[0][i][j].squeeze(-1).numpy().astype(int)

print(global_edit_A[1][i][j].squeeze(-1).numpy().astype(int) - global_edit_A[0][i][j].squeeze(-1).numpy().astype(int))
print()
print(mask * (global_norm_A[1][i][j].squeeze(-1).numpy() - global_norm_A[0][i][j].squeeze(-1).numpy()))
print("========================================================================")
print(global_edit_A[2][i][j].squeeze(-1).numpy().astype(int) - global_edit_A[1][i][j].squeeze(-1).numpy().astype(int))
print()
print(mask * (global_norm_A[2][i][j].squeeze(-1).numpy() - global_norm_A[1][i][j].squeeze(-1).numpy()))
print()
print(mask * (global_norm_A[2][i][j].squeeze(-1).numpy() - global_norm_A[0][i][j].squeeze(-1).numpy()))
print("========================================================================")
print(global_edit_A[3][i][j].squeeze(-1).numpy().astype(int) - global_edit_A[2][i][j].squeeze(-1).numpy().astype(int))
print()
print(mask * (global_norm_A[3][i][j].squeeze(-1).numpy() - global_norm_A[2][i][j].squeeze(-1).numpy()))
print()
print(mask * (global_norm_A[3][i][j].squeeze(-1).numpy() - global_norm_A[0][i][j].squeeze(-1).numpy()))

print("========================================================================")
print(global_edit_A[4][i][j].squeeze(-1).numpy().astype(int) - global_edit_A[3][i][j].squeeze(-1).numpy().astype(int))
print()
print(mask * (global_norm_A[4][i][j].squeeze(-1).numpy() - global_norm_A[3][i][j].squeeze(-1).numpy()))
print()
print(mask * (global_norm_A[4][i][j].squeeze(-1).numpy() - global_norm_A[0][i][j].squeeze(-1).numpy()))

print("========================================================================")
print(global_edit_A[5][i][j].squeeze(-1).numpy().astype(int) - global_edit_A[4][i][j].squeeze(-1).numpy().astype(int))
print()
print(mask * (global_norm_A[5][i][j].squeeze(-1).numpy() - global_norm_A[4][i][j].squeeze(-1).numpy()))
print()
print(mask * (global_norm_A[5][i][j].squeeze(-1).numpy() - global_norm_A[0][i][j].squeeze(-1).numpy()))

print("========================================================================")



# print()
# print(global_norm_A[3][i][j].squeeze(-1).numpy() - global_norm_A[2][i][j].squeeze(-1).numpy())

[[0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [1 0 0 1 0 0 0 0 0 1 1 1 0 1 0 1 0 0 0 0]
 [0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 1 0 1 0 1 0 1 0 0 0]
 [0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]
 [0 0 1 0 0 1 0 1 1 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]

[[0.00 0.04 0.05 0.00 0.00 0.05 -0.19 0.05 -0.11 0.03 0.06 0.04 -0.00 0.05 0.02 0.03 0.02 0.01 0.00 0.00]
 [0.04 0.00 0.00 0.04 0.00 0.05 

In [22]:
# global_edit_A[5][i][j].squeeze(-1).numpy().astype(int)

In [26]:
print(global_norm_A[0][i][j].squeeze(-1).numpy())
print()
print(global_norm_A[5][i][j].squeeze(-1).numpy())

[[0.67 0.40 0.49 0.60 0.67 0.46 0.56 0.43 0.56 0.32 0.54 0.40 0.56 0.44 0.18 0.54 0.21 0.05 0.00 0.00]
 [0.40 0.61 0.67 0.36 0.65 0.49 0.37 0.44 0.68 0.30 0.29 0.16 0.32 0.47 0.18 0.22 0.35 0.25 0.00 0.00]
 [0.49 0.67 0.74 0.43 0.71 0.65 0.36 0.51 0.63 0.32 0.42 0.21 0.38 0.54 0.21 0.27 0.36 0.22 0.00 0.00]
 [0.60 0.36 0.43 0.79 0.69 0.40 0.38 0.31 0.39 0.58 0.49 0.26 0.55 0.39 0.11 0.52 0.32 0.07 0.00 0.00]
 [0.60 0.62 0.71 0.69 0.87 0.63 0.29 0.45 0.61 0.63 0.33 0.26 0.43 0.56 0.16 0.29 0.41 0.16 0.00 0.00]
 [0.46 0.49 0.56 0.40 0.83 0.59 0.30 0.33 0.48 0.17 0.24 0.22 0.34 0.58 0.12 0.34 0.17 0.11 0.00 0.00]
 [0.10 0.37 0.36 0.13 0.34 0.19 0.55 0.25 0.47 0.23 0.19 0.03 0.25 0.25 0.15 0.03 0.94 0.62 0.00 0.00]
 [0.43 0.44 0.51 0.31 0.50 0.33 0.40 0.56 0.53 0.32 0.33 0.32 0.17 0.30 0.33 0.33 0.30 0.16 0.00 0.00]
 [0.25 0.54 0.57 0.26 0.51 0.39 0.47 0.37 0.68 0.24 0.24 0.08 0.38 0.48 0.16 0.11 0.49 0.35 0.00 0.00]
 [0.32 0.30 0.32 0.46 0.55 0.18 0.35 0.32 0.37 0.87 0.43 0.18 0.19 0.15 0