In [41]:
from transformers import AutoProcessor, Llama4ForConditionalGeneration
import torch
from transformers import AutoConfig

config = AutoConfig.from_pretrained(
    "meta-llama/Llama-4-Scout-17B-16E-Instruct",
    output_router_logits=True,         # enable router-softmax output
)

processor = AutoProcessor.from_pretrained("meta-llama/Llama-4-Scout-17B-16E-Instruct")
model = Llama4ForConditionalGeneration.from_pretrained(
    "meta-llama/Llama-4-Scout-17B-16E-Instruct",
    config=config,
    attn_implementation="sdpa",
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

Loading checkpoint shards: 100%|██████████| 50/50 [00:31<00:00,  1.61it/s]
Some parameters are on the meta device because they were offloaded to the cpu.


Llama4ForConditionalGeneration(
  (vision_model): Llama4VisionModel(
    (patch_embedding): Llama4UnfoldConvolution(
      (unfold): Unfold(kernel_size=(14, 14), dilation=1, padding=0, stride=14)
      (linear): Linear(in_features=588, out_features=1408, bias=False)
    )
    (rotary_embedding): Llama4VisionRotaryEmbedding()
    (layernorm_pre): LayerNorm((1408,), eps=1e-05, elementwise_affine=True)
    (layernorm_post): LayerNorm((1408,), eps=1e-05, elementwise_affine=True)
    (model): Llama4VisionEncoder(
      (layers): ModuleList(
        (0-33): 34 x Llama4VisionEncoderLayer(
          (self_attn): Llama4VisionAttention(
            (q_proj): Linear(in_features=1408, out_features=1408, bias=True)
            (k_proj): Linear(in_features=1408, out_features=1408, bias=True)
            (v_proj): Linear(in_features=1408, out_features=1408, bias=True)
            (o_proj): Linear(in_features=1408, out_features=1408, bias=True)
          )
          (mlp): Llama4VisionMLP(
           

In [42]:
# Create a hook to capture router probabilities
router_probs = []

def hook_fn(module, input, output):
    # Print information about input and output
    input_shape = [i.shape if isinstance(i, torch.Tensor) else type(i) for i in input]
    print(f"Router input shape: {input_shape}")
    
    # Check if output is a tensor or has special attributes
    if isinstance(output, torch.Tensor):
        print(f"Router output shape: {output.shape}")
        # Only capture based on actual sequence length
        seq_len = input[0].shape[1] if isinstance(input[0], torch.Tensor) else 0
        if seq_len > 0 and seq_len < output.shape[0]:
            output_slice = output[:seq_len]
            router_probs.append(output_slice.detach().cpu())
        else:
            router_probs.append(output.detach().cpu())
    else:
        print(f"Router output type: {type(output)}")
        if hasattr(output, 'router_probs'):
            router_probs.append(output.router_probs.detach().cpu())

# Register hooks on MoE layers
for name, module in model.named_modules():
    if "router" in name.lower():
        # print("module names with router in it : ", name.lower())
        module.register_forward_hook(hook_fn)

# Process some text

In [43]:
import time
messages = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "Hello! It's nice to meet you. Is there something I can help you with, or would you like to chat? Can you repeat after me ? I want you to count the number of input tokens rn and also sing me a song about deep learning" },
        ]
    },
]

# Start timing
start_total = time.time()

# Time the processing part
start_process = time.time()
tokenized = processor.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=True,
    return_dict=True,
)
inputs = processor.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=True,
    return_dict=True,
    return_tensors="pt",
).to(model.device)
process_time = time.time() - start_process

# Time the generation part
start_generate = time.time()
outputs = model.generate(
    **inputs,
    max_new_tokens=50,
    return_dict_in_generate=True,
)
generate_time = time.time() - start_generate

# Time the decoding part
start_decode = time.time()
generated_tokens = outputs.sequences[0, inputs.input_ids.shape[1]:]
decoded_text = processor.decode(generated_tokens, skip_special_tokens=True)
decode_time = time.time() - start_decode

# Calculate total time
total_time = time.time() - start_total

# Print results
print(f"Processing time: {process_time:.4f} seconds")
print(f"Generation time: {generate_time:.4f} seconds")
print(f"Decoding time: {decode_time:.4f} seconds")
print(f"Total time: {total_time:.4f} seconds")
print("\nGenerated text:")
print(decoded_text)

Router input shape: [torch.Size([60, 5120])]
Router output shape: torch.Size([60, 16])
Router input shape: [torch.Size([60, 5120])]
Router output shape: torch.Size([60, 16])
Router input shape: [torch.Size([60, 5120])]
Router output shape: torch.Size([60, 16])
Router input shape: [torch.Size([60, 5120])]
Router output shape: torch.Size([60, 16])
Router input shape: [torch.Size([60, 5120])]
Router output shape: torch.Size([60, 16])
Router input shape: [torch.Size([60, 5120])]
Router output shape: torch.Size([60, 16])
Router input shape: [torch.Size([60, 5120])]
Router output shape: torch.Size([60, 16])
Router input shape: [torch.Size([60, 5120])]
Router output shape: torch.Size([60, 16])
Router input shape: [torch.Size([60, 5120])]
Router output shape: torch.Size([60, 16])
Router input shape: [torch.Size([60, 5120])]
Router output shape: torch.Size([60, 16])
Router input shape: [torch.Size([60, 5120])]
Router output shape: torch.Size([60, 16])
Router input shape: [torch.Size([60, 5120])

In [40]:

print(f"Number of tokens: {len(tokenized.input_ids[0])}")

for idx, probs in enumerate(router_probs[0:1]):
    print(f"Layer {idx} router probabilities:")
    print(f"Shape: {probs.shape}")
    print(f"Top-5 values per token: {torch.topk(probs, 5, dim=-1).values}")
    print(f"Top-5 indices per token: {torch.topk(probs, 5, dim=-1).indices}")

Number of tokens: 60
Layer 0 router probabilities:
Shape: torch.Size([18, 16])
Top-5 values per token: tensor([[-0.8555, -0.8789, -0.8945, -0.9023, -0.9219],
        [-0.4766, -0.4844, -0.5391, -0.6016, -0.6875],
        [-0.1641, -0.2109, -0.3477, -0.3906, -0.4102],
        [-0.3809, -0.4941, -0.5039, -0.5156, -0.6992],
        [-0.1875, -0.1973, -0.2275, -0.2598, -0.3887],
        [-0.1768, -0.1797, -0.1963, -0.3926, -0.4922],
        [-0.1128, -0.1318, -0.1543, -0.1738, -0.1982],
        [-0.1748, -0.2021, -0.2490, -0.3262, -0.3359],
        [-0.1758, -0.1855, -0.2305, -0.2754, -0.2832],
        [-0.1221, -0.1406, -0.1797, -0.1855, -0.2324],
        [ 0.0282, -0.0405, -0.1250, -0.2100, -0.2158],
        [-0.0625, -0.1992, -0.3203, -0.3496, -0.4102],
        [-0.1484, -0.3711, -0.3945, -0.4102, -0.4395],
        [ 0.3047,  0.1787, -0.1836, -0.3359, -0.4316],
        [-0.3574, -0.4766, -0.6133, -0.6719, -0.7812],
        [ 0.0147, -0.0415, -0.2148, -0.3242, -0.3945],
        [-0.2656,