Conversation
There was a problem hiding this comment.
Pull request overview
This PR aims to fix functional issues in the Im2col “Ball” hardware prototype and ensure it is exercised by the CTest / Sardine verification flow.
Changes:
- Reworks the Chisel
Im2colmodule control/data path (row prefetching, window generation, streaming writeback). - Adds a dedicated
RowSlotFIFOmodule to manage circular buffering of input rows. - Updates/extends the C-based
im2col_testto compute a software reference and adds the workload to the Sardine pytest list.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| bb-tests/workloads/src/CTest/toy/im2col_test.c | Updates the im2col CTest to generate an expected matrix and compare against hardware output. |
| bb-tests/sardine/tests/test_ctest.py | Adds the im2col CTest workload to the parametrized Verilator workload list. |
| arch/src/main/scala/framework/balldomain/prototype/im2col/Im2col.scala | Major rewrite of the Im2col Ball implementation (new FSM, buffering, packing). |
| arch/src/main/scala/framework/balldomain/prototype/im2col/FIFO.scala | Introduces RowSlotFIFO helper module used by the new Im2col implementation. |
Comments suppressed due to low confidence (1)
bb-tests/workloads/src/CTest/toy/im2col_test.c:140
main()doesn't return a status code in the non-MULTICORE case. Other CTest workloads return 0/1 so the runner can detect failures via exit status. Pleasereturn passed ? 0 : 1;(orreturn !passed;) after printing.
int main()
{
#ifdef MULTICORE
multicore(MULTICORE);
#endif
int passed = test_im2col();
if (passed)
{
printf("Im2col test PASSED\n");
}
else
{
printf("Im2col test FAILED\n");
}
#ifdef MULTICORE
exit(0);
#endif
}
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| when(windowDone) { | ||
| when(packFull && !io.bankWrite(0).io.req.fire) { | ||
| state := write_window | ||
| }.otherwise { | ||
| when(isLastWindow) { | ||
| state := complete | ||
| }.elsewhen(colEnd) { |
There was a problem hiding this comment.
The module can transition to complete on the last window without flushing a partially-filled packReg beat. This will drop trailing output elements whenever (numWindows * kRow*kCol) is not a multiple of lanesPerBeat. Please add logic to emit a final partial beat (with a proper mask and/or zero-fill) before completing, or otherwise define/guarantee that output lengths are always beat-aligned and enforce it in invalidShape.
| bb_mem_alloc(op2_bank_id, 1, 1); | ||
|
|
||
| bb_mvin((uintptr_t)a, op1_bank_id, size, 1); | ||
| bb_mvin((uintptr_t)a, op1_bank_id, 32, 1); |
There was a problem hiding this comment.
bb_mvin depth is hard-coded to 32, which will cause an out-of-bounds host-memory read since a only contains DIM*DIM elements and the MVIN depth encodes scratchpad beats/rows (e.g., other tests use DIM for a 16x16 matrix). Please compute the MVIN depth from INROW and INCOL (e.g., INROW * ceil(INCOL / LANES_PER_BEAT)) instead of a constant.
| bb_mvin((uintptr_t)a, op1_bank_id, 32, 1); | |
| uint32_t mvin_depth = INROW * ((INCOL + LANES_PER_BEAT - 1) / LANES_PER_BEAT); | |
| bb_mvin((uintptr_t)a, op1_bank_id, mvin_depth, 1); |
| startcol); | ||
| bb_fence(); | ||
| bb_mvout((uintptr_t)b, op2_bank_id, size, 1); | ||
| bb_mvout((uintptr_t)b, op2_bank_id, conv_num() / kernel_elems(), 1); |
There was a problem hiding this comment.
bb_mvout depth is computed as conv_num() / kernel_elems(), which only happens to match the required number of output beats for the current constants (because kernel_elems^2 == LANES_PER_BEAT). For other kernel sizes this will read back the wrong amount of data. Please compute depth as ceil(conv_num() * kernel_elems() / LANES_PER_BEAT) (or equivalent) so it stays correct when KROW/KCOL change.
| bb_mvout((uintptr_t)b, op2_bank_id, conv_num() / kernel_elems(), 1); | |
| int conv = conv_num(); | |
| int elems = kernel_elems(); | |
| int beats = (conv * elems + LANES_PER_BEAT - 1) / LANES_PER_BEAT; | |
| bb_mvout((uintptr_t)b, op2_bank_id, beats, 1); |
| private val rowEnd = rowPtrReg === (startRowReg + rowMax) | ||
| private val colEnd = colPtrReg === (startColReg + colMax) |
There was a problem hiding this comment.
Window termination uses rowEnd = rowPtrReg === (startRowReg + rowMax) / colEnd = colPtrReg === (startColReg + colMax) but rowMax/colMax are already absolute maxima (inRow - kRow, inCol - kCol). Adding startRowReg/startColReg makes the end condition wrong for non-zero start offsets and can drive reads/writes past the valid window range. Please compute end conditions against the absolute last valid top-left indices (e.g., rowPtrReg === (inRowReg - kRowReg) and colPtrReg === (inColReg - kColReg)) and also validate startRowReg <= (inRowReg - kRowReg) and startColReg <= (inColReg - kColReg) in invalidShape.
| private val rowEnd = rowPtrReg === (startRowReg + rowMax) | |
| private val colEnd = colPtrReg === (startColReg + colMax) | |
| private val rowEnd = rowPtrReg === rowMax | |
| private val colEnd = colPtrReg === colMax |
| is(generate_window) { | ||
| val startLane = colPtrReg % lanesPerBeat.U | ||
| val safeKCol = Mux(kColReg === 0.U, 1.U, kColReg) | ||
|
|
||
| // Rearrange data | ||
| // For example, for klen_reg=3, combine (00)(01)(02)(10)(11)(12)(20)(21)(22) | ||
| Cat((0 until InputNum).map(i => window(i)).reverse) | ||
| } | ||
| val t = genElemIdxReg | ||
| val kRowIdx = t / safeKCol | ||
| val kColIdx = t % safeKCol | ||
|
|
||
| val physicalSlot = RowSlotFIFO.logicalToPhysical(rowFifo.io.head, kRowIdx, kRowReg) | ||
| val laneSum = startLane + kColIdx | ||
| val beatIdx = laneSum / lanesPerBeat.U | ||
| val laneIdx = laneSum % lanesPerBeat.U | ||
| val beatWord = lineBuffer(physicalSlot)(beatIdx) | ||
|
|
There was a problem hiding this comment.
generate_window derives beatIdx from startLane = colPtrReg % lanesPerBeat and laneSum / lanesPerBeat, which drops the beat offset for colPtrReg >= lanesPerBeat. For inColReg > 16, windows starting in the 2nd beat will incorrectly read from beat 0. Please compute absCol = colPtrReg + kColIdx and then use beatIdx = absCol / lanesPerBeat and laneIdx = absCol % lanesPerBeat (or include colPtrReg / lanesPerBeat as a base beat).
No description provided.