In [1]:
import sys

sys.path.append("../..")
sys.path.append("../../../mlstm_simple_torch")

In [2]:
from pathlib import Path

import numpy as np
import torch
from mlstm_simple.from_pretrained import load_from_pretrained
from mlstm_simple.model import mLSTM, mLSTMConfig

In [3]:
logits_inputs_jax = np.load("./logits_inputs_jax.npz")

In [4]:
logits_jax = logits_inputs_jax["logits_jax"]
inputs_jax = logits_inputs_jax["inputs"]

In [5]:
JAX_CHECKPOINT_PATH = "/nfs-gpu/xlstm/logs/outputs/xlstm-jax/DCLM/dclm_mLSTMv1_1.3B_ctx8192_2024-11-19T09:24:50/0/checkpoints/checkpoint_95000"

In [6]:
SAVE_TORCH_CHECKPOINT_AT = (Path(".").parent / "mlstm_simple_checkpoint").resolve()
SAVE_TORCH_CHECKPOINT_AT

PosixPath('/nfs-gpu/users_home/beck/repos/xlstm-jax3/tests/models/mlstm_simple/mlstm_simple_checkpoint')

In [7]:
TORCH_AMP_DTYPE = torch.float32
ENABLE_TORCH_AMP = False
USE_TORCH_COMPILE = True
torch.set_float32_matmul_precision(
    "high"
)  # TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance

In [8]:
# ## Convert jax checkpoint to torch:
command = f'PYTHONPATH=. python scripts/convert_mlstm_checkpoint_jax_to_torch_simple.py --checkpoint_dir "{str(JAX_CHECKPOINT_PATH)}" --output_path "{str(SAVE_TORCH_CHECKPOINT_AT)}" --checkpoint_type plain'
print(command)

PYTHONPATH=. python scripts/convert_mlstm_checkpoint_jax_to_torch_simple.py --checkpoint_dir "/nfs-gpu/xlstm/logs/outputs/xlstm-jax/DCLM/dclm_mLSTMv1_1.3B_ctx8192_2024-11-19T09:24:50/0/checkpoints/checkpoint_95000" --output_path "/nfs-gpu/users_home/beck/repos/xlstm-jax3/tests/models/mlstm_simple/mlstm_simple_checkpoint" --checkpoint_type plain


In [9]:
model = load_from_pretrained(
    checkpoint_path=SAVE_TORCH_CHECKPOINT_AT,
    chunkwise_kernel_name="chunkwise--triton_xl_chunk",
    sequence_kernel_name="native_sequence__triton_step_fused",
    step_kernel_name="triton_fused",
    chunk_size=128,
)

In [10]:
model

mLSTM(
  (embedding): Embedding(50304, 2048)
  (backbone): mLSTMBlockStack(
    (blocks): ModuleList(
      (0-23): 24 x mLSTMBlock(
        (norm_mlstm): RMSNorm()
        (mlstm_layer): mLSTMLayer(
          (q): Linear(in_features=2048, out_features=1024, bias=False)
          (k): Linear(in_features=2048, out_features=1024, bias=False)
          (v): Linear(in_features=2048, out_features=2048, bias=False)
          (ogate_preact): Linear(in_features=2048, out_features=2048, bias=False)
          (ogate_act_fn): Sigmoid()
          (igate_preact): Linear(in_features=2048, out_features=4, bias=True)
          (fgate_preact): Linear(in_features=2048, out_features=4, bias=True)
          (mlstm_backend): mLSTMBackend(mLSTMBackendConfig(chunkwise_kernel='chunkwise--triton_xl_chunk', sequence_kernel='native_sequence__triton_step_fused', step_kernel='triton_fused', mode='inference', chunk_size=128, return_last_states=False, autocast_kernel_dtype='bfloat16', eps=1e-06, inference_state_dtyp

In [11]:
from pprint import pprint

pprint(model.config)

mLSTMConfig(embedding_dim=2048,
            num_heads=4,
            num_blocks=24,
            vocab_size=50304,
            use_bias=False,
            norm_eps=1e-06,
            norm_reduction_force_float32=True,
            add_out_norm=True,
            qk_dim_factor=0.5,
            v_dim_factor=1.0,
            mlstm_round_up_to_multiple_of=64,
            chunkwise_kernel='chunkwise--triton_xl_chunk',
            sequence_kernel='native_sequence__triton_step_fused',
            step_kernel='triton_fused',
            mode='inference',
            chunk_size=128,
            return_last_states=False,
            autocast_kernel_dtype='bfloat16',
            eps=1e-06,
            inference_state_dtype='float32',
            ffn_proj_factor=2.667,
            ffn_round_up_to_multiple_of=64,
            gate_soft_cap=15.0,
            output_logit_soft_cap=30.0)


In [12]:
model = model.to("cuda")
model.config.return_last_states = True

In [13]:
if USE_TORCH_COMPILE:
    model = torch.compile(model)
with torch.autocast(device_type="cuda", dtype=TORCH_AMP_DTYPE, enabled=ENABLE_TORCH_AMP):
    logits_torch, state = model(torch.from_numpy(inputs_jax).to("cuda"))

W1125 18:19:17.125000 139952932398912 torch/_dynamo/variables/tensor.py:715] [0/0] Graph break from `Tensor.item()`, consider setting:
W1125 18:19:17.125000 139952932398912 torch/_dynamo/variables/tensor.py:715] [0/0]     torch._dynamo.config.capture_scalar_outputs = True
W1125 18:19:17.125000 139952932398912 torch/_dynamo/variables/tensor.py:715] [0/0] or:
W1125 18:19:17.125000 139952932398912 torch/_dynamo/variables/tensor.py:715] [0/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W1125 18:19:17.125000 139952932398912 torch/_dynamo/variables/tensor.py:715] [0/0] to include these operations in the captured graph.
W1125 18:19:17.125000 139952932398912 torch/_dynamo/variables/tensor.py:715] [0/0] 


In [14]:
logits_torch_np = logits_torch.float().cpu().detach().numpy()

In [15]:
logits_jax

array([[[ 13.566897  ,   0.43551627,  13.216589  , ...,   0.44723246,
           0.4433271 ,   0.4394217 ],
        [  3.5920599 ,  -8.567479  ,   1.4364008 , ...,  -8.567479  ,
          -8.567479  ,  -8.567479  ],
        [ 15.932797  ,   2.1525445 ,  12.086428  , ...,   2.1525445 ,
           2.1525445 ,   2.1525445 ],
        ...,
        [  6.519365  ,  -4.619217  ,   4.527646  , ...,  -4.619217  ,
          -4.619217  ,  -4.619217  ],
        [  1.4286062 ,  -3.4996197 ,   3.2064557 , ...,  -3.4996197 ,
          -3.4996197 ,  -3.484206  ],
        [  1.4519895 ,   1.327258  ,   0.9684135 , ...,   1.3350551 ,
           1.3350551 ,   1.3350551 ]],

       [[ 10.641614  ,   5.045358  ,   5.8010306 , ...,   5.045358  ,
           5.0149865 ,   5.0149865 ],
        [ 14.205258  ,  -1.1556777 ,  11.34495   , ...,  -1.1634786 ,
          -1.1556777 ,  -1.1556777 ],
        [ 13.61657   ,  -0.9176824 ,  10.967963  , ...,  -0.9215849 ,
          -0.9176824 ,  -0.9176824 ],
        ...,


In [16]:
torch_logits_batch1 = logits_torch_np[0]
jax_logits_batch1 = logits_jax[0]

In [17]:
np.argmax(torch_logits_batch1, axis=-1)

array([   13,   187, 32682, 14683, 14683,  1768, 31507, 19508,    84,
        7994,  7994, 31823, 44199, 31507, 31507,  7994,  4516,   456,
       43780, 31823,   342,  2366, 20058,    52, 34948,  3344,  7043,
          66, 23790, 40918, 35422,  7043,  9518,  7043,   982, 19185,
       30153,   100,  7803,  4431,  9810, 31069,  5729, 30760, 20485,
        3162,  3015, 20485,  2233,  5560,  6144, 25518,  3776,  1209,
         692,    82,  1385, 25077,  9306,  5629, 20485,  2679,  4435,
       23144, 25077,  7422, 23144,   105, 21325,  6144,  1890,  3023,
         143, 46979,   608,  5537,  4449, 48779, 37246,   125,  4861,
       15912, 23144, 14262, 23144, 25518, 14722, 41805,   103, 14262,
         337, 39068,  5537, 18197,  5537,   337,  6144,  2575,  4418,
         318,    47, 41014, 46202,   404,  7538, 23068,   139, 14276,
           5,  8973,   370,  5537,  9747, 14683,  9467,  4739,    84,
        6235, 14262, 20485, 22369,  1216, 15902,  2575, 24724, 22690,
        8976, 23144]

In [18]:
np.argmax(jax_logits_batch1, axis=-1)

array([   13,   187, 32682, 14683, 14683,  1768, 31507, 19508,    84,
        7994,  7994, 31823, 44199, 31507, 31507,  7994,  4516,   456,
       43780, 31823,   342,  2366, 20058,    52, 33470,  3344,  7043,
          66, 23790, 40918, 35422,  7043,  9518,  7043,   982, 19185,
       30153,   100,  7803,  4431,  9810, 31069,  5729, 30760, 20485,
        4317,  3015, 20485,  2233,  5560,  6144, 25518,  3776,  1209,
         692,    82,  1385, 25077,  9306,  5629, 20485,  2679,  4435,
       23144, 25077,  7422,   114,   105, 21325,  6144,  1890,   430,
         143, 46979,     5, 15797,  4449, 48779, 37246,   125,  4861,
       15912, 23144, 14262, 23144, 25518, 14722,  7917,   103,  9885,
         337, 18197,  5537, 18197,  5537,   337,  6144,  2575,  4418,
         318,    47, 41014, 46202,   404,  7538, 22143,   139, 14276,
           5, 50076,   370,  5537,  9747, 14683,  9467,  4739,    84,
        6235, 14262, 20485, 22369,  1216, 15902,  2575, 24724, 22690,
       38512, 23144]

In [19]:
batch1_equal_argmax = np.argmax(torch_logits_batch1, axis=-1) == np.argmax(jax_logits_batch1, axis=-1)

In [20]:
batch1_equal_argmax

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True, False,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
       False,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True, False,  True,  True,  True,  True, False,
        True,  True, False, False,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True, False,  True, False,
        True, False,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True, False,  True,  True,
        True, False,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
       False,  True]

In [21]:
batch1_equal_argmax.sum(), len(batch1_equal_argmax)

(116, 128)

In [22]:
batch_1_torch_top5 = np.argsort(torch_logits_batch1, axis=-1)[:, -5:]
batch_1_jax_top5 = np.argsort(jax_logits_batch1, axis=-1)[:, -5:]

In [23]:
indxes = slice(0, 10)

batch_1_jax_top5[indxes], batch_1_torch_top5[indxes]

(array([[  943,   247,   347,    84,    13],
        [  403,   275,   273,    15,   187],
        [24748,     0,   187,   273, 32682],
        [ 4859, 32682,  7234,  6197, 14683],
        [ 6197, 36482, 34401, 19508, 14683],
        [ 1559,    67,    84, 21287,  1768],
        [17438,  6097, 19508, 14683, 31507],
        [  337,   187,    13, 24747, 19508],
        [19965, 49357, 19508, 31507,    84],
        [40222, 14683,    52, 26745,  7994]]),
 array([[  403,   247,   347,    84,    13],
        [  275,   273,   403,    15,   187],
        [24748,     0,   187,   273, 32682],
        [42443, 32682,  7234,  6197, 14683],
        [ 6197, 36482, 34401, 19508, 14683],
        [ 1559,    67,    84, 21287,  1768],
        [17438,  6097, 19508, 14683, 31507],
        [31507,   187,    13, 24747, 19508],
        [19965, 49357, 19508, 31507,    84],
        [40222, 14683,    52, 26745,  7994]]))

In [24]:
np.testing.assert_allclose(logits_torch_np, logits_jax, atol=2.0, rtol=1.0)