Skip to content

Commit

Permalink
[Pipeline] Add per-register clock gating (llvm#5489)
Browse files Browse the repository at this point in the history
Also changes to that the `clock-gate-regs` options actually implements clock gates instead of a `seq.comp_reg.ce` operation. To cover all cases, i think there needs to be three kinds of gating implementations - clock gate, clock enable (`seq.compreg.ce`) and input muxing. The first and last are what we have now.
  • Loading branch information
mortbopet authored and calebmkim committed Jul 12, 2023
1 parent 1b5f0d9 commit f840dd2
Show file tree
Hide file tree
Showing 12 changed files with 443 additions and 224 deletions.
32 changes: 29 additions & 3 deletions include/circt/Dialect/Pipeline/PipelineOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -208,15 +208,31 @@ def StageOp : Op<Pipeline_Dialect, "stage", [
1. which stage (block) to transition to next
2. which registers to build at this stage boundary
3. which values to pass through to the next stage without registering
4. An optional hierarchy of boolean values to be used for clock gates for
each register.
- The implicit '!stalled' gate will always be the first signal in the
hierarchy. Further signals are added to the hierarchy from left to
right.


Example:
```mlir
pipeline.stage ^bb1 regs(%a : i32 gated by [%foo, %bar], %b : i1) pass(%c : i32)
```
}];

let arguments = (ins Variadic<AnyType>:$registers, Variadic<AnyType>:$passthroughs);
let arguments = (ins
Variadic<AnyType>:$registers,
Variadic<AnyType>:$passthroughs,
Variadic<I1>:$clockGates,
I64ArrayAttr:$clockGatesPerRegister);
let successors = (successor AnySuccessor:$nextStage);
let results = (outs);
let hasVerifier = 1;
let skipDefaultBuilders = 1;

let assemblyFormat = [{
$nextStage (`regs` `(` $registers^ `:` type($registers) `)`)?
$nextStage custom<StageRegisters>($registers, type($registers), $clockGates, $clockGatesPerRegister)
(`pass` `(` $passthroughs^ `:` type($passthroughs) `)` )? attr-dict
}];

Expand All @@ -225,9 +241,19 @@ def StageOp : Op<Pipeline_Dialect, "stage", [
void setNextStage(Block *block) {
setSuccessor(block);
}

// Returns the list of clock gates for the given register.
ValueRange getClockGatesForReg(unsigned regIdx);
}];
}

let builders = [
OpBuilder<(ins "Block*":$dest, "ValueRange":$registers, "ValueRange":$passthroughs)>,
// Clock gate builder, which takes a mapping between the registers and
// and their clock gate hierarchy.
OpBuilder<(ins "Block*":$dest, "ValueRange":$registers, "ValueRange":$passthroughs,
"const llvm::DenseMap<Value, llvm::SmallVector<Value>>&":$clockGates)>
];
}

def ReturnOp : Op<Pipeline_Dialect, "return", [
Terminator,
Expand Down
66 changes: 29 additions & 37 deletions lib/Conversion/PipelineToHW/PipelineToHW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/Pipeline/PipelineOps.h"
#include "circt/Dialect/Seq/SeqOps.h"
#include "circt/Support/BackedgeBuilder.h"
#include "mlir/IR/Builders.h"
#include "llvm/ADT/TypeSwitch.h"

Expand Down Expand Up @@ -93,16 +92,23 @@ class PipelineLowering {

// Build data registers.
auto stageRegPrefix = getStageRegPrefix(stageIndex);
BackedgeBuilder bb(builder, stageTerminator->getLoc());
auto loc = stageTerminator->getLoc();

// Build the clock enable signal: valid && !stall (if applicable)
Value dataValid = valid;
Value stageValidAndNotStalled = valid;
Value notStalled;
bool hasStall = static_cast<bool>(stall);
if (hasStall) {
notStalled = comb::createOrFoldNot(loc, stall, builder);
dataValid = builder.create<comb::AndOp>(loc, dataValid, notStalled);
stageValidAndNotStalled =
builder.create<comb::AndOp>(loc, stageValidAndNotStalled, notStalled);
}

Value notStalledClockGate;
if (this->clockGateRegs) {
// Create the top-level clock gate.
notStalledClockGate = builder.create<seq::ClockGateOp>(
loc, clock, stageValidAndNotStalled, /*test_enable=*/Value());
}

for (auto it : llvm::enumerate(stageTerminator.getRegisters())) {
Expand All @@ -112,23 +118,22 @@ class PipelineLowering {
auto regIn = it.value();
auto regName = builder.getStringAttr(stageRegPrefix.strref() + "_reg" +
std::to_string(regIdx));
Type dataType = regIn.getType();
Value dataReg;
if (this->clockGateRegs) {
// Clock gate based on the valid signal.
dataReg = builder.create<seq::CompRegClockEnabledOp>(
stageTerminator->getLoc(), dataType, regIn, clock, dataValid,
regName, reset, /*resetValue*/ Value(), /*sym_name*/ StringAttr());
// Use the clock gate instead of input muxing.
Value currClockGate = notStalledClockGate;
for (auto hierClockGateEnable :
stageTerminator.getClockGatesForReg(regIdx)) {
// Create clock gates for any hierarchically nested clock gates.
currClockGate = builder.create<seq::ClockGateOp>(
loc, currClockGate, hierClockGateEnable, /*test_enable=*/Value());
}
dataReg = builder.create<seq::CompRegOp>(stageTerminator->getLoc(),
regIn, currClockGate, regName);
} else {
// Use input muxing.
auto dataRegBE = bb.get(dataType);
auto dataRegNext = builder.create<comb::MuxOp>(
stageTerminator->getLoc(), dataValid, regIn, dataRegBE);
dataReg = builder.create<seq::CompRegOp>(
stageTerminator->getLoc(), dataType, dataRegNext, clock, regName,
reset,
/*resetValue*/ Value(), /*sym_name*/ StringAttr());
dataRegBE.setValue(dataReg);
dataReg = builder.create<seq::CompRegClockEnabledOp>(
stageTerminator->getLoc(), regIn, clock, stageValidAndNotStalled,
regName);
}
rets.regs.push_back(dataReg);
}
Expand All @@ -141,26 +146,13 @@ class PipelineLowering {
builder.create<hw::ConstantOp>(terminator->getLoc(), APInt(1, 0, false))
.getResult();
if (hasStall) {
if (clockGateRegs) {
rets.valid = builder.create<seq::CompRegClockEnabledOp>(
loc, builder.getI1Type(), valid, clock, notStalled, validRegName,
reset, validRegResetVal,
/*sym_name*/ StringAttr());
} else {
auto validRegBE = bb.get(builder.getI1Type());
auto validRegNext =
builder.create<comb::MuxOp>(loc, notStalled, valid, validRegBE);
rets.valid = builder.create<seq::CompRegOp>(
loc, builder.getI1Type(), validRegNext, clock, validRegName, reset,
validRegResetVal,
/*sym_name*/ StringAttr());
validRegBE.setValue(rets.valid);
}
rets.valid = builder.create<seq::CompRegClockEnabledOp>(
loc, builder.getI1Type(), valid, clock, notStalled, validRegName,
reset, validRegResetVal, validRegName);
} else {
rets.valid =
builder.create<seq::CompRegOp>(loc, builder.getI1Type(), valid, clock,
validRegName, reset, validRegResetVal,
/*sym_name*/ StringAttr());
rets.valid = builder.create<seq::CompRegOp>(
loc, builder.getI1Type(), valid, clock, validRegName, reset,
validRegResetVal, validRegName);
}

rets.passthroughs = stageTerminator.getPassthroughs();
Expand Down
124 changes: 124 additions & 0 deletions lib/Dialect/Pipeline/PipelineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,124 @@ LogicalResult ReturnOp::verify() {
// StageOp
//===----------------------------------------------------------------------===//

static ParseResult parseSingleStageRegister(
OpAsmParser &parser, OpAsmParser::UnresolvedOperand &v, Type &t,
llvm::SmallVector<OpAsmParser::UnresolvedOperand> &clockGates) {

if (failed(parser.parseOperand(v)) || failed(parser.parseColonType(t)))
return failure();

if (failed(parser.parseOptionalKeyword("gated")))
return success();

if (failed(parser.parseKeyword("by")) ||
failed(
parser.parseOperandList(clockGates, OpAsmParser::Delimiter::Square)))
return failure();

return success();
}

// Parses the form:
// regs($register : type($register) (`gated by` `[` $clockGates `]`)?, ...)
ParseResult parseStageRegisters(
OpAsmParser &parser,
llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> &registers,
llvm::SmallVector<mlir::Type, 1> &registerTypes,
llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> &clockGates,
ArrayAttr &clockGatesPerRegister) {

if (failed(parser.parseOptionalKeyword("regs"))) {
clockGatesPerRegister = parser.getBuilder().getI64ArrayAttr({});
return success(); // no registers to parse.
}

llvm::SmallVector<int64_t> clockGatesPerRegisterList;
if (failed(parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, [&]() {
OpAsmParser::UnresolvedOperand v;
Type t;
llvm::SmallVector<OpAsmParser::UnresolvedOperand> cgs;
if (parseSingleStageRegister(parser, v, t, cgs))
return failure();
registers.push_back(v);
registerTypes.push_back(t);
llvm::append_range(clockGates, cgs);
clockGatesPerRegisterList.push_back(cgs.size());
return success();
})))
return failure();

clockGatesPerRegister =
parser.getBuilder().getI64ArrayAttr(clockGatesPerRegisterList);

return success();
}

void printStageRegisters(OpAsmPrinter &p, Operation *op, ValueRange registers,
TypeRange registerTypes, ValueRange clockGates,
ArrayAttr clockGatesPerRegister) {
if (registers.empty())
return;

p << "regs(";
size_t clockGateStartIdx = 0;
llvm::interleaveComma(
llvm::zip(registers, registerTypes, clockGatesPerRegister), p,
[&](auto it) {
auto &[reg, type, nClockGatesAttr] = it;
p << reg << " : " << type;
int64_t nClockGates =
nClockGatesAttr.template cast<IntegerAttr>().getInt();
if (nClockGates == 0)
return;
p << " gated by [";
llvm::interleaveComma(clockGates.slice(clockGateStartIdx, nClockGates),
p);
p << "]";
clockGateStartIdx += nClockGates;
});
p << ")";
}

void StageOp::build(OpBuilder &odsBuilder, OperationState &odsState,
Block *dest, ValueRange registers,
ValueRange passthroughs) {
odsState.addSuccessors(dest);
odsState.addOperands(registers);
odsState.addOperands(passthroughs);
odsState.addAttribute("operand_segment_sizes",
odsBuilder.getDenseI32ArrayAttr(
{static_cast<int32_t>(registers.size()),
static_cast<int32_t>(passthroughs.size()),
/*clock gates*/ static_cast<int32_t>(0)}));
llvm::SmallVector<int64_t> clockGatesPerRegister(registers.size(), 0);
odsState.addAttribute("clockGatesPerRegister",
odsBuilder.getI64ArrayAttr(clockGatesPerRegister));
}

ValueRange StageOp::getClockGatesForReg(unsigned regIdx) {
assert(regIdx < getRegisters().size() && "register index out of bounds.");

// TODO: This could be optimized quite a bit if we didn't store clock gates
// per register as an array of sizes... look into using properties and maybe
// attaching a more complex datastructure to reduce compute here.

unsigned clockGateStartIdx = 0;
for (auto [index, nClockGatesAttr] :
llvm::enumerate(getClockGatesPerRegister().getAsRange<IntegerAttr>())) {
int64_t nClockGates = nClockGatesAttr.getInt();
if (index == regIdx) {
// This is the register we are looking for.
return getClockGates().slice(clockGateStartIdx, nClockGates);
}
// Increment the start index by the number of clock gates for this
// register.
clockGateStartIdx += nClockGates;
}

llvm_unreachable("register index out of bounds.");
}

LogicalResult StageOp::verify() {
// Verify that the target block has the correct arguments as this stage op.
llvm::SmallVector<Type> expectedTargetArgTypes;
Expand All @@ -350,6 +468,12 @@ LogicalResult StageOp::verify() {
<< index << " to have type " << arg << ", got " << barg << ".";
}

// Verify that the clock gate index list is equally sized to the # of
// registers.
if (getClockGatesPerRegister().size() != getRegisters().size())
return emitOpError("expected clockGatesPerRegister to be equally sized to "
"the number of registers.");

return success();
}

Expand Down
10 changes: 6 additions & 4 deletions lib/Dialect/Pipeline/Transforms/ExplicitRegs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,13 @@ void ExplicitRegsPass::runOnOperation() {
passIns.push_back(value);
}

// Append arguments to the predecessor stage terminator, which feeds this
// stage.
// Replace the predecessor stage terminator, which feeds this stage, with
// a new terminator that has materialized arguments.
StageOp terminator = cast<StageOp>(predecessorStage->getTerminator());
terminator.getRegistersMutable().append(regIns);
terminator.getPassthroughsMutable().append(passIns);
b.setInsertionPoint(terminator);
b.create<StageOp>(terminator.getLoc(), terminator.getNextStage(), regIns,
passIns);
terminator.erase();

// ... add arguments to the next stage. Registers first, then passthroughs.
llvm::SmallVector<Type> regAndPassTypes;
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/Pipeline/Transforms/ScheduleLinearPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ ScheduleLinearPipelinePass::schedulePipeline(UnscheduledPipelineOp pipeline) {
// Create a StageOp in the new stage, and branch it to the newly created
// stage.
b.setInsertionPointToEnd(currentStage);
b.create<pipeline::StageOp>(pipeline.getLoc(), ValueRange{}, ValueRange{},
newStage);
b.create<pipeline::StageOp>(pipeline.getLoc(), newStage, ValueRange{},
ValueRange{});
currentStage = newStage;
}
}
Expand Down
15 changes: 8 additions & 7 deletions test/Conversion/PipelineToHW/test_ce.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@
// CHECK-LABEL: hw.module @testSingle(
// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i1, %[[VAL_3:.*]]: i1, %[[VAL_4:.*]]: i1) -> (out0: i32, out1: i1) {
// CHECK: %[[VAL_5:.*]] = comb.sub %[[VAL_0]], %[[VAL_1]] : i32
// CHECK: %[[VAL_6:.*]] = seq.compreg.ce %[[VAL_5]], %[[VAL_3]], %[[VAL_2]] : i32
// CHECK: %[[VAL_7:.*]] = seq.compreg.ce %[[VAL_0]], %[[VAL_3]], %[[VAL_2]] : i32
// CHECK: %[[VAL_8:.*]] = hw.constant false
// CHECK: %[[VAL_9:.*]] = seq.compreg %[[VAL_2]], %[[VAL_3]], %[[VAL_4]], %[[VAL_8]] : i1
// CHECK: %[[VAL_10:.*]] = comb.add %[[VAL_6]], %[[VAL_7]] : i32
// CHECK: hw.output %[[VAL_10]], %[[VAL_9]] : i32, i1
// CHECK: %[[VAL_6:.*]] = seq.clock_gate %[[VAL_3]], %[[VAL_2]]
// CHECK: %[[VAL_7:.*]] = seq.compreg sym @p0_s0_reg0 %[[VAL_5]], %[[VAL_6]] : i32
// CHECK: %[[VAL_8:.*]] = seq.compreg sym @p0_s0_reg1 %[[VAL_0]], %[[VAL_6]] : i32
// CHECK: %[[VAL_9:.*]] = hw.constant false
// CHECK: %[[VAL_10:.*]] = seq.compreg sym @p0_s0_valid %[[VAL_2]], %[[VAL_3]], %[[VAL_4]], %[[VAL_9]] : i1
// CHECK: %[[VAL_11:.*]] = comb.add %[[VAL_7]], %[[VAL_8]] : i32
// CHECK: hw.output %[[VAL_11]], %[[VAL_10]] : i32, i1
// CHECK: }

hw.module @testSingle(%arg0: i32, %arg1: i32, %go: i1, %clk: i1, %rst: i1) -> (out0: i32, out1: i1) {
%0:2 = pipeline.scheduled(%arg0, %arg1) clock %clk reset %rst go %go : (i32, i32) -> (i32) {
^bb0(%arg0_0: i32, %arg1_1: i32, %s0_valid : i1):
%1 = comb.sub %arg0_0, %arg1_1 : i32
pipeline.stage ^bb1 regs(%1, %arg0_0 : i32, i32)
pipeline.stage ^bb1 regs(%1 : i32, %arg0_0 : i32)
^bb1(%6: i32, %7: i32, %s1_valid : i1): // pred: ^bb1
%8 = comb.add %6, %7 : i32
pipeline.return %8 : i32
Expand Down
Loading

0 comments on commit f840dd2

Please sign in to comment.