<a href="https://colab.research.google.com/github/RlChen0/Learn-Notes/blob/master/Example_of_printing_call_method_input_and_output_shapes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install --upgrade tensorflow dm-sonnet 2>/dev/null

In [None]:
import tensorflow as tf
import sonnet as snt

The following utilities can be used to wrap modules in order to make them print their input and output shape (on `__call__`methods):

In [None]:
import functools

def stringify(leaf):
  if isinstance(leaf, tf.Tensor):
    leaf = f'{leaf.dtype.name}[{",".join(map(str, leaf.shape))}]'
  return str(leaf)

def print_shape(f):
  @functools.wraps(f)
  def wrapper(self, *args, **kwargs):
    cls_name = self.__class__.__name__
    args_str = list(tf.nest.map_structure(stringify, args))
    args_str.extend([f'{k}={tf.nest.map_structure(stringify, v)}' for k, v in kwargs.items()])
    out = f(self, *args, **kwargs)
    ret_str = tf.nest.map_structure(stringify, out)
    print(f'{("{}({})".format(self.name, ", ".join(args_str))):60} -> {ret_str}')
    return out
  return wrapper

def patch_modules_to_print_shape():
  for cls in snt.Module.__subclasses__():
    if hasattr(cls, "__call__"):
      before = getattr(cls.__call__, '__mtps_before', cls.__call__)
      cls.__call__ = print_shape(before)
      cls.__call__.__mtps_before = before

It works with custom types:

In [None]:
class MyModule(snt.Module):
  def __call__(self, x, *, reshape: bool):
    if reshape:
      x = tf.reshape(x, [1] + x.shape)
    return x

# Call this after you have defined all your subclasses..
patch_modules_to_print_shape()

m = MyModule()
x = tf.ones([1, 1])
_ = m(x, reshape=True)

my_module(float32[1,1], reshape=True)                        -> float32[1,1,1]


Or with builtin ones:

In [None]:
x = tf.ones([1, 28 * 28])
m = snt.nets.MLP([300, 100, 10])
_ = m(x)

linear_0(float32[1,784])                                     -> float32[1,300]
linear_1(float32[1,300])                                     -> float32[1,100]
linear_2(float32[1,100])                                     -> float32[1,10]
mlp(float32[1,784])                                          -> float32[1,10]


In [None]:
x = tf.ones([1, 224, 224, 3])
m = snt.nets.ResNet50(1000)
_ = m(x, is_training=True)

initial_conv(float32[1,224,224,3])                           -> float32[1,112,112,64]
initial_batchnorm(float32[1,112,112,64], is_training=True)   -> float32[1,112,112,64]
shortcut_conv(float32[1,56,56,64])                           -> float32[1,56,56,256]
shortcut_batchnorm(float32[1,56,56,256], is_training=True)   -> float32[1,56,56,256]
conv_0(float32[1,56,56,64])                                  -> float32[1,56,56,64]
batchnorm_0(float32[1,56,56,64], is_training=True)           -> float32[1,56,56,64]
conv_1(float32[1,56,56,64])                                  -> float32[1,56,56,64]
batchnorm_1(float32[1,56,56,64], is_training=True)           -> float32[1,56,56,64]
conv_2(float32[1,56,56,64])                                  -> float32[1,56,56,256]
batchnorm_2(float32[1,56,56,256], is_training=True)          -> float32[1,56,56,256]
block_0(float32[1,56,56,64], is_training=True)               -> float32[1,56,56,256]
conv_0(float32[1,56,56,256])                                 -> flo