Skip to content

Commit

Permalink
[FIRRTL] Canonicalize away CVT and adjust all patterns which matched …
Browse files Browse the repository at this point in the history
…cvt (llvm#5527)

closes llvm#5524
  • Loading branch information
darthscsi authored and calebmkim committed Jul 12, 2023
1 parent 7db8bf0 commit e8ef0d4
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 44 deletions.
139 changes: 106 additions & 33 deletions include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ def NullAttr : Constraint<CPred<"!$0">>;
// Constraint that enforces equal types
def EqualTypes : Constraint<CPred<"$0.getType() == $1.getType()">>;

// Constraint that enforces equal type sizes
def EqualIntSize : Constraint<CPred<"cast<IntType>($0.getType()).getWidth() == cast<IntType>($1.getType()).getWidth()">>;

// sizeof(0) >= sizeof(1)
def IntTypeWidthGEQ32 : Constraint<CPred<
"cast<IntType>($0.getType()).getBitWidthOrSentinel() >= cast<IntType>($1.getType()).getBitWidthOrSentinel()">>;

// sizeof(0) > sizeof(1)
def IntTypeWidthGT32 : Constraint<CPred<
"cast<IntType>($0.getType()).getBitWidthOrSentinel() > cast<IntType>($1.getType()).getBitWidthOrSentinel()">>;

// Constraint that enforces int types
def IntTypes : Constraint<CPred<"isa<IntType>($0.getType())">>;

Expand Down Expand Up @@ -78,6 +89,14 @@ def OneConstantOp : Constraint<Or<[
"$0.getDefiningOp<SpecialConstantOp>().getValue() == true">
]>>;

/// Constraint that matches a zero ConstantOp or SpecialConstantOp.
def NotAllOneConstantOp : Constraint<Or<[
CPred<"$0.getDefiningOp<ConstantOp>() &&"
"!$0.getDefiningOp<ConstantOp>().getValue().isAllOnes()">,
CPred<"$0.getDefiningOp<SpecialConstantOp>() &&"
"$0.getDefiningOp<SpecialConstantOp>().getValue() == false">
]>>;

/// Constraint that matches an all ones ConstantOp.
def AllOneConstantOp : Constraint<CPred<"$0.getDefiningOp<ConstantOp>() && $0.getDefiningOp<ConstantOp>().getValue().isAllOnes()">>;

Expand Down Expand Up @@ -277,20 +296,26 @@ def AndOfSelf : Pat <
[(KnownWidth $x)]>;

/// and(x, pad(y, n)) -> pad(and(tail(x), y), n), x is unsigned
def AndOfPad : Pat <
def AndOfPadU : Pat <
(AndPrimOp:$old (either $x, (PadPrimOp:$pad $y, $n))),
(MoveNameHint $old, (PadPrimOp (AndPrimOp (TailPrimOp $x, (TypeWidthAdjust32 $x, $y)), $y), $n)),
[(KnownWidth $x), (UIntType $x), (EqualTypes $x, $pad)]>;

/// cvt is adding a 0, which will pass through the and
/// and(cvt(x:U), const) -> cat(0, and(x,trunc(const))
def AndCvtU : Pat <
(AndPrimOp:$old (CvtPrimOp:$ox $x), (ConstantOp:$ocst $cst)),
(MoveNameHint $old, (CatPrimOp
(NativeCodeCall<"$_builder.create<ConstantOp>($0.getLoc(), APSInt(APInt()))"> $ocst),
(AndPrimOp $x, (NativeCodeCall<"$_builder.create<ConstantOp>($0.getLoc(), cast<IntType>($0.getType()), $1.getValue().trunc($1.getValue().getBitWidth()-1).zext(*cast<IntType>($0.getType()).getWidth()))"> $x, $cst))
)),
[(KnownWidth $x), (UIntType $x), (EqualTypes $ox, $ocst)]>;
/// and(x, pad(y, n)) -> cat(head(x), and(tail(x), y)), x is signed
def AndOfPadS : Pat <
(AndPrimOp:$old (either $x, (PadPrimOp:$pad $y, $n))),
(MoveNameHint $old, (CatPrimOp (HeadPrimOp $x, (TypeWidthAdjust32 $x, $y)), (AndPrimOp (TailPrimOp $x, (TypeWidthAdjust32 $x, $y)), (AsUIntPrimOp $y)))),
[(KnownWidth $x), (SIntType $x), (EqualTypes $x, $pad)]>;

def AndOfAsSIntL : Pat<
(AndPrimOp:$old (AsSIntPrimOp $x), $y),
(MoveNameHint $old, (AndPrimOp $x, (AsUIntPrimOp $y))),
[(KnownWidth $x), (EqualIntSize $x, $y)]>;

def AndOfAsSIntR : Pat<
(AndPrimOp:$old $x, (AsSIntPrimOp $y)),
(MoveNameHint $old, (AndPrimOp (AsUIntPrimOp $x), $y)),
[(KnownWidth $x), (EqualIntSize $x, $y)]>;

// or(x, 0) -> x, fold can't handle all cases
def OrOfZero : Pat <
Expand Down Expand Up @@ -348,23 +373,23 @@ def OrRasUInt : Pat <
(MoveNameHint $old, (OrRPrimOp $x)),
[(KnownWidth $x), (IntTypes $x)]>;

/// orr(cvt(x)) -> orr(x)
def OrRCvt : Pat <
(OrRPrimOp:$old (CvtPrimOp $x)),
(MoveNameHint $old, (OrRPrimOp $x)),
[(KnownWidth $x), (IntTypes $x)]>;

/// orr(cat(0,x)) -> orr(x)
def OrRCatZeroH : Pat <
(OrRPrimOp:$old (CatPrimOp $c, $x)),
(MoveNameHint $old, (OrRPrimOp $x)),
[(KnownWidth $x), (IntTypes $x), (ZeroConstantOp $c)]>;
[(KnownWidth $x), (ZeroConstantOp $c)]>;

/// orr(cat(x,0)) -> orr(x)
def OrRCatZeroL : Pat <
(OrRPrimOp:$old (CatPrimOp $x, $c)),
(MoveNameHint $old, (OrRPrimOp $x)),
[(KnownWidth $x), (IntTypes $x), (ZeroConstantOp $c)]>;
[(KnownWidth $x), (ZeroConstantOp $c)]>;

/// orr(pad(x,n)) -> orr(x)
def OrRPadU : Pat <
(OrRPrimOp:$old (PadPrimOp:$pad $x, $n)),
(MoveNameHint $old, (OrRPrimOp $x)),
[(KnownWidth $x), (UIntTypes $x), (IntTypeWidthGEQ32 $pad, $x)]>;

/// xorr(asSInt(x)) -> xorr(x)
def XorRasSInt : Pat <
Expand All @@ -378,12 +403,6 @@ def XorRasUInt : Pat <
(MoveNameHint $old, (XorRPrimOp $x)),
[(KnownWidth $x), (IntTypes $x)]>;

/// xorr(cvt(x)) -> xorr(x)
def XorRCvt : Pat <
(XorRPrimOp:$old (CvtPrimOp $x)),
(MoveNameHint $old, (XorRPrimOp $x)),
[(KnownWidth $x), (IntTypes $x)]>;

/// xorr(cat(0,x)) -> xorr(x)
def XorRCatZeroH : Pat <
(XorRPrimOp:$old (CatPrimOp $c, $x)),
Expand All @@ -396,6 +415,12 @@ def XorRCatZeroL : Pat <
(MoveNameHint $old, (XorRPrimOp $x)),
[(KnownWidth $x), (IntTypes $x), (ZeroConstantOp $c)]>;

/// xorr(pad(x,n)) -> xorr(x)
def XorRPadU : Pat <
(XorRPrimOp:$old (PadPrimOp:$pad $x, $n)),
(MoveNameHint $old, (XorRPrimOp $x)),
[(KnownWidth $x), (UIntTypes $x), (IntTypeWidthGEQ32 $pad, $x)]>;

/// andr(asSInt(x)) -> andr(x)
def AndRasSInt : Pat <
(AndRPrimOp:$old (AsSIntPrimOp $x)),
Expand All @@ -408,17 +433,40 @@ def AndRasUInt : Pat <
(MoveNameHint $old, (AndRPrimOp $x)),
[(KnownWidth $x), (IntTypes $x)]>;

/// andr(cvt(x:unsigned)) -> 0
def AndRCvtU : Pat <
(AndRPrimOp:$old (CvtPrimOp $x)),
(MoveNameHint $old,
(NativeCodeCall<"$_builder.create<ConstantOp>($0.getLoc(), cast<IntType>($0.getType()), getIntZerosAttr($0.getType()))"> $old)),
[(KnownWidth $x), (UIntTypes $x)]>;
/// andr(cat(*0*, x)) -> 0
def AndRCatZeroL : Pat <
(AndRPrimOp:$old (CatPrimOp $z, $x)),
(NativeCodeCall<"$_builder.create<ConstantOp>($0.getLoc(), cast<IntType>($0.getType()), getIntZerosAttr($0.getType()))"> $old),
[(NotAllOneConstantOp $z)]>;

/// andr(cat(x, *0*)) -> 0
def AndRCatZeroR : Pat <
(AndRPrimOp:$old (CatPrimOp $x, $z)),
(NativeCodeCall<"$_builder.create<ConstantOp>($0.getLoc(), cast<IntType>($0.getType()), getIntZerosAttr($0.getType()))"> $old),
[(NotAllOneConstantOp $z)]>;

/// andr(cat(1, x)) -> andr(x)
def AndRCatOneL : Pat <
(AndRPrimOp:$old (CatPrimOp $z, $x)),
(MoveNameHint $old, (AndRPrimOp $x)),
[(AllOneConstantOp $z)]>;

/// andr(cvt(x:signed)) -> andr(x)
def AndRCvtS : Pat <
(AndRPrimOp:$old (CvtPrimOp $x)),
/// andr(cat(x, 1)) -> andr(x)
def AndRCatOneR : Pat <
(AndRPrimOp:$old (CatPrimOp $x, $z)),
(MoveNameHint $old, (AndRPrimOp $x)),
[(AllOneConstantOp $z)]>;

/// andr(pad(x:U,n)) -> 0 (where pad is doing something)
def AndRPadU : Pat <
(AndRPrimOp:$old (PadPrimOp:$pad $x, $n)),
(NativeCodeCall<"$_builder.create<ConstantOp>($0.getLoc(), cast<IntType>($0.getType()), getIntZerosAttr($0.getType()))"> $old),
[(KnownWidth $x), (UIntTypes $x), (IntTypeWidthGT32 $pad, $x)]>;

/// andr(pad(x:S,n)) -> andr(x)
def AndRPadS : Pat <
(AndRPrimOp:$old (PadPrimOp:$pad $x, $n)),
(NativeCodeCall<"$_builder.create<ConstantOp>($0.getLoc(), cast<IntType>($0.getType()), getIntZerosAttr($0.getType()))"> $old),
[(KnownWidth $x), (SIntTypes $x)]>;


Expand All @@ -443,6 +491,19 @@ def BitsOfAsUInt : Pat<
(MoveNameHint $old, (BitsPrimOp $x, $oHigh, $oLow)),
[(KnownWidth $x)]>;

// bits(asUInt) -> bits
def BitsOfAnd : Pat<
(BitsPrimOp:$old (AndPrimOp $x, $y), I32Attr:$oHigh, I32Attr:$oLow),
(MoveNameHint $old, (AndPrimOp (BitsPrimOp $x, $oHigh, $oLow), (BitsPrimOp $y, $oHigh, $oLow))),
[(KnownWidth $x), (EqualTypes $x, $y)]>;

// bits(pad(x)) -> x when they cancel out
def AttrIsZero : Constraint<CPred<"$0.getValue().isZero()">>;
def BitsOfPad : Pat<
(BitsPrimOp:$old (PadPrimOp $x, I32Attr:$n), I32Attr:$oHigh, I32Attr:$oLow),
(AsUIntPrimOp $x),
[(KnownWidth $x), (EqualIntSize $old, $x), (AttrIsZero $oLow)]>;

// subaccess a, cst -> subindex a, cst
// TODO: only enable if cst is inside a. Subaccess and subindex behave differently for out-of-bounds indexes.
def SubaccessOfConstant : Pat<
Expand Down Expand Up @@ -508,5 +569,17 @@ def StoUtoS : Pat<
(replaceWithValue $x),
[]>;

def CVTSigned : Pat<
(CvtPrimOp $x),
(replaceWithValue $x),
[(SIntType $x)]>;

def CVTUnSigned : Pat<
(CvtPrimOp:$old $x),
(MoveNameHint $old,
(AsSIntPrimOp
(PadPrimOp $x,
(NativeCodeCall<"$_builder.getI32IntegerAttr(cast<SIntType>($0.getType()).getBitWidthOrSentinel())"> $old)))),
[(UIntType $x), (KnownWidth $old)]>;

#endif // CIRCT_DIALECT_FIRRTL_FIRRTLCANONICALIZATION_TD
2 changes: 1 addition & 1 deletion include/circt/Dialect/FIRRTL/FIRRTLExpressions.td
Original file line number Diff line number Diff line change
Expand Up @@ -666,11 +666,11 @@ def SizeOfIntrinsicOp : UnaryPrimOp<"int.sizeof", FIRRTLBaseType, UInt32Type>;
let hasCanonicalizer=1 in {
def AsSIntPrimOp : UnaryPrimOp<"asSInt", FIRRTLBaseType, SIntType>;
def AsUIntPrimOp : UnaryPrimOp<"asUInt", FIRRTLBaseType, UIntType>;
def CvtPrimOp : UnaryPrimOp<"cvt", IntType, SIntType>;
}
def AsAsyncResetPrimOp
: UnaryPrimOp<"asAsyncReset", OneBitCastableType, AsyncResetType>;
def AsClockPrimOp : UnaryPrimOp<"asClock", OneBitCastableType, ClockType>;
def CvtPrimOp : UnaryPrimOp<"cvt", IntType, SIntType>;
def NegPrimOp : UnaryPrimOp<"neg", IntType, SIntType>;
let hasCanonicalizer = true in
def NotPrimOp : UnaryPrimOp<"not", IntType, UIntType>;
Expand Down
24 changes: 16 additions & 8 deletions lib/Dialect/FIRRTL/FIRRTLFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,8 +581,8 @@ void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::extendAnd, patterns::moveConstAnd,
patterns::AndOfZero, patterns::AndOfAllOne,
patterns::AndOfSelf, patterns::AndOfPad, patterns::AndCvtU>(
context);
patterns::AndOfSelf, patterns::AndOfPadU, patterns::AndOfPadS,
patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(context);
}

OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
Expand Down Expand Up @@ -1041,6 +1041,11 @@ OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
return {};
}

void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(context);
}

OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
if (!hasKnownWidthIntTypes(*this))
return {};
Expand Down Expand Up @@ -1091,8 +1096,10 @@ OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {

void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRCvtU,
patterns::AndRCvtS>(context);
results
.insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
patterns::AndRPadS, patterns::AndRCatOneL, patterns::AndRCatOneR,
patterns::AndRCatZeroL, patterns::AndRCatZeroR>(context);
}

OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
Expand All @@ -1116,7 +1123,7 @@ OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {

void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRCvt,
results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
patterns::OrRCatZeroH, patterns::OrRCatZeroL>(context);
}

Expand All @@ -1140,7 +1147,7 @@ OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {

void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRCvt,
results.insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
patterns::XorRCatZeroH, patterns::XorRCatZeroL>(context);
}

Expand Down Expand Up @@ -1244,8 +1251,9 @@ OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {

void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::BitsOfBits, patterns::BitsOfMux,
patterns::BitsOfAsUInt>(context);
results
.insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
patterns::BitsOfAnd, patterns::BitsOfPad>(context);
}

/// Replace the specified operation with a 'bits' op from the specified hi/lo
Expand Down
28 changes: 26 additions & 2 deletions test/Dialect/FIRRTL/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ firrtl.module @Casts(in %ui1 : !firrtl.uint<1>, in %si1 : !firrtl.sint<1>,
%11 = firrtl.asSInt %ui1 : (!firrtl.uint<1>) -> !firrtl.sint<1>
%12 = firrtl.asUInt %11 : (!firrtl.sint<1>) -> !firrtl.uint<1>
firrtl.strictconnect %out2_ui1, %12 : !firrtl.uint<1>
// CHECK: firrtl.strictconnect %out2_si1, %si1
%13 = firrtl.cvt %si1 : (!firrtl.sint<1>) -> !firrtl.sint<1>
firrtl.strictconnect %out2_si1, %13 : !firrtl.sint<1>
}

// CHECK-LABEL: firrtl.module @Div
Expand Down Expand Up @@ -187,7 +190,7 @@ firrtl.module @And(in %in: !firrtl.uint<4>,
firrtl.strictconnect %out6, %9 : !firrtl.uint<6>

// CHECK: %[[AND:.*]] = firrtl.and %in, %c3_ui4
// CHECK: firrtl.cat %c0_ui1, %[[AND]]
// CHECK: firrtl.pad %[[AND]], 5
%10 = firrtl.cvt %in : (!firrtl.uint<4>) -> !firrtl.sint<5>
%11 = firrtl.and %10, %c3_si5 : (!firrtl.sint<5>, !firrtl.sint<5>) -> !firrtl.uint<5>
firrtl.strictconnect %out5, %11 : !firrtl.uint<5>
Expand Down Expand Up @@ -675,7 +678,7 @@ firrtl.module @Tail(in %in4u: !firrtl.uint<4>,
}

// CHECK-LABEL: firrtl.module @Andr
firrtl.module @Andr(in %in0 : !firrtl.uint<0>,
firrtl.module @Andr(in %in0 : !firrtl.uint<0>, in %in1 : !firrtl.sint<2>,
out %a: !firrtl.uint<1>, out %b: !firrtl.uint<1>,
out %c: !firrtl.uint<1>, out %d: !firrtl.uint<1>,
out %e: !firrtl.uint<1>, out %f: !firrtl.uint<1>,
Expand Down Expand Up @@ -703,6 +706,17 @@ firrtl.module @Andr(in %in0 : !firrtl.uint<0>,
// CHECK: firrtl.strictconnect %e, %[[ONE]]
firrtl.connect %e, %4 : !firrtl.uint<1>, !firrtl.uint<1>

// CHECK: %[[and1:.*]] = firrtl.andr %in1
// CHECK-NEXT: firrtl.strictconnect %e, %[[and1]]
%cat = firrtl.cat %in1, %cn1_si2 : (!firrtl.sint<2>, !firrtl.sint<2>) -> !firrtl.uint<4>
%andrcat = firrtl.andr %cat : (!firrtl.uint<4>) -> !firrtl.uint<1>
firrtl.connect %e, %andrcat : !firrtl.uint<1>, !firrtl.uint<1>

// CHECK: firrtl.strictconnect %e, %[[ZERO]]
%cat2 = firrtl.cat %in1, %cn2_si2 : (!firrtl.sint<2>, !firrtl.sint<2>) -> !firrtl.uint<4>
%andrcat2 = firrtl.andr %cat2 : (!firrtl.uint<4>) -> !firrtl.uint<1>
firrtl.connect %e, %andrcat2 : !firrtl.uint<1>, !firrtl.uint<1>

// CHECK: firrtl.strictconnect %g, %[[ZERO]]
%5 = firrtl.asSInt %h : (!firrtl.uint<64>) -> !firrtl.sint<64>
%6 = firrtl.asUInt %5 : (!firrtl.sint<64>) -> !firrtl.uint<64>
Expand Down Expand Up @@ -3035,4 +3049,14 @@ firrtl.module @RefCastSame(in %in: !firrtl.probe<uint<1>>, out %out: !firrtl.pro
firrtl.ref.define %out, %same_as_in : !firrtl.probe<uint<1>>
}

// CHECK-LABEL: @Issue5527
firrtl.module @Issue5527(in %x: !firrtl.uint<1>, out %out: !firrtl.uint<2>) attributes {convention = #firrtl<convention scalarized>} {
%0 = firrtl.cvt %x : (!firrtl.uint<1>) -> !firrtl.sint<2>
%c2_si4 = firrtl.constant 2 : !firrtl.sint<4>
%1 = firrtl.and %0, %c2_si4 : (!firrtl.sint<2>, !firrtl.sint<4>) -> !firrtl.uint<4>
%2 = firrtl.tail %1, 2 : (!firrtl.uint<4>) -> !firrtl.uint<2>
// CHECK: firrtl.strictconnect %out, %c0_ui2
firrtl.strictconnect %out, %2 : !firrtl.uint<2>
}

}

0 comments on commit e8ef0d4

Please sign in to comment.