In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import torch
import torch.nn as nn
from torch_geometric.data import DataLoader
import argparse
import numpy as np
import random
import ogb
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
import torch.nn.functional as F
from graph_transformer import GT
import os
import datetime
from tqdm import tqdm

parser = argparse.ArgumentParser(description='PyTorch implementation of relative positional encodings and relation-aware self-attention for graph Transformers')
args = parser.parse_args("")
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device:", args.device)

In /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The text.latex.preview rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.
In /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The mathtext.fallback_to_cm rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.
In /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: Support for setting the 'mathtext.fallback_to_cm' rcParam is deprecated since 3.3 and will be removed two minor releases later; use 'mathtext.fallback : 'cm' instead.
In /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The validate_bool_maybe_none function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.
In /opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The savefig.

device: cuda


In [2]:
args.dataset = 'ogbg-molhiv'
args.n_classes = 1
args.lr = 2e-4
args.n_hid = 512
args.n_heads = 8
args.n_layer = 4
args.dropout = 0.2
args.num_epochs = 100
args.k_hop_neighbors = 3
args.weight_decay = 1e-2
args.node_dim = 9
args.edge_dim = 3
args.bsz      = 128

In [3]:
import torch
from networkx.algorithms.shortest_paths.generic import shortest_path
from networkx.algorithms.approximation.connectivity import all_pairs_node_connectivity
from networkx.algorithms.clique import node_clique_number
from networkx.algorithms.centrality import betweenness_centrality, edge_betweenness_centrality
from torch_geometric.utils.convert import to_networkx
from torch_geometric.data import Data
from torch_geometric.utils import to_dense_adj, dense_to_sparse

def pre_process(d):
    node_size = d.x.size(0)
    #     TODO: add summary node that connects to all the other nodes.
    
    #     Construct networkX type of original graph for different metrics
    d_nx = to_networkx(d, to_undirected=True)
    
    #     Augment the graph to be K-hop graph
    dense_orig_adj = to_dense_adj(d.edge_index, max_num_nodes=node_size).squeeze(dim=0)
    for k in range(args.k_hop_neighbors):
        dense_orig_adj = F.hardtanh(torch.mm(dense_orig_adj, dense_orig_adj))
    d.edge_index = dense_to_sparse(dense_orig_adj)[0]
    
    #     Calculate structural feature by the ORIGNAL graph, add them to new edge set.
    sd_edge_attr = shortest_distances(d_nx, d.edge_index)
    cn_edge_attr = node_connectivity(d_nx, d.edge_index)
    return Data(x=d.x, y=d.y, edge_index=d.edge_index, edge_attr=d.edge_attr, \
         sd_edge_attr=sd_edge_attr, cn_edge_attr=cn_edge_attr)
    
def shortest_distances(d_nx, edge_index):
    edge_attr = []
    p = shortest_path(d_nx)
    for s, t in edge_index.t().tolist():
        if s in p and t in p[s]:
            edge_attr += [len(p[s][t]) - 1]
        else:
            edge_attr += [0]
        
    return torch.LongTensor(edge_attr)

def node_connectivity(d_nx, edge_index):
    edge_attr = []
    p = all_pairs_node_connectivity(d_nx)
    for s, t in edge_index.t().tolist():
        if s in p and t in p[s]:
            edge_attr += [p[s][t]]
        else:
            edge_attr += [0]
    return torch.LongTensor(edge_attr)

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import glorot, uniform
from torch_geometric.utils import softmax
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
import math
from torch_geometric.nn import global_mean_pool


class RelEncoding(nn.Module):
    def __init__(self, n_hid, max_len = 240, dropout = 0.2):
        super(RelEncoding, self).__init__()
        self.emb = nn.Embedding(max_len, n_hid)
        self.drop = nn.Dropout(dropout)
        self.emb.weight.data.uniform_(-0.1, 0.1)
    def forward(self, t):
        return self.drop(self.emb(t))

class GT(nn.Module):
    def __init__(self, n_hid, n_out, n_heads, n_layers, edge_dim_dict, dropout = 0.2):
        super(GT, self).__init__()
        self.node_encoder = AtomEncoder(emb_dim=n_hid)
        self.n_hid     = n_hid
        self.n_out     = n_out
        self.drop      = nn.Dropout(dropout)
        self.gcs       = nn.ModuleList([GT_Layer(n_hid, n_heads, edge_dim_dict, dropout)\
                                      for _ in range(n_layers)])
        self.out       = nn.Linear(n_hid, n_out)

    def forward(self, node_attr, edge_index, cn_edge_attr, sd_edge_attr, batch_idx):
        node_rep = self.node_encoder(node_attr)
        for gc in self.gcs:
            node_rep = gc(node_rep, edge_index, cn_edge_attr, sd_edge_attr)
        return self.out(global_mean_pool(node_rep, batch_idx))  

class GT_Layer(MessagePassing):
    def __init__(self, n_hid, n_heads, edge_dim_dict, dropout = 0.2, **kwargs):
        super(GT_Layer, self).__init__(node_dim=0, aggr='add', **kwargs)

        self.n_hid         = n_hid
        self.n_heads       = n_heads
        self.d_k           = n_hid // n_heads
        self.sqrt_dk       = math.sqrt(self.d_k)
        self.att           = None
        
        
        self.k_linear   = nn.Linear(n_hid,   n_hid)
        self.q_linear   = nn.Linear(n_hid,   n_hid)
        self.v_linear   = nn.Linear(n_hid,   n_hid)
        self.a_linear   = nn.Linear(n_hid,   n_hid)
        self.norm       = nn.LayerNorm(n_hid)
        self.drop       = nn.Dropout(dropout)
        
        self.struc_enc = nn.ModuleDict({
            key : RelEncoding(max_len = edge_dim_dict[key], n_hid = n_hid, dropout = dropout)
                for key in edge_dim_dict
        })
        
        self.mid_linear  = nn.Linear(n_hid,  n_hid * 2)
        self.out_linear  = nn.Linear(n_hid * 2,  n_hid)
        self.out_norm    = nn.LayerNorm(n_hid)
        
    def forward(self, node_inp, edge_index, cn_edge_attr, sd_edge_attr):
        return self.propagate(edge_index, node_inp=node_inp, \
                              cn_edge_attr = cn_edge_attr, sd_edge_attr = sd_edge_attr)

    def message(self, edge_index_i, node_inp_i, node_inp_j, cn_edge_attr, sd_edge_attr):
        '''
            j: source, i: target; <j, i>
        '''
        data_size = edge_index_i.size(0)
        '''
            Create Attention and Message tensor beforehand.
        '''
                
        target_node_vec = node_inp_i
        source_node_vec = node_inp_j + self.struc_enc['cn'](cn_edge_attr) + self.struc_enc['sd'](sd_edge_attr)

        q_mat = self.q_linear(target_node_vec).view(-1, self.n_heads, self.d_k)
        k_mat = self.k_linear(source_node_vec).view(-1, self.n_heads, self.d_k)
        v_mat = self.v_linear(source_node_vec).view(-1, self.n_heads, self.d_k)
        
        '''
            Softmax based on target node's id (edge_index_i). Store attention value in self.att for later visualization.
        '''
        self.att = self.dropout(softmax((q_mat * k_mat).sum(dim=-1) / self.sqrt_dk, edge_index_i))
        res = v_mat * self.att.view(-1, self.n_heads, 1)
        return res.view(-1, self.n_hid)


    def update(self, aggr_out, node_inp):
        trans_out = self.norm(self.drop(self.a_linear(F.gelu(aggr_out))) + node_inp)
        trans_out = self.out_norm(self.drop(self.out_linear(F.gelu(self.mid_linear(trans_out)))) + trans_out)
        return trans_out

In [5]:
print("Loading data...")
print("dataset: {} ".format(args.dataset))
dataset = PygGraphPropPredDataset(name=args.dataset, pre_transform=pre_process)
evaluator = Evaluator(name=args.dataset)
split_idx = dataset.get_idx_split()
edge_dim_dict = {'sd': dataset.data.sd_edge_attr.max().int().item() + 1, \
                 'cn': dataset.data.cn_edge_attr.max().int().item() + 1}
model = GT(args.n_hid, args.n_classes, args.n_heads, args.n_layer, edge_dim_dict, args.dropout).to(args.device)

Loading data...
dataset: ogbg-molhiv 


In [6]:
from transformers.optimization import AdamW
def get_optimizer(model: nn.Module, learning_rate: float = 1e-4, adam_eps: float = 1e-6,
                  weight_decay: float = 0.0, ) -> torch.optim.Optimizer:
    no_decay = ['bias', 'LayerNorm.weight']

    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_eps)
    return optimizer

In [7]:
print('Model #Params: %d' % get_n_params(model))

criterion = torch.nn.BCEWithLogitsLoss(reduction = "mean")

optimizer = get_optimizer(model, weight_decay = args.weight_decay)

train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.bsz, shuffle=True)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.bsz, shuffle=False)
test_loader  = DataLoader(dataset[split_idx["test"]],  batch_size=args.bsz, shuffle=False)

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, pct_start = 0.05,\
        steps_per_epoch=len(train_loader), epochs = args.num_epochs, anneal_strategy = 'linear')

Model #Params: 8530945


In [None]:
stats = []
for epoch in range(args.num_epochs):
    model.train()
    train_loss = []
    for data in tqdm(train_loader):
        data.to(args.device)
        out = model(data.x, data.edge_index, data.cn_edge_attr, data.sd_edge_attr, data.batch)
        loss = criterion(out, data.y.float())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()
        
        train_loss += [loss.item()]

    model.eval()
    with torch.no_grad():
        valid_loss = []
        y_true = []
        y_scores = []
        for data in valid_loader:
            data.to(args.device)
            out = model(data.x, data.edge_index, data.cn_edge_attr, data.sd_edge_attr, data.batch)

            loss = criterion(out, data.y.float())
            valid_loss += [loss.item()]

            y_true += [data.y]
            y_scores += [out]

        input_dict = {"y_true": torch.cat(y_true), "y_pred": torch.cat(y_scores)}
        valid_rocauc = evaluator.eval(input_dict)['rocauc']
        
        
        test_loss = []
        y_true = []
        y_scores = []
        for data in test_loader:
            data.to(args.device)
            out = model(data.x, data.edge_index, data.cn_edge_attr, data.sd_edge_attr, data.batch)

            loss = criterion(out, data.y.float())
            test_loss += [loss.item()]

            y_true += [data.y]
            y_scores += [out]

        input_dict = {"y_true": torch.cat(y_true), "y_pred": torch.cat(y_scores)}
        test_rocauc = evaluator.eval(input_dict)['rocauc']
    
    print('Epoch %d: LR: %.5f, Train loss: %.3f Valid loss: %.3f  Valid ROC-AUC: %.3f Test loss: %.3f  Test ROC-AUC: %.3f' \
          % (epoch, optimizer.param_groups[0]['lr'], np.average(train_loss), np.average(valid_loss), \
            valid_rocauc, np.average(test_loss), test_rocauc))
    stats += [[epoch, np.average(train_loss), np.average(valid_loss), valid_rocauc, np.average(test_loss), test_rocauc]]

100%|██████████| 258/258 [00:58<00:00,  4.42it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 0: LR: 0.00005, Train loss: 0.171 Valid loss: 0.101  Valid ROC-AUC: 0.659 Test loss: 0.136  Test ROC-AUC: 0.642


100%|██████████| 258/258 [00:53<00:00,  4.79it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 1: LR: 0.00008, Train loss: 0.152 Valid loss: 0.087  Valid ROC-AUC: 0.718 Test loss: 0.127  Test ROC-AUC: 0.671


100%|██████████| 258/258 [00:53<00:00,  4.82it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 2: LR: 0.00012, Train loss: 0.141 Valid loss: 0.120  Valid ROC-AUC: 0.701 Test loss: 0.147  Test ROC-AUC: 0.713


100%|██████████| 258/258 [00:53<00:00,  4.78it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 3: LR: 0.00016, Train loss: 0.136 Valid loss: 0.081  Valid ROC-AUC: 0.778 Test loss: 0.121  Test ROC-AUC: 0.745


100%|██████████| 258/258 [00:53<00:00,  4.85it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 4: LR: 0.00020, Train loss: 0.131 Valid loss: 0.080  Valid ROC-AUC: 0.802 Test loss: 0.127  Test ROC-AUC: 0.776


100%|██████████| 258/258 [00:54<00:00,  4.77it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 5: LR: 0.00020, Train loss: 0.126 Valid loss: 0.074  Valid ROC-AUC: 0.789 Test loss: 0.118  Test ROC-AUC: 0.780


100%|██████████| 258/258 [00:52<00:00,  4.87it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 6: LR: 0.00020, Train loss: 0.124 Valid loss: 0.075  Valid ROC-AUC: 0.806 Test loss: 0.114  Test ROC-AUC: 0.799


100%|██████████| 258/258 [00:53<00:00,  4.78it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 7: LR: 0.00019, Train loss: 0.117 Valid loss: 0.076  Valid ROC-AUC: 0.793 Test loss: 0.117  Test ROC-AUC: 0.777


100%|██████████| 258/258 [00:52<00:00,  4.88it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 8: LR: 0.00019, Train loss: 0.112 Valid loss: 0.079  Valid ROC-AUC: 0.782 Test loss: 0.124  Test ROC-AUC: 0.782


100%|██████████| 258/258 [00:53<00:00,  4.82it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 9: LR: 0.00019, Train loss: 0.109 Valid loss: 0.072  Valid ROC-AUC: 0.808 Test loss: 0.111  Test ROC-AUC: 0.808


100%|██████████| 258/258 [00:52<00:00,  4.90it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 10: LR: 0.00019, Train loss: 0.108 Valid loss: 0.076  Valid ROC-AUC: 0.777 Test loss: 0.119  Test ROC-AUC: 0.783


100%|██████████| 258/258 [00:53<00:00,  4.87it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 11: LR: 0.00019, Train loss: 0.107 Valid loss: 0.070  Valid ROC-AUC: 0.808 Test loss: 0.114  Test ROC-AUC: 0.794


100%|██████████| 258/258 [00:53<00:00,  4.84it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 12: LR: 0.00018, Train loss: 0.103 Valid loss: 0.077  Valid ROC-AUC: 0.787 Test loss: 0.116  Test ROC-AUC: 0.784


100%|██████████| 258/258 [00:53<00:00,  4.83it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 13: LR: 0.00018, Train loss: 0.102 Valid loss: 0.072  Valid ROC-AUC: 0.804 Test loss: 0.116  Test ROC-AUC: 0.797


100%|██████████| 258/258 [00:53<00:00,  4.83it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 14: LR: 0.00018, Train loss: 0.096 Valid loss: 0.078  Valid ROC-AUC: 0.789 Test loss: 0.117  Test ROC-AUC: 0.779


100%|██████████| 258/258 [00:53<00:00,  4.85it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 15: LR: 0.00018, Train loss: 0.097 Valid loss: 0.073  Valid ROC-AUC: 0.823 Test loss: 0.120  Test ROC-AUC: 0.765


100%|██████████| 258/258 [00:53<00:00,  4.83it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 16: LR: 0.00017, Train loss: 0.092 Valid loss: 0.092  Valid ROC-AUC: 0.788 Test loss: 0.132  Test ROC-AUC: 0.769


100%|██████████| 258/258 [00:53<00:00,  4.80it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 17: LR: 0.00017, Train loss: 0.089 Valid loss: 0.071  Valid ROC-AUC: 0.822 Test loss: 0.119  Test ROC-AUC: 0.782


100%|██████████| 258/258 [00:53<00:00,  4.82it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 18: LR: 0.00017, Train loss: 0.087 Valid loss: 0.075  Valid ROC-AUC: 0.793 Test loss: 0.114  Test ROC-AUC: 0.792


100%|██████████| 258/258 [00:52<00:00,  4.89it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 19: LR: 0.00017, Train loss: 0.083 Valid loss: 0.073  Valid ROC-AUC: 0.800 Test loss: 0.117  Test ROC-AUC: 0.780


100%|██████████| 258/258 [00:54<00:00,  4.77it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 20: LR: 0.00017, Train loss: 0.078 Valid loss: 0.077  Valid ROC-AUC: 0.806 Test loss: 0.132  Test ROC-AUC: 0.763


100%|██████████| 258/258 [00:53<00:00,  4.78it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 21: LR: 0.00016, Train loss: 0.077 Valid loss: 0.083  Valid ROC-AUC: 0.765 Test loss: 0.135  Test ROC-AUC: 0.761


100%|██████████| 258/258 [00:54<00:00,  4.74it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 22: LR: 0.00016, Train loss: 0.075 Valid loss: 0.082  Valid ROC-AUC: 0.783 Test loss: 0.135  Test ROC-AUC: 0.765


100%|██████████| 258/258 [00:53<00:00,  4.80it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 23: LR: 0.00016, Train loss: 0.070 Valid loss: 0.082  Valid ROC-AUC: 0.800 Test loss: 0.143  Test ROC-AUC: 0.770


100%|██████████| 258/258 [00:53<00:00,  4.78it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 24: LR: 0.00016, Train loss: 0.068 Valid loss: 0.089  Valid ROC-AUC: 0.748 Test loss: 0.149  Test ROC-AUC: 0.751


100%|██████████| 258/258 [00:53<00:00,  4.85it/s]
  0%|          | 1/258 [00:00<00:47,  5.37it/s]

Epoch 25: LR: 0.00016, Train loss: 0.063 Valid loss: 0.083  Valid ROC-AUC: 0.806 Test loss: 0.141  Test ROC-AUC: 0.787


100%|██████████| 258/258 [00:53<00:00,  4.86it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 26: LR: 0.00015, Train loss: 0.060 Valid loss: 0.091  Valid ROC-AUC: 0.803 Test loss: 0.147  Test ROC-AUC: 0.791


100%|██████████| 258/258 [00:53<00:00,  4.82it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 27: LR: 0.00015, Train loss: 0.056 Valid loss: 0.094  Valid ROC-AUC: 0.781 Test loss: 0.164  Test ROC-AUC: 0.779


100%|██████████| 258/258 [00:53<00:00,  4.83it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 28: LR: 0.00015, Train loss: 0.052 Valid loss: 0.097  Valid ROC-AUC: 0.789 Test loss: 0.158  Test ROC-AUC: 0.787


100%|██████████| 258/258 [00:53<00:00,  4.85it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 29: LR: 0.00015, Train loss: 0.050 Valid loss: 0.098  Valid ROC-AUC: 0.791 Test loss: 0.157  Test ROC-AUC: 0.795


100%|██████████| 258/258 [00:52<00:00,  4.91it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 30: LR: 0.00015, Train loss: 0.047 Valid loss: 0.098  Valid ROC-AUC: 0.770 Test loss: 0.159  Test ROC-AUC: 0.805


100%|██████████| 258/258 [00:52<00:00,  4.95it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 31: LR: 0.00014, Train loss: 0.047 Valid loss: 0.103  Valid ROC-AUC: 0.757 Test loss: 0.185  Test ROC-AUC: 0.761


100%|██████████| 258/258 [00:53<00:00,  4.85it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 32: LR: 0.00014, Train loss: 0.044 Valid loss: 0.104  Valid ROC-AUC: 0.772 Test loss: 0.189  Test ROC-AUC: 0.758


100%|██████████| 258/258 [00:54<00:00,  4.75it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 33: LR: 0.00014, Train loss: 0.038 Valid loss: 0.124  Valid ROC-AUC: 0.743 Test loss: 0.192  Test ROC-AUC: 0.783


100%|██████████| 258/258 [00:54<00:00,  4.73it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 34: LR: 0.00014, Train loss: 0.037 Valid loss: 0.119  Valid ROC-AUC: 0.771 Test loss: 0.204  Test ROC-AUC: 0.767


100%|██████████| 258/258 [00:54<00:00,  4.77it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 35: LR: 0.00013, Train loss: 0.034 Valid loss: 0.134  Valid ROC-AUC: 0.735 Test loss: 0.205  Test ROC-AUC: 0.768


100%|██████████| 258/258 [00:53<00:00,  4.83it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 36: LR: 0.00013, Train loss: 0.035 Valid loss: 0.123  Valid ROC-AUC: 0.753 Test loss: 0.181  Test ROC-AUC: 0.779


100%|██████████| 258/258 [00:53<00:00,  4.86it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 37: LR: 0.00013, Train loss: 0.029 Valid loss: 0.130  Valid ROC-AUC: 0.783 Test loss: 0.215  Test ROC-AUC: 0.781


100%|██████████| 258/258 [00:53<00:00,  4.80it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 38: LR: 0.00013, Train loss: 0.029 Valid loss: 0.119  Valid ROC-AUC: 0.786 Test loss: 0.205  Test ROC-AUC: 0.785


100%|██████████| 258/258 [00:53<00:00,  4.81it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 39: LR: 0.00013, Train loss: 0.026 Valid loss: 0.118  Valid ROC-AUC: 0.778 Test loss: 0.215  Test ROC-AUC: 0.761


100%|██████████| 258/258 [00:53<00:00,  4.81it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 40: LR: 0.00012, Train loss: 0.026 Valid loss: 0.136  Valid ROC-AUC: 0.813 Test loss: 0.239  Test ROC-AUC: 0.774


100%|██████████| 258/258 [00:52<00:00,  4.94it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 41: LR: 0.00012, Train loss: 0.023 Valid loss: 0.135  Valid ROC-AUC: 0.798 Test loss: 0.234  Test ROC-AUC: 0.792


100%|██████████| 258/258 [00:53<00:00,  4.84it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 42: LR: 0.00012, Train loss: 0.022 Valid loss: 0.145  Valid ROC-AUC: 0.787 Test loss: 0.252  Test ROC-AUC: 0.790


100%|██████████| 258/258 [00:53<00:00,  4.85it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 43: LR: 0.00012, Train loss: 0.023 Valid loss: 0.145  Valid ROC-AUC: 0.770 Test loss: 0.250  Test ROC-AUC: 0.785


100%|██████████| 258/258 [00:53<00:00,  4.79it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 44: LR: 0.00012, Train loss: 0.022 Valid loss: 0.135  Valid ROC-AUC: 0.783 Test loss: 0.234  Test ROC-AUC: 0.775


100%|██████████| 258/258 [00:53<00:00,  4.80it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 45: LR: 0.00011, Train loss: 0.022 Valid loss: 0.137  Valid ROC-AUC: 0.780 Test loss: 0.230  Test ROC-AUC: 0.778


100%|██████████| 258/258 [00:53<00:00,  4.81it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 46: LR: 0.00011, Train loss: 0.022 Valid loss: 0.143  Valid ROC-AUC: 0.773 Test loss: 0.245  Test ROC-AUC: 0.783


100%|██████████| 258/258 [00:54<00:00,  4.76it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 47: LR: 0.00011, Train loss: 0.019 Valid loss: 0.153  Valid ROC-AUC: 0.785 Test loss: 0.257  Test ROC-AUC: 0.790


100%|██████████| 258/258 [00:53<00:00,  4.82it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 48: LR: 0.00011, Train loss: 0.019 Valid loss: 0.149  Valid ROC-AUC: 0.757 Test loss: 0.237  Test ROC-AUC: 0.784


100%|██████████| 258/258 [00:53<00:00,  4.79it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 49: LR: 0.00011, Train loss: 0.018 Valid loss: 0.150  Valid ROC-AUC: 0.778 Test loss: 0.250  Test ROC-AUC: 0.785


100%|██████████| 258/258 [00:53<00:00,  4.80it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 50: LR: 0.00010, Train loss: 0.018 Valid loss: 0.156  Valid ROC-AUC: 0.749 Test loss: 0.251  Test ROC-AUC: 0.788


100%|██████████| 258/258 [00:54<00:00,  4.72it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 51: LR: 0.00010, Train loss: 0.016 Valid loss: 0.154  Valid ROC-AUC: 0.796 Test loss: 0.283  Test ROC-AUC: 0.782


100%|██████████| 258/258 [00:53<00:00,  4.79it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 52: LR: 0.00010, Train loss: 0.015 Valid loss: 0.155  Valid ROC-AUC: 0.788 Test loss: 0.257  Test ROC-AUC: 0.805


100%|██████████| 258/258 [00:53<00:00,  4.82it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 53: LR: 0.00010, Train loss: 0.014 Valid loss: 0.175  Valid ROC-AUC: 0.750 Test loss: 0.265  Test ROC-AUC: 0.787


100%|██████████| 258/258 [00:53<00:00,  4.80it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 54: LR: 0.00009, Train loss: 0.013 Valid loss: 0.182  Valid ROC-AUC: 0.786 Test loss: 0.300  Test ROC-AUC: 0.773


100%|██████████| 258/258 [00:53<00:00,  4.83it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 55: LR: 0.00009, Train loss: 0.016 Valid loss: 0.170  Valid ROC-AUC: 0.771 Test loss: 0.275  Test ROC-AUC: 0.782


100%|██████████| 258/258 [00:53<00:00,  4.85it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 56: LR: 0.00009, Train loss: 0.012 Valid loss: 0.172  Valid ROC-AUC: 0.796 Test loss: 0.298  Test ROC-AUC: 0.778


100%|██████████| 258/258 [00:54<00:00,  4.76it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 57: LR: 0.00009, Train loss: 0.012 Valid loss: 0.178  Valid ROC-AUC: 0.786 Test loss: 0.301  Test ROC-AUC: 0.775


100%|██████████| 258/258 [00:53<00:00,  4.82it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 58: LR: 0.00009, Train loss: 0.012 Valid loss: 0.164  Valid ROC-AUC: 0.798 Test loss: 0.291  Test ROC-AUC: 0.799


100%|██████████| 258/258 [00:53<00:00,  4.83it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 59: LR: 0.00008, Train loss: 0.009 Valid loss: 0.177  Valid ROC-AUC: 0.799 Test loss: 0.319  Test ROC-AUC: 0.784


100%|██████████| 258/258 [00:53<00:00,  4.81it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 60: LR: 0.00008, Train loss: 0.011 Valid loss: 0.166  Valid ROC-AUC: 0.811 Test loss: 0.306  Test ROC-AUC: 0.795


100%|██████████| 258/258 [00:54<00:00,  4.77it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 61: LR: 0.00008, Train loss: 0.010 Valid loss: 0.183  Valid ROC-AUC: 0.793 Test loss: 0.321  Test ROC-AUC: 0.782


100%|██████████| 258/258 [00:54<00:00,  4.75it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 62: LR: 0.00008, Train loss: 0.010 Valid loss: 0.196  Valid ROC-AUC: 0.786 Test loss: 0.365  Test ROC-AUC: 0.790


100%|██████████| 258/258 [00:54<00:00,  4.77it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 63: LR: 0.00008, Train loss: 0.010 Valid loss: 0.206  Valid ROC-AUC: 0.771 Test loss: 0.367  Test ROC-AUC: 0.776


100%|██████████| 258/258 [00:53<00:00,  4.82it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 64: LR: 0.00007, Train loss: 0.008 Valid loss: 0.199  Valid ROC-AUC: 0.787 Test loss: 0.335  Test ROC-AUC: 0.788


100%|██████████| 258/258 [00:52<00:00,  4.88it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 65: LR: 0.00007, Train loss: 0.007 Valid loss: 0.205  Valid ROC-AUC: 0.797 Test loss: 0.359  Test ROC-AUC: 0.781


100%|██████████| 258/258 [00:53<00:00,  4.81it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 66: LR: 0.00007, Train loss: 0.006 Valid loss: 0.200  Valid ROC-AUC: 0.800 Test loss: 0.366  Test ROC-AUC: 0.786


100%|██████████| 258/258 [00:53<00:00,  4.82it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 67: LR: 0.00007, Train loss: 0.006 Valid loss: 0.198  Valid ROC-AUC: 0.781 Test loss: 0.370  Test ROC-AUC: 0.767


100%|██████████| 258/258 [00:53<00:00,  4.84it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 68: LR: 0.00007, Train loss: 0.007 Valid loss: 0.209  Valid ROC-AUC: 0.794 Test loss: 0.380  Test ROC-AUC: 0.777


100%|██████████| 258/258 [00:53<00:00,  4.79it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 69: LR: 0.00006, Train loss: 0.006 Valid loss: 0.214  Valid ROC-AUC: 0.793 Test loss: 0.394  Test ROC-AUC: 0.792


100%|██████████| 258/258 [00:53<00:00,  4.85it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 70: LR: 0.00006, Train loss: 0.007 Valid loss: 0.217  Valid ROC-AUC: 0.787 Test loss: 0.368  Test ROC-AUC: 0.780


100%|██████████| 258/258 [00:53<00:00,  4.82it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 71: LR: 0.00006, Train loss: 0.006 Valid loss: 0.223  Valid ROC-AUC: 0.798 Test loss: 0.398  Test ROC-AUC: 0.779


100%|██████████| 258/258 [00:54<00:00,  4.76it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 72: LR: 0.00006, Train loss: 0.006 Valid loss: 0.220  Valid ROC-AUC: 0.787 Test loss: 0.366  Test ROC-AUC: 0.782


 52%|█████▏    | 135/258 [00:27<00:27,  4.51it/s]