In [1]:
import numpy as np
import timeit
import traceback
import contextlib

import tensorflow as tf

In [2]:
@contextlib.contextmanager
def options(options):
    old_opts = tf.config.optimizer.get_experimental_options()
    tf.config.optimizer.set_experimental_options(options)
    try:
        yield
    finally:
        tf.config.optimizer.set_experimental_options(old_opts)

In [3]:
def example_function():
  @tf.function
  def simple_function(input):
    print('Tracing...')
    a = tf.constant(np.random.randn(2000, 2000), dtype = tf.float32)
    c = a
    for i in range(50):
      c = c@a
    return tf.reduce_mean(c+input)

  return simple_function

In [4]:
with options({'constant_folding': False}):
  print(tf.config.optimizer.get_experimental_options())
  simple_function = example_function()
  x = tf.constant(5.4)
  simple_function(x)
  print("Normal Execution:", timeit.timeit(lambda: simple_function(x), number = 1), "s")

{'constant_folding': False, 'disable_model_pruning': False, 'disable_meta_optimizer': False}
Tracing...


2023-07-17 17:11:47.586444: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Max
2023-07-17 17:11:47.586463: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 32.00 GB
2023-07-17 17:11:47.586467: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 10.67 GB
2023-07-17 17:11:47.586523: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-07-17 17:11:47.586557: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
2023-07-17 17:11:47.727298: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


Normal Execution: 0.003241084050387144 s


In [5]:
with options({'constant_folding': True}):
  print(tf.config.optimizer.get_experimental_options())
  simple_function = example_function()
  x = tf.constant(5.4)
  simple_function(x)
  print("Constant Folded Execution:", timeit.timeit(lambda: simple_function(x), number = 1), "s")

{'constant_folding': True, 'disable_model_pruning': False, 'disable_meta_optimizer': False}
Tracing...


2023-07-17 17:12:24.063083: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


Constant Folded Execution: 0.00044241698924452066 s
