In [6]:
import torch 
from torch_geometric.data import DataLoader
from sklearn.metrics import confusion_matrix, f1_score, \
    accuracy_score, precision_score, recall_score, roc_auc_score
import numpy as np
from tqdm import tqdm
import import_ipynb
from Dataset import MoleculeDataset 
from Model import GNN
import mlflow.pytorchnp

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

mlflow.set_tracking_uri("http://localhost:5000")


In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def train_one_epoch(epoch, model, train_loader, optimizer, loss_fn):
    
    all_preds = []
    all_labels = []
    running_loss = 0.0
    step = 0
    for _, batch in enumerate(tqdm(train_loader)):
        
        batch.to(device)  
        
        optimizer.zero_grad() 
        
        pred = model(batch.x.float(), 
                                batch.edge_attr.float(),
                                batch.edge_index, 
                                batch.batch) 
        
        loss = loss_fn(torch.squeeze(pred), batch.y.float())
        loss.backward()  
        optimizer.step()  
        
        running_loss += loss.item()
        step += 1
        all_preds.append(np.rint(torch.sigmoid(pred).cpu().detach().numpy()))
        all_labels.append(batch.y.cpu().detach().numpy())
    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    calculate_metrics(all_preds, all_labels, epoch, "train")
    return running_loss/step

def test(epoch, model, test_loader, loss_fn):
    all_preds = []
    all_preds_raw = []
    all_labels = []
    running_loss = 0.0
    step = 0
    for batch in test_loader:
        batch.to(device)  
        pred = model(batch.x.float(), 
                        batch.edge_attr.float(),
                        batch.edge_index, 
                        batch.batch) 
        loss = loss_fn(torch.squeeze(pred), batch.y.float())

        running_loss += loss.item()
        step += 1
        all_preds.append(np.rint(torch.sigmoid(pred).cpu().detach().numpy()))
        all_preds_raw.append(torch.sigmoid(pred).cpu().detach().numpy())
        all_labels.append(batch.y.cpu().detach().numpy())
    
    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    print(all_preds_raw[0][:10])
    print(all_preds[:10])
    print(all_labels[:10])
    calculate_metrics(all_preds, all_labels, epoch, "test")
    return running_loss/step


def calculate_metrics(y_pred, y_true, epoch, type):
    print(f"\n Confusion matrix: \n {confusion_matrix(y_pred, y_true)}")
    print(f"F1 Score: {f1_score(y_true, y_pred)}")
    print(f"Accuracy: {accuracy_score(y_true, y_pred)}")
    prec = precision_score(y_true, y_pred)
    rec = recall_score(y_true, y_pred)
    print(f"Precision: {prec}")
    print(f"Recall: {rec}")
    mlflow.log_metric(key=f"Precision-{type}", value=float(prec), step=epoch)
    mlflow.log_metric(key=f"Recall-{type}", value=float(rec), step=epoch)
    try:
        roc = roc_auc_score(y_true, y_pred)
        print(f"ROC AUC: {roc}")
        mlflow.log_metric(key=f"ROC-AUC-{type}", value=float(roc), step=epoch)
    except:
        mlflow.log_metric(key=f"ROC-AUC-{type}", value=float(0), step=epoch)
        print(f"ROC AUC: notdefined")


from mango import scheduler, tuner
from mango.tuner import Tuner
from config import HYPERPARAMETERS, SIGNATURE

def run_one_training(params):
    params = params[0]
    with mlflow.start_run() as run:
        
        for key in params.keys():
            mlflow.log_param(key, params[key])

        
        train_dataset = MoleculeDataset(root="data/", filename="HIV_train_oversampled.csv")
        test_dataset = MoleculeDataset(root="data/", filename="HIV_test.csv", test=True)
        params["model_edge_dim"] = train_dataset[0].edge_attr.shape[1]

        
        train_loader = DataLoader(train_dataset, batch_size=params["batch_size"], shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=params["batch_size"], shuffle=True)

        
        model_params = {k: v for k, v in params.items() if k.startswith("model_")}
        model = GNN(feature_size=train_dataset[0].x.shape[1], model_params=model_params) 
        model = model.to(device)
        print(f"Number of parameters: {count_parameters(model)}")
        mlflow.log_param("num_params", count_parameters(model))

    
        weight = torch.tensor([params["pos_weight"]], dtype=torch.float32).to(device)
        loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=weight)
        optimizer = torch.optim.SGD(model.parameters(), 
                                    lr=params["learning_rate"],
                                    momentum=params["sgd_momentum"],
                                    weight_decay=params["weight_decay"])
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=params["scheduler_gamma"])
        
        
        best_loss = 1000
        early_stopping_counter = 0
        for epoch in range(300): 
            if early_stopping_counter <= 10: 
                
                model.train()
                loss = train_one_epoch(epoch, model, train_loader, optimizer, loss_fn)
                print(f"Epoch {epoch} | Train Loss {loss}")
                mlflow.log_metric(key="Train loss", value=float(loss), step=epoch)

                
                model.eval()
                if epoch % 5 == 0:
                    loss = test(epoch, model, test_loader, loss_fn)
                    print(f"Epoch {epoch} | Test Loss {loss}")
                    mlflow.log_metric(key="Test loss", value=float(loss), step=epoch)
                    
                    
                    if float(loss) < best_loss:
                        best_loss = loss
                        
                        mlflow.pytorch.log_model(model, "model", signature=SIGNATURE)
                        early_stopping_counter = 0
                    else:
                        early_stopping_counter += 1

                scheduler.step()
            else:
                
                return [best_loss]
    print(f"Finishing training with best test loss: {best_loss}")
    return [best_loss]



config = dict()
config["optimizer"] = "Bayesian"
config["num_iteration"] = 100

tuner = Tuner(HYPERPARAMETERS, 
              objective=run_one_training,
              conf_dict=config) 
results = tuner.minimize()

Running hyperparameter search...
Loading dataset...
Loading model...
Number of parameters: 33137


100%|██████████| 1286/1286 [00:49<00:00, 25.74it/s]



 Confusion matrix: 
 [[39064  1422]
 [  620    21]]
F1 Score: 0.02015355086372361
Accuracy: 0.9503489192014978
Precision: 0.0327613104524181
Recall: 0.014553014553014554
ROC AUC: 0.49946479474752836
Epoch 0 | Train Loss 0.2668248330203825
[[0.16080391]
 [0.14538158]
 [0.1551345 ]
 [0.16296472]
 [0.1430269 ]
 [0.13044773]
 [0.17981192]
 [0.0990032 ]
 [0.14161502]
 [0.13394845]]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0 0 0 0 0 0 0 0 0 0]

 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 0 | Test Loss 0.2262515065319616


100%|██████████| 1286/1286 [00:50<00:00, 25.55it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 1 | Train Loss 0.20674607478383447


100%|██████████| 1286/1286 [00:50<00:00, 25.35it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 2 | Train Loss 0.18988362468582104


100%|██████████| 1286/1286 [00:51<00:00, 24.99it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 3 | Train Loss 0.17764518688639994


100%|██████████| 1286/1286 [00:50<00:00, 25.31it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 4 | Train Loss 0.17305404676666156


100%|██████████| 1286/1286 [00:51<00:00, 24.73it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 5 | Train Loss 0.16727635769889013
[[0.06773974]
 [0.10751171]
 [0.05739827]
 [0.06013353]
 [0.07677156]
 [0.06829631]
 [0.08466428]
 [0.07414608]
 [0.09273242]
 [0.06587753]]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[1 0 0 0 0 0 0 0 0 0]

 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 5 | Test Loss 0.16039109892830322


100%|██████████| 1286/1286 [00:51<00:00, 24.96it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 6 | Train Loss 0.16449265194474538


100%|██████████| 1286/1286 [00:50<00:00, 25.29it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 7 | Train Loss 0.16328275700217088


100%|██████████| 1286/1286 [00:52<00:00, 24.57it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 8 | Train Loss 0.15909959797633008


100%|██████████| 1286/1286 [00:53<00:00, 24.04it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 9 | Train Loss 0.1576280578113455


100%|██████████| 1286/1286 [00:51<00:00, 24.77it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 10 | Train Loss 0.15727885789056867
[[0.05781325]
 [0.05872874]
 [0.04750323]
 [0.07448643]
 [0.03571719]
 [0.08424575]
 [0.03800564]
 [0.03674357]
 [0.06914573]
 [0.04873258]]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0 0 0 0 0 1 0 0 0 0]

 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 10 | Test Loss 0.15311741014337094


100%|██████████| 1286/1286 [00:49<00:00, 25.85it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 11 | Train Loss 0.15605068507185713


100%|██████████| 1286/1286 [00:49<00:00, 25.78it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 12 | Train Loss 0.1551374804408146


100%|██████████| 1286/1286 [00:49<00:00, 25.98it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 13 | Train Loss 0.15402306419323614


100%|██████████| 1286/1286 [00:50<00:00, 25.47it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 14 | Train Loss 0.1536769873374265


100%|██████████| 1286/1286 [00:49<00:00, 25.91it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 15 | Train Loss 0.15248282358316853
[[0.0520501 ]
 [0.06328867]
 [0.05264958]
 [0.05451255]
 [0.05720855]
 [0.02522838]
 [0.0467067 ]
 [0.03620232]
 [0.06178212]
 [0.05691405]]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0 0 0 0 0 0 0 0 0 0]

 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 15 | Test Loss 0.1497500517968998


100%|██████████| 1286/1286 [00:49<00:00, 26.02it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 16 | Train Loss 0.15152271837285758


100%|██████████| 1286/1286 [00:49<00:00, 25.84it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 17 | Train Loss 0.15164173137495066


100%|██████████| 1286/1286 [00:49<00:00, 25.98it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 18 | Train Loss 0.15112119193824694


100%|██████████| 1286/1286 [00:49<00:00, 26.02it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 19 | Train Loss 0.1517655519423876


100%|██████████| 1286/1286 [00:50<00:00, 25.64it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 20 | Train Loss 0.1506289941306526
[[0.04463646]
 [0.04594197]
 [0.03922605]
 [0.02775345]
 [0.02122546]
 [0.05029355]
 [0.02276346]
 [0.0455997 ]
 [0.03000642]
 [0.04767916]]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0 0 0 0 0 0 0 0 0 0]

 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 20 | Test Loss 0.14862840756495854


100%|██████████| 1286/1286 [00:49<00:00, 26.14it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 21 | Train Loss 0.15075745684432204


100%|██████████| 1286/1286 [00:48<00:00, 26.79it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 22 | Train Loss 0.14995835891535178


100%|██████████| 1286/1286 [00:48<00:00, 26.76it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 23 | Train Loss 0.15033724553257366


100%|██████████| 1286/1286 [00:48<00:00, 26.62it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 24 | Train Loss 0.14990615892234152


100%|██████████| 1286/1286 [00:48<00:00, 26.68it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 25 | Train Loss 0.14982259248326857
[[0.03102799]
 [0.03530709]
 [0.03181672]
 [0.03400755]
 [0.04529916]
 [0.02514854]
 [0.05344829]
 [0.02659994]
 [0.04161073]
 [0.05155361]]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0 0 0 0 0 0 0 0 0 0]

 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 25 | Test Loss 0.14735667392738963


100%|██████████| 1286/1286 [00:48<00:00, 26.62it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 26 | Train Loss 0.149148637617172


100%|██████████| 1286/1286 [00:48<00:00, 26.45it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 27 | Train Loss 0.14977711243308722


100%|██████████| 1286/1286 [00:48<00:00, 26.38it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 28 | Train Loss 0.1489118650862672


100%|██████████| 1286/1286 [00:49<00:00, 26.24it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 29 | Train Loss 0.149187452167064


100%|██████████| 1286/1286 [00:49<00:00, 26.08it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 30 | Train Loss 0.14880700719804962
[[0.0347745 ]
 [0.02153724]
 [0.03579311]
 [0.03274637]
 [0.03384675]
 [0.04786965]
 [0.04062903]
 [0.0345025 ]
 [0.02323573]
 [0.04159446]]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0 0 0 0 0 0 0 0 0 1]

 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 30 | Test Loss 0.1473059817094176


100%|██████████| 1286/1286 [00:48<00:00, 26.55it/s]



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 31 | Train Loss 0.1490126855607676


100%|██████████| 1286/1286 [15:42<00:00,  1.36it/s]  



 Confusion matrix: 
 [[39684  1443]
 [    0     0]]
F1 Score: 0.0
Accuracy: 0.9649135604347508
Precision: 0.0
Recall: 0.0
ROC AUC: 0.5
Epoch 32 | Train Loss 0.148993801986031


 37%|███▋      | 476/1286 [00:18<00:32, 24.85it/s]