Skip to content

Commit

Permalink
Add intersect_vector16_inplace. (#445)
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlesChen888 committed Mar 14, 2023
1 parent 2c6708e commit 84fe3c8
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 2 deletions.
3 changes: 3 additions & 0 deletions include/roaring/array_util.h
Expand Up @@ -125,6 +125,9 @@ int32_t intersect_vector16(const uint16_t *__restrict__ A, size_t s_a,
const uint16_t *__restrict__ B, size_t s_b,
uint16_t *C);

int32_t intersect_vector16_inplace(uint16_t *__restrict__ A, size_t s_a,
const uint16_t *__restrict__ B, size_t s_b);

/**
* Compute the cardinality of the intersection using SSE4 instructions
*/
Expand Down
92 changes: 92 additions & 0 deletions src/array_util.c
Expand Up @@ -444,6 +444,98 @@ int32_t intersect_vector16(const uint16_t *__restrict__ A, size_t s_a,
}
return (int32_t)count;
}

int32_t intersect_vector16_inplace(uint16_t *__restrict__ A, size_t s_a,
const uint16_t *__restrict__ B, size_t s_b) {
size_t count = 0;
size_t i_a = 0, i_b = 0;
const int vectorlength = sizeof(__m128i) / sizeof(uint16_t);
const size_t st_a = (s_a / vectorlength) * vectorlength;
const size_t st_b = (s_b / vectorlength) * vectorlength;
__m128i v_a, v_b;
if ((i_a < st_a) && (i_b < st_b)) {
v_a = _mm_lddqu_si128((__m128i *)&A[i_a]);
v_b = _mm_lddqu_si128((__m128i *)&B[i_b]);
__m128i tmp[2] = {_mm_setzero_si128()};
size_t tmp_count = 0;
while ((A[i_a] == 0) || (B[i_b] == 0)) {
const __m128i res_v = _mm_cmpestrm(
v_b, vectorlength, v_a, vectorlength,
_SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK);
const int r = _mm_extract_epi32(res_v, 0);
__m128i sm16 = _mm_loadu_si128((const __m128i *)shuffle_mask16 + r);
__m128i p = _mm_shuffle_epi8(v_a, sm16);
_mm_storeu_si128((__m128i*)&((uint16_t*)tmp)[tmp_count], p);
tmp_count += _mm_popcnt_u32(r);
const uint16_t a_max = A[i_a + vectorlength - 1];
const uint16_t b_max = B[i_b + vectorlength - 1];
if (a_max <= b_max) {
_mm_storeu_si128((__m128i *)&A[count], tmp[0]);
_mm_storeu_si128(tmp, _mm_setzero_si128());
count += tmp_count;
tmp_count = 0;
i_a += vectorlength;
if (i_a == st_a) break;
v_a = _mm_lddqu_si128((__m128i *)&A[i_a]);
}
if (b_max <= a_max) {
i_b += vectorlength;
if (i_b == st_b) break;
v_b = _mm_lddqu_si128((__m128i *)&B[i_b]);
}
}
if ((i_a < st_a) && (i_b < st_b)) {
while (true) {
const __m128i res_v = _mm_cmpistrm(
v_b, v_a,
_SIDD_UWORD_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_BIT_MASK);
const int r = _mm_extract_epi32(res_v, 0);
__m128i sm16 = _mm_loadu_si128((const __m128i *)shuffle_mask16 + r);
__m128i p = _mm_shuffle_epi8(v_a, sm16);
_mm_storeu_si128((__m128i*)&((uint16_t*)tmp)[tmp_count], p);
tmp_count += _mm_popcnt_u32(r);
const uint16_t a_max = A[i_a + vectorlength - 1];
const uint16_t b_max = B[i_b + vectorlength - 1];
if (a_max <= b_max) {
_mm_storeu_si128((__m128i *)&A[count], tmp[0]);
_mm_storeu_si128(tmp, _mm_setzero_si128());
count += tmp_count;
tmp_count = 0;
i_a += vectorlength;
if (i_a == st_a) break;
v_a = _mm_lddqu_si128((__m128i *)&A[i_a]);
}
if (b_max <= a_max) {
i_b += vectorlength;
if (i_b == st_b) break;
v_b = _mm_lddqu_si128((__m128i *)&B[i_b]);
}
}
}
// tmp_count <= 8, so this does not affect efficiency so much
for (size_t i = 0; i < tmp_count; i++) {
A[count] = ((uint16_t*)tmp)[i];
count++;
}
i_a += tmp_count; // We can at least jump pass $tmp_count elements in A
}
// intersect the tail using scalar intersection
while (i_a < s_a && i_b < s_b) {
uint16_t a = A[i_a];
uint16_t b = B[i_b];
if (a < b) {
i_a++;
} else if (b < a) {
i_b++;
} else {
A[count] = a; //==b;
count++;
i_a++;
i_b++;
}
}
return (int32_t)count;
}
CROARING_UNTARGET_REGION

CROARING_TARGET_AVX2
Expand Down
13 changes: 11 additions & 2 deletions src/containers/array.c
Expand Up @@ -361,7 +361,6 @@ bool array_container_intersect(const array_container_t *array1,
* */
void array_container_intersection_inplace(array_container_t *src_1,
const array_container_t *src_2) {
// todo: can any of this be vectorized?
int32_t card_1 = src_1->cardinality, card_2 = src_2->cardinality;
const int threshold = 64; // subject to tuning
if (card_1 * threshold < card_2) {
Expand All @@ -371,8 +370,18 @@ void array_container_intersection_inplace(array_container_t *src_1,
src_1->cardinality = intersect_skewed_uint16(
src_2->array, card_2, src_1->array, card_1, src_1->array);
} else {
#ifdef CROARING_IS_X64
if (croaring_avx2()) {
src_1->cardinality = intersect_vector16_inplace(
src_1->array, card_1, src_2->array, card_2);
} else {
src_1->cardinality = intersect_uint16(
src_1->array, card_1, src_2->array, card_2, src_1->array);
}
#else
src_1->cardinality = intersect_uint16(
src_1->array, card_1, src_2->array, card_2, src_1->array);
src_1->array, card_1, src_2->array, card_2, src_1->array);
#endif
}
}

Expand Down

0 comments on commit 84fe3c8

Please sign in to comment.