# NuGraph dataset quick look

This notebook mirrors `scripts/explore_dataset.py` so you can inspect a few events and generate exploratory plots interactively.

In [11]:
from pathlib import Path
from collections import defaultdict
import h5py
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

# Change working directory
os.chdir('/exp/sbnd/app/users/yuhw/nugraph')
print(f"Current working directory: {os.getcwd()}")

%matplotlib inline
plt.rcParams.update({"axes.grid": True,
                     "axes.spines.top": False,
                     "axes.spines.right": False,
                     "figure.dpi": 120})


Current working directory: /exp/sbnd/app/users/yuhw/nugraph


In [12]:
data_path = Path('test/NG2-paper.gnn.keep1.h5')
split = 'test'
limit = 3  # set to None to use every sample in the split
outdir = Path('../plots_notebook')
outdir.mkdir(parents=True, exist_ok=True)
data_path


PosixPath('test/NG2-paper.gnn.keep1.h5')

In [13]:
def decode_bytes(values):
    return [v.decode() if isinstance(v, (bytes, np.bytes_)) else str(v) for v in values]

def load_sample(record):
    return {field: record[field] for field in record.dtype.names}

def semantic_palette(classes):
    cmap = plt.colormaps.get_cmap('tab10', len(classes))
    palette = {'background': '#b3b3b3'}
    for idx, cls in enumerate(classes):
        palette[cls] = mcolors.to_hex(cmap(idx))
    return palette

def plot_semantic(sample_name, planes, semantic_labels, palette, record):
    label_names = ['background'] + semantic_labels
    nplanes = len(planes)
    fig, axes = plt.subplots(1, nplanes, figsize=(4 * nplanes, 4), squeeze=False)
    for ax, plane in zip(axes[0], planes):
        pos = record[f'{plane}/pos']
        y_sem = record[f'{plane}/y_semantic'].astype(int)
        labels = [label_names[val + 1] for val in y_sem]
        colours = [palette[label] for label in labels]
        ax.scatter(pos[:, 0], pos[:, 1], c=colours, s=5, linewidths=0)
        ax.set_title(f'{plane.upper()} plane')
        ax.set_xlabel('proj')
        ax.set_ylabel('drift')
    fig.suptitle(f'Semantic truth – {sample_name}')
    fig.tight_layout()
    fig.savefig(outdir / f'{sample_name}_semantic.png', bbox_inches='tight')
    plt.show()

def plot_instances(sample_name, planes, record):
    nplanes = len(planes)
    fig, axes = plt.subplots(1, nplanes, figsize=(4 * nplanes, 4), squeeze=False)
    for ax, plane in zip(axes[0], planes):
        pos = record[f'{plane}/pos']
        inst = record[f'{plane}/y_instance'].astype(int)
        mask = inst >= 0
        valid = np.unique(inst[mask])
        if len(valid):
            cmap = plt.colormaps.get_cmap('nipy_spectral', len(valid))
            colours = [cmap(np.where(valid == val)[0][0]) if val in valid else '#b3b3b3' for val in inst]
        else:
            colours = ['#b3b3b3'] * len(inst)
        ax.scatter(pos[:, 0], pos[:, 1], c=colours, s=5, linewidths=0)
        vtx_key = f'{plane}/y_vtx'
        if vtx_key in record and record[vtx_key].size:
            vtx = record[vtx_key]
            ax.scatter(vtx[:, 0], vtx[:, 1], marker='*', s=120, c='k')
        ax.set_title(f'{plane.upper()} plane')
        ax.set_xlabel('proj')
        ax.set_ylabel('drift')
    fig.suptitle(f'Instance ids – {sample_name}')
    fig.tight_layout()
    fig.savefig(outdir / f'{sample_name}_instances.png', bbox_inches='tight')
    plt.show()

def update_counts(per_plane_counts, planes, semantic_labels, record):
    for plane in planes:
        y_sem = record[f'{plane}/y_semantic'].astype(int)
        counts = np.bincount(y_sem + 1, minlength=len(semantic_labels) + 1)
        per_plane_counts[plane] += counts

def plot_semantic_summary(split_name, planes, semantic_labels, palette, per_plane_counts):
    label_names = ['background'] + semantic_labels
    indices = np.arange(len(planes))
    bottom = np.zeros(len(planes), dtype=int)
    fig, ax = plt.subplots(figsize=(6, 4))
    for label_idx, label in enumerate(label_names):
        counts = np.array([per_plane_counts[p][label_idx] for p in planes])
        ax.bar(indices, counts, 0.6, bottom=bottom, label=label, color=palette.get(label, '#999999'))
        bottom += counts
    ax.set_xticks(indices, [p.upper() for p in planes])
    ax.set_ylabel('hit count')
    ax.set_title(f'Semantic label distribution – {split_name}')
    ax.legend(fontsize='small', loc='upper right')
    fig.tight_layout()
    fig.savefig(outdir / f'{split_name}_semantic_summary.png', bbox_inches='tight')
    plt.show()


In [14]:
with h5py.File(data_path, 'r') as f:
    planes = decode_bytes(f['planes'][()])
    semantic_labels = decode_bytes(f['semantic_classes'][()])
    palette = semantic_palette(semantic_labels)
    samples = decode_bytes(f['samples'][split][()])
    if limit is not None:
        samples = samples[:limit]
    per_plane_counts = {plane: np.zeros(len(semantic_labels) + 1, dtype=int) for plane in planes}
    for sample in samples:
        record = load_sample(f[f'dataset/{sample}'][()])
        plot_semantic(sample, planes, semantic_labels, palette, record)
        plot_instances(sample, planes, record)
        update_counts(per_plane_counts, planes, semantic_labels, record)
    plot_semantic_summary(split, planes, semantic_labels, palette, per_plane_counts)


TypeError: ColormapRegistry.get_cmap() takes 2 positional arguments but 3 were given