[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/demo_decoding.ipynb)
        
# Qwen3-0.6B Decoding Demo

## Prerequisites

### Change Runtime Type

**Instructions:**
1.  Navigate to the menu at the top of the screen.
2.  Click on **Runtime**.
3.  Select **Change runtime type** from the dropdown menu.
4.  Select **v5e-1 TPU** as the **Hardware accelerator**.
5. Click on **Save**.

### Get Your Hugging Face Token

To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.

**Follow these steps to get your token:**

1.  **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:
    *   [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)

2.  **Create a new token** by clicking the **"+ Create new token"** button.

3.  **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.

4.  **Copy the generated token**. You will need to paste it in the next step.

**Follow these steps to store your token:**

1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).

2. Click **"+ Add new secret"**.

3. Set the Name as **HF_TOKEN**.

4. Paste your token into the Value field.

5. Ensure the Notebook access toggle is turned On.

## Installation: MaxText & Other Dependencies

In [None]:
# Install uv, a fast Python package installer
!pip install uv

# Install MaxText and dependencies
!uv pip install maxtext --resolution=lowest
!install_maxtext_github_deps

# Use nest_asyncio to allow nested event loops in notebooks
!uv pip install nest_asyncio

# Install the PyTorch library
!uv pip install torch

### Restart Session
To apply certain changes, you need to restart the session.

**Instructions:**
1.  Navigate to the menu at the top of the screen.
2.  Click on **Runtime**.
3.  Select **Restart session** from the dropdown menu.

You will be asked to confirm the action in a pop-up dialog. Click on **Yes**.

## Imports

In [None]:
%%capture
import datetime
import jax
import os
import nest_asyncio
import numpy as np

import MaxText as mt
from MaxText import common_types
from MaxText import inference_utils
from MaxText import maxtext_utils
from MaxText import max_logging
from MaxText import pyconfig
from MaxText.input_pipeline import _input_pipeline_utils
from MaxText.utils.ckpt_conversion import to_maxtext

from google.colab import userdata
from huggingface_hub import login

MAXTEXT_PKG_DIR = os.path.dirname(mt.__file__)

nest_asyncio.apply()

## Sanity Test: Checking for Available TPU Devices

In [None]:
jax.distributed.initialize()  # distributed.initialize should only be called once.
jax.devices()

## Model Configurations

In [None]:
MODEL_NAME = "qwen3-0.6b"
PROMPT = "I love to"
RUN_NAME = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
MODEL_CHECKPOINT_PATH = f"/tmp/checkpoints/{MODEL_NAME}/{RUN_NAME}/unscanned"

HF_TOKEN = userdata.get("HF_TOKEN")
login(token=HF_TOKEN)
max_logging.log("Authenticated with Hugging Face successfully!")

## Download Model Checkpoint From Hugging Face

In [None]:
%%capture
argv = [
    "",  # This is a placeholder, it's not actually used by the script's logic
    f"{MAXTEXT_PKG_DIR}/configs/base.yml",
    f"model_name={MODEL_NAME}",
    f"base_output_directory={MODEL_CHECKPOINT_PATH}",
    f"hf_access_token={HF_TOKEN}",
    "use_multimodal=false",
    "scan_layers=false",
]

to_maxtext.main(argv)

In [None]:
max_logging.log(f"Model checkpoint can be found at: {MODEL_CHECKPOINT_PATH}/0/items")

## Initialize Configurations

In [None]:
%%capture
config = pyconfig.initialize(
    ["", f"{MAXTEXT_PKG_DIR}/configs/base.yml"],
    per_device_batch_size=1.0,
    run_name="test",
    max_target_length=4,
    max_prefill_predict_length=4,
    tokenizer_path=f"{MAXTEXT_PKG_DIR}/assets/qwen3-tokenizer",
    load_parameters_path=f"{MODEL_CHECKPOINT_PATH}/0/items",
    model_name=MODEL_NAME,
    async_checkpointing=False,
    prompt=PROMPT,
    scan_layers=False,
)

In [None]:
max_logging.log("Decode configurations initialized.")

## Initialize Decode State

In [None]:
model = mt.from_config(config)
mesh = model.mesh
init_rng = jax.random.PRNGKey(config.init_weights_seed)
state, _ = maxtext_utils.setup_decode_state(model, config, init_rng, mesh, None)
max_logging.log("Decode state initialized.")

## Get Tokenizer

In [None]:
tokenizer = _input_pipeline_utils.get_tokenizer(
    f"{MAXTEXT_PKG_DIR}/assets/qwen3-tokenizer",
    "huggingface",
    add_bos=True,
    add_eos=False,
)
max_logging.log("Tokenizer loaded succuessfully.")

## Prepare Inputs

In [None]:
input_ids = tokenizer.encode(config.prompt)

# Pad input_ids to max_target_length
padded_ids = np.zeros(config.max_target_length, dtype=np.int32)
padded_ids[: len(input_ids)] = input_ids
ids = np.asarray(padded_ids, dtype=np.int32)

s = (config.global_batch_size_to_train_on, config.max_target_length)
decoder_segment_ids = np.zeros(s) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR
decoder_positions = np.stack(
    [np.arange(config.max_target_length, dtype=np.int32) for _ in range(config.global_batch_size_to_train_on)]
)

ids = np.stack([ids for _ in range(config.global_batch_size_to_train_on)])
max_logging.log(
    f"input_ids={input_ids}, \n\nids={ids}, \n\ndecoder_segment_ids = {decoder_segment_ids}, \n\ndecoder_positions= {decoder_positions}"
)

## Run Forward Pass

In [None]:
full_train_logits = model.apply(
    state.params,
    ids,
    decoder_positions,
    decoder_segment_ids,
    enable_dropout=False,
    rngs={"aqt": init_rng},
)
full_train_logits = jax.experimental.multihost_utils.process_allgather(full_train_logits)
max_logging.log(f"{full_train_logits[0, 0, :]=}")

## Generate Text with Greedy Decoding

In [None]:
selected_logits = jax.lax.dynamic_slice(
    full_train_logits, (0, 0, full_train_logits.shape[2] - 2, 0), (1, 1, 1, full_train_logits.shape[3])
)

# Consider the greedily sampled token
init_rng, new_rng = jax.random.split(init_rng)
first_generated_token = inference_utils.sampling(
    selected_logits,
    new_rng,
    config.decode_sampling_strategy,  # "greedy"
)
output = tokenizer.decode([first_generated_token.item()])
max_logging.log(f"Next predicted token is `{output}` for the input prompt: `{config.prompt}`.")