Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor custom gemm heuristics #56

Merged
merged 8 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import List, Optional

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from vllm.distributed import (divide, get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -90,13 +89,9 @@ def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = layer.weight
if self.separate_bias_add:
if bias is not None:
return tgemm.mm(x, weight) + bias
return tgemm.mm(x, weight)
elif bias is not None:
return F.linear(x, weight, bias)
return tgemm.mm(x, weight)
if self.separate_bias_add and bias is not None:
return tgemm.mm(x, weight) + bias
return tgemm.mm(x, weight, bias)
mawong-amd marked this conversation as resolved.
Show resolved Hide resolved


class LinearBase(torch.nn.Module):
Expand Down
64 changes: 31 additions & 33 deletions vllm/model_executor/layers/tuned_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self):
self.bestsols = {}
self.load_best_sols()
self.create_ds()
self.CuCount = torch.cuda.get_device_properties(
self.cu_count = torch.cuda.get_device_properties(
device='cuda').multi_processor_count

if (self.save_gemm == 1):
Expand Down Expand Up @@ -51,7 +51,27 @@ def create_ds(self):
def query_sol(self, m, n, k):
return self.solids.get((m, n, k), (0, 0))

def mm(self, inp, weights):
def apply_skinny(self, m, n, k, inp_view, weights):
if inp_view.dtype != torch.float16 or k % 8 != 0:
return None
if m > 8 and n <= 4:
out = torch.empty(inp_view.shape[0],
weights.shape[0],
dtype=inp_view.dtype,
device='cuda')
_custom_C.wvSpltK(weights, inp_view, out, n, self.cu_count)
return out
elif m % 4 == 0 and n == 1 and k <= 8192:
out = torch.empty(inp_view.shape[0],
weights.shape[0],
dtype=inp_view.dtype,
device='cuda')
_custom_C.LLMM1(weights, inp_view, out, 4)
return out
else:
return None

def mm(self, inp, weights, bias=None):
# F.Linear can take a 3 dimensional input. vllm
# uses this for linear units. However, sampler
# will use torch.matmul with 2 dimensions only
Expand All @@ -61,7 +81,6 @@ def mm(self, inp, weights):
else:
inp_view = inp
batched = False
#print(f'>>>inp_view {inp_view.shape}')
if self.extensions_created is False:
rocb_create_extension()
hipb_create_extension()
Expand All @@ -70,16 +89,15 @@ def mm(self, inp, weights):
n = inp_view.shape[0]
k = inp_view.shape[1]
soltype, solidx = self.query_sol(m=m, n=n, k=k)
if soltype == 1:
#print(">>> found hipblas")
out = self.apply_skinny(m, n, k, inp_view, weights)
if out is not None:
pass
elif soltype == 1:
out = hipb_mm(inp_view, weights.t(), solidx)
elif soltype == 2:
#print(">>> found rocblas")
out = rocb_mm(inp_view, weights.t(), solidx)
else:
if (self.save_gemm == 1):
#print('>>>Tgemm Default',inp_view.shape,
# inp.shape,weights.shape,soltype,solidx)
self.tuned_df = pd.concat([
self.tuned_df,
pd.DataFrame({
Expand All @@ -89,32 +107,12 @@ def mm(self, inp, weights):
})
]).drop_duplicates()
self.tuned_df.to_csv(self.untune_path, index=False)

if ((n == 4 or n == 3 or n == 2 or n == 1) and k % 8 == 0
and inp_view.dtype == torch.float16):
out = torch.empty(inp_view.shape[0],
weights.shape[0],
dtype=inp_view.dtype,
device='cuda')
_custom_C.wvSpltK(weights, inp_view, out, n, self.CuCount)
elif n == 1 and inp_view.dtype == torch.float16:
out = torch.empty(inp_view.shape[0],
weights.shape[0],
dtype=inp_view.dtype,
device='cuda')
if (k == 8192 and
(m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
_custom_C.LLMM1(weights, inp_view, out, 8)
elif k <= 8192 and k % 8 == 0 and m % 4 == 0:
_custom_C.LLMM1(weights, inp_view, out, 4)
mawong-amd marked this conversation as resolved.
Show resolved Hide resolved
else:
out = F.linear(inp_view, weights)
else:
out = F.linear(inp_view, weights)
return F.linear(inp, weights, bias)
if batched:
return out.view(inp.shape[0], inp.shape[1], weights.shape[0])
else:
return out
out = out.view(inp.shape[0], inp.shape[1], weights.shape[0])
if bias is not None:
return out + bias
return out


tgemm = TunedGemm()
Loading