Skip to content

Commit

Permalink
Merge pull request #6373 from digantdesai:goi_scalar_packing
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 634497827
  • Loading branch information
xnnpack-bot committed May 16, 2024
2 parents ecdb4a9 + fdd08a9 commit 084e96a
Show file tree
Hide file tree
Showing 5 changed files with 326 additions and 0 deletions.
39 changes: 39 additions & 0 deletions src/microparams-init.c
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,45 @@ void xnn_init_qs8_to_qs8_qc8w_scale_fp32_params(
}
}

void xnn_init_qs8_qb8w_scale_fp32_params(
size_t channels,
size_t channels_tile,
size_t channels_subtile,
size_t stride,
size_t substride,
size_t num_blocks,
size_t block_stride,
size_t stride_offset,
const float scale[XNN_MIN_ELEMENTS(1)],
void* packed_w)
{
void* packed_w_saved = packed_w;
for (size_t block_start = 0; block_start < num_blocks; block_start++) {
packed_w = (void*)((uintptr_t) packed_w_saved + block_start * block_stride);
const size_t tiled_channels = round_down_po2(channels, channels_tile);
size_t tile_start = 0;
for (; tile_start < tiled_channels; tile_start += channels_tile) {
const size_t tile_size = channels_tile;
for (size_t tile_offset = 0; tile_offset < tile_size; tile_offset++) {
size_t scale_index = (tile_start + tile_offset) * num_blocks + block_start;
unaligned_indexed_store_f32(packed_w, tile_offset, scale[scale_index]);
}
packed_w = (void*) ((uintptr_t) packed_w + stride);
}

packed_w = (void*) ((uintptr_t) packed_w - stride_offset);

for (; tile_start < channels; tile_start += channels_subtile) {
const size_t tile_size = min(channels - tile_start, channels_subtile);
for (size_t tile_offset = 0; tile_offset < tile_size; tile_offset++) {
size_t scale_index = (tile_start + tile_offset) * num_blocks + block_start;
unaligned_indexed_store_f32(packed_w, tile_offset, scale[scale_index]);
}
packed_w = (void*) ((uintptr_t) packed_w + substride);
}
}
}

size_t xnn_init_qs8_avgpool_minmax_fp32_scalar_fmagic_params(
union xnn_qs8_avgpool_minmax_params params[XNN_MIN_ELEMENTS(1)],
int32_t init_bias,
Expand Down
100 changes: 100 additions & 0 deletions src/packing.c
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,106 @@ void xnn_pack_qs8_qc4w_gemm_goi_w(
} while (--g != 0);
}

void xnn_pack_qs8_qb4w_gemm_goi_w(
size_t g,
size_t nc,
size_t kc,
size_t nr,
size_t kr,
size_t sr,
size_t bl,
const uint8_t* k,
const int32_t* b,
const float* scale,
void* packed_weights,
size_t extra_bytes_bl,
size_t extra_bytes_n,
const struct xnn_qs8_qc4w_packing_params* params)
{
assert(g != 0);
assert(nc != 0);
assert(kc != 0);
assert(nr >= sr);
assert(kr >= 1 && kr <= 16);
assert(sr >= 1 && sr <= 16);
assert(k != NULL);
assert(packed_weights != NULL);
assert(params != NULL);
assert(params->kernel_zero_point == 8);

const size_t skr = sr * kr;

// Constraints for blocksize
// These need to be reevaluated in the future.
assert(bl != 0);
assert(round_up_po2(kc, skr) % bl == 0); // must be round number of blocks inside a column
assert(bl % skr == 0); // must be round number of kr * sr
assert(bl <= round_up_po2(kc, skr)); // must be larger than K
assert(2 * skr <= bl); // must be at least two skr to avoid back-to-back empty_bytes

const size_t num_blocks = round_up_po2(kc, skr) / bl;
const int32_t izp = (int32_t) params->input_zero_point;

do {
size_t nr_block_start = 0;
do {
const size_t nr_block_size = min(nc - nr_block_start, nr);
int32_t* packed_b = (int32_t*) packed_weights;
if XNN_LIKELY(b != NULL) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
unaligned_store_s32(packed_weights, b[nr_block_start + nr_block_offset]);
packed_weights = (int32_t*) packed_weights + 1;
}
} else {
size_t n = nr_block_size;
do {
unaligned_store_s32(packed_weights, 0);
packed_weights = (int32_t*) packed_weights + 1;
} while (--n != 0);
}
packed_weights = (int32_t*) packed_weights + (nr - nr_block_size);

for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr * 2); kr_block_start += kr * 2) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
int32_t ksum = 0;
for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
const size_t kc_idx = round_down_po2(kr_block_start, skr) + ((kr_block_start + kr_block_offset + nr_block_offset * kr) & (skr - 1));
const size_t k_offset = (nr_block_start + nr_block_offset) * kc + kc_idx;
const size_t kh_offset = k_offset + kr;
uint8_t kv_lo = 8;
if (kc_idx < kc) {
kv_lo = ((k_offset & 1) ? (k[k_offset >> 1] >> 4) : (k[k_offset >> 1] & 0xF));
}
uint8_t kv_hi = 8;
if ((kc_idx + kr) < kc) {
kv_hi = ((kh_offset & 1) ? (k[kh_offset >> 1] >> 4) : (k[kh_offset >> 1] & 0xF));
}
ksum += kv_lo + kv_hi - 16; // subtract 2 zero points (8)
const uint8_t kv = (kv_lo | (kv_hi << 4)) ^ 0x88;
((uint8_t*) packed_weights)[kr_block_offset] = kv;
}

size_t block_index = kr_block_start / bl;
size_t scale_index = (nr_block_start + nr_block_offset) * num_blocks + block_index;
unaligned_indexed_store_f32(packed_b, nr_block_offset, unaligned_indexed_load_f32(packed_b, nr_block_offset) - (float) ksum * izp * scale[scale_index] * 16);
packed_weights = (uint8_t*) packed_weights + kr; // kr * 2 nibbles
}
if (((2 * kr) + kr_block_start) % bl == 0) {
packed_weights = (void*) ((uintptr_t) packed_weights + extra_bytes_bl);
}

packed_weights = (uint8_t*) packed_weights + (nr - nr_block_size) * kr; // skip NR remainder
}
packed_weights = (void*) ((uintptr_t) packed_weights + extra_bytes_n);
nr_block_start += nr;
} while (nr_block_start < nc);
k += nc * kc; // kc * 2 nibbles
if XNN_UNPREDICTABLE(b != NULL) {
b += nc;
}
} while (--g != 0);
}

void xnn_pack_qs8_qc4w_gemm_gio_w(
size_t g,
size_t nc,
Expand Down
14 changes: 14 additions & 0 deletions src/xnnpack/microparams-init.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,20 @@ XNN_INTERNAL void xnn_init_qs8_qc8w_scale_fp32_params(
const float scale[XNN_MIN_ELEMENTS(1)],
void* packed_w);

XNN_INTERNAL void xnn_init_qs8_qb8w_scale_fp32_params(
size_t channels,
size_t channels_tile,
size_t channels_subtile,
size_t stride,
size_t substride,
size_t num_blocks,
size_t block_stride,
// How much offset to subtract from packed_w pointer when moving from channels_tile to channels_subtile.
size_t stride_offset,
const float scale[XNN_MIN_ELEMENTS(1)],
void* packed_w);


XNN_INTERNAL void xnn_init_qs8_to_qs8_qc8w_scale_fp32_params(
size_t channels,
size_t channels_tile,
Expand Down
35 changes: 35 additions & 0 deletions src/xnnpack/pack.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,41 @@ XNN_INTERNAL void xnn_pack_qs8_qc4w_gemm_goi_w(
size_t extra_bytes,
const struct xnn_qs8_qc4w_packing_params* params);

/*
* Packing function for weights with int4 elements, per channel blockwise quantized
*/
typedef void (*xnn_pack_qs8_qb4w_gemm_fn)(
size_t groups,
size_t nc,
size_t kc,
size_t nr,
size_t kr,
size_t sr,
size_t block_size, // number of K elements in a block
const uint8_t* kernel,
const int32_t* bias,
const float* scale,
void* packed_weights,
size_t extra_bytes_per_block,
size_t extra_bytes_per_n,
const struct xnn_qs8_qc4w_packing_params* params);

XNN_INTERNAL void xnn_pack_qs8_qb4w_gemm_goi_w(
size_t g,
size_t nc,
size_t kc,
size_t nr,
size_t kr,
size_t sr,
size_t bl,
const uint8_t* kernel,
const int32_t* bias,
const float* scale,
void* packed_weights,
size_t extra_bytes_bl,
size_t extra_bytes_n,
const struct xnn_qs8_qc4w_packing_params* params);

typedef void (*xnn_pack_f32_qc4w_gemm_fn)(
size_t g,
size_t nc,
Expand Down
138 changes: 138 additions & 0 deletions test/packing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <xnnpack/math.h>
#include <xnnpack/microkernel-utils.h>
#include <xnnpack/microparams-init.h>
#include <xnnpack/pack.h>

#include <algorithm>
Expand Down Expand Up @@ -147,6 +148,143 @@ TEST(PACK_QD8_F32_QC4W_GEMM_GIO_W, kr_eq_4_nr_eq_2) {
EXPECT_EQ(expected, packed_weights);
}

TEST(PACK_QD8_F32_QB4W_GEMM_GOI_W, bl_eq_kc) {
size_t g = 1;
size_t nc = 1;
size_t kc = 16;
size_t nr = 1;
size_t kr = 4;
size_t sr = 1;
size_t bl = kc;
size_t k_num_blocks = round_up_po2(kc, kr) / bl;

std::vector<int32_t> b(g * nc);
std::iota(b.begin(), b.end(), 0);
std::vector<uint8_t> k(g * nc * kc / 2);
k[0] = 0x98; k[1] = 0xBA; k[2] = 0xDC; k[3] = 0xFE; k[4] = 0x10; k[5] = 0x32; k[6] = 0x54; k[7] = 0x76;
size_t extra_bytes_bl = sizeof(float);
size_t extra_bytes_n = sizeof(float);
std::vector<uint8_t> packed_weights(g * round_up(nc, nr) * (sizeof(float) + round_up_po2(kc, kr * sr) / 2)
+ k_num_blocks * round_up(nc, nr) * extra_bytes_bl + round_up(nc, nr) * extra_bytes_n);
std::vector<float> scale(nc * k_num_blocks, 0);
std::fill(scale.begin(), scale.end(), 853.6010);
auto a = xnn_qs8_qc4w_packing_params{ -1, 0x8 };

xnn_pack_qs8_qb4w_gemm_goi_w(g, nc, kc, nr, kr, sr, bl,
k.data(), b.data(), /*scale=*/scale.data(), packed_weights.data(), extra_bytes_bl, extra_bytes_n, /*params=*/&a);

size_t k_stride = round_up_po2(kc, kr * sr * 2 /* planes */);

// If filter is 4-bit, half k_stride (since we will scale k_stride by log2_filter_element_size, and we pass 0 for qc4).
k_stride = round_up_po2(k_stride, 2) >> 1;

size_t k_bytes = sizeof(int8_t) * k_stride * nr;
size_t bias_bytes = sizeof(float) * nr;
size_t ksum_bytes = sizeof(float) * nr;
size_t block_bytes = sizeof(float) * k_num_blocks * nr;

size_t start_offset = ksum_bytes + k_bytes / k_num_blocks;
size_t stride = ksum_bytes + k_bytes + block_bytes + bias_bytes;
size_t block_stride = (bl * nr) / 2 + (sizeof(float) * nr);

xnn_init_qs8_qb8w_scale_fp32_params(
/*channels=*/ nc,
/*channels_tile=*/ nr,
/*channel_stride=*/ nr,
/*stride=*/ stride,
/*substride=*/ stride,
/*num_blocks=*/k_num_blocks,
/*block_stride=*/block_stride,
/*stride_offset=*/ 0,
/*scale=*/ scale.data(),
/*packed_weight=*/ packed_weights.data() + start_offset);

const std::vector<uint8_t> expected = {
// kscaledsum
0x78, 0x66, 0xd5, 0xc7, // -1 * 853.6010 * (sum(-8..+7) = -109260.9297 = 0xc7d56678

// weights.
0x40, 0x51, 0x62, 0x73, // kr0 | kr1
0xC8, 0xD9, 0xEA, 0xFB, // kr2 | kr3
// extra bytes bl
0x77, 0x66, 0x55, 0x44,
// extra bytes n
0, 0, 0, 0
};
EXPECT_EQ(expected, packed_weights);
}

TEST(PACK_QD8_F32_QB4W_GEMM_GOI_W, bl_lt_kc) {
size_t g = 1;
size_t nc = 1;
size_t kc = 16;
size_t nr = 1;
size_t kr = 4;
size_t sr = 1;
size_t bl = 8;
size_t k_num_blocks = kc / bl;

std::vector<int32_t> b(g * nc);
std::iota(b.begin(), b.end(), 0);
std::vector<uint8_t> k(g * nc * kc / 2);
k[0] = 0x98; k[1] = 0xBA; k[2] = 0xDC; k[3] = 0xFE; k[4] = 0x10; k[5] = 0x32; k[6] = 0x54; k[7] = 0x76;
size_t extra_bytes_n = sizeof(float);
size_t extra_bytes_bl = sizeof(float);
std::vector<uint8_t> packed_weights(g * round_up(nc, nr) * (sizeof(float) + round_up_po2(kc, kr * sr) / 2)
+ k_num_blocks * round_up(nc, nr) * extra_bytes_bl + round_up(nc, nr) * extra_bytes_n);
std::vector<float> scale(nc * k_num_blocks, 0);
std::fill(scale.begin(), scale.end(), 853.6010);


auto a = xnn_qs8_qc4w_packing_params{ -1, 0x8 };
xnn_pack_qs8_qb4w_gemm_goi_w(g, nc, kc, nr, kr, sr, bl,
k.data(), b.data(), /*scale=*/scale.data(), packed_weights.data(), extra_bytes_bl, extra_bytes_n, /*params=*/&a);

size_t k_stride = round_up_po2(kc, kr * sr * 2 /* planes */);

k_stride = round_up_po2(k_stride, 2) >> 1;

size_t k_bytes = sizeof(int8_t) * k_stride * nr;
size_t bias_bytes = sizeof(float) * nr;
size_t ksum_bytes = sizeof(float) * nr;
size_t block_bytes = sizeof(float) * k_num_blocks * nr;

size_t start_offset = ksum_bytes + k_bytes / k_num_blocks;
size_t stride = ksum_bytes + k_bytes + block_bytes + bias_bytes;
size_t block_stride = (bl * nr) / 2 + (sizeof(float) * nr);

xnn_init_qs8_qb8w_scale_fp32_params(
/*channels=*/ nc,
/*channels_tile=*/ nr,
/*channel_stride=*/ nr,
/*stride=*/ stride,
/*substride=*/ stride,
/*num_blocks=*/k_num_blocks,
/*block_stride=*/block_stride,
/*stride_offset=*/ 0,
/*scale=*/ scale.data(),
/*packed_weight=*/ packed_weights.data() + start_offset);

const std::vector<uint8_t> expected = {
// kscaledsum
0x78, 0x66, 0xd5, 0xc7, // -1 * 853.6010 * (sum(-8..+7) = -109260.9297 = 0xc7d56678

// weights
0x40, 0x51, 0x62, 0x73, // kr0 | kr1
// extra bytes bl
0x77, 0x66, 0x55, 0x44,

// weights
0xC8, 0xD9, 0xEA, 0xFB, // kr2 | kr3
// extra bytes bl
0x77, 0x66, 0x55, 0x44,

// extra bytes n
0, 0, 0, 0
};
EXPECT_EQ(expected, packed_weights);
}

TEST(PACK_F32_GEMM_GIO_W, g_eq_1) {
size_t g = 1;
size_t nc = 2;
Expand Down

0 comments on commit 084e96a

Please sign in to comment.