In [19]:
import os, sys, torch
from tqdm import tqdm
sys.path.append('/global/cfs/cdirs/m3443/usr/pmtuan/Tracking-ML-Exa.TrkX')

from Pipelines.TrackML_Example.LightningModules.GNN.Models.interaction_gnn import InteractionGNN
from Pipelines.TrackML_Example.notebooks.build_gnn import GNNInferenceBuilder

In [23]:
class GNNInferenceBuilder:
    def __init__(self, model, overwrite=None):
        self.model = model
        self.output_dir = model.hparams["output_dir"]
        if overwrite is not None: 
            self.overwrite = overwrite
        elif "overwrite" in model.hparams:
            self.overwrite = model.hparams["overwrite"]
        else:
            self.overwrite = False

        # Prep the directory to produce inference data to
        self.datatypes = ["train", "val", "test"]
        os.makedirs(self.output_dir, exist_ok=True)
        [
            os.makedirs(os.path.join(self.output_dir, datatype), exist_ok=True)
            for datatype in self.datatypes
        ]

    def infer(self):
        print("Training finished, running inference to filter graphs...")

        # By default, the set of examples propagated through the pipeline will be train+val+test set
        datasets = {
            "train": self.model.trainset,
            "val": self.model.valset,
            "test": self.model.testset,
        }

        self.model.eval()
        with torch.no_grad():
            for set_idx, (datatype, dataset) in enumerate(datasets.items()):
                print(f"Building {datatype}")
                for batch in tqdm(dataset[:2]):
                    batch = batch.to(self.model.device)
                    batch = self.construct_downstream(batch)
                    self.save_downstream(batch, datatype)


    def construct_downstream(self, batch):
        print(batch)

        output = self.model.shared_evaluation(batch, 0, log=False)
        
        batch.scores = output["score"][: int(len(output["score"]) / 2)]
        print(batch)
        return batch
    
    def save_downstream(self, batch, datatype):
        pass

In [16]:
load_model = "/global/cfs/cdirs/m3443/usr/pmtuan/Tracking-ML-Exa.TrkX/run/gnn/models/trackml_wobbly-oath-197_version_23.ckpt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = InteractionGNN.load_from_checkpoint(load_model).to(device)
model.setup_data()

Training finished, running inference to filter graphs...
Building train


100%|██████████| 800/800 [00:00<00:00, 46121.03it/s]


Building val


100%|██████████| 100/100 [00:00<00:00, 40148.41it/s]


Building test


100%|██████████| 100/100 [00:00<00:00, 40868.21it/s]


In [24]:
graph_scorer = GNNInferenceBuilder(model)
graph_scorer.infer()

Training finished, running inference to filter graphs...
Building train


  0%|          | 0/2 [00:00<?, ?it/s]

Data(x=[125576, 3], pid=[125576], modules=[125576], event_file='/global/cfs/cdirs/m3443/data/trackml-codalab/train_all/event000021000', hid=[125576], pt=[125576], weights=[96834], modulewise_true_edges=[2, 96834], cell_data=[125576, 9], signal_true_edges=[2, 13819], edge_index=[2, 93287], y=[93287], y_pid=[93287])


100%|██████████| 2/2 [00:00<00:00,  2.46it/s]


Data(x=[125576, 3], pid=[125576], modules=[125576], event_file='/global/cfs/cdirs/m3443/data/trackml-codalab/train_all/event000021000', hid=[125576], pt=[125576], weights=[96834], modulewise_true_edges=[2, 96834], cell_data=[125576, 9], signal_true_edges=[2, 13819], edge_index=[2, 93287], y=[93287], y_pid=[93287], scores=[93287])
Data(x=[120844, 3], pid=[120844], modules=[120844], event_file='/global/cfs/cdirs/m3443/data/trackml-codalab/train_all/event000021001', hid=[120844], pt=[120844], weights=[92848], modulewise_true_edges=[2, 92848], cell_data=[120844, 9], signal_true_edges=[2, 13039], edge_index=[2, 88508], y=[88508], y_pid=[88508])
Data(x=[120844, 3], pid=[120844], modules=[120844], event_file='/global/cfs/cdirs/m3443/data/trackml-codalab/train_all/event000021001', hid=[120844], pt=[120844], weights=[92848], modulewise_true_edges=[2, 92848], cell_data=[120844, 9], signal_true_edges=[2, 13039], edge_index=[2, 88508], y=[88508], y_pid=[88508], scores=[88508])
Building val


100%|██████████| 2/2 [00:00<00:00, 14.87it/s]


Data(x=[116176, 3], pid=[116176], modules=[116176], event_file='/global/cfs/cdirs/m3443/data/trackml-codalab/train_all/event000021002', hid=[116176], pt=[116176], weights=[88801], modulewise_true_edges=[2, 88801], cell_data=[116176, 9], signal_true_edges=[2, 13854], edge_index=[2, 84407], y=[84407], y_pid=[84407])
Data(x=[116176, 3], pid=[116176], modules=[116176], event_file='/global/cfs/cdirs/m3443/data/trackml-codalab/train_all/event000021002', hid=[116176], pt=[116176], weights=[88801], modulewise_true_edges=[2, 88801], cell_data=[116176, 9], signal_true_edges=[2, 13854], edge_index=[2, 84407], y=[84407], y_pid=[84407], scores=[84407])
Data(x=[122332, 3], pid=[122332], modules=[122332], event_file='/global/cfs/cdirs/m3443/data/trackml-codalab/train_all/event000021007', hid=[122332], pt=[122332], weights=[94101], modulewise_true_edges=[2, 94101], cell_data=[122332, 9], signal_true_edges=[2, 13026], edge_index=[2, 89287], y=[89287], y_pid=[89287])
Data(x=[122332, 3], pid=[122332], mo

  0%|          | 0/2 [00:00<?, ?it/s]

Data(x=[100511, 3], pid=[100511], modules=[100511], event_file='/global/cfs/cdirs/m3443/data/trackml-codalab/train_all/event000021005', hid=[100511], pt=[100511], weights=[74374], modulewise_true_edges=[2, 74374], cell_data=[100511, 9], signal_true_edges=[2, 10886], edge_index=[2, 63905], y=[63905], y_pid=[63905])


100%|██████████| 2/2 [00:00<00:00, 18.72it/s]

Data(x=[100511, 3], pid=[100511], modules=[100511], event_file='/global/cfs/cdirs/m3443/data/trackml-codalab/train_all/event000021005', hid=[100511], pt=[100511], weights=[74374], modulewise_true_edges=[2, 74374], cell_data=[100511, 9], signal_true_edges=[2, 10886], edge_index=[2, 63905], y=[63905], y_pid=[63905], scores=[63905])
Data(x=[88019, 3], pid=[88019], modules=[88019], event_file='/global/cfs/cdirs/m3443/data/trackml-codalab/train_all/event000021019', hid=[88019], pt=[88019], weights=[63061], modulewise_true_edges=[2, 63061], cell_data=[88019, 9], signal_true_edges=[2, 8365], edge_index=[2, 49670], y=[49670], y_pid=[49670])
Data(x=[88019, 3], pid=[88019], modules=[88019], event_file='/global/cfs/cdirs/m3443/data/trackml-codalab/train_all/event000021019', hid=[88019], pt=[88019], weights=[63061], modulewise_true_edges=[2, 63061], cell_data=[88019, 9], signal_true_edges=[2, 8365], edge_index=[2, 49670], y=[49670], y_pid=[49670], scores=[49670])



