In [1]:
import pickle
import gzip
import argparse
import os
import sys
from smiles_rnn_distribution_learner import SmilesRnnDistributionLearner

from atalaya import Logger
graph_logger = Logger(
    name="QM9_batchszie_200_lstm",         # name of the logger
    path="logs",        # path to logs
    verbose=True,       # logger in verbose mode
    grapher="visdom",
    server="http://send2.visdom.xyz",
    port=8999
)

sys.path.append('../../')
from guacamol.utils.helpers import setup_default_logger




In [2]:
graph_logger

<atalaya.logger.Logger at 0x7f373d43c5f8>

In [3]:
parser = argparse.ArgumentParser(description='Distribution learning benchmark for SMILES RNN',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--data_path', default='../../data/QM9')
parser.add_argument('--train_data', default='../../data/QM9/QM9_clean_smi_train_smile.npz',
                    help='Full path to SMILES file containing training data')
parser.add_argument('--valid_data', default='',
                    help='Full path to SMILES file containing validation data')
parser.add_argument('--batch_size', default=200, type=int, help='Size of a mini-batch for gradient descent')
parser.add_argument('--valid_every', default=1000, type=int, help='Validate every so many batches')
parser.add_argument('--print_every', default=10, type=int, help='Report every so many batches')
parser.add_argument('--n_epochs', default=100, type=int, help='Number of training epochs')
parser.add_argument('--max_len', default=100, type=int, help='Max length of a SMILES string')
parser.add_argument('--hidden_size', default=512, type=int, help='Size of hidden layer')
parser.add_argument('--n_layers', default=3, type=int, help='Number of layers for training')
parser.add_argument('--rnn_dropout', default=0.2, type=float, help='Dropout value for RNN')
parser.add_argument('--lr', default=1e-3, type=float, help='RNN learning rate')
parser.add_argument('--seed', default=42, type=int, help='Random seed')
#parser.add_argument('--prop_model', default="../../data/QM9/prior.pkl.gz", help='Saved model for properties distribution')    
parser.add_argument('--output_dir', default='./output/QM9/', help='Output directory')
args,_ = parser.parse_known_args()


In [4]:
if args.output_dir is None:
    args.output_dir = os.path.dirname(os.path.realpath(__file__))

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

graph_logger.add_parameters(args)





In [5]:
trainer = SmilesRnnDistributionLearner(data_set = "QM9",
                                       graph_logger= graph_logger,
                                       output_dir=args.output_dir,
                                       n_epochs=args.n_epochs,
                                       hidden_size=args.hidden_size,
                                       n_layers=args.n_layers,
                                       max_len=args.max_len,
                                       batch_size=args.batch_size,
                                       rnn_dropout=args.rnn_dropout,
                                       lr=args.lr,
                                       #prop_model=prop_model,
                                       valid_every=args.valid_every)

#     training_set_file = args.train_data
#     validation_set_file = args.valid_data
# 
#     with open(training_set_file) as f:
#         train_list = f.readlines()
# 
#     with open(validation_set_file) as f:
#         valid_list = f.readlines()

trainer.train(args.data_path, training_set=args.train_data, validation_set = args.valid_data)
print(f'All done, your trained model is in {args.output_dir}')




CUDA enabled:	True
EPOCH 1
VALID | elapsed: 0:00:05 | epoch: 1/100 (0.0%) | molecules: 200 | valid_loss: 3.6126

model_1_3.613
TRAIN | elapsed: 0:00:08 | epoch|batch : 1|10 (0.0%) | molecules: 2200 | mols/sec: 272.09 | train_loss: 2.8474
TRAIN | elapsed: 0:00:11 | epoch|batch : 1|20 (0.0%) | molecules: 4200 | mols/sec: 380.64 | train_loss: 2.1033
TRAIN | elapsed: 0:00:14 | epoch|batch : 1|30 (0.1%) | molecules: 6200 | mols/sec: 437.04 | train_loss: 1.8999
TRAIN | elapsed: 0:00:17 | epoch|batch : 1|40 (0.1%) | molecules: 8200 | mols/sec: 474.78 | train_loss: 1.8150
TRAIN | elapsed: 0:00:20 | epoch|batch : 1|50 (0.1%) | molecules: 10200 | mols/sec: 492.70 | train_loss: 1.7339
TRAIN | elapsed: 0:00:24 | epoch|batch : 1|60 (0.1%) | molecules: 12200 | mols/sec: 508.06 | train_loss: 1.7188
TRAIN | elapsed: 0:00:27 | epoch|batch : 1|70 (0.1%) | molecules: 14200 | mols/sec: 524.56 | train_loss: 1.7153
TRAIN | elapsed: 0:00:30 | epoch|batch : 1|80 (0.1%) | molecules: 16200 | mols/sec: 538.60 | 

KeyboardInterrupt: 