Skip to content

Consider instruction input size when considering embedded mask optimization #115074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/coreclr/jit/instrsxarch.h
Original file line number Diff line number Diff line change
@@ -864,13 +864,13 @@ INST3(vcvttps2uqq, "cvttps2uqq", IUM_WR, BAD_CODE, BAD_
INST3(vcvtuqq2pd, "cvtuqq2pd", IUM_WR, BAD_CODE, BAD_CODE, SSEFLT(0x7A), INS_TT_FULL, Input_64Bit | REX_W1 | Encoding_EVEX | INS_Flags_EmbeddedBroadcastSupported) // cvt packed signed QWORDs to doubles
INST3(vcvtuqq2ps, "cvtuqq2ps", IUM_WR, BAD_CODE, BAD_CODE, SSEDBL(0x7A), INS_TT_FULL, Input_64Bit | REX_W1 | Encoding_EVEX | INS_Flags_EmbeddedBroadcastSupported) // cvt packed signed QWORDs to singles
INST3(vextractf32x8, "extractf32x8", IUM_WR, SSE3A(0x1B), BAD_CODE, BAD_CODE, INS_TT_TUPLE8, Input_32Bit | REX_W0 | Encoding_EVEX) // Extract 256-bit packed double-precision floating point values
INST3(vextractf64x2, "extractf64x2", IUM_WR, SSE3A(0x19), BAD_CODE, BAD_CODE, INS_TT_TUPLE2, Input_64Bit | REX_W1 | Encoding_EVEX) // Extract 256-bit packed double-precision floating point values
INST3(vextractf64x2, "extractf64x2", IUM_WR, SSE3A(0x19), BAD_CODE, BAD_CODE, INS_TT_TUPLE2, Input_64Bit | REX_W1 | Encoding_EVEX) // Extract 128-bit packed double-precision floating point values
INST3(vextracti32x8, "extracti32x8", IUM_WR, SSE3A(0x3B), BAD_CODE, BAD_CODE, INS_TT_TUPLE8, Input_32Bit | REX_W0 | Encoding_EVEX) // Extract 256-bit packed quadword integer values
INST3(vextracti64x2, "extracti64x2", IUM_WR, SSE3A(0x39), BAD_CODE, BAD_CODE, INS_TT_TUPLE2, Input_64Bit | REX_W1 | Encoding_EVEX) // Extract 256-bit packed quadword integer values
INST3(vextracti64x2, "extracti64x2", IUM_WR, SSE3A(0x39), BAD_CODE, BAD_CODE, INS_TT_TUPLE2, Input_64Bit | REX_W1 | Encoding_EVEX) // Extract 128-bit packed quadword integer values
INST3(vinsertf32x8, "insertf32x8", IUM_WR, BAD_CODE, BAD_CODE, SSE3A(0x1A), INS_TT_TUPLE8, Input_32Bit | REX_W0 | Encoding_EVEX | INS_Flags_IsDstDstSrcAVXInstruction) // Insert 256-bit packed double-precision floating point values
INST3(vinsertf64x2, "insertf64x2", IUM_WR, BAD_CODE, BAD_CODE, SSE3A(0x18), INS_TT_TUPLE2, Input_64Bit | REX_W1 | Encoding_EVEX | INS_Flags_IsDstDstSrcAVXInstruction) // Insert 256-bit packed double-precision floating point values
INST3(vinsertf64x2, "insertf64x2", IUM_WR, BAD_CODE, BAD_CODE, SSE3A(0x18), INS_TT_TUPLE2, Input_64Bit | REX_W1 | Encoding_EVEX | INS_Flags_IsDstDstSrcAVXInstruction) // Insert 128-bit packed double-precision floating point values
INST3(vinserti32x8, "inserti32x8", IUM_WR, BAD_CODE, BAD_CODE, SSE3A(0x3A), INS_TT_TUPLE8, Input_32Bit | REX_W0 | Encoding_EVEX | INS_Flags_IsDstDstSrcAVXInstruction) // Insert 256-bit packed quadword integer values
INST3(vinserti64x2, "inserti64x2", IUM_WR, BAD_CODE, BAD_CODE, SSE3A(0x38), INS_TT_TUPLE2, Input_64Bit | REX_W1 | Encoding_EVEX | INS_Flags_IsDstDstSrcAVXInstruction) // Insert 256-bit packed quadword integer values
INST3(vinserti64x2, "inserti64x2", IUM_WR, BAD_CODE, BAD_CODE, SSE3A(0x38), INS_TT_TUPLE2, Input_64Bit | REX_W1 | Encoding_EVEX | INS_Flags_IsDstDstSrcAVXInstruction) // Insert 128-bit packed quadword integer values
INST3(vpcmpd, "pcmpd", IUM_WR, BAD_CODE, BAD_CODE, SSE3A(0x1F), INS_TT_FULL, Input_32Bit | REX_W0 | Encoding_EVEX | INS_Flags_Is3OperandInstructionMask | INS_Flags_EmbeddedBroadcastSupported)
INST3(vpcmpq, "pcmpq", IUM_WR, BAD_CODE, BAD_CODE, SSE3A(0x1F), INS_TT_FULL, Input_64Bit | REX_W1 | Encoding_EVEX | INS_Flags_Is3OperandInstructionMask | INS_Flags_EmbeddedBroadcastSupported)
INST3(vpcmpud, "pcmpud", IUM_WR, BAD_CODE, BAD_CODE, SSE3A(0x1E), INS_TT_FULL, Input_32Bit | REX_W0 | Encoding_EVEX | INS_Flags_Is3OperandInstructionMask | INS_Flags_EmbeddedBroadcastSupported)
33 changes: 28 additions & 5 deletions src/coreclr/jit/lowerxarch.cpp
Original file line number Diff line number Diff line change
@@ -10580,7 +10580,7 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
//
// The managed API surface we expose doesn't directly support TYP_MASK
// and we don't directly expose overloads for APIs like `vaddps` which
// support embedded masking. Instead, we have decide to do pattern
// support embedded masking. Instead, we have decided to do pattern
// recognition over the relevant ternary select APIs which functionally
// execute `cond ? selectTrue : selectFalse` on a per element basis.
//
@@ -10605,14 +10605,37 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
// TODO-AVX512-CQ: Ensure we can support embedded operations on RMW intrinsics
isEmbeddedMask = false;
}
else
{
uint32_t maskSize = genTypeSize(simdBaseType);
var_types op2SimdBaseType = op2->AsHWIntrinsic()->GetSimdBaseType();
uint32_t operSize = genTypeSize(op2SimdBaseType);

if (maskSize != operSize)
{
isEmbeddedMask = false;
}
else
{
// Check the op2 instruction input size to see if it's the same as the
// mask size.

NamedIntrinsic op2IntrinsicId = op2->AsHWIntrinsic()->GetHWIntrinsicId();
instruction ins =
HWIntrinsicInfo::lookupIns(op2IntrinsicId, op2SimdBaseType);
assert(ins != INS_invalid);
unsigned inputSize = CodeGenInterface::instInputSize(ins);
if (maskSize != inputSize)
{
isEmbeddedMask = false;
}
}
}
Comment on lines +10610 to +10633
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Carrying over @saucecontrol's comment

Both vpmaddwd and vpmaddubsw are widening instructions, and their embedded mask size is the same as the (widened) result size. We have the input size set to... well... the input size, rather than the result size, in the table. Since neither of those instructions supports embedded broadcast (they're both tuple type Full Mem), input size doesn't mean anything in terms of load containment, so I think it would be more correct to change them to match the mask element size.

Basically what's being said is that for most instructions the size and type of the input/output is the same. So if it takes in a V128, it returns a V128. If it takes in 8-bit elements, it produces 8-bit elements. However, some instructions deviate from this and so may take V128 and return V256 or may take 8-bit elements and produce 16-bit elements.

For the purposes of loads, this is always an input and so INS_TT_* (tuple type) is sufficient for determining if embedded broadcast can be done. On the other hand loads need INS_TT_* + Input_* to determine if containment can be done (as some tuple types, like TUPLE8 indicate that 8 * Input_* elements are loaded).

For the purposes of embedded masking, however, it instead always impacts the output instead and so we rather need some Output_* instead to ensure that instructions like vpmaddwd which take in V128<ushort> (Input_16bit) but return V128<uint> (Output_32bit) can have the correct mask type.


The simpler immediate fix here is to mark InsertVector128 with HW_Flag_NormalizeSmallTypeToInt like we're doing for Sse2_And or other operations that don't have Input_8bit or Input_16bit variants.

However, it won't necessarily get everything and I think we need to add Output_* in a separate PR and then consume it here in lowering to ensure everything is correct.

Notably the meaning of certain fields on GenTreeHWIntrinsic are not entirely consistent today either. That is, unlike gtType which is always the type the node produces; we have some cases where gtSimdSize/gtSimdBaseJitType is sometimes for the input and sometimes the output. We should probably normalize this to reduce issues as well.

Namely I think that since gtType will typically be a TYP_SIMD and therefore already track the output simd size and since C# APIs cannot overload based on return type but we do sometimes need to track a type for "overload resolution" purposes in picking the right instruction; then I think we should guarantee that gtSimdSize and gtSimdBaseJitType are tracking the input information (so we always know this regardless of whether its a local, a bitcast vector that produced a different base type, etc). We should then have gtAuxiliaryJitType always be the base type of the output (maybe we should rename them gtSimdInputBaseJitType and gtSimdOutputBaseJitType, respectively).

Under such a setup we can then safely use gtSimdOutputBaseJitType to determine if a mask can be used and gtSimdInputBaseJitType to determine if broadcast can be done. This also gives a clean mapping to instructions and a way to assert that gtSimdInputBaseJitType matches Input_*bit and gtSimdOutputBaseJitType matches Output_*bit` for the selected instruction.

Copy link
Contributor Author

@BruceForstall BruceForstall May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed reply.

Can you clarify one thing? You wrote:

For the purposes of loads, this is always an input and so INS_TT_* (tuple type) is sufficient for determining if embedded broadcast can be done. On the other hand loads need INS_TT_* + Input_* to determine if containment can be done (as some tuple types, like TUPLE8 indicate that 8 * Input_* elements are loaded).

Namely, "For the purposes of loads...", then "On the other hand loads...'. Is one of those supposed to be different (i.e., not "loads")?

The idea to add Output_* makes sense. How would we add the correct Output_* for each instruction (semi-automatically)? Presumably we could start with explicitly setting Output_* to Input_*, or else assuming an empty/non-existent Output_* means use the Input_* value for the Output_* value; then, just add explicit Output_* values for those that differ.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Namely, "For the purposes of loads...", then "On the other hand loads...'. Is one of those supposed to be different (i.e., not "loads")?

It was meant to be two parts. The first part was covering loads that can be represented as embedded broadcasts and the second part that can be represented as regular containment. All load containment needs to consider at least INS_TT_* while "regular" containment needs to additionally consider the Input_*.

The idea to add Output_* makes sense. How would we add the correct Output_* for each instruction (semi-automatically)?

I think the easiest default would be to define the 4 new Output_* flags and do a regex find/replace for Input_(8Bit |16Bit|32Bit|64Bit) to Input_$1 | Output_$1. There'd then be a little bit of adjusting the column alignment and then a handful of instructions that need to be fixed up where the outputs differ.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, on first glance I didn't see any instructions that had different element sizes for embedded broadcast vs embedded masking, but I see that is a possibility now. A good example is vcvtps2qq, which has 32-bit broadcast size and 64-bit mask element size.

It seems we're probably pessimizing codegen in a few places because we currently use SimdBaseType to make masking decisions, but maybe having a separate flag would allow that to be cleaned up.

}

if (isEmbeddedMask)
{
uint32_t maskSize = genTypeSize(simdBaseType);
uint32_t operSize = genTypeSize(op2->AsHWIntrinsic()->GetSimdBaseType());

if ((maskSize == operSize) && IsInvariantInRange(op2, node))
if (IsInvariantInRange(op2, node))
{
MakeSrcContained(node, op2);
op2->MakeEmbMaskOp();
83 changes: 83 additions & 0 deletions src/tests/JIT/Regression/JitBlue/Runtime_114921/Runtime_114921.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
//
// Generated by Fuzzlyn v2.5 on 2025-04-22 17:32:36
// Run on X64 Windows
// Seed: 7915602115310323123-vectort,vector128,vector256,vector512,x86aes,x86avx,x86avx2,x86avx512bw,x86avx512bwvl,x86avx512cd,x86avx512cdvl,x86avx512dq,x86avx512dqvl,x86avx512f,x86avx512fvl,x86avx512fx64,x86avx512vbmi,x86avx512vbmivl,x86bmi1,x86bmi1x64,x86bmi2,x86bmi2x64,x86fma,x86lzcnt,x86lzcntx64,x86pclmulqdq,x86popcnt,x86popcntx64,x86sse,x86ssex64,x86sse2,x86sse2x64,x86sse3,x86sse41,x86sse41x64,x86sse42,x86sse42x64,x86ssse3,x86x86base
// Reduced from 123.1 KiB to 0.5 KiB in 00:00:46
// Debug: Outputs <0, 0, 0, 0, 0, 0, 0, 0>
// Release: Outputs <0, 0, 0, 0, -1, -1, -1, -1>

using System;
using System.Numerics;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
using Xunit;

public class Runtime_114921
{
public static Vector512<long> s_4 = Vector512.Create<long>(-1);
public static Vector128<long> s_8;

[Fact]
public static void Problem1()
{
if (Avx512F.IsSupported)
{
var vr1 = Vector512.Create<long>(0);
s_4 = Avx512F.BlendVariable(s_4, Avx512F.InsertVector128(vr1, s_8, 0), s_4);
System.Console.WriteLine(s_4);
Assert.Equal(Vector512.Create(0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L), s_4);
}
}
}

// Generated by Fuzzlyn v2.5 on 2025-04-22 17:37:13
// Run on X64 Windows
// Seed: 14731447107126414231-vectort,vector128,vector256,vector512,x86aes,x86avx,x86avx2,x86avx512bw,x86avx512bwvl,x86avx512cd,x86avx512cdvl,x86avx512dq,x86avx512dqvl,x86avx512f,x86avx512fvl,x86avx512fx64,x86avx512vbmi,x86avx512vbmivl,x86bmi1,x86bmi1x64,x86bmi2,x86bmi2x64,x86fma,x86lzcnt,x86lzcntx64,x86pclmulqdq,x86popcnt,x86popcntx64,x86sse,x86ssex64,x86sse2,x86sse2x64,x86sse3,x86sse41,x86sse41x64,x86sse42,x86sse42x64,x86ssse3,x86x86base
// Reduced from 217.7 KiB to 1.0 KiB in 00:02:50
// Debug: Outputs <9223372036854775807, 0, 0, 0, 0, 0, 0, 0>
// Release: Outputs <4294967295, 0, 0, 0, 0, 0, 0, 0>

public struct S2
{
public Vector128<long> F0;
public S2(Vector128<long> f0) : this()
{
F0 = f0;
}
}

public class Runtime_114921_2
{
public static IRuntime s_rt;

[Fact]
public static void Problem2()
{
if (Avx512F.IsSupported)
{
s_rt = new Runtime();
long vr6 = default(long);
S2 vr7 = new S2(Vector128.CreateScalar(9223372036854775807L));
Vector512<long> vr14 = default(Vector512<long>);
var vr9 = Vector512.Create<long>(vr6);
var vr10 = vr7.F0;
var vr11 = Avx512F.InsertVector128(vr9, vr10, 0);
var vr12 = Vector512.CreateScalar(-9223372036854775808L);
var vr13 = Avx512F.BlendVariable(vr14, vr11, vr12);
s_rt.WriteLine(vr13);
Assert.Equal(Vector512.Create(9223372036854775807L, 0L, 0L, 0L, 0L, 0L, 0L, 0L), vr13);
}
}
}

public interface IRuntime
{
void WriteLine<T>(T value);
}

public class Runtime : IRuntime
{
public void WriteLine<T>(T value) => System.Console.WriteLine(value);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<Optimize>True</Optimize>
</PropertyGroup>
<ItemGroup>
<Compile Include="$(MSBuildProjectName).cs" />
</ItemGroup>
</Project>
Loading
Oops, something went wrong.