In [1]:
    dataset_path = "/home/mirek/data/stdlib-lgraph-intermediate-v9-global"
weights_path = "/home/mirek/graph2tac/graph2tac/tf2/weights/checkpoint__epoch299"
params_path = "/home/mirek/graph2tac/graph2tac/tf2/trained_params"

In [2]:
from pathlib import Path
from tqdm import tqdm
import tensorflow as tf
import numpy as np
from tensorflow import keras

from graph2tac.tf2.train import TrainingParams
from graph2tac.tf2.predict import Predict
from graph2tac.loader.data_server import DataServer
from graph2tac.tf2.graph_nn_def_batch import make_flat_def_batch_np
from graph2tac.tf2.model import np_to_tensor_def

In [3]:
def run_update_embeddings(predict, d, extra_iterations = 0):
    embeddings = predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings
    emb0 = np.array(predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings.numpy())
    start_def = predict.dataset_consts.base_node_type_num
    emb_def = embeddings[start_def:]
    print("emb_zeros_ori", tf.math.reduce_all(tf.equal(predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings[start_def:], 0.)))
    emb_def.assign(tf.zeros_like(emb_def))
    print("emb_zeros", tf.math.reduce_all(tf.equal(predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings[start_def:], 0.)))
    # recalculate the embeddings by their definitions
    clusters = d.def_cluster_subgraphs(tf_gnn=False)
    #print("Update embeddings")
    for cluster in tqdm(clusters, desc = "Update embedding"):
    #for cluster in clusters:
        nodes_c = np.array(cluster[0])
        root_nums = cluster[2]
        roots = nodes_c[:root_nums]
        nodes_c = nodes_c[root_nums:]
        #print(np.max(nodes_c), "->", roots)
        predict.compute_new_definitions([cluster])
    print("emb_updated_zeros", tf.math.reduce_all(tf.equal(predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings[start_def:], 0.)))
    emb1 = np.array(predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings.numpy())
    print("emb0 == emb1", np.max(np.abs(emb0 - emb1)))
    for i in range(extra_iterations):
        for cluster in tqdm(clusters, desc = "Update embedding{}".format(i+2)):
            predict.compute_new_definitions([cluster])
        emb2 = np.array(predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings.numpy())
        print("emb{} == emb{}".format(i+1, i+2), np.max(np.abs(emb1 - emb2)))
        emb1 = emb2

In [4]:
def check_def_loss(predict, d):
    def squared_norm_diff(
        emb1,  # [batch, dim]
        emb2,  # [batch, dim]
    ):  # -> [batch]
        diff = emb1 - emb2  # [batch, dim]
        norm_squared = tf.math.reduce_sum(tf.square(diff), axis=-1)  # [batch]
        return norm_squared  # [batch]

    def_loss_mean = tf.keras.metrics.Mean()
    @tf.function(input_signature=(predict.model_wrapper.def_input_spec,))
    def def_loss_step(batch):
        def_body_embs, def_id_embs = predict.model_wrapper.model_def(batch, training=False)
        def_loss = squared_norm_diff(def_body_embs.values, def_id_embs.values)
        #print(def_loss.numpy())
        def_loss_mean(def_loss)
    clusters = d.def_cluster_subgraphs(tf_gnn=False)
    for cluster in tqdm(clusters, desc = "Check def loss"):
        flat_batch_np = make_flat_def_batch_np([cluster])
        flat_batch = np_to_tensor_def(flat_batch_np)
        def_loss_step(flat_batch)
    print("Def loss:", def_loss_mean.result().numpy())

In [5]:
def check_predict_accuracy(predict, d, restrict_tactics = False, restrict_globargs = False, global_equiv = None):
    allowed_model_tactics = list(range(len(d.graph_constants().tactic_index_to_hash)))
    available_global = None
    score = {
        "no_arg" : [0,0],
        "local" : [0,0],
        "global" : [0,0],
        "both" : [0,0],
        "none" : [0,0],
    }
    correct = 0
    total = 0
    print(f"Restrictions: tactics {restrict_tactics}, globargs {restrict_globargs}")
    data = d.data_valid()
    data = tqdm(data, desc = "Valid check")
    for state, action, _ in data:
        if restrict_tactics:
            allowed_model_tactics = [action[0]]
        if restrict_globargs:
            _, args, mask_args = action
            args = args[mask_args]
            arg_is_global = args[:,0] == 1
            arg_vals = args[:,1]
            glob_args = arg_vals[arg_is_global]
            #available_global = np.array([
            #    predict.dataset_consts.global_context.index(x)
            #    for x in sorted(set(glob_args))
            #], dtype = int)
            available_global = np.array(sorted(set(glob_args)), dtype = int)
        ranked_actions, ranked_values = predict.ranked_predictions(
            state, allowed_model_tactics, available_global = available_global,
            tactic_expand_bound=1, total_expand_bound=1,
        )
        real_tactic_label, real_args, mask_args = action
        real_args_valid = real_args[mask_args]
        if len(ranked_actions) > 0: action = ranked_actions[0]
        else: action = None
        arg_types = real_args_valid[:,0]
        arg_values = real_args_valid[:,1]
        ctx_len = len(state[3])
        if arg_types.size == 0: cur_type = "no_arg"
        elif ((arg_types == 0) & (arg_values == ctx_len)).any(): cur_type = "none"
        elif (arg_types == 0).all(): cur_type = "local"
        elif (arg_types == 1).all(): cur_type = "global"
        else: cur_type = "both"
        score_t = score[cur_type]
        score_t[0] += 1
        #print("cur_type:", cur_type)
        #print("real_tactic_label:", real_tactic_label)
        #print("real_args_valid:", real_args_valid)
        #print("action:", action)
        #print()
        if action is not None:
            args_pred = action[1:]
            global_ctx = np.array(predict.dataset_consts.global_context)
            if global_equiv is not None:
                global_equiv = np.array(global_equiv)
                def convert_args(args):
                    args_t, args_val = args.T
                    args_glob = args_val[args_t == 1]
                    args_glob = global_equiv[global_ctx[args_glob]] # conversion
                    args_val[args_t == 1] = args_glob
                #print("----------------")
                #print(real_args_valid.shape, args_pred.shape)
                convert_args(real_args_valid)
                convert_args(args_pred)
                #print(real_args_valid.shape, args_pred.shape)
            if real_tactic_label == action[0,0] and (real_args_valid == args_pred).all():
                score_t[1] += 1
    print("Validation accuracy:")
    for t, score_t in score.items():
        print(f"  {t}: {score_t[1]} / {score_t[0]}")    

In [6]:
params = TrainingParams.from_yaml_file(Path(params_path))



In [7]:
params

TrainingParams(data_params=DataParams(split_seed=0, cv_fold=0, shuffle_def=True, bfs_option=False, max_subgraph_size=1024), optimizer_params=OptimizerParams(optimizer='adam', learning_rate=None, clipvalue=None, clipnorm=None, global_clipnorm=None, loss_weights=LossWeightParams(tactic_base=1.0, tactic_args=1.0, def_task=100.0, def_id_to_body=1.0, def_body_to_id=1.0), def_task=True), model_params=ModelParams(dataset_consts=None, ignore_definitions=False, normalize_def_embeddings=True, single_edge_type=False, symmetric_edges=True, self_edges=True, total_hops=10, node_dim=128, message_passing_layer='conv2', norm_msgs=False, nonlin_position=None, nonlin_type='relu', residuals=True, residual_dropout=True, residual_norm='layer_norm', final_collapse=True, aggreg_max=False, use_same_graph_nn_weights_for_def_training=True), tf_seed=42, tf_eager=False, enable_op_determinism=False, batch_size=50, num_epochs=300, num_proc=None)

In [8]:
import os
os.environ["G2T_LOG_LEVEL"] = "40"

In [9]:
d = DataServer(
    Path(dataset_path),
    Path('.'),
    cross_valid_fold=params.data_params.cv_fold,
    bfs_option=params.data_params.bfs_option,
    max_subgraph_size=params.data_params.max_subgraph_size,
    split_random_seed=params.data_params.split_seed,
    split=(0,1,0),
    global_args = True,
    num_proc=params.num_proc
)

Data: using be613e for processed data storage
LOADING | 
LOADING | 526
LOADING | c_data call complete
LOADING | building dataset complete in 34.3323290348053 seconds
LOADING | splitting data to train, valid, test complete in 0.702211856842041 seconds
LOADING | building def clusters complete 0.49122166633605957 seconds


In [10]:
predict = Predict(Path(weights_path))

In [11]:
check_def_loss(predict, d)

LOADING | requested 16630


100%|██████████████████████████████████████████████████████| 16630/16630 [00:01<00:00, 10686.33it/s]
Check def loss: 100%|█████████████████████████████████████████| 16630/16630 [08:31<00:00, 32.54it/s]


Def loss: 0.0014793068


In [12]:
check_predict_accuracy(predict, d, restrict_tactics = False, restrict_globargs = False)

Restrictions: tactics False, globargs False


Valid check: 100%|████████████████████████████████████████| 198207/198207 [2:36:57<00:00, 21.05it/s]

Validation accuracy:
  no_arg: 82501 / 98146
  local: 21475 / 31200
  global: 34901 / 59958
  both: 2469 / 5207
  none: 0 / 3696





In [13]:
check_predict_accuracy(predict, d, restrict_tactics = True, restrict_globargs = True)

Restrictions: tactics True, globargs True


Valid check: 100%|████████████████████████████████████████| 198207/198207 [2:36:48<00:00, 21.07it/s]

Validation accuracy:
  no_arg: 98146 / 98146
  local: 28425 / 31200
  global: 56920 / 59958
  both: 4590 / 5207
  none: 0 / 3696





In [14]:
check_predict_accuracy(predict, d, restrict_tactics = True, restrict_globargs = False)

Restrictions: tactics True, globargs False


Valid check: 100%|████████████████████████████████████████| 198207/198207 [2:34:04<00:00, 21.44it/s]

Validation accuracy:
  no_arg: 98146 / 98146
  local: 27755 / 31200
  global: 46048 / 59958
  both: 3649 / 5207
  none: 0 / 3696





In [15]:
check_predict_accuracy(predict, d, restrict_tactics = False, restrict_globargs = True)

Restrictions: tactics False, globargs True


Valid check: 100%|████████████████████████████████████████| 198207/198207 [2:23:32<00:00, 23.01it/s]

Validation accuracy:
  no_arg: 82501 / 98146
  local: 21793 / 31200
  global: 41250 / 59958
  both: 2959 / 5207
  none: 0 / 3696





In [16]:
run_update_embeddings(predict, d)

emb_zeros_ori tf.Tensor(False, shape=(), dtype=bool)
emb_zeros tf.Tensor(True, shape=(), dtype=bool)
LOADING | requested 16630


Update embedding: 100%|███████████████████████████████████████| 16630/16630 [08:55<00:00, 31.05it/s]

emb_updated_zeros tf.Tensor(False, shape=(), dtype=bool)
emb0 == emb1 4.303076





In [17]:
check_def_loss(predict, d)

LOADING | requested 16630


Check def loss: 100%|█████████████████████████████████████████| 16630/16630 [08:40<00:00, 31.95it/s]


Def loss: 1.821816e-14


In [18]:
nt_to_name = predict.dataset_consts.node_type_to_name
emb = np.array(predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings.numpy())

In [36]:
radius = 1e-6
nt_to_eqc = np.arange(emb.shape[0])
for n in tqdm(range(1, len(nt_to_eqc))):
    dists = np.abs(emb[n] - emb[:n])
    dists = np.max(dists, axis = 1)
    min_dist_n = np.argmin(dists)
    min_dist = dists[min_dist_n]
    if min_dist <= radius:
        nt_to_eqc[n] = nt_to_eqc[min_dist_n]

100%|████████████████████████████████████████████████████████| 17512/17512 [01:33<00:00, 188.21it/s]


In [37]:
nt_to_eqc

array([    0,     1,     2, ..., 17510, 17511, 17512])

In [38]:
nt_eq_classes = dict()
for n,c in enumerate(nt_to_eqc):
    if c in nt_eq_classes: nt_eq_classes[c].append(n)
    else: nt_eq_classes[c] = [n]
nt_eq_classes = sorted(nt_eq_classes.values(), key = lambda x: (-len(x), x))
len(nt_eq_classes)

15943

In [39]:
for cl in nt_eq_classes:
    if len(cl) > 1: print([nt_to_name[x] for x in cl])

['Coq.Init.Byte.xae', 'Coq.Init.Byte.x9d', 'Coq.Init.Byte.xe4', 'Coq.Init.Byte.x5d', 'Coq.Init.Byte.x98', 'Coq.Init.Byte.x0d', 'Coq.Init.Byte.x85', 'Coq.Init.Byte.x78', 'Coq.Init.Byte.x3f', 'Coq.Init.Byte.x1f', 'Coq.Init.Byte.xdf', 'Coq.Init.Byte.x62', 'Coq.Init.Byte.x65', 'Coq.Init.Byte.x1b', 'Coq.Init.Byte.xee', 'Coq.Init.Byte.xc1', 'Coq.Init.Byte.x48', 'Coq.Init.Byte.xb6', 'Coq.Init.Byte.x49', 'Coq.Init.Byte.x00', 'Coq.Init.Byte.x82', 'Coq.Init.Byte.xd4', 'Coq.Init.Byte.xb8', 'Coq.Init.Byte.xe3', 'Coq.Init.Byte.x41', 'Coq.Init.Byte.x84', 'Coq.Init.Byte.x24', 'Coq.Init.Byte.xbb', 'Coq.Init.Byte.x03', 'Coq.Init.Byte.xf7', 'Coq.Init.Byte.xb0', 'Coq.Init.Byte.x45', 'Coq.Init.Byte.x4c', 'Coq.Init.Byte.xb1', 'Coq.Init.Byte.x89', 'Coq.Init.Byte.x8e', 'Coq.Init.Byte.x08', 'Coq.Init.Byte.xe9', 'Coq.Init.Byte.xa3', 'Coq.Init.Byte.xa8', 'Coq.Init.Byte.x56', 'Coq.Init.Byte.x74', 'Coq.Init.Byte.x33', 'Coq.Init.Byte.x52', 'Coq.Init.Byte.x76', 'Coq.Init.Byte.xba', 'Coq.Init.Byte.x12', 'Coq.Init.By

In [40]:
nt_to_name.index("Coq.Numbers.Cyclic.Abstract.CyclicAxioms.ZnZ.spec_div21")

9766

In [41]:
nt_to_name.index("Coq.Numbers.Cyclic.Abstract.CyclicAxioms.ZnZ.spec_add_carry_c")

15741

In [47]:
emb[9766]

array([ 0.14054923, -0.04473993,  0.08690224, -0.02388918,  0.02586977,
       -0.00946436,  0.04641248,  0.1355759 ,  0.08986658,  0.07203969,
       -0.07628072, -0.03913284,  0.13116364, -0.04516129,  0.20110899,
       -0.20312256,  0.06257287, -0.02988453, -0.03029349, -0.01568686,
       -0.15514576,  0.03228218, -0.01818803,  0.04229537, -0.10691108,
        0.03156586,  0.02886117, -0.01097399,  0.1339615 , -0.07812891,
        0.07290291, -0.06301636, -0.02028232,  0.03604344,  0.04409381,
       -0.01240731,  0.13607772, -0.07737188,  0.02698449, -0.05652228,
        0.07383179,  0.03441898, -0.02360063,  0.01186982,  0.03559583,
        0.00156493, -0.06959265,  0.08477193,  0.11025506, -0.05358714,
        0.08345757, -0.04631991, -0.1370102 , -0.11418301,  0.18243027,
       -0.01983626, -0.18766038,  0.2691925 , -0.02898286, -0.12781386,
       -0.05086587, -0.00190134, -0.11870609,  0.03551312,  0.03886468,
       -0.01524484,  0.00514902, -0.17779914, -0.00921541,  0.18

In [48]:
emb[15741]

array([ 0.14054911, -0.04474   ,  0.08690237, -0.02388951,  0.02586984,
       -0.00946432,  0.0464124 ,  0.13557598,  0.08986656,  0.07203986,
       -0.07628078, -0.03913295,  0.13116376, -0.04516151,  0.20110892,
       -0.20312257,  0.06257293, -0.02988463, -0.03029353, -0.01568667,
       -0.15514562,  0.03228221, -0.01818779,  0.0422956 , -0.10691135,
        0.03156582,  0.02886078, -0.0109741 ,  0.1339612 , -0.07812901,
        0.07290281, -0.06301681, -0.02028219,  0.0360433 ,  0.04409389,
       -0.01240722,  0.13607746, -0.07737202,  0.02698488, -0.05652205,
        0.07383217,  0.03441947, -0.02360078,  0.01186919,  0.0355956 ,
        0.00156488, -0.06959283,  0.0847717 ,  0.11025478, -0.053587  ,
        0.08345757, -0.04632008, -0.13701   , -0.1141829 ,  0.18243031,
       -0.0198363 , -0.18766066,  0.26919252, -0.02898305, -0.12781452,
       -0.05086588, -0.00190127, -0.11870638,  0.03551309,  0.0388649 ,
       -0.01524493,  0.00514954, -0.17779931, -0.00921538,  0.18

In [23]:
# how often an ambiguous definition appear in the dataset?
global_ctx = predict.dataset_consts.global_context
ambig_args = set()
for cl in nt_eq_classes:
    if len(cl) > 1: ambig_args.update(cl)
data = d.data_valid()
is_ambig_per_arg = []
is_ambig_per_tac = []
for state, action, _ in data:
    _, args, mask_args = action
    args = args[mask_args]
    args_t, args_val = args.T
    args_global = args_val[args_t == 1]
    for arg in args_global:
        is_ambig_per_arg.append(global_ctx[arg] in ambig_args)

    ctx_len = len(state[3])
    if ((args_t == 0) & (args_val == ctx_len)).any():
        continue

    if len(args_global) > 0:
        is_ambig_per_tac.append(any(
            global_ctx[arg] in ambig_args
            for arg in args_global
        ))

In [24]:
print(f"ambiguous per argument: {np.sum(is_ambig_per_arg)} / {len(is_ambig_per_arg)} = {np.mean(is_ambig_per_arg)}")
print(f"ambiguous per tactic: {np.sum(is_ambig_per_tac)} / {len(is_ambig_per_tac)} = {np.mean(is_ambig_per_tac)}")

ambiguous per argument: 14127 / 72235 = 0.1955700145358898
ambiguous per tactic: 13055 / 65165 = 0.20033760454231567


In [25]:
# TODO: find out why some particular classes (Lasse on Slack) are equal
#  . check if the embedding is not close to zero
#  . try to tweak epsilon
#  . try random network
#    . try a network without trained definition loss
#  . try proper "hashing"
# TODO:
# . find which ambiguous classes occur the most in the dataset (calculate statistics)
# TODO:
# . check how the drop in accuracy relates to the difference between original a new embedding

# TODO:
# . top-k accuracy
# . hyperparameters:
#   . tactic_expand_bound = 10
#   . total_expand_bound = try 10 / 100 / 1000 ... 10 000 000 000

In [26]:
check_predict_accuracy(predict, d, global_equiv = nt_to_eqc, restrict_tactics = False, restrict_globargs = False)

Restrictions: tactics False, globargs False


Valid check: 100%|████████████████████████████████████████| 198207/198207 [2:23:18<00:00, 23.05it/s]

Validation accuracy:
  no_arg: 78460 / 98146
  local: 19615 / 31200
  global: 12263 / 59958
  both: 734 / 5207
  none: 0 / 3696





In [30]:
check_predict_accuracy(predict, d, global_equiv = nt_to_eqc, restrict_tactics = True, restrict_globargs = True)

Restrictions: tactics True, globargs True


Valid check: 100%|████████████████████████████████████████| 198207/198207 [2:29:19<00:00, 22.12it/s]

Validation accuracy:
  no_arg: 98146 / 98146
  local: 27974 / 31200
  global: 49637 / 59958
  both: 3773 / 5207
  none: 0 / 3696





In [28]:
check_predict_accuracy(predict, d, global_equiv = nt_to_eqc, restrict_tactics = True, restrict_globargs = False)

Restrictions: tactics True, globargs False


Valid check: 100%|████████████████████████████████████████| 198207/198207 [2:11:15<00:00, 25.17it/s]

Validation accuracy:
  no_arg: 98146 / 98146
  local: 27643 / 31200
  global: 17358 / 59958
  both: 1182 / 5207
  none: 0 / 3696





In [29]:
check_predict_accuracy(predict, d, global_equiv = nt_to_eqc, restrict_tactics = False, restrict_globargs = True)

Restrictions: tactics False, globargs True


Valid check: 100%|████████████████████████████████████████| 198207/198207 [2:10:28<00:00, 25.32it/s]

Validation accuracy:
  no_arg: 78460 / 98146
  local: 19732 / 31200
  global: 32436 / 59958
  both: 2092 / 5207
  none: 0 / 3696



