### A4.2.1. TensorFlow Graph Compilation

> *`tf.function` traces a Python function into a TensorFlow graph (FuncGraph), which can be optimized by Grappler and optionally compiled by XLA before execution.*

**Explanation:**

TensorFlow's **graph mode** separates graph construction from execution, enabling whole-program optimizations impossible in eager mode.

**`tf.function` Mechanics:**

1. **Tracing** ‚Äî TF calls the Python function with symbolic `tf.Tensor` placeholders, recording every TF op into a `FuncGraph`.
2. **ConcreteFunction** ‚Äî the traced graph + input signature, cached by `(input shapes, input dtypes, Python control flow constants)`.
3. **Retracing** ‚Äî triggered by new input signatures. Excessive retracing (many shapes) wastes compilation time.

**Graph Optimization (Grappler):**

| Pass | Optimization |
|------|-------------|
| Constant folding | Evaluate ops with compile-time-known inputs |
| Common subexpression elimination | Deduplicate identical subgraphs |
| Layout optimization | NHWC ‚Üî NCHW for target device |
| Op fusion | Merge adjacent ops (e.g., Conv+BN+ReLU) |
| Pruning | Remove unreachable ops |

**Tracing Pitfalls:**

- **Python side effects** in traced functions execute only at trace time, not at call time.
- **Python `if/for`** on tensors traces only one branch; use `tf.cond` / `tf.while_loop` for dynamic control flow.
- **`input_signature`** on `tf.function` prevents retracing by fixing shapes.

**Example:**

```python
@tf.function(input_signature=[tf.TensorSpec([None, 784], tf.float32)])
def predict(x):
    return tf.nn.softmax(x @ weights + bias)
```

`None` in the signature allows variable batch size without retracing.

In [None]:
from dataclasses import dataclass, field


@dataclass
class GraphNode:
    op_type: str
    name: str
    inputs: list[str]


@dataclass
class FuncGraph:
    function_name: str
    nodes: list[GraphNode] = field(default_factory=list)

    def add_node(self, op_type, name, inputs):
        node = GraphNode(op_type, name, inputs)
        self.nodes.append(node)
        return node


@dataclass
class InputSignature:
    shapes: list[tuple]
    dtypes: list[str]

    @property
    def cache_key(self):
        return tuple(zip(self.shapes, self.dtypes))


graph = FuncGraph("predict")
graph.add_node("Placeholder", "x", [])
graph.add_node("ReadVariable", "weights", [])
graph.add_node("ReadVariable", "bias", [])
graph.add_node("MatMul", "matmul", ["x", "weights"])
graph.add_node("BiasAdd", "biasadd", ["matmul", "bias"])
graph.add_node("Softmax", "softmax", ["biasadd"])

print(f"FuncGraph: {graph.function_name}")
print(f"Nodes: {len(graph.nodes)}")
for node in graph.nodes:
    inputs_str = f"({', '.join(node.inputs)})" if node.inputs else "()"
    print(f"  {node.name} = {node.op_type}{inputs_str}")

sig_a = InputSignature(shapes=[(128, 784)], dtypes=["float32"])
sig_b = InputSignature(shapes=[(64, 784)], dtypes=["float32"])
sig_c = InputSignature(shapes=[(None, 784)], dtypes=["float32"])

print(f"\nSignature caching:")
print(f"  batch=128: {sig_a.cache_key}")
print(f"  batch=64:  {sig_b.cache_key}")
print(f"  Same key? {sig_a.cache_key == sig_b.cache_key} ‚Üí retrace needed")
print(f"  With None: {sig_c.cache_key} ‚Üí no retrace for batch changes")

grappler_passes = [
    "constant_folding",
    "common_subexpression_elimination",
    "layout_optimization",
    "op_fusion",
    "pruning",
]
print(f"\nGrappler passes: {len(grappler_passes)}")
for pass_name in grappler_passes:
    print(f"  {pass_name}")

**References:**

[üìò TensorFlow. *Introduction to graphs and tf.function.*](https://www.tensorflow.org/guide/intro_to_graphs)

[üìò TensorFlow. *Better performance with tf.function.*](https://www.tensorflow.org/guide/function)

---

[‚¨ÖÔ∏è Previous: JAX Compilation](../01_ML_Compiler_Ecosystem/03_jax_compilation.ipynb) | [Next: JAX Just-in-Time Compilation ‚û°Ô∏è](./02_jax_just_in_time_compilation.ipynb)