Skip to content

Commit

Permalink
Add gemm predict
Browse files Browse the repository at this point in the history
  • Loading branch information
KKyang committed May 9, 2024
1 parent 7d0dd01 commit aa5ccae
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 4 deletions.
71 changes: 71 additions & 0 deletions submodule/gemmmodel_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
################################################################################
#
# Copyright (C) 2022-2024 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
################################################################################

import sys
import os
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'gemmmodel'))
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'gemmmodel/gmodel_lib'))

from gmodel_lib.problem import GemmProblemFromSizes
from gmodel_lib.gemm_model import GemmModel
from gmodel_lib.gemm_model_types import DataType
from gmodel_lib.arch import *

def predict(transA, transB, m, n, batch_count, k, dtype, soc_s='mi300x', b_in_HBM=True, macroTiles=None):
soc = None
if soc_s == 'mi300x':
from gmodel_lib.arch import MI300X
soc = MI300X()
else:
print("SoC unavailable")
assert(0)

compute_dtype = None
if dtype == 'torch.float16' or dtype == 'torch.bfloat16':
compute_dtype = DataType.BF16
elif dtype == 'torch.float8_e4m3fn' or dtype == 'torch.float8_e5m2':
compute_dtype = DataType.FP8
else:
print("Datatype not found for GEMM")
assert(0)

p = GemmProblemFromSizes(m, n, k, batch_count, compute_dtype=compute_dtype)
if b_in_HBM:
p.force_to_cache(MemLoc.HBM, p.b)
model = GemmModel(soc)
sp = soc.default_solution_parms(p)

tiles = []
for mt in macroTiles:
tiles.append(Tile(mt0=mt[0], mt1=mt[1], unroll=mt[2], split_summation=mt[3]))
perf= model.simulate(p, model.make_solutions(p, solution_parms=sp, force_tiles=tiles))
res = []
for p in perf:
res.append([
p.solution.tile.mt0,
p.solution.tile.mt1,
p.solution.tile.unroll,
p.solution.tile.split_summation
])
return res
11 changes: 11 additions & 0 deletions tensilelite/Tensile/ClientWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,17 @@ def param(key, value):
param("use-user-args", globalParameters["UseUserArgs"])
param("rotating-buffer-size", globalParameters["RotatingBufferSize"])

# Gemmmodel
wrapperPath = os.path.join(globalParameters["ScriptPath"], "../../submodule")
param("gemm-wrapper-path", wrapperPath)
submodulePath = os.path.join(wrapperPath, "gemmmodel")
if not os.path.isdir(submodulePath) and globalParameters["GemmModelThreshold"] > 0:
print("Submodule Gemmmodel not found. Disable threshold")
threshold = 0
else:
threshold = 0 if globalParameters["GemmModelThreshold"] >= 1 else globalParameters["GemmModelThreshold"]
param("gemm-predict-threshold", threshold)


def writeClientConfig(forBenchmark, solutions, problemSizes, biasTypeArgs, biasDimArgs, activationArgs, stepName, stepBaseDir, newLibrary, codeObjectFiles, tileAwareSelection, configBase = "ClientParameters", libraryFile = None):

Expand Down
2 changes: 2 additions & 0 deletions tensilelite/Tensile/Common.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@

globalParameters["RotatingBufferSize"] = 0 # Size in MB

globalParameters["GemmModelThreshold"] = 0 # 0 or 1 is off, use a value between 0~1, e.g. 0.4

globalParameters["BuildIdKind"] = "sha1"

# Save a copy - since pytest doesn't re-run this initialization code and YAML files can override global settings - odd things can happen
Expand Down
8 changes: 5 additions & 3 deletions tensilelite/Tensile/Source/client/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright (C) 2022-2023 Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2022-2024 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -65,14 +65,16 @@ if(TENSILE_USE_OPENMP)
target_link_libraries(TensileClient PRIVATE custom_openmp_cxx)
endif()

find_package (Python3 REQUIRED COMPONENTS Interpreter Development)

add_executable(tensile_client main.cpp)
set_target_properties(tensile_client
PROPERTIES
CXX_STANDARD 20
CXX_STANDARD_REQUIRED ON
CXX_EXTENSIONS OFF)

target_link_libraries(tensile_client PRIVATE TensileHost TensileClient ${Boost_LIBRARIES})
target_include_directories(tensile_client PRIVATE ${Python3_INCLUDE_DIRS})
target_link_libraries(tensile_client PRIVATE TensileHost TensileClient ${Boost_LIBRARIES} Python3::Python)
if(TENSILE_USE_OPENMP)
target_link_libraries(tensile_client PRIVATE custom_openmp_cxx)
endif()
Expand Down
186 changes: 185 additions & 1 deletion tensilelite/Tensile/Source/client/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@
#include <boost/algorithm/string/split.hpp>
#include <boost/program_options.hpp>

// Temporarily import Python for gemmmodel
#include <Python.h>

#include <cstddef>

namespace po = boost::program_options;
Expand Down Expand Up @@ -258,6 +261,8 @@ namespace Tensile
("use-gradient", po::value<bool>()->default_value(false), "Use gradient.")
("use-user-args", po::value<bool>()->default_value(false), "Use user argument structure as kernel input.")
("rotating-buffer-size", po::value<int32_t>()->default_value(0), "Size of rotating buffer in the unit of MB.")
("gemm-predict-threshold", po::value<float>()->default_value(0), "Thresholdof the gemm prediction.")
("gemm-wrapper-path", po::value<std::string>()->default_value(""), "Path to gemm predict submodule.")
;
// clang-format on

Expand Down Expand Up @@ -469,6 +474,71 @@ namespace Tensile
} // namespace Client
} // namespace Tensile

struct solutionInfo
{
int mt0;
int mt1;
int unroll;
int splitk;

std::string key()
{
return std::to_string(mt0) + "_" + std::to_string(mt1) + "_" + std::to_string(unroll) + "_" + std::to_string(splitk);
}
};

std::string tensile2TorchDataType(Tensile::DataType dataType)
{
switch(dataType)
{
case Tensile::DataType::Half:
return "torch.float16";
break;
case Tensile::DataType::BFloat16:
return "torch.bfloat16";
break;
case Tensile::DataType::Float8:
return "torch.float8_e4m3fn";
break;
case Tensile::DataType::BFloat8:
return "torch.float8_e5m2";
break;
default:
std::cerr << "Currently data type " << int(dataType) << " is unsupported." << std::endl;
}
return "";
}

std::vector<solutionInfo> pyPredict(PyObject *pFuncPredict, PyObject *pMT, const char* transA, const char* transB, int m, int n, int batch, int k, const char* datatype)
{
// Changed with each call
// predict(transA, transB, m, n, batch_count, k, dtype, soc_s='mi300x', b_in_HBM=True, macroTiles=None)
PyObject *pArgs = Py_BuildValue("(s, s, i, i, i, i, s)", transA, transB, m, n, batch, k, datatype);
PyObject *pKargs = Py_BuildValue("{s:s, s:i, s:O}", "soc_s", "mi300x", "b_in_HBM", 1, "macroTiles", pMT);
PyObject *pRes = PyObject_Call(pFuncPredict, pArgs, pKargs);
if (pRes == NULL)
std::cout << "Failed to call function" << std::endl;

size_t len = PyList_Size(pRes);
std::vector<solutionInfo> info;
for(size_t i = 0; i < len; i++)
{
PyObject *item = PyList_GetItem(pRes, i);
PyObject *mt0 = PyList_GetItem(item, 0);
PyObject *mt1 = PyList_GetItem(item, 1);
PyObject *unroll = PyList_GetItem(item, 2);
PyObject *splitk = PyList_GetItem(item, 3);
info.push_back({static_cast<int>(PyLong_AsLong(mt0)),
static_cast<int>(PyLong_AsLong(mt1)),
static_cast<int>(PyLong_AsLong(unroll)),
static_cast<int>(PyLong_AsLong(splitk))});
}
Py_DECREF(pArgs);
Py_DECREF(pKargs);
Py_DECREF(pRes);
return info;
}

int main(int argc, const char* argv[])
{
using namespace Tensile;
Expand Down Expand Up @@ -531,6 +601,63 @@ int main(int argc, const char* argv[])
{
auto iter = library->solutions.end();
iter--;

numSolutions = 0;
for (auto iter = library->solutions.begin(); iter != library->solutions.end(); iter++)
numSolutions++;
}

// Get solution properties here
auto testThreshold = args["gemm-predict-threshold"].as<float>();
auto testedSolutions = numSolutions * testThreshold;
bool runPredict = testThreshold > 0 ? true: false;

PyObject *pSysPath, *pModule, *pFuncPredict, *pMT;
std::map<std::string, int> solutionInfoGroups;
std::vector<solutionInfo> solutionInfoList;
std::string pTransA, pTransB, pDataType;
if(runPredict)
{
std::cout << "Gemm prediction enabled. Threshold: " << testThreshold << ". Total solutions: " << numSolutions << "." << std::endl;
// Python stuffs
Py_Initialize();

pSysPath = PySys_GetObject("path");
auto modelPath = args["gemm-wrapper-path"].as<std::string>();
PyList_Append(pSysPath, PyUnicode_FromString(modelPath.c_str()));
pModule = PyImport_Import(PyUnicode_FromString("gemmmodel_wrapper"));
pFuncPredict = PyObject_GetAttrString(pModule, "predict");

for (auto iter = library->solutions.begin(); iter != library->solutions.end(); iter++)
{
// Collect all the solutions , ContractionSolution -> solutionInfo
solutionInfo si;
si.mt0 = iter->second->sizeMapping.macroTile.x;
si.mt1 = iter->second->sizeMapping.macroTile.y;
si.unroll = iter->second->sizeMapping.depthU;
si.splitk = iter->second->sizeMapping.globalSplitU;

std::map<std::string, int>::iterator sg;
auto solutionKey = si.key();
sg = solutionInfoGroups.find(solutionKey);
if (sg != solutionInfoGroups.end())
{
sg->second++;
}
else
{
solutionInfoGroups[solutionKey] = 1;
solutionInfoList.push_back(si);
}
}
// Create variables for predict
pMT = PyList_New(solutionInfoList.size());
for(size_t i = 0; i < solutionInfoList.size(); i++)
PyList_SetItem(pMT, i, Py_BuildValue("[i, i, i, i]", solutionInfoList[i].mt0, solutionInfoList[i].mt1, solutionInfoList[i].unroll, solutionInfoList[i].splitk));
// Trans, datatype
pTransA = library->solutions.begin()->second->problemType.transA ? "T" : "N";
pTransB = library->solutions.begin()->second->problemType.transB ? "T" : "N";
pDataType = tensile2TorchDataType(library->solutions.begin()->second->problemType.aType);
}

auto* ptr = new DataInitialization(args, problemFactory);
Expand Down Expand Up @@ -609,14 +736,62 @@ int main(int argc, const char* argv[])
auto inputArr
= dataInit->prepareRotatingGPUOutput(maxRotatingBufferNum, problem, inputs);
bool resetInput = false;

// Run predict
std::vector<solutionInfo> rankedVec;
std::map<std::string, int> runSolutions;

if(runPredict)
{
if(auto gemm = dynamic_cast<ContractionProblemGemm*>(problem))
{
if(gemm->d().dimensions() != 3)
std::cerr << "Wrong dimension " << gemm->d().dimensions() << std::endl;
else
{
int m = gemm->freeSizeA(0);
int n = gemm->freeSizeB(0);
int batch = gemm->batchSize(0);
int k = gemm->boundSize(0);

rankedVec = pyPredict(pFuncPredict, pMT, pTransA.c_str(), pTransB.c_str(), m, n, batch, k, pDataType.c_str());

int addSols = 0;
for(size_t r = 0; r < rankedVec.size(); r++)
{
auto rk = rankedVec[r].key();
runSolutions[rk] = 1;
addSols += solutionInfoGroups[rk];
if(addSols > testedSolutions)
break;
}
}
}
}

while(solutionIterator->moreSolutionsInProblem())
{
auto solution = solutionIterator->getSolution();
if(solution == nullptr)
throw std::runtime_error("Could not find a solution");

// Check whether we should run this solution
bool skipSolution = false;
if(rankedVec.size())
{
solutionInfo si;
si.mt0 = solution->sizeMapping.macroTile.x;
si.mt1 = solution->sizeMapping.macroTile.y;
si.unroll = solution->sizeMapping.depthU;
si.splitk = solution->sizeMapping.globalSplitU;

auto solutionKey = si.key();
if (runSolutions.find(si.key()) == runSolutions.end())
skipSolution = true;
}

listeners.preSolution(*solution);
if(solutionIterator->runCurrentSolution() && runKernels)
if(solutionIterator->runCurrentSolution() && runKernels && !skipSolution)
{
try
{
Expand Down Expand Up @@ -729,6 +904,15 @@ int main(int argc, const char* argv[])

listeners.finalizeReport();

if(runPredict)
{
Py_DECREF(pMT);
Py_DECREF(pSysPath);
Py_DECREF(pModule);
Py_DECREF(pFuncPredict);
Py_Finalize();
}

// error range in shell is [0-255]
return std::min(listeners.error(), 255);
}

0 comments on commit aa5ccae

Please sign in to comment.