Skip to content

Conversation

@kshitij12345
Copy link
Collaborator

@kshitij12345 kshitij12345 commented Oct 10, 2025

Depends on #2611 (requires the changes from #2611 for thunderfx path to work).

Eager (2 layer model on RTX6000 Ada)

CMD: torchrun --nproc-per-node 2 thunder/benchmarks/benchmark_inference.py --input-length 32 --output-length 32 --mode eager --num-iterations 10


============================================================
BENCHMARK RESULTS - meta-llama/Llama-4-Maverick-17B-128E eager
============================================================

Throughput Metrics:
  Overall Throughput: 38.99 tokens/sec
  Prefill Throughput: 1231.52 tokens/sec
  Decode Throughput: 39.01 tokens/sec
  Latency: 25.67 ms/token

Latency Breakdown:
  Time to First Token (TTFT): 25.99 ms
  Time Between Output Tokens (TBOT): 25.65 ms
  Prefill Time: 25.99 ms
  Decode Time: 795.30 ms
  Total Generation Time: 821.29 ms

Memory Usage:
  Current Memory: 20.76 GB
  Peak Memory: 20.78 GB

Variance Analysis:
  Throughput Std Dev: 25.17 ms
  TTFT Std Dev: 0.35 ms

thunderfx (2 layer model on RTX6000 Ada)

CMD: torchrun --nproc-per-node 2 thunder/benchmarks/benchmark_inference.py --input-length 32 --output-length 32 --mode thunder --num-iterations 10

============================================================

BENCHMARK RESULTS - meta-llama/Llama-4-Maverick-17B-128E thunder
============================================================

Throughput Metrics:
  Overall Throughput: 101.91 tokens/sec
  Prefill Throughput: 3266.44 tokens/sec
  Decode Throughput: 102.32 tokens/sec
  Latency: 208.25 ms/token

Latency Breakdown:
  Time to First Token (TTFT): 9.83 ms
  Time Between Output Tokens (TBOT): 214.65 ms
  Prefill Time: 9.83 ms
  Decode Time: 6654.27 ms
  Total Generation Time: 6664.09 ms

Memory Usage:
  Current Memory: 20.76 GB
  Peak Memory: 20.78 GB

Variance Analysis:
  Throughput Std Dev: 20179.27 ms
  TTFT Std Dev: 0.58 ms

thunderfx and nv_enable_linear=True (2 layer model on RTX6000 Ada)

CMD: torchrun --local-ranks-filter 0 --nproc-per-node 2 thunder/benchmarks/benchmark_inference.py --input-length 32 --output-length 32 --mode thunder --num-iterations 10 --enable-nv-linear

============================================================
BENCHMARK RESULTS - meta-llama/Llama-4-Maverick-17B-128E thunder
============================================================

Throughput Metrics:
  Overall Throughput: 106.47 tokens/sec
  Prefill Throughput: 3449.92 tokens/sec
  Decode Throughput: 106.85 tokens/sec
  Latency: 224.55 ms/token

Latency Breakdown:
  Time to First Token (TTFT): 9.29 ms
  Time Between Output Tokens (TBOT): 231.49 ms
  Prefill Time: 9.29 ms
  Decode Time: 7176.27 ms
  Total Generation Time: 7185.56 ms

Memory Usage:
  Current Memory: 20.76 GB
  Peak Memory: 20.78 GB

Variance Analysis:
  Throughput Std Dev: 21866.90 ms
  TTFT Std Dev: 0.36 ms

if isinstance(offsets, DTensor):
assert offsets.placements == (Replicate(),)
offsets = offsets.to_local()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this we will see -

[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_sharding_prop.py", line 539, in propagate_op_sharding_non_cached
[rank0]:     raise NotImplementedError(
[rank0]: NotImplementedError: Operator aten.unbind.int does not have a sharding strategy registered.

due to for offset in offsets (which calls unbind on tensor).


group_sizes = _group_sizes_from_offsets(offsets)
group_outs = []
for group_a, group_b in zip(a.split(group_sizes), b.unbind()):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NotImplementedError: Operator aten.unbind.int does not have a sharding strategy registered.

return self.get_next_token(input_ids, past_key_values)

@torch.inference_mode()
# TODO: Running `torchrun --nproc-per-node 2 thunder/benchmarks/benchmark_inference.py --input-length 32 --output-length 32 --mode eager --num-iterations 10`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my understanding, wouldn't tensor parallel work with inference_mode, but single device works?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this only failed with TP and worked fine for single device.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. and torch.compile and thunderfx also fail, right?

Copy link
Collaborator Author

@kshitij12345 kshitij12345 Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for thunderfx + TP.

torch.compile + TP fails with the following error irrespective of torch.inference_mode.

[rank1]:   File "/tmp/torchinductor_root/6d/c6dlxelqgtfg5yx3w6xwsyntdcsyogve3v5sxqwhd7t7b67b4rn6.py", line 1356, in call
[rank1]:     assert_size_stride(buf10, (1, 32, 64), (2048, 64, 1), 'torch.ops.aten.polar.default')
[rank1]: AssertionError: expected size 32==32, stride 1==64 at dim=1; expected size 64==64, stride 32==1 at dim=2
[rank1]: Error in op: torch.ops.aten.polar.default

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

understood. that sounds more convoluted than I thought 😅

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you created an issue to the pytorch repo about the inference_mode problem?

@t-vi added a no_grad decorator in another place of the benchmark script in 6dca11f. We need to consolidate no_grad/inference_mode decorators in one place for readability.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't yet. Will file one today, thanks!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you created an issue to the pytorch repo about the inference_mode problem?

I wasn't able to repro the error for torch.compile + TP on the latest container.

As for the RuntimeError: Cannot set version_counter for inference tensor when running with eager, I have a feeling that we are probably missing something, will take a look.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for revisiting this!

Base automatically changed from enable-moe-tp-thunderfx to main October 11, 2025 18:07
@kshitij12345 kshitij12345 marked this pull request as ready for review October 11, 2025 19:37
Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@t-vi t-vi merged commit 1273cc4 into main Oct 13, 2025
53 of 54 checks passed
@t-vi t-vi deleted the thunderfx-tp-benchmark-inf branch October 13, 2025 09:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants