In [None]:
!pip install --upgrade torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
!python -c "import torch; print(torch.__version__); print(torch.version.cuda);"
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.11.0+cu113.html

## Step1: Dataset Preprocessing

In [2]:
import numpy as np
import gc
import torch
import pyarrow as pa
from tqdm import tqdm
from pyarrow.parquet import ParquetFile
from sklearn.neighbors import kneighbors_graph
from sklearn.model_selection import train_test_split
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

In [3]:
import glob
import pandas as pd
tau_dataset = glob.glob('../input/boosted-top-tau-4/*')
X_jets = []
lables = []

for file_name in tau_dataset:
    df = pd.read_parquet(file_name)
    X_jets.append(np.array(df['X_jets'].tolist()).astype(np.float32))
    lables.append(df['y'].to_numpy())
    del df

In [4]:
X_jets = np.array(X_jets).reshape((-1,125000))
lables = np.array(lables).reshape((-1,1))

In [5]:
X_data = X_jets.reshape((-1,125*125,8))
X_data = X_data[:,:,:5]
non_black_pixels_mask = np.any(X_data != 0., axis=-1)

node_list = []
for i, x in enumerate(X_data):
    node_list.append(x[non_black_pixels_mask[i]])
del X_jets, X_data, non_black_pixels_mask


In [6]:
dataset = []
for i,nodes in enumerate(tqdm(node_list)):
    edges = kneighbors_graph(nodes, 15, mode='connectivity', include_self=True)
    c = edges.tocoo()
    edge_list = torch.from_numpy(np.vstack((c.row, c.col))).type(torch.long)
    edge_weight = torch.from_numpy(c.data.reshape(-1,1))
    y = lables[i]
    dataset.append(Data(x=torch.from_numpy(nodes), edge_index=edge_list, edge_attr=edge_weight, y=torch.from_numpy(y).long()))

100%|██████████| 9600/9600 [01:22<00:00, 116.99it/s]


In [7]:
del lables, node_list, edge_list, edge_weight, y
gc.collect()

911

In [9]:
rand_seed = 42
X_train, X_test = train_test_split(dataset, test_size=0.1, random_state = rand_seed)
X_train, X_val = train_test_split(X_train, test_size=0.1, random_state = rand_seed)

In [10]:
train_loader = DataLoader(X_train, batch_size=32, shuffle=True)
val_loader = DataLoader(X_val, batch_size=32, shuffle=False)
test_loader = DataLoader(X_test, batch_size=32, shuffle=False)

batch = next(iter(test_loader))
print("Batch:", batch)
print("Labels:", batch.y[:10])
print("Batch indices:", batch.batch[:40])

Batch: DataBatch(x=[37894, 5], edge_index=[2, 568410], edge_attr=[568410, 1], y=[32], batch=[37894], ptr=[33])
Labels: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Batch indices: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])


## Step2: Define Graph Convolution GNN Model

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

In [12]:
from torch.nn import Linear
from torch_geometric.nn import GraphConv
from torch_geometric.nn import global_mean_pool

num_node_features = 5
num_classes = 2

class GCN(torch.nn.Module):
    def __init__(self, c_in, c_hidden, c_out, dp_rate_linear=0.3):
        super().__init__()
        torch.manual_seed(123)
        self.conv1 = GraphConv(c_in, c_hidden)
        self.conv2 = GraphConv(c_hidden,2*c_hidden)
        self.conv3 = GraphConv(2*c_hidden, c_hidden)
        self.lin1 = Linear(c_hidden, 4*c_out)
        self.lin2 = Linear(4*c_out, c_out)
        self.dp_rate_linear = dp_rate_linear

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # classifier
        x = F.dropout(x, p=self.dp_rate_linear, training=self.training)
        x = self.lin1(x)
        x = F.dropout(x, p=self.dp_rate_linear, training=self.training)
        x = self.lin2(x)

        return x

## Step3: Use PyTorch Lightning to Define Training Process

In [20]:
from torchmetrics.functional import auroc
class GraphLevelGNN(pl.LightningModule):
    
    def __init__(self, **model_kwargs):
        super().__init__()
        # Saving hyperparameters
        self.save_hyperparameters()
        
        self.model = GCN(**model_kwargs)
        self.loss_module = nn.BCEWithLogitsLoss() if self.hparams.c_out == 1 else nn.CrossEntropyLoss()
        self.auroc = auroc
        
    def forward(self, data, mode="train"):
        x, edge_index, batch_idx = data.x, data.edge_index, data.batch
        # print(data.x.shape, data.edge_index.shape, data.batch.shape)
        
        x = self.model(x, edge_index, batch_idx)
        x = x.squeeze(dim=-1)
        
        if self.hparams.c_out == 1:
            preds = (x > 0).float()
            data.y = data.y.float()
        else:
            preds = x.argmax(dim=-1)
        loss = self.loss_module(x, data.y)
        acc = (preds == data.y).sum().float() / preds.shape[0]
        auc = self.auroc(x, data.y, num_classes=2)
        return loss, acc, auc

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=3e-4, weight_decay=0) # High lr because of small dataset and small model
        return optimizer

    def training_step(self, batch, batch_idx):
        loss, acc, auc = self.forward(batch, mode="train")
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        self.log('train_auc', auc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc, auc = self.forward(batch, mode="val")
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        self.log('val_auc', auc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        loss, acc, auc = self.forward(batch, mode="test")
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        self.log('test_auc', auc, prog_bar=True)

In [23]:
CHECKPOINT_PATH = "./"
def train_graph_classifier(model_name, **model_kwargs):
    pl.seed_everything(42)
    
    # Create a PyTorch Lightning trainer with the generation callback
    root_dir = os.path.join(CHECKPOINT_PATH, "GraphLevel" + model_name)
    os.makedirs(root_dir, exist_ok=True)
    trainer = pl.Trainer(default_root_dir=root_dir,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc")],
                         gpus=1 if str(device).startswith("cuda") else 0,
                         max_epochs=20,
                         progress_bar_refresh_rate=5)

    # Check whether pretrained model exists. If yes, load it and skip training
    model = GraphLevelGNN(**model_kwargs)
    print(model)
    trainer.fit(model, train_loader, val_loader)
    model = GraphLevelGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    
    # Test best model on validation and test set
    # train_result = trainer.test(model, dataloaders=train_loader, verbose=False)
    val_result = trainer.test(model, dataloaders=val_loader, verbose=False)
    test_result = trainer.test(model, dataloaders=test_loader, verbose=False)
    print(val_result, test_result)
    result = {"test": test_result[0]['test_acc'], "valid": val_result[0]['test_acc']} 
    return model, result

In [24]:
import os
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, result = train_graph_classifier(model_name="GCN", c_in=5, c_hidden=64, c_out=2)

GraphLevelGNN(
  (model): GCN(
    (conv1): GraphConv(5, 64)
    (conv2): GraphConv(64, 128)
    (conv3): GraphConv(128, 64)
    (lin1): Linear(in_features=64, out_features=8, bias=True)
    (lin2): Linear(in_features=8, out_features=2, bias=True)
  )
  (loss_module): CrossEntropyLoss()
)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Testing: 0it [00:00, ?it/s]

Testing: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection.

[{'test_loss': 0.40515729784965515, 'test_acc': 0.8140714764595032, 'test_auc': 0.85390305519104}] [{'test_loss': 0.375895231962204, 'test_acc': 0.8309985399246216, 'test_auc': 0.8593456149101257}]
