Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/customAllReduceKernels.h"
#include <NvInferRuntime.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <vector>

namespace tensorrt_llm::kernels
{

constexpr int LP_ALLREDUCE_MAX_BLOCKS = 8;
constexpr int LP_ALLREDUCE_WARPSIZE = 32;
constexpr int LP_ALLREDUCE_DEFAULT_BLOCK_SIZE = 512;
constexpr int LP_ALLREDUCE_WARP_NUM_PER_BLOCK = 16;
constexpr int LP_ALLREDUCE_BYTES_PER_LOAD = 16;
constexpr int LP_ALLREDUCE_NUMA_NUM = 2;
constexpr int LP_ALLREDUCE_MAX_RANKS_PER_NUMA = 4;
constexpr int LP_ALLREDUCE_BUFFER_DUPLICATE = 16;
constexpr int LP_ALLREDUCE_BUFFER_CHUNKS = 8;
constexpr int LP_ALLREDUCE_HIER_STAGE_NUM = 3;
constexpr int LP_ALLREDUCE_RANKS_PER_NUMA = 4;
constexpr int LP_ALLREDUCE_MAX_ELTS_IN_WORKSPACE = 32 * 1024 * 1024;
constexpr int LP_ALLREDUCE_MIN_ELTS_THRESHOLD = 8 * 1024 * 1024;
constexpr int LP_ALLREDUCE_MAX_TP_SIZE = 8;
constexpr int LP_ALLREDUCE_MAX_RANKS_PER_NODE = 16;

struct StaticLowPrecisionBuffers
{
void* peer_comm_buffer_ptrs[LP_ALLREDUCE_MAX_TP_SIZE * 2];
uint64_t* peer_barrier_ptrs_in[LP_ALLREDUCE_MAX_TP_SIZE];
uint64_t* peer_barrier_ptrs_out[LP_ALLREDUCE_MAX_TP_SIZE];
int64_t* flag_ptr;
bool initialized = false;
size_t tpSize = 0;
};

void initialize_static_lowprecision_buffers(int64_t* buffer, size_t tpSize);

std::vector<size_t> splitNumber(size_t number);

struct LowPrecisionAllReduceParams
{
size_t elts_total;
size_t elts_per_rank;
size_t elts_per_block;
size_t rank_offset;
int32_t ranks_per_node, rank, local_rank;
uint64_t barrier_flag;
uint64_t* peer_barrier_ptrs_in[LP_ALLREDUCE_MAX_RANKS_PER_NODE];
uint64_t* peer_barrier_ptrs_out[LP_ALLREDUCE_MAX_RANKS_PER_NODE];
void* peer_comm_buffer_ptrs[LP_ALLREDUCE_MAX_RANKS_PER_NODE];
void* local_output_buffer_ptr;
void const* local_input_buffer_ptr;

// for low precision
size_t buffer_elts_per_rank;
size_t buffer_offset;

// for low precision hier
uint32_t num_rounds = 0;
uint32_t num_rounds_fence = 0;
uint32_t block_num = 0;
int32_t numa_rank = -1;

void* inputs_inside_numa[4];

void* rs_buffers[LP_ALLREDUCE_MAX_BLOCKS];
void* ar_buffers[LP_ALLREDUCE_MAX_BLOCKS];
void* ar_peer_buffers_cross_numa[LP_ALLREDUCE_MAX_BLOCKS];
void* ag_peer_buffers_inside_numa[LP_ALLREDUCE_MAX_BLOCKS * 4];

// for low precision hier handshake rs stage
uint64_t* rs_send_flags[LP_ALLREDUCE_MAX_BLOCKS];
uint64_t* rs_ack_flags[LP_ALLREDUCE_MAX_BLOCKS]; // 2*flags
uint64_t* rs_notify_local_flags[LP_ALLREDUCE_MAX_BLOCKS];
uint64_t* rs_notify_remote_flags[LP_ALLREDUCE_MAX_BLOCKS];

// for low precision hier handshake ar stage
uint64_t* ar_send_flags[LP_ALLREDUCE_MAX_BLOCKS];
uint64_t* ar_ack_peer_rs_flags[LP_ALLREDUCE_MAX_BLOCKS];
uint64_t* ar_ack_flags[LP_ALLREDUCE_MAX_BLOCKS];
uint64_t* ar_notify_rs_local_flags[LP_ALLREDUCE_MAX_BLOCKS];
uint64_t* ar_notify_rs_remote_flags[LP_ALLREDUCE_MAX_BLOCKS];
uint64_t* ar_notify_ag_flags[LP_ALLREDUCE_MAX_BLOCKS];

// for low precision hier handshake ag stage
uint64_t* ag_send_flags[LP_ALLREDUCE_MAX_BLOCKS];
uint64_t* ag_ack_peer_inside_numa_flags[LP_ALLREDUCE_MAX_BLOCKS]; // 3*flags , 3 is other rank inside numa
uint64_t* ag_notify_peer_inside_numa_flags[LP_ALLREDUCE_MAX_BLOCKS * 4]; // 3*flags , 3 is other rank inside numa

static LowPrecisionAllReduceParams deserialize(
size_t tpSize, size_t tpRank, nvinfer1::DataType dataType, int token_num, int hidden_size);
static LowPrecisionAllReduceParams deserialize_hier(
size_t tpSize, size_t tpRank, nvinfer1::DataType dataType, int token_num, int hidden_size);
};

bool lowPrecisionConfigurationSupported(size_t msg_size, size_t n_ranks);

void customLowPrecisionAllReduce(
kernels::LowPrecisionAllReduceParams& params, nvinfer1::DataType dataType, cudaStream_t stream);

int32_t max_workspace_size_lowprecision(int32_t tp_size);
} // namespace tensorrt_llm::kernels
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/kernels/customAllReduceKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ enum class AllReduceStrategyType : int8_t
AUTO = 3,
ONESHOT = 4,
TWOSHOT = 5,
LOWPRECISION = 6,
};

enum class AllReduceStrategyConfig : int8_t
Expand Down
5 changes: 5 additions & 0 deletions cpp/tensorrt_llm/pybind/runtime/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "bindings.h"
#include "tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h"
#include "tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.h"
#include "tensorrt_llm/kernels/communicationKernels/customLowPrecisionAllReduceKernels.h"
#include "tensorrt_llm/kernels/customAllReduceKernels.h"
#include "tensorrt_llm/kernels/delayStream.h"
#include "tensorrt_llm/runtime/cudaEvent.h"
Expand Down Expand Up @@ -393,6 +394,10 @@ void initBindings(pybind11::module_& m)
tensorrt_llm::kernels::invokeDelayStreamKernel(delay_micro_secs, stream);
},
"Delay kernel launch on the default stream");
m.def(
"max_workspace_size_lowprecision",
[](int32_t tp_size) { return tensorrt_llm::kernels::max_workspace_size_lowprecision(tp_size); },
"Calculate the maximum workspace size needed for low precision all-reduce operations");

py::enum_<tensorrt_llm::kernels::AllReduceFusionOp>(m, "AllReduceFusionOp")
.value("NONE", tensorrt_llm::kernels::AllReduceFusionOp::NONE)
Expand Down
119 changes: 119 additions & 0 deletions cpp/tensorrt_llm/thop/allreduceOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "tensorrt_llm/common/dataType.h"
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h"
#include "tensorrt_llm/kernels/communicationKernels/customLowPrecisionAllReduceKernels.h"
#include "tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.h"
#include "tensorrt_llm/kernels/customAllReduceKernels.h"
#include "tensorrt_llm/kernels/internal_cutlass_kernels/include/fp4_gemm.h"
Expand Down Expand Up @@ -177,6 +178,8 @@ class AllreduceOp
case AllReduceStrategyType::ONESHOT:
case AllReduceStrategyType::TWOSHOT:
return runFusionAllReduce(input, residual, norm_weight, scale, bias, workspace, runtime_strategy);
case AllReduceStrategyType::LOWPRECISION:
return runLowPrecisionAllReduce(input, residual, norm_weight, scale, bias);
default: TORCH_CHECK(false, "Invalid runtime strategy"); return {};
}
}
Expand Down Expand Up @@ -296,6 +299,73 @@ class AllreduceOp
return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, reduce_output);
}

std::vector<torch::Tensor> runLowPrecisionAllReduce(torch::Tensor const& input,
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias) noexcept
{
#ifdef ENABLE_FP8
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
int size = input.numel();
int hidden_size = input.size(-1);

auto const tp_size = mGroup.size();
auto const cur_rank = COMM_SESSION.getRank();
int tp_rank = 0;

for (auto const& currentRank : mGroup)
{
if (cur_rank == currentRank)
break;
++tp_rank;
}

int bytes_per_element = input.element_size();

int token_num = size / hidden_size;

auto parts = tensorrt_llm::kernels::splitNumber(size);

torch::Tensor reduce_output = torch::empty_like(input);

size_t global_offset = 0;
for (size_t i = 0; i < parts.size(); ++i)
{
size_t tmp_size = parts[i];
tensorrt_llm::kernels::LowPrecisionAllReduceParams tmp_param;
if (tp_size <= 4)
{
tmp_param = tensorrt_llm::kernels::LowPrecisionAllReduceParams::deserialize(
tp_size, tp_rank, mType, token_num, hidden_size);
}
else
{
tmp_param = tensorrt_llm::kernels::LowPrecisionAllReduceParams::deserialize_hier(
tp_size, tp_rank, mType, token_num, hidden_size);
}

tmp_param.local_input_buffer_ptr = reinterpret_cast<void const*>(
reinterpret_cast<char const*>(input.data_ptr()) + global_offset * bytes_per_element);
tmp_param.local_output_buffer_ptr = reinterpret_cast<void*>(
reinterpret_cast<char*>(reduce_output.mutable_data_ptr()) + global_offset * bytes_per_element);
tmp_param.elts_total = tmp_size;
tensorrt_llm::kernels::customLowPrecisionAllReduce(tmp_param, mType, stream);

global_offset += tmp_size;
}

if (mOp == AllReduceFusionOp::NONE)
{
return {reduce_output};
}

// Treat any other patterns as fallback cases.
return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, reduce_output);

#else
C10_THROW_ERROR(NotImplementedError, "Can't use LOWPRECISION without compile with ENABLE FP8.");
#endif
}

std::vector<torch::Tensor> runFusionAllReduce(torch::Tensor const& input,
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias,
Expand Down Expand Up @@ -594,6 +664,11 @@ class AllreduceOp
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: UB", rank);
break;
}
case AllReduceStrategyType::LOWPRECISION:
{
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: LOWPRECISION", rank);
break;
}
default: break;
}
}
Expand Down Expand Up @@ -766,7 +841,21 @@ class AllreduceOp
AllReduceStrategyType selectImplementation(
size_t seq_len, size_t message_size, int world_size, nvinfer1::DataType type) noexcept
{

if (isUsingLowPrecision(message_size))
{
return AllReduceStrategyType::LOWPRECISION;
}
else
{
if (mStrategy == AllReduceStrategyType::LOWPRECISION)
{
mStrategy = AllReduceStrategyType::AUTO;
}
}

// Check that heuristic is only applied when AUTO is set.
// Use Auto select
bool const is_auto = (mStrategy == AllReduceStrategyType::AUTO);
auto const message_size_bytes = message_size * tensorrt_llm::common::getDTypeSize(type);
auto const max_workspace_size
Expand Down Expand Up @@ -847,6 +936,24 @@ class AllreduceOp
return strategy;
}

bool isUsingLowPrecision(size_t message_size) const noexcept
{
static char* force_low_precision_allreduce_strategy_char
= std::getenv("FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY");
bool force_low_precision = (force_low_precision_allreduce_strategy_char != nullptr)
|| (mStrategy == AllReduceStrategyType::LOWPRECISION);

#ifdef ENABLE_FP8
// Use LowPrecision if PCIe and p2p support and message size is larger than 2MB
constexpr int LowPrecisionMinMessageSize = 2 * 1024 * 1024;
return force_low_precision && !mIsNVLINKSupported && mIsP2PSupported
&& message_size >= LowPrecisionMinMessageSize;
#else
// Low precision is not available when FP8 is not enabled
return false;
#endif
}

private:
std::set<int> mGroup;
bool mIsNVLINKSupported;
Expand Down Expand Up @@ -966,10 +1073,22 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
"int rank,"
"int nranks,"
"float eps) -> Tensor[]");
m.def("initialize_static_lowprecision_buffers(Tensor workspace, int tp_size) -> Tensor[]");
}

TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("allreduce", &torch_ext::allreduce);
m.impl("moe_allreduce", &torch_ext::moe_allreduce);
}

TORCH_LIBRARY_IMPL(trtllm, CPU, m)
{
m.impl("initialize_static_lowprecision_buffers",
[](at::Tensor const& workspace, int64_t tp_size)
{
tensorrt_llm::kernels::initialize_static_lowprecision_buffers(
reinterpret_cast<int64_t*>(workspace.data_ptr()), (int) tp_size);
return std::vector<at::Tensor>{};
});
}
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/thop/thUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,5 @@ int nextPowerOfTwo(int v);
std::optional<float> getFloatEnv(char const* name);

cudaDataType_t convert_torch_dtype(torch::ScalarType dtype);

} // namespace torch_ext
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
65 changes: 65 additions & 0 deletions docs/source/advanced/lowprecision-pcie-allreduce.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Low-Precision-AllReduce

```{note}
Note:
This feature is optimized for PCIe-based GPU topologies and may affect model accuracy. Please evaluate precision impact for your specific workload.
```


TRT-LLM supports `low-precision-allreduce`, a communication optimization that accelerates AllReduce operations in PCIe-based GPU environments. This feature quantizes FP16/BF16 data to FP8 during network transmission, reducing communication volume and improving performance.

## Algorithm

The Low-Precision-AllReduce algorithm works by:
1. Quantizing input FP16/BF16 tensors to FP8 format before network transmission


**Quantization details**: We use a "per-warp" quantization approach where each CUDA warp (32 threads) processes a batch of data. In each warp, 31 threads quantize FP16/BF16 values to FP8 e4m3 format (16 bytes per thread), while the last thread transmits a scalar value. This results in each warp collectively quantizing 496 elements plus one scalar at a time.

2. Transmitting the quantized data through the network
3. Dequantizing received data back to the original precision
4. Performing the reduction operation

In 8-GPU scenarios, this approach shifts the communication bottleneck from cross-NUMA QPI to the PCIe switch, resulting in better overall performance.

## Topology Requirements

![8x L20/L40s Node Architecture](images/8x_l20_L40S_node_architecture.png)

Low-Precision-AllReduce is specifically designed for the topology shown above, where:
- Each node contains 2 NUMA domains
- Each NUMA domain has 4 GPUs connected via PCIe switch
- GPUs within the same NUMA node communicate via the PCIe switch

**Important:** This optimization will not accelerate performance in different topologies (e.g., where each GPU is in a separate NUMA domain).

## Usage

The Low-Precision-AllReduce algorithm can be enabled in two ways:

1. **Direct specification** in your code:
```
AllReduce allreduce(mapping=mapping, strategy=AllReduceStrategy.LOWPRECISION);
```
2. **Environment variable control** with AUTO strategy:
```
// In your code
AllReduce allreduce(mapping=mapping, strategy=AllReduceStrategy.AUTO);
// Set environment variable before running
export FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY=1
```

## Performance and Accuracy Considerations

Low-Precision-AllReduce reduces communication volume by using FP8 data format for transmission. This optimization:
- Improves performance for large message sizes in PCIe-based topologies
- May slightly reduce numerical precision
- Automatically falls back to other strategies when no performance benefit is expected (e.g., with NVLink or small messages)

Users should evaluate the precision impact on their specific models and workloads.

## Environment Variables

- `FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY`: When set to `1`, forces the use of low-precision algorithm with AUTO strategy. If the algorithm determines it cannot provide performance benefits, it will automatically fall back to other strategies.

**Note**: When compiling TensorRT-LLM without enabling the `ENABLE_FP8` option, setting Low Precision allreduce will not take effect.
Loading