In [1]:
import torch 
import yaml, os
import matplotlib, seaborn as sns
from torch_geometric.data import HeteroData, Dataset
from pytorch_lightning import LightningModule
from itertools import combinations_with_replacement, product
from typing import Dict, Optional
from torch import Tensor
from torch_geometric.typing import Adj, EdgeType, NodeType
import pdb
import numpy as np
import sys 
sys.path.append('/global/cfs/cdirs/m3443/usr/pmtuan/Tracking-ML-Exa.TrkX')
%matplotlib inline

CONFIGFILE = "hetero_reg3_lev3.yaml"
with open(CONFIGFILE, 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader)
# data_dir = config['input_dir'] + '/train'
# data = torch.load(os.path.join(data_dir, os.listdir(data_dir)[0]))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from utils import load_dataset, LargeDataset, background_cut_event, make_mlp, process_data, convert_triplet_graph
import torch_geometric.transforms as T
import random

In [3]:
from functools import partial

def get_region(model):
    return 'volume_' + '_'.join([str(i) for i in model['volume_ids']])

def process_data(events, pt_background_cut, pt_signal_cut, noise, triplets, input_cut):
    # Handle event in batched form
    if type(events) is not list:
        events = [events]

    # NOTE: Cutting background by pT BY DEFINITION removes noise
    if pt_background_cut > 0 or not noise:
        for i, event in enumerate(events):

            if triplets:  # Keep all event data for posterity!
                event = convert_triplet_graph(event)

            else:
                event = background_cut_event(event, pt_background_cut, pt_signal_cut)                
                    
    for i, event in enumerate(events):
        
        # Ensure PID definition is correct
        event.y_pid = (event.pid[event.edge_index[0]] == event.pid[event.edge_index[1]]) & event.pid[event.edge_index[0]].bool()
        event.pid_signal = torch.isin(event.edge_index, event.signal_true_edges).all(0) & event.y_pid
        
        if (input_cut is not None) and "scores" in event.keys:
            score_mask = event.scores > input_cut
            for edge_attr in ["edge_index", "y", "y_pid", "pid_signal", "scores"]:
                event[edge_attr] = event[edge_attr][..., score_mask]            

    return events[0]

# class LargeHeteroDataset(Dataset):
#     def __init__(self, root, subdir, hparams, num_events=-1, transform=None, pre_transform=None, pre_filter=None):
#         super().__init__(root, transform, pre_transform, pre_filter)
        
#         self.subdir = subdir
#         self.hparams = hparams
#         # if transform is not None:
#         #     from functools import partial
#         #     self.transform = partial(transform, pt_background_cut=self.hparams['pt_background_cut'], pt_signal_cut=self.hparams['pt_signal_cut'], noise=self.hparams['noise'], triplets=False, input_cut=None)
        
#         self.input_paths = os.listdir(os.path.join(root, subdir))
#         if "sorted_events" in hparams.keys() and hparams["sorted_events"]:
#             self.input_paths = sorted(self.input_paths)
#         else:
#             random.shuffle(self.input_paths)
        
#         self.input_paths = [os.path.join(root, subdir, event) for event in self.input_paths][:num_events]
        
#     def len(self):
#         return len(self.input_paths)
    
#     def get(self, idx):
#         event = torch.load(self.input_paths[idx], map_location=torch.device("cpu"))     

#         map = torch.zeros_like(event.hid)
#         for model in self.hparams['model_ids']:
#             volume_id = model['volume_ids']
#             homo_ids = event.hid[ torch.isin( event.volume_id, torch.tensor(volume_id) ) ]
#             map[homo_ids] = torch.arange(homo_ids.shape[0])
        

#         data = HeteroData()
#         for model in self.hparams['model_ids']:
#             region = get_region(model)
#             mask = torch.isin( event.volume_id, torch.tensor(model['volume_ids']) )
#             for attr in ['x', 'cell_data', 'pid', 'hid', 'pt', 'primary', 'nhits', 'modules', 'volume_id']:
#                 data[region][attr] = event[attr][mask]
#             data[region]['mask'] = mask
        
#         for model1, model2 in product(self.hparams['model_ids'], self.hparams['model_ids']):
#             # ids = torch.tensor([model1['volume_ids'], model2['volume_ids']])
#             id0, id1 = torch.tensor([model1['volume_ids']]), torch.tensor([model2['volume_ids']])
#             region1, region2 = get_region(model1), get_region(model2)
#             mask0 = torch.isin(event.volume_id[event.edge_index[0]], id0)
#             mask1 = torch.isin(event.volume_id[event.edge_index[1]], id1)
#             mask = mask1 * mask0 #+ torch.isin(event.volume_id[event.edge_index[0]], id2) * torch.isin(event.volume_id[event.edge_index[1]],id1)
#             edge_index = event.edge_index.T[mask].T
#             edge_index = map[edge_index]
#             data[region1, 'connected_to', region2].edge_index = edge_index
#             data[region1, 'connected_to', region2].y = event.y[mask]
#             data[region1, 'connected_to', region2].y_pid = event.y_pid[mask]
#             for truth_edge in ['modulewise_true_edges', 'signal_true_edges']:
#                 mask = torch.isin(event.volume_id[event[truth_edge][0]], id0) * torch.isin(event.volume_id[event[truth_edge][1]], id1) #+ torch.isin(event.volume_id[event[truth_edge][0]], id2) * torch.isin(event.volume_id[event[truth_edge][1]], id1)
#                 data[region1, 'connected_to', region2][truth_edge] = event[truth_edge].T[mask].T
#         return data 

# class LargeHeteroDataset(LargeDataset):
#     def __init__(self, root, subdir, hparams, num_events=-1, transform=None, pre_transform=None, pre_filter=None):
#         super().__init__(root, subdir, hparams, num_events, transform, pre_transform, pre_filter)
#         print(self.__getitem__)

#     # def get_region(self, model):
#     #     return 'region_' + '_'.join([str(i) for i in model['volume_ids']])

#     def get(self, idx):
        
#         event = self.__getitem__(idx)

#         # Process event with pt cuts
#         # if self.hparams["pt_background_cut"] > 0:
#         #     event = background_cut_event(event, self.hparams["pt_background_cut"], self.hparams["pt_signal_cut"])
        
#         # Ensure PID definition is correct
#         # event.y_pid = (event.pid[event.edge_index[0]] == event.pid[event.edge_index[1]]) & event.pid[event.edge_index[0]].bool()
#         # event.pid_signal = torch.isin(event.edge_index, event.signal_true_edges).all(0) & event.y_pid

#         # create new hit map
#         models = config['model_ids']
#         map = torch.zeros_like(event.hid)
#         for model in self.hparams['model_ids']:
#             volume_id = model['volume_ids']
#             homo_ids = event.hid[ torch.isin( event.volume_id, torch.tensor(volume_id) ) ]
#             map[homo_ids] = torch.arange(homo_ids.shape[0])
        

#         data = HeteroData()
#         for model in self.hparams['model_ids']:
#             region = get_region(model)
#             mask = torch.isin( event.volume_id, torch.tensor(model['volume_ids']) )
#             for attr in ['x', 'cell_data', 'pid', 'hid', 'pt', 'primary', 'nhits', 'modules', 'volume_id']:
#                 data[region][attr] = event[attr][mask]
#             data[region]['mask'] = mask
        
#         for model1, model2 in product(self.hparams['model_ids'], self.hparams['model_ids']):
#             # ids = torch.tensor([model1['volume_ids'], model2['volume_ids']])
#             id0, id1 = torch.tensor([model1['volume_ids']]), torch.tensor([model2['volume_ids']])
#             region1, region2 = get_region(model1), get_region(model2)
#             mask0 = torch.isin(event.volume_id[event.edge_index[0]], id0)
#             mask1 = torch.isin(event.volume_id[event.edge_index[1]], id1)
#             mask = mask1 * mask0 #+ torch.isin(event.volume_id[event.edge_index[0]], id2) * torch.isin(event.volume_id[event.edge_index[1]],id1)
#             edge_index = event.edge_index.T[mask].T
#             edge_index = map[edge_index]
#             data[region1, 'connected_to', region2].edge_index = edge_index
#             data[region1, 'connected_to', region2].y = event.y[mask]
#             data[region1, 'connected_to', region2].y_pid = event.y_pid[mask]
#             for truth_edge in ['modulewise_true_edges', 'signal_true_edges']:
#                 mask = torch.isin(event.volume_id[event[truth_edge][0]], id0) * torch.isin(event.volume_id[event[truth_edge][1]], id1) #+ torch.isin(event.volume_id[event[truth_edge][0]], id2) * torch.isin(event.volume_id[event[truth_edge][1]], id1)
#                 data[region1, 'connected_to', region2][truth_edge] = event[truth_edge].T[mask].T
#         return data


In [5]:
files = os.listdir(os.path.join(config['input_dir'], 'train'))
e = torch.load(os.path.join(config['input_dir'], 'train', files[0]))
e

Data(x=[297481, 9], cell_data=[297481, 11], pid=[297481], event_file='/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/full_events_v4/event000015937', hid=[297481], pt=[297481], primary=[297481], nhits=[297481], modules=[297481], modulewise_true_edges=[2, 116389], signal_true_edges=[2, 12588], edge_index=[2, 574596], y=[574596], y_pid=[574596], volume_id=[297481])

In [9]:
from Pipelines.Common_Tracking_Example.LightningModules.GNN.hetero_dataset import LargeHeteroDataset
transform = partial(process_data, pt_background_cut=config['pt_background_cut'], pt_signal_cut=config['pt_signal_cut'], noise=config['noise'], triplets=False, input_cut=None)
dataset = LargeHeteroDataset(root=config['input_dir'], subdir='train', num_events=10, hparams=config, process_function=transform)
data = dataset.get(0)
# homo_data = torch.load(input_dir)
# undirected_data = T.ToUndirected()(data)

In [11]:
from torch_geometric.nn import HeteroConv, HeteroLinear, MLP, GCNConv, MessagePassing, to_hetero
from torch.nn import Module, ModuleDict
from collections import defaultdict

class NodeEncoder(torch.nn.Module):
    def __init__(self, hparams, model) -> None:
        super().__init__()
        self.hparams = hparams

        self.network = make_mlp(
            model['num_features'],
            [hparams["hidden"]] * hparams["nb_node_layer"],
            output_activation=hparams["output_activation"],
            hidden_activation=hparams["hidden_activation"],
            layer_norm=hparams["layernorm"],
            batch_norm=hparams["batchnorm"],
        )

    def forward(self, x):
        return  self.network( x.float() )
    
class HeteroNodeEncoder(torch.nn.Module):
    def __init__(self, hparams) -> None:
        super().__init__()

        self.hparams = hparams

        self.encoders = torch.nn.ModuleDict()
        for model in self.hparams['model_ids']:
            region = get_region(model)
            self.encoders[region] = NodeEncoder(self.hparams, model)
        
    def forward(self, x_dict):
        for model in self.hparams['model_ids']:
            region = get_region(model)
            print(region)
            x_dict[region] = self.encoders[region](x_dict[region][:, : model['num_features']])
        
        return x_dict

In [7]:
encoder = HeteroNodeEncoder(hparams=config)

encoder(data.x_dict)

{'volume_0_1': tensor([[ 0.4559, -0.6419, -0.2931,  ...,  0.0601,  0.2524,  0.7009],
         [ 0.4751, -0.6559, -0.3046,  ...,  0.0529,  0.2268,  0.7113],
         [ 0.4322, -0.6406, -0.2763,  ...,  0.0861,  0.2159,  0.7096],
         ...,
         [-0.9515, -0.9330, -0.9221,  ..., -0.2887,  0.8334, -0.8976],
         [-0.9532, -0.9333, -0.9220,  ..., -0.2904,  0.8344, -0.8974],
         [-0.9538, -0.9331, -0.9231,  ..., -0.2832,  0.8338, -0.8953]],
        grad_fn=<TanhBackward0>),
 'volume_2': tensor([[ 0.5418,  0.9371,  0.8508,  ..., -0.9120,  0.1569,  0.0486],
         [ 0.5762,  0.9358,  0.8506,  ..., -0.9064,  0.1539,  0.0734],
         [ 0.4515,  0.9366,  0.8540,  ..., -0.9223,  0.1309, -0.0043],
         ...,
         [ 0.3760, -0.8509, -0.4800,  ..., -0.7841,  0.7006, -0.4799],
         [ 0.3826, -0.8437, -0.4772,  ..., -0.7881,  0.7012, -0.4746],
         [ 0.4670, -0.7989, -0.5056,  ..., -0.8094,  0.7284, -0.4121]],
        grad_fn=<TanhBackward0>),
 'volume_3': tensor([[ 0

In [12]:
class EdgeEncoder(torch.nn.Module):
    def __init__(self, hparams) -> None:
        super().__init__()
        self.hparams = hparams

        self.network = make_mlp(
            2 * hparams['hidden'],
            [hparams["hidden"]] * hparams["nb_node_layer"],
            output_activation=hparams["output_activation"],
            hidden_activation=hparams["hidden_activation"],
            layer_norm=hparams["layernorm"],
            batch_norm=hparams["batchnorm"],
        )

    def forward(self, x, edge_index):
        src, dst = edge_index
        print("Encoding edges")
        if isinstance(x, tuple):
            x1, x2 = x
            x_in = torch.cat([x1[src], x2[dst]], dim=-1)
        else:
            x_in = torch.cat([x[src], x[dst]], dim=-1)

        return  self.network( x_in )

class HeteroEdgeConv(HeteroConv):

    def __init__(self, convs: dict, aggr: str = "sum"):
        super().__init__(convs, aggr)

    def forward(
        self,
        x_dict: dict,
        edge_index_dict: dict,
        *args_dict,
        **kwargs_dict,
    ) -> dict :

        out_dict = {}
        for edge_type, edge_index in edge_index_dict.items():
            print(edge_type)
            src, rel, dst = edge_type

            str_edge_type = '__'.join(edge_type)
            if str_edge_type not in self.convs:
                continue

            args = []
            for value_dict in args_dict:
                if edge_type in value_dict:
                    args.append(value_dict[edge_type])
                elif src == dst and src in value_dict:
                    args.append(value_dict[src])
                elif src in value_dict or dst in value_dict:
                    args.append(
                        (value_dict.get(src, None), value_dict.get(dst, None)))

            kwargs = {}
            for arg, value_dict in kwargs_dict.items():
                arg = arg[:-5]  # `{*}_dict`
                if edge_type in value_dict:
                    kwargs[arg] = value_dict[edge_type]
                elif src == dst and src in value_dict:
                    kwargs[arg] = value_dict[src]
                elif src in value_dict or dst in value_dict:
                    kwargs[arg] = (value_dict.get(src, None),
                                value_dict.get(dst, None))

            conv = self.convs[str_edge_type]
            print(kwargs)

            if src == dst:
                out = conv(x_dict[src], edge_index, *args, **kwargs)
            else:
                out = conv((x_dict[src], x_dict[dst]), edge_index, *args, **kwargs)

            out_dict[edge_type] = out

        return out_dict

class EdgeClassifier(torch.nn.Module):

    def __init__(self, hparams):
        super().__init__()

        self.network = make_mlp(
            3 * hparams["hidden"],
            [hparams["hidden"]] * hparams["nb_edge_layer"] + [1],
            layer_norm=hparams["layernorm"],
            batch_norm=hparams["batchnorm"],
            output_activation=hparams['output_activation'],
            hidden_activation=hparams["hidden_activation"],
        )

    def forward(self, x, edge_index, edge):
        src, dst = edge_index
        if isinstance(x, tuple):
            x1, x2 = x
            classifier_input = torch.cat([x1[src], x2[dst], edge], dim=-1)
        else:
            classifier_input = torch.cat([x[src], x[dst], edge], dim=-1)
        # classifier_input = torch.cat([x[src], x[dst], edge], dim=-1)
        return self.network(classifier_input)

In [13]:
class InteractionMessagePassing(MessagePassing):
    def __init__(self, hparams, aggr: str = "add", flow: str = "source_to_target", node_dim: int = -2, decomposed_layers: int = 1):
        super().__init__(aggr, flow=flow, node_dim=node_dim, decomposed_layers=decomposed_layers)

        self.hparams=hparams

        # The edge network computes new edge features from connected nodes
        self.edge_encoder = make_mlp(
            2 * (hparams["hidden"]),
            [hparams["hidden"]] * hparams["nb_edge_layer"],
            layer_norm=hparams["layernorm"],
            batch_norm=hparams["batchnorm"],
            output_activation=hparams["output_activation"],
            hidden_activation=hparams["hidden_activation"],
        )

        # The edge network computes new edge features from connected nodes
        self.edge_network = make_mlp(
            3 * hparams["hidden"],
            [hparams["hidden"]] * hparams["nb_edge_layer"],
            layer_norm=hparams["layernorm"],
            batch_norm=hparams["batchnorm"],
            output_activation=hparams["output_activation"],
            hidden_activation=hparams["hidden_activation"],
        )

        # The node network computes new node features
        self.node_network = make_mlp(
            2 * hparams["hidden"],
            [hparams["hidden"]] * hparams["nb_node_layer"],
            layer_norm=hparams["layernorm"],
            batch_norm=hparams["batchnorm"],
            output_activation=hparams["output_activation"],
            hidden_activation=hparams["hidden_activation"],
        )

    def message(self, edge):
        return edge

    def aggregate(self, out, edge_index):

        src, dst = edge_index
        return self.aggr_module(out, dst)[dst.unique()]
    
    def update(self, agg_message, x, edge_index):
        src, dst = edge_index
        indices_to_add = torch.arange(agg_message.shape[0])
        print(dst.unique())
        print(x)
        x[dst.unique()] += agg_message
        
        return x

    def edge_update(self, x, edge, edge_index, *args, **kwargs):
        src, dst = edge_index
        if isinstance(x, tuple):
            x_src, x_dst = x[0][src], x[1][dst]
        else:
            x_src, x_dst = x[src], x[dst]
        out = self.edge_network(torch.cat([x_src, x_dst, edge], dim=-1))
        return out

    def forward(self, x, edge_index, edge):

        if isinstance(x, tuple):
            x_src, x_dst = x
        else:
            x_src, x_dst = x, x

        x_dst = self.propagate(edge_index, x=x_dst, edge=edge)

        return x_dst

class InteractionHeteroConv(HeteroConv):
    def __init__(self, convs: Dict[EdgeType, Module], aggr: Optional[str] = "sum"):
        super().__init__(convs, aggr)

    def edge_forward(self,x_dict: Dict[NodeType, Tensor],
        edge_index_dict: Dict[EdgeType, Adj],
        edge_dict,
        *args_dict,
        **kwargs_dict,
    ) -> Dict[NodeType, Tensor]:

        out_dict = {}
        for edge_type, edge_index in edge_index_dict.items():
            print(edge_type)
            src, rel, dst = edge_type

            str_edge_type = '__'.join(edge_type)
            if str_edge_type not in self.convs:
                continue

            args = []
            for value_dict in args_dict:
                if edge_type in value_dict:
                    args.append(value_dict[edge_type])
                elif src == dst and src in value_dict:
                    args.append(value_dict[src])
                elif src in value_dict or dst in value_dict:
                    args.append(
                        (value_dict.get(src, None), value_dict.get(dst, None)))

            kwargs = {}
            for arg, value_dict in kwargs_dict.items():
                arg = arg[:-5]  # `{*}_dict`
                if edge_type in value_dict:
                    kwargs[arg] = value_dict[edge_type]
                elif src == dst and src in value_dict:
                    kwargs[arg] = value_dict[src]
                elif src in value_dict or dst in value_dict:
                    kwargs[arg] = (value_dict.get(src, None),
                                value_dict.get(dst, None))

            conv = self.convs[str_edge_type]
            edge = edge_dict[edge_type]

            out = conv.edge_update((x_dict[src], x_dict[dst]), edge, edge_index, *args, **kwargs)

            # if src == dst:
            #     out = conv.edge_updater(x_dict[src], edge_index, *args, **kwargs)
            # else:
            #     out = conv.edge_update((x_dict[src], x_dict[dst]), edge_index, *args,
            #             **kwargs)

            out_dict[edge_type] = out

        return out_dict


class HeteroGNN(torch.nn.Module):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.node_encoders = HeteroNodeEncoder(self.hparams)
        self.edge_encoders = HeteroEdgeConv({
            (get_region(model0), 'connected_to', get_region(model1)): EdgeEncoder(self.hparams)
            for model0, model1 in product(self.hparams['model_ids'], self.hparams['model_ids'])
        })

        self.convs = torch.nn.ModuleList()

        for _ in range(2):
            conv = InteractionHeteroConv({
                (get_region(model0), 'connected_to', get_region(model1)): InteractionMessagePassing(hparams=self.hparams)
                for model0, model1 in product(self.hparams['model_ids'], self.hparams['model_ids'])
            }, aggr='sum')
            self.convs.append(conv)

        self.edge_classifiers = HeteroEdgeConv({
            (get_region(model0), 'connected_to', get_region(model1)): EdgeClassifier(self.hparams)
            for model0, model1 in combinations_with_replacement(self.hparams['model_ids'], 2)
        })


    def forward(self, x_dict: dict, edge_index_dict: dict):
        x_dict = self.node_encoders(x_dict)  
        edge_dict = self.edge_encoders(x_dict, edge_index_dict) 
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict, edge_dict=edge_dict)
            edge_dict = conv.edge_forward(x_dict, edge_index_dict, edge_dict)

        out = self.edge_classifiers(x_dict, edge_index_dict, edge_dict=edge_dict)

        return out

In [17]:
from torch_geometric.data import DataLoader
import torch.functional as F

class HeteroGNNBase(LightningModule):
    def __init__(self, hparams):
        super().__init__()
        """
        Initialise the Lightning Module that can scan over different GNN training regimes
        """

        # Assign hyperparameters
        self.save_hyperparameters(hparams)
        self.trainset, self.valset, self.testset = None, None, None

    def setup(self, stage):
        # Handle any subset of [train, val, test] data split, assuming that ordering

        if self.trainset is None:
            print("Setting up dataset")
            input_subdirs = [None, None, None]
            input_subdirs[: len(self.hparams["datatype_names"])] = [
                os.path.join(self.hparams["input_dir"], datatype)
                for datatype in self.hparams["datatype_names"]
            ]
            self.trainset, self.valset, self.testset = [
                load_dataset(
                    input_subdir=input_subdir,
                    num_events=self.hparams["datatype_split"][i],
                    **self.hparams
                )
                for i, input_subdir in enumerate(input_subdirs)
            ]

        if (
            (self.trainer)
            and ("logger" in self.trainer.__dict__.keys())
            and ("_experiment" in self.logger.__dict__.keys())
        ):
            self.logger.experiment.define_metric("val_loss", summary="min")
            self.logger.experiment.define_metric("sig_auc", summary="max")
            self.logger.experiment.define_metric("tot_auc", summary="max")
            self.logger.experiment.define_metric("sig_fake_ratio", summary="max")
            self.logger.experiment.define_metric("custom_f1", summary="max")
            self.logger.experiment.log({"sig_auc": 0})
            self.logger.experiment.log({"sig_fake_ratio": 0})
            self.logger.experiment.log({"custom_f1": 0})

    def train_dataloader(self):
        if self.trainset is not None:
            return DataLoader(
                self.trainset, batch_size=1, num_workers=1
            )  # , pin_memory=True, persistent_workers=True)
        else:
            return None

    def val_dataloader(self):
        if self.valset is not None:
            return DataLoader(
                self.valset, batch_size=1, num_workers=1
            )  # , pin_memory=True, persistent_workers=True)
        else:
            return None

    def test_dataloader(self):
        if self.testset is not None:
            return DataLoader(
                self.testset, batch_size=1, num_workers=1
            )  # , pin_memory=True, persistent_workers=True)
        else:
            return None

    def configure_optimizers(self):
        optimizer = [
            torch.optim.AdamW(
                self.parameters(),
                lr=(self.hparams["lr"]),
                betas=(0.9, 0.999),
                eps=1e-08,
                amsgrad=True,
            )
        ]
        scheduler = [
            {
                "scheduler": torch.optim.lr_scheduler.StepLR(
                    optimizer[0],
                    step_size=self.hparams["patience"],
                    gamma=self.hparams["factor"],
                ),
                "interval": "epoch",
                "frequency": 1,
            }
        ]
        return optimizer, scheduler

    def handle_directed(self, batch, edge_sample, truth_sample, sample_indices):

        edge_sample = torch.cat([edge_sample, edge_sample.flip(0)], dim=-1)
        truth_sample = truth_sample.repeat(2)
        sample_indices = sample_indices.repeat(2)

        if ("directed" in self.hparams.keys()) and self.hparams["directed"]:
            direction_mask = batch.x[edge_sample[0], 0] < batch.x[edge_sample[1], 0]
            edge_sample = edge_sample[:, direction_mask]
            truth_sample = truth_sample[direction_mask]

        return edge_sample, truth_sample, sample_indices
    
    def training_step(self, batch, batch_idx):
        
        # truth = batch[self.hparams["truth_key"]]

        # currently, we don't train purity
        # if ("train_purity" in self.hparams.keys()) and (
        #     self.hparams["train_purity"] > 0
        # ):
        #     edge_sample, truth_sample, sample_indices = purity_sample(
        #         truth, batch.edge_index, self.hparams["train_purity"]
        #     )
        # else:
        #     edge_sample, truth_sample, sample_indices = batch.edge_index, truth, torch.arange(batch.edge_index.shape[1])
            
        # edge_sample, truth_sample, sample_indices = self.handle_directed(batch, edge_sample, truth_sample, sample_indices)

        # weight = (
        #     torch.tensor(self.hparams["weight"])
        #     if ("weight" in self.hparams)
        #     else torch.tensor((~truth_sample.bool()).sum() / truth_sample.sum())
        # )

        output_dict = self(batch.x_dict, batch.edge_index_dict)
        truth_dict = batch.truth_dict
        if self.hparams["mask_background"]:
            for key, output in output_dict.items():
                y_subset = batch.truth_dict[key] | ~batch.y_pid_dict[key].bool() # previously the y_pid is filtered by the sample_indices
                output_dict[key], truth_dict[key] = output_dict[key][y_subset], batch.truth_dict[key][y_subset]

        loss = [
            F.binary_cross_entropy_with_logits(
                output, truth.float(), reduction='sum'
            )
            for output, truth in zip(output_dict.values(), truth_dict.values())
        ]
        loss /= np.sum([output.shape[0] for output in output_dict.values()])
        # loss = F.binary_cross_entropy_with_logits(
        #     output, truth_sample.float()#, pos_weight=weight
        # )

        self.log("train_loss", loss, on_step=False, on_epoch=True)

        return loss

    def log_metrics(self, output, sample_indices, batch, loss, log):

        preds = torch.sigmoid(output) > self.hparams["edge_cut"]

        # Positives
        edge_positive = preds.sum().float()

        # Signal true & signal tp
        sig_truth = batch.pid_signal[sample_indices]
        sig_true = sig_truth.sum().float()
        sig_true_positive = (sig_truth.bool() & preds).sum().float()
        sig_auc = roc_auc_score(
            sig_truth.bool().cpu().detach(), torch.sigmoid(output).cpu().detach()
        )

        # Total true & total tp
        tot_truth = (batch.y_pid.bool() | batch.y.bool())[sample_indices]
        tot_true = tot_truth.sum().float()
        tot_true_positive = (tot_truth.bool() & preds).sum().float()
        tot_auc = roc_auc_score(
            tot_truth.bool().cpu().detach(), torch.sigmoid(output).cpu().detach()
        )

        # Eff, pur, auc
        sig_eff = sig_true_positive / sig_true
        sig_pur = sig_true_positive / edge_positive
        tot_eff = tot_true_positive / tot_true
        tot_pur = tot_true_positive / edge_positive

        # Combined metrics
        double_auc = sig_auc * tot_auc
        custom_f1 = 2 * sig_eff * tot_pur / (sig_eff + tot_pur)
        sig_fake_ratio = sig_true_positive / (edge_positive - tot_true_positive)

        if log:
            current_lr = self.optimizers().param_groups[0]["lr"]
            self.log_dict(
                {
                    "val_loss": loss,
                    "current_lr": current_lr,
                    "sig_eff": sig_eff,
                    "sig_pur": sig_pur,
                    "sig_auc": sig_auc,
                    "tot_eff": tot_eff,
                    "tot_pur": tot_pur,
                    "tot_auc": tot_auc,
                    "double_auc": double_auc,
                    "custom_f1": custom_f1,
                    "sig_fake_ratio": sig_fake_ratio,
                },
                sync_dist=True,
            )

        return preds

    def shared_evaluation(self, batch, batch_idx, log=True):

        truth = batch[self.hparams["truth_key"]]
        
        # if ("train_purity" in self.hparams.keys()) and (
        #     self.hparams["train_purity"] > 0
        # ):
        #     edge_sample, truth_sample, sample_indices = purity_sample(
        #         truth, batch.edge_index, self.hparams["train_purity"]
        #     )
        # else:
        #     edge_sample, truth_sample, sample_indices = batch.edge_index, truth, torch.arange(batch.edge_index.shape[1])
            
        # edge_sample, truth_sample, sample_indices = self.handle_directed(batch, edge_sample, truth_sample, sample_indices)

        # weight = (
        #     torch.tensor(self.hparams["weight"])
        #     if ("weight" in self.hparams)
        #     else torch.tensor((~truth_sample.bool()).sum() / truth_sample.sum())
        # )
        
        # output = self(batch.x.float(), edge_sample, batch.volume_id).squeeze()

        output_dict = self(batch.x_dict, batch.edge_index_dict)
        truth_dict = batch.truth_dict
        if self.hparams["mask_background"]:
            for key, output in output_dict.items():
                y_subset = batch.truth_dict[key] | ~batch.y_pid_dict[key].bool() # previously the y_pid is filtered by the sample_indices
                output_dict[key], truth_dict[key] = output_dict[key][y_subset], truth_dict[key][y_subset]

        loss = [
            F.binary_cross_entropy_with_logits(
                output, truth.float(), reduction='sum'
            )
            for output, truth in zip(output_dict.values(), truth_dict.values())
        ]
        loss /= np.sum([output.shape[0] for output in output_dict.values()])

        # if self.hparams["mask_background"]:
        #     y_subset = truth_sample | ~batch.y_pid[sample_indices].bool()
        #     subset_output, subset_truth_sample = output[y_subset], truth_sample[y_subset]
        #     loss = F.binary_cross_entropy_with_logits(
        #         subset_output, subset_truth_sample.float().squeeze(), pos_weight=weight
        #     )            
        # else:
        #     loss = F.binary_cross_entropy_with_logits(
        #         output, truth_sample.float().squeeze(), pos_weight=weight
        #     )

        # try:
        #     preds = self.log_metrics(output, sample_indices, batch, loss, log)
        #     return {"loss": loss, "preds": preds, "score": torch.sigmoid(output)}
        # except:
        return {"loss": loss, 
                # "score": torch.sigmoid(output)
            }

    def validation_step(self, batch, batch_idx):

        outputs = self.shared_evaluation(batch, batch_idx)

        return outputs["loss"]

    def test_step(self, batch, batch_idx):

        outputs = self.shared_evaluation(batch, batch_idx, log=False)

        return outputs

    def test_step_end(self, output_results):

        print("Step:", output_results)

    def test_epoch_end(self, outputs):

        print("Epoch:", outputs)

    def optimizer_step(
        self,
        epoch,
        batch_idx,
        optimizer,
        optimizer_idx,
        optimizer_closure=None,
        on_tpu=False,
        using_native_amp=False,
        using_lbfgs=False,
    ):
        # warm up lr
        if (self.hparams["warmup"] is not None) and (
            self.current_epoch < self.hparams["warmup"]
        ):
            lr_scale = min(
                1.0, float(self.current_epoch + 1) / self.hparams["warmup"]
            )
            for pg in optimizer.param_groups:
                pg["lr"] = lr_scale * self.hparams["lr"]

        # update params
        optimizer.step(closure=optimizer_closure)
        optimizer.zero_grad()

        
class LargeGNNBase(HeteroGNNBase):
    def __init__(self, hparams):
        super().__init__(hparams)

    def setup(self, stage, process_function):
        # Handle any subset of [train, val, test] data split, assuming that ordering
        splits = np.array(self.hparams['datatype_split']) 
        splits = (self.hparams['n_events'] * splits / np.sum(splits)).astype(np.int32)

        self.trainset, self.valset, self.testset = [
            LargeHeteroDataset(
                self.hparams['input_dir'],
                subdir,
                self.hparams,
                split,
                process_function
            )
            for subdir, split in zip(self.hparams['datatype_names'], self.hparams['datatype_split'])
        ]

        if (
            (self.trainer)
            and ("logger" in self.trainer.__dict__.keys())
            and ("_experiment" in self.logger.__dict__.keys())
        ):
            self.logger.experiment.define_metric("val_loss", summary="min")
            self.logger.experiment.define_metric("sig_auc", summary="max")
            self.logger.experiment.define_metric("tot_auc", summary="max")
            self.logger.experiment.define_metric("sig_fake_ratio", summary="max")
            self.logger.experiment.define_metric("custom_f1", summary="max")
            self.logger.experiment.log({"sig_auc": 0})
            self.logger.experiment.log({"sig_fake_ratio": 0})
            self.logger.experiment.log({"custom_f1": 0})