# 📘 Common Task 2 – Graph-based Jet Classification (DeepFalcon GSoC 2025)

### ✅ Setup: Install Required Libraries

In [None]:
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
!pip install -q torch-geometric
!pip install -q h5py imageio seaborn open3d tqdm


### ✅ Step 1: Import Libraries

In [None]:
import os
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool

import random
from sklearn.metrics import classification_report, accuracy_score


### ✅ Step 2: Convert Jet Images to Point Cloud Graphs

In [None]:
class JetGraphDataset(Dataset):
    def __init__(self, npz_path):
        data = np.load(npz_path)
        self.ecal = data['ecal']
        self.hcal = data['hcal']
        self.track = data['track']
        self.labels = data['labels']  # quark = 0, gluon = 1

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        images = [self.ecal[idx], self.hcal[idx], self.track[idx]]
        label = self.labels[idx]

        points = []
        for ch, img in enumerate(images):
            y, x = np.nonzero(img)
            energy = img[y, x]
            for i in range(len(x)):
                points.append([x[i], y[i], energy[i], ch])

        points = np.array(points)
        if len(points) == 0:
            points = np.zeros((1, 4))

        x = torch.tensor(points[:, :3], dtype=torch.float)
        edge_index = self.build_knn_edges(x, k=8)
        return Data(x=x, edge_index=edge_index, y=torch.tensor(label, dtype=torch.long))

    def build_knn_edges(self, x, k=8):
        from sklearn.neighbors import NearestNeighbors
        x_np = x[:, :2].numpy()
        knn = NearestNeighbors(n_neighbors=min(k, len(x_np)))
        knn.fit(x_np)
        _, indices = knn.kneighbors(x_np)
        row = []
        col = []
        for i, neighbors in enumerate(indices):
            for j in neighbors:
                row.append(i)
                col.append(j)
        edge_index = torch.tensor([row, col], dtype=torch.long)
        return edge_index


### ✅ Step 3: Define GCN Classifier

In [None]:
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self, input_dim=3, hidden_dim=64):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.lin = torch.nn.Linear(hidden_dim, 2)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        return self.lin(x)


### ✅ Step 4: Load Dataset and Train GCN

In [None]:
from torch_geometric.loader import DataLoader

# NOTE: Replace path with your real dataset path
graph_dataset = JetGraphDataset("/content/jet_images_3ch.npz")

# Shuffle and split
train_size = int(0.8 * len(graph_dataset))
val_size = len(graph_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(graph_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(1, 11):
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        out = model(batch.x, batch.edge_index, batch.batch)
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()
    print(f"Epoch {epoch}, Loss: {total_loss/len(train_loader):.4f}")


### ✅ Step 5: Evaluate Model

In [None]:
model.eval()
y_true = []
y_pred = []

with torch.no_grad():
    for batch in val_loader:
        batch = batch.to(device)
        out = model(batch.x, batch.edge_index, batch.batch)
        pred = out.argmax(dim=1)
        y_true += batch.y.cpu().numpy().tolist()
        y_pred += pred.cpu().numpy().tolist()

print(classification_report(y_true, y_pred, target_names=['quark', 'gluon']))
print(f"Accuracy: {accuracy_score(y_true, y_pred):.4f}")
