Skip to content

Commit

Permalink
Merge pull request #10 from IST-DASLab/gemm-optimizations
Browse files Browse the repository at this point in the history
Gemm optimizations
  • Loading branch information
efrantar committed Jan 26, 2024
2 parents b930c72 + c99cd6f commit dd58055
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 24 deletions.
2 changes: 1 addition & 1 deletion bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def benchmark_dense(A, B, C):
}

def benchmark_quant(A, B, C, s, thread_k, thread_n, sms):
workspace = torch.zeros(256, device=torch.device('cuda:0'))
workspace = torch.zeros(C.shape[1] // 128 * 16, device=torch.device('cuda:0'))
res = benchmark(lambda: marlin.mul(A, B, C, s, workspace, thread_k, thread_n, sms))
return {
's': res,
Expand Down
11 changes: 6 additions & 5 deletions marlin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,19 @@

import marlin_cuda

def mul(A, B, C, s, workspace, thread_k=-1, thread_n=-1, sms=-1):
def mul(A, B, C, s, workspace, thread_k=-1, thread_n=-1, sms=-1, max_par=16):
"""Marlin FP16xINT4 multiply; can be used within `torch.compile`.
@A: `torch.half` input matrix of shape `(m, k)` in standard row-major layout
@B: `torch.int` weight matrix of original shape `(k, n)` in Marlin format; see `Layer.pack()`
@C: `torch.half` out matrix of shape `(m, n)` in standard row-major layout
@s: `torch.half` scales of shape `(m / groupsize, n)`
@workspace: `torch.int` tensor with at least `n / 128` entries that are all zero
@workspace: `torch.int` tensor with at least `n / 128 * max_par` entries that are all zero
@thread_k: `k` size of a thread_tile in `B` (can usually be left as auto -1)
@thread_n: `n` size of a thread_tile in `B` (can usually be left as auto -1)
@sms: number of SMs to use for the kernel (can usually be left as auto -1)
@max_par: maximum number of batch 64 problems to solve in parallel for large input sizes
"""
marlin_cuda.mul(A, B, C, s, workspace, thread_k, thread_n, sms)
marlin_cuda.mul(A, B, C, s, workspace, thread_k, thread_n, sms, max_par)


# Precompute permutations for Marlin weight and scale shuffling
Expand Down Expand Up @@ -90,8 +91,8 @@ def __init__(self, infeatures, outfeatures, groupsize=-1):
self.groupsize = groupsize
self.register_buffer('B', torch.empty((self.k // 16, self.n * 16 // 8), dtype=torch.int))
self.register_buffer('s', torch.empty((self.k // groupsize, self.n), dtype=torch.half))
# 128 is currently the minimum `tile_n`, hence it gives the maximum workspace size
self.register_buffer('workspace', torch.zeros(self.n // 128, dtype=torch.int), persistent=False)
# 128 is currently the minimum `tile_n`, hence it gives the maximum workspace size; 16 is the default `max_par`
self.register_buffer('workspace', torch.zeros(self.n // 128 * 16, dtype=torch.int), persistent=False)

def forward(self, A):
C = torch.empty(A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device)
Expand Down
11 changes: 8 additions & 3 deletions marlin/marlin_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ int marlin_cuda(
cudaStream_t stream = 0,
int thread_k = -1,
int thread_n = -1,
int sms = -1
int sms = -1,
int max_par = 16
);

const int ERR_PROB_SHAPE = 1;
Expand All @@ -48,14 +49,17 @@ void mul(
torch::Tensor& workspace,
int thread_k = -1,
int thread_n = -1,
int sms = -1
int sms = -1,
int max_par = 8
) {
int prob_m = A.size(0);
int prob_n = C.size(1);
int prob_k = A.size(1);
int groupsize = (s.size(0) == 1) ? -1 : prob_k / s.size(0);
if (groupsize != -1 && groupsize * s.size(0) != prob_k)
AT_ERROR("k=", prob_k, " not compatible with ", s.size(0), " groups.");
if (workspace.numel() < prob_n / 128 * max_par)
AT_ERROR("workspace must be of size at least ", prob_n / 128 * max_par, ".");
int dev = A.get_device();
int err = marlin_cuda(
A.data_ptr(),
Expand All @@ -69,7 +73,8 @@ void mul(
at::cuda::getCurrentCUDAStream(dev),
thread_k,
thread_n,
sms
sms,
max_par
);
if (err == ERR_PROB_SHAPE) {
AT_ERROR(
Expand Down
59 changes: 48 additions & 11 deletions marlin/marlin_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>


constexpr int ceildiv(int a, int b) {
Expand Down Expand Up @@ -211,35 +212,51 @@ __global__ void Marlin(
// While this kind of partitioning makes things somewhat more complicated, it ensures good utilization of all SMs
// for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as
// possible.

// For larger GEMMs we run multiple batchsize 64 versions in parallel for a better partitioning with less reductions
int parallel = 1;
if (prob_m > 16 * thread_m_blocks) {
parallel = prob_m / (16 * thread_m_blocks);
prob_m = 16 * thread_m_blocks;
}

int k_tiles = prob_k / 16 / thread_k_blocks;
int n_tiles = prob_n / 16 / thread_n_blocks;
int iters = ceildiv(k_tiles * n_tiles, gridDim.x);
int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
// Ensure that the number of tiles in each stripe is a multiple of the groupsize; this avoids an annoying special case
// where a stripe starts in the middle of group.
if (group_blocks != -1)
iters = (group_blocks / thread_k_blocks) * ceildiv(iters, (group_blocks / thread_k_blocks));

int slice_row = (iters * blockIdx.x) % k_tiles;
int slice_col = (iters * blockIdx.x) / k_tiles;
int slice_col_par = (iters * blockIdx.x) / k_tiles;
int slice_col = slice_col_par;
int slice_iters; // number of threadblock tiles in the current slice
int slice_count = 0; // total number of active threadblocks in the current slice
int slice_idx; // index of threadblock in current slice; numbered bottom to top

// We can easily implement parallel problem execution by just remapping indices and advancing global pointers
if (slice_col_par >= n_tiles) {
A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
locks += (slice_col_par / n_tiles) * n_tiles;
slice_col = slice_col_par % n_tiles;
}

// Compute all information about the current slice which is required for synchronization.
auto init_slice = [&] () {
slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col + slice_row);
if (slice_iters < 0 || slice_col >= n_tiles)
slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel)
slice_iters = 0;
if (slice_iters == 0)
return;
if (slice_row + slice_iters > k_tiles)
slice_iters = k_tiles - slice_row;
slice_count = 1;
slice_idx = 0;
int col_first = iters * ceildiv(k_tiles * slice_col, iters);
if (col_first <= k_tiles * (slice_col + 1)) {
int col_off = col_first - k_tiles * slice_col;
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
if (col_first <= k_tiles * (slice_col_par + 1)) {
int col_off = col_first - k_tiles * slice_col_par;
slice_count = ceildiv(k_tiles - col_off, iters);
if (col_off > 0)
slice_count++;
Expand All @@ -252,6 +269,12 @@ __global__ void Marlin(
slice_idx--;
}
}
if (slice_col == n_tiles) {
A += 16 * thread_m_blocks * prob_k / 8;
C += 16 * thread_m_blocks * prob_n / 8;
locks += n_tiles;
slice_col = 0;
}
};
init_slice();

Expand Down Expand Up @@ -656,13 +679,19 @@ __global__ void Marlin(
if (last) // only the last block in a slice actually writes the result
write_result();
slice_row = 0;
slice_col_par++;
slice_col++;
init_slice();
if (slice_iters) {
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
if (slice_col == 0) {
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] -= b_gl_stride;
}
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
start_pipes();
}
Expand Down Expand Up @@ -713,10 +742,12 @@ int marlin_cuda(
cudaStream_t stream = 0,
int thread_k = -1,
int thread_n = -1,
int sms = -1
int sms = -1,
int max_par = 16
) {
int tot_m = prob_m;
int tot_m_blocks = ceildiv(tot_m, 16);
int pad = 16 * tot_m_blocks - tot_m;

if (sms == -1)
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
Expand Down Expand Up @@ -753,9 +784,15 @@ int marlin_cuda(
for (int i = 0; i < tot_m_blocks; i += 4) {
int thread_m_blocks = tot_m_blocks - i;
prob_m = tot_m - 16 * i;
int par = 1;
if (thread_m_blocks > 4) {
// Note that parallel > 1 currently only works for inputs without any padding
par = (16 * thread_m_blocks - pad) / 64;
if (par > max_par)
par = max_par;
prob_m = 64 * par;
i += 4 * (par - 1);
thread_m_blocks = 4;
prob_m = 64;
}

// For compilation speed, we only define the kernel configurations that have seemed useful (in terms of performance)
Expand All @@ -774,8 +811,8 @@ int marlin_cuda(
else
ret = ERR_KERN_SHAPE;

A_ptr += 16 * thread_m_blocks * (prob_k / 8);
C_ptr += 16 * thread_m_blocks * (prob_n / 8);
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
}

return ret;
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

setup(
name='marlin',
version='0.1',
version='0.1.1',
author='Elias Frantar',
author_email='elias.frantar@ist.ac.at',
description='Highly optimized FP16xINT4 CUDA matmul kernel.',
Expand Down
6 changes: 3 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,19 @@ def reshape(w):
class Test(unittest.TestCase):

def run_problem(self, m, n, k, thread_k, thread_n, groupsize=-1):
print('% 4d % 6d % 6d % 4d % 4d % 4d' % (m, n, k, thread_k, thread_n, groupsize))
print('% 5d % 6d % 6d % 4d % 4d % 4d' % (m, n, k, thread_k, thread_n, groupsize))
A = torch.randn((m, k), dtype=torch.half, device=DEV)
B_ref, B, s = gen_quant4(k, n, groupsize=groupsize)
C = torch.zeros((m, n), dtype=torch.half, device=DEV)
C_ref = torch.matmul(A, B_ref)
workspace = torch.zeros(n // 128, device=DEV)
workspace = torch.zeros(n // 128 * 16, device=DEV)
marlin.mul(A, B, C, s, workspace, thread_k, thread_n, -1)
torch.cuda.synchronize()
self.assertLess(torch.mean(torch.abs(C - C_ref)) / torch.mean(torch.abs(C_ref)), 0.001)

def test_tiles(self):
print()
for m in [1, 2, 3, 4, 8, 12, 16, 24, 32, 48, 64, 128, 152]:
for m in [1, 2, 3, 4, 8, 12, 16, 24, 32, 48, 64, 118, 128, 152, 768, 1024]:
for thread_k, thread_n in [(64, 256), (128, 128)]:
if m > 16 and thread_k == 128:
continue
Expand Down

0 comments on commit dd58055

Please sign in to comment.