Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions include/sofieBLAS/backends/cuda/sofieBLAS_cublas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ struct PairEq {

class BlasCuda {
cublasLtHandle_t ltHandle = nullptr;
cublasHandle_t handle = nullptr;
cublasLtMatmulDesc_t operationDesc = nullptr;
cublasLtMatmulPreference_t preference = nullptr;
void *d_workspace = nullptr;
Expand All @@ -72,7 +71,6 @@ class BlasCuda {
BlasCuda(alpaka::QueueCudaRtNonBlocking &queue) : m_queue{queue} {
stream = static_cast<cudaStream_t>(m_queue.getNativeHandle());
CHECK_CUBLAS(cublasLtCreate(&ltHandle));
CHECK_CUBLAS(cublasCreate(&handle));
heuristic = {};
CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, CUBLAS_COMPUTE_32F,
CUDA_R_32F));
Expand Down Expand Up @@ -118,10 +116,10 @@ class BlasCuda {
}
}

void AddLayoutConfig(std::size_t m, std::size_t n, std::size_t k) {
CheckAndAddLayout(k, m);
CheckAndAddLayout(k, n);
CheckAndAddLayout(m, n);
void AddLayoutConfig(std::size_t m, std::size_t n, std::size_t k, std::size_t lda, std::size_t ldb, std::size_t ldc) {
CheckAndAddLayout(k, m, lda);
CheckAndAddLayout(k, n, ldb);
CheckAndAddLayout(m, n, ldc);
}

template <typename T, typename TIdx>
Expand Down Expand Up @@ -171,7 +169,6 @@ gemm(char transa, char transb, const unsigned int m,
1,
&localHeuristic,
&returnedResults));

if (returnedResults == 0) {
cublasLtMatmulDescDestroy(localDesc);
std::cerr << "No suitable cuBLASLt algorithm found!\n";
Expand Down Expand Up @@ -238,7 +235,8 @@ gemmrelu(char transa, char transb, const unsigned int m,
1,
&localHeuristic,
&error_flag));

std::cout << "Requested workspace: "
<< localHeuristic.workspaceSize << std::endl;
if (error_flag == 0) {
cublasLtMatmulDescDestroy(localDesc);
std::cerr << "No suitable cuBLASLt algorithm found!\n";
Expand Down Expand Up @@ -313,11 +311,10 @@ gemmrelu(char transa, char transb, const unsigned int m,
private:
alpaka::QueueCudaRtNonBlocking m_queue;

void CheckAndAddLayout(size_t rows, size_t cols) {
void CheckAndAddLayout(size_t rows, size_t cols, size_t ld) {
auto key = std::make_pair(rows, cols);
if (LayoutStore.find(key) == LayoutStore.end()) {
cublasLtMatrixLayout_t temp = nullptr;
size_t ld = rows;
CHECK_CUBLAS(
cublasLtMatrixLayoutCreate(&temp, CUDA_R_32F, rows, cols, ld));
LayoutStore.emplace(key, temp);
Expand Down