In [1]:
#dataset_path = "/home/mirek/data/stdlib-lgraph-intermediate-v9-global"
dataset_path = "/home/mirek/data-local/stdlib-lgraph-intermediate-v11-global"
#weights_path = "/home/mirek/graph2tac/graph2tac/tf2/weights/checkpoint__epoch0"
weights_path = "/mnt/share/data/weights/vasily/runs/4_16_push_def_task/weights/checkpoint__epoch267"
params_path = "/home/mirek/graph2tac/graph2tac/tf2/trained_params"
#params_path = "/home/mirek/graph2tac/graph2tac/tf2/params_fast_train"

In [2]:
from pathlib import Path
from tqdm import tqdm
import tensorflow as tf
import numpy as np
from tensorflow import keras

from graph2tac.tf2.train import TrainingParams
from graph2tac.tf2.predict import Predict
from graph2tac.loader.data_server import DataServer
from graph2tac.tf2.graph_nn_def_batch import make_flat_def_batch_np
from graph2tac.tf2.model import np_to_tensor_def

In [3]:
def run_update_embeddings(predict, d, extra_iterations = 0):
    embeddings = predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings
    emb0 = np.array(predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings.numpy())
    start_def = predict.dataset_consts.base_node_type_num
    emb_def = embeddings[start_def:]
    print("emb_zeros_ori", tf.math.reduce_all(tf.equal(predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings[start_def:], 0.)))
    emb_def.assign(tf.zeros_like(emb_def))
    print("emb_zeros", tf.math.reduce_all(tf.equal(predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings[start_def:], 0.)))
    # recalculate the embeddings by their definitions
    clusters = d.def_cluster_subgraphs(tf_gnn=False)
    #print("Update embeddings")
    for cluster in tqdm(clusters, desc = "Update embedding"):
    #for cluster in clusters:
        nodes_c = np.array(cluster[0])
        root_nums = cluster[2]
        roots = nodes_c[:root_nums]
        nodes_c = nodes_c[root_nums:]
        #print(np.max(nodes_c), "->", roots)
        predict.compute_new_definitions([cluster])
    print("emb_updated_zeros", tf.math.reduce_all(tf.equal(predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings[start_def:], 0.)))
    emb1 = np.array(predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings.numpy())
    print("emb0 == emb1", np.max(np.abs(emb0 - emb1)))
    for i in range(extra_iterations):
        for cluster in tqdm(clusters, desc = "Update embedding{}".format(i+2)):
            predict.compute_new_definitions([cluster])
        emb2 = np.array(predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings.numpy())
        print("emb{} == emb{}".format(i+1, i+2), np.max(np.abs(emb1 - emb2)))
        emb1 = emb2

In [4]:
def check_def_loss(predict, d):
    def squared_norm_diff(
        emb1,  # [batch, dim]
        emb2,  # [batch, dim]
    ):  # -> [batch]
        diff = emb1 - emb2  # [batch, dim]
        norm_squared = tf.math.reduce_sum(tf.square(diff), axis=-1)  # [batch]
        return norm_squared  # [batch]

    def_loss_mean = tf.keras.metrics.Mean()
    @tf.function(input_signature=(predict.model_wrapper.def_input_spec,))
    def def_loss_step(batch):
        def_body_embs, def_id_embs = predict.model_wrapper.model_def(batch, training=False)
        def_loss = squared_norm_diff(def_body_embs.values, def_id_embs.values)
        #print(def_loss.numpy())
        def_loss_mean(def_loss)
    clusters = d.def_cluster_subgraphs(tf_gnn=False)
    for cluster in tqdm(clusters, desc = "Check def loss"):
        flat_batch_np = make_flat_def_batch_np([cluster])
        flat_batch = np_to_tensor_def(flat_batch_np)
        def_loss_step(flat_batch)
    print("Def loss:", def_loss_mean.result().numpy())

In [5]:
def check_predict_accuracy(predict, d, restrict_tactics = False, restrict_globargs = False, global_equiv = None):
    allowed_model_tactics = list(range(len(d.graph_constants().tactic_index_to_hash)))
    available_global = None
    score = {
        "no_arg" : [0,0],
        "local" : [0,0],
        "global" : [0,0],
        "both" : [0,0],
        "none" : [0,0],
    }
    correct = 0
    total = 0
    print(f"Restrictions: tactics {restrict_tactics}, globargs {restrict_globargs}")
    data = d.data_valid()
    data = tqdm(data, desc = "Valid check")
    for state, action, _ in data:
        if restrict_tactics:
            allowed_model_tactics = [action[0]]
        if restrict_globargs:
            _, args, mask_args = action
            args = args[mask_args]
            arg_is_global = args[:,0] == 1
            arg_vals = args[:,1]
            glob_args = arg_vals[arg_is_global]
            #available_global = np.array([
            #    predict.dataset_consts.global_context.index(x)
            #    for x in sorted(set(glob_args))
            #], dtype = int)
            available_global = np.array(sorted(set(glob_args)), dtype = int)
        ranked_actions, ranked_values = predict.ranked_predictions(
            state, allowed_model_tactics, available_global = available_global,
            tactic_expand_bound=1, total_expand_bound=1,
        )
        real_tactic_label, real_args, mask_args = action
        real_args_valid = real_args[mask_args]
        if len(ranked_actions) > 0: action = ranked_actions[0]
        else: action = None
        arg_types = real_args_valid[:,0]
        arg_values = real_args_valid[:,1]
        ctx_len = len(state[3])
        if arg_types.size == 0: cur_type = "no_arg"
        elif ((arg_types == 0) & (arg_values == ctx_len)).any(): cur_type = "none"
        elif (arg_types == 0).all(): cur_type = "local"
        elif (arg_types == 1).all(): cur_type = "global"
        else: cur_type = "both"
        score_t = score[cur_type]
        score_t[0] += 1
        #print("cur_type:", cur_type)
        #print("real_tactic_label:", real_tactic_label)
        #print("real_args_valid:", real_args_valid)
        #print("action:", action)
        #print()
        if action is not None:
            args_pred = action[1:]
            global_ctx = np.array(predict.dataset_consts.global_context)
            if global_equiv is not None:
                global_equiv = np.array(global_equiv)
                def convert_args(args):
                    args_t, args_val = args.T
                    args_glob = args_val[args_t == 1]
                    args_glob = global_equiv[global_ctx[args_glob]] # conversion
                    args_val[args_t == 1] = args_glob
                #print("----------------")
                #print(real_args_valid.shape, args_pred.shape)
                convert_args(real_args_valid)
                convert_args(args_pred)
                #print(real_args_valid.shape, args_pred.shape)
            if real_tactic_label == action[0,0] and (real_args_valid == args_pred).all():
                score_t[1] += 1
    print("Validation accuracy:")
    for t, score_t in score.items():
        print(f"  {t}: {score_t[1]} / {score_t[0]}")    

In [6]:
params = TrainingParams.from_yaml_file(Path(params_path))



In [7]:
params

TrainingParams(data_params=DataParams(split_seed=0, split=(8, 1, 1), cv_fold=0, shuffle_def=True, max_subgraph_size=1024, bfs_option=False, restrict_to_spine=False), optimizer_params=OptimizerParams(optimizer='adam', learning_rate=None, clipvalue=None, clipnorm=None, global_clipnorm=None, loss_weights=LossWeightParams(tactic_base=1.0, tactic_args=1.0, def_task=100.0, def_id_to_body=1.0, def_body_to_id=1.0), def_task=True), model_params=ModelParams(dataset_consts=None, ignore_definitions=False, normalize_def_embeddings=True, single_edge_label=False, symmetric_edges=True, self_edges=True, total_hops=10, node_dim=128, message_passing_layer='conv2', norm_msgs=False, nonlin_position=None, nonlin_type='relu', residuals=True, residual_dropout=True, residual_norm='layer_norm', final_collapse=True, aggreg_max=False, use_same_graph_nn_weights_for_def_training=True), tf_seed=42, tf_eager=False, enable_op_determinism=False, batch_size=50, num_epochs=300, num_proc=None)

In [8]:
import os
os.environ["G2T_LOG_LEVEL"] = "40"

In [9]:
d = DataServer(
    Path(dataset_path),
    Path('.'),
    cross_valid_fold=params.data_params.cv_fold,
    bfs_option=params.data_params.bfs_option,
    max_subgraph_size=params.data_params.max_subgraph_size,
    split_random_seed=params.data_params.split_seed,
    split=(0,1,0),
#    global_args = True,
    num_proc=params.num_proc,
    restrict_to_spine = False,
)

LOADING | indexing and top sorting bin files in /home/mirek/data-local/stdlib-lgraph-intermediate-v11-global...done.
LOADING | preparing data from 526 files.
LOADING | constructing file reference table...done.
LOADING | indexing all definitions...Indexed 73595 definitions in 0.537760 seconds.
LOADING | indexing all tactical action-outcomes in 526 files...

526it [00:00, 2174.74it/s]


LOADING | Indexed 567329 tactical action-outcomes in 2.300030 seconds.
LOADING | mmaping all capnp files and building data loader hash tables...done in 0.079705 seconds.
LOADING | in def_dependencies: max_subgraph_size=1024 bfs_option=False
LOADING | constructing shallow expansions of all definitions to build the graph of definition dependencies...done in 1.304275 seconds.
LOADING | NOTICE: the graph of definition dependencies should be precomputed and recorded to capnp bin files at the time of generation of the dataset. It is inefficient to recompute this graph every time dataserver is initialized.
LOADING | building strongly connected components (def clusters) in the meta graph of definition dependencies...done in 0.611645 seconds. Constructed 71179 def clusters.
LOADING | DataServer is fully initialized in 4.902103 seconds and is ready to stream.


In [10]:
predict = Predict(Path(weights_path))

In [11]:
predict.initialize()

In [12]:
check_def_loss(predict, d)

LOADING | requested 71179


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 71179/71179 [00:07<00:00, 9534.64it/s]
Check def loss: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 71179/71179 [33:00<00:00, 35.94it/s]


Def loss: 0.00031275675


In [13]:
check_predict_accuracy(predict, d, restrict_tactics = False, restrict_globargs = False)

Restrictions: tactics False, globargs False


Valid check: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 567329/567329 [13:15:40<00:00, 11.88it/s]

Validation accuracy:
  no_arg: 280336 / 317618
  local: 59487 / 76910
  global: 102502 / 150085
  both: 8948 / 13556
  none: 0 / 9160





In [None]:
check_predict_accuracy(predict, d, restrict_tactics = True, restrict_globargs = True)

Restrictions: tactics True, globargs True


Valid check:  68%|████████████████████████████████████████████████████████████▉                            | 388294/567329 [8:24:54<4:18:52, 11.53it/s]

In [None]:
check_predict_accuracy(predict, d, restrict_tactics = True, restrict_globargs = False)

In [None]:
check_predict_accuracy(predict, d, restrict_tactics = False, restrict_globargs = True)

In [None]:
run_update_embeddings(predict, d)

In [None]:
check_def_loss(predict, d)

In [None]:
nt_to_name = predict.dataset_consts.node_type_to_name
emb = np.array(predict.model_wrapper.node_and_def_emb._embedding_layer.embeddings.numpy())

In [None]:
radius = 1e-6
nt_to_eqc = np.arange(emb.shape[0])
for n in tqdm(range(1, len(nt_to_eqc))):
    dists = np.abs(emb[n] - emb[:n])
    dists = np.max(dists, axis = 1)
    min_dist_n = np.argmin(dists)
    min_dist = dists[min_dist_n]
    if min_dist <= radius:
        nt_to_eqc[n] = nt_to_eqc[min_dist_n]

In [None]:
nt_to_eqc

In [None]:
nt_eq_classes = dict()
for n,c in enumerate(nt_to_eqc):
    if c in nt_eq_classes: nt_eq_classes[c].append(n)
    else: nt_eq_classes[c] = [n]
nt_eq_classes = sorted(nt_eq_classes.values(), key = lambda x: (-len(x), x))
len(nt_eq_classes)

In [None]:
for cl in nt_eq_classes:
    if len(cl) > 1: print([nt_to_name[x] for x in cl])

In [None]:
nt_to_name.index("Coq.Numbers.Cyclic.Abstract.CyclicAxioms.ZnZ.spec_div21")

In [None]:
nt_to_name.index("Coq.Numbers.Cyclic.Abstract.CyclicAxioms.ZnZ.spec_add_carry_c")

In [None]:
emb[9766]

In [None]:
emb[15741]

In [None]:
# how often an ambiguous definition appear in the dataset?
global_ctx = predict.dataset_consts.global_context
ambig_args = set()
for cl in nt_eq_classes:
    if len(cl) > 1: ambig_args.update(cl)
data = d.data_valid()
is_ambig_per_arg = []
is_ambig_per_tac = []
for state, action, _ in data:
    _, args, mask_args = action
    args = args[mask_args]
    args_t, args_val = args.T
    args_global = args_val[args_t == 1]
    for arg in args_global:
        is_ambig_per_arg.append(global_ctx[arg] in ambig_args)

    ctx_len = len(state[3])
    if ((args_t == 0) & (args_val == ctx_len)).any():
        continue

    if len(args_global) > 0:
        is_ambig_per_tac.append(any(
            global_ctx[arg] in ambig_args
            for arg in args_global
        ))

In [None]:
print(f"ambiguous per argument: {np.sum(is_ambig_per_arg)} / {len(is_ambig_per_arg)} = {np.mean(is_ambig_per_arg)}")
print(f"ambiguous per tactic: {np.sum(is_ambig_per_tac)} / {len(is_ambig_per_tac)} = {np.mean(is_ambig_per_tac)}")

In [None]:
# TODO: find out why some particular classes (Lasse on Slack) are equal
#  . check if the embedding is not close to zero
#  . try to tweak epsilon
#  . try random network
#    . try a network without trained definition loss
#  . try proper "hashing"
# TODO:
# . find which ambiguous classes occur the most in the dataset (calculate statistics)
# TODO:
# . check how the drop in accuracy relates to the difference between original a new embedding

# TODO:
# . top-k accuracy
# . hyperparameters:
#   . tactic_expand_bound = 10
#   . total_expand_bound = try 10 / 100 / 1000 ... 10 000 000 000

In [None]:
check_predict_accuracy(predict, d, global_equiv = nt_to_eqc, restrict_tactics = False, restrict_globargs = False)

In [None]:
check_predict_accuracy(predict, d, global_equiv = nt_to_eqc, restrict_tactics = True, restrict_globargs = True)

In [None]:
check_predict_accuracy(predict, d, global_equiv = nt_to_eqc, restrict_tactics = True, restrict_globargs = False)

In [None]:
check_predict_accuracy(predict, d, global_equiv = nt_to_eqc, restrict_tactics = False, restrict_globargs = True)