In [1]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from src.model import D_GAT
from src.mol_processing import Read_mol_data, Generate_dataloader
from src.train_eval import Train_eval, Test_NN

You need to define 
1. The datset to finetune: dataset
2. The task name in file: task_name
3. The type of tasks: target_type ('classification', 'regression')
4. The metrics to evaluate: metrics ('AUC', 'RMSE', 'MAE')
5. The name of stored model: store_name (dataset + '.pth' or None)

However, if you test the datasets used in our paper (Tox21 SIDER MUV HIV BBBP BACE ClinTox ToxCast PCBA ESOL FreeSolv Lipo QM7 QM8 QM9), you only need to define the dataset.

In [2]:
dataset = 'FreeSolv' # Tox21 SIDER MUV HIV BBBP BACE ClinTox ToxCast PCBA ESOL FreeSolv Lipo QM7 QM8 QM9
config_file_path = './config/config.json'

if dataset in ['HIV', 'BBBP', 'Tox21', 'SIDER', 'MUV', 'BACE', 'ClinTox', 'ToxCast', 'PCBA']:
    target_type = 'classification'
    metrics = 'AUC'
    task_name = None
elif dataset in ['ESOl', 'Lipo', 'FreeSolv']:
    target_type = 'regression'
    metrics = 'RMSE'
    task_name = None
elif dataset in [ 'QM7', 'QM8', 'QM9']:
    target_type = 'regression'
    metrics = 'MAE'
    task_name = None
else:
    print('Please define the target type, task_name and metrics!')
#     target_type = 'classification'
#     metrics = 'AUC'
#     target_type = 'regression'
#     metrics = 'RMSE'
#     metrics = 'MAE'
#     task_name = ['u0_atom'] #for QM7

store_name = dataset + '.pth'
# store_name = None

Next is to load and process data, and load pre-training model. Nothing to define

In [3]:
assert target_type in ['classification', 'regression']
if target_type == 'classification':
    assert metrics in ['AUC']
elif target_type == 'regression':
    assert metrics in ['RMSE', 'MAE']
    
mol_train, mol_val, mol_test, mean, std = Read_mol_data(dataset, task_name, target_type)
train_dataloader, val_dataloader,test_dataloader = Generate_dataloader(dataset, mol_train, mol_val, mol_test)
model, best_score = D_GAT(dataset, mol_train, config_file_path)


Read and process the collected data...
----------------------------------------
Dataset:  FreeSolv
Example: 
iupac     methanesulfonyl chloride
smiles                CS(=O)(=O)Cl
expt                         -4.87
calc                        -6.219
Name: 0, dtype: object
Number of molecules: 642
1 / 1  finished!
Training dataset finished
Val dataset finished
Test dataset finished
Load PreTraining mdoel:  ./model/PreTraining.pth
Model prepared!


If you load the fine-tuned model, next section could be used to evaluate its performance

In [4]:
# ## Following code is to evaluate the fine-tuning model
# Loss, auc = Test_NN(dataset, model, test_dataloader, metrics, target_type, mean, std)
# print('Mean Loss: ', Loss)
# if target_type == 'classification':
#     print('Mean AUC: ', auc.mean())

You are going to fine-tune your model. It may take some time.

In [5]:
# To fine-tune the model
model, best_score = Train_eval(dataset, model, train_dataloader, val_dataloader,test_dataloader, best_score, config_file_path, store_name, metrics, target_type, mean, std)


For first training part:
New best model saved!
|  1/20 epochs | lr 1.0e-04 |  1 s | Train 3.45316 | Val 6.18611 | Test 4.52748
New best model saved!
|  2/20 epochs | lr 9.9e-05 |  1 s | Train 2.60138 | Val 4.88129 | Test 3.30626
New best model saved!
|  3/20 epochs | lr 9.9e-05 |  1 s | Train 2.16218 | Val 4.37161 | Test 3.07191
New best model saved!
|  4/20 epochs | lr 9.9e-05 |  1 s | Train 1.93665 | Val 3.90612 | Test 2.75351
New best model saved!
|  5/20 epochs | lr 9.8e-05 |  1 s | Train 1.80044 | Val 3.72002 | Test 2.67885
New best model saved!
|  6/20 epochs | lr 9.8e-05 |  2 s | Train 1.71770 | Val 3.44444 | Test 2.42028
New best model saved!
|  7/20 epochs | lr 9.8e-05 |  1 s | Train 1.68629 | Val 3.44088 | Test 2.43212
New best model saved!
|  8/20 epochs | lr 9.7e-05 |  1 s | Train 1.65436 | Val 3.28742 | Test 2.28490
New best model saved!
|  9/20 epochs | lr 9.7e-05 |  1 s | Train 1.61782 | Val 3.24079 | Test 2.28508
New best model saved!
| 10/20 epochs | lr 9.7e-05 |  1 s 