Skip to content

Commit

Permalink
Merge pull request #84 from HansKristian-Work/spv-khr-integer-dot-pro…
Browse files Browse the repository at this point in the history
…duct

Add support for SPV_KHR_integer_dot_product.
  • Loading branch information
HansKristian-Work committed Oct 5, 2021
2 parents bff5ab1 + 7b4c1fe commit d53f231
Show file tree
Hide file tree
Showing 12 changed files with 344 additions and 21 deletions.
5 changes: 5 additions & 0 deletions dxil_converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4954,6 +4954,10 @@ void Converter::Impl::set_option(const OptionBase &cap)
break;
}

case Option::ShaderI8Dot:
options.shader_i8_dot_enabled = static_cast<const OptionShaderI8Dot &>(cap).supported;
break;

default:
break;
}
Expand Down Expand Up @@ -5000,6 +5004,7 @@ bool Converter::recognizes_option(Option cap)
case Option::StorageInputOutput16:
case Option::DescriptorQA:
case Option::MinPrecisionNative16Bit:
case Option::ShaderI8Dot:
return true;

default:
Expand Down
13 changes: 12 additions & 1 deletion dxil_converter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ enum class Option : uint32_t
BindlessOffsetBufferLayout = 13,
StorageInputOutput16 = 14,
DescriptorQA = 15,
MinPrecisionNative16Bit = 16
MinPrecisionNative16Bit = 16,
ShaderI8Dot = 17
};

enum class ResourceClass : uint32_t
Expand Down Expand Up @@ -391,6 +392,16 @@ struct OptionMinPrecisionNative16Bit : OptionBase
bool enabled = false;
};

struct OptionShaderI8Dot : OptionBase
{
OptionShaderI8Dot()
: OptionBase(Option::ShaderI8Dot)
{
}

bool supported = false;
};

struct DescriptorTableEntry
{
ResourceClass type;
Expand Down
10 changes: 10 additions & 0 deletions dxil_spirv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ static void print_help()
"\t[--vertex-input semantic location]\n"
"\t[--stream-output semantic index offset stride buffer-index]\n"
"\t[--enable-shader-demote]\n"
"\t[--enable-shader-i8-dot]\n"
"\t[--enable-dual-source-blending]\n"
"\t[--bindless]\n"
"\t[--no-bda]\n"
Expand Down Expand Up @@ -157,6 +158,7 @@ struct Arguments
bool emit_asm = false;
bool validate = false;
bool shader_demote = false;
bool shader_i8_dot = false;
bool dual_source_blending = false;
bool debug_all_entry_points = false;
bool storage_input_output_16bit = false;
Expand Down Expand Up @@ -530,6 +532,7 @@ int main(int argc, char **argv)
remapper.stream_outputs.push_back({ std::string(sem), index, offset, stride, buffer_index });
});
cbs.add("--enable-shader-demote", [&](CLIParser &) { args.shader_demote = true; });
cbs.add("--enable-shader-i8-dot", [&](CLIParser &parser) { args.shader_i8_dot = true; });
cbs.add("--enable-dual-source-blending", [&](CLIParser &) { args.dual_source_blending = true; });
cbs.add("--bindless", [&](CLIParser &) {
remapper.bindless = true;
Expand Down Expand Up @@ -734,6 +737,13 @@ int main(int argc, char **argv)
dxil_spv_converter_add_option(converter, &helper.base);
}

if (args.shader_i8_dot)
{
const dxil_spv_option_shader_i8_dot helper = { { DXIL_SPV_OPTION_SHADER_I8_DOT },
DXIL_SPV_TRUE };
dxil_spv_converter_add_option(converter, &helper.base);
}

if (args.dual_source_blending)
{
const dxil_spv_option_dual_source_blending helper = { { DXIL_SPV_OPTION_DUAL_SOURCE_BLENDING }, DXIL_SPV_TRUE };
Expand Down
9 changes: 9 additions & 0 deletions dxil_spirv_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,15 @@ dxil_spv_result dxil_spv_converter_add_option(dxil_spv_converter converter, cons
break;
}

case DXIL_SPV_OPTION_SHADER_I8_DOT:
{
OptionShaderI8Dot helper;
helper.supported = bool(reinterpret_cast<const dxil_spv_option_shader_i8_dot *>(option)->supported);

converter->options.emplace_back(duplicate(helper));
break;
}

default:
return DXIL_SPV_ERROR_UNSUPPORTED_FEATURE;
}
Expand Down
9 changes: 8 additions & 1 deletion dxil_spirv_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ extern "C" {
#endif

#define DXIL_SPV_API_VERSION_MAJOR 2
#define DXIL_SPV_API_VERSION_MINOR 10
#define DXIL_SPV_API_VERSION_MINOR 11
#define DXIL_SPV_API_VERSION_PATCH 0

#define DXIL_SPV_DESCRIPTOR_QA_INTERFACE_VERSION 1
Expand Down Expand Up @@ -277,6 +277,7 @@ typedef enum dxil_spv_option
DXIL_SPV_OPTION_STORAGE_INPUT_OUTPUT_16BIT = 14,
DXIL_SPV_OPTION_DESCRIPTOR_QA = 15,
DXIL_SPV_OPTION_MIN_PRECISION_NATIVE_16BIT = 16,
DXIL_SPV_OPTION_SHADER_I8_DOT = 17,
DXIL_SPV_OPTION_INT_MAX = 0x7fffffff
} dxil_spv_option;

Expand Down Expand Up @@ -408,6 +409,12 @@ typedef struct dxil_spv_option_min_precision_native_16bit
dxil_spv_bool enabled;
} dxil_spv_option_min_precision_native_16bit;

typedef struct dxil_spv_option_shader_i8_dot
{
dxil_spv_option_base base;
dxil_spv_bool supported;
} dxil_spv_option_shader_i8_dot;

/* Gets the ABI version used to build this library. Used to detect API/ABI mismatches. */
DXIL_SPV_PUBLIC_API void dxil_spv_get_version(unsigned *major, unsigned *minor, unsigned *patch);

Expand Down
1 change: 1 addition & 0 deletions opcodes/converter_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ struct Converter::Impl
bool descriptor_qa_enabled = false;
bool descriptor_qa_sink_handles = true;
bool min_precision_prefer_native_16bit = false;
bool shader_i8_dot_enabled = false;
} options;

struct BindlessInfo
Expand Down
53 changes: 36 additions & 17 deletions opcodes/dxil/dxil_arithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,31 +409,50 @@ static spv::Id build_bfe(Converter::Impl &impl, spv::Id value_id, unsigned offse

bool emit_i8_dot_instruction(Converter::Impl &impl, const llvm::CallInst *instruction, bool sign_extend)
{
// Can be improved with a specific intrinsic to support this.
// This is mostly a thing for machine learning algorithms.
// Could potentially be improved a bit with Int8 stuff, but meh ...

auto &builder = impl.builder();
spv::Id acc = impl.get_id_for_value(instruction->getOperand(1));
spv::Id a = impl.get_id_for_value(instruction->getOperand(2));
spv::Id b = impl.get_id_for_value(instruction->getOperand(3));
for (unsigned i = 0; i < 4; i++)

if (impl.options.shader_i8_dot_enabled)
{
spv::Id a_component = build_bfe(impl, a, 8 * i, 8, sign_extend);
spv::Id b_component = build_bfe(impl, b, 8 * i, 8, sign_extend);
auto *mul_op = impl.allocate(spv::OpIMul, builder.makeUintType(32));
mul_op->add_id(a_component);
mul_op->add_id(b_component);
impl.add(mul_op);
builder.addExtension("SPV_KHR_integer_dot_product");
builder.addCapability(spv::CapabilityDotProductKHR);
builder.addCapability(spv::CapabilityDotProductInput4x8BitPackedKHR);

// Not supposed to saturate.
auto *dot_op = impl.allocate(sign_extend ? spv::OpSDotKHR : spv::OpUDotKHR, builder.makeUintType(32));
dot_op->add_id(a);
dot_op->add_id(b);
dot_op->add_literal(spv::PackedVectorFormatPackedVectorFormat4x8BitKHR);
impl.add(dot_op);

auto *acc_op = impl.allocate(spv::OpIAdd, instruction);
acc_op->add_id(acc);
acc_op->add_id(dot_op->id);
impl.add(acc_op);
}
else
{
for (unsigned i = 0; i < 4; i++)
{
spv::Id a_component = build_bfe(impl, a, 8 * i, 8, sign_extend);
spv::Id b_component = build_bfe(impl, b, 8 * i, 8, sign_extend);
auto *mul_op = impl.allocate(spv::OpIMul, builder.makeUintType(32));
mul_op->add_id(a_component);
mul_op->add_id(b_component);
impl.add(mul_op);

auto *add_op = impl.allocate(spv::OpIAdd, builder.makeUintType(32));
add_op->add_id(acc);
add_op->add_id(mul_op->id);
acc = add_op->id;
impl.add(add_op);
}

auto *add_op = impl.allocate(spv::OpIAdd, builder.makeUintType(32));
add_op->add_id(acc);
add_op->add_id(mul_op->id);
acc = add_op->id;
impl.add(add_op);
impl.rewrite_value(instruction, acc);
}

impl.rewrite_value(instruction, acc);
return true;
}

Expand Down
Loading

0 comments on commit d53f231

Please sign in to comment.