# Converting Graphs for L2IT Collaboration

In [5]:
%load_ext autoreload
%autoreload 2

# System imports
import os
import sys
import yaml
import logging

# External imports
import numpy as np
import pandas as pd
# import seaborn as sns
from tqdm import tqdm
import torch_geometric as pyg
from scipy import sparse as sps
import torch
import networkx as nx

# import seaborn as sns
import torch
import warnings

warnings.filterwarnings("ignore")
sys.path.append("../../..")

from notebooks.ITk.utils import *
from onetrack import TrackingData
device = "cuda" if torch.cuda.is_available() else "cpu"
logging.basicConfig(level=logging.INFO)

## Roadmap
TODO:
Construction comparison dataset
- Download uncorrelated dataset from EOS
- Check that these events aren't already in my dataset
- Run through the processing
- Run through embedding
- Run through filter
- Check performance is okay... (or not)
- Convert to NetworkX
- Convert to GraphML

Graphs for training
- Pull in TF conversion code
- Convert filtered graphs to TF records, including edge_index, hit_ids, hit positions, edge features, hit_pids, and signal truth 

## Uncorrelated Dataset Processed to GraphML

1. Run preprocessing and compare two events

In [10]:
input_dir = "/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/filter_processed/0GeV_uncorrelated_large/test"
outdir = "/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/filter_processed/0GeV_uncorrelated_large/graphml"

In [11]:
# List all files in the input directory
all_files = os.listdir(input_dir)
all_files = [os.path.join(input_dir, file) for file in all_files]

In [12]:
from tqdm.contrib.concurrent import process_map

def process_event(event_file):
    sample = torch.load(event_file, map_location="cpu")
    sample.event_file = os.path.split(sample.event_file)[-1]
    if os.path.exists(os.path.join(outdir, f"{sample.event_file}.graphml")):
        return
    sample_nx = pyg.utils.to_networkx(sample, node_attrs=["hid", "pid"])
    sample_nx.graph["event_file"] = sample.event_file
    nx.write_graphml(sample_nx, os.path.join(outdir, f"{sample.event_file}.graphml"))

In [13]:
_ = process_map(process_event, all_files, max_workers=16)

  0%|          | 0/200 [00:00<?, ?it/s]

## Filtered Graphs to TF Records

In [20]:
graphs.ALL_FIELDS

('nodes', 'edges', 'receivers', 'senders', 'globals', 'n_node', 'n_edge')

In [32]:
"""
Make doublet GraphNtuple
"""
import tensorflow as tf
from graph_nets import graphs

graph_types = {
    'n_node': tf.int32,
    'n_edge': tf.int32,
    'nodes': tf.float32,
    'edges': tf.float32,
    'receivers': tf.int32,
    'senders': tf.int32,
    'globals': tf.float32,
}

def parse_tfrec_function(example_proto):
    features_description = dict(
        [(key+"_IN",  tf.io.FixedLenFeature([], tf.string)) for key in graphs.ALL_FIELDS] + 
        [(key+"_OUT", tf.io.FixedLenFeature([], tf.string)) for key in graphs.ALL_FIELDS])

    example = tf.io.parse_single_example(example_proto, features_description)
    input_dd = graphs.GraphsTuple(**dict([(key, tf.io.parse_tensor(example[key+"_IN"], graph_types[key]))
        for key in graphs.ALL_FIELDS]))
    out_dd = graphs.GraphsTuple(**dict([(key, tf.io.parse_tensor(example[key+"_OUT"], graph_types[key]))
        for key in graphs.ALL_FIELDS]))
    return input_dd, out_dd


def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def serialize_graph(G1, G2):
    feature = {}
    for key in graphs.ALL_FIELDS:
        feature[key+"_IN"] = _bytes_feature(tf.io.serialize_tensor(getattr(G1, key)))
        feature[key+"_OUT"] = _bytes_feature(tf.io.serialize_tensor(getattr(G2, key)))
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()


def specs_from_graphs_tuple(
    graphs_tuple_sample, with_batch_dim=False,
    dynamic_num_graphs=False,
    dynamic_num_nodes=True,
    dynamic_num_edges=True,
    description_fn=tf.TensorSpec,
    ):
    graphs_tuple_description_fields = {}
    edge_dim_fields = [graphs.EDGES, graphs.SENDERS, graphs.RECEIVERS]

    for field_name in graphs.ALL_FIELDS:
        field_sample = getattr(graphs_tuple_sample, field_name)
        if field_sample is None:
            raise ValueError(
                "The `GraphsTuple` field `{}` was `None`. All fields of the "
                "`GraphsTuple` must be specified to create valid signatures that"
                "work with `tf.function`. This can be achieved with `input_graph = "
                "utils_tf.set_zero_{{node,edge,global}}_features(input_graph, 0)`"
                "to replace None's by empty features in your graph. Alternatively"
                "`None`s can be replaced by empty lists by doing `input_graph = "
                "input_graph.replace({{nodes,edges,globals}}=[]). To ensure "
                "correct execution of the program, it is recommended to restore "
                "the None's once inside of the `tf.function` by doing "
                "`input_graph = input_graph.replace({{nodes,edges,globals}}=None)"
                "".format(field_name))

        shape = list(field_sample.shape)
        dtype = field_sample.dtype

        # If the field is not None but has no field shape (i.e. it is a constant)
        # then we consider this to be a replaced `None`.
        # If dynamic_num_graphs, then all fields have a None first dimension.
        # If dynamic_num_nodes, then the "nodes" field needs None first dimension.
        # If dynamic_num_edges, then the "edges", "senders" and "receivers" need
        # a None first dimension.
        if shape:
            if with_batch_dim:
                shape[1] = None
            elif (dynamic_num_graphs \
                or (dynamic_num_nodes \
                    and field_name == graphs.NODES) \
                or (dynamic_num_edges \
                    and field_name in edge_dim_fields)): shape[0] = None

        graphs_tuple_description_fields[field_name] = description_fn(
            shape=shape, dtype=dtype)

    return graphs.GraphsTuple(**graphs_tuple_description_fields)


def dtype_shape_from_graphs_tuple(input_graph, with_batch_dim=False,\
                                with_padding=True, debug=False, with_fixed_size=False):
    graphs_tuple_dtype = {}
    graphs_tuple_shape = {}

    edge_dim_fields = [graphs.EDGES, graphs.SENDERS, graphs.RECEIVERS]
    for field_name in graphs.ALL_FIELDS:
        field_sample = getattr(input_graph, field_name)
        shape = list(field_sample.shape)
        dtype = field_sample.dtype

        if not with_fixed_size and shape and not with_padding:
            if with_batch_dim:
                shape[1] = None
            else:
                if field_name == graphs.NODES or field_name in edge_dim_fields:
                    shape[0] = None

        graphs_tuple_dtype[field_name] = dtype
        graphs_tuple_shape[field_name] = tf.TensorShape(shape)
        if debug:
            print(field_name, shape, dtype)
    
    return graphs.GraphsTuple(**graphs_tuple_dtype), graphs.GraphsTuple(**graphs_tuple_shape)


In [38]:
# TF Conversion Code
"""
base class defines the procedure with that the TFrecord data is produced.
"""
import time
import os

import numpy as np
import torch
import tensorflow as tf
from graph_nets import utils_tf
from tqdm import tqdm

def calc_eta(r, z):
    theta = np.arctan2(r, z)
    return -1.0 * np.log(np.tan(theta / 2.0))


class DoubletsDataset(object):
    def __init__(self, num_workers=1, with_padding=False,
                n_graphs_per_evt=1, overwrite=False, edge_name='edge_index',
                truth_name='y'
        ):
        self.input_dtype = None
        self.input_shape = None
        self.target_dtype = None
        self.target_shape = None
        self.with_padding = False
        self.num_workers = num_workers
        self.overwrite = overwrite
        self.edge_name = edge_name
        self.truth_name = truth_name

    @staticmethod
    def get_edge_features(edge_index, nodes):
        edge_delta_eta = calc_eta(nodes[edge_index[0], 0], nodes[edge_index[0], 2]) - calc_eta(nodes[edge_index[1], 0], nodes[edge_index[1], 2])
        edge_delta_phi = nodes[edge_index[0], 1] - nodes[edge_index[1], 1]
        edge_delta_phi = torch.where(edge_delta_phi > 1, edge_delta_phi - 2, edge_delta_phi)
        edge_delta_phi = torch.where(edge_delta_phi < -1, edge_delta_phi + 2, edge_delta_phi)
        edge_delta_r = nodes[edge_index[0], 0] - nodes[edge_index[1], 0]
        edge_delta_z = nodes[edge_index[0], 2] - nodes[edge_index[1], 2]

        edge_features = torch.stack([edge_delta_eta, edge_delta_phi, edge_delta_r, edge_delta_z], dim=1).T
        return edge_features

    def make_graph(self, event, debug=False):
        """
        Convert the event into a graphs_tuple. 
        """
        edge_name = self.edge_name
        n_nodes = event['x'].shape[0]
        n_edges = event[edge_name].shape[1]
        nodes = event['x']
        edges = np.zeros((n_edges, 1), dtype=np.float32)
        # edges = self.get_edge_features(event[edge_name], nodes) # Returns an array of shape (n_edge_features, n_edges)
        senders =  event[edge_name][0, :]
        receivers = event[edge_name][1, :]
        edge_target = event[self.truth_name].numpy().astype(np.float32)
        
        input_datadict = {
            "n_node": n_nodes,
            "n_edge": n_edges,
            "nodes": nodes,
            "edges": edges,
            "senders": senders,
            "receivers": receivers,
            "globals": np.array([n_nodes], dtype=np.float32)
        }
        n_edges_target = 1
        target_datadict = {
            "n_node": 1,
            "n_edge": n_edges_target,
            "nodes": np.zeros((1, 1), dtype=np.float32),
            "edges": edge_target,
            "senders": np.zeros((n_edges_target,), dtype=np.int32),
            "receivers": np.zeros((n_edges_target,), dtype=np.int32),
            "globals": np.zeros((1,), dtype=np.float32),
        }
        input_graph = utils_tf.data_dicts_to_graphs_tuple([input_datadict])
        target_graph = utils_tf.data_dicts_to_graphs_tuple([target_datadict])
        return [(input_graph, target_graph)]        

    def _get_signature(self, tensors):
        if self.input_dtype and self.target_dtype:
            return 

        ex_input, ex_target = tensors[0]
        self.input_dtype, self.input_shape = dtype_shape_from_graphs_tuple(
            ex_input, with_padding=self.with_padding)
        self.target_dtype, self.target_shape = dtype_shape_from_graphs_tuple(
            ex_target, with_padding=self.with_padding)
    

    def process(self, indir, outdir, max_num_events=None):
        files = os.listdir(indir)
        if max_num_events:
            files = files[:max_num_events]
        ievt = 0
        now = time.time()
        for filename in tqdm(files):
            infile = os.path.join(indir, filename)
            outname = os.path.join(outdir, filename + ".rec")
            if os.path.exists(outname) and not self.overwrite:
                continue
            if "npz" in infile:
                array = np.load(infile)
            else:
                import torch
                array = torch.load(infile, map_location='cpu')
            assert (array.hid == torch.arange(len(array.hid))).all()
            tensors = self.make_graph(array)
            def generator():
                for G in tensors:
                    yield (G[0], G[1])
            self._get_signature(tensors)
            dataset = tf.data.Dataset.from_generator(
                generator, 
                output_types=(self.input_dtype, self.target_dtype),
                output_shapes=(self.input_shape, self.target_shape),
                args=None
            )

            writer = tf.io.TFRecordWriter(outname)
            for data in dataset:
                example = serialize_graph(*data)
                writer.write(example)
            writer.close()
            ievt += 1

        read_time = time.time() - now
        print("{} added {:,} events, in {:.1f} mins".format(self.__class__.__name__,
            ievt, read_time/60.))


In [39]:
indir = "/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/filter_processed/0GeV_v3"
outdir = "/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/filter_processed/0GeV_v3_TFrecords"

data = DoubletsDataset(num_workers=1, overwrite=True,
                        edge_name="edge_index", truth_name="y")

datatypes = ['train', 'val', 'test']
for datatype in datatypes:
    indir = os.path.join(inputdir, datatype)
    outname = os.path.join(outdir, datatype)
    print("processing files in folder: {}".format(indir))
    if not os.path.exists(outname):
        os.makedirs(outname, exist_ok=True)
    data.process(indir=indir, outdir=outname)


processing files in folder: /global/cfs/cdirs/m3443/data/ITk-upgrade/processed/filter_processed/0GeV_v3/train


  0%|          | 0/4000 [00:00<?, ?it/s]

nodes [297481, 3] <dtype: 'float32'>
edges [574596, 1] <dtype: 'float32'>
receivers [574596] <dtype: 'int32'>
senders [574596] <dtype: 'int32'>
globals [1, 1] <dtype: 'float32'>
n_node [1] <dtype: 'int32'>
n_edge [1] <dtype: 'int32'>
nodes [1, 1] <dtype: 'float32'>
edges [574596] <dtype: 'float32'>
receivers [1] <dtype: 'int32'>
senders [1] <dtype: 'int32'>
globals [1, 1] <dtype: 'float32'>
n_node [1] <dtype: 'int32'>
n_edge [1] <dtype: 'int32'>


100%|██████████| 4000/4000 [21:08<00:00,  3.15it/s]


DoubletsDataset added 4,000 events, in 21.1 mins
processing files in folder: /global/cfs/cdirs/m3443/data/ITk-upgrade/processed/filter_processed/0GeV_v3/val


100%|██████████| 100/100 [00:28<00:00,  3.50it/s]


DoubletsDataset added 100 events, in 0.5 mins
processing files in folder: /global/cfs/cdirs/m3443/data/ITk-upgrade/processed/filter_processed/0GeV_v3/test


100%|██████████| 50/50 [00:14<00:00,  3.45it/s]

DoubletsDataset added 50 events, in 0.2 mins



