## Molecule odor prediction via GINs with self-supervised contrastive pretraining



In [1]:
# toggle whether or not to use pretraining
use_pretraining = False

In [2]:
import torch
print("PyTorch has version {}".format(torch.__version__))
import torch_geometric
torch_geometric.__version__
import numpy as np

from torch_geometric.loader import DataLoader

import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, roc_auc_score, precision_score, recall_score
from tqdm import tqdm


import sys
sys.path.append('olf/GNNose/')
sys.path.append("GraphSSL/")


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

seed = 12345

torch_geometric.seed.seed_everything(seed)

PyTorch has version 1.13.1+cu117




In [13]:
# initialize model and finetune
from odor_model import ScentClassifier
from odor_train import train, test
from torch_ema import ExponentialMovingAverage

def run_finetuning(model, device, num_epochs, lr, weight_decay):
    """Train or finetune the given model using the train/val sets"""

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    ema = ExponentialMovingAverage(model.parameters(), decay=0.995)

    # track metrics
    roc_scores, f1_scores, accs, losses, val_losses, precision_scores, recall_scores = [], [], [], [], [], [], []

    for epoch in tqdm(range(num_epochs)):
        loss = train(model, optimizer, train_loader, 'train', device, weighted_BCE=True, ema=ema)
        val_loss = train(model, optimizer, val_loader, 'val', device, weighted_BCE=True)
        train_acc, train_preds, train_true  = test(model, train_loader, device)
        test_acc, test_preds, test_true = test(model, val_loader, device)

        # calculate bootstrapped ROC AUC score over entire val set
        _, whole_val_preds, whole_val_true = test(
            model, DataLoader(val_set, batch_size=len(val_set), shuffle=True), device)
        whole_val_preds = whole_val_preds.squeeze()
        whole_val_true = whole_val_true.squeeze()
        rocauc_score = 0
        # roc_auc_score(whole_val_true.cpu(), whole_val_preds.cpu())
        f1 = f1_score(whole_val_true.cpu(), whole_val_preds.cpu())
        precision = precision_score(whole_val_true.cpu(), whole_val_preds.cpu())
        recall = recall_score(whole_val_true.cpu(), whole_val_preds.cpu())


        # track metrics
        roc_scores.append(rocauc_score)
        f1_scores.append(f1)
        accs.append(test_acc)
        losses.append(loss)
        val_losses.append(val_loss)
        precision_scores.append(precision)
        recall_scores.append(recall)


    best_f1_score = max(f1_scores)
    best_f1_epoch = f1_scores.index(best_f1_score)
    best_auc_score = max(roc_scores)
    best_auc_epoch = roc_scores.index(best_auc_score)
    best_precison_score = max(precision_scores)
    best_recall_score = max(recall_scores)

    # plot_losses(losses, val_losses, title='finetuning: train vs. val loss')
    # fig, ax = plt.subplots(1, 2, figsize=(6, 2))
    # for i, (name, metric) in enumerate([
    #     ('f1', f1_scores), ('roc', roc_scores)
    # ]):
    #     ax[i].plot(range(len(metric)), metric)
    #     ax[i].set_ylim((0, 1))
    #     ax[i].set_title(name)
    # plt.show()

    return best_f1_score, best_f1_epoch, best_auc_score, best_auc_epoch, best_precison_score, best_recall_score


In [14]:
# from odor_data import get_graph_data
# graph_list = get_graph_data('ol','alcoholic') # 'ol' for Dream Olfaction dataset
# graph_list


In [15]:
# train_set, val_set, test_set = torch.utils.data.random_split(graph_list, [0.9, 0.05, 0.05])
# print(type(train_set))

In [21]:
# load and split the dataset
# 'alcoholic',
#  'aldehydic',
#  'alliaceous',
#  'almond',
#  'ambergris',
#  'ambery',

 # 'animalic',
 # 'anisic',
 # 'apple',
 # 'balsamic',
 # 'banana',
 # 'berry',
 # 'blackcurrant',
 # 'blueberry',
 # 'body',
 # 'bread',
 # 'burnt',
 # 'butter',
 # 'cacao',
 # 'camphor',
 # 'caramellic',
 # 'cedar',
 # 'cheese',
 # 'chemical',
 # 'cherry',
 # 'cinnamon',
 # 'citrus',
 # 'clean',
 # 'clove',
 # 'coconut',

 # 'cucumber',
 # 'dairy',
 # 'dry',
 # 'earthy',
 # 'ester',
 # 'ethereal',
 # 'fatty',
 # 'fermented',
 # 'floral',
 # 'fresh',
 # 'fruity',
 # 'geranium',
# 'grape',
#  'grapefruit',
#  'grass',
#  'green',
#  'herbal',
#  'honey',

# 'jasmin',
#  'lactonic',
#  'leaf',
#  'leather',
#  'lemon',
#  'lily',
#  'liquor',
#  'meat',
#  'medicinal',
#  'melon',
#  'metallic',
#  'mint',
#  'mushroom',
#  'musk',
#  'musty',
#  'nut',
#  'odorless',
#  'oily',
#  'orange',
# 'pear',
#  'pepper',
#  'phenolic',
#  'plastic',
#  'plum',
#  'powdery',
#  'pungent',
#  'rancid',
#  'resinous',
#  'ripe',
#  'roasted',
#  'rose',
#  'seafood',
 # 'sulfuric',
 # 'sweet',
 # 'syrup',
# 'tobacco',
#  'tropicalfruit',
#  'vanilla',
#  'vegetable',
#  'violetflower',
# 'watery',
#  'waxy',
#  'whiteflower',
#  'wine',
#  'woody'


 # 'coffee',
 # 'cognac',
 # 'coniferous',
 # 'cooked',
 # 'cooling',

 # 'fennel',

 # 'gourmand',
 
 # 'hyacinth',
 
 # 'overripe',
 
 # 'sharp',
 # 'smoky',
 # 'sour',
 # 'spicy',

 # 'terpenic',
 # 'ambrette',
# 
results = []
odor_list = [

 'ammoniac',

 
 
]

# odor_list = ["pungent","rose"]
from odor_data import get_graph_data

for odor in odor_list:
    print(odor)
    graph_list = get_graph_data('ol',odor) # 'ol' for Dream Olfaction dataset
    
    
    
    
    train_set, val_set, test_set = torch.utils.data.random_split(graph_list, [0.7, 0.05, 0.25])
    print(train_set)
    batch_size = 128
    
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader =  DataLoader(val_set, batch_size=batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, drop_last=True)
    
    print(train_set)
    hidden_channels = 128
    num_layers = 2
    dropout_p = 0.25
    pooling_type = 'max'
    
    
    # set hyperparameters
    in_channels = list(graph_list[2].x.shape)[-1]
    print(in_channels)
    out_channels = 1
    
    
    
    
    if use_pretraining:
        from odor_pretrain import build_pretraining_loader, pretrain
        from odor_model import PretrainingGIN
    
        pretrain_epochs = 50
        pretrain_batch_size = 256
        pretrain_lr = 1e-3
        pretrain_weight_decay = 1e-4
    
        # We apply this approach using the GraphSSL GitHub repository.
        pretrain_train_loader = build_pretraining_loader(train_set, "train", batch_size=pretrain_batch_size)
        pretrain_val_loader = build_pretraining_loader(val_set, "val", batch_size=pretrain_batch_size)
    
    
    
    # apply pretraining
    if use_pretraining:
        pretrain_model = PretrainingGIN(
            in_channels,
            hidden_channels,
            num_layers,
            out_channels,
            dropout=dropout_p
        ).to(device)
    
        pretrain_optimizer = torch.optim.Adam(pretrain_model.parameters(), lr=pretrain_lr, weight_decay=pretrain_weight_decay)
    
        val_losses = []
        train_losses = []
        for epoch in (range(pretrain_epochs)):
            train_loss = pretrain(pretrain_model, pretrain_optimizer, epoch, "train", pretrain_train_loader, device)
            val_loss = pretrain(pretrain_model, pretrain_optimizer, epoch, "val", pretrain_val_loader, device)
            log = "Epoch {}, Train Loss: {:.3f}, Val Loss: {:.3f}"
            print(log.format(epoch, train_loss, val_loss))
    
            train_losses.append(train_loss)
            val_losses.append(val_loss)
    
        # plot_losses(train_losses, val_losses, "pretraining loss")
    
    num_epochs = 50
    lr = 1e-4
    weight_decay = 1e-3
    
    
    num_trials = 1
    dropout_p = 0.1
    trial_results = []
    
    for t in range(num_trials):
        model = ScentClassifier(
            in_channels,
            hidden_channels,
            num_layers,
            out_channels,
            dropout=dropout_p,
            pooling_type=pooling_type,
        )
        model = model.to(device)
    
        if use_pretraining:
            model.gnn.load_state_dict(
                pretrain_model.gnn.state_dict()
            )
    
        best_f1_score, best_f1_epoch, best_auc_score, best_auc_epoch, best_precision_score, best_recall_score = run_finetuning(model, device, num_epochs, lr, weight_decay)
        trial_results.append((best_f1_score, best_f1_epoch, best_auc_score, best_auc_epoch, best_precision_score, best_recall_score))
    
    
    hyperparam_tuning = True
    
    if hyperparam_tuning:
        import itertools
    
        param_search = {
            "lr": [1e-2, 5e-3, 1e-3],
            "weight_decay": [1e-5, 1e-6],
        }
    
        param_search_results = []
    
        for lr, weight_decay in list(itertools.product(*list(param_search.values()))):
            # create model
            model = ScentClassifier(
                in_channels,
                hidden_channels,
                num_layers,
                out_channels,
                dropout=dropout_p,
                pooling_type=pooling_type,
            )
            model = model.to(device)
    
            # load pretrained weights, if relevant
            if use_pretraining:
                model.gnn.load_state_dict(
                    pretrain_model.gnn.state_dict()
                )
    
            # train
            best_f1_score, best_f1_epoch, best_auc_score, best_auc_epoch, best_precision_score, best_recall_score = run_finetuning(
                model, device, num_epochs, lr, weight_decay)
    
            param_search_results.append(
                ((best_f1_score, best_f1_epoch, best_auc_score, best_auc_epoch, best_precision_score, best_recall_score), (lr, weight_decay))
            )
    if hyperparam_tuning:
        # show best results, by F1 or AUC
        param_search_results.sort(key = lambda x: x[0][0])
        print("Best F1 score:", param_search_results[-1][0][0])
        # print("Best F1 epoch:", param_search_results[-1][0][1])
        print(
            "Best F1 params:",
            dict(zip(
                param_search.keys(),
                param_search_results[-1][-1]
            ))
        )
    
        param_search_results.sort(key = lambda x: x[0][2])
        print("Best AUC score:", param_search_results[-1][0][2])
        # print("Best AUC epoch:", param_search_results[-1][0][3])
        print(
            "Best AUC params:",
            dict(zip(
                param_search.keys(),
                param_search_results[-1][-1]
            ))
        )

        param_search_results.sort(key = lambda x: x[0][4])
        print("Best Precision score:", param_search_results[-1][0][4])
        # print("Best AUC epoch:", param_search_results[-1][0][3])
        print(
            "Best Precision params:",
            dict(zip(
                param_search.keys(),
                param_search_results[-1][-1]
            ))
        )

        param_search_results.sort(key = lambda x: x[0][5])
        print("Best Recall score:", param_search_results[-1][0][5])
        # print("Best AUC epoch:", param_search_results[-1][0][3])
        print(
            "Best Recall params:",
            dict(zip(
                param_search.keys(),
                param_search_results[-1][-1]
            ))
        )

        results.append((odor,param_search_results[-1][0][0],param_search_results[-1][0][2],param_search_results[-1][0][4],param_search_results[-1][0][5]))
        # torch.save(model.state_dict(), '/Users/dishant/Desktop/base_model_3_12_2023_{0}.pth'.format(odor))

ammoniac
<torch.utils.data.dataset.Subset object at 0x7f4c70112440>
<torch.utils.data.dataset.Subset object at 0x7f4c70112440>
12


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:25<00:00,  1.96it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(avera

Best F1 score: 0.8
Best F1 params: {'lr': 0.001, 'weight_decay': 1e-05}
Best AUC score: 0
Best AUC params: {'lr': 0.001, 'weight_decay': 1e-05}
Best Precision score: 0.6666666666666666
Best Precision params: {'lr': 0.001, 'weight_decay': 1e-05}
Best Recall score: 1.0
Best Recall params: {'lr': 0.001, 'weight_decay': 1e-05}



