Skip to content

Commit

Permalink
Merge pull request #1474 from zaliu/master
Browse files Browse the repository at this point in the history
merge staging db22912 into master on GO from CQE
  • Loading branch information
Benjamin Ulmer authored Mar 5, 2022
2 parents ea38f86 + db22912 commit d5eea38
Show file tree
Hide file tree
Showing 16 changed files with 1,131 additions and 323 deletions.
2 changes: 1 addition & 1 deletion HostLibraryTests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ if(TENSILE_STATIC_ONLY)
endif()

if(NOT Tensile_FOUND)
find_package(Tensile 4.32.0 EXACT REQUIRED ${TENSILE_COMPONENTS} PATHS "${CMAKE_CURRENT_SOURCE_DIR}/../Tensile")
find_package(Tensile 4.32.1 EXACT REQUIRED ${TENSILE_COMPONENTS} PATHS "${CMAKE_CURRENT_SOURCE_DIR}/../Tensile")
endif()

if(NOT TENSILE_DISABLE_CTEST)
Expand Down
6 changes: 3 additions & 3 deletions Tensile/Common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
################################################################################
# Copyright 2016-2021 Advanced Micro Devices, Inc. All rights reserved.
# Copyright 2016-2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -339,7 +339,7 @@ def getArchitectureName(gfxName):

# These type are newly supported and we would like to use a better file naming for them: _TiToTc_
# For the rest of the typed, we keep them with old existing naming.
typesUsingNewNaming = [ ('H','H','S'), ('H','S','S'), ('B','S','S'),('I8','I','I')]
typesUsingNewNaming = [ ('H','S','S'), ('B','S','S'),('I8','I','I')] # ('H','H','S') is removed bcs we are merging this case in HBH

validParameters = {
"LoopDoWhile": [ False, True ], # Source. True=DoWhile, False=For loop
Expand Down Expand Up @@ -646,7 +646,7 @@ def getArchitectureName(gfxName):
# - Tail loop can be unrolled up to InnerUnroll amount if AssertSummationElementMultiple%InnerUnroll==0
#
# 1 indicates no assertion (since all sizes are multiples of 1)
"AssertSummationElementMultiple": [1,2,4,8,16,32],
"AssertSummationElementMultiple": [1,2,4,8,16,32,64],

# Kernel generator will assume that the FreeIndex[0] size is some multiple of the element size
# and use this to optimize the kernel.
Expand Down
11 changes: 6 additions & 5 deletions Tensile/Components/LocalRead.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
################################################################################
# Copyright 2021 Advanced Micro Devices, Inc. All rights reserved.
# Copyright 2021-2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -215,8 +215,9 @@ def __call__(self, writer, bufferIdx, iui, epsi, tP):
if (kernel["DirectToLds%s" % tP["tensorChar"]] and \
kernel["GlobalLoadVectorWidth%c"%tc] * tP["bpe"] > 4):
# directToLds special case
rIdxMod = rIdx % 2
rIdxDiv = rIdx // 2
divVal = 4 if kernel["ProblemType"]["DataType"].isDoubleComplex() else 2
rIdxMod = rIdx % divVal
rIdxDiv = rIdx // divVal
offset_val = (eIdx + (vIdx * numOffsets+oIdx) * MIWaveGroupShape[tile01]) * tileStride
offset_val = (rIdxDiv * UnrollStride + offset_val + tP["localReadOffset"]) * tP["bpe"] + rIdxMod * writer.bpr
else:
Expand Down Expand Up @@ -271,11 +272,11 @@ def __call__(self, writer, bufferIdx, iui, epsi, tP):
bit3 = offset_val & 8
bit4 = offset_val & 16
bit5 = offset_val & 32
if (kernel["VectorWidth"] * tP["bpe"] == 8):
if (kernel["GlobalLoadVectorWidth%s"%tc] * tP["bpe"] == 8):
# dword_x2 case
# (bit2<<3) | (bit3 >>1) | (bit4>>1) | (bit5>>1)
newVal = (bit2<<3) | (bit3 >>1) | (bit4>>1) | (bit5>>1)
else: #if (kernel["VectorWidth"] * tP["bpe"] == 16): # most preferred case
else: #if (kernel["GlobalLoadVectorWidth%s"%tc] * tP["bpe"] == 16): # most preferred case
# dword_x4 case
# (bit2<<3) | (bit3 <<1) | (bit4>>2) | (bit5>>2)
newVal = (bit2<<3) | (bit3 <<1) | (bit4>>2) | (bit5>>2)
Expand Down
6 changes: 5 additions & 1 deletion Tensile/Components/Signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ def getDstValueType(kernel, cov):
return dstValueType

def getCptValueType(kernel, cov):
cptValueType = cptValueTypeDict[kernel["ProblemType"]["DataType"].toNameAbbrev()]
if kernel["ProblemType"]["DataType"].isHalf() and kernel["ProblemType"]["HighPrecisionAccumulate"]:
cptValueType = "F32"
else:
cptValueType = cptValueTypeDict[kernel["ProblemType"]["DataType"].toNameAbbrev()]

if cov == "V3":
cptValueType = cptValueType.lower()
return cptValueType
Expand Down
94 changes: 56 additions & 38 deletions Tensile/KernelWriter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
################################################################################
# Copyright 2016-2021 Advanced Micro Devices, Inc. All rights reserved.
# Copyright 2016-2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -399,7 +399,7 @@ def assignParamSplitLds(numMfmaBetweenLWandBarrier):
# to schedule global read for DTV after lwEndMfmaIndex or execute PostLoop after StoreC in NoLoadLoop
if (kernel["PrefetchGlobalRead"] == 2 and (kernel["DirectToVgprA"] or kernel["DirectToVgprB"])) or \
(lastLoop and kernel["StoreCInUnrollPostLoop"]):
self.lwEndMfmaIndex = min(self.lwEndMfmaIndex, numMfmaPerIter * (kernel["LoopIters"] - 1))
self.lwEndMfmaIndex = min(self.lwEndMfmaIndex, numMfmaPerIter * (kernel["LoopIters"] - 1) - 1)
localWriteEndIter = self.lwEndMfmaIndex//numMfmaPerIter
localWriteEndIter = min(kernel["LoopIters"] - 1, localWriteEndIter)
assert localWriteEndIter < kernel["LoopIters"]
Expand Down Expand Up @@ -612,6 +612,16 @@ def assignParamSplitLds(numMfmaBetweenLWandBarrier):

assert not itemsGRToSched # should have scheduled everything already, itemsGRToSched should be empty

# adjustment for StoreCInUnroll
# lastLoop case, make the last perIterGlobalReadCode[] (LoopIters-1) empty
# otherwise, mixing global read inc code and StoreCInUnroll post code could cause memory access issue
if kernel["StoreCInUnroll"] and lastLoop:
lastIter = kernel["LoopIters"] - 1
prevLastIter = max(0, lastIter - 1)
if prevLastIter < lastIter:
while self.perIterGlobalReadCode[lastIter].items():
self.perIterGlobalReadCode[prevLastIter].addCode(self.perIterGlobalReadCode[lastIter].items().pop(0))

self.perIterGlobalReadCode[endIter-1].addCode(self.globalReadACode.footer)
self.perIterGlobalReadCode[endIter-1].addCode(self.globalReadBCode.footer)

Expand Down Expand Up @@ -689,13 +699,9 @@ def assignParamSplitLds(numMfmaBetweenLWandBarrier):
readsToWait = len(list(self.localWriteACode.items())) + len(list(self.localWriteBCode.items()))
readsToWaitDTV = 0
# add waitcnt for DirectToVgpr. Delaying wait for DirectToVgpr global read
if kernel["DirectToVgprA"]:
# readsToWait += kernel["NumLoadsPerpendicularA"] * kernel["NumLoadsCoalescedA"] * self.numReadVectorComponentsA
# Except for PGR=2, add the number of global read with DirectToVgpr
readsToWaitDTV += len(list(self.globalReadACode.middle.items()))
elif kernel["DirectToVgprB"]:
# readsToWait += kernel["NumLoadsPerpendicularB"] * kernel["NumLoadsCoalescedB"] * self.numReadVectorComponentsB
# Except for PGR=2, add the number of global read with DirectToVgpr
if kernel["DirectToVgprA"] or kernel["DirectToVgprB"]:
# DirectToVgprA case, actual A load is in self.globalReadBCode (due to swap).
# Need to check self.globalReadBCode
readsToWaitDTV += len(list(self.globalReadBCode.middle.items()))
# add waitcnt for StoreCInUnroll. Delaying wait for Load C
readsToWait += numGlobalReadC
Expand Down Expand Up @@ -742,14 +748,20 @@ def assignParamSplitLds(numMfmaBetweenLWandBarrier):
if kernel["StoreCInUnroll"] or kernel["PrefetchGlobalRead"]==2:
if "s_waitcnt" in str(item) and "__placeholder__" in str(item):
# waitcnt adjustment for StoreCInUnroll
readsToWaitAdjust = readsToWait - numGlobalReadC
readsToWaitAdjust = readsToWait + readsToWaitDTV - numGlobalReadC
if kernel["PrefetchGlobalRead"]==2:
# PGR=2 special cases
if (kernel["AtomicAddC"] or not kernel["ProblemType"]["UseBeta"]):
# no Load C case
if not firstIter:
# PGR=2 and not firstIter case, __placeholder__ includes num of storeC from previous Iter
readsToWaitAdjust += readsToWaitAdjustForStoreC
else:
# Load C case
# adjustment for waitcnt for loadC
if kernel["StoreCInUnroll"] and self.StoreCUnrollLoadCWaitComment in str(item):
# readsToWaitDTV should not be added for loadC waitcnt
readsToWaitAdjust -= readsToWaitDTV
if kernel["NoLdsWriteCode"]:
# DirectToLds or DirectToVgpr for both A and B case, use the number of global read for both A and B as vmcnt
readsToWaitAdjust = len(list(self.globalReadACode.middle.items())) + len(list(self.globalReadBCode.middle.items()))
Expand Down Expand Up @@ -827,8 +839,6 @@ def makeSubIterSchedule(self, kernel, localReadCode, iteration, pointerLWCode, p
localWriteCode = self.perIterLocalWriteCode[iteration]
isBarrier = kernel["LoopIters"] - self.numItersPLR
hasLocalRead = localReadCode.countType(Code.LocalReadInst)
if kernel["StoreCInUnroll"]:
storeCUnrollPostCodeList = list(self.StoreCUnrollPostCode.items())
# Default schedule is other, local reads, then local writes:
if self.scheduleIterAlg==0:
# simple schedule, just add the modules in-order
Expand Down Expand Up @@ -1349,6 +1359,10 @@ def makeSubIterSchedule(self, kernel, localReadCode, iteration, pointerLWCode, p
numLoadVgpr = len(list(globalReadCodeDTV.items()))
if numLoadVgpr > 0:
interval = roundUp(numMfmaPerIter / origLenGlobalReadCodeDTV)
if kernel["ProblemType"]["DataType"].isDoubleComplex() and (kernel["MIWaveTile"][0] // kernel["VectorWidth"]) > 1:
# adjustment for double complex
# limit the max of interval up to 4 if (kernel["MIWaveTile"][0] // kernel["VectorWidth"]) > 1
interval = min(4, interval)
numInstToInsert = roundUp(origLenGlobalReadCodeDTV / numMfmaPerIter)
remainingTimesToInsert = roundUp(numLoadVgpr / numInstToInsert)
insertMfmaIndex = kernel["LoopIters"] * numMfmaPerIter - 1 - interval * (remainingTimesToInsert - 1)
Expand All @@ -1360,14 +1374,18 @@ def makeSubIterSchedule(self, kernel, localReadCode, iteration, pointerLWCode, p
# scheduled StoreCInUnrollPostProcess
####
if kernel["StoreCInUnroll"]:
numItems = len(storeCUnrollPostCodeList)
if numItems > 0:
numItems = len(self.StoreCUnrollPostCode.items())
# need to make sure all global read inc is already generated
# (iteration should be the last one)
if numItems > 0 and iteration == kernel["LoopIters"] - 1 and len(globalReadCode.items()) == 0:
totalMfma = kernel["LoopIters"] * numMfmaPerIter
interval = 1
numInstToInsert = 1
numInstToInsert = roundUp(numItems / (totalMfma - mfmaIndex))
remainingTimesToInsert = roundUp(numItems / numInstToInsert)
insertMfmaIndex = kernel["LoopIters"] * numMfmaPerIter - 2 - interval * (remainingTimesToInsert - 1)
if mfmaIndex == insertMfmaIndex:
iterCode.addCode(storeCUnrollPostCodeList.pop(0))
insertMfmaIndex = totalMfma - 2 - interval * (remainingTimesToInsert - 1)
if mfmaIndex >= insertMfmaIndex:
for i in range(numInstToInsert):
iterCode.addCode(self.StoreCUnrollPostCode.items().pop(0))

if kernel["StorePriorityOpt"] and kernel["PrefetchGlobalRead"] == 2 and \
(mfmaIndex == self.barrierMfmaIndex or mfmaIndex == (kernel["LoopIters"] * numMfmaPerIter - 1)):
Expand Down Expand Up @@ -2650,10 +2668,6 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ):
# on each unroll iteration.
self.doShadowInit = 1 # 1 is just store setup

# for PersistentKernel with HPA mode, we can do the alpha, beta conversion f16->f32 only once outside the PK-loop
if kernel["PersistentKernel"]:
kl.append( self.checkAlphaBetaForHPA(kernel))

if self.prefetchAcrossPersistent:
# SrdC/D init before persistent loop
kl.append(self.globalWriteWorkGroupInitBeforePersistentLoop(kernel))
Expand Down Expand Up @@ -2853,15 +2867,17 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ):
self.saveLocalPointers(kernel)
# deepCopy packCode for OptNLL noLoadLoop
deepCopyPack = copy.deepcopy(pack)
# keep StoreCInUnrollPreCode for the next noLoadLoop
# keep StoreCInUnrollPreCode, StoreCUnrollPostCode for the next noLoadLoop
if kernel["StoreCInUnroll"]:
StoreCUnrollPreCodeBackup = copy.deepcopy(self.StoreCUnrollPreCode)
StoreCUnrollPostCodeBackup = copy.deepcopy(self.StoreCUnrollPostCode)
isPap = self.prefetchAcrossPersistent
kl += self.noLoadLoop(kernel, tensorParametersA, tensorParametersB, isOptNLL=True, isPap=isPap, isNGLL=False, pack=deepCopyPack)
self.restoreLocalPointers(kernel)
# restore StoreCInUnroll related parameters
if kernel["StoreCInUnroll"]:
self.StoreCUnrollPreCode = StoreCUnrollPreCodeBackup
self.StoreCUnrollPostCode = StoreCUnrollPostCodeBackup
self.StoreCUnrollLoopCodeStarted = 0

papMode = self.prefetchAcrossPersistent and kernel["PrefetchAcrossPersistentMode"] == 1
Expand Down Expand Up @@ -4066,14 +4082,7 @@ def openLoopCopy(self, kernel, lc):
@abc.abstractmethod
def endSummation(self, kernel, label = None, isOptNLL = False):
return ""

##############################################################################
# Convert Alpha, Beta from F16 to F32 for HPA
##############################################################################
@abc.abstractmethod
def checkAlphaBetaForHPA(self, kernel):
return ""


##############################################################################
# MAC Iteration
# useMacro : if true, call the MAC* macro. If False, inline the MACs
Expand Down Expand Up @@ -4914,24 +4923,21 @@ def generateStoreCCodeInUnrollLoop(self, kernel, odd, isLast=False):
kStr = ""
for x in self.AlphaOpTemplate.items():
kStr += str(x)

if kStr != "":
self.StoreCUnrollPreCode.addText(kStr)

# count the number of items before StoreC (before beta)
self.numItemsBeforeStoreC = len(list(self.StoreCUnrollPreCode.items()))

# Beta
kStrBeta = ""
for x in self.BetaOpTemplate.items():
kStrBeta += str(x)

# StoreC

# put marker comment to recognize start point of StoreC code
# this must be the first item in self.StoreCUnrollCode.
self.StoreCUnrollCode.addComment0(self.StoreCUnrollStartComment)
# add necessary dummy based on number of mfma instructions between local write items
numMfma = 1 if kernel["LocalWritePerMfma"] == -1 else roundUp(1/kernel["LocalWritePerMfma"])
# put enough interval (=3) for LocalWritePerMfma == -1 case
numMfma = 3 if kernel["LocalWritePerMfma"] == -1 else roundUp(1/kernel["LocalWritePerMfma"])
n = self.numItemsBeforeStoreC - numMfma # first numMfma items are inserted at the start comment and following mfmas
while n >= numMfma:
self.StoreCUnrollCode.addText("")
Expand All @@ -4941,7 +4947,19 @@ def generateStoreCCodeInUnrollLoop(self, kernel, odd, isLast=False):
imod = Code.Module()
imod.addComment0(self.StoreCUnrollStoreStartComment)
StartComment = str(imod)
# number of instructions(items) of increment code betwwen MFMAs

# Beta
kStrBeta = ""
for x in self.BetaOpTemplate.items():
kStrBeta += str(x)
# double complex case, put beta instruction separately
if kStrBeta != "" and kernel["ProblemType"]["DestDataType"].isDoubleComplex():
# combine beta code with first StoreC comment to avoid generating beta before alpha
self.StoreCUnrollCode.addText(kStrBeta + StartComment)
kStrBeta = ""
StartComment = ""

# number of instructions(items) of increment code between MFMAs
putCount = 1
postProcessListIndex = 0
# generate post process for StoreCInUnroll loop
Expand Down
Loading

0 comments on commit d5eea38

Please sign in to comment.