From 6bf9480c96abdb0b10d5455960928fd54a9a00fc Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 15 Oct 2025 00:15:46 +0000 Subject: [PATCH 1/2] rename Signed-off-by: Qubitium --- gptqmodel_ext/marlin/core/scalar_type.hpp | 30 ++++++++++++------- ...lkit.sh => sync_cuda_toolkit_with_torch.sh | 0 2 files changed, 20 insertions(+), 10 deletions(-) rename auto_switch_cuda_toolkit.sh => sync_cuda_toolkit_with_torch.sh (100%) diff --git a/gptqmodel_ext/marlin/core/scalar_type.hpp b/gptqmodel_ext/marlin/core/scalar_type.hpp index cfd8894a0..97078169d 100644 --- a/gptqmodel_ext/marlin/core/scalar_type.hpp +++ b/gptqmodel_ext/marlin/core/scalar_type.hpp @@ -3,8 +3,19 @@ // For TORCH_CHECK #include +#include + namespace vllm { +template +inline To bit_cast_like(const From& src) noexcept { + static_assert(sizeof(To) == sizeof(From), + "bit_cast_like requires source and destination to be the same size"); + To dst{}; + std::memcpy(&dst, &src, sizeof(To)); + return dst; +} + // // ScalarType can represent a wide range of floating point and integer types, // in particular it can be used to represent sub-byte data types (something @@ -208,30 +219,29 @@ class ScalarType { // the exponent uint64_t double_raw = (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); - - return *reinterpret_cast(&double_raw); + return bit_cast_like(double_raw); } - constexpr std::variant _raw_max() const { + std::variant _raw_max() const { if (is_floating_point()) { return {_floating_point_max()}; } else { - TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(), + TORCH_CHECK(size_bits() < 64 || (size_bits() == 64 && is_signed()), "Cannot represent max as a int64_t"); return {(int64_t(1) << mantissa) - 1}; } } - constexpr std::variant _raw_min() const { + std::variant _raw_min() const { if (is_floating_point()) { TORCH_CHECK(is_signed(), "We currently assume all floating point types are signed"); constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); double max = _floating_point_max(); - uint64_t max_raw = *reinterpret_cast(&max); + uint64_t max_raw = bit_cast_like(max); uint64_t min_raw = max_raw | sign_bit_double; - return {*reinterpret_cast(&min_raw)}; + return {bit_cast_like(min_raw)}; } else { TORCH_CHECK(!is_signed() || size_bits() <= 64, "Cannot represent min as a int64_t"); @@ -249,7 +259,7 @@ class ScalarType { public: // Max representable value for this scalar type. // (accounting for bias if there is one) - constexpr std::variant max() const { + std::variant max() const { return std::visit( [this](auto x) -> std::variant { return {x - bias}; }, _raw_max()); @@ -257,7 +267,7 @@ class ScalarType { // Min representable value for this scalar type. // (accounting for bias if there is one) - constexpr std::variant min() const { + std::variant min() const { return std::visit( [this](auto x) -> std::variant { return {x - bias}; }, _raw_min()); @@ -349,4 +359,4 @@ static inline constexpr auto kFloat16 = kHalf; static inline constexpr auto kBFloat16 = kFE8M7; static inline constexpr auto kFloat16Id = kFloat16.id(); -}; // namespace vllm \ No newline at end of file +}; // namespace vllm diff --git a/auto_switch_cuda_toolkit.sh b/sync_cuda_toolkit_with_torch.sh similarity index 100% rename from auto_switch_cuda_toolkit.sh rename to sync_cuda_toolkit_with_torch.sh From e03749aad951e7e0497f2971c44ceffa3bd17a54 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 15 Oct 2025 00:27:48 +0000 Subject: [PATCH 2/2] git ignore dynamic kernels from template Signed-off-by: Qubitium --- .gitignore | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/.gitignore b/.gitignore index 0ca338d14..c1482d225 100644 --- a/.gitignore +++ b/.gitignore @@ -169,3 +169,15 @@ debug .vscode/ example.py + +# dynamically generated from template +/gptqmodel_ext/marlin/kernel_bf16_kfe2m1f.cu +/gptqmodel_ext/marlin/kernel_bf16_kfe4m3fn.cu +/gptqmodel_ext/marlin/kernel_bf16_ku4.cu +/gptqmodel_ext/marlin/kernel_bf16_ku4b8.cu +/gptqmodel_ext/marlin/kernel_bf16_ku8b128.cu +/gptqmodel_ext/marlin/kernel_fp16_kfe2m1f.cu +/gptqmodel_ext/marlin/kernel_fp16_kfe4m3fn.cu +/gptqmodel_ext/marlin/kernel_fp16_ku4.cu +/gptqmodel_ext/marlin/kernel_fp16_ku4b8.cu +/gptqmodel_ext/marlin/kernel_fp16_ku8b128.cu