Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,15 @@ debug
.vscode/

example.py

# dynamically generated from template
/gptqmodel_ext/marlin/kernel_bf16_kfe2m1f.cu
/gptqmodel_ext/marlin/kernel_bf16_kfe4m3fn.cu
/gptqmodel_ext/marlin/kernel_bf16_ku4.cu
/gptqmodel_ext/marlin/kernel_bf16_ku4b8.cu
/gptqmodel_ext/marlin/kernel_bf16_ku8b128.cu
/gptqmodel_ext/marlin/kernel_fp16_kfe2m1f.cu
/gptqmodel_ext/marlin/kernel_fp16_kfe4m3fn.cu
/gptqmodel_ext/marlin/kernel_fp16_ku4.cu
/gptqmodel_ext/marlin/kernel_fp16_ku4b8.cu
/gptqmodel_ext/marlin/kernel_fp16_ku8b128.cu
30 changes: 20 additions & 10 deletions gptqmodel_ext/marlin/core/scalar_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,19 @@
// For TORCH_CHECK
#include <torch/library.h>

#include <cstring>

namespace vllm {

template <typename To, typename From>
inline To bit_cast_like(const From& src) noexcept {
static_assert(sizeof(To) == sizeof(From),
"bit_cast_like requires source and destination to be the same size");
To dst{};
std::memcpy(&dst, &src, sizeof(To));
return dst;
}

//
// ScalarType can represent a wide range of floating point and integer types,
// in particular it can be used to represent sub-byte data types (something
Expand Down Expand Up @@ -208,30 +219,29 @@ class ScalarType {
// the exponent
uint64_t double_raw =
(max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);

return *reinterpret_cast<double*>(&double_raw);
return bit_cast_like<double>(double_raw);
}

constexpr std::variant<int64_t, double> _raw_max() const {
std::variant<int64_t, double> _raw_max() const {
if (is_floating_point()) {
return {_floating_point_max()};
} else {
TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(),
TORCH_CHECK(size_bits() < 64 || (size_bits() == 64 && is_signed()),
"Cannot represent max as a int64_t");
return {(int64_t(1) << mantissa) - 1};
}
}

constexpr std::variant<int64_t, double> _raw_min() const {
std::variant<int64_t, double> _raw_min() const {
if (is_floating_point()) {
TORCH_CHECK(is_signed(),
"We currently assume all floating point types are signed");
constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);

double max = _floating_point_max();
uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
uint64_t max_raw = bit_cast_like<uint64_t>(max);
uint64_t min_raw = max_raw | sign_bit_double;
return {*reinterpret_cast<double*>(&min_raw)};
return {bit_cast_like<double>(min_raw)};
} else {
TORCH_CHECK(!is_signed() || size_bits() <= 64,
"Cannot represent min as a int64_t");
Expand All @@ -249,15 +259,15 @@ class ScalarType {
public:
// Max representable value for this scalar type.
// (accounting for bias if there is one)
constexpr std::variant<int64_t, double> max() const {
std::variant<int64_t, double> max() const {
return std::visit(
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
_raw_max());
}

// Min representable value for this scalar type.
// (accounting for bias if there is one)
constexpr std::variant<int64_t, double> min() const {
std::variant<int64_t, double> min() const {
return std::visit(
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
_raw_min());
Expand Down Expand Up @@ -349,4 +359,4 @@ static inline constexpr auto kFloat16 = kHalf;
static inline constexpr auto kBFloat16 = kFE8M7;

static inline constexpr auto kFloat16Id = kFloat16.id();
}; // namespace vllm
}; // namespace vllm
File renamed without changes.