Skip to content

Commit

Permalink
disable llm_int8 ut (#62282)
Browse files Browse the repository at this point in the history
  • Loading branch information
RichardWooSJTU committed Mar 5, 2024
1 parent 59c61db commit e816529
Showing 1 changed file with 13 additions and 77 deletions.
90 changes: 13 additions & 77 deletions test/quantization/test_llm_int8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
import unittest

import numpy as np
from test_weight_only_linear import convert_uint16_to_float, get_cuda_version
from test_weight_only_linear import convert_uint16_to_float

import paddle
import paddle.nn.quant as Q
from paddle import base
from paddle.base import core
from paddle.base.framework import default_main_program
from paddle.framework import set_default_dtype
from paddle.pir_utils import test_with_pir_api
Expand All @@ -30,12 +29,7 @@
default_main_program().random_seed = 42


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
@unittest.skipIf(True, "Disable this unit test in release/2.6")
class LLMInt8LinearTestCase(unittest.TestCase):
def config(self):
self.dtype = 'float16'
Expand Down Expand Up @@ -149,25 +143,15 @@ def test_llm_int8_linear(self):
)


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
@unittest.skipIf(True, "Disable this unit test in release/2.6")
class LLMInt8LinearTestCase1(LLMInt8LinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.weight_dtype = "int8"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
@unittest.skipIf(True, "Disable this unit test in release/2.6")
class LLMInt8LinearTestCase2(LLMInt8LinearTestCase):
def config(self):
super().config()
Expand All @@ -176,39 +160,23 @@ def config(self):
self.weight_dtype = "int8"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
@unittest.skipIf(True, "Disable this unit test in release/2.6")
class LLMInt8LinearTestCase3(LLMInt8LinearTestCase):
def config(self):
super().config()
self.dtype = 'bfloat16'
self.weight_dtype = "int8"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8 or core is not support bfloat16",
)
@unittest.skipIf(True, "Disable this unit test in release/2.6")
class LLMInt8LinearTestCase4(LLMInt8LinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.weight_dtype = "int4"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
@unittest.skipIf(True, "Disable this unit test in release/2.6")
class LLMInt8LinearTestCase5(LLMInt8LinearTestCase):
def config(self):
super().config()
Expand All @@ -217,26 +185,15 @@ def config(self):
self.weight_dtype = "int4"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8 or core is not support bfloat16",
)
@unittest.skipIf(True, "Disable this unit test in release/2.6")
class LLMInt8LinearTestCase6(LLMInt8LinearTestCase):
def config(self):
super().config()
self.dtype = 'bfloat16'
self.weight_dtype = "int4"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
@unittest.skipIf(True, "Disable this unit test in release/2.6")
class LLMInt8LinearTestCase7(LLMInt8LinearTestCase):
def config(self):
super().config()
Expand All @@ -246,12 +203,7 @@ def config(self):
self.token = 1


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
@unittest.skipIf(True, "Disable this unit test in release/2.6")
class LLMInt8LinearTestCase8(LLMInt8LinearTestCase):
def config(self):
super().config()
Expand All @@ -262,12 +214,7 @@ def config(self):
self.token = 1


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
@unittest.skipIf(True, "Disable this unit test in release/2.6")
class LLMInt8LinearTestCase9(LLMInt8LinearTestCase):
def config(self):
super().config()
Expand All @@ -277,12 +224,7 @@ def config(self):
self.token = 1


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
@unittest.skipIf(True, "Disable this unit test in release/2.6")
class LLMInt8LinearTestCase10(LLMInt8LinearTestCase):
def config(self):
super().config()
Expand All @@ -293,13 +235,7 @@ def config(self):
self.token = 1


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
@unittest.skipIf(True, "Disable this unit test in release/2.6")
class LLMInt8LinearTestCaseStatic(LLMInt8LinearTestCase):
def config(self):
super().config()
Expand Down

0 comments on commit e816529

Please sign in to comment.