#### Bug report for mjx.ray

In [12]:
import mujoco
from mujoco import mjx
from functools import partial
import jax 
import jax.numpy as jp

xml_2dof = """
<mujoco model="2 dof ball">
    <compiler angle="radian"/>
	<option timestep="0.002" iterations="1" ls_iterations="4" solver="Newton" gravity="0 0 -9.81">
		<flag eulerdamp="disable"/>
	</option>
    <worldbody>
        <geom name="obst" size="0.1" type="capsule" fromto="0 2 0 0 2 1" conaffinity="1" contype="0"/>
        <body name="ball" pos="0 0 0.2">
            <geom name="ball_geom" size="0.1" type="sphere" conaffinity="0" contype="1"/>
        </body>
    </worldbody>
    <asset>
        <texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
        markrgb="0.8 0.8 0.8" width="100" height="100"/>
        <material name="geom" texture="texgeom" texuniform="true"/>
    </asset>
</mujoco>
"""

model = mujoco.MjModel.from_xml_string(xml_2dof)

mx = mjx.put_model(model)
data = mujoco.MjData(model)
mujoco.mj_forward(model, data)
dx = mjx.put_data(model, data)

test_func = jax.jit(partial(mjx.ray, m=mx, d=dx, pnt=jp.zeros(3)))

normalised_vec = jp.array([0, 1, 0])
print("Obstacle distance normalised vec: {}".format(test_func(vec=normalised_vec)[0]))
print("Obstacle distance unnormalised vec: {}".format(test_func(vec=normalised_vec/10)[0]))

Obstacle distance normalised vec: 1.9000000953674316
Obstacle distance unnormalised vec: 18.999998092651367
