# Task IV - Graph Neural Network (GNN)

# 1. Introduction

### Goal

- Perform Quark/Gluon jet classification on ParticleNet’s data for using Graph Neural Network (GNN)

### On the packages used

- In this task, we wil be using the `dgl` (deep graph library) package for constructing GNN models


In [5]:
# load some basic packages first
from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt

import torch

ValueError: module functions cannot set METH_CLASS or METH_STATIC

# 2. Create Graph Representation of Data

References:

How to create a graph  https://docs.dgl.ai/tutorials/basics/1_first.html

How to assign features to edges:  https://docs.dgl.ai/guide/training-edge.html

### Load Data

The dataset contains:

- 50k quark and 50k gluons with constitutes (charged or neural tracks) of the jets

- Each jet is represented by a four-dimensional vector: $p_T, y, \phi$, and PDGid. 

In [4]:

data = np.load('../data/QG_orig/QG_jets.npz')
X, y = data['X'], {'label': torch.tensor(data['y']).long()}
print('Dimension of input data: {}'.format(X.shape))

RuntimeError: KeyboardInterrupt: 

A graph can have features in three different places or levels:

- features for nodes
- features for edges
- features for the whole graph

A **jet** is **encoded as a graph**.

The **label** of whether it is a quark or a gluon is **encoded as the graph feature**.

Each constitute (**track**) is encoded as a **node** of the graph, with the **four-dimensional vector** as its **node features**.

Features are preprocessed to reflect a centred jets and normalised $p_T$ according to Eqn 3.1 of the original paper (https://arxiv.org/abs/1810.05165).

**Nodes** are **connected by bi-directional edges** if they are **adjacent** in either $p_T, y, \phi$.

This ensures that each graph is connected.

**Nodes index** are **ordered in descending order of $p_T$**.

In [None]:
''' Create graphs and save to local disk '''
        input_data = np.load(input_file)
        X, y = input_data['X'], {'label': torch.tensor(input_data['y']).long()}
        
    graph_file = f'../data/QG_graph/QG_jets__{args.connection}__{args.nevent if args.nevent else "All"}.bin'
    gmodel_name = f'../data/QG_model/QG_jets.model'        
        generator = GenerateGraphs(X, args.connection)
        graphs = generator.create_graphs(stop=args.nevent)
        print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '\033[92m[INFO]\033[0m', '\033[92mCreate graphs from\033[0m'.rjust(40, ' '), input_file, len(graphs))
        makedirs(path.dirname(graph_file), exist_ok=True)
        save_graphs(graph_file, graphs, y)
        print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '\033[92m[INFO]\033[0m', '\033[92mSave graphs to\033[0m'.rjust(40, ' '), graph_file, len(graphs)


In [None]:
import dgl
import networkx as nx
import torch

# https://github.com/dmlc/xgboost/issues/1715
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

# data = X # X.shape = (njets, track multiplicity, 4-momentum)

class GenerateGraphs(object):
    def describe(self): return self.__class__.__name__
    def __init__(self, data, connect_style):
        self.data = data
        self.label_name = 'label'
        self.feature_name = 'feature'
        self.connect_style = connect_style
        self._current_data = None

    def _assign_node_feature(self, graph):
        ''' Each node presents a b-jet with four momentum and b-tagging score as features.
        '''
        feature_name = 'feature'
        assert(self._current_data is not None)
        graph.ndata[feature_name] = torch.tensor(self._current_data)
        return graph

    def _create_graph(self, ievent: int):
        ''' Create a graph from a jet X [track multiplicity, "4-momentum"], where "4-momentum" is pt, rapidity, azimuthal angle, and pdgid.
            data: https://zenodo.org/record/3164691#.YFeQey1Q0lp
        '''
        self._current_data = self.data[ievent][~np.all(self.data[ievent] == 0, axis=1)]

        ''' Feature preprocessing: Sect 3.1 in https://arxiv.org/pdf/1810.05165.pdf
            centering jets and normalizing pT
            https://energyflow.network/examples/
        '''
        yphi_avg = np.average(self._current_data[:,1:3], weights=self._current_data[:,0], axis=0)
        self._current_data[:,1:3] -= yphi_avg
        self._current_data[:, 0] /= np.sum(self._current_data[:, 0])

        ''' - Sort by pT (0th column) '''
        self._current_data = self._current_data[self._current_data[:,0].argsort()][::-1].copy()

        ''' Construct a graph '''
        n_nodes = self._current_data.shape[0] # track multiplicity
        if self.connect_style == 'bifully':
            ''' Option 1: Fully connected graph '''
            g = nx.complete_graph(n_nodes)
            graph = dgl.from_networkx(g)
        else:
            ''' Option 2: Bi-directional connection in adjacent track in pT or eta or phi'''
            pt_order = self._current_data[:,0].argsort()[::-1]
            rapidity_order = self._current_data[:,1].argsort()[::-1]
            eta_order = self._current_data[:,2].argsort()[::-1]

            if self.connect_style == 'biadj_pt_y_phi':
                in_node  = np.concatenate((pt_order[:-1], rapidity_order[:-1], eta_order[:-1]))
                out_node = np.concatenate((pt_order[1: ], rapidity_order[1: ], eta_order[1: ]))
            elif self.connect_style == 'biadj_pt_y':
                in_node  = np.concatenate((pt_order[:-1], rapidity_order[:-1]))
                out_node = np.concatenate((pt_order[1: ], rapidity_order[1: ]))
            elif self.connect_style == 'biadj_pt_phi':
                in_node  = np.concatenate((pt_order[:-1], eta_order[:-1]))
                out_node = np.concatenate((pt_order[1: ], eta_order[1: ]))
            elif self.connect_style == 'biadj_y_phi':
                in_node  = np.concatenate((rapidity_order[:-1], eta_order[:-1]))
                out_node = np.concatenate((rapidity_order[1: ], eta_order[1: ]))

            g = dgl.graph(( in_node, out_node), num_nodes=n_nodes)
            g = dgl.add_reverse_edges(g)
            graph = dgl.add_self_loop(g)


        ''' Assign node feature using "current data" '''
        graph = self._assign_node_feature(graph)

        return graph.int() # 32-bit integers for node and edge IDs to reduce memory

    def create_graphs(self, stop=None):
        ''' Create all graphs for all events.
        '''

        ''' PDGid to small float dictionary https://github.com/pkomiske/EnergyFlow/blob/master/energyflow/utils/data_utils.py#L188 '''
        PID2FLOAT_MAP = {22: 0,
                    211: .1, -211: .2,
                    321: .3, -321: .4,
                    130: .5,
                    2112: .6, -2112: .7,
                    2212: .8, -2212: .9,
                    11: 1.0, -11: 1.1,
                    13: 1.2, -13: 1.3,
                    0: 0,}
        for pid in np.unique(self.data[:, :, 3].flatten()):
            np.place(self.data[:, :, 3], self.data[:, :, 3] == pid, PID2FLOAT_MAP[pid])

        graphs = []
        n_graphs = min(stop, self.data.shape[0]) if stop else self.data.shape[0]
        
        for i in range(n_graphs):
            if i % 1000 == 0:
                print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '\033[92m[INFO]\033[0m', f'\033[92mCreated {self.connect_style} graphs:\033[0m'.rjust(40, ' '),  i, '/', n_graphs)
            graph = self._create_graph(i)
            graphs.append(graph)
        return graphs


def plot_graph(graph):
    print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '\033[92m[INFO]\033[0m', '\033[92mNumber of nodes and edges:\033[0m'.rjust(40, ' '),  graph.number_of_nodes(), '/', graph.number_of_edges())
    nx.draw(graph.to_networkx(), with_labels=True, node_color=[[.7, .7, .7]])
    plt.show()


## Implementing Various GNN Architectures

In [None]:
from datetime import datetime
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.nn as dglnn
import dgl
import numpy as np
from dgl.nn.pytorch.conv import GraphConv

np.set_printoptions(suppress=True)

class SGC(nn.Module):
    ''' 
        https://docs.dgl.ai/api/python/nn.pytorch.html?highlight=sageconv#sgconv
    '''
    def __init__(self, in_feats, hid_feats, out_feats):
        super().__init__()
        self.conv1 = dglnn.SGConv(in_feats=in_feats, out_feats=hid_feats)
        self.conv2 = dglnn.SGConv(in_feats=hid_feats, out_feats=out_feats)

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = F.elu(h)
        h = self.conv2(graph, h)
        h = torch.sigmoid(h)
        return h

class SAGE(nn.Module):
    ''' Two layers of Inductive Representation Learning on Large Graphs.
        https://docs.dgl.ai/api/python/nn.pytorch.html#sageconv
    '''
    def __init__(self, in_feats, hid_feats, out_feats):
        super().__init__()
        self.conv1 = dglnn.SAGEConv(in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')
        self.conv2 = dglnn.SAGEConv(in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = F.elu(h)
        h = self.conv2(graph, h)
        h = torch.sigmoid(h)
        return h

class GAT(nn.Module):
    ''' Two layers of Graph Attention Network.
        https://docs.dgl.ai/api/python/nn.pytorch.html#gatconv
    '''
    def __init__(self, in_feats, hid_feats, out_feats):
        super().__init__()
        # input shape = (nodes, features=in_feats); output shape = (nodes, num_head, hid_feats)
        self.gatconv1 = dglnn.GATConv(in_feats, hid_feats, num_heads=2)
        # input shape = (nodes, hid_feats * num_heads_previous_layer); output shape = (nodes, num_head, out_feats)
        self.gatconv2 = dglnn.GATConv(hid_feats * 2, out_feats, num_heads=1)

    def forward(self, graph, inputs):
        # input shape = (nodes , features)
        # print('zhangr inputs', inputs.shape, inputs)
        h = self.gatconv1(graph, inputs)
        # here h shape = (nodes, num_head, hid_feats)
        # print('zhangr h1', h.shape, h)
        ''' Reshape h to flatten the num_heads '''
        # here h shape = (nodes, num_head * hid_feats)
        h = h.reshape(h.shape[0], np.prod(h.shape[1:]))
        # print('zhangr h2', h.shape, h)
        h = self.gatconv2(graph, h)
        # print('zhangr h3', h.shape, h)

        graph.ndata['tmp_feature'] = h
        h = dgl.mean_nodes(graph, 'tmp_feature')
        # print('zhangr h4', h.shape, h)
        h = torch.sigmoid(h)
        h = torch.squeeze(h)
        # print('zhangr h9', h.shape, h)
        return h

class AGNNConv(nn.Module):
    ''' Graph Attention Network.
        https://docs.dgl.ai/api/python/nn.pytorch.html#agnnconv
    '''
    def __init__(self):
        super().__init__()

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = AGNNConv()(graph, inputs)
        h = torch.sigmoid(h)
        return h

class GCN(nn.Module):
    ''' Graph Convolutional Network
        https://docs.dgl.ai/en/0.4.x/tutorials/basics/4_batch.html#graph-classification-tutorial
        https://docs.dgl.ai/en/0.4.x/tutorials/models/1_gnn/1_gcn.html
    '''
    def __init__(self, in_dim, hidden_dim, out_feats):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_dim, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, hidden_dim)
        self.conv3 = GraphConv(hidden_dim, out_feats)
        self.bn1 = nn.BatchNorm1d(num_features=hidden_dim)
        self.bn2 = nn.BatchNorm1d(num_features=hidden_dim)
        self.bn3 = nn.BatchNorm1d(num_features=out_feats)

    def forward(self, graph, h):
        # Use node degree as the initial node feature. For undirected graphs, the in-degree
        # is the same as the out_degree.
        # h = graph.in_degrees().view(-1, 1).float()
        # Perform graph convolution and activation function.
        h = F.relu(self.bn1(self.conv1(graph, h)))
        h = F.relu(self.bn2(self.conv2(graph, h)))
        h = F.relu(self.bn3(self.conv3(graph, h)))
        graph.ndata['tmp_feature'] = h
        h = dgl.mean_nodes(graph, 'tmp_feature')
        # Calculate graph representation by averaging all the node representations.
        h = torch.sigmoid(h)
        h = (h-0.5) * 2
        return h

class MLPPredictor(nn.Module):
    def __init__(self, in_features, out_classes):
        super().__init__()
        self.W = nn.Linear(in_features * 2, out_classes)

    def apply_edges(self, edges):
        h_u = edges.src['h']
        h_v = edges.dst['h']
        score = self.W(torch.cat([h_u, h_v], 1))
        return {'score': score}

    def forward(self, graph, h):
        # h contains the node representations computed from the GNN defined
        # in the node classification section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(self.apply_edges)
            return graph.edata['score']

class GNNmodel(nn.Module):
    def __init__(self, args, in_features, hidden_features, out_features):
        super().__init__()
        
        print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '\033[92m[INFO]\033[0m', '\033[92mGNN module\033[0m'.rjust(40, ' '), args.module.lower())
        if args.module.lower() == 'sage':
            # from dgl.nn import SAGEConv
            self.module = SAGE(in_features, hidden_features, out_features)
        elif args.module.lower() == 'gat':
            # from dgl.nn import GATConv
            self.module = GAT(in_features, hidden_features, out_features)
        elif args.module.lower() == 'agnnconv':
            # from dgl.nn import AGNNConv
            self.module = AGNNConv()
        elif args.module.lower() == 'sgc':
            # from dgl.nn import SGConv
            self.module = SGC(in_features, hidden_features, out_features)
        elif args.module.lower() == 'gcn':
            self.module = GCN(in_features, hidden_features, out_features)
        else:
            assert False, args.module + ' does not support'

        # self.pred = MLPPredictor(out_features, 1)

    def forward(self, g, x):
        h = self.module(g, x)
        return torch.squeeze(h) #self.pred(g, h)


In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import precision_recall_curve
import numpy as np

def cal_acc(test_target, test_yhat, train_target, train_yhat):
    train_results = np.logical_not(np.logical_xor(train_target.numpy() > 0.5, train_yhat> 0.5))
    test_results = np.logical_not(np.logical_xor(test_target.numpy() > 0.5, test_yhat> 0.5))
    train_acc = np.count_nonzero(train_results) / len(train_results)
    test_acc = np.count_nonzero(test_results) / len(test_results)
    
    return test_acc, train_acc

def plot_ROC(test_y, test_predict, train_y = None, train_predict = None, val_y = None, val_predict = None, show = True):
    if train_y is not None and train_predict is not None:
        train__false_positive, train__true_positive, train__thresholds = roc_curve(train_y, train_predict)
        train__roc_auc = auc(train__false_positive, train__true_positive)
    if val_y is not None and val_predict is not None:
        val__false_positive, val__true_positive, val__thresholds = roc_curve(val_y, val_predict)
        val__roc_auc = auc(val__false_positive, val__true_positive)
    test__false_positive, test__true_positive, test__thresholds = roc_curve(test_y, test_predict)
    test__roc_auc = auc(test__false_positive, test__true_positive)

    if show:
        # plt.title('Receiver Operating Characteristic')
        if train_y is not None and train_predict is not None:
            plt.plot(train__true_positive, 1-train__false_positive, 'g--', label='Train AUC = %0.3f'% train__roc_auc)
        if val_y is not None and val_predict is not None:
            plt.plot(val__true_positive, 1-val__false_positive, 'b--', label='Val AUC = %0.3f'% val__roc_auc)
        plt.plot(test__true_positive, 1-test__false_positive, 'r', label='Test AUC = %0.3f'% test__roc_auc)

        plt.legend(loc='lower left')
        plt.plot([0,1],[1,0],'k--')
        plt.xlim([-0.1,1.1])
        plt.ylim([-0.1,1.1])
        plt.ylabel('Background rejection')
        plt.xlabel('Signal efficiency')
        plt.show()

    if train_y is not None and train_predict is not None:
        return test__roc_auc, train__roc_auc
    else:
        return test__roc_auc

def plotPRC(test_y, test_predict, train_y = None, train_predict = None, show = True):
    if train_y is not None and train_predict is not None:
        train__precision, train__recall, train__thresholds = precision_recall_curve(train_y, train_predict)
        train__prec_auc = auc(train__recall, train__precision)
    test__precision, test__recall, test__thresholds = precision_recall_curve(test_y, test_predict)
    test__prec_auc = auc(test__recall, test__precision)

    if show:
        plt.title('Precision-Recall Curves')
        if train_y is not None and train_predict is not None:
            plt.plot(train__precision, train__recall, 'g--', label='Train PRC = %0.3f'% train__prec_auc)
        plt.plot(test__precision, test__recall, 'b', label='Test PRC = %0.3f'% test__prec_auc)

        plt.legend(loc='lower right')
        plt.xlim([-0.,1.1])
        plt.ylim([-0.,1.1])
        plt.ylabel('Precision')
        plt.xlabel('Recall')
        plt.show()

    if train_y is not None and train_predict is not None:
        return test__prec_auc, train__prec_auc
    else:
        return test__prec_auc

def plot_response(test_y, test_predict, train_y = None, train_predict = None, val_y = None, val_predict = None, nbin = 10, normalised = True, log = False, scaling = ''):
    xlo, xhi = 0, 1

    if scaling == 'minmax':
        hi, lo = max(test_predict), min(test_predict)
        if hi != lo:
            test_predict = (test_predict - lo) / (hi - lo)

    plt.hist(test_predict[test_y == 1], range=[xlo, xhi], bins=nbin, histtype="step", density=normalised, linewidth=2, label='Test signal')
    plt.hist(test_predict[test_y == 0], range=[xlo, xhi], bins=nbin, histtype="step", density=normalised, linewidth=2, label='Test bkg')
    if train_y is not None and train_predict is not None:
        if scaling == 'minmax':
            if hi != lo:
                train_predict = (train_predict - lo) / (hi - lo)
        plt.hist(train_predict[train_y == 1], range=[xlo, xhi], bins=nbin, histtype="step", density=normalised, linewidth=2, linestyle='dashed', label='Training signal')
        plt.hist(train_predict[train_y == 0], range=[xlo, xhi], bins=nbin, histtype="step", density=normalised, linewidth=2, linestyle='dashed', label='Training bkg')
    if val_y  is not None and val_predict is not None:
        if scaling == 'minmax':
            if hi != lo:
                val_predict = (val_predict - lo) / (hi - lo)
        plt.hist(val_predict[train_y == 1], range=[xlo, xhi], bins=nbin, histtype="step", density=normalised, linewidth=1, linestyle='dashed', label='Val signal')
        plt.hist(val_predict[train_y == 0], range=[xlo, xhi], bins=nbin, histtype="step", density=normalised, linewidth=1, linestyle='dashed', label='Val bkg')
    plt.ylim(0, plt.gca().get_ylim()[1] * 1.5)
    if log:
        plt.yscale('symlog')
        plt.ylabel('[Log scale]')
    plt.legend(loc='best', fancybox=True, framealpha=0.2)
    plt.xlabel('Response', fontsize='large')
    plt.title('Normalised' if normalised else 'Absolute')

    plt.show()

def plot_loss(train_loss, valid_loss, starting = 0):
    # summarize history for loss
    plt.plot(train_loss[starting:])
    plt.plot(valid_loss[starting:])
    plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Val'], loc='upper right')
    plt.show()


In [None]:
import numpy as np
import torch
import os

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        os.makedirs(os.path.dirname(self.path), exist_ok=True)
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

# https://github.com/sksq96/pytorch-summary/blob/master/torchsummary/torchsummary.py

import torch
import torch.nn as nn
from torch.autograd import Variable

from collections import OrderedDict
import numpy as np


def summary(model, input_size, batch_size=-1, dtypes=None):
    result, params_info = summary_string(
        model, input_size, batch_size, dtypes)
    print(result)

    return params_info


def summary_string(model, input_size, batch_size=-1, dtypes=None):
    if dtypes == None:
        dtypes = [torch.FloatTensor]*len(input_size)

    summary_str = ''

    def register_hook(module):
        def hook(module, input, output):
            class_name = str(module.__class__).split(".")[-1].split("'")[0]
            module_idx = len(summary)

            m_key = "%s-%i" % (class_name, module_idx + 1)
            summary[m_key] = OrderedDict()
            summary[m_key]["input_shape"] = list(input[0].size())
            summary[m_key]["input_shape"][0] = batch_size
            if isinstance(output, (list, tuple)):
                summary[m_key]["output_shape"] = [
                    [-1] + list(o.size())[1:] for o in output
                ]
            else:
                summary[m_key]["output_shape"] = list(output.size())
                summary[m_key]["output_shape"][0] = batch_size

            params = 0
            if hasattr(module, "weight") and hasattr(module.weight, "size"):
                params += torch.prod(torch.LongTensor(list(module.weight.size())))
                summary[m_key]["trainable"] = module.weight.requires_grad
            if hasattr(module, "bias") and hasattr(module.bias, "size"):
                params += torch.prod(torch.LongTensor(list(module.bias.size())))
            summary[m_key]["nb_params"] = params

        if (
            not isinstance(module, nn.Sequential)
            and not isinstance(module, nn.ModuleList)
        ):
            hooks.append(module.register_forward_hook(hook))

    # multiple inputs to the network
    if isinstance(input_size, tuple):
        input_size = [input_size]

    # batch_size of 2 for batchnorm
    x = [torch.rand(2, *in_size).type(dtype)
         for in_size, dtype in zip(input_size, dtypes)]

    # create properties
    summary = OrderedDict()
    hooks = []

    # register hook
    model.apply(register_hook)

    # make a forward pass
    # print(x.shape)
    model(*x)

    # remove these hooks
    for h in hooks:
        h.remove()

    summary_str += "----------------------------------------------------------------" + "\n"
    line_new = "{:>20}  {:>25} {:>15}".format(
        "Layer (type)", "Output Shape", "Param #")
    summary_str += line_new + "\n"
    summary_str += "================================================================" + "\n"
    total_params = 0
    total_output = 0
    trainable_params = 0
    for layer in summary:
        # input_shape, output_shape, trainable, nb_params
        line_new = "{:>20}  {:>25} {:>15}".format(
            layer,
            str(summary[layer]["output_shape"]),
            "{0:,}".format(summary[layer]["nb_params"]),
        )
        total_params += summary[layer]["nb_params"]

        total_output += np.prod(summary[layer]["output_shape"])
        if "trainable" in summary[layer]:
            if summary[layer]["trainable"] == True:
                trainable_params += summary[layer]["nb_params"]
        summary_str += line_new + "\n"

    # assume 4 bytes/number (float on cuda).
    total_input_size = abs(np.prod(sum(input_size, ()))
                           * batch_size * 4. / (1024 ** 2.))
    total_output_size = abs(2. * total_output * 4. /
                            (1024 ** 2.))  # x2 for gradients
    total_params_size = abs(total_params * 4. / (1024 ** 2.))
    total_size = total_params_size + total_output_size + total_input_size

    summary_str += "================================================================" + "\n"
    summary_str += "Total params: {0:,}".format(total_params) + "\n"
    summary_str += "Trainable params: {0:,}".format(trainable_params) + "\n"
    summary_str += "Non-trainable params: {0:,}".format(total_params -
                                                        trainable_params) + "\n"
    summary_str += "----------------------------------------------------------------" + "\n"
    summary_str += "Input size (MB): %0.2f" % total_input_size + "\n"
    summary_str += "Forward/backward pass size (MB): %0.2f" % total_output_size + "\n"
    summary_str += "Params size (MB): %0.2f" % total_params_size + "\n"
    summary_str += "Estimated Total Size (MB): %0.2f" % total_size + "\n"
    summary_str += "----------------------------------------------------------------" + "\n"
    # return summary
    return summary_str, (total_params, trainable_params)


## Do the Training

In [None]:
from datetime import datetime
from argparse import ArgumentParser
import random
import numpy as np
np.set_printoptions(suppress=True)
from os import makedirs, path, remove
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

import dgl
from dgl.data.utils import save_graphs, load_graphs
import torch.nn.functional as F

from createGraph import *
from createGNN import *
from pytorchtools import *
from plottools import *

from matplotlib import rcParams
size = 15
rcParams['font.size'] = size
rcParams['font.weight'] = 'bold'
rcParams['axes.labelweight'] = 'bold'
rcParams['axes.titleweight'] = 'bold'
rcParams['axes.linewidth'] = 2
rcParams['figure.facecolor'] = (1,1,1,0)
torch.autograd.set_detect_anomaly(True)

''' How to batch graphs '''
# https://docs.dgl.ai/en/0.4.x/tutorials/basics/4_batch.html


def train_q_g(args):

    input_file = f'../data/QG_orig/QG_jets.npz'
    graph_file = f'../data/QG_graph/QG_jets__{args.connection}__{args.nevent if args.nevent else "All"}.bin'
    gmodel_name = f'../data/QG_model/QG_jets.model'

    makedirs(path.dirname(gmodel_name), exist_ok=True)

    if not args.overwrite and path.exists(graph_file):
        ''' Load graphs from local disk '''
        graphs, y = load_graphs(graph_file)
        print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '\033[92m[INFO]\033[0m', '\033[92mLoad training graphs from\033[0m'.rjust(40, ' '), graph_file, len(graphs))
        if args.nevent:
            graphs = graphs[:args.nevent]
    else:
        ''' Create graphs and save to local disk '''
        input_data = np.load(input_file)
        X, y = input_data['X'], {'label': torch.tensor(input_data['y']).long()}
        generator = GenerateGraphs(X, args.connection)
        graphs = generator.create_graphs(stop=args.nevent)
        print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '\033[92m[INFO]\033[0m', '\033[92mCreate graphs from\033[0m'.rjust(40, ' '), input_file, len(graphs))
        makedirs(path.dirname(graph_file), exist_ok=True)
        save_graphs(graph_file, graphs, y)
        print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '\033[92m[INFO]\033[0m', '\033[92mSave graphs to\033[0m'.rjust(40, ' '), graph_file, len(graphs))

    print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '\033[92m[INFO]\033[0m', '\033[92mBelow the 1st graphs\033[0m'.rjust(40, ' '))
    plot_graph(graphs[0])

    ''' Construct dataset / dataloader '''
    train_size = int(len(graphs) * 0.6)
    val_size = int(len(graphs) * 0.2)
    test_size = len(graphs) - train_size - val_size

    dataset = list(zip(graphs, y['label']))
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, (train_size, val_size, test_size))

    train_dataloader = dgl.dataloading.GraphDataLoader(train_dataset, batch_size = train_size, drop_last=False, shuffle=True)
    val_dataloader = dgl.dataloading.GraphDataLoader(val_dataset, batch_size = val_size, drop_last=False, shuffle=False)
    test_dataloader = dgl.dataloading.GraphDataLoader(test_dataset, batch_size = test_size, drop_last=False, shuffle=False)

    # for g in train_graphs:
    #     print('zhangr train graph', g.ndata, g.nodes())
    # for g in val_graphs:
    #     print('zhangr val graph', g.ndata, g.nodes())
    print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '\033[92m[INFO]\033[0m', '\033[92mSplit train/val/test to\033[0m'.rjust(40, ' '), f'{train_size}/{val_size}/{test_size}')
    loss_func = nn.MSELoss() # nn.CrossEntropyLoss()

    if not args.no_training:
        ''' Construct training model '''
        hidden_features, out_features = 9, 1 # 2 = classifier predict
        model = GNNmodel(args, graphs[0].ndata[list(graphs[0].ndata.keys())[0]].shape[1], hidden_features, out_features)
        print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '\033[92m[INFO]\033[0m', '\033[92mNode feature dim changing\033[0m'.rjust(40, ' '), graphs[0].ndata[list(graphs[0].ndata.keys())[0]].shape[1], hidden_features, out_features)
        # opt = torch.optim.Adam(model.parameters())
        opt = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        epoch_losses = {'train': [], 'val': [], 'test': []}


        ''' Training '''
        for epoch in range(args.epochs):
            train_loss = 0
            for ibatch, (batched_graph, labels) in enumerate(train_dataloader):
                # plot_Graph(batched_graph)
                node_features = batched_graph.ndata['feature']
                pred = model(batched_graph, node_features.float())
                loss = loss_func(pred, labels.float())
                opt.zero_grad()
                loss.backward()
                opt.step()

                train_loss += loss.detach().item()

            ''' The loss per epoch for all batch is the average of losses per batch in this epoch '''
            train_loss /= (ibatch + 1)

            ''' Store loss per epoch '''
            epoch_losses['train'].append(train_loss)

            ''' Evaluate validation loss '''
            for ibatch, (batched_graph, labels) in enumerate(val_dataloader):
                node_features = batched_graph.ndata['feature']
                pred = model(batched_graph, node_features.float())
                val_loss = loss_func(pred, labels.float()).detach().item()
                epoch_losses['val'].append(val_loss)
                assert(ibatch == 0)

            if epoch < 20:
                print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '\033[92m[INFO]\033[0m', '\033[92mepoch, loss, val_loss:\033[0m'.rjust(40, ' '),  epoch+1, '|', train_loss, '|', val_loss)
            elif args.epochs > 100 and epoch % 100 == 0:
                print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '\033[92m[INFO]\033[0m', '\033[92mepoch, loss, val_loss:\033[0m'.rjust(40, ' '),  epoch+1, '|', train_loss, '|', val_loss)
            elif args.epochs > 1000 and epoch % 1000 == 0:
                print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '\033[92m[INFO]\033[0m', '\033[92mepoch, loss, val_loss:\033[0m'.rjust(40, ' '),  epoch+1, '|', train_loss, '|', val_loss)
            elif epoch == args.epochs - 1:
                print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '\033[92m[INFO]\033[0m', '\033[92mepoch, loss, val_loss:\033[0m'.rjust(40, ' '),  epoch+1, '|', train_loss, '|', val_loss)

            ''' Early stopping '''
            early_stopping = EarlyStopping(patience=10, verbose=False, path=gmodel_name+'.checkpoint')
            # early_stopping needs the validation loss to check if it is decreasing, 
            # and if it is, it will make a checkpoint of the current model
            early_stopping(val_loss, model)

            if early_stopping.early_stop:
                print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '\033[92m[INFO]\033[0m', '\033[92mepoch, loss, val_loss:\033[0m'.rjust(40, ' '),  epoch+1, '|', train_loss, '|', val_loss)
                print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '\033[92m[INFO]\033[0m', '\033[92mEarly stop at:\033[0m'.rjust(40, ' '),  epoch)
                break
    
        ''' Load the last checkpoint with the best model. '''
        model.load_state_dict(torch.load(gmodel_name+'.checkpoint'))

        if path.exists(gmodel_name):
            print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '\033[92m[INFO]\033[0m', '\033[92mRemove model:\033[0m'.rjust(40, ' '),  gmodel_name)
            remove(gmodel_name)

        print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '\033[92m[INFO]\033[0m', '\033[92mSave model:\033[0m'.rjust(40, ' '),  gmodel_name)
        torch.save(model, gmodel_name)

    else: # args.no_training
        print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), '\033[92m[INFO]\033[0m', '\033[92mSkip training and load\033[0m'.rjust(40, ' '),  gmodel_name)
        model = torch.load(gmodel_name)

    with torch.no_grad():
        ''' Final Evaluate train loss '''
        train_pred = []
        train_labels = []
        for ibatch, (batched_graph, labels) in enumerate(train_dataloader):
            node_features = batched_graph.ndata['feature']
            pred = model(batched_graph, node_features.float()).detach()
            # print('zhangr train pred', pred.shape, pred)
            # print('zhangr train labels', labels.shape, labels)
            train_pred.append(pred)
            train_labels.append(labels)
        train_pred = torch.cat(train_pred).numpy()
        train_labels = torch.cat(train_labels).numpy()

        ''' Final Evaluate val loss '''
        val_pred = []
        val_labels = []
        for ibatch, (batched_graph, labels) in enumerate(train_dataloader):
            node_features = batched_graph.ndata['feature']
            pred = model(batched_graph, node_features.float()).detach()
            val_pred.append(pred)
            val_labels.append(labels)
            assert(ibatch == 0)
        val_pred = torch.cat(val_pred).numpy()
        val_labels = torch.cat(val_labels).numpy()


        ''' Evaluate test loss '''
        test_pred = []
        test_labels = []
        for ibatch, (batched_graph, labels) in enumerate(test_dataloader):
            node_features = batched_graph.ndata['feature']
            pred = model(batched_graph, node_features.float()).detach()
            if not args.no_training:
                test_loss = loss_func(pred, labels.float()).detach().item()
                epoch_losses['test'].append(test_loss)
            test_pred.append(pred)
            test_labels.append(labels)
            assert(ibatch == 0)
        test_pred = torch.cat(test_pred).numpy()
        test_labels = torch.cat(test_labels).numpy()

    if not args.no_training:
        print('zhangr loss', len(epoch_losses['train']), len(epoch_losses['val']))
        plot_loss(epoch_losses['train'], epoch_losses['val'], starting = 0)
    plot_ROC(test_labels, test_pred, train_labels, train_pred, val_labels, val_pred)
    plot_response(test_labels, test_pred, train_labels, train_pred, val_labels, val_pred, scaling = '')

    return


if __name__ == '__main__':

    """Get arguments from command line."""
    parser = ArgumentParser(description='\033[92mGNN training.\033[0m')

    parser.add_argument('--bs', type=int, default=10, help='Batch size (defult: %(default)s).')
    parser.add_argument('--overwrite', type=bool, default=False, help='Overwrite if any to cached file (defult: %(default)s).')
    parser.add_argument('--epochs', type=int, default=200, help='Epochs (defult: %(default)s).')
    parser.add_argument('--module', type=str, choices=['sage', 'gat', 'agnnconv', 'sgc'], default='sgc', help='Network module')
    parser.add_argument('--nevent', type=int, default=None, help='Number of graphs from events to create (defult: %(default)s).')
    parser.add_argument('--connection', type=str, default='biadj_pt_y_phi', choices = ['bifully', 'biadj_pt_y_phi', 'biadj_pt_y', 'biadj_pt_phi', 'biadj_y_phi'], help='Type of edge connections (defult: %(default)s).')

    args = parser.parse_args()

    job1(args)
