Skip to content

Conversation

@mariogeiger
Copy link
Collaborator

Make JAX implementation of trimul able to load the same cache files used in the counterpart pytorch code.

dtypes = [
str(t.dtype if t.dtype != jnp.bfloat16 else jnp.dtype(jnp.float16))
for t in [x1, x2, w1, w2, b1, b2, mask]
"torch." + str(t.dtype if t.dtype != jnp.bfloat16 else jnp.dtype(jnp.float16))
Copy link
Collaborator

@hsadasivan hsadasivan Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should remove the 'torch' sometime everywhere, I'd change it everywhere when we have more time.
For now, maybe good to leave this as comment:
We currently perform tuning in torch. Leaving the same annotation here as a lazy implementation.

@mariogeiger mariogeiger merged commit 7aaaeb4 into main Sep 19, 2025
13 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants