In [1]:
import sys

sys.path.append("../../..")

In [2]:
from dataclasses import asdict
from typing import Any

import flax
import jax
import numpy as np
from flax import linen as nn
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec as P

from xlstm_jax.distributed.mesh_utils import initialize_mesh
from xlstm_jax.models.configs import ParallelConfig
from xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend import mLSTMBackendNameAndKwargs
from xlstm_jax.models.xlstm_parallel.blocks.mlstm.block import mLSTMBlockConfig
from xlstm_jax.models.xlstm_parallel.blocks.mlstm.cell import mLSTMCellConfig
from xlstm_jax.models.xlstm_parallel.blocks.mlstm.layer import mLSTMLayerConfig
from xlstm_jax.models.xlstm_parallel.components.feedforward import FeedForwardConfig
from xlstm_jax.models.xlstm_parallel.xlstm_lm_model import xLSTMLMModel, xLSTMLMModelConfig
from xlstm_jax.utils.model_param_handling.convert_checkpoint import convert_orbax_checkpoint_to_torch_state_dict
from xlstm_jax.utils.model_param_handling.handle_mlstm_simple import (
    pipeline_convert_mlstm_checkpoint_jax_to_torch_simple,
)
from xlstm_jax.utils.model_param_handling.load import load_model_params_and_config_from_checkpoint
from xlstm_jax.utils.pytree_utils import flatten_dict

  from .autonotebook import tqdm as notebook_tqdm
2024-11-25 17:21:06.183475: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-25 17:21:06.200030: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-25 17:21:06.204860: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
JAX_CHECKPOINT_PATH = "/nfs-gpu/xlstm/logs/outputs/xlstm-jax/DCLM/dclm_mLSTMv1_1.3B_ctx8192_2024-11-19T09:24:50/0/checkpoints/checkpoint_95000"

In [4]:
jax_checkpoint, jax_config = load_model_params_and_config_from_checkpoint(
    JAX_CHECKPOINT_PATH, return_config_as_dataclass=True
)

2024-11-25 17:21:16.621767: W external/xla/xla/service/gpu/nvptx_compiler.cc:930] The NVIDIA driver's CUDA version is 12.4 which is older than the PTX compiler version 12.6.77. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [5]:
from pprint import pprint

pprint(flatten_dict(asdict(jax_config)))

{'_block_map': '0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0',
 'add_embedding_dropout': False,
 'add_post_blocks_norm': True,
 'bias': False,
 'context_length': 128,
 'dropout': 0.0,
 'dtype': 'bfloat16',
 'embedding_dim': 2048,
 'init_distribution_embed': 'normal',
 'init_distribution_out': 'normal',
 'lm_head_dtype': 'bfloat16',
 'logits_soft_cap': 30.0,
 'mlstm_block._block_idx': None,
 'mlstm_block._num_blocks': 24,
 'mlstm_block.add_post_norm': False,
 'mlstm_block.feedforward._num_blocks': 24,
 'mlstm_block.feedforward._proj_up_dim': 5504,
 'mlstm_block.feedforward.act_fn': 'swish',
 'mlstm_block.feedforward.bias': False,
 'mlstm_block.feedforward.dropout': 0.0,
 'mlstm_block.feedforward.dtype': 'bfloat16',
 'mlstm_block.feedforward.embedding_dim': 2048,
 'mlstm_block.feedforward.ff_type': 'ffn_gated',
 'mlstm_block.feedforward.init_distribution': 'normal',
 'mlstm_block.feedforward.output_init_fn': 'wang',
 'mlstm_block.feedforward.parallel.data_axis_name': 'dp',
 'mlstm_blo

In [6]:
parallel = ParallelConfig(
    data_axis_name="dp",
    fsdp_axis_name="fsdp",
    model_axis_name="tp",
    pipeline_axis_name="pp",
    fsdp_modules=[],
    fsdp_gather_dtype="bfloat16",
    fsdp_min_weight_size=2**18,
    remat=[],
    fsdp_axis_size=1,
    model_axis_size=1,
    data_axis_size=1,
    tp_async_dense=False,
)

In [7]:
mesh = initialize_mesh(parallel_config=parallel, device_array=np.array(jax.devices())[0:1])

In [8]:
xlstm_model_jax = xLSTMLMModel(jax_config)

In [9]:
VOCAB_SIZE = jax_config.vocab_size
BATCH_SIZE = 3
CONTEXT_LENGTH = jax_config.context_length
VOCAB_SIZE, BATCH_SIZE, CONTEXT_LENGTH

(50304, 3, 128)

In [10]:
exmp_input = jax.random.randint(jax.random.PRNGKey(0), (BATCH_SIZE, CONTEXT_LENGTH), minval=0, maxval=VOCAB_SIZE)

In [11]:
def _init_model(rng: jax.Array, batch_input: jax.Array) -> Any:
    param_rng, dropout_rng = jax.random.split(rng)
    # Initialize parameters.
    variables = xlstm_model_jax.init({"params": param_rng, "dropout": dropout_rng}, batch_input)
    return variables


# Prepare PRNG.
init_rng = jax.random.PRNGKey(42)
# First infer the output sharding to set up shard_map correctly.
# This does not actually run the init, only evaluates the shapes.
init_model_fn = jax.jit(
    shard_map(
        _init_model,
        mesh,
        in_specs=(P(), P()),
        out_specs=P(),
        check_rep=False,
    ),
)
variables_shapes = jax.eval_shape(init_model_fn, init_rng, exmp_input)
variables_partition_specs = nn.get_partition_spec(variables_shapes)
# Run init model function again with correct output specs.
init_model_fn = jax.jit(
    shard_map(
        _init_model,
        mesh,
        in_specs=(P(), P()),
        out_specs=variables_partition_specs,
        check_rep=False,
    ),
)

variables = init_model_fn(init_rng, exmp_input)

In [12]:
variables = {}
variables["params"] = flax.core.frozen_dict.unfreeze(jax_checkpoint)


def _forward(
    batch_input: jax.Array, variables: Any, batch_position: jax.Array | None, batch_borders: jax.Array | None
) -> jax.Array:
    return xlstm_model_jax.apply(
        variables,
        batch_input,
        pos_idx=batch_position,
        document_borders=batch_borders,
        train=True,
        rngs={"dropout": jax.random.PRNGKey(42)},
    )


forward_fn = jax.jit(
    shard_map(
        _forward,
        mesh,
        in_specs=(P(), variables_partition_specs, P(), P()),
        out_specs=P(),
        check_rep=False,
    ),
)

logits_jax = forward_fn(exmp_input, variables, None, None)

In [13]:
logits_jax_np = jax.device_get(logits_jax)
example_inputs_np = jax.device_get(exmp_input)

In [14]:
np.savez("./logits_inputs_jax.npz", logits_jax=logits_jax_np, inputs=example_inputs_np)

In [15]:
file = np.load("./logits_inputs_jax.npz")

In [16]:
file["logits_jax"]

array([[[ 13.566897  ,   0.43551627,  13.216589  , ...,   0.44723246,
           0.4433271 ,   0.4394217 ],
        [  3.5920599 ,  -8.567479  ,   1.4364008 , ...,  -8.567479  ,
          -8.567479  ,  -8.567479  ],
        [ 15.932797  ,   2.1525445 ,  12.086428  , ...,   2.1525445 ,
           2.1525445 ,   2.1525445 ],
        ...,
        [  6.519365  ,  -4.619217  ,   4.527646  , ...,  -4.619217  ,
          -4.619217  ,  -4.619217  ],
        [  1.4286062 ,  -3.4996197 ,   3.2064557 , ...,  -3.4996197 ,
          -3.4996197 ,  -3.484206  ],
        [  1.4519895 ,   1.327258  ,   0.9684135 , ...,   1.3350551 ,
           1.3350551 ,   1.3350551 ]],

       [[ 10.641614  ,   5.045358  ,   5.8010306 , ...,   5.045358  ,
           5.0149865 ,   5.0149865 ],
        [ 14.205258  ,  -1.1556777 ,  11.34495   , ...,  -1.1634786 ,
          -1.1556777 ,  -1.1556777 ],
        [ 13.61657   ,  -0.9176824 ,  10.967963  , ...,  -0.9215849 ,
          -0.9176824 ,  -0.9176824 ],
        ...,
