In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv torch_geometric -f https://data.pyg.org/whl/torch-1.13.0+cu117.html

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.13.0+cu117.html
Collecting pyg_lib
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcu117/pyg_lib-0.1.0%2Bpt113cu117-cp39-cp39-linux_x86_64.whl (1.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m29.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_scatter
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcu117/torch_scatter-2.1.1%2Bpt113cu117-cp39-cp39-linux_x86_64.whl (10.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.1/10.1 MB[0m [31m84.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_sparse
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcu117/torch_sparse-0.6.17%2Bpt113cu117-cp39-cp39-linux_x86_64.whl (4.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.7/4.7 MB[0m [31m93.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollect

In [3]:
import numpy as np
np.random.seed(0)
import os, glob
import time
import h5py
from tqdm import tqdm

import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.utils
import torch.utils.data as data_utils
from torch.utils.data import ConcatDataset, Dataset, DataLoader, sampler, DistributedSampler, Subset
from torch_geometric.data import Batch, Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool

from sklearn.metrics import roc_curve, auc

In [4]:
BATCH = 8

granularity = 1
dropout = 0.3
maxnodes = 100
lr_init = 5.e-4
edgeconvblocks = 3
epochs = 5
os.environ["CUDA_VISIBLE_DEVICES"]=str(0)
PATH = "/content/drive/MyDrive/gsoc/quark-gluon_data-set_n139306.hdf5"
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [5]:
import torch
from torch_geometric.data import Data
class HDF5Dataset(Dataset):
    def __init__(self, file_path):
        self.file = h5py.File(file_path, 'r')
        self.data = self.file['X_jets']
        self.m0 = self.file['m0']
        self.pt = self.file['pt']
        self.y = self.file['y']
        
    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        # Load X_jets in shape (125, 125, 3) from HDF5 file
        x_jets = self.data[index]
        # Transpose the array to shape (3, 125, 125)
        image = torch.tensor(x_jets.transpose(2, 0, 1)).detach().clone()

        nonzero_indices = torch.nonzero(image > 0)
        points = torch.zeros(nonzero_indices.shape[0], 3)
        points[:, 0] = 2.0 * (nonzero_indices[:, 2].float() / (image.shape[2] - 1) - 0.5) # x-coordinates
        points[:, 1] = 2.0 * (nonzero_indices[:, 1].float() / (image.shape[1] - 1) - 0.5) # y-coordinates
        points[:, 2] = image[nonzero_indices[:, 0], nonzero_indices[:, 1], nonzero_indices[:, 2]]

        ret_data = {}
        ret_data["X_jets"] = torch.tensor(image)
        ret_data["points"] = torch.tensor(points)
        ret_data["m0"] = torch.tensor(self.m0[index])
        ret_data["pt"] = torch.tensor(self.pt[index])
        ret_data["y"] = torch.tensor(self.y[index])
      
        return dict(ret_data)

In [23]:
dataset = HDF5Dataset(PATH)

indices = torch.arange(200)
train_dataset = Subset(dataset, indices)

In [24]:
train_size = 0.8

val_dataset = train_dataset
test_dataset = train_dataset

num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(train_size * num_train))
split2 = int(np.floor((train_size+(1-train_size)/2) * num_train))
np.random.shuffle(indices)
train_idx, valid_idx, test_idx = indices[:split], indices[split:split2], indices[split2:]

train_data = Subset(train_dataset, indices=train_idx)
val_data = Subset(val_dataset, indices=valid_idx)
test_data = Subset(test_dataset, indices=test_idx)

train_data = Subset(train_dataset, indices=train_idx)
val_data = Subset(val_dataset, indices=valid_idx)
test_data = Subset(test_dataset, indices=test_idx)

In [25]:
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=BATCH, num_workers=2)
val_dataloader = DataLoader(val_data, shuffle=True, batch_size=BATCH, num_workers=2)



In [26]:
class PointCloudGraphDataset(Dataset):
    def __init__(self, point_cloud_dataloader):
        self.point_cloud_dataloader = point_cloud_dataloader

    def __getitem__(self, index):
        points, label = self.point_cloud_dataloader.dataset[index]["points"], self.point_cloud_dataloader.dataset[index]["y"]
        num_points = points.shape[0]
        edge_index = torch.zeros((2, num_points * (num_points - 1) // 2), dtype=torch.long)
        edge_attr = torch.zeros(edge_index.shape[1], dtype=torch.float)
        k = 0
        for i in range(num_points):
            for j in range(i + 1, num_points):
                edge_index[0, k] = i
                edge_index[1, k] = j
                edge_attr[k] = torch.norm(points[i] - points[j])
                k += 1
        return Data(x=points, edge_index=edge_index, edge_attr=edge_attr, y=label)

    def __len__(self):
        return len(self.point_cloud_dataloader.dataset)

In [27]:
train_point_cloud_graph_dataset = PointCloudGraphDataset(train_dataloader)
val_point_cloud_graph_dataset = PointCloudGraphDataset(val_dataloader)
def collate_fn(batch):
    return Batch.from_data_list(batch)
train_dataloader = DataLoader(train_point_cloud_graph_dataset, batch_size=BATCH, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_point_cloud_graph_dataset, batch_size=BATCH, shuffle=True, collate_fn=collate_fn)

In [28]:
class GNNModel(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNNModel, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        self.conv4 = GCNConv(hidden_dim, output_dim)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.relu(self.conv2(x, edge_index, edge_attr))
        x = F.relu(self.conv3(x, edge_index, edge_attr))
        x = self.conv4(x, edge_index, edge_attr)
        x = global_mean_pool(x, data.batch)
        return F.log_softmax(x, dim=1)




In [29]:
model = GNNModel(input_dim=3, hidden_dim=32, output_dim=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.NLLLoss()

for epoch in range(50):
    model.train()
    for data in tqdm(train_dataloader):
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, data.y.long())
        loss.backward()
        optimizer.step()

    model.eval()
    correct = 0
    total = 0
    for data in tqdm(val_dataloader):
        data = data.to(device)
        output = model(data)
        _, predicted = torch.max(output, 1)
        total += data.y.size(0)
        correct += (predicted == data.y).sum().item()
    accuracy = 100 * correct / total
    print('Epoch {}, Accuracy: {:.2f}%'.format(epoch + 1, accuracy))

  ret_data["X_jets"] = torch.tensor(image)
  ret_data["points"] = torch.tensor(points)
100%|██████████| 20/20 [32:49<00:00, 98.46s/it]
100%|██████████| 3/3 [03:44<00:00, 74.86s/it]


Epoch 1, Accuracy: 50.00%


100%|██████████| 20/20 [33:36<00:00, 100.84s/it]
100%|██████████| 3/3 [03:51<00:00, 77.32s/it]


Epoch 2, Accuracy: 20.00%


100%|██████████| 20/20 [32:55<00:00, 98.78s/it]
100%|██████████| 3/3 [03:44<00:00, 74.76s/it]


Epoch 3, Accuracy: 75.00%


100%|██████████| 20/20 [32:55<00:00, 98.79s/it] 
100%|██████████| 3/3 [03:51<00:00, 77.04s/it]


Epoch 4, Accuracy: 25.00%


100%|██████████| 20/20 [33:55<00:00, 101.79s/it]
100%|██████████| 3/3 [03:43<00:00, 74.48s/it]


Epoch 5, Accuracy: 55.00%


100%|██████████| 20/20 [32:56<00:00, 98.84s/it] 
100%|██████████| 3/3 [03:40<00:00, 73.44s/it]


Epoch 6, Accuracy: 50.00%


  5%|▌         | 1/20 [01:37<30:53, 97.55s/it]


KeyboardInterrupt: ignored