In [4]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, BatchNorm, global_mean_pool
from torch_geometric.data import Data
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# Define model
class GNN(torch.nn.Module):
    def __init__(self):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(1, 16)
        self.bn1 = BatchNorm(16)
        self.conv2 = GCNConv(16, 32)
        self.bn2 = BatchNorm(32)
        self.conv3 = GCNConv(32, 64)
        self.bn3 = BatchNorm(64)
        self.conv4 = GCNConv(64, 16)
        self.bn4 = BatchNorm(16)
        self.dropout = torch.nn.Dropout(p=0.4)
        self.fc = torch.nn.Linear(16, 2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.dropout(F.relu(self.bn1(self.conv1(x, edge_index))))
        x = self.dropout(F.relu(self.bn2(self.conv2(x, edge_index))))
        x = self.dropout(F.relu(self.bn3(self.conv3(x, edge_index))))
        x = self.dropout(F.relu(self.bn4(self.conv4(x, edge_index))))
        x = global_mean_pool(x, data.batch) if hasattr(data, "batch") else x.mean(dim=0, keepdim=True)
        return F.softmax(self.fc(x), dim=1)

In [3]:
# Load model 
model = torch.load("gnn_model.pt", map_location="cpu")
model.eval()

GNN(
  (conv1): GCNConv(1, 16)
  (bn1): BatchNorm(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): GCNConv(16, 32)
  (bn2): BatchNorm(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): GCNConv(32, 64)
  (bn3): BatchNorm(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): GCNConv(64, 16)
  (bn4): BatchNorm(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.4, inplace=False)
  (fc): Linear(in_features=16, out_features=2, bias=True)
)

In [4]:
# Load dataset
df = pd.read_csv("../raw/Z2JetsToNuNu_PtZ-250To400.csv")
jet_columns = ["lead_jet_pt", "lead_jet_eta", "lead_jet_phi", "lead_jet_mass"]
recoil_threshold = 200
predictions = []

In [5]:
# Build graph for each event
def build_graph_from_row(row):
    try:
        nodes = [float(row[col]) for col in jet_columns]
        edges = [[i, i + 1] for i in range(len(nodes) - 1)]
        x = torch.tensor(nodes, dtype=torch.float).view(-1, 1)
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        return Data(x=x, edge_index=edge_index)
    except:
        return None

In [6]:
# Run prediction loop
with torch.no_grad():
    for _, row in df.iterrows():
        data = build_graph_from_row(row)
        if data is None:
            predictions.append(np.nan)
            continue
        out = model(data)
        pred_class = torch.argmax(out, dim=1).item()
        predictions.append(pred_class)

In [7]:
# Add predictions to DataFrame
df["gnn_prediction"] = predictions

In [8]:
# 1. Class 1 predictions that passed the recoil cut
df_class1_passed = df[
    (df["gnn_prediction"] == 1) &
    (df["recoil_magnitude_u"] > recoil_threshold)
]

In [9]:
# 2. Class 0 predictions and class 1 that FAILED the recoil cut
df_class0 = df[
    (df["gnn_prediction"] == 0) |
    ((df["gnn_prediction"] == 1) & (df["recoil_magnitude_u"] <= recoil_threshold))
].copy()

In [10]:
# Reassign those that failed the cut as class 0
df_class0.loc[df_class0["gnn_prediction"] == 1, "gnn_prediction"] = 0

In [11]:
# Save both groups
df_class1_passed.to_csv("../test_data/Z2JetsToNuNu_class1.csv", index=False)
df_class0.to_csv("../test_data/Z2JetsToNuNu_class0.csv", index=False)

In [14]:
print("Events in class 1:", len(df_class1_passed))
print("Events in class 0:", len(df_class0))

Events in class 1: 178938
Events in class 0: 82515
