In [None]:
# Use nest_asyncio to allow nested event loops in notebooks
!pip install nest_asyncio

In [None]:
import os
import sys

import nest_asyncio

# Get the current working directory of the notebook
# This will be '.../maxtext/MaxText/scratch_code'
current_dir = os.getcwd()

# Navigate two levels up to get to the project root 'maxtext'
project_root = os.path.abspath(os.path.join(current_dir, '..', '..'))

# Add the project root to the system path if it's not already there
if project_root not in sys.path:
    sys.path.insert(0, project_root)
    print(f"Added '{project_root}' to sys.path")

nest_asyncio.apply()

Sanity test to ensure we have TPU devices accessible

In [None]:
import jax
jax.distributed.initialize() # distributed.initialize should only be called once.
jax.devices()

Import MaxText and its components

In [None]:
import MaxText as mt
from MaxText import pyconfig
from MaxText import maxtext_utils
import numpy as np
from MaxText.input_pipeline import _input_pipeline_utils
import os
from MaxText.globals import PKG_DIR
from MaxText import max_logging
from MaxText import common_types
import jax
from MaxText import inference_utils

In [None]:
# Replace path to your Llama3.1-8b checkpoint for the `load_parameters_path` argument.
config = pyconfig.initialize(
    ["", "../configs/base.yml"], 
    per_device_batch_size=1.0,
    run_name="test",
    max_target_length=4,
    max_prefill_predict_length=4,
    tokenizer_type="tiktoken",
    tokenizer_path="assets/tokenizer_llama3.tiktoken/",
    load_parameters_path="path/to/your/llama3.1-8b/checkpoint",  # Replace with your checkpoint path
    model_name="llama3.1-8b",
    async_checkpointing=False,

)

model = mt.from_config(config)
mesh = model.mesh
init_rng = jax.random.PRNGKey(config.init_weights_seed)
state, _ = maxtext_utils.setup_decode_state(model, config, init_rng, mesh, None)



Get Tokenizer

In [None]:
source_tokenizer = _input_pipeline_utils.get_tokenizer(
        os.path.join(os.path.dirname(PKG_DIR), "assets", "tokenizer_llama3.tiktoken"),
        "tiktoken",
        add_bos=True,
        add_eos=False,
    )

Prepare the inputs

In [None]:
input_ids = source_tokenizer.encode(config.prompt)
ids = np.asarray(input_ids, dtype=np.int32)
s = (config.global_batch_size_to_train_on, config.max_target_length)
decoder_segment_ids = np.zeros(s) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR
decoder_positions = np.stack(
    [np.arange(config.max_target_length, dtype=np.int32) for _ in range(config.global_batch_size_to_train_on)]
)


ids = np.stack([ids for _ in range(config.global_batch_size_to_train_on)])
max_logging.log(f"input_ids={input_ids}, \nids={ids}, \ndecoder_segment_ids = {decoder_segment_ids}, \ndecoder_positions= {decoder_positions}")


Run a forward pass

In [None]:
full_train_logits = model.apply(
          state.params,
          ids,
          decoder_positions,
          decoder_segment_ids,
          enable_dropout=False,
          rngs={"aqt": init_rng},
      )
full_train_logits = jax.experimental.multihost_utils.process_allgather(full_train_logits)
max_logging.log(f"{full_train_logits[0, 0, :]=}")

Check the logits

In [None]:
selected_logits = jax.lax.dynamic_slice(
        full_train_logits,
        (0, 0, full_train_logits.shape[2]-1, 0),
        (1, 1, 1, full_train_logits.shape[3])
    )

In [None]:
# Consider the greedily sampled token
init_rng, new_rng = jax.random.split(init_rng)
first_generated_token = inference_utils.sampling(
        selected_logits,
        new_rng,
        config.decode_sampling_strategy, #"greedy"
    )

In [None]:
first_generated_token.item()

In [None]:
source_tokenizer.decode([first_generated_token.item()])