Skip to content
This repository has been archived by the owner on Feb 25, 2022. It is now read-only.

Commit

Permalink
fix scalar summaries (@Mistobaan's code)
Browse files Browse the repository at this point in the history
  • Loading branch information
sdtblck committed Sep 17, 2020
1 parent 56bcb19 commit 9dfd7e0
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 115 deletions.
22 changes: 15 additions & 7 deletions model_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import mesh_tensorflow.transformer as mtf_transformer

from optimizers import get_optimizer
from utils import (TpuSummaries, get_graph_info, remove_batch_from_layout, simd_mesh_setup, add_mode_to_params, get_batch_size)
from utils import (create_host_call, get_graph_info, remove_batch_from_layout, simd_mesh_setup, add_mode_to_params, get_batch_size)
from models.utils import biasmask_attn_weights
from tensorflow.python.ops import resources
from sample import sample_autoregressive
Expand All @@ -22,9 +22,6 @@ def model_fn(features, labels, mode, params):
if mode == tf.estimator.ModeKeys.PREDICT:
params["layout"] = remove_batch_from_layout(params["layout"])
layout_rules = mtf.convert_to_layout_rules(params["layout"])

# init summary class
summary = TpuSummaries(params["model_path"])

# Mesh setup
if params["use_tpu"]:
Expand Down Expand Up @@ -188,10 +185,17 @@ def serialized_fn(mtf_features):
if params["num_microbatches"] > 1:
# if we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn
# so we pass them in here
_, update_ops, var_grads = get_optimizer(loss, params, summary, variable_dtype=variable_dtype, inp_var_grads=var_grads)
_, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype, inp_var_grads=var_grads)
else:
# otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank
_, update_ops, var_grads = get_optimizer(loss, params, summary, variable_dtype=variable_dtype)
_, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype)
# log summaries to tensorboard
mtf.scalar_summary("loss", loss)
# log gradients if in params
if params["log_grads"] is not None:
for g in var_grads:
grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g)))
mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm)
else:
# For now, we can only export fully-replicated tensors.
# This has to be done before lowering or they will not be included in the graph
Expand All @@ -210,6 +214,10 @@ def serialized_fn(mtf_features):
tf_loss = tf.cast(tf_loss, tf.float32)

if mode == tf.estimator.ModeKeys.TRAIN:
# use our patched version until mtf does not update theirs
host_call = create_host_call(params['model_path'])
mtf.utils.remove_summaries()

# creates update ops to pass into optimizer
tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
tf_update_ops.append(tf.assign_add(global_step, 1)) # Need to manually increment global_step
Expand Down Expand Up @@ -242,7 +250,7 @@ def serialized_fn(mtf_features):
return tpu_estimator.TPUEstimatorSpec(
tf.estimator.ModeKeys.TRAIN,
loss=tf_loss,
host_call=summary.get_host_call(),
host_call=host_call,
train_op=train_op,
training_hooks=[restore_hook, saver_hook])

Expand Down
11 changes: 5 additions & 6 deletions optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,15 @@ def clip_by_global_norm(grads, clip_norm):
clipped_grads = [None if t is None else t * multiplier for t in grads]
return clipped_grads, global_norm

def get_optimizer(loss, params, summary, variable_dtype, inp_var_grads=None):
def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None):
"""Creates and returns an optimizer training op."""
mesh = loss.mesh # get mesh info from loss
graph = mesh.graph # get graph info from mesh
global_step = tf.train.get_or_create_global_step() # get global step

learning_rate = tf.constant(value=params["lr"], shape=[], dtype=variable_dtype.slice_dtype) # grab lr param
clip_value = mtf.constant(mesh, params["gradient_clipping"], dtype=variable_dtype.slice_dtype)

if inp_var_grads is None:
var_grads = mtf.gradients([loss], [v.outputs[0] for v in graph.trainable_variables])
var_grads = mtf.gradients([loss], [v.outputs[0] for v in mesh.graph.trainable_variables])
else:
var_grads = inp_var_grads

Expand Down Expand Up @@ -63,7 +61,8 @@ def get_optimizer(loss, params, summary, variable_dtype, inp_var_grads=None):
learning_rate = ((1.0 - is_warmup) * learning_rate +
is_warmup * warmup_learning_rate)

summary.scalar("lr", learning_rate)
learning_rate = mtf.import_fully_replicated(mesh, learning_rate, mtf.Shape([]), name="learning_rate")
mtf.scalar_summary("lr", learning_rate)

if params["opt_name"].lower() == "adam":
optimizer = AdamWeightDecayOptimizer(
Expand All @@ -87,7 +86,7 @@ def get_optimizer(loss, params, summary, variable_dtype, inp_var_grads=None):
if params["gradient_clipping"] is not None:
(var_grads_fp, _) = clip_by_global_norm(var_grads_fp, clip_norm=clip_value)

update_ops = optimizer.apply_grads(var_grads_fp, graph.trainable_variables)
update_ops = optimizer.apply_grads(var_grads_fp, mesh.graph.trainable_variables)
return learning_rate, update_ops, var_grads_fp


Expand Down
154 changes: 52 additions & 102 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import tensorflow.compat.v1 as tf
from tensorflow.contrib import summary
import re
from urllib.parse import urlparse
from shutil import rmtree
import collections
import logging
import os
import mesh_tensorflow as mtf
from pathlib import Path
import sys
import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf2
import mesh_tensorflow as mtf


def setup_logging(args):
Expand Down Expand Up @@ -206,102 +205,53 @@ class constructor.
return float(ret)


"""Provide a helper class for using summaries on TPU via a host call.
TPUEstimator does not support writing TF summaries out of the box and TPUs can't
perform operations that write files to disk. To monitor tensor values during
training you can copy the tensors back to the CPU of the host machine via
a host call function. This small library provides a convenient API to do this.
Example:
from compare_gan.tpu import tpu_summaries
def model_fn(features, labels, params, mode):
summary = tpu_summries.TpuSummaries(my_model_dir)
summary.scalar("my_scalar_summary", tensor1)
summary.scalar("my_counter", tensor2, reduce_fn=tf.math.reduce_sum)
return TPUEstimatorSpec(
host_call=summary.get_host_call(),
...)
Warning: The host call function will run every step. Writing large tensors to
summaries can slow down your training. High ranking outfeed operations in your
XProf profile can be an indication for this.
"""

TpuSummaryEntry = collections.namedtuple(
"TpuSummaryEntry", "summary_fn name tensor reduce_fn")


class TpuSummaries(object):
"""Class to simplify TF summaries on TPU.
An instance of the class provides simple methods for writing summaries in the
similar way to tf.summary. The difference is that each summary entry must
provide a reduction function that is used to reduce the summary values from
all the TPU cores.
def create_host_call(model_dir):
"""Construct a host_call writing scalar summaries.
Borrowed from t2t.
TPU.
Args:
model_dir: String containing path to train
Returns:
(fn, args) Pair to be called by TPUEstimator as the host_call.
"""

def __init__(self, log_dir, save_summary_steps=500):
self.logger = tf.logging
self._log_dir = log_dir
self._scalar_entries = []
# While False no summary entries will be added. On TPU we unroll the graph
# and don't want to add multiple summaries per step.
self.record = True
self._save_summary_steps = save_summary_steps
#assert TpuSummaries.inst is None
TpuSummaries.inst = self

def has(self, name):
for entry in self._scalar_entries:
if entry.name == name:
return True
return False

def scalar(self, name, tensor, reduce_fn=tf.math.reduce_mean):
"""Add a summary for a scalar tensor."""
if not self.record:
return
if self.has(name):
self.logger.info("TpuSummaries.scalar: skipping duplicate %s", name)
else:
tensor = tf.convert_to_tensor(tensor)
if tensor.shape.ndims == 0:
tensor = tf.expand_dims(tensor, 0)
self._scalar_entries.append(
TpuSummaryEntry(summary.scalar, name, tensor, reduce_fn))

def get_host_call(self):
"""Returns the tuple (host_call_fn, host_call_args) for TPUEstimatorSpec."""
# All host_call_args must be tensors with batch dimension.
# All tensors are streamed to the host machine (mind the band width).
global_step = tf.train.get_or_create_global_step()
host_call_args = [tf.expand_dims(global_step, 0)]
host_call_args.extend([e.tensor for e in self._scalar_entries])
self.logger.info("host_call_args: %s", host_call_args)
return (self._host_call_fn, host_call_args)

def _host_call_fn(self, step, *args):
"""Function that will run on the host machine."""
# Host call receives values from all tensor cores (concatenate on the
# batch dimension). Step is the same for all cores.
step = step[0]
self.logger.info("host_call_fn: args=%s", args)
ops = []

# log scalars
with summary.create_file_writer(os.path.join(self._log_dir, 'scalars')).as_default():
offset = 0
with summary.record_summaries_every_n_global_steps(
self._save_summary_steps, step):
for i, e in enumerate(self._scalar_entries):
value = e.reduce_fn(args[i + offset])
e.summary_fn(e.name, value, step=step)
offset += len(self._scalar_entries)
ops.append(summary.all_summary_ops())
return tf.group(ops)


TpuSummaries.inst = None
graph = tf.get_default_graph()
# a list of (name, lowered tensor) tuples
summaries = graph.get_collection(mtf.utils.SCALAR_SUMMARIES_COLLECTION_KEY)

def maybe_cast(tensor):
assert tensor.shape.is_compatible_with([]), tensor.name
if tensor.dtype == tf.int64:
return tf.to_int32(tensor)
if tensor.dtype == tf.bfloat16:
return tf.cast(tensor, tf.float32)
return tensor

reshaped_tensors = [tf.reshape(maybe_cast(t), [1]) for _, t in summaries]

# When no supported summaries are found, don't create host_call. Otherwise,
# TPU outfeed queue would enqueue global_step while host_call doesn't dequeue
# it, eventually causing hang.
if not reshaped_tensors:
return None

def host_call_fn(global_step, *args):
"""Training host call. Creates scalar summaries for training metrics."""
# This function is executed on the CPU and should not directly reference
# any Tensors in the rest of the `model_fn`. To pass Tensors from the
# model to the `model_fn`, provide as part of the `host_call`.
global_step = tf.cast(global_step[0], tf.int64)
with tf2.summary.create_file_writer(model_dir).as_default():
# We cannot directly use any tensor from summaries, because each
# tensor here must be a concat of multiple tensors from all shards.
# Therefore, we rely on the assumption that args wil have the same
# length as summaries, and all tensors in args will have the same
# order of self._tup_summaries.
assert len(args) == len(summaries)
for i, tensor in enumerate(args):
name = summaries[i][0]
tf2.summary.scalar(
name, tf.reduce_mean(tensor), step=global_step)
return tf.summary.all_v2_summary_ops()

global_step_t = tf.reshape(tf.to_int32(tf.train.get_global_step()), [1])
return host_call_fn, [global_step_t] + reshaped_tensors

0 comments on commit 9dfd7e0

Please sign in to comment.