forked from huggingface/text-generation-inference
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
<!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
- Loading branch information
Showing
59 changed files
with
2,060 additions
and
1,915 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from text_generation_server.layers.tensor_parallel import ( | ||
TensorParallelColumnLinear, | ||
TensorParallelRowLinear, | ||
TensorParallelEmbedding, | ||
) | ||
from text_generation_server.layers.speculative import SpeculativeHead | ||
from text_generation_server.layers.linear import ( | ||
get_linear, | ||
FastLinear, | ||
) | ||
|
||
# Just to add the `load` methods. | ||
from text_generation_server.layers.layernorm import load_layer_norm | ||
from text_generation_server.layers.conv import load_conv2d |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import torch | ||
from loguru import logger | ||
from functools import lru_cache | ||
import bitsandbytes as bnb | ||
from bitsandbytes.nn import Int8Params, Params4bit | ||
|
||
|
||
@lru_cache(1) | ||
def warn_deprecate_bnb(): | ||
logger.warning( | ||
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce" | ||
) | ||
|
||
|
||
class Linear8bitLt(torch.nn.Module): | ||
def __init__( | ||
self, | ||
weight, | ||
bias, | ||
has_fp16_weights=True, | ||
memory_efficient_backward=False, | ||
threshold=0.0, | ||
index=None, | ||
): | ||
super().__init__() | ||
assert ( | ||
not memory_efficient_backward | ||
), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" | ||
self.state = bnb.MatmulLtState() | ||
self.index = index | ||
|
||
# Necessary for stacked layers | ||
self.state.threshold = threshold | ||
self.state.has_fp16_weights = has_fp16_weights | ||
self.state.memory_efficient_backward = memory_efficient_backward | ||
if threshold > 0.0 and not has_fp16_weights: | ||
self.state.use_pool = True | ||
|
||
self.weight = Int8Params( | ||
weight.data, | ||
has_fp16_weights=has_fp16_weights, | ||
requires_grad=has_fp16_weights, | ||
) | ||
self.weight.cuda(weight.device) | ||
self.bias = bias | ||
|
||
def init_8bit_state(self): | ||
self.state.CB = self.weight.CB | ||
self.state.SCB = self.weight.SCB | ||
self.weight.CB = None | ||
self.weight.SCB = None | ||
|
||
def forward(self, x: torch.Tensor): | ||
self.state.is_training = self.training | ||
if self.weight.CB is not None: | ||
self.init_8bit_state() | ||
|
||
# weights are cast automatically as Int8Params, but the bias has to be cast manually | ||
if self.bias is not None and self.bias.dtype != x.dtype: | ||
self.bias.data = self.bias.data.to(x.dtype) | ||
|
||
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) | ||
|
||
if not self.state.has_fp16_weights: | ||
if self.state.CB is not None and self.state.CxB is not None: | ||
# we converted 8-bit row major to turing/ampere format in the first inference pass | ||
# we no longer need the row-major weight | ||
del self.state.CB | ||
self.weight.data = self.state.CxB | ||
return out | ||
|
||
|
||
class Linear4bit(nn.Module): | ||
def __init__(self, weight, bias, quant_type): | ||
super().__init__() | ||
self.weight = Params4bit( | ||
weight.data, | ||
requires_grad=False, | ||
compress_statistics=True, | ||
quant_type=quant_type, | ||
) | ||
self.compute_dtype = None | ||
self.weight.cuda(weight.device) | ||
self.bias = bias | ||
|
||
def forward(self, x: torch.Tensor): | ||
# weights are cast automatically as Int8Params, but the bias has to be cast manually | ||
if self.bias is not None and self.bias.dtype != x.dtype: | ||
self.bias.data = self.bias.data.to(x.dtype) | ||
|
||
if getattr(self.weight, "quant_state", None) is None: | ||
print( | ||
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first." | ||
) | ||
inp_dtype = x.dtype | ||
if self.compute_dtype is not None: | ||
x = x.to(self.compute_dtype) | ||
|
||
bias = None if self.bias is None else self.bias.to(self.compute_dtype) | ||
out = bnb.matmul_4bit( | ||
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state | ||
) | ||
|
||
out = out.to(inp_dtype) | ||
|
||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from accelerate import init_empty_weights | ||
import torch | ||
|
||
|
||
@classmethod | ||
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride): | ||
weight = weights.get_tensor(f"{prefix}.weight") | ||
bias = weights.get_tensor(f"{prefix}.bias") | ||
with init_empty_weights(): | ||
conv2d = cls( | ||
in_channels=in_channels, | ||
out_channels=out_channels, | ||
kernel_size=kernel_size, | ||
stride=stride, | ||
) | ||
|
||
conv2d.weight = torch.nn.Parameter(weight) | ||
conv2d.bias = torch.nn.Parameter(bias) | ||
return conv2d | ||
|
||
|
||
@classmethod | ||
def load_conv2d_no_bias( | ||
cls, prefix, weights, in_channels, out_channels, kernel_size, stride | ||
): | ||
weight = weights.get_tensor(f"{prefix}.weight") | ||
with init_empty_weights(): | ||
conv2d = cls( | ||
in_channels=in_channels, | ||
out_channels=out_channels, | ||
kernel_size=kernel_size, | ||
stride=stride, | ||
) | ||
|
||
conv2d.weight = torch.nn.Parameter(weight) | ||
conv2d.bias = None | ||
return conv2d | ||
|
||
|
||
torch.nn.Conv2d.load = load_conv2d | ||
torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import torch | ||
from EETQ import quant_weights, w8_a16_gemm | ||
|
||
|
||
class EETQLinear(torch.nn.Module): | ||
def __init__( | ||
self, | ||
weight, | ||
bias, | ||
) -> None: | ||
super().__init__() | ||
device = weight.device | ||
if weight.dtype != torch.float16: | ||
weight = weight.to(dtype=torch.float16) | ||
weight = torch.t(weight).contiguous().cpu() | ||
weight, scale = quant_weights(weight, torch.int8, False) | ||
|
||
self.weight = weight.cuda(device) | ||
self.scale = scale.cuda(device) | ||
self.bias = bias.cuda(device) if bias is not None else None | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
output = w8_a16_gemm(input, self.weight, self.scale) | ||
output = output + self.bias if self.bias is not None else output | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import torch | ||
|
||
|
||
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): | ||
device = weight.device | ||
# weight, scale = quant_weights(weight, torch.int8, False) | ||
finfo = torch.finfo(qdtype) | ||
# Calculate the scale as dtype max divided by absmax | ||
scale = finfo.max / weight.abs().max().clamp(min=1e-12) | ||
# scale and clamp the tensor to bring it to | ||
# the representative range of float8 data type | ||
# (as default cast is unsaturated) | ||
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) | ||
# Return both float8 data and the inverse scale (as float), | ||
# as both required as inputs to torch._scaled_mm | ||
qweight = qweight.to(qdtype) | ||
scale = scale.float().reciprocal() | ||
return qweight, scale | ||
|
||
|
||
class Fp8Linear(torch.nn.Module): | ||
def __init__( | ||
self, | ||
weight, | ||
bias, | ||
) -> None: | ||
super().__init__() | ||
self.dtype = weight.dtype | ||
self.qweight, self.scale = fp8_quantize(weight) | ||
|
||
self.bias = bias if bias is not None else None | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
qinput, scale = fp8_quantize(input) | ||
output, _ = torch._scaled_mm( | ||
qinput, | ||
self.qweight.t(), | ||
out_dtype=self.dtype, | ||
scale_a=scale, | ||
scale_b=self.scale, | ||
bias=self.bias, | ||
) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import os | ||
import torch | ||
from text_generation_server.utils.import_utils import ( | ||
SYSTEM, | ||
) | ||
|
||
try: | ||
major, _minor = torch.cuda.get_device_capability() | ||
except Exception: | ||
major = 1 | ||
|
||
HAS_EXLLAMA = False | ||
CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm" | ||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" | ||
if os.getenv("DISABLE_EXLLAMA") == "True": | ||
HAS_EXLLAMA = False | ||
elif CAN_EXLLAMA: | ||
try: | ||
if V2: | ||
from text_generation_server.layers.gptq.exllamav2 import ( | ||
QuantLinear as ExllamaQuantLinear, | ||
create_exllama_buffers, | ||
set_device, | ||
) | ||
|
||
HAS_EXLLAMA = "2" | ||
else: | ||
from text_generation_server.layers.gptq.exllama import ( | ||
Ex4bitLinear as ExllamaQuantLinear, | ||
create_exllama_buffers, | ||
set_device, | ||
) | ||
|
||
HAS_EXLLAMA = "1" | ||
|
||
except ImportError: | ||
pass | ||
|
||
from text_generation_server.layers.gptq.quant_linear import QuantLinear |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.