In [None]:
import torch
import torch.nn.functional as F
import onnx


@torch.jit.script
def if_conv_flatten(
    x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, condition: bool
) -> torch.Tensor:
    """
    Applies a conv2d, then uses an If node:
      if condition: Flatten
      else: clone (i.e. identity)
    """
    x = F.conv2d(x, w, b, stride=1, padding=1)
    if condition:
        # 'then' branch
        return x.view(x.size(0), x.size(1), -1)  # flatten
    else:
        # 'else' branch
        return x.clone()


# 1) Generate some dummy data
x = torch.randn(2, 1, 5, 5, dtype=torch.float32)
w = torch.randn(1, 1, 3, 3, dtype=torch.float32)
b = torch.randn(1, dtype=torch.float32)

# 2) Export to ONNX.  This definitely yields an If node
model_path = "model_if.onnx"
torch.onnx.export(
    if_conv_flatten,
    (x, w, b, True),  # Pass a bool that's not recognized as a constant by JIT
    str(model_path),
    opset_version=14,  # or 20, depending on your version
    input_names=["x", "w", "b", "condition"],
    output_names=["output"],
    dynamic_axes={"x": {0: "batch"}, "output": {0: "batch"}},
    do_constant_folding=False,
)
onnx_model = onnx.load(str(model_path))