In [1]:
import h5py
import hdf5plugin
import numpy as np
import itertools
import torch
from torch_geometric.data import Data, Dataset

filename = "/higgs/train/ntuple_merged_44.h5"

h5 = h5py.File(filename,'r')





In [2]:
params_2 = ['track_ptrel',     
          'track_erel',     
          'track_phirel',     
          'track_etarel',     
          'track_deltaR',
          'track_drminsv',     
          'track_drsubjet1',     
          'track_drsubjet2',
          'track_dz',     
          'track_dzsig',     
          'track_dxy',     
          'track_dxysig',     
          'track_normchi2',     
          'track_quality',     
          'track_dptdpt',     
          'track_detadeta',     
          'track_dphidphi',     
          'track_dxydxy',     
          'track_dzdz',     
          'track_dxydz',     
          'track_dphidxy',     
          'track_dlambdadz',     
          'trackBTag_EtaRel',     
          'trackBTag_PtRatio',     
          'trackBTag_PParRatio',     
          'trackBTag_Sip2dVal',     
          'trackBTag_Sip2dSig',     
          'trackBTag_Sip3dVal',     
          'trackBTag_Sip3dSig',     
          'trackBTag_JetDistVal'
         ]


labels_qcd = {"label_QCD_b": 53, "label_QCD_bb":51, "label_QCD_c":54,"label_QCD_cc":52,"label_QCD_others": 55}
label_H_bb = 41
num_jets = h5["event_no"].shape[0]
n_particles = h5["track_dxydxy"].shape[-1]

X = []
y = []
edge_indices = []
for jet in range(num_jets):
    if h5["fj_label"][jet] in list(labels_qcd.values()):
        y.append([0,1])
    elif h5["fj_label"][jet] == label_H_bb:
        y.append([1,0])
    else: continue
    
    x = []
    for feature in params_2:
        x.append(h5[feature][jet])
    X.append(x)
    
    # complete graph has n_particles*(n_particles-1)/2 edges, here we double count each edge, so has  n_particles*(n_particles-1) total edges
    pairs = np.stack([[m, n] for (m, n) in itertools.product(range(n_particles), range(n_particles)) if m != n])
    # edge_index = torch.tensor(pairs, dtype=torch.long)
    edge_index = pairs.transpose()
    edge_indices.append(edge_index)
    
    if jet > 10:
        break

In [5]:
# define pytorch dataset
datas = []
i = 0
data = Data(x=torch.tensor(X[i]), edge_index=torch.tensor(edge_indices[i]), y=torch.tensor(y[i], dtype=torch.long))
for i in range(len(X)):
    data = Data(x=torch.tensor(X[i]), edge_index=torch.tensor(edge_indices[i]), y=torch.tensor(y[i], dtype=torch.long))
    datas.append([data])

In [4]:
datas

[[Data(x=[30, 60], edge_index=[2, 3540], y=[2])],
 [Data(x=[30, 60], edge_index=[2, 3540], y=[2])],
 [Data(x=[30, 60], edge_index=[2, 3540], y=[2])],
 [Data(x=[30, 60], edge_index=[2, 3540], y=[2])],
 [Data(x=[30, 60], edge_index=[2, 3540], y=[2])],
 [Data(x=[30, 60], edge_index=[2, 3540], y=[2])],
 [Data(x=[30, 60], edge_index=[2, 3540], y=[2])],
 [Data(x=[30, 60], edge_index=[2, 3540], y=[2])],
 [Data(x=[30, 60], edge_index=[2, 3540], y=[2])],
 [Data(x=[30, 60], edge_index=[2, 3540], y=[2])],
 [Data(x=[30, 60], edge_index=[2, 3540], y=[2])],
 [Data(x=[30, 60], edge_index=[2, 3540], y=[2])]]

In [4]:
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.nn import EdgeConv, global_mean_pool
from torch.nn import Sequential as Seq, Linear as Lin, ReLU, BatchNorm1d
from torch_scatter import scatter_mean
from torch_geometric.nn import MetaLayer
from tqdm.notebook import tqdm


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
inputs = 48
hidden = 128
outputs = 2


class EdgeBlock(torch.nn.Module):
    def __init__(self):
        super(EdgeBlock, self).__init__()
        self.edge_mlp = Seq(Lin(inputs * 2, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, hidden))

    def forward(self, src, dest, edge_attr, u, batch):
        out = torch.cat([src, dest], 1)
        return self.edge_mlp(out)


class NodeBlock(torch.nn.Module):
    def __init__(self):
        super(NodeBlock, self).__init__()
        self.node_mlp_1 = Seq(Lin(inputs + hidden, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, hidden))
        self.node_mlp_2 = Seq(Lin(inputs + hidden, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, hidden))

    def forward(self, x, edge_index, edge_attr, u, batch):
        row, col = edge_index
        out = torch.cat([x[row], edge_attr], dim=1)
        out = self.node_mlp_1(out)
        out = scatter_mean(out, col, dim=0, dim_size=x.size(0))
        out = torch.cat([x, out], dim=1)
        return self.node_mlp_2(out)


class GlobalBlock(torch.nn.Module):
    def __init__(self):
        super(GlobalBlock, self).__init__()
        self.global_mlp = Seq(Lin(hidden, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, outputs))

    def forward(self, x, edge_index, edge_attr, u, batch):
        out = scatter_mean(x, batch, dim=0)
        return self.global_mlp(out)


class InteractionNetwork(torch.nn.Module):
    def __init__(self):
        super(InteractionNetwork, self).__init__()
        self.interactionnetwork = MetaLayer(EdgeBlock(), NodeBlock(), GlobalBlock())
        self.bn = BatchNorm1d(inputs)

    def forward(self, x, edge_index, batch):

        x = self.bn(x)
        x, edge_attr, u = self.interactionnetwork(x, edge_index, None, None, batch)
        return u


model = InteractionNetwork().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [5]:
model

InteractionNetwork(
  (interactionnetwork): MetaLayer(
    edge_model=EdgeBlock(
    (edge_mlp): Sequential(
      (0): Linear(in_features=96, out_features=128, bias=True)
      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=128, out_features=128, bias=True)
    )
  ),
    node_model=NodeBlock(
    (node_mlp_1): Sequential(
      (0): Linear(in_features=176, out_features=128, bias=True)
      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=128, out_features=128, bias=True)
    )
    (node_mlp_2): Sequential(
      (0): Linear(in_features=176, out_features=128, bias=True)
      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=128, out_features=128, bias=True)
    )
  ),
    global_model=GlobalBlock(
    (global_mlp): Sequential(
    

In [6]:
@torch.no_grad()
def test(model, loader, total, batch_size, leave=False):
    model.eval()

    xentropy = nn.CrossEntropyLoss(reduction="mean")

    sum_loss = 0.0
    t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)
    for i, data in t:
        data = data.to(device)
        y = torch.argmax(data.y, dim=1)
        batch_output = model(data.x, data.edge_index, data.batch)
        batch_loss_item = xentropy(batch_output, y).item()
        sum_loss += batch_loss_item
        t.set_description("loss = %.5f" % (batch_loss_item))
        t.refresh()  # to show immediately the update

    return sum_loss / (i + 1)


def train(model, optimizer, loader, total, batch_size, leave=False):
    model.train()

    xentropy = nn.CrossEntropyLoss(reduction="mean")

    sum_loss = 0.0
    t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)
    for i, data in t:
        data = data.to(device)
        y = torch.argmax(data.y, dim=1)
        optimizer.zero_grad()
        batch_output = model(data.x, data.edge_index, data.batch)
        batch_loss = xentropy(batch_output, y)
        batch_loss.backward()
        batch_loss_item = batch_loss.item()
        t.set_description("loss = %.5f" % batch_loss_item)
        t.refresh()  # to show immediately the update
        sum_loss += batch_loss_item
        optimizer.step()

    return sum_loss / (i + 1)

In [None]:
from torch_geometric.data import Data, DataListLoader, Batch
from torch.utils.data import random_split


def collate(items):
    l = sum(items, [])
    return Batch.from_data_list(l)


torch.manual_seed(0)
valid_frac = 0.20
full_length = len(graph_dataset)
valid_num = int(valid_frac * full_length)
batch_size = 32

train_dataset, valid_dataset = random_split(graph_dataset, [full_length - valid_num, valid_num])

train_loader = DataListLoader(train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True)
train_loader.collate_fn = collate
valid_loader = DataListLoader(valid_dataset, batch_size=batch_size, pin_memory=True, shuffle=False)
valid_loader.collate_fn = collate
test_loader = DataListLoader(test_dataset, batch_size=batch_size, pin_memory=True, shuffle=False)
test_loader.collate_fn = collate


train_samples = len(train_dataset)
valid_samples = len(valid_dataset)
test_samples = len(test_dataset)
print(full_length)
print(train_samples)
print(valid_samples)
print(test_samples)