In [1]:
!pip install juliapkg
import juliapkg
juliapkg.require_julia("=1.10.8")
juliapkg.resolve()
!pip install dm-haiku juliacall



True

In [5]:
import os
os.environ["JAX_PLATFORMS"]="cpu"
import jax
import haiku

def forward(x):
   mlp = haiku.nets.MLP([30, 20, 10])
   return mlp(x)

forward = haiku.without_apply_rng(haiku.transform(forward))
rng = haiku.PRNGSequence(jax.random.PRNGKey(42))
x = jax.numpy.ones([8, 28 * 28])
params = forward.init(next(rng), x)
hlo_code = jax.jit(forward.apply).lower(params,x).as_text()

for line in hlo_code.split('\n'):
   print(line)

module @jit_apply_fn attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<30xf32>, %arg1: tensor<784x30xf32>, %arg2: tensor<20xf32>, %arg3: tensor<30x20xf32>, %arg4: tensor<10xf32>, %arg5: tensor<20x10xf32>, %arg6: tensor<8x784xf32>) -> (tensor<8x10xf32> {jax.result_info = ""}) {
    %0 = stablehlo.dot_general %arg6, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<8x784xf32>, tensor<784x30xf32>) -> tensor<8x30xf32>
    %1 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<30xf32>) -> tensor<8x30xf32>
    %2 = stablehlo.add %0, %1 : tensor<8x30xf32>
    %3 = call @relu(%2) : (tensor<8x30xf32>) -> tensor<8x30xf32>
    %4 = stablehlo.dot_general %3, %arg3, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<8x30xf32>, tensor<30x20xf32>) -> tensor<8x20xf32>
    %5 = stablehlo.broadcast_in_dim %arg2, dims = [1] : (tensor<20xf32>) -> tensor<8x20xf32>
    %6 = stablehlo.add %4, %

In [38]:
from juliacall import Main as Julia # juliacall must be imported before Torch

Julia.seval("""
using Serialization

to_namedtuple(x::PyArray) = Array(x)
function to_namedtuple(dict::PyDict)
    # notice that we sort dictionary keys
    # this is consistent with observed JAX behavior
    k = map(string, dict |> keys |> collect) |> sort
    NamedTuple(Symbol(key) => to_namedtuple(dict[key]) for key in k)
end

function save_model(filename, code, params, x, grads)
    model = (code=string(code),
             params=to_namedtuple(params),
             grads=to_namedtuple(grads),
             x=Array(x))
    return Serialization.serialize(string(filename), model)
end
""")

def loss(params,x):
    pred = forward.apply(params,x)
    return (pred*pred).sum()

l = loss(params,x)
grads = jax.grad(loss)(params, x) # gradient wrt params only

Julia.save_model("small.jld", hlo_code, params, x, grads)
!ls -l *.jld

-rw-rw-r-- 1 dubos dubos 1093924 mars   9 22:43 MLP.jld
-rw-rw-r-- 1 dubos dubos  222529 mars  12 14:15 small.jld


{'mlp/~/linear_0': {'b': Array([ 0.        , -1.2139146 ,  0.27003068,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  2.5587885 ,  0.        ,
          2.8276005 ,  0.        , -0.9441228 , -0.48282346, -0.6649508 ,
         -0.00763107,  1.3438833 , -1.3073275 ,  2.2923863 ,  0.        ,
          0.5418246 ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        , -0.27992308,  0.        ,  0.        ,  0.        ],      dtype=float32),
  'w': Array([[ 0.        , -1.2139146 ,  0.27003068, ...,  0.        ,
           0.        ,  0.        ],
         [ 0.        , -1.2139146 ,  0.27003068, ...,  0.        ,
           0.        ,  0.        ],
         [ 0.        , -1.2139146 ,  0.27003068, ...,  0.        ,
           0.        ,  0.        ],
         ...,
         [ 0.        , -1.2139146 ,  0.27003068, ...,  0.        ,
           0.        ,  0.        ],
         [ 0.        , -1.2139146 ,  0.27003068, ...,  0.        ,
         