In [1]:
from pathlib import Path
from tqdm import tqdm
import tensorflow as tf
import numpy as np
from tensorflow import keras

from graph2tac.tf2.train import TrainingParams
from graph2tac.tf2.predict import Predict
from graph2tac.loader.data_server import DataServer
from graph2tac.tf2.graph_nn_def_batch import make_flat_def_batch_np
from graph2tac.tf2.model import np_to_tensor_def

In [2]:
def run_update_embeddings(predict, d, extra_iterations = 0):
    embeddings = predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings
    emb0 = np.array(predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings.numpy())
    start_def = predict.dataset_consts.base_node_type_num
    emb_def = embeddings[start_def:]
    print("emb_zeros_ori", tf.math.reduce_all(tf.equal(predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings[start_def:], 0.)))
    emb_def.assign(tf.zeros_like(emb_def))
    print("emb_zeros", tf.math.reduce_all(tf.equal(predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings[start_def:], 0.)))
    # recalculate the embeddings by their definitions
    clusters = d.def_cluster_subgraphs(tf_gnn=False)
    #print("Update embeddings")
    for cluster in tqdm(clusters, desc = "Update embedding"):
    #for cluster in clusters:
        nodes_c = np.array(cluster[0])
        root_nums = cluster[2]
        roots = nodes_c[:root_nums]
        nodes_c = nodes_c[root_nums:]
        #print(np.max(nodes_c), "->", roots)
        predict.compute_new_definitions([cluster])
    print("emb_updated_zeros", tf.math.reduce_all(tf.equal(predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings[start_def:], 0.)))
    emb1 = np.array(predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings.numpy())
    print("emb0 == emb1", np.max(np.abs(emb0 - emb1)))
    for i in range(extra_iterations):
        for cluster in tqdm(clusters, desc = "Update embedding{}".format(i+2)):
            predict.compute_new_definitions([cluster])
        emb2 = np.array(predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings.numpy())
        print("emb{} == emb{}".format(i+1, i+2), np.max(np.abs(emb1 - emb2)))
        emb1 = emb2

In [3]:
def check_def_loss(predict, d):
    def squared_norm_diff(
        emb1,  # [batch, dim]
        emb2,  # [batch, dim]
    ):  # -> [batch]
        diff = emb1 - emb2  # [batch, dim]
        norm_squared = tf.math.reduce_sum(tf.square(diff), axis=-1)  # [batch]
        return norm_squared  # [batch]

    def_loss_mean = tf.keras.metrics.Mean()
    @tf.function(input_signature=(predict.model_wrapper.def_input_spec,))
    def def_loss_step(batch):
        def_body_embs, def_id_embs = predict.model_wrapper.model_def(batch, training=False)
        def_loss = squared_norm_diff(def_body_embs.values, def_id_embs.values)
        #print(def_loss.numpy())
        def_loss_mean(def_loss)
    clusters = d.def_cluster_subgraphs(tf_gnn=False)
    for cluster in tqdm(clusters, desc = "Check def loss"):
        flat_batch_np = make_flat_def_batch_np([cluster])
        flat_batch = np_to_tensor_def(flat_batch_np)
        def_loss_step(flat_batch)
    print("Def loss:", def_loss_mean.result().numpy())

In [4]:
def check_predict_accuracy(predict, d, restrict_tactics = False, restrict_globargs = False):
    allowed_model_tactics = list(range(len(d.graph_constants().tactic_index_to_hash)))
    available_global = None
    score = {
        "no_arg" : [0,0],
        "local" : [0,0],
        "global" : [0,0],
        "both" : [0,0],
        "none" : [0,0],
    }
    correct = 0
    total = 0
    print(f"Restrictions: tactics {restrict_tactics}, globargs {restrict_globargs}")
    data = d.data_valid()
    data = tqdm(data, desc = "Valid check")
    for state, action, _ in data:
        if restrict_tactics:
            allowed_model_tactics = [action[0]]
        if restrict_globargs:
            _, args, mask_args = action
            args = args[mask_args]
            arg_is_global = args[:,0] == 1
            arg_vals = args[:,1]
            glob_args = arg_vals[arg_is_global]
            #available_global = np.array([
            #    predict.dataset_consts.global_context.index(x)
            #    for x in sorted(set(glob_args))
            #], dtype = int)
            available_global = np.array(sorted(set(glob_args)), dtype = int)
        ranked_actions, ranked_values = predict.ranked_predictions(
            state, allowed_model_tactics, available_global = available_global,
            tactic_expand_bound=1, total_expand_bound=1,
        )
        real_tactic_label, real_args, mask_args = action
        real_args_valid = real_args[mask_args]
        if len(ranked_actions) > 0: action = ranked_actions[0]
        else: action = None
        arg_types = real_args_valid[:,0]
        arg_values = real_args_valid[:,1]
        ctx_len = len(state[3])
        if arg_types.size == 0: cur_type = "no_arg"
        elif ((arg_types == 0) & (arg_values == ctx_len)).any(): cur_type = "none"
        elif (arg_types == 0).all(): cur_type = "local"
        elif (arg_types == 1).all(): cur_type = "global"
        else: cur_type = "both"
        score_t = score[cur_type]
        score_t[0] += 1
        #print("cur_type:", cur_type)
        #print("real_tactic_label:", real_tactic_label)
        #print("real_args_valid:", real_args_valid)
        #print("action:", action)
        #print()
        if action is not None and real_tactic_label == action[0,0] and (real_args_valid == action[1:]).all():
            score_t[1] += 1
    print("Validation accuracy:")
    for t, score_t in score.items():
        print(f"  {t}: {score_t[1]} / {score_t[0]}")    

In [7]:
dataset_path = "/home/mirek/graph2tac/tests/unit-test-mini-stdlib-v9/dataset"
weights_path = "/home/mirek/graph2tac/tests/unit-test-mini-stdlib-v9/weights_combined/checkpoint__epoch299"
params_path = "/home/mirek/graph2tac/tests/unit-test-mini-stdlib-v9/params_overfit.yaml"

In [8]:
params = TrainingParams.from_yaml_file(Path(params_path))



In [9]:
params

TrainingParams(data_params=DataParams(split_seed=0, cv_fold=0, shuffle_def=True, bfs_option=False, max_subgraph_size=1024), optimizer_params=OptimizerParams(optimizer='adam', learning_rate=None, clipvalue=None, clipnorm=None, global_clipnorm=None, loss_weights=LossWeightParams(tactic_base=1.0, tactic_args=1.0, def_task=100.0, def_id_to_body=1.0, def_body_to_id=1.0), def_task=True), model_params=ModelParams(dataset_consts=None, ignore_definitions=False, normalize_def_embeddings=True, single_edge_type=False, symmetric_edges=True, self_edges=True, total_hops=10, node_dim=128, message_passing_layer='conv2', norm_msgs=False, nonlin_position=None, nonlin_type='relu', residuals=True, residual_dropout=True, residual_norm='layer_norm', final_collapse=True, aggreg_max=False, use_same_graph_nn_weights_for_def_training=True), tf_seed=42, tf_eager=False, enable_op_determinism=False, batch_size=50, num_epochs=300, num_proc=None)

In [10]:
import os
os.environ["G2T_LOG_LEVEL"] = "40"

In [11]:
d = DataServer(
    Path(dataset_path),
    Path('.'),
    cross_valid_fold=params.data_params.cv_fold,
    bfs_option=params.data_params.bfs_option,
    max_subgraph_size=params.data_params.max_subgraph_size,
    split_random_seed=params.data_params.split_seed,
    split=(0,1,0),
    global_args = True,
    num_proc=params.num_proc
)

Data: using 8b0330 for processed data storage
LOADING | 
LOADING | 20
LOADING | c_data call complete
LOADING | building dataset complete in 0.23527908325195312 seconds
LOADING | splitting data to train, valid, test complete in 0.007067203521728516 seconds
LOADING | building def clusters complete 0.017113685607910156 seconds


In [12]:
predict = Predict(Path(weights_path))

In [13]:
check_def_loss(predict, d)

LOADING | requested 482


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 482/482 [00:00<00:00, 12332.50it/s]
Check def loss: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 482/482 [00:28<00:00, 16.66it/s]


Def loss: 0.0009519947


In [14]:
run_update_embeddings(predict, d)

emb_zeros_ori tf.Tensor(False, shape=(), dtype=bool)
emb_zeros tf.Tensor(True, shape=(), dtype=bool)
LOADING | requested 482


Update embedding: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 482/482 [00:27<00:00, 17.22it/s]

emb_updated_zeros tf.Tensor(False, shape=(), dtype=bool)
emb0 == emb1 0.048569106





In [15]:
nt_to_name = predict.dataset_consts.node_type_to_name
emb = np.array(predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings.numpy())

In [16]:
radius = 1e-6
nt_to_eqc = np.arange(emb.shape[0])
for n in range(1, len(nt_to_eqc)):
    dists = np.abs(emb[n] - emb[:n])
    dists = np.max(dists, axis = 1)
    min_dist_n = np.argmin(dists)
    min_dist = dists[min_dist_n]
    if min_dist <= radius:
        nt_to_eqc[n] = nt_to_eqc[min_dist_n]

In [17]:
nt_eq_classes = dict()
for n,c in enumerate(nt_to_eqc):
    if c in nt_eq_classes: nt_eq_classes[c].append(n)
    else: nt_eq_classes[c] = [n]
nt_eq_classes = sorted(nt_eq_classes.values(), key = lambda x: (-len(x), x))
len(nt_eq_classes)

550

In [18]:
for cl in nt_eq_classes:
    if len(cl) > 1: print([nt_to_name[x] for x in cl])

['Coq.Init.Byte.xae', 'Coq.Init.Byte.x9d', 'Coq.Init.Byte.xe4', 'Coq.Init.Byte.x5d', 'Coq.Init.Byte.x98', 'Coq.Init.Byte.x0d', 'Coq.Init.Byte.x85', 'Coq.Init.Byte.x78', 'Coq.Init.Byte.x3f', 'Coq.Init.Byte.x1f', 'Coq.Init.Byte.xdf', 'Coq.Init.Byte.x62', 'Coq.Init.Byte.x65', 'Coq.Init.Byte.x1b', 'Coq.Init.Byte.xee', 'Coq.Init.Byte.xc1', 'Coq.Init.Byte.x48', 'Coq.Init.Byte.xb6', 'Coq.Init.Byte.x49', 'Coq.Init.Byte.x00', 'Coq.Init.Byte.x82', 'Coq.Init.Byte.xd4', 'Coq.Init.Byte.xb8', 'Coq.Init.Byte.xe3', 'Coq.Init.Byte.x41', 'Coq.Init.Byte.x84', 'Coq.Init.Byte.x24', 'Coq.Init.Byte.xbb', 'Coq.Init.Byte.x03', 'Coq.Init.Byte.xf7', 'Coq.Init.Byte.xb0', 'Coq.Init.Byte.x45', 'Coq.Init.Byte.x4c', 'Coq.Init.Byte.xb1', 'Coq.Init.Byte.x89', 'Coq.Init.Byte.x8e', 'Coq.Init.Byte.x08', 'Coq.Init.Byte.xe9', 'Coq.Init.Byte.xa3', 'Coq.Init.Byte.xa8', 'Coq.Init.Byte.x56', 'Coq.Init.Byte.x74', 'Coq.Init.Byte.x33', 'Coq.Init.Byte.x52', 'Coq.Init.Byte.x76', 'Coq.Init.Byte.xba', 'Coq.Init.Byte.x12', 'Coq.Init.By

In [17]:
# TODO: how often an ambiguous definition appear in the dataset?
data = d.data_valid()
for state, action, _ in data:
    # TODO
    pass

In [17]:
# TODO: what if predict global arguments up to this equivalence
# TODO: check this in the standard library
# TODO: top-k accuracy

In [20]:
check_predict_accuracy(predict, d, restrict_tactics = False, restrict_globargs = True)

Restrictions: tactics False, globargs True


Valid check: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1954/1954 [02:39<00:00, 12.22it/s]

Validation accuracy:
  no_arg: 703 / 742
  local: 548 / 557
  global: 578 / 578
  both: 11 / 14
  none: 0 / 63





In [18]:
data = d.data_valid()
for _ in range(927): state, action, _ = next(data)

In [20]:
_, args, mask_args = action
args = args[mask_args]
arg_is_global = args[:,0] == 1
arg_vals = args[:,1]
glob_args = arg_vals[arg_is_global]

In [22]:
glob_args

array([9], dtype=uint32)

In [23]:
predict.dataset_consts.global_context

[29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 131,
 132,
 133,
 134,
 135,
 136,
 137,
 138,
 139,
 140,
 141,
 142,
 143,
 144,
 145,
 146,
 147,
 148,
 149,
 150,
 151,
 152,
 153,
 154,
 155,
 156,
 157,
 158,
 159,
 160,
 161,
 162,
 163,
 164,
 165,
 166,
 167,
 168,
 169,
 170,
 171,
 172,
 173,
 174,
 175,
 176,
 177,
 178,
 179,
 180,
 181,
 182,
 183,
 184,
 185,
 186,
 187,
 188,
 189,
 190,
 191,
 192,
 193,
 194,
 195,
 196,
 197,
 198,
 199,
 200,
 201,
 202,
 203,
 204,
 205,
 206,
 20

In [166]:
allowed_model_tactics = list(range(len(d.graph_constants().tactic_index_to_hash)))
data = d.data_valid()
for i in range(89): state, action, _ = next(data)
real_tactic_label, real_args, mask_args = action
real_args_valid = real_args[mask_args]
arg_types = real_args_valid[:,0]
arg_values = real_args_valid[:,1]
ctx_len = len(state[3])
if arg_types.size == 0: cur_type = "no_arg"
elif ((arg_types == 0) & (arg_values == ctx_len)).any(): cur_type = "none"
elif (arg_types == 0).all(): cur_type = "local"
elif (arg_types == 1).all(): cur_type = "global"
else: cur_type = "both"
print(cur_type)

global


In [233]:
ranked_actions, ranked_values = predict.ranked_predictions(
    state, allowed_model_tactics, available_global = np.array([43]),
    tactic_expand_bound=1, total_expand_bound=1,
)
if len(ranked_actions) > 0: act = ranked_actions[0]
else: act = None
print(act)
if act is not None and real_tactic_label == act[0,0] and (real_args_valid == act[1:]).all():
    print("correct")
else:
    print("incorrect")

[[5 5]
 [1 0]]
incorrect


In [232]:
print(predict.dataset_consts.global_context[43])
print(predict.dataset_consts.global_context[688])

72
717


In [67]:
glob_predicted4 = glob_predicted
print(len(glob_predicted0))
print(len(glob_predicted1))
print(len(glob_predicted2))
print(len(glob_predicted3))
print(len(glob_predicted4))
print(sorted(glob_predicted0)[:20])
print(sorted(glob_predicted1)[:20])
print(sorted(glob_predicted2)[:20])
print(sorted(glob_predicted3)[:20])
print(sorted(glob_predicted4)[:20])
print(sorted(glob_predicted0 & glob_predicted1 & glob_predicted2 & glob_predicted3 & glob_predicted4)[:20])

305
301
312
313
313
[1, 9, 16, 19, 20, 22, 26, 36, 38, 41, 43, 65, 67, 71, 72, 78, 79, 82, 86, 88]
[2, 3, 4, 12, 19, 21, 30, 32, 38, 43, 45, 50, 51, 62, 71, 74, 76, 84, 87, 88]
[4, 13, 17, 20, 22, 31, 44, 48, 49, 51, 54, 58, 60, 61, 75, 81, 86, 88, 89, 90]
[14, 17, 24, 29, 34, 42, 48, 50, 51, 54, 60, 64, 70, 73, 76, 83, 87, 88, 90, 95]
[3, 4, 22, 32, 44, 48, 50, 51, 60, 67, 70, 72, 74, 75, 77, 80, 83, 84, 88, 91]
[88, 519, 524, 526, 528, 529, 530, 531, 532, 533, 540, 544, 548, 550, 551, 553, 556, 558, 559, 561]


In [21]:
available_global = np.array([
    predict.dataset_consts.global_context.index(x)
    for x in sorted(set(glob_args))
], dtype = int)

ValueError: 9 is not in list