In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from lstm_chem.utils.config import process_config
from lstm_chem.model import LSTMChem
from lstm_chem.generator import LSTMChemGenerator

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
CONFIG_FILE = 'experiments/2020-03-24/LSTM_Chem/config.json'
config = process_config(CONFIG_FILE)

In [3]:
modeler = LSTMChem(config, session='generate')

Loading model architecture from experiments/2020-03-24/LSTM_Chem/model_arch.json ...
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
Loading model checkpoint from experiments/2020-03-24/LSTM_Chem/checkpoints/LSTM_Chem-22-0.45.hdf5 ...
Loaded the Model.


In [7]:
# It took more than 4 hours!
generator = LSTMChemGenerator(modeler)
sampled_smiles = generator.sample(num=100)


  0%|                                                                                          | 0/100 [00:00<?, ?it/s][A
  1%|▊                                                                                 | 1/100 [00:00<00:38,  2.56it/s][A
  2%|█▋                                                                                | 2/100 [00:00<00:40,  2.39it/s][A
  3%|██▍                                                                               | 3/100 [00:01<00:42,  2.31it/s][A
  4%|███▎                                                                              | 4/100 [00:01<00:40,  2.38it/s][A
  5%|████                                                                              | 5/100 [00:02<00:39,  2.42it/s][A
  6%|████▉                                                                             | 6/100 [00:02<00:40,  2.29it/s][A
  7%|█████▋                                                                            | 7/100 [00:03<00:40,  2.28it/s][A
  8%|██████▌   

 66%|█████████████████████████████████████████████████████▍                           | 66/100 [00:28<00:13,  2.60it/s][A
 67%|██████████████████████████████████████████████████████▎                          | 67/100 [00:28<00:13,  2.45it/s][A
 68%|███████████████████████████████████████████████████████                          | 68/100 [00:29<00:12,  2.51it/s][A
 69%|███████████████████████████████████████████████████████▉                         | 69/100 [00:29<00:11,  2.63it/s][A
 70%|████████████████████████████████████████████████████████▋                        | 70/100 [00:30<00:12,  2.34it/s][A
 71%|█████████████████████████████████████████████████████████▌                       | 71/100 [00:30<00:13,  2.10it/s][A
 72%|██████████████████████████████████████████████████████████▎                      | 72/100 [00:30<00:12,  2.31it/s][A
 73%|███████████████████████████████████████████████████████████▏                     | 73/100 [00:31<00:10,  2.48it/s][A
 74%|███████████

In [8]:
from rdkit import RDLogger, Chem, DataStructs
from rdkit.Chem import AllChem, Descriptors
RDLogger.DisableLog('rdApp.*')

In [9]:
valid_mols = []
for smi in sampled_smiles:
    mol = Chem.MolFromSmiles(smi)
    if mol is not None:
        valid_mols.append(mol)
# low validity
print(f'{len(valid_mols) / 30000:.2%}')

0.19%


In [10]:
valid_smiles = [Chem.MolToSmiles(mol) for mol in valid_mols]
# high uniqueness
print(f'{len(set(valid_smiles)) / len(valid_smiles):.2%}')

100.00%


In [11]:
with open('./datasets/dataset_cleansed.smi') as f:
    org_smiles = [l.rstrip() for l in f]

org_mols = [mol for mol in [Chem.MolFromSmiles(smi) for smi in org_smiles] if mol is not None]

In [12]:
Vfps = []
for mol in valid_mols:
    bv = AllChem.GetMACCSKeysFingerprint(mol)
    fp = np.zeros(len(bv))
    DataStructs.ConvertToNumpyArray(bv, fp)
    Vfps.append(fp)

Ofps = []
for mol in org_mols:
    bv = AllChem.GetMACCSKeysFingerprint(mol)
    fp = np.zeros(len(bv))
    DataStructs.ConvertToNumpyArray(bv, fp)
    Ofps.append(fp)

KeyboardInterrupt: 

In [None]:
from sklearn.decomposition import PCA
Vlen = len(Vfps)
x = Vfps + Ofps
pca = PCA(n_components=2, random_state=71)
X = pca.fit_transform(x)

In [None]:
plt.figure(figsize=(12, 9))
plt.scatter(X[Vlen:, 0], X[Vlen:, 1], c='w', edgecolors='k', label='original')
plt.scatter(X[:Vlen, 0], X[:Vlen, 1], marker='+', label='generated')
plt.xlabel('PC 1')
plt.ylabel('PC 2')
plt.legend();

In [None]:
props = {
    'MW': {
        'generated': [Descriptors.ExactMolWt(mol) for mol in valid_mols],
        'original': [Descriptors.ExactMolWt(mol) for mol in org_mols]        
    },
    'logP': {
        'generated': [Descriptors.MolLogP(mol) for mol in valid_mols],
        'original': [Descriptors.MolLogP(mol) for mol in org_mols]
    }
}

In [None]:
fig = plt.figure(figsize=(12, 6))

ax1 = fig.add_subplot(1, 2, 1)
ax1.violinplot([props['MW']['original'], props['MW']['generated']])
ax1.set_xticks(ticks=[1, 2])
ax1.set_xticklabels(labels=['original', 'generated'])
ax1.set_title('MW')

ax2 = fig.add_subplot(1, 2, 2)
ax2.violinplot([props['logP']['original'], props['logP']['generated']])
ax2.set_xticks(ticks=[1, 2])
ax2.set_xticklabels(labels=['original', 'generated'])
ax2.set_title('clogP')

plt.tight_layout()