Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions gptqmodel/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,14 @@ def timed_gc_collect(*args, **kwargs) -> int:
return collected

# reset dynamo cache on each model load since during ci loop model inference may exhuast cache
torch._dynamo.reset()

# Increase the dynamo cache size limit, default of 8 is too low
if torch._dynamo.config.cache_size_limit < 128:
torch._dynamo.config.cache_size_limit = 128
try:
torch._dynamo.reset()
# Increase the dynamo cache size limit, default of 8 is too low
if torch._dynamo.config.cache_size_limit < 128:
torch._dynamo.config.cache_size_limit = 128
except BaseException:
# triton built from source maybe incompatible with _dynamo private api
pass

if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_available") and torch.cuda.is_available():
HAS_CUDA = True
Expand Down
18 changes: 9 additions & 9 deletions gptqmodel_ext/pack_block_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,22 @@ std::tuple<at::Tensor, at::Tensor> pack_block_cpu(
const int32_t* gidx_ptr = g_idx_i32.const_data_ptr<int32_t>();
int32_t* qweight_ptr = qweight.data_ptr<int32_t>();

int old_threads = at::get_num_threads();
if (threads > 0) {
at::set_num_threads(static_cast<int>(threads));
}

const int64_t out_stride = in_features;
const int64_t scales_stride = out_features;

int64_t grain_size = block_in / word_bits;
if (grain_size <= 0) {
grain_size = 1;
}
if (threads > 0) {
// Limit the number of parallel chunks to roughly `threads` without
// mutating the global ATen thread configuration, keeping the kernel reentrant.
int64_t target_chunk = (num_blocks + threads - 1) / threads;
if (target_chunk <= 0) {
target_chunk = 1;
}
grain_size = std::max<int64_t>(grain_size, target_chunk);
}

at::parallel_for(0, num_blocks, grain_size, [&](int64_t block_begin, int64_t block_end) {
std::array<int32_t, 32> qvals{};
Expand Down Expand Up @@ -163,10 +167,6 @@ std::tuple<at::Tensor, at::Tensor> pack_block_cpu(
}
});

if (threads > 0) {
at::set_num_threads(old_threads);
}

at::Tensor zeros_i32_contig = zeros_i32.contiguous();
const int32_t* zeros_ptr = zeros_i32_contig.const_data_ptr<int32_t>();
const int64_t zeros_stride = zeros_i32_contig.size(1);
Expand Down
6 changes: 5 additions & 1 deletion tests/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@

import torch.cuda # noqa: E402
from datasets import load_dataset # noqa: E402
from ovis.image_to_test_dataset import get_calib_dataset # noqa: E402
try:
from ovis.image_to_test_dataset import get_calib_dataset # noqa: E402
except BaseException:
pass

from transformers import AutoProcessor, AutoTokenizer # noqa: E402

from gptqmodel import BACKEND, GPTQModel # noqa: E402
Expand Down
28 changes: 18 additions & 10 deletions tests/test_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Contact: qubitium@modelcloud.ai, x.com/qubitium

import math
import time
import unittest

import torch
Expand Down Expand Up @@ -62,6 +63,8 @@ def _run_impl(self, impl: str, linear, scales, zeros, g_idx):
scales_T = scales.t().contiguous()
zeros_T = zeros.t().contiguous()

start = time.perf_counter()

if impl == "original":
qlinear.pack_original(linear, scales_T, zeros_T, g_idx=g_idx)
elif impl == "pack_block":
Expand All @@ -84,6 +87,9 @@ def _run_impl(self, impl: str, linear, scales, zeros, g_idx):
else:
raise ValueError(f"Unknown impl `{impl}`")

end = time.perf_counter()
duration = end - start

# Move buffers to CPU for comparisons
result = {
"qweight": qlinear.qweight.detach().cpu(),
Expand All @@ -93,7 +99,7 @@ def _run_impl(self, impl: str, linear, scales, zeros, g_idx):
}
if hasattr(qlinear, "bias") and qlinear.bias is not None:
result["bias"] = qlinear.bias.detach().cpu()
return result
return result, duration

@parameterized.expand(
[
Expand All @@ -109,16 +115,17 @@ def test_pack_consistency(self, bits, group_size):

linear, scales, zeros, g_idx = self._build_inputs(bits, group_size)

baseline = self._run_impl("original", linear, scales, zeros, g_idx)
pack_cpu = self._run_impl("pack_block", linear, scales, zeros, g_idx)
results = {"pack_block": pack_cpu}
baseline, baseline_time = self._run_impl("original", linear, scales, zeros, g_idx)
pack_cpu, pack_cpu_time = self._run_impl("pack_block", linear, scales, zeros, g_idx)
results = {"pack_block": (pack_cpu, pack_cpu_time)}

if torch.cuda.is_available():
results["pack_gpu"] = self._run_impl("gpu", linear, scales, zeros, g_idx)
pack_gpu, pack_gpu_time = self._run_impl("gpu", linear, scales, zeros, g_idx)
results["pack_gpu"] = (pack_gpu, pack_gpu_time)

rows = []
rows.append([f"pack_original (bits={bits}, g={group_size})", 0.0, 0.0, 0.0, 0.0])
for name, tensors in results.items():
rows.append([f"pack_original (bits={bits}, g={group_size})", 0.0, 0.0, 0.0, 0.0, baseline_time * 1e3])
for name, (tensors, duration) in results.items():
diff_qweight = (tensors["qweight"].to(dtype=baseline["qweight"].dtype) - baseline["qweight"]).abs().max().item()
diff_qzeros = (tensors["qzeros"].to(dtype=baseline["qzeros"].dtype) - baseline["qzeros"]).abs().max().item()
diff_scales = (tensors["scales"].to(dtype=baseline["scales"].dtype) - baseline["scales"]).abs().max().item()
Expand All @@ -129,6 +136,7 @@ def test_pack_consistency(self, bits, group_size):
diff_qzeros,
diff_scales,
diff_gidx,
duration * 1e3,
])

self.assertTrue(torch.equal(tensors["qweight"], baseline["qweight"]))
Expand All @@ -139,7 +147,7 @@ def test_pack_consistency(self, bits, group_size):
print(
tabulate(
rows,
headers=["impl", "max|Δ qweight|", "max|Δ qzeros|", "max|Δ scales|", "max|Δ g_idx|"],
headers=["impl", "max|Δ qweight|", "max|Δ qzeros|", "max|Δ scales|", "max|Δ g_idx|", "time [ms]"],
floatfmt=".3e",
)
)
Expand All @@ -155,8 +163,8 @@ def test_pack_negative_g_idx(self):
g_idx_neg = g_idx.to(dtype=torch.int32)
g_idx_neg[::7] -= groups

baseline = self._run_impl("original", linear, scales, zeros, g_idx_neg)
pack_cpu = self._run_impl("pack_block", linear, scales, zeros, g_idx_neg)
baseline, _ = self._run_impl("original", linear, scales, zeros, g_idx_neg)
pack_cpu, _ = self._run_impl("pack_block", linear, scales, zeros, g_idx_neg)

self.assertTrue(torch.equal(pack_cpu["qweight"], baseline["qweight"]))
self.assertTrue(torch.equal(pack_cpu["qzeros"], baseline["qzeros"]))
Expand Down