This file contains GNN based solution for specific task of project [Graph Neural Networks for Particle Momentum Estimation in the CMS Trigger System](https://drive.google.com/file/d/13gQToLhaoKGM7hXJY2sxVaVFqvS0Z9X9/view).

GNN layers been used:

- [Graph Convolution Layer](https://arxiv.org/pdf/1609.02907.pdf)

- [PointNet Convolution Layer](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.PointNetConv.html#torch_geometric.nn.conv.PointNetConv)

- [GraphGPS](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GPSConv.html#torch_geometric.nn.conv.GPSConv)

In all models, architecture is composed of two layers.
Latent embedding dimension is set to 300.

Node features:
- Channel values
- Global Positional encoding (2D coordinates of the nodes) - optional
- Additionally random walk embeddings can be added in preprocessing step - optional 
  (RW embeddings are used in GraphGPS based model)

Edge features:
- Euclidean distance between nodes.

All models are trained for 75 epochs.

In [1]:
import os
from tqdm import tqdm
import torch
from torch_geometric.nn import MessagePassing,GPSConv, GINEConv
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.utils import degree
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader
import torch.optim as optim
from torchmetrics.classification import MulticlassAUROC, MulticlassAccuracy

from dataset import QGJetsGraphsDataset
from torch_geometric.nn.conv import GATConv,PointNetConv

In [2]:
### GIN convolution along the graph structure
class GINConv(MessagePassing):
    def __init__(self, emb_dim,input_node_dim,input_edge_dim):

        super(GINConv, self).__init__(aggr = "add")

        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), 
                                       torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
        self.eps = torch.nn.Parameter(torch.Tensor([0]))
        self.linear = torch.nn.Linear(input_node_dim, emb_dim)
        self.edge_encoder = torch.nn.Linear(input_edge_dim, emb_dim)

    def forward(self, x, edge_index, edge_attr):
        x = self.linear(x)
        edge_embedding = self.edge_encoder(edge_attr)
        out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))

        return out

    def message(self, x_j, edge_attr):
        return F.relu(x_j + edge_attr)

    def update(self, aggr_out):
        return aggr_out

### GCN convolution along the graph structure
class GCNConv(MessagePassing):
    def __init__(self, emb_dim,input_node_dim,input_edge_dim):
        super(GCNConv, self).__init__(aggr='add')
        
        self.linear = torch.nn.Linear(input_node_dim, emb_dim)
        self.root_emb = torch.nn.Embedding(1, emb_dim)
        self.edge_encoder = torch.nn.Linear(input_edge_dim, emb_dim)

    def forward(self, x, edge_index, edge_attr):
        x = self.linear(x)
        edge_embedding = self.edge_encoder(edge_attr)

        row, col = edge_index

        #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device)
        deg = degree(row, x.size(0), dtype = x.dtype) + 1
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        return self.propagate(edge_index, x=x, edge_attr = edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1)

    def message(self, x_j, edge_attr, norm):
        return norm.view(-1, 1) * F.relu(x_j + edge_attr)

    def update(self, aggr_out):
        return aggr_out

In [3]:
class mlp(torch.nn.Module):
    def __init__(self,input_node_dim,emb_dim):
        super(mlp, self).__init__()
        self.mlp = torch.nn.Sequential(torch.nn.Linear(input_node_dim, 2*emb_dim), 
                                       torch.nn.BatchNorm1d(2*emb_dim), 
                                       torch.nn.ReLU(), 
                                       torch.nn.Linear(2*emb_dim, emb_dim))
    def forward(self,x):
        return self.mlp(x)

In [4]:
class MessagePasssing_Module(torch.nn.Module):
    """
    MessagePasssing_Module contains 2 or more GNN layers stacked.
    Output:
        node representations
    """
    def __init__(self, num_layer, input_node_dim, input_edge_dim, emb_dim,extraPE_dim=None,
                 extraPE_method='sum', drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'):
        '''
            emb_dim (int): node embedding dimensionality
            num_layer (int): number of GNN message passing layers
        '''
        super(MessagePasssing_Module, self).__init__()
        
        self.gnn_type = gnn_type
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.input_node_dim = input_node_dim
        self.input_edge_dim = input_edge_dim
        ### add residual connection or not
        self.residual = residual
        self.extraPE_dim = extraPE_dim
        self.extraPE_method = extraPE_method

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        ### List of GNNs
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()
        
        if extraPE_dim:
            # this is to transform random walk encoding
            self.extraPE_Encoder = torch.nn.Linear(self.extraPE_dim,self.input_node_dim)
        
        if extraPE_dim and extraPE_method=='cat':
            self.input_node_dim *= 2
            
        for layer in range(num_layer):
            if layer == 0:
                if gnn_type == 'gin':
                    self.convs.append(GINConv(emb_dim,input_node_dim=self.input_node_dim,input_edge_dim=self.input_edge_dim))
                elif gnn_type == 'gcn':
                    self.convs.append(GCNConv(emb_dim,input_node_dim=self.input_node_dim,input_edge_dim=self.input_edge_dim))
                elif gnn_type == 'gat':
                    self.convs.append(GATConv(in_channels=self.input_node_dim,out_channels=emb_dim,edge_dim=self.input_edge_dim))
                elif gnn_type == "pointnet":
                    local_mlp = mlp(self.input_node_dim+2, emb_dim)
                    global_mlp = None
                    self.convs.append(PointNetConv(local_mlp,global_mlp))
                    
                elif gnn_type == "gps":
                    # we need to explicitly declare this mlp
                    nn = torch.nn.Sequential(torch.nn.Linear(self.input_node_dim, 2*emb_dim),
                                           torch.nn.BatchNorm1d(2*emb_dim), 
                                           torch.nn.ReLU(), 
                                           torch.nn.Linear(2*emb_dim, emb_dim))
                    self.convs.append(GPSConv(self.input_node_dim, GINEConv(nn,edge_dim =self.input_edge_dim), 
                                              heads=5, attn_dropout=0.3))
                else:
                    raise ValueError('Undefined GNN type called {}'.format(gnn_type))
                
            else:
                if gnn_type == 'gin':
                    self.convs.append(GINConv(emb_dim,input_node_dim=emb_dim,input_edge_dim=self.input_edge_dim))
                elif gnn_type == 'gcn':
                    self.convs.append(GCNConv(emb_dim,input_node_dim=emb_dim,input_edge_dim=self.input_edge_dim))
                elif gnn_type == 'gat':
                    self.convs.append(GATConv(in_channels=emb_dim,out_channels=emb_dim,edge_dim=self.input_edge_dim))
                elif gnn_type == "pointnet":
                    local_mlp = mlp(emb_dim+2, emb_dim)
                    global_mlp = None
                    self.convs.append(PointNetConv(local_mlp,global_mlp))
                    
                elif gnn_type == "gps":
                    # we need to explicitly declare this mlp
                    nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), 
                                       torch.nn.BatchNorm1d(2*emb_dim), 
                                       torch.nn.ReLU(), 
                                       torch.nn.Linear(2*emb_dim, emb_dim))
                    self.convs.append(GPSConv(emb_dim, GINEConv(nn,edge_dim =self.input_edge_dim),
                                              heads=5, attn_dropout=0.3))
                else:
                    raise ValueError('Undefined GNN type called {}'.format(gnn_type))

            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

    def forward(self, batched_data):
        
        x = batched_data.x
        edge_index = batched_data.edge_index
        edge_attr = batched_data.edge_attr
        pos = batched_data.pos
        batch = batched_data.batch
        
        if self.extraPE_dim:
            extraPE = batched_data.extraPE
            extraPE_emb = self.extraPE_Encoder(extraPE)
            if self.extraPE_method == 'sum':
                h_list = [x+extraPE_emb]
            elif self.extraPE_method == 'cat':
                h_list = [torch.cat((x,extraPE_emb),1)]      
        else:    
            h_list = [x]  
            
        for layer in range(self.num_layer):
            if self.gnn_type == 'pointnet':
                h = self.convs[layer](h_list[layer], pos, edge_index)
            elif self.gnn_type == 'gps':
                h = self.convs[layer](h_list[layer], edge_index,batch=batch,edge_attr=edge_attr)
            else:    
                h = self.convs[layer](h_list[layer], edge_index, edge_attr)
                
            h = self.batch_norms[layer](h)

            if layer == self.num_layer - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training = self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)

            if self.residual:
                h += h_list[layer]

            h_list.append(h)

        ### Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            # we don't sum the input features 
            # we only sum outputs of each layer
            for layer in range(1,self.num_layer + 1):
                node_representation += h_list[layer]

        return node_representation


In [5]:
class GNN(torch.nn.Module):

    def __init__(self, num_classes=2, num_layer = 5,num_pre_fnn_layers =0,num_post_fnn_layers =1,hasPos =True,num_coords=3, 
                 input_spec_fts_dim=3,input_edge_dim = 1, emb_dim = 300,extraPE_dim=None,extraPE_method = 'sum', gnn_type = 'gcn', 
                 residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean"):
        '''
            hasPos (bool) : whether input node features should contain global positioning embeded
                            ps: global positioning is the coordinate of the pixel on 2D grid.
            input_spec_fts_dim (int) : denotes number of specific features (features apart from postional embedding)
            num_coords : number of coordinates required for the positional embedding 
            extraPE_dim: Denotes dimension of random walk  or Laplacian eigenvector positional encoding
            extraPE_method: Denotes how random walk embeddings or Laplacian eigenvector positional encoding should be embedded 
                         cat - concatenation, sum - summation.
        '''
        
        super(GNN, self).__init__()
        
        self.gnn_type = gnn_type
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.hasPos = hasPos
        self.num_coords =num_coords
        self.num_classes = num_classes
        self.num_pre_fnn_layers = num_pre_fnn_layers
        self.num_post_fnn_layers = num_post_fnn_layers
        self.graph_pooling = graph_pooling
        self.input_spec_fts_dim = input_spec_fts_dim
        self.input_edge_dim = input_edge_dim
        self.extraPE_method = extraPE_method
        
        if self.gnn_type=="pointnet":
            self.hasPos = False
        
        self.pos_kwd = "hasPos"
        if not self.hasPos:
            self.pos_kwd = "noPos"
            
        if not self.hasPos:
            self.input_node_dim = self.input_spec_fts_dim
        else:
            self.input_node_dim = self.input_spec_fts_dim+num_coords

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")
            
        if self.num_post_fnn_layers < 1:
            raise ValueError("Number of GNN layers must be greater than or equal to 1.")
        
        self.graph_pred_pre_linear_list = torch.nn.ModuleList()
        
        # dimention of node fts which are fed into message passing layers
        self.input_node_dim_mp = self.input_node_dim
        
        if self.num_pre_fnn_layers >0:
            self.graph_pred_pre_linear_list.append(torch.nn.Linear(self.input_node_dim, emb_dim))
            for i in range(1,num_pre_fnn_layers):
                self.graph_pred_pre_linear_list.append(torch.nn.Linear(emb_dim, emb_dim))
            self.input_node_dim_mp = emb_dim
                 
        ### GNN to generate node embeddings
        self.gnn_node = MessagePasssing_Module(num_layer,input_node_dim=self.input_node_dim_mp,
                                               input_edge_dim = self.input_edge_dim, emb_dim=emb_dim,
                                               extraPE_dim = extraPE_dim, extraPE_method = self.extraPE_method,
                                               JK = JK, drop_ratio = drop_ratio, residual = residual, 
                                               gnn_type = gnn_type)

        ### Pooling function to generate entire-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = global_add_pool
        elif self.graph_pooling == "mean":
            self.pool = global_mean_pool
        elif self.graph_pooling == "max":
            self.pool = global_max_pool
        elif self.graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), 
                                                                      torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(),
                                                                      torch.nn.Linear(2*emb_dim, 1)))
        else:
            raise ValueError("Invalid graph pooling type.")

        self.graph_pred_post_linear_list = torch.nn.ModuleList()

        for i in range(num_post_fnn_layers-1):
            self.graph_pred_post_linear_list.append(torch.nn.Linear(emb_dim, emb_dim))
        self.graph_pred_post_linear_list.append(torch.nn.Linear(emb_dim, self.num_classes))
                


    def forward(self, batched_data):

        input_x = batched_data.x # here we can split the x 
        
        batched_data.pos = input_x[:,self.input_spec_fts_dim:] # this will keep pos embeddings
        input_x = input_x[:,:self.input_node_dim]
        prep_x = input_x
        
        #preprocessing node features (only). 
        for fnn_inx in range(self.num_pre_fnn_layers):
            prep_x = self.graph_pred_pre_linear_list[fnn_inx](prep_x)
        
        batched_data.x = prep_x
        
        h_node = self.gnn_node(batched_data)

        h_graph = self.pool(h_node, batched_data.batch)

        output = h_graph # initial input is set to the output of the GNN 
        
        #postprocessing graph embeddings (only). 
        for fnn_inx in range(self.num_post_fnn_layers):
            output = self.graph_pred_post_linear_list[fnn_inx](output)
            
        return F.softmax(output,dim=1)
    
    def __str__(self):
        return self.gnn_type+f"-model-{self.pos_kwd}"

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device("cpu"))

In [7]:
multicls_criterion = torch.nn.CrossEntropyLoss()
epochs = 75

In [8]:
def import_dataset(name,transform=None, pre_transform=None,pre_filter=None):
    return QGJetsGraphsDataset('../../dataset/',name=name,transform=transform,
                             pre_transform=pre_transform,pre_filter=pre_filter)

In [9]:
def create_loaders(dataset,batch_size=32):
    # random splitting dataset
    train_inx, valid_inx, test_inx = random_split(range(len(dataset)),[0.7,0.2,0.1],generator=torch.Generator()
                                                .manual_seed(42))

    train_dataloader = DataLoader(dataset[list(train_inx)], batch_size=batch_size, shuffle=True)
    valid_dataloader = DataLoader(dataset[list(valid_inx)], batch_size=batch_size, shuffle=False)
    test_dataloader = DataLoader(dataset[list(test_inx)], batch_size=batch_size, shuffle=False)
    
    return train_dataloader,valid_dataloader,test_dataloader

In [10]:
def train(model, device, loader, optimizer):
    model.train()

    loss_accum = 0
    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch=batch.to(device)
        if batch.x.shape[0] == 1:
            pass
        else: 
            output = model(batch)
            optimizer.zero_grad()
            loss = multicls_criterion(output, batch.y.view(-1).to(torch.int64))
            loss.backward()
            optimizer.step()

        loss_accum += loss.item()

    print('Average training loss: {}'.format(loss_accum / (step + 1))) 

In [11]:
def evaluate(model, device, loader,evaluator= "roauc"):
    model.eval()
    
    preds_list = []
    target_list = []
    for step, batch in enumerate(loader):
        batch = batch.to(device)
        with torch.no_grad():
            output = model(batch)
            preds_list.extend(output.tolist())
        target_list += batch.y.view(-1).tolist()

    if evaluator == "roauc":   
        metric = MulticlassAUROC(num_classes=2, average="macro", thresholds=None)
    if evaluator == "acc":
        metric = MulticlassAccuracy(num_classes=2, average="macro")
    # print("AUC-ROC metric score : ",metric(torch.Tensor(preds_list),torch.Tensor(target_list)).item())
    return metric(torch.Tensor(preds_list),torch.Tensor(target_list).to(torch.int64)).item()

In [12]:
def train_model(model,optimizer,dataset,batch_size=32):
    checkpoints_path = "../models"
    checkpoints = os.listdir(checkpoints_path)
    checkpoint_path = list(filter(lambda i : str(model) in i, checkpoints))
    
    train_curves = []
    valid_curves = []
    starting_epoch = 1 
    
    # create loaders 
    train_dataloader,valid_dataloader,test_dataloader = create_loaders(dataset,batch_size=batch_size)
    
    if len(checkpoint_path)>0:
        checkpoint = torch.load(f"{checkpoints_path}/{checkpoint_path[0]}")
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        starting_epoch = checkpoint['epoch']+1

    for epoch in range(starting_epoch, epochs + 1):
        print("=====Epoch {}".format(epoch))
        print('Training...')
        train(model, device, train_dataloader, optimizer)
        
        # save checkpoint of current epoch
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                }, f"{checkpoints_path}/{str(model)}-{epoch}.pt")

        # delete checkpoint of previous epoch
        if epoch>1:
            os.remove(f"{checkpoints_path}/{str(model)}-{epoch-1}.pt")

        print("Evaluating...")
        train_perf_roauc = evaluate(model,device,train_dataloader)
        valid_perf_roauc = evaluate(model,device,valid_dataloader)
        print('ROAUC scores: ',{'Train': train_perf_roauc, 'Validation': valid_perf_roauc})
        
    print('\nFinished training!')
    print('\nROAUC Test score: {}'.format(evaluate(model,device,test_dataloader)))

## Training PointNet Conv based GNN model

In [13]:
QG_jets_dataset = import_dataset(name="QG_Jets")

pointnet_model = GNN(num_classes = 2, num_layer = 2,num_post_fnn_layers=2,input_edge_dim = 1,num_coords=2, 
                 input_spec_fts_dim=4,gnn_type = 'pointnet', emb_dim = 300, drop_ratio = 0.3).to(device)
optimizer = optim.Adam(pointnet_model.parameters(), lr=1e-3)

train_model(pointnet_model,optimizer,QG_jets_dataset)



=====Epoch 1
Training...


Iteration: 100%|██████████| 2188/2188 [00:50<00:00, 43.01it/s]


Average training loss: 0.5277088258340547
Evaluating...
ROAUC scores:  {'Train': 0.8568630218505859, 'Validation': 0.855898916721344}
=====Epoch 2
Training...


Iteration: 100%|██████████| 2188/2188 [00:48<00:00, 44.92it/s]


Average training loss: 0.5228723950643208
Evaluating...
ROAUC scores:  {'Train': 0.8608030080795288, 'Validation': 0.8602957725524902}
=====Epoch 3
Training...


Iteration: 100%|██████████| 2188/2188 [01:28<00:00, 24.76it/s]


Average training loss: 0.5215871250852806
Evaluating...
ROAUC scores:  {'Train': 0.8625271320343018, 'Validation': 0.8627728819847107}
=====Epoch 4
Training...


Iteration: 100%|██████████| 2188/2188 [01:51<00:00, 19.64it/s]


Average training loss: 0.5207997143704451
Evaluating...
ROAUC scores:  {'Train': 0.8586188554763794, 'Validation': 0.8585447072982788}
=====Epoch 5
Training...


Iteration: 100%|██████████| 2188/2188 [02:20<00:00, 15.53it/s]


Average training loss: 0.5194510995687907
Evaluating...
ROAUC scores:  {'Train': 0.8625052571296692, 'Validation': 0.862358808517456}
=====Epoch 6
Training...


Iteration: 100%|██████████| 2188/2188 [02:13<00:00, 16.33it/s]


Average training loss: 0.5181096819054056
Evaluating...
ROAUC scores:  {'Train': 0.8652693033218384, 'Validation': 0.8648958206176758}
=====Epoch 7
Training...


Iteration: 100%|██████████| 2188/2188 [02:04<00:00, 17.51it/s]


Average training loss: 0.5183502028149705
Evaluating...
ROAUC scores:  {'Train': 0.8637072443962097, 'Validation': 0.8639230728149414}
=====Epoch 8
Training...


Iteration: 100%|██████████| 2188/2188 [02:28<00:00, 14.76it/s]


Average training loss: 0.516902874031494
Evaluating...
ROAUC scores:  {'Train': 0.8627790212631226, 'Validation': 0.8623815178871155}
=====Epoch 9
Training...


Iteration: 100%|██████████| 2188/2188 [02:21<00:00, 15.45it/s]


Average training loss: 0.5170029715466326
Evaluating...
ROAUC scores:  {'Train': 0.8629748821258545, 'Validation': 0.862883448600769}
=====Epoch 10
Training...


Iteration: 100%|██████████| 2188/2188 [02:21<00:00, 15.41it/s]


Average training loss: 0.5168866911525046
Evaluating...
ROAUC scores:  {'Train': 0.8669871091842651, 'Validation': 0.8659375905990601}
=====Epoch 11
Training...


Iteration: 100%|██████████| 2188/2188 [02:35<00:00, 14.11it/s]


Average training loss: 0.5159665670865633
Evaluating...
ROAUC scores:  {'Train': 0.8625607490539551, 'Validation': 0.8614637851715088}
=====Epoch 12
Training...


Iteration: 100%|██████████| 2188/2188 [02:32<00:00, 14.37it/s]


Average training loss: 0.5155245222736974
Evaluating...
ROAUC scores:  {'Train': 0.864812433719635, 'Validation': 0.8639065027236938}
=====Epoch 13
Training...


Iteration: 100%|██████████| 2188/2188 [02:40<00:00, 13.59it/s]


Average training loss: 0.5153443809646358
Evaluating...
ROAUC scores:  {'Train': 0.868465781211853, 'Validation': 0.8682403564453125}
=====Epoch 14
Training...


Iteration: 100%|██████████| 2188/2188 [02:40<00:00, 13.61it/s]


Average training loss: 0.5142440746591121
Evaluating...
ROAUC scores:  {'Train': 0.8664100170135498, 'Validation': 0.8657922744750977}
=====Epoch 15
Training...


Iteration: 100%|██████████| 2188/2188 [02:40<00:00, 13.62it/s]


Average training loss: 0.51466863436226
Evaluating...
ROAUC scores:  {'Train': 0.8695412874221802, 'Validation': 0.8688908815383911}
=====Epoch 16
Training...


Iteration: 100%|██████████| 2188/2188 [02:40<00:00, 13.61it/s]


Average training loss: 0.5133845845349746
Evaluating...
ROAUC scores:  {'Train': 0.8692023158073425, 'Validation': 0.8683531284332275}
=====Epoch 17
Training...


Iteration: 100%|██████████| 2188/2188 [02:36<00:00, 13.97it/s]


Average training loss: 0.5131847959495331
Evaluating...
ROAUC scores:  {'Train': 0.8675959706306458, 'Validation': 0.867475152015686}
=====Epoch 18
Training...


Iteration: 100%|██████████| 2188/2188 [02:30<00:00, 14.58it/s]


Average training loss: 0.5134600626076598
Evaluating...
ROAUC scores:  {'Train': 0.8700790405273438, 'Validation': 0.8692171573638916}
=====Epoch 19
Training...


Iteration: 100%|██████████| 2188/2188 [02:30<00:00, 14.58it/s]


Average training loss: 0.5132731418014662
Evaluating...
ROAUC scores:  {'Train': 0.8682753443717957, 'Validation': 0.8676638603210449}
=====Epoch 20
Training...


Iteration: 100%|██████████| 2188/2188 [02:34<00:00, 14.18it/s]


Average training loss: 0.5127640857856713
Evaluating...
ROAUC scores:  {'Train': 0.8686457872390747, 'Validation': 0.8680394887924194}
=====Epoch 21
Training...


Iteration: 100%|██████████| 2188/2188 [02:29<00:00, 14.65it/s]


Average training loss: 0.5130286714440072
Evaluating...
ROAUC scores:  {'Train': 0.870881199836731, 'Validation': 0.8698322772979736}
=====Epoch 22
Training...


Iteration: 100%|██████████| 2188/2188 [02:34<00:00, 14.13it/s]


Average training loss: 0.5133288505962169
Evaluating...
ROAUC scores:  {'Train': 0.8710359334945679, 'Validation': 0.8703526258468628}
=====Epoch 23
Training...


Iteration: 100%|██████████| 2188/2188 [02:31<00:00, 14.41it/s]


Average training loss: 0.5119278591800434
Evaluating...
ROAUC scores:  {'Train': 0.8699014782905579, 'Validation': 0.8688746690750122}
=====Epoch 24
Training...


Iteration: 100%|██████████| 2188/2188 [02:32<00:00, 14.33it/s]


Average training loss: 0.5125798465074525
Evaluating...
ROAUC scores:  {'Train': 0.8713685274124146, 'Validation': 0.8705602884292603}
=====Epoch 25
Training...


Iteration: 100%|██████████| 2188/2188 [02:28<00:00, 14.76it/s]


Average training loss: 0.5114661490023681
Evaluating...
ROAUC scores:  {'Train': 0.8690933585166931, 'Validation': 0.8681923151016235}
=====Epoch 26
Training...


Iteration: 100%|██████████| 2188/2188 [02:32<00:00, 14.33it/s]


Average training loss: 0.5120576177834375
Evaluating...
ROAUC scores:  {'Train': 0.8720210790634155, 'Validation': 0.8710700273513794}
=====Epoch 27
Training...


Iteration: 100%|██████████| 2188/2188 [02:29<00:00, 14.63it/s]


Average training loss: 0.5119361916805534
Evaluating...
ROAUC scores:  {'Train': 0.8711850643157959, 'Validation': 0.8705085515975952}
=====Epoch 28
Training...


Iteration: 100%|██████████| 2188/2188 [02:26<00:00, 14.90it/s]


Average training loss: 0.5116528379688315
Evaluating...
ROAUC scores:  {'Train': 0.8696811199188232, 'Validation': 0.8686683773994446}
=====Epoch 29
Training...


Iteration: 100%|██████████| 2188/2188 [02:30<00:00, 14.51it/s]


Average training loss: 0.5109619313690954
Evaluating...
ROAUC scores:  {'Train': 0.8690903782844543, 'Validation': 0.8683627843856812}
=====Epoch 30
Training...


Iteration: 100%|██████████| 2188/2188 [02:29<00:00, 14.65it/s]


Average training loss: 0.5102716953079923
Evaluating...
ROAUC scores:  {'Train': 0.8720585107803345, 'Validation': 0.8706750869750977}
=====Epoch 31
Training...


Iteration: 100%|██████████| 2188/2188 [02:30<00:00, 14.54it/s]


Average training loss: 0.5100683278559113
Evaluating...
ROAUC scores:  {'Train': 0.8717700242996216, 'Validation': 0.8705282211303711}
=====Epoch 32
Training...


Iteration: 100%|██████████| 2188/2188 [02:27<00:00, 14.86it/s]


Average training loss: 0.5108012351627977
Evaluating...
ROAUC scores:  {'Train': 0.8716086149215698, 'Validation': 0.8702489733695984}
=====Epoch 33
Training...


Iteration: 100%|██████████| 2188/2188 [02:29<00:00, 14.59it/s]


Average training loss: 0.5100832527036858
Evaluating...
ROAUC scores:  {'Train': 0.8716449737548828, 'Validation': 0.8703022003173828}
=====Epoch 34
Training...


Iteration: 100%|██████████| 2188/2188 [02:23<00:00, 15.21it/s]


Average training loss: 0.5101161993481559
Evaluating...
ROAUC scores:  {'Train': 0.8715892434120178, 'Validation': 0.8704921007156372}
=====Epoch 35
Training...


Iteration: 100%|██████████| 2188/2188 [02:36<00:00, 13.97it/s]


Average training loss: 0.5098155506050565
Evaluating...
ROAUC scores:  {'Train': 0.871620774269104, 'Validation': 0.870195746421814}
=====Epoch 36
Training...


Iteration: 100%|██████████| 2188/2188 [02:26<00:00, 14.90it/s]


Average training loss: 0.5098625452170861
Evaluating...
ROAUC scores:  {'Train': 0.8710184097290039, 'Validation': 0.8694014549255371}
=====Epoch 37
Training...


Iteration: 100%|██████████| 2188/2188 [02:38<00:00, 13.79it/s]


Average training loss: 0.510017013487149
Evaluating...
ROAUC scores:  {'Train': 0.8727719783782959, 'Validation': 0.8709975481033325}
=====Epoch 38
Training...


Iteration: 100%|██████████| 2188/2188 [02:40<00:00, 13.62it/s]


Average training loss: 0.5092205377721264
Evaluating...
ROAUC scores:  {'Train': 0.8729770183563232, 'Validation': 0.8712793588638306}
=====Epoch 39
Training...


Iteration: 100%|██████████| 2188/2188 [02:40<00:00, 13.61it/s]


Average training loss: 0.5094419859036449
Evaluating...
ROAUC scores:  {'Train': 0.8726171851158142, 'Validation': 0.8710023164749146}
=====Epoch 40
Training...


Iteration: 100%|██████████| 2188/2188 [02:34<00:00, 14.18it/s]


Average training loss: 0.5094245782690685
Evaluating...
ROAUC scores:  {'Train': 0.8727693557739258, 'Validation': 0.8706198930740356}
=====Epoch 41
Training...


Iteration: 100%|██████████| 2188/2188 [02:36<00:00, 13.97it/s]


Average training loss: 0.5097303980283807
Evaluating...
ROAUC scores:  {'Train': 0.8717496395111084, 'Validation': 0.8700937628746033}
=====Epoch 42
Training...


Iteration: 100%|██████████| 2188/2188 [02:35<00:00, 14.03it/s]


Average training loss: 0.5090410223526972
Evaluating...
ROAUC scores:  {'Train': 0.8721281290054321, 'Validation': 0.8703560829162598}
=====Epoch 43
Training...


Iteration: 100%|██████████| 2188/2188 [02:35<00:00, 14.07it/s]


Average training loss: 0.5095544903789403
Evaluating...
ROAUC scores:  {'Train': 0.8738383054733276, 'Validation': 0.8717211484909058}
=====Epoch 44
Training...


Iteration: 100%|██████████| 2188/2188 [02:37<00:00, 13.85it/s]


Average training loss: 0.5074779623659699
Evaluating...
ROAUC scores:  {'Train': 0.8722825646400452, 'Validation': 0.8702841997146606}
=====Epoch 45
Training...


Iteration: 100%|██████████| 2188/2188 [02:30<00:00, 14.57it/s]


Average training loss: 0.507973619050583
Evaluating...
ROAUC scores:  {'Train': 0.8730079531669617, 'Validation': 0.8707097768783569}
=====Epoch 46
Training...


Iteration: 100%|██████████| 2188/2188 [02:37<00:00, 13.92it/s]


Average training loss: 0.5083661841763221
Evaluating...
ROAUC scores:  {'Train': 0.8746219873428345, 'Validation': 0.872497022151947}
=====Epoch 47
Training...


Iteration: 100%|██████████| 2188/2188 [02:29<00:00, 14.60it/s]


Average training loss: 0.5075355103508428
Evaluating...
ROAUC scores:  {'Train': 0.8732395172119141, 'Validation': 0.8715653419494629}
=====Epoch 48
Training...


Iteration: 100%|██████████| 2188/2188 [02:33<00:00, 14.22it/s]


Average training loss: 0.5081816935670005
Evaluating...
ROAUC scores:  {'Train': 0.8737385272979736, 'Validation': 0.8719351887702942}
=====Epoch 49
Training...


Iteration: 100%|██████████| 2188/2188 [02:40<00:00, 13.59it/s]


Average training loss: 0.5070123521703035
Evaluating...
ROAUC scores:  {'Train': 0.8724724054336548, 'Validation': 0.8699018955230713}
=====Epoch 50
Training...


Iteration: 100%|██████████| 2188/2188 [02:40<00:00, 13.60it/s]


Average training loss: 0.5077355822081757
Evaluating...
ROAUC scores:  {'Train': 0.8714710474014282, 'Validation': 0.8684195280075073}
=====Epoch 51
Training...


Iteration: 100%|██████████| 2188/2188 [02:40<00:00, 13.60it/s]


Average training loss: 0.5073321112440099
Evaluating...
ROAUC scores:  {'Train': 0.8747886419296265, 'Validation': 0.8721598982810974}
=====Epoch 52
Training...


Iteration: 100%|██████████| 2188/2188 [02:34<00:00, 14.13it/s]


Average training loss: 0.5070214968113402
Evaluating...
ROAUC scores:  {'Train': 0.8730165958404541, 'Validation': 0.8703686594963074}
=====Epoch 53
Training...


Iteration: 100%|██████████| 2188/2188 [02:36<00:00, 13.97it/s]


Average training loss: 0.5074606153439778
Evaluating...
ROAUC scores:  {'Train': 0.8735488653182983, 'Validation': 0.8708200454711914}
=====Epoch 54
Training...


Iteration: 100%|██████████| 2188/2188 [02:29<00:00, 14.65it/s]


Average training loss: 0.5066439653363699
Evaluating...
ROAUC scores:  {'Train': 0.8753085136413574, 'Validation': 0.8726427555084229}
=====Epoch 55
Training...


Iteration: 100%|██████████| 2188/2188 [02:35<00:00, 14.03it/s]


Average training loss: 0.5065600776557948
Evaluating...
ROAUC scores:  {'Train': 0.8749114274978638, 'Validation': 0.8711620569229126}
=====Epoch 56
Training...


Iteration: 100%|██████████| 2188/2188 [02:35<00:00, 14.10it/s]


Average training loss: 0.5064832573480754
Evaluating...
ROAUC scores:  {'Train': 0.875169038772583, 'Validation': 0.8720486760139465}
=====Epoch 57
Training...


Iteration: 100%|██████████| 2188/2188 [02:29<00:00, 14.64it/s]


Average training loss: 0.5061800209515928
Evaluating...
ROAUC scores:  {'Train': 0.873226523399353, 'Validation': 0.8703306913375854}
=====Epoch 58
Training...


Iteration: 100%|██████████| 2188/2188 [02:31<00:00, 14.46it/s]


Average training loss: 0.5072234925257223
Evaluating...
ROAUC scores:  {'Train': 0.8742249011993408, 'Validation': 0.8707195520401001}
=====Epoch 59
Training...


Iteration: 100%|██████████| 2188/2188 [02:28<00:00, 14.69it/s]


Average training loss: 0.5064382766811024
Evaluating...
ROAUC scores:  {'Train': 0.8753495216369629, 'Validation': 0.8724624514579773}
=====Epoch 60
Training...


Iteration: 100%|██████████| 2188/2188 [02:30<00:00, 14.50it/s]


Average training loss: 0.5063149960374483
Evaluating...
ROAUC scores:  {'Train': 0.8744620084762573, 'Validation': 0.8713162541389465}
=====Epoch 61
Training...


Iteration: 100%|██████████| 2188/2188 [02:28<00:00, 14.71it/s]


Average training loss: 0.5057592433919854
Evaluating...
ROAUC scores:  {'Train': 0.8747999668121338, 'Validation': 0.8717173337936401}
=====Epoch 62
Training...


Iteration: 100%|██████████| 2188/2188 [02:31<00:00, 14.44it/s]


Average training loss: 0.50597114006987
Evaluating...
ROAUC scores:  {'Train': 0.8757224082946777, 'Validation': 0.8723422288894653}
=====Epoch 63
Training...


Iteration: 100%|██████████| 2188/2188 [02:28<00:00, 14.75it/s]


Average training loss: 0.5053077738125974
Evaluating...
ROAUC scores:  {'Train': 0.8749850988388062, 'Validation': 0.8715770244598389}
=====Epoch 64
Training...


Iteration: 100%|██████████| 2188/2188 [02:29<00:00, 14.66it/s]


Average training loss: 0.5051008221681419
Evaluating...
ROAUC scores:  {'Train': 0.8760401010513306, 'Validation': 0.8723183870315552}
=====Epoch 65
Training...


Iteration: 100%|██████████| 2188/2188 [02:32<00:00, 14.34it/s]


Average training loss: 0.5045556282212573
Evaluating...
ROAUC scores:  {'Train': 0.8754938840866089, 'Validation': 0.872022271156311}
=====Epoch 66
Training...


Iteration: 100%|██████████| 2188/2188 [02:29<00:00, 14.66it/s]


Average training loss: 0.5042203830817282
Evaluating...
ROAUC scores:  {'Train': 0.8751358985900879, 'Validation': 0.8711485862731934}
=====Epoch 67
Training...


Iteration: 100%|██████████| 2188/2188 [02:32<00:00, 14.31it/s]


Average training loss: 0.504691223203591
Evaluating...
ROAUC scores:  {'Train': 0.87628173828125, 'Validation': 0.872530460357666}
=====Epoch 68
Training...


Iteration: 100%|██████████| 2188/2188 [02:27<00:00, 14.83it/s]


Average training loss: 0.5053473168360904
Evaluating...
ROAUC scores:  {'Train': 0.8753905296325684, 'Validation': 0.8715938925743103}
=====Epoch 69
Training...


Iteration: 100%|██████████| 2188/2188 [02:31<00:00, 14.41it/s]


Average training loss: 0.5048432003917817
Evaluating...
ROAUC scores:  {'Train': 0.8770575523376465, 'Validation': 0.873652458190918}
=====Epoch 70
Training...


Iteration: 100%|██████████| 2188/2188 [02:40<00:00, 13.60it/s]


Average training loss: 0.5041975821121519
Evaluating...
ROAUC scores:  {'Train': 0.8768487572669983, 'Validation': 0.8724943399429321}
=====Epoch 71
Training...


Iteration: 100%|██████████| 2188/2188 [02:34<00:00, 14.14it/s]


Average training loss: 0.5042252643834084
Evaluating...
ROAUC scores:  {'Train': 0.8770954608917236, 'Validation': 0.8727120161056519}
=====Epoch 72
Training...


Iteration: 100%|██████████| 2188/2188 [02:25<00:00, 14.99it/s]


Average training loss: 0.5042755333144878
Evaluating...
ROAUC scores:  {'Train': 0.8778943419456482, 'Validation': 0.8738393783569336}
=====Epoch 73
Training...


Iteration: 100%|██████████| 2188/2188 [02:36<00:00, 14.00it/s]


Average training loss: 0.5038617500793563
Evaluating...
ROAUC scores:  {'Train': 0.875625729560852, 'Validation': 0.8704603910446167}
=====Epoch 74
Training...


Iteration: 100%|██████████| 2188/2188 [02:34<00:00, 14.15it/s]


Average training loss: 0.5033223067003368
Evaluating...
ROAUC scores:  {'Train': 0.875677764415741, 'Validation': 0.8717756867408752}
=====Epoch 75
Training...


Iteration: 100%|██████████| 2188/2188 [02:35<00:00, 14.08it/s]


Average training loss: 0.5033711054016728
Evaluating...
ROAUC scores:  {'Train': 0.877160370349884, 'Validation': 0.8725188970565796}

Finished training!

ROAUC Test score: 0.8701311945915222


## Training of GCN based model

### Training with GPE 

In [14]:
QG_jets_dataset = import_dataset(name="QG_Jets")

gcn_model = GNN(num_classes = 2, num_layer = 2,num_post_fnn_layers=2,hasPos=True,input_edge_dim = 1,num_coords=2, 
                input_spec_fts_dim=3, gnn_type = 'gcn', emb_dim = 300, drop_ratio = 0.3).to(device)
optimizer = optim.Adam(gcn_model.parameters(), lr=1e-3)

train_model(gcn_model,optimizer,QG_jets_dataset)



=====Epoch 1
Training...


Iteration: 100%|██████████| 2188/2188 [00:37<00:00, 58.48it/s]


Average training loss: 0.5257440694561825
Evaluating...
ROAUC scores:  {'Train': 0.860903263092041, 'Validation': 0.8610101938247681}
=====Epoch 2
Training...


Iteration: 100%|██████████| 2188/2188 [00:26<00:00, 81.12it/s]


Average training loss: 0.5214273425353728
Evaluating...
ROAUC scores:  {'Train': 0.8606448173522949, 'Validation': 0.8606523275375366}
=====Epoch 3
Training...


Iteration: 100%|██████████| 2188/2188 [00:28<00:00, 78.13it/s]


Average training loss: 0.5193967459215978
Evaluating...
ROAUC scores:  {'Train': 0.8628547191619873, 'Validation': 0.862496018409729}
=====Epoch 4
Training...


Iteration: 100%|██████████| 2188/2188 [00:28<00:00, 76.01it/s]


Average training loss: 0.5188389459464842
Evaluating...
ROAUC scores:  {'Train': 0.8663569092750549, 'Validation': 0.8658032417297363}
=====Epoch 5
Training...


Iteration: 100%|██████████| 2188/2188 [00:29<00:00, 74.34it/s]


Average training loss: 0.5175307112076819
Evaluating...
ROAUC scores:  {'Train': 0.8637694120407104, 'Validation': 0.8636735677719116}
=====Epoch 6
Training...


Iteration: 100%|██████████| 2188/2188 [00:29<00:00, 72.96it/s]


Average training loss: 0.5176915379149168
Evaluating...
ROAUC scores:  {'Train': 0.8659428358078003, 'Validation': 0.8656755685806274}
=====Epoch 7
Training...


Iteration: 100%|██████████| 2188/2188 [00:31<00:00, 68.94it/s]


Average training loss: 0.5161082908865303
Evaluating...
ROAUC scores:  {'Train': 0.867822527885437, 'Validation': 0.8678199052810669}
=====Epoch 8
Training...


Iteration: 100%|██████████| 2188/2188 [00:35<00:00, 61.55it/s]


Average training loss: 0.5164533771261218
Evaluating...
ROAUC scores:  {'Train': 0.8676608800888062, 'Validation': 0.8678642511367798}
=====Epoch 9
Training...


Iteration: 100%|██████████| 2188/2188 [00:30<00:00, 70.61it/s]


Average training loss: 0.5159136407363349
Evaluating...
ROAUC scores:  {'Train': 0.8679288625717163, 'Validation': 0.8676323890686035}
=====Epoch 10
Training...


Iteration: 100%|██████████| 2188/2188 [00:31<00:00, 69.86it/s]


Average training loss: 0.5153324420970798
Evaluating...
ROAUC scores:  {'Train': 0.8676282167434692, 'Validation': 0.8670448660850525}
=====Epoch 11
Training...


Iteration: 100%|██████████| 2188/2188 [00:35<00:00, 61.64it/s]


Average training loss: 0.5141758615290661
Evaluating...
ROAUC scores:  {'Train': 0.8693467974662781, 'Validation': 0.869069516658783}
=====Epoch 12
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 66.61it/s]


Average training loss: 0.5150229798347248
Evaluating...
ROAUC scores:  {'Train': 0.8689647912979126, 'Validation': 0.8687614798545837}
=====Epoch 13
Training...


Iteration: 100%|██████████| 2188/2188 [00:31<00:00, 69.28it/s]


Average training loss: 0.5137417530419621
Evaluating...
ROAUC scores:  {'Train': 0.8705579042434692, 'Validation': 0.8701055645942688}
=====Epoch 14
Training...


Iteration: 100%|██████████| 2188/2188 [00:31<00:00, 68.96it/s]


Average training loss: 0.5135110277253487
Evaluating...
ROAUC scores:  {'Train': 0.869513750076294, 'Validation': 0.8687781095504761}
=====Epoch 15
Training...


Iteration: 100%|██████████| 2188/2188 [00:31<00:00, 68.74it/s]


Average training loss: 0.513877157414744
Evaluating...
ROAUC scores:  {'Train': 0.8695979714393616, 'Validation': 0.8687388896942139}
=====Epoch 16
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.77it/s]


Average training loss: 0.5127481003811198
Evaluating...
ROAUC scores:  {'Train': 0.8696922063827515, 'Validation': 0.8689446449279785}
=====Epoch 17
Training...


Iteration: 100%|██████████| 2188/2188 [00:31<00:00, 68.77it/s]


Average training loss: 0.5132966220678533
Evaluating...
ROAUC scores:  {'Train': 0.8699267506599426, 'Validation': 0.8692046403884888}
=====Epoch 18
Training...


Iteration: 100%|██████████| 2188/2188 [00:31<00:00, 69.71it/s]


Average training loss: 0.5117444381853344
Evaluating...
ROAUC scores:  {'Train': 0.8711331486701965, 'Validation': 0.8702178597450256}
=====Epoch 19
Training...


Iteration: 100%|██████████| 2188/2188 [00:31<00:00, 68.93it/s]


Average training loss: 0.5114475526800976
Evaluating...
ROAUC scores:  {'Train': 0.8717547655105591, 'Validation': 0.871400773525238}
=====Epoch 20
Training...


Iteration: 100%|██████████| 2188/2188 [00:31<00:00, 69.26it/s]


Average training loss: 0.5125659418988707
Evaluating...
ROAUC scores:  {'Train': 0.8713948726654053, 'Validation': 0.8708136081695557}
=====Epoch 21
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 68.35it/s]


Average training loss: 0.5119466441504699
Evaluating...
ROAUC scores:  {'Train': 0.8707377910614014, 'Validation': 0.8702837824821472}
=====Epoch 22
Training...


Iteration: 100%|██████████| 2188/2188 [00:31<00:00, 69.17it/s]


Average training loss: 0.5113799173718615
Evaluating...
ROAUC scores:  {'Train': 0.8727436065673828, 'Validation': 0.8723272681236267}
=====Epoch 23
Training...


Iteration: 100%|██████████| 2188/2188 [00:31<00:00, 68.85it/s]


Average training loss: 0.5113020600032545
Evaluating...
ROAUC scores:  {'Train': 0.8715546131134033, 'Validation': 0.8714849948883057}
=====Epoch 24
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.65it/s]


Average training loss: 0.5111789701326019
Evaluating...
ROAUC scores:  {'Train': 0.8720149397850037, 'Validation': 0.8714002370834351}
=====Epoch 25
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.79it/s]


Average training loss: 0.5109417732444063
Evaluating...
ROAUC scores:  {'Train': 0.870933473110199, 'Validation': 0.8710314035415649}
=====Epoch 26
Training...


Iteration: 100%|██████████| 2188/2188 [00:31<00:00, 68.63it/s]


Average training loss: 0.5108577587405114
Evaluating...
ROAUC scores:  {'Train': 0.8709770441055298, 'Validation': 0.870854377746582}
=====Epoch 27
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.04it/s]


Average training loss: 0.5109124490259116
Evaluating...
ROAUC scores:  {'Train': 0.8717897534370422, 'Validation': 0.8714373707771301}
=====Epoch 28
Training...


Iteration: 100%|██████████| 2188/2188 [00:31<00:00, 69.06it/s]


Average training loss: 0.5106170829487675
Evaluating...
ROAUC scores:  {'Train': 0.8718029260635376, 'Validation': 0.8710183501243591}
=====Epoch 29
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.45it/s]


Average training loss: 0.5099469643406284
Evaluating...
ROAUC scores:  {'Train': 0.870331883430481, 'Validation': 0.8699193000793457}
=====Epoch 30
Training...


Iteration: 100%|██████████| 2188/2188 [00:31<00:00, 68.71it/s]


Average training loss: 0.5102101842816613
Evaluating...
ROAUC scores:  {'Train': 0.8722693920135498, 'Validation': 0.8710672855377197}
=====Epoch 31
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 66.67it/s]


Average training loss: 0.5096165160401648
Evaluating...
ROAUC scores:  {'Train': 0.8735992908477783, 'Validation': 0.8731201887130737}
=====Epoch 32
Training...


Iteration: 100%|██████████| 2188/2188 [00:31<00:00, 68.40it/s]


Average training loss: 0.5104135832931486
Evaluating...
ROAUC scores:  {'Train': 0.8733552694320679, 'Validation': 0.8730689287185669}
=====Epoch 33
Training...


Iteration: 100%|██████████| 2188/2188 [00:34<00:00, 63.22it/s]


Average training loss: 0.5098663116830794
Evaluating...
ROAUC scores:  {'Train': 0.8726544380187988, 'Validation': 0.8720848560333252}
=====Epoch 34
Training...


Iteration: 100%|██████████| 2188/2188 [00:33<00:00, 66.14it/s]


Average training loss: 0.509321613206615
Evaluating...
ROAUC scores:  {'Train': 0.870938777923584, 'Validation': 0.8706455230712891}
=====Epoch 35
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.21it/s]


Average training loss: 0.5096721031875018
Evaluating...
ROAUC scores:  {'Train': 0.8733812570571899, 'Validation': 0.8724511861801147}
=====Epoch 36
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.02it/s]


Average training loss: 0.509131707523488
Evaluating...
ROAUC scores:  {'Train': 0.8725975155830383, 'Validation': 0.8713606595993042}
=====Epoch 37
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 66.73it/s]


Average training loss: 0.5096034777692946
Evaluating...
ROAUC scores:  {'Train': 0.873694896697998, 'Validation': 0.8731011748313904}
=====Epoch 38
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.06it/s]


Average training loss: 0.5094285822138054
Evaluating...
ROAUC scores:  {'Train': 0.8737905621528625, 'Validation': 0.8731966018676758}
=====Epoch 39
Training...


Iteration: 100%|██████████| 2188/2188 [00:33<00:00, 64.36it/s]


Average training loss: 0.5090133859154511
Evaluating...
ROAUC scores:  {'Train': 0.8728063702583313, 'Validation': 0.8713783025741577}
=====Epoch 40
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 66.56it/s]


Average training loss: 0.5087978407958526
Evaluating...
ROAUC scores:  {'Train': 0.8674192428588867, 'Validation': 0.8660330772399902}
=====Epoch 41
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 66.60it/s]


Average training loss: 0.5091164713449626
Evaluating...
ROAUC scores:  {'Train': 0.8726012110710144, 'Validation': 0.8713294267654419}
=====Epoch 42
Training...


Iteration: 100%|██████████| 2188/2188 [00:33<00:00, 66.18it/s]


Average training loss: 0.5091617870565088
Evaluating...
ROAUC scores:  {'Train': 0.872002363204956, 'Validation': 0.8711861371994019}
=====Epoch 43
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 66.44it/s]


Average training loss: 0.5090467055793
Evaluating...
ROAUC scores:  {'Train': 0.8744007349014282, 'Validation': 0.8727902173995972}
=====Epoch 44
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.02it/s]


Average training loss: 0.5086938696326024
Evaluating...
ROAUC scores:  {'Train': 0.8746724724769592, 'Validation': 0.873274564743042}
=====Epoch 45
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 66.89it/s]


Average training loss: 0.5088549663586852
Evaluating...
ROAUC scores:  {'Train': 0.8726706504821777, 'Validation': 0.8720767498016357}
=====Epoch 46
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.06it/s]


Average training loss: 0.5084073296058549
Evaluating...
ROAUC scores:  {'Train': 0.8729790449142456, 'Validation': 0.8717129230499268}
=====Epoch 47
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.54it/s]


Average training loss: 0.5078337680841277
Evaluating...
ROAUC scores:  {'Train': 0.8728843927383423, 'Validation': 0.8712385892868042}
=====Epoch 48
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.26it/s]


Average training loss: 0.5080387970429234
Evaluating...
ROAUC scores:  {'Train': 0.8725199699401855, 'Validation': 0.8713496923446655}
=====Epoch 49
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.60it/s]


Average training loss: 0.5083754540470444
Evaluating...
ROAUC scores:  {'Train': 0.8733323216438293, 'Validation': 0.8714393973350525}
=====Epoch 50
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.22it/s]


Average training loss: 0.5089286198972348
Evaluating...
ROAUC scores:  {'Train': 0.8747598528862, 'Validation': 0.8737496137619019}
=====Epoch 51
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 66.93it/s]


Average training loss: 0.5084059044201151
Evaluating...
ROAUC scores:  {'Train': 0.8738851547241211, 'Validation': 0.8727562427520752}
=====Epoch 52
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.31it/s]


Average training loss: 0.5082800011584702
Evaluating...
ROAUC scores:  {'Train': 0.8736099004745483, 'Validation': 0.871715784072876}
=====Epoch 53
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.54it/s]


Average training loss: 0.5074106070197697
Evaluating...
ROAUC scores:  {'Train': 0.8742491602897644, 'Validation': 0.8723645210266113}
=====Epoch 54
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.15it/s]


Average training loss: 0.5078166128430768
Evaluating...
ROAUC scores:  {'Train': 0.8752480745315552, 'Validation': 0.874092698097229}
=====Epoch 55
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.23it/s]


Average training loss: 0.5079840662901101
Evaluating...
ROAUC scores:  {'Train': 0.874587893486023, 'Validation': 0.8735040426254272}
=====Epoch 56
Training...


Iteration: 100%|██████████| 2188/2188 [00:33<00:00, 66.28it/s]


Average training loss: 0.507491496325193
Evaluating...
ROAUC scores:  {'Train': 0.8747178316116333, 'Validation': 0.8737173080444336}
=====Epoch 57
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 66.36it/s]


Average training loss: 0.5070916800002075
Evaluating...
ROAUC scores:  {'Train': 0.8751394748687744, 'Validation': 0.8733458518981934}
=====Epoch 58
Training...


Iteration: 100%|██████████| 2188/2188 [00:33<00:00, 65.12it/s]


Average training loss: 0.5084673817213319
Evaluating...
ROAUC scores:  {'Train': 0.8746107816696167, 'Validation': 0.8733582496643066}
=====Epoch 59
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.20it/s]


Average training loss: 0.5077275579192521
Evaluating...
ROAUC scores:  {'Train': 0.8738857507705688, 'Validation': 0.8717935085296631}
=====Epoch 60
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.70it/s]


Average training loss: 0.5075015082710186
Evaluating...
ROAUC scores:  {'Train': 0.8748542070388794, 'Validation': 0.8735019564628601}
=====Epoch 61
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.56it/s]


Average training loss: 0.5081001302632597
Evaluating...
ROAUC scores:  {'Train': 0.8744024634361267, 'Validation': 0.8727296590805054}
=====Epoch 62
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.90it/s]


Average training loss: 0.5071621124906776
Evaluating...
ROAUC scores:  {'Train': 0.8742270469665527, 'Validation': 0.8727160096168518}
=====Epoch 63
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.72it/s]


Average training loss: 0.5070587752397143
Evaluating...
ROAUC scores:  {'Train': 0.872542679309845, 'Validation': 0.871611475944519}
=====Epoch 64
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.90it/s]


Average training loss: 0.5076284701416653
Evaluating...
ROAUC scores:  {'Train': 0.8759008646011353, 'Validation': 0.8738919496536255}
=====Epoch 65
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.92it/s]


Average training loss: 0.5069820760209556
Evaluating...
ROAUC scores:  {'Train': 0.8749250173568726, 'Validation': 0.8733831644058228}
=====Epoch 66
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 66.61it/s]


Average training loss: 0.5082310502582538
Evaluating...
ROAUC scores:  {'Train': 0.875460147857666, 'Validation': 0.8736337423324585}
=====Epoch 67
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 68.10it/s]


Average training loss: 0.5070033209360931
Evaluating...
ROAUC scores:  {'Train': 0.8733037710189819, 'Validation': 0.8717614412307739}
=====Epoch 68
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.50it/s]


Average training loss: 0.5069428747026097
Evaluating...
ROAUC scores:  {'Train': 0.8747228384017944, 'Validation': 0.8733593821525574}
=====Epoch 69
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.79it/s]


Average training loss: 0.5072115294539951
Evaluating...
ROAUC scores:  {'Train': 0.8749822378158569, 'Validation': 0.8734769821166992}
=====Epoch 70
Training...


Iteration: 100%|██████████| 2188/2188 [00:32<00:00, 67.83it/s]


Average training loss: 0.5073927987836833
Evaluating...
ROAUC scores:  {'Train': 0.8744055032730103, 'Validation': 0.8724023103713989}
=====Epoch 71
Training...


Iteration: 100%|██████████| 2188/2188 [00:33<00:00, 64.88it/s]


Average training loss: 0.5070534109442718
Evaluating...
ROAUC scores:  {'Train': 0.8741073608398438, 'Validation': 0.8726912140846252}
=====Epoch 72
Training...


Iteration: 100%|██████████| 2188/2188 [00:35<00:00, 62.02it/s]


Average training loss: 0.5066361276933437
Evaluating...
ROAUC scores:  {'Train': 0.8755701780319214, 'Validation': 0.8738547563552856}
=====Epoch 73
Training...


Iteration: 100%|██████████| 2188/2188 [00:41<00:00, 52.17it/s]


Average training loss: 0.507061451932303
Evaluating...
ROAUC scores:  {'Train': 0.873927652835846, 'Validation': 0.8719356060028076}
=====Epoch 74
Training...


Iteration: 100%|██████████| 2188/2188 [00:42<00:00, 51.99it/s]


Average training loss: 0.506802327317991
Evaluating...
ROAUC scores:  {'Train': 0.8750592470169067, 'Validation': 0.8738869428634644}
=====Epoch 75
Training...


Iteration: 100%|██████████| 2188/2188 [00:42<00:00, 51.66it/s]


Average training loss: 0.5069926199545808
Evaluating...
ROAUC scores:  {'Train': 0.8753765225410461, 'Validation': 0.8736090660095215}

Finished training!

ROAUC Test score: 0.873801589012146


## Training Transformers based GPS model

### Training with GPE summed with Random-walk embedding

In [15]:
# importing dataset for GPS 
transform = T.AddRandomWalkPE(walk_length=20, attr_name='extraPE') # adding random walk positional encoding
QG_jets_dataset = import_dataset(name="QG_Jets",
                             pre_transform=transform)

GPS_model = GNN(num_classes = 2, num_layer = 2,num_pre_fnn_layers=1,num_post_fnn_layers=2,hasPos=True,input_edge_dim = 1,num_coords=2, 
                input_spec_fts_dim=3, gnn_type = 'gps', emb_dim=300, extraPE_dim=20, drop_ratio = 0.3).to(device)
optimizer = optim.Adam(GPS_model.parameters(), lr=1e-3)

train_model(GPS_model,optimizer,QG_jets_dataset,batch_size=16)

=====Epoch 1
Training...


Iteration: 100%|██████████| 4375/4375 [02:27<00:00, 29.62it/s]


Average training loss: 0.5423124539034707
Evaluating...
ROAUC scores:  {'Train': 0.852789580821991, 'Validation': 0.8515341281890869}
=====Epoch 2
Training...


Iteration: 100%|██████████| 4375/4375 [02:38<00:00, 27.63it/s]


Average training loss: 0.5364028421878815
Evaluating...
ROAUC scores:  {'Train': 0.8397599458694458, 'Validation': 0.8411649465560913}
=====Epoch 3
Training...


Iteration: 100%|██████████| 4375/4375 [04:45<00:00, 15.32it/s]


Average training loss: 0.5315140270437513
Evaluating...
ROAUC scores:  {'Train': 0.8583822250366211, 'Validation': 0.8584153652191162}
=====Epoch 4
Training...


Iteration: 100%|██████████| 4375/4375 [04:51<00:00, 14.99it/s]


Average training loss: 0.5294418691090175
Evaluating...
ROAUC scores:  {'Train': 0.8597381114959717, 'Validation': 0.8600350022315979}
=====Epoch 5
Training...


Iteration: 100%|██████████| 4375/4375 [04:56<00:00, 14.77it/s]


Average training loss: 0.5270990559509823
Evaluating...
ROAUC scores:  {'Train': 0.8588021993637085, 'Validation': 0.8587085008621216}
=====Epoch 6
Training...


Iteration: 100%|██████████| 4375/4375 [04:53<00:00, 14.92it/s]


Average training loss: 0.5272001102719989
Evaluating...
ROAUC scores:  {'Train': 0.8453062772750854, 'Validation': 0.8457729816436768}
=====Epoch 7
Training...


Iteration: 100%|██████████| 4375/4375 [04:52<00:00, 14.94it/s]


Average training loss: 0.5243762910502298
Evaluating...
ROAUC scores:  {'Train': 0.8573293685913086, 'Validation': 0.8576071858406067}
=====Epoch 8
Training...


Iteration: 100%|██████████| 4375/4375 [03:48<00:00, 19.13it/s]


Average training loss: 0.5242033536502293
Evaluating...
ROAUC scores:  {'Train': 0.8600507974624634, 'Validation': 0.8588091135025024}
=====Epoch 9
Training...


Iteration: 100%|██████████| 4375/4375 [03:50<00:00, 18.94it/s]


Average training loss: 0.5253698536872864
Evaluating...
ROAUC scores:  {'Train': 0.8589409589767456, 'Validation': 0.8582242727279663}
=====Epoch 10
Training...


Iteration: 100%|██████████| 4375/4375 [03:57<00:00, 18.41it/s]


Average training loss: 0.521920531470435
Evaluating...
ROAUC scores:  {'Train': 0.8589451909065247, 'Validation': 0.8587508201599121}
=====Epoch 11
Training...


Iteration: 100%|██████████| 4375/4375 [03:59<00:00, 18.25it/s]


Average training loss: 0.5204373253549849
Evaluating...
ROAUC scores:  {'Train': 0.8629381656646729, 'Validation': 0.8626424074172974}
=====Epoch 12
Training...


Iteration: 100%|██████████| 4375/4375 [04:31<00:00, 16.12it/s]


Average training loss: 0.5219879601682935
Evaluating...
ROAUC scores:  {'Train': 0.8620893359184265, 'Validation': 0.8612008094787598}
=====Epoch 13
Training...


Iteration: 100%|██████████| 4375/4375 [04:18<00:00, 16.95it/s]


Average training loss: 0.5226166635377066
Evaluating...
ROAUC scores:  {'Train': 0.860072135925293, 'Validation': 0.8596490621566772}
=====Epoch 14
Training...


Iteration: 100%|██████████| 4375/4375 [03:54<00:00, 18.63it/s]


Average training loss: 0.5220360334600721
Evaluating...
ROAUC scores:  {'Train': 0.8631078004837036, 'Validation': 0.8622255325317383}
=====Epoch 15
Training...


Iteration: 100%|██████████| 4375/4375 [04:05<00:00, 17.84it/s]


Average training loss: 0.5202655040740967
Evaluating...
ROAUC scores:  {'Train': 0.86246657371521, 'Validation': 0.8618073463439941}
=====Epoch 16
Training...


Iteration: 100%|██████████| 4375/4375 [03:40<00:00, 19.87it/s]


Average training loss: 0.519895736353738
Evaluating...
ROAUC scores:  {'Train': 0.8613380789756775, 'Validation': 0.8605377674102783}
=====Epoch 17
Training...


Iteration: 100%|██████████| 4375/4375 [03:39<00:00, 19.94it/s]


Average training loss: 0.5199836931228637
Evaluating...
ROAUC scores:  {'Train': 0.8619658946990967, 'Validation': 0.8610340356826782}
=====Epoch 18
Training...


Iteration: 100%|██████████| 4375/4375 [03:33<00:00, 20.46it/s]


Average training loss: 0.5195532656397138
Evaluating...
ROAUC scores:  {'Train': 0.8647400140762329, 'Validation': 0.8635700941085815}
=====Epoch 19
Training...


Iteration: 100%|██████████| 4375/4375 [03:30<00:00, 20.74it/s]


Average training loss: 0.51881748919487
Evaluating...
ROAUC scores:  {'Train': 0.8639123439788818, 'Validation': 0.8623226881027222}
=====Epoch 20
Training...


Iteration: 100%|██████████| 4375/4375 [03:24<00:00, 21.38it/s]


Average training loss: 0.5183297642844064
Evaluating...
ROAUC scores:  {'Train': 0.8616541624069214, 'Validation': 0.8603137731552124}
=====Epoch 21
Training...


Iteration: 100%|██████████| 4375/4375 [03:04<00:00, 23.68it/s]


Average training loss: 0.5187926824092866
Evaluating...
ROAUC scores:  {'Train': 0.8615453243255615, 'Validation': 0.8615857362747192}
=====Epoch 22
Training...


Iteration: 100%|██████████| 4375/4375 [03:16<00:00, 22.25it/s]


Average training loss: 0.5194772891521454
Evaluating...
ROAUC scores:  {'Train': 0.8587016463279724, 'Validation': 0.857795238494873}
=====Epoch 23
Training...


Iteration: 100%|██████████| 4375/4375 [03:19<00:00, 21.93it/s]


Average training loss: 0.521063367169244
Evaluating...
ROAUC scores:  {'Train': 0.8547612428665161, 'Validation': 0.854282557964325}
=====Epoch 24
Training...


Iteration: 100%|██████████| 4375/4375 [03:18<00:00, 22.04it/s]


Average training loss: 0.5222595492294857
Evaluating...
ROAUC scores:  {'Train': 0.860120415687561, 'Validation': 0.8588464856147766}
=====Epoch 25
Training...


Iteration: 100%|██████████| 4375/4375 [03:10<00:00, 22.94it/s]


Average training loss: 0.5187579287528992
Evaluating...
ROAUC scores:  {'Train': 0.8642932176589966, 'Validation': 0.8631460666656494}
=====Epoch 26
Training...


Iteration: 100%|██████████| 4375/4375 [03:03<00:00, 23.86it/s]


Average training loss: 0.5188018960339682
Evaluating...


=====Epoch 27
Training...


Iteration: 100%|██████████| 4375/4375 [02:22<00:00, 30.64it/s]


Average training loss: 0.518549009690966
Evaluating...
ROAUC scores:  {'Train': 0.8634494543075562, 'Validation': 0.8636170625686646}
=====Epoch 28
Training...


Iteration: 100%|██████████| 4375/4375 [02:29<00:00, 29.21it/s]


Average training loss: 0.5168801335743496
Evaluating...
ROAUC scores:  {'Train': 0.8617868423461914, 'Validation': 0.8609186410903931}
=====Epoch 29
Training...


Iteration: 100%|██████████| 4375/4375 [02:42<00:00, 26.89it/s]


Average training loss: 0.5247454071726118
Evaluating...
ROAUC scores:  {'Train': 0.8641334772109985, 'Validation': 0.863054633140564}
=====Epoch 30
Training...


Iteration: 100%|██████████| 4375/4375 [03:01<00:00, 24.16it/s]


Average training loss: 0.5181355385507856
Evaluating...
ROAUC scores:  {'Train': 0.8647853136062622, 'Validation': 0.8637218475341797}
=====Epoch 31
Training...


Iteration: 100%|██████████| 4375/4375 [03:08<00:00, 23.15it/s]


Average training loss: 0.5166640813895634
Evaluating...
ROAUC scores:  {'Train': 0.8633993864059448, 'Validation': 0.8626481294631958}
=====Epoch 32
Training...


Iteration: 100%|██████████| 4375/4375 [02:59<00:00, 24.40it/s]


Average training loss: 0.5178880507673536
Evaluating...
ROAUC scores:  {'Train': 0.8660768270492554, 'Validation': 0.8654994964599609}
=====Epoch 33
Training...


Iteration: 100%|██████████| 4375/4375 [03:15<00:00, 22.33it/s]


Average training loss: 0.5162514325005667
Evaluating...
ROAUC scores:  {'Train': 0.8663890957832336, 'Validation': 0.8661628365516663}
=====Epoch 34
Training...


Iteration: 100%|██████████| 4375/4375 [03:10<00:00, 22.95it/s]


Average training loss: 0.5178138997146061
Evaluating...
ROAUC scores:  {'Train': 0.8637841939926147, 'Validation': 0.8640825748443604}
=====Epoch 35
Training...


Iteration: 100%|██████████| 4375/4375 [03:04<00:00, 23.70it/s]


Average training loss: 0.5165479603426797
Evaluating...
ROAUC scores:  {'Train': 0.8632340431213379, 'Validation': 0.8620650172233582}
=====Epoch 36
Training...


Iteration: 100%|██████████| 4375/4375 [03:05<00:00, 23.64it/s]


Average training loss: 0.5180678240912301
Evaluating...
ROAUC scores:  {'Train': 0.8652162551879883, 'Validation': 0.8635404706001282}
=====Epoch 37
Training...


Iteration: 100%|██████████| 4375/4375 [03:11<00:00, 22.85it/s]


Average training loss: 0.5171776694161552
Evaluating...
ROAUC scores:  {'Train': 0.8615982532501221, 'Validation': 0.8612436056137085}
=====Epoch 38
Training...


Iteration: 100%|██████████| 4375/4375 [03:13<00:00, 22.63it/s]


Average training loss: 0.5156425375870296
Evaluating...
ROAUC scores:  {'Train': 0.8658615350723267, 'Validation': 0.8653901815414429}
=====Epoch 39
Training...


Iteration: 100%|██████████| 4375/4375 [03:08<00:00, 23.19it/s]


Average training loss: 0.5166345295224871
Evaluating...
ROAUC scores:  {'Train': 0.8616796135902405, 'Validation': 0.8607050180435181}
=====Epoch 40
Training...


Iteration: 100%|██████████| 4375/4375 [03:21<00:00, 21.73it/s]


Average training loss: 0.517081686108453
Evaluating...
ROAUC scores:  {'Train': 0.8651921153068542, 'Validation': 0.8644276857376099}
=====Epoch 41
Training...


Iteration: 100%|██████████| 4375/4375 [02:59<00:00, 24.41it/s]


Average training loss: 0.5177349919864109
Evaluating...
ROAUC scores:  {'Train': 0.8640519380569458, 'Validation': 0.8640450239181519}
=====Epoch 42
Training...


Iteration: 100%|██████████| 4375/4375 [02:57<00:00, 24.62it/s]


Average training loss: 0.5167003959315164
Evaluating...
ROAUC scores:  {'Train': 0.8638925552368164, 'Validation': 0.8640033006668091}
=====Epoch 43
Training...


Iteration: 100%|██████████| 4375/4375 [02:56<00:00, 24.80it/s]


Average training loss: 0.5176188937459674
Evaluating...
ROAUC scores:  {'Train': 0.8630677461624146, 'Validation': 0.8623066544532776}
=====Epoch 44
Training...


Iteration: 100%|██████████| 4375/4375 [03:06<00:00, 23.52it/s]


Average training loss: 0.5166484396117074
Evaluating...
ROAUC scores:  {'Train': 0.863801121711731, 'Validation': 0.8637722730636597}
=====Epoch 45
Training...


Iteration: 100%|██████████| 4375/4375 [03:07<00:00, 23.33it/s]


Average training loss: 0.5179656341484615
Evaluating...
ROAUC scores:  {'Train': 0.8652269244194031, 'Validation': 0.8645574450492859}
=====Epoch 46
Training...


Iteration: 100%|██████████| 4375/4375 [02:57<00:00, 24.58it/s]


Average training loss: 0.5165763446671622
Evaluating...
ROAUC scores:  {'Train': 0.8636411428451538, 'Validation': 0.8630264401435852}
=====Epoch 47
Training...


Iteration: 100%|██████████| 4375/4375 [03:29<00:00, 20.90it/s]


Average training loss: 0.51865568888528
Evaluating...
ROAUC scores:  {'Train': 0.8650760650634766, 'Validation': 0.8650192022323608}
=====Epoch 48
Training...


Iteration: 100%|██████████| 4375/4375 [03:26<00:00, 21.20it/s]


Average training loss: 0.5180612385477339
Evaluating...
ROAUC scores:  {'Train': 0.8584153652191162, 'Validation': 0.8589514493942261}
=====Epoch 49
Training...


Iteration: 100%|██████████| 4375/4375 [03:25<00:00, 21.32it/s]


Average training loss: 0.5204483228887831
Evaluating...
ROAUC scores:  {'Train': 0.864727795124054, 'Validation': 0.8632930517196655}
=====Epoch 50
Training...


Iteration: 100%|██████████| 4375/4375 [03:21<00:00, 21.68it/s]


Average training loss: 0.5154471651826587
Evaluating...
ROAUC scores:  {'Train': 0.8637420535087585, 'Validation': 0.8634814023971558}
=====Epoch 51
Training...


Iteration: 100%|██████████| 4375/4375 [03:20<00:00, 21.86it/s]


Average training loss: 0.5161444251605443
Evaluating...
ROAUC scores:  {'Train': 0.8658727407455444, 'Validation': 0.864318311214447}
=====Epoch 52
Training...


Iteration: 100%|██████████| 4375/4375 [03:37<00:00, 20.16it/s]


Average training loss: 0.5167650491510118
Evaluating...
ROAUC scores:  {'Train': 0.8631484508514404, 'Validation': 0.8617367744445801}
=====Epoch 53
Training...


Iteration: 100%|██████████| 4375/4375 [03:38<00:00, 19.98it/s]


Average training loss: 0.5158762222017561
Evaluating...
ROAUC scores:  {'Train': 0.8654677867889404, 'Validation': 0.8645888566970825}
=====Epoch 54
Training...


Iteration: 100%|██████████| 4375/4375 [03:41<00:00, 19.76it/s]


Average training loss: 0.5166287628310068
Evaluating...
ROAUC scores:  {'Train': 0.863450288772583, 'Validation': 0.8627854585647583}
=====Epoch 55
Training...


Iteration: 100%|██████████| 4375/4375 [03:40<00:00, 19.86it/s]


Average training loss: 0.5163293659005846
Evaluating...
ROAUC scores:  {'Train': 0.8645160794258118, 'Validation': 0.8633794784545898}
=====Epoch 56
Training...


Iteration: 100%|██████████| 4375/4375 [04:17<00:00, 17.01it/s]


Average training loss: 0.5161443832261222
Evaluating...
ROAUC scores:  {'Train': 0.8655692338943481, 'Validation': 0.8649299740791321}
=====Epoch 57
Training...


Iteration: 100%|██████████| 4375/4375 [02:49<00:00, 25.75it/s]


Average training loss: 0.5168030478341239
Evaluating...
ROAUC scores:  {'Train': 0.8657070994377136, 'Validation': 0.865304708480835}
=====Epoch 58
Training...


Iteration: 100%|██████████| 4375/4375 [04:00<00:00, 18.22it/s]


Average training loss: 0.5167851876871926
Evaluating...
ROAUC scores:  {'Train': 0.8629458546638489, 'Validation': 0.8622485995292664}
=====Epoch 59
Training...


Iteration: 100%|██████████| 4375/4375 [07:27<00:00,  9.79it/s]


Average training loss: 0.5167347177777971
Evaluating...
ROAUC scores:  {'Train': 0.8640303611755371, 'Validation': 0.8639055490493774}
=====Epoch 60
Training...


Iteration: 100%|██████████| 4375/4375 [09:11<00:00,  7.94it/s]


Average training loss: 0.5170260637147086
Evaluating...
ROAUC scores:  {'Train': 0.8649024963378906, 'Validation': 0.8641939163208008}
=====Epoch 61
Training...


Iteration: 100%|██████████| 4375/4375 [11:05<00:00,  6.58it/s]


Average training loss: 0.5165872487613133
Evaluating...
ROAUC scores:  {'Train': 0.8636777400970459, 'Validation': 0.8633352518081665}
=====Epoch 62
Training...


Iteration: 100%|██████████| 4375/4375 [12:12<00:00,  5.97it/s]


Average training loss: 0.5171674356119973
Evaluating...
ROAUC scores:  {'Train': 0.8600934743881226, 'Validation': 0.8594541549682617}
=====Epoch 63
Training...


Iteration: 100%|██████████| 4375/4375 [13:18<00:00,  5.48it/s]


Average training loss: 0.516054750422069
Evaluating...
ROAUC scores:  {'Train': 0.8628381490707397, 'Validation': 0.8624252676963806}
=====Epoch 64
Training...


Iteration: 100%|██████████| 4375/4375 [15:04<00:00,  4.84it/s]


Average training loss: 0.5157286580426352
Evaluating...
ROAUC scores:  {'Train': 0.8637243509292603, 'Validation': 0.8635921478271484}
=====Epoch 65
Training...


Iteration: 100%|██████████| 4375/4375 [15:36<00:00,  4.67it/s]


Average training loss: 0.5181633014951433
Evaluating...
ROAUC scores:  {'Train': 0.8559607267379761, 'Validation': 0.8566691875457764}
=====Epoch 66
Training...


Iteration: 100%|██████████| 4375/4375 [15:23<00:00,  4.74it/s]


Average training loss: 0.5197811382974897
Evaluating...
ROAUC scores:  {'Train': 0.8610608577728271, 'Validation': 0.8608206510543823}
=====Epoch 67
Training...


Iteration: 100%|██████████| 4375/4375 [14:49<00:00,  4.92it/s]


Average training loss: 0.5175616487366813
Evaluating...
ROAUC scores:  {'Train': 0.8646191358566284, 'Validation': 0.863286554813385}
=====Epoch 68
Training...


Iteration: 100%|██████████| 4375/4375 [14:48<00:00,  4.92it/s]


Average training loss: 0.5159902996335711
Evaluating...
ROAUC scores:  {'Train': 0.8656014800071716, 'Validation': 0.8651618957519531}
=====Epoch 69
Training...


Iteration: 100%|██████████| 4375/4375 [15:02<00:00,  4.85it/s]


Average training loss: 0.5158113384587424
Evaluating...
ROAUC scores:  {'Train': 0.8643971681594849, 'Validation': 0.8643567562103271}
=====Epoch 70
Training...


Iteration: 100%|██████████| 4375/4375 [14:51<00:00,  4.91it/s]


Average training loss: 0.5166520882538387
Evaluating...
ROAUC scores:  {'Train': 0.8637818098068237, 'Validation': 0.8635343313217163}
=====Epoch 71
Training...


Iteration: 100%|██████████| 4375/4375 [14:48<00:00,  4.92it/s]


Average training loss: 0.515364158480508
Evaluating...
ROAUC scores:  {'Train': 0.8667749166488647, 'Validation': 0.866027295589447}
=====Epoch 72
Training...


Iteration: 100%|██████████| 4375/4375 [10:49<00:00,  6.74it/s]


Average training loss: 0.5160908549853733
Evaluating...
ROAUC scores:  {'Train': 0.8660194873809814, 'Validation': 0.865546464920044}
=====Epoch 73
Training...


Iteration: 100%|██████████| 4375/4375 [04:54<00:00, 14.84it/s]


Average training loss: 0.5176176962103163
Evaluating...
ROAUC scores:  {'Train': 0.8635982871055603, 'Validation': 0.8630688190460205}
=====Epoch 74
Training...


Iteration: 100%|██████████| 4375/4375 [04:52<00:00, 14.96it/s]


Average training loss: 0.5163173939568656
Evaluating...
ROAUC scores:  {'Train': 0.86283278465271, 'Validation': 0.8618417382240295}
=====Epoch 75
Training...


Iteration: 100%|██████████| 4375/4375 [04:52<00:00, 14.94it/s]


Average training loss: 0.5187882170268467
Evaluating...
ROAUC scores:  {'Train': 0.8622004985809326, 'Validation': 0.8613588809967041}

Finished training!

ROAUC Test score: 0.862023115158081
