# Neural structured learning with TensorFlow 2.0

In this example, we'll demonstrate how to use [Neural Structured Learning](https://www.tensorflow.org/neural_structured_learning) (NSl) to augmented the training process of unstructured data with structured signals. We'll train the NSL model using the [CORA](https://relational.fit.cvut.cz/dataset/CORA) dataset. It contains scientific publications, separated in 7 categories. Each publication is represented with multi-hot encoded vector, which indicates all the words present in the publication text. This represents the unstructured portion of the dataset. Additionally, the dataset contains a graph of citations between the publications, which represents the structured portion. In this example, we'll use the multi-hot encoded publications as input to a classification network, which will classify each publication in its corresponding class. We'll also augment the training with the structured graph of citations. <br />
_Fort more details, check out the book._

Let's start with the imports:

In [1]:
import neural_structured_learning as nsl
import tensorflow as tf

Next, let's define the parameters of the training and the model:

In [2]:
# Cora dataset path
TRAIN_DATA_PATH = 'data/train_merged_examples.tfr'
TEST_DATA_PATH = 'data/test_examples.tfr'
# Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'
# Dataset parameters
NUM_CLASSES = 7
MAX_SEQ_LENGTH = 1433
# Number of neighbors to consider in the composite loss function
NUM_NEIGHBORS = 1
# Training parameters
BATCH_SIZE = 128

We'll continue with the pre-processing of the dataset, which is implemented with the `make_dataset` (on dataset level) and `parse_example` (on sample level) functions. The final result is an instance of `tf.data.TFRecordDataset`:

In [3]:
def parse_example(example_proto: tf.train.Example) -> tuple:
    """
    Extracts relevant fields from the example_proto
    :param example_proto: the example
    :return: A data/label pair, where
        data is a dictionary containing relevant features of the example
    """
    # The 'words' feature is a multi-hot, bag-of-words representation of the
    # original raw text. A default value is required for examples that don't
    # have the feature.
    feature_spec = {
        'words':
            tf.io.FixedLenFeature(shape=[MAX_SEQ_LENGTH],
                                  dtype=tf.int64,
                                  default_value=tf.constant(
                                      value=0,
                                      dtype=tf.int64,
                                      shape=[MAX_SEQ_LENGTH])),
        'label':
            tf.io.FixedLenFeature((), tf.int64, default_value=-1),
    }
    # We also extract corresponding neighbor features in a similar manner to
    # the features above.
    for i in range(NUM_NEIGHBORS):
        nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
        nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i, NBR_WEIGHT_SUFFIX)
        feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
            shape=[MAX_SEQ_LENGTH],
            dtype=tf.int64,
            default_value=tf.constant(
                value=0, dtype=tf.int64, shape=[MAX_SEQ_LENGTH]))

        # We assign a default value of 0.0 for the neighbor weight so that
        # graph regularization is done on samples based on their exact number
        # of neighbors. In other words, non-existent neighbors are discounted.
        feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
            shape=[1], dtype=tf.float32, default_value=tf.constant([0.0]))

    features = tf.io.parse_single_example(example_proto, feature_spec)

    labels = features.pop('label')
    return features, labels


def make_dataset(file_path: str, training=False) -> tf.data.TFRecordDataset:
    """
    Extracts relevant fields from the example_proto
    :param file_path: name of the file in the `.tfrecord` format containing
        `tf.train.Example` objects.
    :param training: whether the dataset is for training
    :return: tf.data.TFRecordDataset of tf.train.Example objects
    """
    dataset = tf.data.TFRecordDataset([file_path])
    if training:
        dataset = dataset.shuffle(10000)
    dataset = dataset.map(parse_example).batch(BATCH_SIZE)

    return dataset

Next, we'll instantiate the dataset:

In [4]:
train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)

Next, let's build the model. It starts as a regular feedforward neural network:

In [5]:
def build_model(dropout_rate):
    """Creates a sequential multi-layer perceptron model."""
    return tf.keras.Sequential([
        # one-hot encoded input.
        tf.keras.layers.InputLayer(
            input_shape=(MAX_SEQ_LENGTH,), name='words'),

        # 2 fully-connected layers + dropout
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dropout(dropout_rate),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dropout(dropout_rate),

        # Softmax output
        tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
    ])


# Build a new base MLP model.
model = build_model(dropout_rate=0.5)

Next, we'll wrap the model with the graph regularization procedure, which allows to include the structured component of the dataset:

In [6]:
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
    max_neighbors=NUM_NEIGHBORS,
    multiplier=0.1,
    distance_type=nsl.configs.DistanceType.L2,
    sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(model,
                                                graph_reg_config)

We can now run the training for 100 epochs:

In [7]:
graph_reg_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])

# run eagerly to prevent epoch warnings
graph_reg_model.run_eagerly = True

graph_reg_model.fit(train_dataset, epochs=100, verbose=1)

W1208 21:10:47.376572 139984963245888 deprecation.py:323] From /usr/local/lib/python3.7/dist-packages/tensorflow_core/python/ops/array_grad.py:502: _EagerTensorBase.cpu (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.identity instead.


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100


Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 79/100
Epoch 80/100
Epoch 81/100
Epoch 82/100
Epoch 83/100
Epoch 84/100
Epoch 85/100
Epoch 86/100
Epoch 87/100
Epoch 88/100
Epoch 89/100
Epoch 90/100
Epoch 91/100
Epoch 92/100
Epoch 93/100
Epoch 94/100
Epoch 95/100
Epoch 96/100
Epoch 97/100
Epoch 98/100
Epoch 99/100
Epoch 100/100


<tensorflow.python.keras.callbacks.History at 0x7f50401e20b8>

Finally, we can evaluate the results. As we can see, the model achieves more than 81% classification accuracy, which is an increase of around 3% compared to a model without graph regularization:

In [8]:
eval_results = dict(
    zip(graph_reg_model.metrics_names,
        graph_reg_model.evaluate(test_dataset)))
print('Evaluation accuracy: {}'.format(eval_results['accuracy']))
print('Evaluation loss: {}'.format(eval_results['loss']))
print('Evaluation graph loss: {}'.format(eval_results['graph_loss']))

Evaluation accuracy: 0.8101266026496887
Evaluation loss: 1.0497615039348602
Evaluation graph loss: 0.0
