In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.nn import SAGEConv, GraphConv, EGATConv
from torch.nn import Linear
import dgl.function as fn
import pandas as pd
import sklearn.metrics as sk_m
import numpy as np
import scipy.sparse as sp
import itertools
from ogb.graphproppred import DglGraphPropPredDataset
from torch.utils.data import Dataset
from tqdm import tqdm
from collections import Counter
import random
import os
from ogb.graphproppred import Evaluator


In [30]:
device = torch.device("cpu")

In [31]:
evaluator = Evaluator(name = 'ogbg-molhiv')


In [32]:
def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")


# Loading and preparing the data

In [33]:
class CustomGraphDataGenerator(Dataset):
    def __init__(self, graphs, labels, device, balanced_sampling=False, batch_size=32):
        self.device = device
        self.balanced_sampling = balanced_sampling
        self.batch_size= batch_size
        self.y = labels
        self.X = graphs
        self.__prepare_samples()
        self.__define_batch_size()
        self.neg_indices = np.array(list(range(len(self.X_neg))))
        self.pos_indices = np.array(list(range(len(self.X_pos))))
        self.overall_indices = np.array(list(range(len(self.X))))
        
    @staticmethod
    def floor(x):
        if x<1:
            tmp = 1
        else:
            tmp = int(x)
        return int(np.ceil(x)) if x % tmp >= 0.5 else int(np.floor(x))
       
    
    def __define_batch_size(self):
        counts = Counter(self.y.reshape(-1))
        
        batch_size_neg = (counts[0]/len(self.X))*self.batch_size
        batch_size_neg = self.floor(batch_size_neg)
        
        batch_size_pos = (counts[1]/len(self.X))*self.batch_size
        batch_size_pos = self.floor(batch_size_pos)
        
        self.batch_size_pos = batch_size_pos
        self.batch_size_neg = batch_size_neg
    
    def __prepare_samples(self):
        self.X_neg = self.X[self.y.reshape(-1)==0]
        self.X_pos = self.X[self.y.reshape(-1)==1]
        
        
    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.ceil(len(self.X) / self.batch_size))
        
    def __getitem__(self, index):
        if index+1>self.__len__()-1:
            raise StopIteration
        
        if self.balanced_sampling:
            
            indices_neg = self.neg_indices[index*self.batch_size_neg:(index+1)*self.batch_size_neg]
            indices_pos = self.pos_indices[index*self.batch_size_pos:(index+1)*self.batch_size_pos]

            X = np.concatenate((self.X_pos[indices_pos], self.X_neg[indices_neg]))
            y = np.concatenate((self.y[indices_pos], self.y[indices_neg]))
        
            return dgl.batch(X).to(self.device), torch.Tensor(y).to(self.device)
        
        else:
            
            indices = self.overall_indices[index*self.batch_size:(index+1)*self.batch_size]
            
            return dgl.batch(self.X[indices]).to(self.device), torch.Tensor(self.y[indices]).to(self.device)
        
    
    def shuffle_indices(self):
        np.random.shuffle(self.neg_indices)
        np.random.shuffle(self.pos_indices)
        np.random.shuffle(self.overall_indices)

In [34]:
dataset = DglGraphPropPredDataset(name = 'ogbg-molhiv') 
split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
labels = np.array(dataset.labels)
graphs = np.array(dataset.graphs)
train_graphs, val_graphs, test_graphs  = graphs[train_idx], graphs[valid_idx],  graphs[test_idx]
train_labels, valid_labels, test_labels = labels[train_idx], labels[valid_idx], labels[test_idx]

In [35]:
test_graphs_batched = dgl.batch(test_graphs)

In [36]:
Counter(train_labels.reshape(-1))

Counter({0: 31669, 1: 1232})

In [37]:
len(dataset), len(train_graphs), len(val_graphs), len(test_graphs)

(41127, 32901, 4113, 4113)

## class weights

In [38]:
from sklearn.utils.class_weight import compute_class_weight

In [39]:
class_weights = compute_class_weight('balanced', y=train_labels.reshape(-1),
                                     classes=np.unique(train_labels.reshape(-1)))


# Training

In [60]:
def metrics(labels, scores, threshold=0.5):
    labels = np.hstack([i.reshape(-1) for i in labels])
    scores = np.hstack([i.reshape(-1) for i in scores])
    pred = (np.array(scores)>threshold).astype(int)
    f1 = sk_m.f1_score(y_pred=pred, y_true=labels, average='binary')
    roc = sk_m.roc_auc_score(y_score=scores, y_true=labels)
    return f1, roc

def BCELoss_class_weighted(weights):

    def loss(input, target):
        input = torch.clamp(input,min=1e-7,max=1-1e-7)
        bce = - weights[1] * target * torch.log(input) - (1 - target) * weights[0] * torch.log(1 - input)
        return torch.mean(bce)

    return loss

def train(train_generator, val_generator, model, 
          epochs=10, lr=1e-3, weight_decay=1e-5,
          class_weights=None,
          use_edge_feature=False, shuffle=False, early_stopping=None):
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    loss_func = BCELoss_class_weighted(class_weights)
    
    if not (class_weights is None):
        class_weights = torch.FloatTensor(class_weights)
        loss_func = BCELoss_class_weighted(class_weights)
    else:
        loss_func = F.binary_cross_entropy

    
    for e in range(epochs):
        train_scores = []
        train_labels = []
        loss_train_holder = []
        
        val_scores = []
        val_loss_holder = []
        val_labels = []
        
        for train_graph, label in tqdm(train_generator):
            if use_edge_feature:
                logits_train = model(train_graph, train_graph.ndata['feat'].to(torch.float32),\
                                     train_graph.edata['feat'].to(torch.float32))
            else:
                logits_train = model(train_graph, train_graph.ndata['feat'])
            sigmoided_train = F.sigmoid(logits_train)
            loss_train = loss_func(sigmoided_train, label)
            
            # Backward
            optimizer.zero_grad()
            loss_train.backward()
            optimizer.step()
        
            train_labels.append(label.numpy())
            train_scores.append(sigmoided_train.detach().numpy())
            loss_train_holder.append(loss_train.detach().numpy())
            
        
        with torch.no_grad():  
            for val_graph, val_label in val_generator:
                
                if use_edge_feature:
                    logits_val = model(val_graph, val_graph.ndata['feat'].to(torch.float32), \
                                   val_graph.edata['feat'].to(torch.float32))
                else:
                    logits_val = model(val_graph, val_graph.ndata['feat'])
                    
                sigmoided_val = F.sigmoid(logits_val)
                
                loss_val = loss_func(sigmoided_val, val_label)
                
                
                val_labels.append(val_label.numpy())
                val_scores.append(sigmoided_val.detach().numpy())
                val_loss_holder.append(loss_val)
            
        train_f1, train_roc = metrics(train_labels, train_scores)
        val_f1, val_roc = metrics(val_labels, val_scores)
        

        print('In epoch {}, Train loss: {:.3f}, train roc: {:.3f}, train f1: {:.3f},'
              ' val loss: {:.3f}, val roc: {:.3f}, val f1 : {:.3f}'.format(
            e, np.mean(loss_train_holder), train_roc, train_f1, np.mean(val_loss_holder), \
        val_roc, val_f1))
        
        if early_stopping:
            early_stopping(val_roc, model)
            print('Early stopping extemum : {}'.format(early_stopping.extemum_value))
            if early_stopping.early_stop:
                print('Stopping early')
                model = early_stopping.best_model
                break
        
        if shuffle:
            train_generator.shuffle_indices()
            
            
    return model

In [61]:
import operator

In [78]:
class EarlyStopping:
    def __init__(self, tolerance=5, mode='min'):
        assert mode in ['min','max'], 'Mode should be min or max'
        self.mode = operator.lt if mode=='min' else operator.gt 
        self.tolerance = tolerance
        self.counter = 0
        self.early_stop = False
        self.extemum_value = None
        self.best_model = None
        
        
    def __call__(self, val, model):
        if self.extemum_value is None:
            self.extemum_value = val
            self.best_model = model
        else:
            if not self.mode(val, self.extemum_value):
                self.counter+=1
            else:
                self.extemum_value = val
                self.best_model = model
        
        if self.counter==self.tolerance:
            self.early_stop=True

In [181]:
from dgl.nn import GraphConv
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder

class GraphModelNoEdge(nn.Module):
    def __init__(self, emb_shape=100):
        set_seed(0)
        super(GraphModelNoEdge, self).__init__()
        self.emb = AtomEncoder(emb_shape)
        self.node_conv1 = GraphConv(emb_shape, 256, allow_zero_in_degree=True)
        self.node_conv2 = SAGEConv(256, 128, 'lstm')
        self.node_conv3 = SAGEConv(128, 64, 'mean')
        self.node_conv4 = SAGEConv(64, 1, 'mean')
        self.dropout1 = torch.nn.Dropout(0.25)

        
    def forward(self, g, n_data):
        h_nodes = self.emb(n_data)
        h_nodes = self.node_conv1(g, h_nodes)
        h_nodes = F.relu(h_nodes)
        h_nodes = self.dropout1(self.node_conv2(g, h_nodes))
        h_nodes = F.relu(h_nodes)
        h_nodes = self.node_conv3(g, h_nodes)
        h_nodes = F.relu(h_nodes)
        h_nodes = self.node_conv4(g, h_nodes)
        g.ndata['h'] = h_nodes
        h_nodes = dgl.mean_nodes(g, 'h')
        return h_nodes
    


In [182]:
model = GraphModelNoEdge(100)
model.to(device)
batch_size = 64
balanced_sampling = False
use_edge_feature = False
early_stopping = EarlyStopping(tolerance=3, mode='max')

Random seed set as 0


In [183]:
train_generator = CustomGraphDataGenerator(train_graphs, train_labels,device=device, batch_size=batch_size,
                                          balanced_sampling=balanced_sampling)
valid_generator = CustomGraphDataGenerator(val_graphs, valid_labels, device=device, batch_size=batch_size,
                                          balanced_sampling=balanced_sampling)

In [None]:
model = train(train_generator, valid_generator, model, class_weights=class_weights, weight_decay=1e-5,
             epochs=25, use_edge_feature=use_edge_feature, 
             early_stopping=early_stopping)

100%|████████████████████████████████████████▉| 514/515 [02:43<00:00,  3.15it/s]


In epoch 0, Train loss: 0.677, train roc: 0.618, train f1: 0.096, val loss: 0.958, val roc: 0.675, val f1 : 0.039
Early stopping extemum : 0.6752702058638131


100%|████████████████████████████████████████▉| 514/515 [02:42<00:00,  3.16it/s]


In epoch 1, Train loss: 0.656, train roc: 0.656, train f1: 0.115, val loss: 0.943, val roc: 0.693, val f1 : 0.039
Early stopping extemum : 0.6929754162630875


 23%|█████████▋                               | 121/515 [00:33<01:59,  3.29it/s]

# Evaluation

In [None]:
predicted = model(test_graphs_batched, test_graphs_batched.ndata['feat'])
predicted = F.sigmoid(predicted).detach().numpy()

In [None]:
evaluator.eval({"y_true": test_labels, "y_pred": predicted})

# Table with results

# Conclusions:
### 1) Using sampling w.r.t distribution in the dataset (balanced_sampling) on average gives worser results; 
### 2) Using weighted loss improves rocauc score (basically because of the highly imbalanced weights)
### 3) Adding l2 regularization helps 
### 4) Using embeddings for node features hepls a lot, no overfitting detected (basically because of the better representation of categorical features)
