In [1]:
import sys
sys.path.append("..")

import jax
from jax.tree_util import register_pytree_node_class
from utils.test_utils import *
from functools import partial
key = Key(1234)

In [8]:
@register_pytree_node_class
class class1:
    def __init__(self,x,y):
        self.x = x
        self.y = y

    @classmethod
    def init(cls):
        return cls(jax.random.normal(key(),(1000,1000)),jax.random.normal(key(),(1000,1000)))
    
    def tree_flatten(self):
        children = (self.x,)
        aux_data = {"y":self.y}
        return (children,aux_data)
    
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children, **aux_data)
    
    @jax.jit
    def test(self):
        return self.x@self.y
    
    def step(self):
        self.x = self.test()

@register_pytree_node_class
class class2:
    def __init__(self,c1:class1,x,y):
        self.c1 = c1
        self.x = x
        self.y = y
    
    @classmethod
    def init(cls,c1):
        return cls(c1,jax.random.normal(key(),(1000,1000)),jax.random.normal(key(),(1000,1000)))
    
    def tree_flatten(self):
        children = (self.c1,self.x,)
        aux_data = {"y":self.y}
        return (children,aux_data)
    
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children, **aux_data)

    @jax.jit
    def temp(self):
        return self.x@self.c1.x
    
    def step(self):
        self.c1.x = self.temp()

    @jax.jit
    def temp2(self):
        def f():
            print(self.x)
            return self.x
        return f()

In [9]:
c = class2.init(class1.init())
c.temp2()

Traced<ShapedArray(float32[1000,1000])>with<DynamicJaxprTrace(level=1/0)>


Array([[ 0.00847632,  1.7395164 ,  0.96435696, ..., -0.23694572,
         1.2518747 ,  1.5256135 ],
       [ 0.9652041 ,  1.1905515 ,  1.1188512 , ..., -0.9264909 ,
        -1.3184824 , -0.6343539 ],
       [-0.25975093,  1.77774   ,  1.6337568 , ...,  0.13074388,
         0.5207169 , -1.0848312 ],
       ...,
       [ 0.3411899 , -0.40735915,  1.2093589 , ..., -0.27201095,
        -0.49648887, -0.98410124],
       [-0.6393662 ,  0.02730814,  0.8605588 , ..., -1.8695649 ,
         1.6408234 ,  0.419339  ],
       [ 0.41409856, -0.28595927,  0.25655448, ..., -1.1911238 ,
         0.6019864 , -1.5955957 ]], dtype=float32)

In [10]:
c.x = c.temp()
c.temp2()

Array([[-3.1266510e+01, -1.1488864e+01, -4.4631664e+01, ...,
         3.1016073e+01, -3.0652786e+01,  1.0896491e+01],
       [-4.5138502e+00,  2.3635082e+01,  2.3400490e+01, ...,
        -1.0833298e+01,  2.8219080e+01, -9.0148420e+00],
       [-1.6158099e+00, -5.8236179e+01, -9.8671112e+00, ...,
        -2.4166967e+01, -2.9759541e+00, -1.5821617e+01],
       ...,
       [ 5.6655121e+00,  3.4643116e+01,  1.2729847e+01, ...,
        -1.7660484e+01,  4.5824150e+01, -2.3327166e+01],
       [ 1.4432661e+01, -4.1831314e+01, -1.3267406e+01, ...,
        -5.3859115e-01,  5.6154218e+00,  3.8598831e+01],
       [ 2.3744148e+01, -2.9354095e-03, -3.9674115e+00, ...,
         2.2480003e+01,  1.1760723e+01, -2.7548183e+01]], dtype=float32)

In [31]:
c.step()
c.temp()

Array([[ 4.5768850e+05,  7.5048788e+05,  2.9285547e+03, ...,
         7.3159406e+05,  8.4202450e+05, -2.9815005e+06],
       [ 9.9156181e+05, -1.0399205e+06,  2.6114994e+05, ...,
         1.0136124e+06,  1.8926330e+06, -6.9653444e+05],
       [ 2.5574308e+06, -2.3492120e+06, -2.5507122e+05, ...,
         8.1924062e+04, -4.7874184e+05,  1.5777668e+06],
       ...,
       [ 6.1556019e+05,  1.0859738e+06,  2.5609278e+05, ...,
        -1.0557076e+06,  5.9958350e+05,  1.8305288e+05],
       [ 1.2653528e+06, -2.0992369e+05, -4.4531056e+05, ...,
        -2.9983945e+04, -1.0683095e+06,  7.4259875e+05],
       [ 1.1651331e+06,  1.5011408e+06,  3.3231381e+05, ...,
        -1.6507364e+06, -3.5327481e+05,  5.9389550e+05]], dtype=float32)

In [5]:
c = class1.init()
c.x

Array([[-0.93353075,  1.7326977 ,  1.719694  , ..., -0.3498142 ,
        -0.2929668 ,  0.38068113],
       [ 0.33971688,  0.04780002,  0.9801366 , ..., -1.2653811 ,
         0.47628784,  0.7499514 ],
       [ 0.01225322,  0.8905399 , -0.35550836, ..., -1.6039928 ,
        -0.1824435 ,  0.46025038],
       ...,
       [-0.1727683 ,  0.3482741 ,  0.6989034 , ...,  0.44682515,
         0.9915858 , -0.15785185],
       [-0.24816576, -0.2779053 , -1.5229701 , ..., -0.5056279 ,
        -0.82986337,  0.8726066 ],
       [-1.2603881 ,  0.13645144,  1.2452022 , ...,  0.05471618,
        -0.2707411 , -0.8478249 ]], dtype=float32)

In [6]:
c.step()
c.x

Array([[  98.14861  ,   65.43248  ,   28.060387 , ...,  -49.934658 ,
           7.3294187,   34.67536  ],
       [-140.70764  ,  -94.06485  ,   30.446339 , ...,   93.30864  ,
        -149.92651  ,   26.446575 ],
       [ -69.37301  ,  106.9172   ,   99.633026 , ...,  -28.965012 ,
           1.453764 , -185.7376   ],
       ...,
       [ -16.044582 ,  -17.878666 , -164.36652  , ..., -125.0126   ,
         152.71463  , -101.07814  ],
       [-169.5637   ,   14.349819 ,   89.201614 , ...,   50.865486 ,
         -44.28895  ,  -63.07538  ],
       [  38.69707  ,  167.0109   , -129.59023  , ...,  -61.62404  ,
          19.643017 ,   23.720064 ]], dtype=float32)

In [32]:
@jax.jit
def outside(x,y):
    return x@y
outside(c.x,c.y)

Array([[-33.631756 , -22.025528 ,   6.129885 , ..., -59.38269  ,
         -4.4846025, -61.29505  ],
       [ 64.205215 ,  12.65848  , -38.906532 , ..., -10.707792 ,
         -5.894917 ,   3.3090916],
       [-30.926172 , -17.672123 ,  26.825018 , ...,   4.323929 ,
         57.959896 , -10.966793 ],
       ...,
       [-25.70777  , -10.58358  ,  -7.1986876, ...,  77.43709  ,
         10.047129 , -28.002825 ],
       [  9.16236  ,  37.486954 ,   2.790567 , ..., -47.4012   ,
        -18.492418 ,  29.89198  ],
       [ 21.791706 ,  16.047516 , -46.4171   , ...,   7.6332836,
         13.882181 , -71.61071  ]], dtype=float32)

In [38]:
%timeit -n100 -r10 c.step()
%timeit -n100 -r10 outside(c.x,c.c1.x)

4.27 ms ± 230 μs per loop (mean ± std. dev. of 10 runs, 100 loops each)
4.21 ms ± 169 μs per loop (mean ± std. dev. of 10 runs, 100 loops each)


In [44]:
def f(x,y):
    return x@y

@partial(jax.jit,static_argnums=(1,))
def vmapped(x,y):
    return jax.vmap(f,in_axes=(0,0))(x,y)

In [45]:
x = jax.random.normal(key(),(10,1000,1000))
y = jax.random.normal(key(),(10,1000,1000))
vmapped(x,y)

ValueError: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'jaxlib.xla_extension.ArrayImpl'>, [[[ 1.03704584e+00 -1.04108131e+00 -2.29338020e-01 ... -2.38400340e+00
   -5.75253308e-01  2.38139105e+00]
  [ 1.00051832e+00 -2.60955215e-01 -1.19654369e+00 ...  2.82352269e-01
   -3.58326249e-02  1.04465628e+00]
  [ 1.37149692e-01  2.34273955e-01  9.05101299e-01 ... -3.92438322e-01
    1.76136756e+00  4.62761782e-02]
  ...
  [ 1.92702468e-02  6.48455799e-01 -6.47387028e-01 ... -1.74105763e+00
   -1.47042260e-01 -1.12431455e+00]
  [-2.00012222e-01 -6.61116183e-01 -2.93076009e-01 ... -4.39528853e-01
    1.91518739e-01 -1.57084501e+00]
  [ 8.02728772e-01  1.05269730e+00 -9.09205258e-01 ... -1.06039679e+00
   -4.78373915e-01 -9.13119037e-03]]

 [[ 2.27715030e-01  2.03399211e-01  2.95398623e-01 ... -2.70673871e-01
   -1.58086729e+00 -1.22985733e+00]
  [ 1.65342033e+00 -1.25151396e+00  2.38321280e+00 ... -4.58790988e-01
    1.45490646e+00  8.63270462e-01]
  [ 4.61011901e-02  9.07626390e-01 -2.02650204e-01 ...  1.63067305e+00
    9.79223922e-02  7.56201863e-01]
  ...
  [-2.28644013e+00 -8.55249166e-01 -6.86759889e-01 ...  1.86622277e-01
   -5.42330325e-01 -1.50007039e-01]
  [-6.32356927e-02  2.74241567e-01 -1.35291553e+00 ...  5.12543023e-01
    7.44169414e-01 -6.31853223e-01]
  [ 6.30225658e-01  2.27782190e-01 -1.11524534e+00 ...  4.08963054e-01
   -9.43428516e-01 -1.38390493e+00]]

 [[ 5.00866652e-01 -1.27581079e-02  4.56214547e-01 ... -7.02286720e-01
   -8.36087227e-01  9.64239955e-01]
  [ 9.47768807e-01 -2.31369781e+00 -5.63923240e-01 ... -1.03626454e+00
    3.65863770e-01 -4.24644798e-01]
  [ 1.43890691e+00 -1.25726080e+00  5.78677237e-01 ...  4.65313420e-02
   -2.13044381e+00 -1.95093679e+00]
  ...
  [-3.46660256e-01  4.46816951e-01  1.31057501e+00 ... -2.11818433e+00
   -4.14184660e-01 -1.05553314e-01]
  [ 1.37663293e+00  1.03607881e+00  7.69723594e-01 ...  8.90806854e-01
   -1.77171245e-01 -1.18582153e+00]
  [ 1.73965347e+00 -2.06052244e-01 -5.42385697e-01 ... -7.94219553e-01
    6.62151396e-01 -5.04332036e-02]]

 ...

 [[-1.63147914e+00 -1.15416542e-01 -2.81382650e-01 ...  3.41179132e-01
   -1.62080944e+00  1.14577763e-01]
  [-9.93992835e-02  9.16982055e-01 -1.79580927e+00 ... -5.56136072e-01
    5.05208671e-01  7.73912311e-01]
  [ 4.29427743e-01 -9.04865682e-01 -7.32284725e-01 ...  5.44249177e-01
    1.39648274e-01  4.24285948e-01]
  ...
  [-2.58523792e-01 -7.02559352e-01 -1.56759113e-01 ... -1.86613441e+00
    5.93415916e-01 -6.66194201e-01]
  [ 1.20303559e+00  3.22338700e-01  2.46804023e+00 ... -1.22824025e+00
   -7.69521073e-02 -7.10751891e-01]
  [-9.97448504e-01 -6.51679188e-03  1.16317439e+00 ...  5.19573987e-01
    5.86722791e-01  4.39802371e-03]]

 [[ 1.38665712e+00 -1.04791033e+00  1.03120156e-01 ...  4.81855333e-01
    9.47672784e-01 -4.65204298e-01]
  [ 2.06559563e+00 -3.52326483e-01  5.21771133e-01 ...  1.08644009e+00
   -1.78372872e+00 -1.02832425e+00]
  [ 1.40735412e+00 -5.63022077e-01  2.50414753e+00 ...  1.46373725e-02
    1.43189907e+00  1.47473085e+00]
  ...
  [-8.10471654e-01  1.72056258e+00  3.95932198e-01 ... -1.27685353e-01
   -1.05876672e+00 -6.83200836e-01]
  [-2.80211449e-01 -1.17289031e+00  1.41392767e+00 ... -1.72109110e-03
   -1.09696209e+00  1.18866837e+00]
  [ 1.21784723e+00  2.20031515e-01 -5.95208049e-01 ... -5.86371601e-01
    1.13640273e+00 -1.34611678e+00]]

 [[ 1.13317382e+00 -1.00422812e+00  1.28222239e+00 ...  9.18946087e-01
   -1.69928145e+00  5.48540235e-01]
  [ 6.71897754e-02  1.31334603e-01  7.83413827e-01 ...  6.10358357e-01
   -5.25806487e-01 -1.04243493e+00]
  [ 2.62155116e-01 -1.16374791e+00 -5.42242110e-01 ... -2.76293129e-01
   -4.22694325e-01  1.48544455e+00]
  ...
  [ 1.55089021e+00  2.08026981e+00 -8.32753778e-01 ...  2.83946007e-01
   -2.10967064e-01 -5.96374154e-01]
  [-5.34397960e-01  8.43855262e-01  4.56557125e-01 ...  5.15110970e-01
   -1.33779213e-01 -1.07159913e+00]
  [-5.44506550e-01 -1.58776784e+00 -1.38596308e+00 ...  1.03533947e+00
    2.14565679e-01 -1.39092314e+00]]]. The error was:
TypeError: unhashable type: 'jaxlib.xla_extension.ArrayImpl'
