In [None]:
import sys

sys.path.append("../..")
sys.path.append("../../../mlstm_simple_torch")

In [None]:
from pathlib import Path

import numpy as np
import torch
from mlstm_simple.from_pretrained import load_from_pretrained
from mlstm_simple.model import mLSTM, mLSTMConfig

In [None]:
logits_inputs_jax = np.load("./logits_inputs_jax.npz")

In [None]:
logits_jax = logits_inputs_jax["logits_jax"]
inputs_jax = logits_inputs_jax["inputs"]

In [None]:
JAX_CHECKPOINT_PATH = "/nfs-gpu/xlstm/logs/outputs/xlstm-jax/DCLM/dclm_mLSTMv1_1.3B_ctx8192_2024-11-19T09:24:50/0/checkpoints/checkpoint_95000"

In [None]:
SAVE_TORCH_CHECKPOINT_AT = (Path(".").parent / "mlstm_simple_checkpoint").resolve()
SAVE_TORCH_CHECKPOINT_AT

In [None]:
TORCH_AMP_DTYPE = torch.float32
ENABLE_TORCH_AMP = False
USE_TORCH_COMPILE = True
torch.set_float32_matmul_precision(
    "high"
)  # TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance

In [None]:
# ## Convert jax checkpoint to torch:
command = f'PYTHONPATH=. python scripts/checkpoint_conversion/convert_mlstm_checkpoint_jax_to_torch_simple.py --checkpoint_dir "{str(JAX_CHECKPOINT_PATH)}" --output_path "{str(SAVE_TORCH_CHECKPOINT_AT)}" --checkpoint_type plain'
print(command)

In [None]:
model = load_from_pretrained(
    checkpoint_path=SAVE_TORCH_CHECKPOINT_AT,
    chunkwise_kernel_name="chunkwise--triton_xl_chunk",
    sequence_kernel_name="native_sequence__triton_step_fused",
    step_kernel_name="triton_fused",
    chunk_size=128,
)

In [None]:
model

In [None]:
from pprint import pprint

pprint(model.config)

In [None]:
model = model.to("cuda")
model.config.return_last_states = True

In [None]:
if USE_TORCH_COMPILE:
    model = torch.compile(model)
with torch.autocast(device_type="cuda", dtype=TORCH_AMP_DTYPE, enabled=ENABLE_TORCH_AMP):
    logits_torch, state = model(torch.from_numpy(inputs_jax).to("cuda"))

In [None]:
logits_torch_np = logits_torch.float().cpu().detach().numpy()

In [None]:
logits_jax

In [None]:
torch_logits_batch1 = logits_torch_np[0]
jax_logits_batch1 = logits_jax[0]

In [None]:
np.argmax(torch_logits_batch1, axis=-1)

In [None]:
np.argmax(jax_logits_batch1, axis=-1)

In [None]:
batch1_equal_argmax = np.argmax(torch_logits_batch1, axis=-1) == np.argmax(jax_logits_batch1, axis=-1)

In [None]:
batch1_equal_argmax

In [None]:
batch1_equal_argmax.sum(), len(batch1_equal_argmax)

In [None]:
batch_1_torch_top5 = np.argsort(torch_logits_batch1, axis=-1)[:, -5:]
batch_1_jax_top5 = np.argsort(jax_logits_batch1, axis=-1)[:, -5:]

In [None]:
indxes = slice(0, 10)

batch_1_jax_top5[indxes], batch_1_torch_top5[indxes]

In [None]:
np.testing.assert_allclose(logits_torch_np, logits_jax, atol=2.0, rtol=1.0)