# Analysis and visualization of hit-graph datasets

The purpose of this notebook is to analyze the preprocessed hit-graph datasets.

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline

In [None]:
def get_sample_size(filename):
    with np.load(filename) as f:
        n_nodes = f['X'].shape[0]
        n_edges = f['y'].shape[0]
    return n_nodes, n_edges

def process_dataset(dir, n_files=None):
    files = [os.path.join(data_dir, f) for f in os.listdir(data_dir)]
    print('%i total files' % len(files))
    if n_files is not None:
        files = files[:n_files]
    shapes = np.array([get_sample_size(f) for f in files if not 'ID' in f])
    n_nodes, n_edges = shapes[:,0], shapes[:,1]
    return n_nodes, n_edges

## Small dataset

In [None]:
data_dir = "/home/benjamin/xtracker/examples/data/hitgraphs_belle2_vtx"
n_files = 1000

In [None]:
n_nodes, n_edges = process_dataset(data_dir, n_files)

In [None]:
plt.figure(figsize=(8,6))

plt.hist2d(n_nodes, n_edges)
plt.xlabel('Number of graph nodes')
plt.ylabel('Number of graph edges')
plt.colorbar();

## Test

In [None]:
def get_sample_results(filename):
    with np.load(filename) as f:
        n_nodes = f['X'].shape[0]
        n_edges = f['y'].shape[0]
        purity = f['y'].mean()
    return n_nodes, n_edges, purity

In [None]:
n_files = 100

In [None]:
files = [os.path.join(data_dir, f) for f in os.listdir(data_dir)]
print('%i total files' % len(files))
if n_files is not None:
    files = files[:n_files]
sample_results = [get_sample_results(f) for f in files if not 'ID' in f]
n_nodes, n_edges, purity = zip(*sample_results)

In [None]:
plt.figure(figsize=(8,6))

plt.hist(purity)
plt.xlabel('purity')
plt.ylabel('Number of events')
