In [4]:
import jax
import jax.numpy as jnp
from utils.ops import FeedForwardParams, feed_forward as feed_forward_jax
jax.config.update("jax_default_matmul_precision", "float32")

import torch
import torch_xla
import torch_xla.core.xla_model as xm
from tests.torch_ops import FeedForward as FeedForward_torch

import numpy as np

device = xm.xla_device()

In [5]:
# 1. Setup Parameters
bsz = 2
seqlen = 64
dim = 128
multiple_of = 32
dtype = np.float32

# Mimic hidden_dim calculation from PyTorch implementation
hidden_dim = int(2 * (4 * dim) / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

# 2. Create shared weights and inputs
np.random.seed(0)
x_np = np.random.randn(bsz, seqlen, dim).astype(dtype)

# PyTorch names weights w1, w2, w3. JAX uses w1_gate, w2_up, w3_down.
# Mapping: torch.w1 -> jax.w1_gate, torch.w3 -> jax.w2_up, torch.w2 -> jax.w3_down
w1_np = np.random.randn(dim, hidden_dim).astype(dtype) # gate_proj
w3_np = np.random.randn(dim, hidden_dim).astype(dtype) # up_proj
w2_np = np.random.randn(hidden_dim, dim).astype(dtype) # down_proj
print("variables initialized")

# 3. JAX setup
x_jax = jnp.array(x_np)
jax_params = FeedForwardParams(
w1_gate=jnp.array(w1_np),
w2_up=jnp.array(w3_np),
w3_down=jnp.array(w2_np)
)
output_jax = feed_forward_jax(x_jax, jax_params, activation_fn='silu')
print("jax output computed")

# 4. PyTorch setup
x_torch = torch.tensor(x_np, device=device)
torch_ff = FeedForward_torch(dim=dim, hidden_dim=4 * dim, multiple_of=multiple_of, ffn_dim_multiplier=None)
torch_ff.w1.weight = torch.nn.Parameter(torch.tensor(w1_np.T, device=device))
torch_ff.w3.weight = torch.nn.Parameter(torch.tensor(w3_np.T, device=device))
torch_ff.w2.weight = torch.nn.Parameter(torch.tensor(w2_np.T, device=device))
output_torch = torch_ff(x_torch)
print("torch output computed")

# 5. Compare
#np.testing.assert_allclose(np.array(output_jax), output_torch.detach().cpu().numpy(), rtol=1e-4, atol=5e-3) 

variables initialized
jax output computed
torch output computed


In [6]:
print(output_jax)

[[[ 1712.5023    1537.1199     219.94836  ...  1210.542     -517.8104
    1957.5632  ]
  [  720.4284     564.9929    1796.5664   ... -2025.019     1703.8607
    1721.794   ]
  [ -888.39954   -709.58875  -1181.3383   ... -2422.6638    -421.4978
    1032.3291  ]
  ...
  [-2821.6235    -626.46783    510.15985  ...  -265.06628    560.21924
    1006.6872  ]
  [-2590.8398     302.2268   -1411.2493   ...   410.38287  -1112.2267
   -1241.848   ]
  [  372.734      959.9962   -2937.226    ... -3526.522     1063.862
   -2356.3728  ]]

 [[-1457.0582    1140.365        9.848694 ...  3753.942    -1990.3994
     710.10736 ]
  [-1281.9932     480.12286    -14.375618 ...   834.6818    -536.7389
   -1800.0834  ]
  [  412.83234   1287.3164    1074.1871   ...   227.30594  -1567.2417
    1311.1067  ]
  ...
  [ 1935.8376    3122.3125   -1176.6947   ... -2616.458      617.2828
    1723.2715  ]
  [  671.13477  -2396.4302    2321.6958   ...   687.7973    1850.6819
   -1793.8928  ]
  [-2118.3052   -1424.027    

In [None]:
[[[ 1727.29       1533.9014      210.12515   ...  1214.3501
    -517.1265     1966.4995   ]
  [  714.232       559.24133    1806.5352    ... -2020.8241
    1712.354      1723.6455   ]
  [ -883.953      -709.4646    -1173.58      ... -2415.8113
    -417.66623    1023.7455   ]
  ...
  [-2821.728      -618.63855     513.56104   ...  -263.049
     551.41956     995.66113  ]
  [-2590.9634      300.09814   -1415.0653    ...   420.91208
   -1106.2332    -1233.032    ]
  [  365.25873     957.16455   -2939.5374    ... -3539.198
    1062.811     -2341.3174   ]]

 [[-1453.0334     1150.6702        4.0286255 ...  3759.0645
   -1990.8007      712.08453  ]
  [-1276.685       479.71725     -21.86324   ...   830.7525
    -537.3478    -1791.3425   ]
  [  412.71432    1286.6497     1072.5814    ...   218.66956
   -1575.5161     1316.7893   ]
  ...
  [ 1934.4199     3132.8506    -1170.0195    ... -2615.2393
     618.4        1727.3142   ]
  [  673.20856   -2402.7559     2321.2598    ...   677.0799
    1856.0294    -1793.9155   ]
  [-2118.0781    -1413.2001    -3229.768     ... -1631.8623
     411.35004    1051.3274   ]]]

In [None]:
print(output_torch.cpu())



In [None]:
import torch
import torch_xla
import torch_xla.core.xla_model as xm

dev = xm.xla_device()

x1 = torch.rand((3, 3)).to(dev)
x2 = torch.rand((3, 8)).to(dev)

y1 = torch.einsum('bs,st->bt', x1, x2)
print(torch_xla._XLAC._get_xla_tensors_text([y1]))