# Development of AZ policy training procedure

In [26]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [27]:
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
print(gpus)
tf.config.experimental.set_memory_growth(gpus[0], True)
import nfp

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [31]:
import psycopg2
import pandas as pd

dbparams = {
    'dbname': 'bde',
    'port': 5432,
    'host': 'yuma.hpc.nrel.gov',
    'user': 'rlops',
    'password': '***REMOVED***',
    'options': f'-c search_path=rl',
}

## Create the tensorflow dataset from the PostgresQL database

In [256]:
import io
import numpy as np

def psql_generator():
    with psycopg2.connect(**dbparams) as conn:
        df = pd.read_sql_query("""
        select * from (
            select distinct on (gameid) id, reward, data
            from rl.q2replay
            order by gameid, random()) as cte
        order by id desc limit 100
        """, conn)

        for _, row in df.iterrows():
            yield (row.data.tobytes(), row.reward)
            

def parse_binary_data(binary_data, reward):
    """ Use io and numpy to parse the binary data from postgresQL
    """
    with io.BytesIO(binary_data.numpy()) as f:
        parsed_data = dict(np.load(f, allow_pickle=True).items())
    
    # This is something we could talk about; but I'm wondering if the best
    # loss function for a boolean reward is a binary crossentropy
    if reward == -1:
        reward = 0
        
    visit_probs = parsed_data.pop('visit_probs')
    return (parsed_data['atom'], parsed_data['bond'],
            parsed_data['connectivity'], int(reward), visit_probs)


def parse_data_tf(binary_data, reward):
    """tf.py_func wants a flat list of outputs, but here we restructure to
    keras's desired (inputs, outputs) format"""
    atom, bond, connectivity, reward, visit_probs = tf.py_function(
        parse_binary_data, inp=[binary_data, reward], 
        Tout=[tf.int64, tf.int64, tf.int64, tf.int64, tf.float32])
    
    # The py_func doesn't provide tensor shapes, and we'll need these for the
    # padded batch operation
    atom.set_shape([None, None])
    bond.set_shape([None, None])
    connectivity.set_shape([None, None, 2])
    reward.set_shape([])
    visit_probs.set_shape([None])        
    
    return ({'atom': atom, 'bond': bond, 'connectivity': connectivity},
            (reward, visit_probs))

In [292]:
batch_size = 16

dataset = tf.data.Dataset.from_generator(psql_generator, output_types=(tf.string, tf.float32))\
    .repeat()\
    .shuffle(100)\
    .map(parse_data_tf, num_parallel_calls=tf.data.experimental.AUTOTUNE)\
    .padded_batch(batch_size, 
    padding_values=({'atom': nfp.zero, 'bond': nfp.zero, 'connectivity': nfp.zero}, (nfp.zero, 0.)))\
    .prefetch(tf.data.experimental.AUTOTUNE)

Example dataset outputs

In [293]:
inputs, outputs = list(dataset.take(1))[0]
inputs['atom'].shape  # batch_size, max_actions_per_node, max_atoms_per_mol

TensorShape([16, 23, 4])

## Build the tensorflow model

specifically, we need to handle batches of actions to normalize the prior_logits by parent molecule

In [298]:
from alphazero.policy import policy_model
from tensorflow.keras import layers
from tensorflow.python.keras.losses import LossFunctionWrapper, losses_utils

def kl_with_logits(y_true, y_pred):
    """ It's typically more numerically stable *not* to perform the softmax,
    but instead define the loss based on the raw logit predictions. This loss
    function corrects a tensorflow omission where there isn't a KLD loss that
    accepts raw logits. """

    # Mask nan values in y_true with zeros
    y_true = tf.where(tf.math.is_finite(y_true), y_true, tf.zeros_like(y_true))

    return (
        tf.keras.losses.categorical_crossentropy(y_true, y_pred, from_logits=True) -
        tf.keras.losses.categorical_crossentropy(y_true, y_true, from_logits=False))


class KLWithLogits(LossFunctionWrapper):
    """ Keras sometimes wants these loss function wrappers to define how to
    reduce the loss over variable batch sizes """
    def __init__(self,
                 reduction=losses_utils.ReductionV2.AUTO,
                 name='kl_with_logits'):

        super(KLWithLogits, self).__init__(
            kl_with_logits,
            name=name,
            reduction=reduction)
    

class PolicyWrapper(layers.Layer):
    def build(self, input_shape):
        self.policy_model = policy_model()
        
    def call(self, inputs, mask=None):
        atom, bond, connectivity = inputs
    
        # Get the batch and action dimensions
        atom_shape = tf.shape(atom)
        batch_size = atom_shape[0]
        max_actions_per_node = atom_shape[1]
        
        # Flatten the inputs for running individually through the policy model
        atom_flat = tf.reshape(atom, [batch_size * max_actions_per_node, -1])
        bond_flat = tf.reshape(bond, [batch_size * max_actions_per_node, -1])
        connectivity_flat = tf.reshape(connectivity, [batch_size * max_actions_per_node, -1, 2])

        # Get the flat value and prior_logit predictions
        flat_values, flat_prior_logits = self.policy_model([atom_flat, bond_flat, connectivity_flat])      
        
        # We put the parent node first in our batch inputs, so this slices
        # the value prediction for the parent
        value_preds = tf.reshape(flat_values, [batch_size, max_actions_per_node, -1])[:, 0, 0]
        
        # Next we get a mask to see where we have valid actions and replace priors for
        # invalid actions with negative infinity (these get zeroed out after softmax).
        # We also only return prior_logits for the child nodes (not the first entry)
        action_mask = tf.reduce_any(tf.not_equal(atom, 0), axis=-1)  # zero is the padding element
        prior_logits = tf.reshape(flat_prior_logits, [batch_size, max_actions_per_node])
        masked_prior_logits = tf.where(action_mask, prior_logits,
                                       tf.ones_like(prior_logits) * prior_logits.dtype.min)[:, 1:]
        
        return value_preds, masked_prior_logits

In [275]:
# Here we actually build the tf.keras.Model to train

atom_class = layers.Input(shape=[None, None], dtype=tf.int64, name='atom')
bond_class = layers.Input(shape=[None, None], dtype=tf.int64, name='bond')
connectivity = layers.Input(shape=[None, None, 2], dtype=tf.int64, name='connectivity')

value_preds, masked_prior_logits = PolicyWrapper()([atom_class, bond_class, connectivity])

policy_trainer = tf.keras.Model([atom_class, bond_class, connectivity], [value_preds, masked_prior_logits])

policy_trainer.compile(
    optimizer=tf.keras.optimizers.Adam(1E-4),  # Do AZ list their optimizer?
    loss=[tf.keras.losses.BinaryCrossentropy(), KLWithLogits()]
)

## Train the model

note the losses indeed decrease

In [278]:
policy_trainer.fit(dataset, steps_per_epoch=500, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7fbdbda52c50>

## How to get the weights from the trainied model

we'll probably need to load the policy_trainer, and then extract the policy_model sub-model?

In [283]:
policy = policy_model()
policy.set_weights(policy_trainer.layers[-1].policy_model.get_weights())