In [1]:
import os
import sys
import gin
import numpy as np
import pandas as pd

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from IPython.core.display import clear_output, display

import matplotlib.pyplot as plt

import logging

logging.getLogger().setLevel(logging.DEBUG)

from eval.event_evaluation import EventEvaluator
from ariadne_v2.transformations import Compose, ConstraintsNormalize, ToCylindrical, DropSpinningTracks, DropShort, \
    DropEmpty

parse_cfg = {
    'csv_params': {
        "sep": '\s+',
        # "nrows": 15000,
        "encoding": 'utf-8',
        "names": ['event', 'x', 'y', 'z', 'station', 'track', 'px', 'py', 'pz', 'X0', 'Y0', 'Z0']
    },

    'input_file_mask': "D:/ariadne-master/output_fake_1000.tsv",
    'events_quantity': '0..200'
}

global_transformer = Compose([
    DropSpinningTracks(),
    DropShort(num_stations=3),
    DropEmpty()
])

import scripts.clean_cache

# to clean cache if needed
# scripts.clean_cache.clean_jit_cache('20d')


from ariadne.graph_net.graph_utils.graph_prepare_utils import to_pandas_graph_from_df, get_pd_line_graph, \
    apply_nodes_restrictions, apply_edge_restriction, construct_output_graph
from ariadne.transformations import Compose, ConstraintsNormalize, ToCylindrical

from ariadne_v2.inference import IModelLoader

import torch

suff_df = ('_p', '_c')
gin.bind_parameter('get_pd_line_graph.restrictions_0', (-2., 2.))
gin.bind_parameter('get_pd_line_graph.restrictions_1', (-0.03, 0.03))
gin.bind_parameter('get_pd_line_graph.suffix_c', '_c')
gin.bind_parameter('get_pd_line_graph.suffix_p', '_p')
gin.bind_parameter('get_pd_line_graph.spec_kwargs', {'suffix_c': '_c',
                                                     'suffix_p': '_p',
                                                     'axes': ['r', 'phi', 'z']})
_edge_restriction = 0.15


class GraphModelLoader(IModelLoader):
    def __call__(self):
        from ariadne.graph_net.model import GraphNet_v1
        import torch

        gin.bind_parameter('GraphNet_v1.input_dim', 5)
        gin.bind_parameter('GraphNet_v1.hidden_dim', 128)
        gin.bind_parameter('GraphNet_v1.n_iters', 1)

        def weights_update_g(model, checkpoint):
            model_dict = model.state_dict()
            pretrained_dict = checkpoint['state_dict']
            real_dict = {}
            for (k, v) in model_dict.items():
                needed_key = None
                for pretr_key in pretrained_dict:
                    if k in pretr_key:
                        needed_key = pretr_key
                        break
                assert needed_key is not None, "key %s not in pretrained_dict %r!" % (k, pretrained_dict.keys())
                real_dict[k] = pretrained_dict[needed_key]

            model.load_state_dict(real_dict)
            model.eval()
            return model

        path_g =  'D:/ariadne-master/lightning_logs/GraphNet_v1/version_120/epoch=19-step=3999.ckpt'

        checkpoint_g = torch.load(path_g) if torch.cuda.is_available() else torch.load(path_g,
                                                                                       map_location=torch.device('cpu'))
        model_g = weights_update_g(model=GraphNet_v1(),
                                   checkpoint=checkpoint_g)
        model_hash = {"path_g": path_g, 'gin': gin.config_str(), 'model': '%r' % model_g, 'edge': _edge_restriction}
        return model_hash, model_g


from collections import namedtuple

GraphWithIndices = namedtuple('Graph', ['X', 'Ri', 'Ro', 'y', 'v1v2v3', 'ev_id'])


transformer_g = Compose([
    DropSpinningTracks(),
    DropShort(),
    DropEmpty(),
    ToCylindrical(),
    ConstraintsNormalize(
        columns=('r', 'phi', 'z'),
        constraints={'r': [269., 581.], 'phi': [-3.15, 3.15], 'z': [-2386.0, 2386.0]},
        use_global_constraints=True
    ),
])

def construct_graph_with_indices(graph, v1v2v3, ev_id):
    return GraphWithIndices(graph.X, graph.Ri, graph.Ro, graph.y, v1v2v3, ev_id)


def get_graph(event):
    event = event[['event', 'x', 'y', 'z', 'station', 'track', 'index_old']]

    try:
        event = transformer_g(event)
    except AssertionError as err:
        print("ASS error %r" % err)
        return None

    event.index = event['index_old'].values
    event = event[['event', 'r', 'phi', 'z', 'station', 'track']]

    G = to_pandas_graph_from_df(event, suffixes=suff_df, compute_is_true_track=True)

    nodes_t, edges_t = get_pd_line_graph(G, apply_nodes_restrictions)

    edges_filtered = apply_edge_restriction(edges_t, edge_restriction=_edge_restriction)
    graph = construct_output_graph(nodes_t, edges_filtered, ['y_p', 'y_c', 'z_p', 'z_c', 'z'],
                                   [np.pi, np.pi, 1., 1., 1.], 'edge_index_p', 'edge_index_c')
    ev_id = event.event.values[0]
    graph_with_inds = construct_graph_with_indices(graph,
                                                   edges_filtered[['from_ind', 'cur_ind', 'to_ind']].values, ev_id)

    return graph_with_inds


from ariadne.graph_net.dataset import collate_fn

N_STATIONS = 35
from timeit import default_timer as timer
def eval_event(tgt_graph, model_g):
    batch_input, batch_target = collate_fn([tgt_graph])
    with torch.no_grad():
        y_pred = model_g(batch_input['inputs']).numpy().flatten() > 0.15

    v1v2v3 = tgt_graph.v1v2v3 [ y_pred ]

    def find_next( arr, processed):
        ind = np.where(np.all(v1v2v3[:, :-1] == arr[-2:], axis=1))
        v_next = v1v2v3[ind]
        if len(v_next) > 0: return find_next(np.append(arr, v_next[0][-1]), np.append(processed, ind))
        else: return (np.array([arr]), processed)

    v_all = None

    while v1v2v3.size != 0:

        v, processed = find_next(v1v2v3[0], np.array([0]))

        v1v2v3 = np.delete(v1v2v3, processed, 0)

        if v.size != N_STATIONS: continue

        if v_all is None: v_all = v
        else:
            v_all = np.concatenate((v_all, v))

    eval_df = pd.DataFrame( v_all,columns=[f"hit_id_{n}" for n in range(1,N_STATIONS+1)])

    eval_df[ [f"hit_id_{n}" for n in range(1,N_STATIONS+1)]] = v_all

    return eval_df


evaluator = EventEvaluator(parse_cfg, global_transformer, N_STATIONS)
events = evaluator.prepare(model_loader=GraphModelLoader())[0]
all_results = evaluator.build_all_tracks()
model_results = evaluator.run_model(get_graph, eval_event)
results_graphnet = evaluator.solve_results(model_results, all_results)

  from IPython.core.display import clear_output, display
  from IPython.core.display import clear_output, display
Exception when trying to get git hash... bad!


read entry c4bab4322e01468ee8e99343344a7a02 hit
[prepare]: started processing a df output_fake_1000.tsv with 45652 rows:
read entry 27d82359f97e6ab3f74db5dedb1771fd hit
[prepare] finished
[prepare] loading your model(s)...
[prepare] finished loading your model(s)...
[build_all_tracks] start
read entry 55bb202a014c463f0ccc979712d5f0ea hit
read entry ce07df513497c2c64c12b6f3bc0d0c01 hit


processed: 199: 100%|████████████████████████████████████████████████████████████████| 200/200 [00:10<00:00, 19.45it/s]
[build_all_tracks] cache miss, finish
[run model] start
read entry 9032c58e45262bcf5beaf10fb491317e hit
read entry 01b5aa180fccf188525f1326806d4fa7 hit


processed: 168:  84%|██████████████████████████████████████████████████████          | 169/200 [12:51<02:00,  3.88s/it]

DropEmpty returned empty data. Skipping all further transforms


got exception for preprocessing:
 message=Traceback (most recent call last):
  File "D:\ariadne_master_clean\ariadne-master\eval\event_evaluation.py", line 122, in run_model
    preprocess_result = model_preprocess_func(event_df)
  File "C:\Users\joel\AppData\Local\Temp\ipykernel_11428\4143196359.py", line 133, in get_graph
    event = event[['event', 'r', 'phi', 'z', 'station', 'track']]
  File "D:\miniconda\envs\ariadne_cpu\lib\site-packages\pandas\core\frame.py", line 3511, in __getitem__
    indexer = self.columns._get_indexer_strict(key, "columns")[1]
  File "D:\miniconda\envs\ariadne_cpu\lib\site-packages\pandas\core\indexes\base.py", line 5782, in _get_indexer_strict
    self._raise_if_missing(keyarr, indexer, axis_name)
  File "D:\miniconda\envs\ariadne_cpu\lib\site-packages\pandas\core\indexes\base.py", line 5845, in _raise_if_missing
    raise KeyError(f"{not_found} not in index")
KeyError: "['r', 'phi'] not in index"
 
                                            on 
event_id