Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seperate edge M and edge N #783

Merged
merged 1 commit into from
May 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 68 additions & 34 deletions tensilelite/Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -5044,7 +5044,7 @@ def openSumAtLeastUnroll(self, kernel, prefetch, isOptNLL):
module.addSpaceLine()

placeHolder = "skipOptNLL_scc1_placeholder" if self.states.BiasDim == 3 else None
module.add(self.checkIsEdge(kernel, tmpSgprInfo, skipOptNLL, isLongBranch=isLongBranch, placeHolder=placeHolder))
module.add(self.checkIsEdge(kernel, tmpSgprInfo, skipOptNLL, skipOptNLL, isLongBranch=isLongBranch, placeHolder=placeHolder))
module.addSpaceLine()

# Check tail loop required:
Expand Down Expand Up @@ -8279,9 +8279,10 @@ def checkIsBetaZero(self, kernel, tmpSgprInfo, betaLabel, isLongBranch=False, pl
# tmpSgpr must have at least 6 free SGPR
# isEdgeTarget is the branch target if edges are required
##############################################################################
def checkIsEdge(self, kernel, tmpSgprInfo, isEdgeTarget, isLongBranch=False, placeHolder=None):
assert(isinstance(isEdgeTarget, Label))
isEdgeTargetLabel = isEdgeTarget.getLabelName()
def checkIsEdge(self, kernel, tmpSgprInfo, isEdgeTargetMT0, isEdgeTargetMT1, isLongBranch=False, placeHolder=None):
assert(isinstance(isEdgeTargetMT0, Label) and isinstance(isEdgeTargetMT1, Label))
isEdgeTargetMT0Label = isEdgeTargetMT0.getLabelName()
isEdgeTargetMT1Label = isEdgeTargetMT1.getLabelName()
module = Module("checkIsEdge")
tmpS0 = tmpSgprInfo.idx
tmpS1 = tmpS0 + 1
Expand Down Expand Up @@ -8316,9 +8317,9 @@ def checkIsEdge(self, kernel, tmpSgprInfo, isEdgeTarget, isLongBranch=False, pla
module.add(SCmpEQU32(src0=sgpr(tmpS0), src1=sgpr(tmpS0), comment="ForceEdgeStores!"))
if placeHolder == None:
if isLongBranch:
module.add(self.longBranchScc1(isEdgeTarget, posNeg=1, tmpSgprInfo=tmpSgprInfo, comment="jump if edges required"))
module.add(self.longBranchScc1(isEdgeTargetMT0, posNeg=1, tmpSgprInfo=tmpSgprInfo, comment="jump if edges required"))
else:
module.add(SCBranchSCC1(labelName=isEdgeTargetLabel, comment="jump if edges required"))
module.add(SCBranchSCC1(labelName=isEdgeTargetMT0Label, comment="jump if edges required"))
else:
placeHolderModule = Module(placeHolder)
placeHolderModule.addComment1("jump if edges required")
Expand All @@ -8344,9 +8345,9 @@ def checkIsEdge(self, kernel, tmpSgprInfo, isEdgeTarget, isLongBranch=False, pla
module.add(SCmpKGtU32(src=sgpr(tmpS0), simm16=hex(0), comment="rMT1 > 0"))
if placeHolder == None:
if isLongBranch:
module.add(self.longBranchScc1(isEdgeTarget, posNeg=1, tmpSgprInfo=tmpSgprInfo, comment="jump if edges required"))
module.add(self.longBranchScc1(isEdgeTargetMT1, posNeg=1, tmpSgprInfo=tmpSgprInfo, comment="jump if edges required"))
else:
module.add(SCBranchSCC1(labelName=isEdgeTargetLabel, comment="jump if edges required"))
module.add(SCBranchSCC1(labelName=isEdgeTargetMT1Label, comment="jump if edges required"))
else:
placeHolderModule = Module(placeHolder)
placeHolderModule.addComment1("jump if edges required")
Expand Down Expand Up @@ -8708,17 +8709,37 @@ def globalWriteElements(self, kernel, tPA, tPB, vectorWidths_2, vectorWidths_1,
if biasDims is None:
biasDims = [0, 1] if self.states.BiasDim == 3 else [1] if self.states.BiasDim == 2 else [0]
writeLabels = {}
splitMN = False
if True in edges:
if vectorWidths[0] != vectorWidths[1]:
splitMN = True
for beta in betas:
writeLabels[beta] = {}
for edge in edges:
writeLabels[beta]["EdgeCheck0"] = Label(self.labels.getNameInc("GW_B%u_E%u_EdgeCheck0" % ( 1 if beta else 0, 1 if edge else 0) ), "")
writeLabels[beta]["EdgeCheck1"] = Label(self.labels.getNameInc("GW_B%u_E%u_EdgeCheck1" % ( 1 if beta else 0, 1 if edge else 0) ), "")
writeLabels[beta][edge] = {}
if len(biasDims) == 1:
writeLabels[beta][edge][biasDims[0]] = Label(self.labels.getNameInc("GW_B%u_E%u" % ( 1 if beta else 0, 1 if edge else 0) ), "")
if edge:
if splitMN:
writeLabels[beta][edge][biasDims[0]] = []
writeLabels[beta][edge][biasDims[0]].append(Label(self.labels.getNameInc("GW_B%u_E%u_M" % ( 1 if beta else 0, 1 if edge else 0) ), ""))
writeLabels[beta][edge][biasDims[0]].append(Label(self.labels.getNameInc("GW_B%u_E%u_N" % ( 1 if beta else 0, 1 if edge else 0) ), ""))
else:
writeLabels[beta][edge][biasDims[0]] = [Label(self.labels.getNameInc("GW_B%u_E%u" % ( 1 if beta else 0, 1 if edge else 0) ), "")]
else:
writeLabels[beta][edge][biasDims[0]] = [Label(self.labels.getNameInc("GW_B%u_E%u" % ( 1 if beta else 0, 1 if edge else 0) ), "")]
else:
for biasDim in biasDims:
writeLabels[beta][edge][biasDim] = Label(self.labels.getNameInc("GW_B%u_E%u_BD%u" % ( 1 if beta else 0, 1 if edge else 0, biasDim) ), "")
if edge:
if splitMN:
writeLabels[beta][edge][biasDim] = []
writeLabels[beta][edge][biasDim].append(Label(self.labels.getNameInc("GW_B%u_E%u_BD%u_M" % ( 1 if beta else 0, 1 if edge else 0, biasDim) ), ""))
writeLabels[beta][edge][biasDim].append(Label(self.labels.getNameInc("GW_B%u_E%u_BD%u_N" % ( 1 if beta else 0, 1 if edge else 0, biasDim) ), ""))
else:
writeLabels[beta][edge][biasDim]= [Label(self.labels.getNameInc("GW_B%u_E%u_BD%u" % ( 1 if beta else 0, 1 if edge else 0, biasDim) ), "")]
else:
writeLabels[beta][edge][biasDim] = [Label(self.labels.getNameInc("GW_B%u_E%u_BD%u" % ( 1 if beta else 0, 1 if edge else 0, biasDim) ), "")]
endLabel = Label(self.labels.getNameInc("GW_End"), "")

# Layout
Expand Down Expand Up @@ -8858,34 +8879,47 @@ def globalWriteElements(self, kernel, tPA, tPB, vectorWidths_2, vectorWidths_1,
# by now we either jumped to E1 or stayed at E0
for idx1 in reversed(range(len(edges))):
edge = edges[idx1]
edgeModule = Module("edge_%u"%idx1)

edge_mode_pos = 0
for idx2 in range(len(biasDims)):
edge_mode_pos, currentInstLength, activationTypeStr = \
self.globalWriteElementBatch(kernel, tPA, tPB, activation,
applyAlpha, beta, edge, atomic,
vectorWidths, elements, activationLabelList,
tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationEnumStrList,
actPCMaxTempSgpr, isInsertActFunctionCallAddrCalc, toActModuleList,
edgeModule, writeLabels, endLabel,
edge_mode_pos, currentInstLength,
idx0, idx1, idx2, biasDims)
if len(biasDims) == 2:
isLongBranch = True if currentInstLength >= 16384 else False
with self.allocTmpSgpr(3) as tmpSgprInfo:
checkIsBiasDimZero = edgeModule.add(self.checkIsBiasDimZero(kernel, tmpSgprInfo, \
writeLabels[beta][edge][biasDims[1]], isLongBranch=isLongBranch), pos=edge_mode_pos)
currentInstLength += checkIsBiasDimZero.countType(Instruction)

betaModule.add(edgeModule, pos=mod_pos)
loopMN = 2 if (edge and splitMN) else 1
for idxMN in range(loopMN):
edgeStr = ""
if loopMN == 2:
edgeStr = "_M" if idxMN == 0 else "_N"
edgeModule = Module("edge_%u%s"%(idx1, edgeStr))

vectorWidthsNew = vectorWidths
elementsNew = elements
if edge and idxMN == 1:
vectorWidthsNew = [vectorWidths[0], vectorWidths[0]]
elementsNew = [elements[0], elements[0]]

edge_mode_pos = 0
for idx2 in range(len(biasDims)):
edge_mode_pos, currentInstLength, activationTypeStr = \
self.globalWriteElementBatch(kernel, tPA, tPB, activation,
applyAlpha, beta, edge, atomic,
vectorWidthsNew, elementsNew, activationLabelList,
tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationEnumStrList,
actPCMaxTempSgpr, isInsertActFunctionCallAddrCalc, toActModuleList,
edgeModule, writeLabels, endLabel,
edge_mode_pos, currentInstLength,
idx0, idx1, idx2, idxMN, biasDims)
if len(biasDims) == 2:
isLongBranch = True if currentInstLength >= 16384 else False
with self.allocTmpSgpr(3) as tmpSgprInfo:
checkIsBiasDimZero = edgeModule.add(self.checkIsBiasDimZero(kernel, tmpSgprInfo, \
writeLabels[beta][edge][biasDims[1]][idxMN], isLongBranch=isLongBranch), pos=edge_mode_pos)
currentInstLength += checkIsBiasDimZero.countType(Instruction)

betaModule.add(edgeModule, pos=mod_pos)

########################################
# branch if Edge0 or Edge1
if False in edges and True in edges:
isLongBranch = True if currentInstLength >= 16384 else False
with self.allocTmpSgpr(4) as tmpSgprInfo:
labelMT1 = writeLabels[beta][True][biasDims[0]][0] if len(writeLabels[beta][True][biasDims[0]]) == 1 else writeLabels[beta][True][biasDims[0]][1]
checkIsEdge = betaModule.add(self.checkIsEdge(kernel, tmpSgprInfo, \
writeLabels[beta][True][biasDims[0]], isLongBranch=isLongBranch), pos=mod_pos)
writeLabels[beta][True][biasDims[0]][0], labelMT1, isLongBranch=isLongBranch), pos=mod_pos)
currentInstLength += checkIsEdge.countType(Instruction)
betaModules.add(betaModule, pos=0)

Expand Down Expand Up @@ -8962,9 +8996,9 @@ def globalWriteElementBatch(self, kernel, tPA, tPB, activation, \
actPCMaxTempSgpr, isInsertActFunctionCallAddrCalc, toActModuleList, \
edgeModule, writeLabels, endLabel, \
edge_mode_pos, currentInstLength, \
idx0, idx1, idx2, biasDims):
idx0, idx1, idx2, idxMN, biasDims):
biasDim = biasDims[idx2]
edgeModule.add(writeLabels[beta][edge][biasDim])
edgeModule.add(writeLabels[beta][edge][biasDim][idxMN])
if idx2 == 0:
edge_mode_pos = len(edgeModule.items())

Expand Down