In [1]:
import torch
from torch_model import ModelArgs, Transformer
from model import Transformer as JaxTransformer
import pax
from torch_model import precompute_freqs_cis as torch_precompute_freqs_cis
from model import precompute_freqs_cis
import numpy as np
import jax
from functools import partial

In [2]:
model_args = ModelArgs(dim=512, n_layers=3, n_heads=4, vocab_size=16)
net = Transformer(model_args)
tokens = torch.randint(0, 16, (1, 5), generator=torch.Generator().manual_seed(40))
target = net.forward(tokens, 0)[0, -1, :]

In [3]:
a = torch_precompute_freqs_cis(10, 128).numpy()
b = precompute_freqs_cis(10, 128)
assert np.array_equal(a, b) == True

In [4]:
@pax.pure
def load_weight(jax_net):
    state = net.state_dict()
    for k in state.keys():
        part = k.split(".")
        assert part[-1] == "weight"
        mod = part[0]
        if mod == "layers":
            index = int(part[1])
            submod_name = part[2]
            sub_mod = getattr(jax_net.layers[index], submod_name)
            if submod_name in  ["attention", "feed_forward"]:
                attr = part[3]
                assert getattr(sub_mod, attr).weight.shape == state[k].shape
                getattr(sub_mod, attr).weight = state[k].numpy()
            else:
                assert sub_mod.weight.shape == state[k].shape
                sub_mod.weight = state[k].numpy()
        else:
            assert getattr(jax_net, mod).weight.shape == state[k].shape
            getattr(jax_net, mod).weight = state[k].numpy()
    return jax_net

jax_net = load_weight(JaxTransformer(model_args))

In [5]:
# @partial(jax.jit)
def inference(net, tokens, pos):
    net, output = pax.purecall(net, tokens, pos)
    return net, output

In [6]:
jax_net, output = inference(jax_net, tokens.numpy(), 0)

In [7]:
logit = output[0, -1, :]

In [8]:
target.numpy()

array([-0.7551486 ,  0.17903721, -0.20968904, -0.4628146 ,  0.7097525 ,
       -1.4803946 , -0.5472836 ,  0.08360268,  0.8104697 , -0.545269  ,
       -0.11844821,  0.10768881, -0.07461266, -0.5413903 ,  0.5620457 ,
        0.3828546 ], dtype=float32)

In [9]:
logit

Array([-0.7551489 ,  0.17903697, -0.20968904, -0.4628147 ,  0.70975274,
       -1.4803945 , -0.54728377,  0.08360252,  0.81046987, -0.545269  ,
       -0.11844828,  0.10768889, -0.07461257, -0.5413905 ,  0.562046  ,
        0.3828542 ], dtype=float32)