# Development of AZ policy training procedure

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
print(gpus)
tf.config.experimental.set_memory_growth(gpus[0], True)
import nfp

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [3]:
import psycopg2
import pandas as pd

dbparams = {
    'dbname': 'bde',
    'port': 5432,
    'host': 'yuma.hpc.nrel.gov',
    'user': 'rlops',
    'password': 'jTeL85L!',
    'options': f'-c search_path=rl',
}

## Create the tensorflow dataset from the PostgresQL database

In [4]:
import io
import numpy as np
import logging
logging.getLogger().setLevel(logging.INFO)

import alphazero.config as config

def psql_generator():
    """ A python generator to yield rows from the Postgres database. Note, here I'm deferring
    the actual parsing of the binary data to a later function, which we can hopefully parallelize.
    
    The SQL command here selects 100 random game states, selected from the (unique) 100 most recent
    games (id is the row-id, always increasing with newer games; gameid is a unique game identifier)
    
    Essentially when this runs out; it should get re-called to grab new data. 
    """
        
    with psycopg2.connect(**dbparams) as conn:
        
        logging.info("Running SQL query")
        
        df = pd.read_sql_query("""
        with recent_replays as (
            select * from rl.stablepsj_replay where gameid in (
                select gameid from rl.stablepsj_game order by id desc limit %s))

        select distinct on (gameid) id, ranked_reward, data
            from recent_replays order by gameid, random();
        """, conn, params=(config.buffer_max_size,))
        
        for _, row in df.iterrows():
            yield (row.data.tobytes(), row.ranked_reward)
            

def parse_binary_data(binary_data, reward):
    """ Use io and numpy to parse the binary data from postgresQL
    """
    with io.BytesIO(binary_data.numpy()) as f:
        parsed_data = dict(np.load(f, allow_pickle=True).items())
    
    # This is something we could talk about; but I'm wondering if the best
    # loss function for a boolean reward is a binary crossentropy
    if reward == -1:
        reward = 0
        
    visit_probs = parsed_data.pop('visit_probs')
    return (parsed_data['atom'], parsed_data['bond'],
            parsed_data['connectivity'], int(reward), visit_probs)


def parse_data_tf(binary_data, reward):
    """tf.py_func wants a flat list of outputs, but here we restructure to
    keras's desired (inputs, outputs) format"""
    atom, bond, connectivity, reward, visit_probs = tf.py_function(
        parse_binary_data, inp=[binary_data, reward], 
        Tout=[tf.int64, tf.int64, tf.int64, tf.int64, tf.float32])
    
    # The py_func doesn't provide tensor shapes, and we'll need these for the
    # padded batch operation
    atom.set_shape([None, None])
    bond.set_shape([None, None])
    connectivity.set_shape([None, None, 2])
    reward.set_shape([])
    visit_probs.set_shape([None])        
    
    return ({'atom': atom, 'bond': bond, 'connectivity': connectivity},
            (reward, visit_probs))

In [5]:
batch_size = config.batch_size

dataset = tf.data.Dataset.from_generator(psql_generator, output_types=(tf.string, tf.float32))\
    .repeat()\
    .shuffle(config.buffer_max_size)\
    .map(parse_data_tf, num_parallel_calls=tf.data.experimental.AUTOTUNE)\
    .padded_batch(batch_size, 
    padding_values=({'atom': nfp.zero, 'bond': nfp.zero, 'connectivity': nfp.zero}, (nfp.zero, 0.)))\
    .prefetch(tf.data.experimental.AUTOTUNE)

Example dataset outputs

In [6]:
inputs, outputs = list(dataset.take(1))[0]
inputs['atom'].shape  # batch_size, max_actions_per_node, max_atoms_per_mol

INFO:root:Running SQL query
INFO:root:Running SQL query


TensorShape([32, 39, 19])

## Build the tensorflow model

specifically, we need to handle batches of actions to normalize the prior_logits by parent molecule

In [7]:
from alphazero.policy import policy_model
from tensorflow.keras import layers
from tensorflow.python.keras.losses import LossFunctionWrapper, losses_utils

def kl_with_logits(y_true, y_pred):
    """ It's typically more numerically stable *not* to perform the softmax,
    but instead define the loss based on the raw logit predictions. This loss
    function corrects a tensorflow omission where there isn't a KLD loss that
    accepts raw logits. """

    # Mask nan values in y_true with zeros
    y_true = tf.where(tf.math.is_finite(y_true), y_true, tf.zeros_like(y_true))

    return (
        tf.keras.losses.categorical_crossentropy(y_true, y_pred, from_logits=True) -
        tf.keras.losses.categorical_crossentropy(y_true, y_true, from_logits=False))


class KLWithLogits(LossFunctionWrapper):
    """ Keras sometimes wants these loss function wrappers to define how to
    reduce the loss over variable batch sizes """
    def __init__(self,
                 reduction=losses_utils.ReductionV2.AUTO,
                 name='kl_with_logits'):

        super(KLWithLogits, self).__init__(
            kl_with_logits,
            name=name,
            reduction=reduction)
    

class PolicyWrapper(layers.Layer):
    def build(self, input_shape):
        self.policy_model = policy_model()
        
    def call(self, inputs, mask=None):
        atom, bond, connectivity = inputs
    
        # Get the batch and action dimensions
        atom_shape = tf.shape(atom)
        batch_size = atom_shape[0]
        max_actions_per_node = atom_shape[1]
        
        # Flatten the inputs for running individually through the policy model
        atom_flat = tf.reshape(atom, [batch_size * max_actions_per_node, -1])
        bond_flat = tf.reshape(bond, [batch_size * max_actions_per_node, -1])
        connectivity_flat = tf.reshape(connectivity, [batch_size * max_actions_per_node, -1, 2])

        # Get the flat value and prior_logit predictions
        flat_values_logits, flat_prior_logits = self.policy_model([atom_flat, bond_flat, connectivity_flat])      
        
        # We put the parent node first in our batch inputs, so this slices
        # the value prediction for the parent
        value_preds = tf.reshape(flat_values_logits, [batch_size, max_actions_per_node, -1])[:, 0, 0]
        
        # Next we get a mask to see where we have valid actions and replace priors for
        # invalid actions with negative infinity (these get zeroed out after softmax).
        # We also only return prior_logits for the child nodes (not the first entry)
        action_mask = tf.reduce_any(tf.not_equal(atom, 0), axis=-1)  # zero is the padding element
        prior_logits = tf.reshape(flat_prior_logits, [batch_size, max_actions_per_node])
        masked_prior_logits = tf.where(action_mask, prior_logits,
                                       tf.ones_like(prior_logits) * prior_logits.dtype.min)[:, 1:]
        
        return value_preds, masked_prior_logits

In [8]:
# Here we actually build the tf.keras.Model to train

atom_class = layers.Input(shape=[None, None], dtype=tf.int64, name='atom')
bond_class = layers.Input(shape=[None, None], dtype=tf.int64, name='bond')
connectivity = layers.Input(shape=[None, None, 2], dtype=tf.int64, name='connectivity')

value_preds, masked_prior_logits = PolicyWrapper()([atom_class, bond_class, connectivity])

policy_trainer = tf.keras.Model([atom_class, bond_class, connectivity], [value_preds, masked_prior_logits])

policy_trainer.compile(
    optimizer=tf.keras.optimizers.Adam(1E-3),  # Do AZ list their optimizer?
    loss=[tf.keras.losses.BinaryCrossentropy(from_logits=True), KLWithLogits()]
)

## Train the model

note the losses indeed decrease.

In [9]:
# Next, turn off the logger and train for a longer period
logging.getLogger().setLevel(logging.WARN)
policy_trainer.fit(dataset, steps_per_epoch=500, epochs=10)

Epoch 1/10

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(context='talk', style='ticks',
        color_codes=True, rc={'legend.frameon': False})

%matplotlib inline

In [None]:
plt.plot(tf.nn.sigmoid(policy_trainer.predict(inputs)[0]) * 2 - 1, outputs[0], '.')

In [None]:
plt.plot(tf.nn.softmax(policy_trainer.predict(inputs)[1]), outputs[1], '.')

## How to get the weights from the trainied model

we'll probably need to load the policy_trainer, and then extract the policy_model sub-model?

In [None]:
policy = policy_model()
policy.set_weights(policy_trainer.layers[-1].policy_model.get_weights())

In [34]:
checkpoint_filepath = f'/scratch/pstjohn/policy_checkpoints/{config.experiment_id}'

In [35]:
checkpoint_filepath

'/scratch/pstjohn/policy_checkpoints/0001'

In [20]:
policy_trainer = tf.keras.models.load_model(checkpoint_filepath, compile=False)

In [42]:
policy_trainer.layers[-1].policy_model

<tensorflow.python.keras.engine.training.Model at 0x7f3f00382950>

In [74]:
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_filepath)

In [75]:
latest_checkpoint

'/scratch/pstjohn/policy_checkpoints/0001/policy.22'

In [71]:
model = policy_model()

In [72]:
policy_trainer.load_weights(latest_checkpoint)
# pmodel = policy_trainer.layers[-1].policy_model

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f3eccb6f190>

In [73]:
pmodel.layers[-1].get_weights()

[array([[ 0.10591211],
        [ 0.28725603],
        [-0.6236239 ],
        [ 0.06739221],
        [-0.52482575],
        [-0.642917  ],
        [-0.11559153],
        [-0.40216187],
        [ 0.10140881],
        [-0.64844733],
        [-0.9934893 ],
        [ 0.41949975],
        [-0.08478913],
        [ 0.14393784],
        [-0.30324477],
        [-0.24709758]], dtype=float32),
 array([0.027227], dtype=float32)]