Skip to content

Commit 98c7dbe

Browse files
committed
Fix typing of BlendVariableMask condition when optimizing TernaryLogic
TernaryLogic nodes normalize small types to large types. When optimizing a TernaryLogic node to a BlendVariableMask, if we leave the normalized type, then the Blend uses the wrong instruction. To fix this, change the Blend node to use the mask simd base type. We only do this for mask base types that are small, which is unsatisfying: there exists codegen today for double->int casts, for example, that uses a 'double' type mask but an 'int' sized TernaryLogic node. This works because we end up only using lane '0', as the vector is converted to scalar. I didn't find any other case like this, so hopefully if the mask type is small we can safely use it. Fixes #114572
1 parent 760e8e9 commit 98c7dbe

File tree

3 files changed

+70
-0
lines changed

3 files changed

+70
-0
lines changed

src/coreclr/jit/lowerxarch.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3875,6 +3875,28 @@ GenTree* Lowering::LowerHWIntrinsicTernaryLogic(GenTreeHWIntrinsic* node)
38753875
}
38763876

38773877
assert(varTypeIsMask(condition));
3878+
3879+
// The TernaryLogic node normalizes small SIMD base types on import. To optimize
3880+
// to BlendVariableMask, we need to "un-normalize". We no longer have the original
3881+
// base type, so we use the mask base type instead.
3882+
NamedIntrinsic intrinsicId = node->GetHWIntrinsicId();
3883+
assert(HWIntrinsicInfo::NeedsNormalizeSmallTypeToInt(intrinsicId));
3884+
3885+
// The condition mask element size is expected to be the same size as the TernaryLogic
3886+
// node element size, unless the condition element size is short, in which case the
3887+
// TernaryLogic node element size would have been normalized.
3888+
// However, there is code, such as for double->int conversions, that generates a 'double'
3889+
// base type mask for a TernaryLogic node with base type 'int'. That works because it only
3890+
// cares about the '0' elem, as the result will be cast to scalar.
3891+
var_types conditionBaseType = condition->AsHWIntrinsic()->GetSimdBaseType();
3892+
uint32_t conditionElemSize = genTypeSize(conditionBaseType);
3893+
uint32_t elemSize = genTypeSize(simdBaseType);
3894+
if (varTypeIsShort(conditionBaseType) && (conditionElemSize < elemSize))
3895+
{
3896+
CorInfoType simdBaseJitTypeCondition = condition->AsHWIntrinsic()->GetSimdBaseJitType();
3897+
node->AsHWIntrinsic()->SetSimdBaseJitType(simdBaseJitTypeCondition);
3898+
}
3899+
38783900
node->ResetHWIntrinsicId(NI_EVEX_BlendVariableMask, comp, selectFalse, selectTrue, condition);
38793901
BlockRange().Remove(op4);
38803902
break;
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
//
4+
// Generated by Fuzzlyn v2.5 on 2025-04-11 19:29:41
5+
// Run on X64 Windows
6+
// Seed: 557319528607462789-vectort,vector128,vector256,x86aes,x86avx,x86avx2,x86avx512bw,x86avx512bwvl,x86avx512cd,x86avx512cdvl,x86avx512dq,x86avx512dqvl,x86avx512f,x86avx512fvl,x86avx512fx64,x86bmi1,x86bmi1x64,x86bmi2,x86bmi2x64,x86fma,x86lzcnt,x86lzcntx64,x86pclmulqdq,x86popcnt,x86popcntx64,x86sse,x86ssex64,x86sse2,x86sse2x64,x86sse3,x86sse41,x86sse41x64,x86sse42,x86sse42x64,x86ssse3,x86x86base
7+
// Reduced from 41.1 KiB to 0.7 KiB in 00:01:23
8+
// Debug: Outputs <0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>
9+
// Release: Outputs <0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1>
10+
11+
using System;
12+
using System.Numerics;
13+
using System.Runtime.Intrinsics;
14+
using System.Runtime.Intrinsics.X86;
15+
using Xunit;
16+
17+
public class Runtime_114572
18+
{
19+
public static Vector256<ushort> s_2;
20+
public static ushort s_4;
21+
22+
[Fact]
23+
public static void Problem()
24+
{
25+
if (Avx512F.VL.IsSupported)
26+
{
27+
var vr11 = Vector256.Create<ushort>(0);
28+
var vr12 = Vector256.Create<ushort>(1);
29+
var vr13 = (ushort)0;
30+
var vr14 = Vector256.CreateScalar(vr13);
31+
var vr15 = (ushort)1;
32+
var vr16 = Vector256.CreateScalar(vr15);
33+
var vr17 = Vector256.Create<ushort>(s_4);
34+
var vr18 = Avx2.Max(vr16, vr17);
35+
s_2 = Avx512F.VL.TernaryLogic(vr11, vr12, Avx512BW.VL.CompareGreaterThanOrEqual(vr14, vr18), 216);
36+
System.Console.WriteLine(s_2);
37+
Assert.Equal(Vector256.Create((ushort)0, (ushort)1, (ushort)1, (ushort)1, (ushort)1, (ushort)1, (ushort)1, (ushort)1, (ushort)1, (ushort)1, (ushort)1, (ushort)1, (ushort)1, (ushort)1, (ushort)1, (ushort)1), s_2);
38+
}
39+
}
40+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
<PropertyGroup>
3+
<Optimize>True</Optimize>
4+
</PropertyGroup>
5+
<ItemGroup>
6+
<Compile Include="$(MSBuildProjectName).cs" />
7+
</ItemGroup>
8+
</Project>

0 commit comments

Comments
 (0)