# @tf.function

### what @tf.function is and why it matters
- @tf.function compiles a Python function into a TensorFlow graph (AutoGraph + XLA-ready ops in some cases).
- Result: much faster execution for numeric code because it removes Python overhead and allows graph optimizations.
- But: compilation has costs (tracing, retracing) and restrictions (no arbitrary Python behavior inside compiled parts).

Bottom line: compile hot, stable, tensor-only code paths — typically inner training/validation steps — and keep Python orchestration (logging, checkpointing, dataset iteration) in plain Python.

### When to use @tf.function (practical rules)
Use it for:
- train_step and val_step (forward → loss → backward → apply grads).
- Computationally heavy preprocessing you can express in TF ops (e.g., tf.strings.*, tf.image.*) — if it’s pure-TF and benefits from graph speed.

**Do NOT wrap:**
- Entire training loop if it contains Python I/O, printing, tf.summary writers not designed for inside @tf.function, or checkpoint logic. Keep outer loop in Python.
- Code that needs to run arbitrary Python logic per step (e.g., complex Python control flow, logging per step during debugging).

### Common pitfalls & how to fix them
1) Retracing overhead
   - If @tf.function is called on many different shapes/dtypes or with Python objects changing, TensorFlow will retrace repeatedly and kill performance.
   - Fix: give stable shapes / dtypes, add input_signature or call with tensors of fixed tf.TensorShape (use tf.TensorSpec).
2) Python objects inside @tf.function
   - Python lists/dicts, file I/O, printing, random Python code — these either fail or cause retracing.
   - Fix: move Python logic outside the function, use TF equivalents (e.g., tf.print, tf.io.write_file) or convert to tensors.
3) Mutable Python state
   - Modifying Python lists/dicts inside a compiled function won’t have the intended persistent effect.
   - Fix: keep mutable state in tf.Variable or in Python outside the function.
4) Debugging is harder
   - Errors inside a @tf.function often show cryptic stack traces.
   - Fix: run in eager (tf.config.run_functions_eagerly(True)) to debug, then switch back.
5) Returning non-Tensor objects
   - Functions should primarily return tensors or (nested) structures of tensors. Returning complex Python objects can cause issues.
6) Random seeds & determinism
   - tf.random.set_seed() works inside @tf.function but order of ops matters. If you seed inside function, be careful.
7) Side effects
   - tf.print and tf.summary can be used in @tf.function, but be mindful of semantics (use tf.summary with proper writer contexts and flush outside).

### Profiling basics (what profiler gives you)
- tf.profiler.experimental captures timelines, op-by-op cost, memory usage, kernel details, and python/CPU hotspots.
- Typical workflow:
  1) Start trace with tf.profiler.experimental.start(logdir)
  2) Run a few warmup steps + the steps you want to profile
  3) Stop the trace tf.profiler.experimental.stop()
  4) Launch tensorboard --logdir <logdir> and use the “Profile” tab to inspect trac


In [None]:
%pip install tensorflow
%pip install tensorboard
%pip install tensorboard-plugin-profile

### Minimal, runnable example — one cell

This does three things:
1) Builds a tiny dataset and model.
2) Shows train_step with and without @tf.function.
3) Runs the profiler for a few steps and writes a TensorBoard trace to ./logs/profile.


In [None]:
# Single-cell example: @tf.function usage + profiling (runnable)
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import time, os, shutil

# --- small dataset (toy) ---
(x_train, y_train), _ = keras.datasets.imdb.load_data(num_words=2000)
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=80)
x_train = x_train[:1024]   # keep it small so profiling is quick
y_train = y_train[:1024]

ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(1000).batch(64).prefetch(tf.data.AUTOTUNE)

# --- tiny model ---
def make_model():
    inp = keras.Input(shape=(80,), dtype='int32')
    x = layers.Embedding(2000, 32)(inp)
    x = layers.Bidirectional(layers.LSTM(32))(x)
    x = layers.Dense(16, activation='relu')(x)
    out = layers.Dense(1, activation='sigmoid')(x)
    return keras.Model(inp, out)

model = make_model()
optimizer = keras.optimizers.Adam(1e-3)
loss_fn = keras.losses.BinaryCrossentropy()

# --- train_step functions: eager vs @tf.function ---
# Eager version (no decorator) -- easier to debug
def train_step_eager(x, y):
    with tf.GradientTape() as tape:
        pred = model(x, training=True)
        loss = loss_fn(y, pred)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss

# Compiled version (fast) -- wrap only the inner step
@tf.function
def train_step_compiled(x, y):
    with tf.GradientTape() as tape:
        pred = model(x, training=True)
        loss = loss_fn(y, pred)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss

# --- small helper to run N steps using provided train_step ---
def run_steps(train_step_fn, dataset, steps=10):
    it = iter(dataset)
    t0 = time.time()
    losses = []
    for i in range(steps):
        x_batch, y_batch = next(it)
        l = train_step_fn(x_batch, tf.cast(y_batch, tf.float32))
        # If train_step_fn is a tf.function, l is a tensor; convert for logging
        try:
            losses.append(float(l))
        except Exception:
            losses.append(l.numpy())
    dt = time.time() - t0
    print(f"{train_step_fn.__name__}: ran {steps} steps in {dt:.3f}s, avg loss {sum(losses)/len(losses):.4f}")

# Warm up: run a few eager steps to initialize variables
print("Warmup (eager)...")
run_steps(train_step_eager, ds, steps=3)

# Measure compiled vs eager
print("\nEager timing:")
run_steps(train_step_eager, ds, steps=10)

print("\nCompiled (@tf.function) timing (first call includes trace/compile overhead):")
run_steps(train_step_compiled, ds, steps=10)
print("Compiled timing (second run, should be faster):")
run_steps(train_step_compiled, ds, steps=10)

# --- Profiling: capture a small trace around compiled steps ---
logdir = "./assets/logs/profile"
if os.path.exists(logdir):
    shutil.rmtree(logdir)
os.makedirs(logdir, exist_ok=True)

# Use profiler to capture a small window. We profile the compiled train_step.
tf.profiler.experimental.start(logdir)
print("\nProfiling: running 5 compiled steps with profiler running...")
run_steps(train_step_compiled, ds, steps=5)
tf.profiler.experimental.stop()
print(f"Profiler trace written to: {logdir}")
print("Open TensorBoard: tensorboard --logdir assets/logs/profile, then go to Profile tab to view.")


### How to use the example and what to look for in TensorBoard
1) Run the cell. It will:
    - Show warmup (eager).
    - Show timings for eager vs compiled.
    - Write a profiler trace to ./logs/profile.
2) Start TensorBoard: tensorboard --logdir ./logs_path
   - Open the Profile tab.
   - Inspect the Trace Viewer: you’ll see CPU/GPU timelines, op durations, and which ops dominate time.
   - Look at the “TensorFlow Stats” and “Kernel Stats” pages to identify hotspots and memory peaks.


In [None]:
%load_ext tensorboard
%tensorboard --logdir assets/logs/profile --port 6006

### Practical tips when profiling
- Warm up first: JIT and GPU need warmup. Run a few steps before capturing the trace.
- Profile a small window: profiling huge runs produces massive traces. Profile 5–50 steps.
- Use tf.function: profile compiled functions to see optimized kernels and fused ops.
- If you see retracing: profiler shows repeated traces. Fix retracing by stabilizing inputs or adding input_signature to @tf.function:
``` py 
@tf.function(input_signature=[tf.TensorSpec([None, 80], tf.int32), tf.TensorSpec([None], tf.int32)])
def train_step_compiled(x, y): ...
```
- TensorBoard helps: use the Trace Viewer to see whether CPU preprocessing or GPU compute is the bottleneck. If CPU dominates, tune tf.data (more num_parallel_calls, prefetch, caching).


### Extra advanced notes
- tf.function + XLA: you can request XLA compilation per function (jit_compile=True in tf.function or in optimizer) if you want to experiment with backend compilation — tradeoff: compile time vs runtime speed.
- For debugging, tf.print works inside @tf.function and prints during execution (useful when you can’t run eager).
- If you use tf.data with map functions that are decorated with @tf.function, the map will run inside the dataset pipeline efficiently.
- For reproducibility across @tf.function, set seeds before training and avoid Python randomness inside the function.


### What is a TensorFlow Graph?

A TensorFlow graph is a directed acyclic graph (DAG) of computation:
- Nodes: operations (e.g., `MatMul`, `Add`, `Relu`, `Conv2D`).
- Edges: tensors flowing between ops (data dependencies).
- No Python control flow at runtime; only the traced ops/tensors.
- Built when TensorFlow "traces" your Python function (e.g., via `@tf.function`).

Conceptually, for `y = relu(Wx + b)` the graph looks like:
```
 x ----> MatMul ----> Add ----> Relu ----> y
          ^            ^
          |            |
          W            b
```

Why graphs?
- Optimize globally (fusion, constant folding, device placement).
- Run efficiently on CPU/GPU/TPU without Python overhead.
- Serialize as `GraphDef`/`SavedModel` and serve elsewhere.


In [None]:
# Minimal example: build a graph with tf.function and inspect it
import os, datetime
import tensorflow as tf

@tf.function
def model(x, W, b):
    y = tf.matmul(x, W) + b
    return tf.nn.relu(y)

# Fixed shapes for a stable trace
x = tf.random.normal([2, 3])
W = tf.random.normal([3, 4])
b = tf.random.normal([4])

concrete = model.get_concrete_function(
    tf.TensorSpec(x.shape, x.dtype),
    tf.TensorSpec(W.shape, W.dtype),
    tf.TensorSpec(b.shape, b.dtype),
)

print("Ops in graph:")
for op in concrete.graph.get_operations():
    print(f"- {op.name}: {op.type}")

# Save GraphDef as text for a quick look
os.makedirs('assets', exist_ok=True)
path_pbtxt = os.path.join('assets', 'example_graph.pbtxt')
tf.io.write_graph(concrete.graph.as_graph_def(), 'assets', 'example_graph.pbtxt', as_text=True)
print(f"\nSaved graph to {path_pbtxt}")

# Log graph for TensorBoard
logdir = os.path.join('assets', 'logs', 'graph', datetime.datetime.now().strftime('%Y%m%d-%H%M%S'))
writer = tf.summary.create_file_writer(logdir)
with writer.as_default():
    tf.summary.graph(concrete.graph)
writer.flush()
print("TensorBoard logs at:", logdir)
print("To view here, you can run:\n%load_ext tensorboard\n%tensorboard --logdir assets/logs/graph")


Start TensorBoard: tensorboard --logdir ./logs_path
   - Open GRAPHS
   - look at graph inspect nodes

In [None]:
%load_ext tensorboard
%tensorboard --logdir assets/logs/graph --port 6007 