This Jupyter Notebook includes code adapted from the RouteNet project.

Original author:

Krzysztof Rusek
AGH University of Science and Technology, Department of Communications, Krakow, Poland.
Email: krusek@agh.edu.pl
This code is licensed under the BSD 3-Clause License. See the LICENSE file in this repository for details.

In [1]:
# TF2.x
import tensorflow.compat.v1 as tf

# TF1.x
# import tensorflow as tf

from tensorflow.python.util import deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False

from tensorflow import keras

import numpy as np
from dataclasses import dataclass, field
from typing import Optional
import glob

In [2]:
print(tf.__version__)
print(tf.config.list_physical_devices('GPU'))

2.15.1
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')]


In [3]:
def parse(serialized, target='delay'): #Target is the name of predicted variable
    
    with tf.name_scope('parse'):    
        features = tf.compat.v1.parse_single_example(
            serialized,
            features={
                'traffic':tf.compat.v1.VarLenFeature(tf.float32),
                target:tf.compat.v1.VarLenFeature(tf.float32),
                'links':tf.compat.v1.VarLenFeature(tf.int64),
                'paths':tf.compat.v1.VarLenFeature(tf.int64),
                'sequances':tf.compat.v1.VarLenFeature(tf.int64),
                'n_links':tf.compat.v1.FixedLenFeature([],tf.int64), 
                'n_paths':tf.compat.v1.FixedLenFeature([],tf.int64),
                'n_total':tf.compat.v1.FixedLenFeature([],tf.int64)
            }
        )
        
        for k in ['traffic',target,'links','paths','sequances']:
            features[k] = tf.compat.v1.sparse_tensor_to_dense(features[k])
            if k == 'delay':
                features[k] = (features[k]-2.8)/2.5
            if k == 'traffic':
                #features[k] = (features[k]-0.76)/.008
                features[k] = (features[k]-0.5)/.5
            if k == 'drops':
                features[k] = (features[k])/12000/(0.5*features['traffic']+0.5) #loss rate
            #if k == 'jitter':
                #features[k] = (tf.math.log( features[k] )-2.0)/2.0 #logjitter
            
    return {k:v for k,v in features.items() if k is not target },features[target]

In [4]:
def tfrecord_input_fn(filenames,hparams,shuffle_buf=1000, target='delay'):
    
    files = tf.data.Dataset.from_tensor_slices(filenames)
    files = files.shuffle(len(filenames))

    ds = files.interleave(tf.data.TFRecordDataset, cycle_length=4)

    if shuffle_buf:
        ds = ds.shuffle(shuffle_buf).repeat()
    
    # ds = ds.map(lambda buf:parse(buf,target), num_parallel_calls=2)
    ds = ds.map(lambda buf:parse(buf,target), num_parallel_calls=tf.data.AUTOTUNE)

    shapes=(
        {
        'traffic':[hparams.node_count*(hparams.node_count-1)],
        'links':[-1],
        'paths':[-1],
        'sequances':[-1],
        'n_links':[],
        'n_paths':[],
        'n_total':[]
        },
        [hparams.node_count*(hparams.node_count-1)]
    )
    
    ds = ds.padded_batch(hparams.batch_size,shapes)
    ds = ds.prefetch(1)
    
    return ds

In [5]:
class ComnetModel(tf.keras.Model):
    def __init__(self,hparams, output_units=1):
        super(ComnetModel, self).__init__()
        self.hparams = hparams

        self.edge_update = tf.compat.v1.nn.rnn_cell.GRUCell(hparams.link_state_dim, dtype=tf.float32)
        self.path_update = tf.compat.v1.nn.rnn_cell.GRUCell(hparams.path_state_dim, dtype=tf.float32)
        # self.edge_update = tf.keras.layers.GRUCell(hparams.link_state_dim)
        # self.path_update = tf.keras.layers.GRUCell(hparams.path_state_dim)
        
        self.readout = tf.keras.models.Sequential()
        
        self.readout.add(keras.layers.Dense(hparams.readout_units, activation=tf.nn.selu, kernel_regularizer=tf.keras.regularizers.l2(hparams.l2)))
        self.readout.add(keras.layers.Dropout(rate=hparams.dropout_rate))
        
        self.readout.add(keras.layers.Dense(hparams.readout_units, activation=tf.nn.selu, kernel_regularizer=tf.keras.regularizers.l2(hparams.l2)))
        self.readout.add(keras.layers.Dropout(rate=hparams.dropout_rate))
        
        self.readout.add(keras.layers.Dense(output_units, kernel_regularizer=tf.keras.regularizers.l2(hparams.l2)))
            
    def build(self, input_shape=None):
        del input_shape
        self.edge_update.build(tf.TensorShape([None,self.hparams.path_state_dim]))
        self.path_update.build(tf.TensorShape([None,self.hparams.link_state_dim]))
        self.readout.build(input_shape = [None,self.hparams.path_state_dim])
        self.built = True

    def call(self, inputs, training=False):
        f_ = inputs
        shape = tf.stack([f_['n_links'],self.hparams.link_state_dim], axis=0)
        link_state = tf.zeros(shape)
        shape = tf.stack([f_['n_paths'],self.hparams.path_state_dim-1], axis=0)
        path_state = tf.concat([tf.expand_dims(f_['traffic'],axis=1), tf.zeros(shape)], axis=1)

        links = f_['links'][0:f_["n_total"]]
        paths = f_['paths'][0:f_["n_total"]]
        seqs=  f_['sequances'][0:f_["n_total"]]
        
        for _ in range(self.hparams.T):
        
            h_tild = tf.gather(link_state,links)

            ids=tf.stack([paths, seqs], axis=1)            
            max_len = tf.reduce_max(seqs)+1
            shape = tf.stack([f_['n_paths'], max_len, self.hparams.link_state_dim])
            lens = tf.compat.v1.segment_sum(data=tf.ones_like(paths), segment_ids=paths)

            link_inputs = tf.scatter_nd(ids, h_tild, shape)
            outputs, path_state = tf.compat.v1.nn.dynamic_rnn(self.path_update,link_inputs,sequence_length=lens,initial_state=path_state,dtype=tf.float32)
            m = tf.gather_nd(outputs,ids)
            m = tf.compat.v1.unsorted_segment_sum(m, links ,f_['n_links'])
            _,link_state = self.edge_update(m, link_state)
        
        r = self.readout(path_state,training=training)
        
        return r

In [6]:
def streaming_pearson_correlation_dl(labels, predictions, weights=None):
    """
    Compute streaming Pearson correlation with distributed training compatibility.
    Handles both single-replica and multi-replica contexts.
    Returns (value_tensor, update_op) for use in eval_metric_ops.
    """
    global learning_strategy
    
    # Cast inputs to float32 for compatibility
    labels = tf.cast(labels, tf.float32)
    predictions = tf.cast(predictions, tf.float32)

    # Batch-wise computations
    batch_size = tf.cast(tf.size(labels), tf.float32)
    batch_mean_x = tf.reduce_mean(predictions)
    batch_mean_y = tf.reduce_mean(labels)
    batch_mean_x_squared = tf.reduce_mean(tf.square(predictions))
    batch_mean_y_squared = tf.reduce_mean(tf.square(labels))
    batch_mean_xy = tf.reduce_mean(predictions * labels)

    # Use tf.metrics to accumulate streaming values
    mean_x, update_mean_x = tf.compat.v1.metrics.mean(batch_mean_x, weights)
    mean_y, update_mean_y = tf.compat.v1.metrics.mean(batch_mean_y, weights)
    mean_x_squared, update_mean_x_squared = tf.compat.v1.metrics.mean(batch_mean_x_squared, weights)
    mean_y_squared, update_mean_y_squared = tf.compat.v1.metrics.mean(batch_mean_y_squared, weights)
    mean_xy, update_mean_xy = tf.compat.v1.metrics.mean(batch_mean_xy, weights)
    count, update_count = tf.compat.v1.metrics.mean(batch_size, weights)

    # Compute Pearson correlation
    def compute_correlation():
        covariance = mean_xy - mean_x * mean_y
        variance_x = mean_x_squared - tf.square(mean_x)
        variance_y = mean_y_squared - tf.square(mean_y)
        denominator = tf.sqrt(variance_x * variance_y)
        return tf.where(tf.greater(denominator, 0), covariance / denominator, 0.0)

    rho = compute_correlation()

    # Group all update ops
    update_op = tf.group(
        update_mean_x, update_mean_y, update_mean_x_squared,
        update_mean_y_squared, update_mean_xy, update_count
    )

    # Distributed training compatibility
    if learning_strategy and learning_strategy.num_replicas_in_sync > 1:
        
        # Aggregate distributed values using the strategy context
        def aggregate_rho(learning_strategy, rho):
            return learning_strategy.reduce(tf.distribute.ReduceOp.MEAN, rho, axis=None)

        # Use merge_call to enter cross-replica context
        rho = tf.distribute.get_replica_context().merge_call(lambda ctx: aggregate_rho(ctx, rho))

    return rho, update_op

In [7]:
def model_fn(features, labels, mode, params):
    
    model = ComnetModel(params)
    model.build()
    #model.summary()

    predictions = tf.map_fn(lambda x: model(x,training=mode==tf.estimator.ModeKeys.TRAIN), features,dtype=tf.float32)
    #predictions = model(features,training=mode==tf.estimator.ModeKeys.TRAIN)
    predictions = tf.squeeze(predictions)

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode, predictions={'predictions':predictions})

    loss =  tf.compat.v1.losses.mean_squared_error(labels=labels, predictions = predictions, reduction=tf.compat.v1.losses.Reduction.MEAN)

    regularization_loss = sum(model.losses)
    total_loss = loss + regularization_loss
    
    tf.summary.scalar('loss', loss)
    tf.summary.scalar('regularization_loss', regularization_loss)

    if mode == tf.estimator.ModeKeys.EVAL:
        
        return tf.estimator.EstimatorSpec(
            mode,
            loss=loss,
            eval_metric_ops=
            {
                'label/mean':tf.compat.v1.metrics.mean(labels),
                'prediction/mean': tf.compat.v1.metrics.mean(predictions),
                
                'mae':tf.compat.v1.metrics.mean_absolute_error(labels, predictions),
                'mse': tf.compat.v1.metrics.mean_squared_error(labels=labels, predictions=predictions),
                
                #'rho_contrib':tf.contrib.metrics.streaming_pearson_correlation(labels=labels,predictions=predictions),
                'rho': streaming_pearson_correlation_dl(labels, predictions)
            }
        )
    
    assert mode == tf.estimator.ModeKeys.TRAIN

    trainables = model.variables
    grads = tf.gradients(total_loss, trainables)
    grad_var_pairs = zip(grads, trainables)

    summaries = [tf.summary.histogram(var.op.name, var) for var in trainables]
    summaries += [tf.summary.histogram(g.op.name, g) for g in grads if g is not None]

    # optimizer=tf.compat.v1.train.AdamOptimizer(params.learning_rate)
    optimizer = tfa.optimizers.LAMB(learning_rate=params.learning_rate, weight_decay=params.l2)
    
    update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
    
    # with tf.control_dependencies(update_ops):
    #     train_op = optimizer.apply_gradients(grad_var_pairs, global_step=tf.compat.v1.train.get_global_step())

    global_step = tf.compat.v1.train.get_global_step()
    with tf.control_dependencies(update_ops):
        train_op = tf.group(optimizer.apply_gradients(grad_var_pairs), tf.compat.v1.assign_add(global_step, 1))

    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op,)

In [8]:
def train(args):
    
    print(args.hparams)
    tf.compat.v1.logging.set_verbosity('INFO')

    global learning_strategy
    learning_strategy = args.strategy

    my_checkpointing_config = tf.estimator.RunConfig(
        train_distribute=args.strategy,
        eval_distribute=args.strategy,
        save_checkpoints_steps=1000,
        keep_checkpoint_max=50
    )

    estimator = tf.estimator.Estimator(
        model_fn=model_fn, 
        model_dir=args.model_dir, 
        params=args.hparams, 
        warm_start_from=args.warm, 
        config=my_checkpointing_config)

    train_spec = tf.estimator.TrainSpec(input_fn=lambda:tfrecord_input_fn(args.train,args.hparams,shuffle_buf=args.shuffle_buf,target=args.target),max_steps=args.train_steps)
    eval_spec = tf.estimator.EvalSpec(input_fn=lambda:tfrecord_input_fn(args.eval_,args.hparams,shuffle_buf=None,target=args.target),steps=args.eval_steps,throttle_secs=1)

    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

In [9]:
tfrecord_train_files = glob.glob("nsfnet/tfrecords/train/*.tfrecords")
tfrecord_eval_files = glob.glob("nsfnet/tfrecords/evaluate/*.tfrecords")

@dataclass
class HyperParams:
    node_count: int = 14
    link_state_dim: int = 16
    path_state_dim: int = 32
    T: int = 8
    readout_units: int = 256
    learning_rate: float = 0.001
    batch_size: int = 32
    dropout_rate: float = 0.5
    l2: float = 0.01
    l2_2: float = 0.01

@dataclass
class Args:
    target: str = "delay"
    strategy: Optional[str] = tf.distribute.MirroredStrategy() #None
    hparams: HyperParams = HyperParams()
    train: list = field(default_factory=lambda: tfrecord_train_files)
    eval_: list = field(default_factory=lambda: tfrecord_eval_files)
    model_dir: str = "cp2_rn0dl_models_10022025_distributed"
    train_steps: int = 50000
    eval_steps: Optional[int] = None
    shuffle_buf: int = 30000
    warm: Optional[str] = None

args = Args()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')


In [10]:
%time train(args)

HyperParams(node_count=14, link_state_dim=16, path_state_dim=32, T=8, readout_units=256, learning_rate=0.001, batch_size=32, dropout_rate=0.5, l2=0.01, l2_2=0.01)
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Using config: {'_model_dir': 'cp2_rn0dl_models_10022025_distributed', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 1000, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 50, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x148068114ca0>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x148068114ca0>, '_experimental_distribute'