Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
308 changes: 265 additions & 43 deletions cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,136 @@ __global__ void activationKernel(KernelParams params)

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Float4Max
{
__device__ __forceinline__ float4 operator()(float4 const& a, float4 const& b) const
{
float4 result;
result.x = fmaxf(a.x, b.x);
result.y = fmaxf(a.y, b.y);
result.z = fmaxf(a.z, b.z);
result.w = fmaxf(a.w, b.w);
return result;
}
};

struct Float2Max
{
__device__ __forceinline__ float2 operator()(float2 const& a, float2 const& b) const
{
float2 result;
result.x = fmaxf(a.x, b.x);
result.y = fmaxf(a.y, b.y);
return result;
}
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename VecType, int size>
__device__ __forceinline__ VecType packedTypeFromArray(float data[size])
{
return {};
}

template <>
__device__ __forceinline__ float4 packedTypeFromArray<float4, 4>(float data[4])
{
float4 result;
result.x = data[0];
result.y = data[1];
result.z = data[2];
result.w = data[3];
return result;
}

template <>
__device__ __forceinline__ float2 packedTypeFromArray<float2, 2>(float data[2])
{
float2 result;
result.x = data[0];
result.y = data[1];
return result;
}

template <>
__device__ __forceinline__ float packedTypeFromArray<float, 1>(float data[1])
{
return data[0];
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename PackedType, int size>
__device__ __forceinline__ cutlass::Array<float, size> arrayFromPackedType(PackedType data)
{
return cutlass::Array<float, size>{};
}

template <>
__device__ __forceinline__ cutlass::Array<float, 4> arrayFromPackedType<float4, 4>(float4 data)
{
return cutlass::Array<float, 4>{data.x, data.y, data.z, data.w};
}

template <>
__device__ __forceinline__ cutlass::Array<float, 2> arrayFromPackedType<float2, 2>(float2 data)
{
return cutlass::Array<float, 2>{data.x, data.y};
}

template <>
__device__ __forceinline__ cutlass::Array<float, 1> arrayFromPackedType<float, 1>(float data)
{
return cutlass::Array<float, 1>{data};
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template <int NUM_TOKENS_PER_CTA>
struct KernelTraits;

template <>
struct KernelTraits<4>
{
using MaxOp = Float4Max;
using PackedType = float4;
};

template <>
struct KernelTraits<2>
{
using MaxOp = Float2Max;
using PackedType = float2;
};

template <>
struct KernelTraits<1>
{
#if CUDA_VERSION >= 12090
using MaxOp = cuda::maximum<>;
#else
using MaxOp = cub::Max;
#endif
using PackedType = float;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

constexpr int DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA = 128;

template <typename KernelParams>
__global__ void activationDeepSeekKernel(KernelParams params)
{
using Type = typename KernelParams::Type;
using BlockReduce = cub::BlockReduce<float, 128>;
int32_t constexpr NumTokensPerCta = KernelParams::NumTokensPerCta;
using KernelTraits = KernelTraits<NumTokensPerCta>;
using MaxOp = typename KernelTraits::MaxOp;
using PackedType = typename KernelTraits::PackedType;
using BlockReduce = cub::BlockReduce<PackedType, DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA>;

__shared__ float s_scaleOut;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ float s_scaleOutArr[NumTokensPerCta];
__shared__ typename BlockReduce::TempStorage tempStorage;

#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
// immediately trigger the secondary kernel when using PDL, then wait on primary
Expand All @@ -117,55 +239,124 @@ __global__ void activationDeepSeekKernel(KernelParams params)
cudaGridDependencySynchronize();
}
#endif

// The largest (finite) value that can be represented using E4m3.
float constexpr E4m3MaxVal{448.f};

int const totalNumPaddedTokens = params.totalNumPaddedTokens[0];
// Loop over tokens
for (int tokenIdx = blockIdx.z; tokenIdx < params.numTokens; tokenIdx += gridDim.z)
float scale1Arr[NumTokensPerCta];
float scale2Arr[NumTokensPerCta];
float dataX1Arr[NumTokensPerCta];
float dataX2Arr[NumTokensPerCta];
float outArr[NumTokensPerCta];
float absOutArr[NumTokensPerCta];
int permutedIdxArr[NumTokensPerCta];

// Loop over tokens
for (int k = blockIdx.z; k < params.topK; k += gridDim.z)
{
// Look over experts per token
for (int k = blockIdx.y; k < params.topK; k += gridDim.y)
for (int tokenCtaIdx = blockIdx.y * NumTokensPerCta; tokenCtaIdx < params.numTokens;
tokenCtaIdx += gridDim.y * NumTokensPerCta)
{
int const expandedIdx = tokenIdx * params.topK + k;
int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx];

// Needed for expert parallelism
if (permutedIdx == -1)
continue;

// Loop over hidden dim
for (int hiddenIdx = threadIdx.x + blockDim.x * blockIdx.x; hiddenIdx < params.innerDim / 2;
hiddenIdx += blockDim.x * gridDim.x)
{
int const baseIdx = permutedIdx * params.innerDim + hiddenIdx;

int const totalNumPaddedTokens = params.totalNumPaddedTokens[0];

int const scale1_idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128);
int const scale2_idx
= permutedIdx + totalNumPaddedTokens * ((hiddenIdx / 128) + (params.innerDim / 2 / 128));
float const scale1 = params.inDqSfsPtr[scale1_idx];
float const scale2 = params.inDqSfsPtr[scale2_idx];
#pragma unroll
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++)
{
scale1Arr[tokenInCtaIdx] = 0.0f;
scale2Arr[tokenInCtaIdx] = 0.0f;
dataX1Arr[tokenInCtaIdx] = 0.0f;
dataX2Arr[tokenInCtaIdx] = 0.0f;
outArr[tokenInCtaIdx] = 0.0f;
absOutArr[tokenInCtaIdx] = 0.0f;
}
#pragma unroll
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++)
{
int const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
if (tokenIdx >= params.numTokens)
{
break;
}

float x1 = scale1 * (float) params.inPtr[baseIdx];
float x2 = scale2 * (float) params.inPtr[baseIdx + params.innerDim / 2];
int const expandedIdx = tokenIdx * params.topK + k;
int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx];
permutedIdxArr[tokenInCtaIdx] = permutedIdx;
if (permutedIdx == -1)
{
continue;
}

// Process blocks for this CTA
int const baseIdx = permutedIdx * params.innerDim + hiddenIdx;

int const scale1Idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128);
int const scale2Idx
= permutedIdx + totalNumPaddedTokens * ((hiddenIdx / 128) + (params.innerDim / 2 / 128));

scale1Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale1Idx];
scale2Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale2Idx];
dataX1Arr[tokenInCtaIdx] = static_cast<float>(params.inPtr[baseIdx]);
dataX2Arr[tokenInCtaIdx] = static_cast<float>(params.inPtr[baseIdx + params.innerDim / 2]);
}

float act = silu(x2);
float out = act * x1;
#pragma unroll
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++)
{
float x1 = scale1Arr[tokenInCtaIdx] * dataX1Arr[tokenInCtaIdx];
float x2 = scale2Arr[tokenInCtaIdx] * dataX2Arr[tokenInCtaIdx];
float act = silu(x2);
float out = act * x1;
outArr[tokenInCtaIdx] = out;
absOutArr[tokenInCtaIdx] = fabsf(out);
}

// The largest (finite) value that can be represented using E4m3.
float constexpr E4m3MaxVal{448.f};
auto absOutPacked = packedTypeFromArray<PackedType, NumTokensPerCta>(absOutArr);
auto aMaxPacked = BlockReduce(tempStorage).Reduce(absOutPacked, MaxOp{});
auto aMaxArr = arrayFromPackedType<PackedType, NumTokensPerCta>(aMaxPacked);

// Compute the absolute max
float aMax = BlockReduce(temp_storage).Reduce(fabsf(out), cuda::maximum<>());
if (threadIdx.x == 0)
#pragma unroll
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++)
{
s_scaleOut = aMax / E4m3MaxVal;
int const scaleOut_idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128);
params.outDqSfsPtr[scaleOut_idx] = aMax / E4m3MaxVal;
if (threadIdx.x == 0)
{
auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
if (tokenIdx >= params.numTokens)
{
break;
}
int const permutedIdx = permutedIdxArr[tokenInCtaIdx];
if (permutedIdx == -1)
{
continue;
}
s_scaleOutArr[tokenInCtaIdx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal;
int const scaleOut_idx
= permutedIdxArr[tokenInCtaIdx] + totalNumPaddedTokens * (hiddenIdx / 128);
params.outDqSfsPtr[scaleOut_idx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal;
}
}
__syncthreads();
float const scaleOut = s_scaleOut;
__syncthreads();
int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx;
params.outPtr[outIdx] = (Type) (out / scaleOut);

#pragma unroll
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++)
{
auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
if (tokenIdx >= params.numTokens)
{
break;
}
int const permutedIdx = permutedIdxArr[tokenInCtaIdx];
if (permutedIdx == -1)
{
continue;
}
float const scaleOut = s_scaleOutArr[tokenInCtaIdx];
int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx;
params.outPtr[outIdx] = static_cast<Type>(outArr[tokenInCtaIdx] / scaleOut);
}
}
}
}
Expand All @@ -185,17 +376,48 @@ void run(Data const& data, void* stream)

if (data.mUseDeepSeekFp8)
{
int const numThreads = 128;
const dim3 grid(data.innerDim / 128, data.topK, std::min(8192, data.numTokens));
constexpr int NUM_ELTS_PER_LOAD = 1;
constexpr int NUM_ELTS_PER_SF = 128;

int device{-1};
cudaGetDevice(&device);
int numSms = 0;
cudaDeviceGetAttribute(&numSms, cudaDevAttrMultiProcessorCount, device);

// Output dimension is innerDim / 2, and each scale block is 128 elements
int const outputDim = data.innerDim / 2;
int const numScaleBlocks = (outputDim + NUM_ELTS_PER_SF - 1) / NUM_ELTS_PER_SF;
int const gridSizeX = (numScaleBlocks + NUM_ELTS_PER_LOAD - 1) / NUM_ELTS_PER_LOAD;

auto numCtas = gridSizeX * data.numTokens * data.topK;
// FIXME: This is heruistic based on very short benchmark.
int numTokensPerCta = 1;
if (numCtas > numSms * 32)
{
numTokensPerCta = 4;
}
else if (numCtas > numSms * 4)
{
numTokensPerCta = 2;
}
else
{
numTokensPerCta = 1;
}

int const gridSizeY = std::min(8192, (data.numTokens + numTokensPerCta - 1) / numTokensPerCta);

const dim3 grid(gridSizeX, gridSizeY, data.topK);

LAUNCH(data, activationDeepSeekKernel, grid, numThreads, 0, stream);
LAUNCH_ACTIVATION(
data, activationDeepSeekKernel, numTokensPerCta, grid, DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA, 0, stream);
}
else
{
int const numThreads = 256;
const dim3 grid(data.innerDim / 128, data.topK, std::min(8192, data.numTokens));

LAUNCH(data, activationKernel, grid, numThreads, 0, stream);
LAUNCH_ACTIVATION(data, activationKernel, 1, grid, numThreads, 0, stream);
}
}

Expand Down
Loading