From 9c69e5ad35a1a2a9b0adfb17c87d0388d622c4d2 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 21 Mar 2025 09:29:06 -0700 Subject: [PATCH] import te before te_jax Signed-off-by: Phuong Nguyen --- examples/jax/encoder/common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/jax/encoder/common.py b/examples/jax/encoder/common.py index 2785deac0c..ea6de73b34 100644 --- a/examples/jax/encoder/common.py +++ b/examples/jax/encoder/common.py @@ -4,6 +4,7 @@ """Shared functions for the encoder tests""" from functools import lru_cache +import transformer_engine from transformer_engine_jax import get_device_compute_capability