Skip to content

Commit

Permalink
updating example notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
pstjohn committed Oct 20, 2021
1 parent 545596e commit 88eaef1
Showing 1 changed file with 89 additions and 72 deletions.
161 changes: 89 additions & 72 deletions examples/creating_and_training_a_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"tensorflow 2.3.1\n",
"nfp 0.2.0+0.g72982c1.dirty\n"
"tensorflow 2.5.0\n",
"nfp 0.3.5\n"
]
}
],
Expand Down Expand Up @@ -171,6 +171,9 @@
"outputs": [],
"source": [
"# Define how to featurize the input molecules\n",
"from nfp.preprocessing.mol_preprocessor import SmilesPreprocessor\n",
"from nfp.preprocessing.features import get_ring_size\n",
"\n",
"\n",
"def atom_featurizer(atom):\n",
" \"\"\" Return an string representing the atom type\n",
Expand All @@ -179,7 +182,7 @@
" return str((\n",
" atom.GetSymbol(),\n",
" atom.GetIsAromatic(),\n",
" nfp.get_ring_size(atom, max_size=6),\n",
" get_ring_size(atom, max_size=6),\n",
" atom.GetDegree(),\n",
" atom.GetTotalNumHs(includeNeighbors=True)\n",
" ))\n",
Expand All @@ -199,13 +202,13 @@
" bond.GetBeginAtom().GetSymbol())))\n",
" \n",
" btype = str(bond.GetBondType())\n",
" ring = 'R{}'.format(nfp.get_ring_size(bond, max_size=6)) if bond.IsInRing() else ''\n",
" ring = 'R{}'.format(get_ring_size(bond, max_size=6)) if bond.IsInRing() else ''\n",
" \n",
" return \" \".join([atoms, btype, ring]).strip()\n",
"\n",
"\n",
"preprocessor = nfp.SmilesPreprocessor(atom_features=atom_featurizer, bond_features=bond_featurizer,\n",
" explicit_hs=False)"
"preprocessor = SmilesPreprocessor(atom_features=atom_featurizer, bond_features=bond_featurizer,\n",
" explicit_hs=False)"
]
},
{
Expand Down Expand Up @@ -247,7 +250,7 @@
{
"data": {
"text/plain": [
"array([ 2, 3, 14])"
"array([ 2, 3, 14], dtype=int32)"
]
},
"execution_count": 6,
Expand All @@ -271,7 +274,7 @@
{
"data": {
"text/plain": [
"array([2, 2, 5, 6])"
"array([2, 2, 5, 6], dtype=int32)"
]
},
"execution_count": 7,
Expand Down Expand Up @@ -324,32 +327,63 @@
"train_dataset = tf.data.Dataset.from_generator(\n",
" lambda: ((preprocessor(row.SMILES, train=False), row.YSI)\n",
" for i, row in ysi[ysi.SMILES.isin(train)].iterrows()),\n",
" output_types=(preprocessor.output_types, tf.float32),\n",
" output_shapes=(preprocessor.output_shapes, []))\\\n",
" output_signature=(preprocessor.output_signature, tf.TensorSpec((), dtype=tf.float32)))\\\n",
" .cache().shuffle(buffer_size=200)\\\n",
" .padded_batch(batch_size=64, \n",
" padded_shapes=(preprocessor.padded_shapes(), []),\n",
" padding_values=(preprocessor.padding_values, 0.))\\\n",
" .padded_batch(batch_size=64)\\\n",
" .prefetch(tf.data.experimental.AUTOTUNE)\n",
"\n",
"\n",
"valid_dataset = tf.data.Dataset.from_generator(\n",
" lambda: ((preprocessor(row.SMILES, train=False), row.YSI)\n",
" for i, row in ysi[ysi.SMILES.isin(valid)].iterrows()),\n",
" output_types=(preprocessor.output_types, tf.float32),\n",
" output_shapes=(preprocessor.output_shapes, []))\\\n",
" output_signature=(preprocessor.output_signature, tf.TensorSpec((), dtype=tf.float32)))\\\n",
" .cache()\\\n",
" .padded_batch(batch_size=64, \n",
" padded_shapes=(preprocessor.padded_shapes(), []),\n",
" padding_values=(preprocessor.padding_values, 0.))\\\n",
" .padded_batch(batch_size=64)\\\n",
" .prefetch(tf.data.experimental.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"array([[ 2, 3, 12, ..., 0, 0, 0],\n",
" [ 7, 8, 3, ..., 0, 0, 0],\n",
" [ 7, 4, 2, ..., 0, 0, 0],\n",
" ...,\n",
" [ 7, 8, 12, ..., 0, 0, 0],\n",
" [ 2, 3, 9, ..., 0, 0, 0],\n",
" [ 2, 8, 8, ..., 0, 0, 0]], dtype=int32)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs, outputs = next(train_dataset.as_numpy_iterator())\n",
"inputs['atom']"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /Users/pstjohn/mambaforge/envs/rlmol/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py:5043: calling gather (from tensorflow.python.ops.array_ops) with validate_indices is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.\n"
]
}
],
"source": [
"## Define the keras model\n",
"from tensorflow.keras import layers\n",
Expand Down Expand Up @@ -397,86 +431,72 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/25\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/pstjohn/miniconda3/envs/nfp/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py:543: UserWarning: Input dict contained keys ['n_atom', 'n_bond', 'bond_indices'] which did not match any model input. They will be ignored by the model.\n",
" [n for n in tensors.keys() if n not in ref_input_names])\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"6/6 [==============================] - 1s 201ms/step - loss: 192.5282 - val_loss: 156.3338\n",
"Epoch 1/25\n",
"6/6 [==============================] - 5s 245ms/step - loss: 192.0872 - val_loss: 155.4040\n",
"Epoch 2/25\n",
"6/6 [==============================] - 0s 10ms/step - loss: 191.6269 - val_loss: 154.2021\n",
"6/6 [==============================] - 0s 13ms/step - loss: 190.3157 - val_loss: 152.5118\n",
"Epoch 3/25\n",
"6/6 [==============================] - 0s 11ms/step - loss: 187.9288 - val_loss: 147.4886\n",
"6/6 [==============================] - 0s 21ms/step - loss: 186.1115 - val_loss: 145.6914\n",
"Epoch 4/25\n",
"6/6 [==============================] - 0s 11ms/step - loss: 178.0559 - val_loss: 131.6612\n",
"6/6 [==============================] - 0s 19ms/step - loss: 176.7000 - val_loss: 131.8439\n",
"Epoch 5/25\n",
"6/6 [==============================] - 0s 11ms/step - loss: 162.7336 - val_loss: 124.0303\n",
"6/6 [==============================] - 0s 15ms/step - loss: 163.8713 - val_loss: 122.8008\n",
"Epoch 6/25\n",
"6/6 [==============================] - 0s 11ms/step - loss: 162.0391 - val_loss: 125.7313\n",
"6/6 [==============================] - 0s 12ms/step - loss: 157.7288 - val_loss: 122.6877\n",
"Epoch 7/25\n",
"6/6 [==============================] - 0s 10ms/step - loss: 154.8945 - val_loss: 116.7943\n",
"6/6 [==============================] - 0s 12ms/step - loss: 153.1267 - val_loss: 117.7342\n",
"Epoch 8/25\n",
"6/6 [==============================] - 0s 10ms/step - loss: 149.2845 - val_loss: 113.8111\n",
"6/6 [==============================] - 0s 13ms/step - loss: 144.7063 - val_loss: 109.9825\n",
"Epoch 9/25\n",
"6/6 [==============================] - 0s 11ms/step - loss: 143.0998 - val_loss: 105.6354\n",
"6/6 [==============================] - 0s 13ms/step - loss: 133.6063 - val_loss: 101.3907\n",
"Epoch 10/25\n",
"6/6 [==============================] - 0s 10ms/step - loss: 127.5234 - val_loss: 93.4604\n",
"6/6 [==============================] - 0s 12ms/step - loss: 110.1535 - val_loss: 84.3457\n",
"Epoch 11/25\n",
"6/6 [==============================] - 0s 11ms/step - loss: 93.8345 - val_loss: 65.8470\n",
"6/6 [==============================] - 0s 12ms/step - loss: 77.5999 - val_loss: 61.5951\n",
"Epoch 12/25\n",
"6/6 [==============================] - 0s 11ms/step - loss: 73.4662 - val_loss: 62.8007\n",
"6/6 [==============================] - 0s 13ms/step - loss: 58.9964 - val_loss: 52.2004\n",
"Epoch 13/25\n",
"6/6 [==============================] - 0s 12ms/step - loss: 65.6246 - val_loss: 61.4815\n",
"6/6 [==============================] - 0s 13ms/step - loss: 56.7010 - val_loss: 54.1540\n",
"Epoch 14/25\n",
"6/6 [==============================] - 0s 12ms/step - loss: 63.4636 - val_loss: 61.4872\n",
"6/6 [==============================] - 0s 12ms/step - loss: 54.9731 - val_loss: 49.3258\n",
"Epoch 15/25\n",
"6/6 [==============================] - 0s 12ms/step - loss: 61.8959 - val_loss: 57.7244\n",
"6/6 [==============================] - 0s 13ms/step - loss: 49.4758 - val_loss: 49.8511\n",
"Epoch 16/25\n",
"6/6 [==============================] - 0s 11ms/step - loss: 58.3664 - val_loss: 52.4185\n",
"6/6 [==============================] - 0s 13ms/step - loss: 49.3332 - val_loss: 48.4715\n",
"Epoch 17/25\n",
"6/6 [==============================] - 0s 10ms/step - loss: 59.1145 - val_loss: 58.3494\n",
"6/6 [==============================] - 0s 12ms/step - loss: 46.6994 - val_loss: 43.7063\n",
"Epoch 18/25\n",
"6/6 [==============================] - 0s 11ms/step - loss: 56.9374 - val_loss: 52.7102\n",
"6/6 [==============================] - 0s 13ms/step - loss: 48.1494 - val_loss: 47.9336\n",
"Epoch 19/25\n",
"6/6 [==============================] - 0s 11ms/step - loss: 55.3144 - val_loss: 53.1646\n",
"6/6 [==============================] - 0s 12ms/step - loss: 50.2123 - val_loss: 46.2343\n",
"Epoch 20/25\n",
"6/6 [==============================] - 0s 12ms/step - loss: 53.2501 - val_loss: 48.5530\n",
"6/6 [==============================] - 0s 12ms/step - loss: 46.8542 - val_loss: 41.4796\n",
"Epoch 21/25\n",
"6/6 [==============================] - 0s 14ms/step - loss: 52.1931 - val_loss: 53.1881\n",
"6/6 [==============================] - 0s 12ms/step - loss: 46.0189 - val_loss: 52.0516\n",
"Epoch 22/25\n",
"6/6 [==============================] - 0s 11ms/step - loss: 53.9029 - val_loss: 46.2755\n",
"6/6 [==============================] - 0s 13ms/step - loss: 47.6934 - val_loss: 40.5394\n",
"Epoch 23/25\n",
"6/6 [==============================] - 0s 11ms/step - loss: 51.6329 - val_loss: 47.9711\n",
"6/6 [==============================] - 0s 13ms/step - loss: 48.3689 - val_loss: 47.6283\n",
"Epoch 24/25\n",
"6/6 [==============================] - 0s 12ms/step - loss: 51.6885 - val_loss: 46.3855\n",
"6/6 [==============================] - 0s 14ms/step - loss: 45.0009 - val_loss: 41.7982\n",
"Epoch 25/25\n",
"6/6 [==============================] - 0s 11ms/step - loss: 49.6647 - val_loss: 43.7484\n"
"6/6 [==============================] - 0s 13ms/step - loss: 43.0285 - val_loss: 44.1419\n"
]
},
{
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x7fe7427bfe10>"
"<tensorflow.python.keras.callbacks.History at 0x7fd206450c10>"
]
},
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -491,7 +511,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -500,26 +520,23 @@
"test_dataset = tf.data.Dataset.from_generator(\n",
" lambda: (preprocessor(smiles, train=False)\n",
" for smiles in test),\n",
" output_types=preprocessor.output_types,\n",
" output_shapes=preprocessor.output_shapes)\\\n",
" .padded_batch(batch_size=64, \n",
" padded_shapes=preprocessor.padded_shapes(),\n",
" padding_values=preprocessor.padding_values)\\\n",
" output_signature=preprocessor.output_signature)\\\n",
" .padded_batch(batch_size=64)\\\n",
" .prefetch(tf.data.experimental.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"40.5671657333374"
"33.404793376922605"
]
},
"execution_count": 13,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -549,7 +566,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
"version": "3.8.10"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 88eaef1

Please sign in to comment.