Skip to content

Commit

Permalink
Seperate edge M and edge N
Browse files Browse the repository at this point in the history
edge N still use SVW settings
  • Loading branch information
KKyang committed May 27, 2024
1 parent a7529dc commit 43c64bd
Showing 1 changed file with 68 additions and 34 deletions.
102 changes: 68 additions & 34 deletions tensilelite/Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -5044,7 +5044,7 @@ def openSumAtLeastUnroll(self, kernel, prefetch, isOptNLL):
module.addSpaceLine()

placeHolder = "skipOptNLL_scc1_placeholder" if self.states.BiasDim == 3 else None
module.add(self.checkIsEdge(kernel, tmpSgprInfo, skipOptNLL, isLongBranch=isLongBranch, placeHolder=placeHolder))
module.add(self.checkIsEdge(kernel, tmpSgprInfo, skipOptNLL, skipOptNLL, isLongBranch=isLongBranch, placeHolder=placeHolder))
module.addSpaceLine()

# Check tail loop required:
Expand Down Expand Up @@ -8279,9 +8279,10 @@ def checkIsBetaZero(self, kernel, tmpSgprInfo, betaLabel, isLongBranch=False, pl
# tmpSgpr must have at least 6 free SGPR
# isEdgeTarget is the branch target if edges are required
##############################################################################
def checkIsEdge(self, kernel, tmpSgprInfo, isEdgeTarget, isLongBranch=False, placeHolder=None):
assert(isinstance(isEdgeTarget, Label))
isEdgeTargetLabel = isEdgeTarget.getLabelName()
def checkIsEdge(self, kernel, tmpSgprInfo, isEdgeTargetMT0, isEdgeTargetMT1, isLongBranch=False, placeHolder=None):
assert(isinstance(isEdgeTargetMT0, Label) and isinstance(isEdgeTargetMT1, Label))
isEdgeTargetMT0Label = isEdgeTargetMT0.getLabelName()
isEdgeTargetMT1Label = isEdgeTargetMT1.getLabelName()
module = Module("checkIsEdge")
tmpS0 = tmpSgprInfo.idx
tmpS1 = tmpS0 + 1
Expand Down Expand Up @@ -8316,9 +8317,9 @@ def checkIsEdge(self, kernel, tmpSgprInfo, isEdgeTarget, isLongBranch=False, pla
module.add(SCmpEQU32(src0=sgpr(tmpS0), src1=sgpr(tmpS0), comment="ForceEdgeStores!"))
if placeHolder == None:
if isLongBranch:
module.add(self.longBranchScc1(isEdgeTarget, posNeg=1, tmpSgprInfo=tmpSgprInfo, comment="jump if edges required"))
module.add(self.longBranchScc1(isEdgeTargetMT0, posNeg=1, tmpSgprInfo=tmpSgprInfo, comment="jump if edges required"))
else:
module.add(SCBranchSCC1(labelName=isEdgeTargetLabel, comment="jump if edges required"))
module.add(SCBranchSCC1(labelName=isEdgeTargetMT0Label, comment="jump if edges required"))
else:
placeHolderModule = Module(placeHolder)
placeHolderModule.addComment1("jump if edges required")
Expand All @@ -8344,9 +8345,9 @@ def checkIsEdge(self, kernel, tmpSgprInfo, isEdgeTarget, isLongBranch=False, pla
module.add(SCmpKGtU32(src=sgpr(tmpS0), simm16=hex(0), comment="rMT1 > 0"))
if placeHolder == None:
if isLongBranch:
module.add(self.longBranchScc1(isEdgeTarget, posNeg=1, tmpSgprInfo=tmpSgprInfo, comment="jump if edges required"))
module.add(self.longBranchScc1(isEdgeTargetMT1, posNeg=1, tmpSgprInfo=tmpSgprInfo, comment="jump if edges required"))
else:
module.add(SCBranchSCC1(labelName=isEdgeTargetLabel, comment="jump if edges required"))
module.add(SCBranchSCC1(labelName=isEdgeTargetMT1Label, comment="jump if edges required"))
else:
placeHolderModule = Module(placeHolder)
placeHolderModule.addComment1("jump if edges required")
Expand Down Expand Up @@ -8708,17 +8709,37 @@ def globalWriteElements(self, kernel, tPA, tPB, vectorWidths_2, vectorWidths_1,
if biasDims is None:
biasDims = [0, 1] if self.states.BiasDim == 3 else [1] if self.states.BiasDim == 2 else [0]
writeLabels = {}
splitMN = False
if True in edges:
if vectorWidths[0] != vectorWidths[1]:
splitMN = True
for beta in betas:
writeLabels[beta] = {}
for edge in edges:
writeLabels[beta]["EdgeCheck0"] = Label(self.labels.getNameInc("GW_B%u_E%u_EdgeCheck0" % ( 1 if beta else 0, 1 if edge else 0) ), "")
writeLabels[beta]["EdgeCheck1"] = Label(self.labels.getNameInc("GW_B%u_E%u_EdgeCheck1" % ( 1 if beta else 0, 1 if edge else 0) ), "")
writeLabels[beta][edge] = {}
if len(biasDims) == 1:
writeLabels[beta][edge][biasDims[0]] = Label(self.labels.getNameInc("GW_B%u_E%u" % ( 1 if beta else 0, 1 if edge else 0) ), "")
if edge:
if splitMN:
writeLabels[beta][edge][biasDims[0]] = []
writeLabels[beta][edge][biasDims[0]].append(Label(self.labels.getNameInc("GW_B%u_E%u_M" % ( 1 if beta else 0, 1 if edge else 0) ), ""))
writeLabels[beta][edge][biasDims[0]].append(Label(self.labels.getNameInc("GW_B%u_E%u_N" % ( 1 if beta else 0, 1 if edge else 0) ), ""))
else:
writeLabels[beta][edge][biasDims[0]] = [Label(self.labels.getNameInc("GW_B%u_E%u" % ( 1 if beta else 0, 1 if edge else 0) ), "")]
else:
writeLabels[beta][edge][biasDims[0]] = [Label(self.labels.getNameInc("GW_B%u_E%u" % ( 1 if beta else 0, 1 if edge else 0) ), "")]
else:
for biasDim in biasDims:
writeLabels[beta][edge][biasDim] = Label(self.labels.getNameInc("GW_B%u_E%u_BD%u" % ( 1 if beta else 0, 1 if edge else 0, biasDim) ), "")
if edge:
if splitMN:
writeLabels[beta][edge][biasDim] = []
writeLabels[beta][edge][biasDim].append(Label(self.labels.getNameInc("GW_B%u_E%u_BD%u_M" % ( 1 if beta else 0, 1 if edge else 0, biasDim) ), ""))
writeLabels[beta][edge][biasDim].append(Label(self.labels.getNameInc("GW_B%u_E%u_BD%u_N" % ( 1 if beta else 0, 1 if edge else 0, biasDim) ), ""))
else:
writeLabels[beta][edge][biasDim]= [Label(self.labels.getNameInc("GW_B%u_E%u_BD%u" % ( 1 if beta else 0, 1 if edge else 0, biasDim) ), "")]
else:
writeLabels[beta][edge][biasDim] = [Label(self.labels.getNameInc("GW_B%u_E%u_BD%u" % ( 1 if beta else 0, 1 if edge else 0, biasDim) ), "")]
endLabel = Label(self.labels.getNameInc("GW_End"), "")

# Layout
Expand Down Expand Up @@ -8858,34 +8879,47 @@ def globalWriteElements(self, kernel, tPA, tPB, vectorWidths_2, vectorWidths_1,
# by now we either jumped to E1 or stayed at E0
for idx1 in reversed(range(len(edges))):
edge = edges[idx1]
edgeModule = Module("edge_%u"%idx1)

edge_mode_pos = 0
for idx2 in range(len(biasDims)):
edge_mode_pos, currentInstLength, activationTypeStr = \
self.globalWriteElementBatch(kernel, tPA, tPB, activation,
applyAlpha, beta, edge, atomic,
vectorWidths, elements, activationLabelList,
tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationEnumStrList,
actPCMaxTempSgpr, isInsertActFunctionCallAddrCalc, toActModuleList,
edgeModule, writeLabels, endLabel,
edge_mode_pos, currentInstLength,
idx0, idx1, idx2, biasDims)
if len(biasDims) == 2:
isLongBranch = True if currentInstLength >= 16384 else False
with self.allocTmpSgpr(3) as tmpSgprInfo:
checkIsBiasDimZero = edgeModule.add(self.checkIsBiasDimZero(kernel, tmpSgprInfo, \
writeLabels[beta][edge][biasDims[1]], isLongBranch=isLongBranch), pos=edge_mode_pos)
currentInstLength += checkIsBiasDimZero.countType(Instruction)

betaModule.add(edgeModule, pos=mod_pos)
loopMN = 2 if (edge and splitMN) else 1
for idxMN in range(loopMN):
edgeStr = ""
if loopMN == 2:
edgeStr = "_M" if idxMN == 0 else "_N"
edgeModule = Module("edge_%u%s"%(idx1, edgeStr))

vectorWidthsNew = vectorWidths
elementsNew = elements
if edge and idxMN == 1:
vectorWidthsNew = [vectorWidths[0], vectorWidths[0]]
elementsNew = [elements[0], elements[0]]

edge_mode_pos = 0
for idx2 in range(len(biasDims)):
edge_mode_pos, currentInstLength, activationTypeStr = \
self.globalWriteElementBatch(kernel, tPA, tPB, activation,
applyAlpha, beta, edge, atomic,
vectorWidthsNew, elementsNew, activationLabelList,
tmpVgpr, cvtVgprStruct, activationSetPCStruct, activationEnumStrList,
actPCMaxTempSgpr, isInsertActFunctionCallAddrCalc, toActModuleList,
edgeModule, writeLabels, endLabel,
edge_mode_pos, currentInstLength,
idx0, idx1, idx2, idxMN, biasDims)
if len(biasDims) == 2:
isLongBranch = True if currentInstLength >= 16384 else False
with self.allocTmpSgpr(3) as tmpSgprInfo:
checkIsBiasDimZero = edgeModule.add(self.checkIsBiasDimZero(kernel, tmpSgprInfo, \
writeLabels[beta][edge][biasDims[1]][idxMN], isLongBranch=isLongBranch), pos=edge_mode_pos)
currentInstLength += checkIsBiasDimZero.countType(Instruction)

betaModule.add(edgeModule, pos=mod_pos)

########################################
# branch if Edge0 or Edge1
if False in edges and True in edges:
isLongBranch = True if currentInstLength >= 16384 else False
with self.allocTmpSgpr(4) as tmpSgprInfo:
labelMT1 = writeLabels[beta][True][biasDims[0]][0] if len(writeLabels[beta][True][biasDims[0]]) == 1 else writeLabels[beta][True][biasDims[0]][1]
checkIsEdge = betaModule.add(self.checkIsEdge(kernel, tmpSgprInfo, \
writeLabels[beta][True][biasDims[0]], isLongBranch=isLongBranch), pos=mod_pos)
writeLabels[beta][True][biasDims[0]][0], labelMT1, isLongBranch=isLongBranch), pos=mod_pos)
currentInstLength += checkIsEdge.countType(Instruction)
betaModules.add(betaModule, pos=0)

Expand Down Expand Up @@ -8962,9 +8996,9 @@ def globalWriteElementBatch(self, kernel, tPA, tPB, activation, \
actPCMaxTempSgpr, isInsertActFunctionCallAddrCalc, toActModuleList, \
edgeModule, writeLabels, endLabel, \
edge_mode_pos, currentInstLength, \
idx0, idx1, idx2, biasDims):
idx0, idx1, idx2, idxMN, biasDims):
biasDim = biasDims[idx2]
edgeModule.add(writeLabels[beta][edge][biasDim])
edgeModule.add(writeLabels[beta][edge][biasDim][idxMN])
if idx2 == 0:
edge_mode_pos = len(edgeModule.items())

Expand Down

0 comments on commit 43c64bd

Please sign in to comment.