In [1]:
import torch
from torch_geometric.data import Data
import uproot
import awkward as ak
import numpy as np
import matplotlib.pyplot as plt
import concurrent.futures

In [2]:
fileset = {}

sig_dir = '/ceph/cms/store/user/aaportel/B-Parking/rechits_v2/BToKPhi_MuonLLPDecayGenFilter_PhiToPi0Pi0_mPhi0p3_ctau300/'
fileset['sample'] = [sig_dir + f'BToKPhi_MuonLLPDecayGenFilter_PhiToPi0Pi0_mPhi0p3_ctau300_{str(i).zfill(7)}_graphs.pt' for i in range(328)]
# fileset['sample'] = [sig_dir + f'BToKPhi_MuonLLPDecayGenFilter_PhiToPi0Pi0_mPhi0p3_ctau300_{str(i).zfill(7)}_graphs.pt' for i in range(100)]

# bkg_dir = '/ceph/cms/store/user/aaportel/B-Parking/rechits_v2/ParkingBPH1_2018A/'
# fileset['background'] = [bkg_dir + f'ParkingBPH1_2018A_{str(i).zfill(7)}.root' for i in range(380)]

In [4]:
import torch
from torch.utils.data import Dataset, ConcatDataset
from torch_geometric.loader import DataLoader

datasets = [torch.load(fp) for fp in fileset['sample']]
combined_dataset = ConcatDataset(datasets)
dataloader = DataLoader(combined_dataset, batch_size=32, shuffle=True, num_workers=64)

In [8]:
import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

class GNNClassifier(torch.nn.Module):
    def __init__(self, num_node_features, num_graph_features, hidden_channels):
        super(GNNClassifier, self).__init__()
        # GCN layers
        self.conv1 = GCNConv(num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        # Fully connected layers for graph-level features
        self.fc_graph = Linear(num_graph_features, hidden_channels)
        # Fully connected layers for concatenated features
        self.fc1 = Linear(2 * hidden_channels, hidden_channels)
        self.fc2 = Linear(hidden_channels, 1)

    def forward(self, x, edge_index, batch_index, graph_features):
        # Node feature learning
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        # Graph feature processing

        graph_features = F.relu(self.fc_graph(graph_features.view(-1, num_graph_features)))

        # Pooling node features to graph-level features
        x = global_mean_pool(x, batch_index)

        # Concatenate node features and graph features
        x = torch.cat([x, graph_features], dim=1)

        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)

        return torch.sigmoid(x)

# Define the number of features and instantiate the model
num_node_features = 17
num_graph_features = 77
hidden_channels = 64
model = GNNClassifier(num_node_features, num_graph_features, hidden_channels)

In [9]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCELoss()

# Assuming 'dataloader' is a PyTorch Geometric DataLoader instance containing your data
def train():
    model.train()
    total_loss = 0
    for data in dataloader:
        optimizer.zero_grad()  # Clear gradients.
        data = data.to('cuda' if torch.cuda.is_available() else 'cpu')  # Move data to the device.
        out = model(data.x, data.edge_index, data.batch, data.u)  # Forward pass.
        loss = criterion(out.view(-1), data.y.to(torch.float))  # Compute the loss.
        loss.backward()  # Backpropagate to compute gradients.
        optimizer.step()  # Update model parameters.
        total_loss += loss.item() * data.num_graphs  # Multiply by the number of graphs in the batch.
    return total_loss / len(dataloader.dataset)  # Return the average loss.

# Training loop
for epoch in range(100):  # Number of epochs
    loss = train()
    print(f'Epoch {epoch+1}: Loss = {loss:.4f}')

Epoch 1: Loss = 48.8210
Epoch 2: Loss = 48.8574
Epoch 3: Loss = 48.8623
Epoch 4: Loss = 48.7968
Epoch 5: Loss = 48.8380
Epoch 6: Loss = 48.8768
Epoch 7: Loss = 48.8986
Epoch 8: Loss = 48.8065
Epoch 9: Loss = 48.8331
Epoch 10: Loss = 48.8307
Epoch 11: Loss = 48.8331
Epoch 12: Loss = 48.8574
Epoch 13: Loss = 48.7798
Epoch 14: Loss = 48.8477
Epoch 15: Loss = 48.8671


Process Process-1021:
Process Process-1002:
Process Process-1023:
Process Process-982:
Process Process-1019:
Process Process-1014:
Process Process-1016:
Process Process-992:
Process Process-1006:
Process Process-995:
Process Process-1013:
Process Process-1012:
Process Process-1024:
Process Process-993:
Process Process-1022:
Process Process-1003:
Process Process-1018:
Process Process-998:
Process Process-987:
Process Process-1007:
Process Process-988:
Process Process-997:
Process Process-991:
Process Process-1020:
Process Process-1005:
Process Process-1004:
Process Process-1010:
Process Process-1011:
Process Process-996:
Process Process-1017:
Process Process-1001:
Process Process-1008:
Process Process-1009:
Process Process-1000:
Process Process-1015:
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3dc393ede0>
Traceback (most recent call last):
  File "/home/users/aaportel/mambaforge/envs/mlllp/lib/python3.11/site-packages/torch/utils/data/dataloader.py", li

In [None]:
77*32