Skip to content

Commit

Permalink
gf 127 attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiyong1997 committed Sep 11, 2024
1 parent 9c00ff9 commit 47bb22d
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 70 deletions.
4 changes: 2 additions & 2 deletions arith/src/extension_field.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
mod fr_ext;
// mod gf2_127;
mod gf2_127;
mod gf2_128;
mod gf2_128x8;
mod m31_ext;
mod m31_ext3x16;
use crate::{Field, FieldSerde};

// pub use gf2_127::*;
pub use gf2_127::*;
pub use gf2_128::*;
pub use gf2_128x8::GF2_128x8;
#[cfg(target_arch = "x86_64")]
Expand Down
118 changes: 50 additions & 68 deletions arith/src/extension_field/gf2_127/avx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl Field for AVX512GF2_127 {
#[inline(always)]
fn one() -> Self {
AVX512GF2_127 {
v: unsafe { std::mem::transmute::<[i32; 4], __m128i>([1, 0, 0, 0]) },
v: unsafe { std::mem::transmute::<[i32; 4], __m128i>([1, 0, 0, 0]) },
}
}

Expand Down Expand Up @@ -181,35 +181,34 @@ impl ExtensionField for AVX512GF2_127 {
res
}

///
///
#[inline]
fn mul_by_x(&self) -> Self {
unsafe {
// Shift left by 1 bit
let shifted = _mm_slli_epi64(self.v, 1);

// Get the most significant bit and move it
let msb = _mm_srli_epi64(self.v, 63);
let msb_moved = _mm_slli_si128(msb, 8);
let msb = _mm_srli_epi64(self.v, 63);
let msb_moved = _mm_slli_si128(msb, 8);

// Combine the shifted value with the moved msb
let shifted_consolidated = _mm_or_si128(shifted, msb_moved);
let x0to126 = _mm_and_si128(shifted_consolidated, X0TO126_MASK);

// Create the reduction value (0b11) and the comparison value (1)
let reduction = {
let multiplier = _mm_set_epi64x(0, 0b11);
let one = _mm_set_epi64x(0, 1);

// Check if the MSB was 1 and create a mask
let mask = _mm_cmpeq_epi64(
_mm_srli_si128(_mm_srli_epi64(shifted, 63), 8),
one);
let mask = _mm_cmpeq_epi64(_mm_srli_si128(_mm_srli_epi64(shifted, 63), 8), one);

_mm_and_si128(mask, multiplier)
};

// Apply the reduction conditionally
let res = _mm_xor_si128(shifted_consolidated, reduction);
let res = _mm_xor_si128(x0to126, reduction);

Self { v: res }
}
Expand All @@ -225,28 +224,28 @@ impl From<GF2> for AVX512GF2_127 {
}
}

const X0TO126_MASK: __m128i = unsafe { transmute::<[u8; 16], __m128i>(
[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F])};
const X127_MASK: __m128i = unsafe { transmute::<[u8; 16], __m128i>(
[0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80])};
const X127_REMINDER: __m128i = unsafe { transmute::<[u8; 16], __m128i>(
[0b11, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80])};


#[inline(always)]
unsafe fn mm_bitshift_left<const count: usize>(x: __m128i) -> __m128i
{
let mut carry = _mm_bslli_si128(x, 8);
carry = _mm_srli_epi64(carry, 64 - count);
let x = _mm_slli_epi64(x, count);
_mm_or_si128(x, carry)
}

// WARNING: The following assumes little endian storage
const X0TO126_MASK: __m128i = unsafe {
transmute::<[u8; 16], __m128i>([
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0x7F,
])
};
const X127_MASK: __m128i = unsafe {
transmute::<[u8; 16], __m128i>([
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x80,
])
};
const X127_REMINDER: __m128i = unsafe {
transmute::<[u8; 16], __m128i>([
0b11, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00,
])
};

#[inline]
unsafe fn gfmul(a: __m128i, b: __m128i) -> __m128i {
let xmm_mask = _mm_setr_epi32((0xFFffffff_u32) as i32, 0x0, 0x0, 0x0);

// a = a0|a1, b = b0|b1

let mut tmp3 = _mm_clmulepi64_si128(a, b, 0x00); // tmp3 = a0 * b0
Expand All @@ -266,51 +265,34 @@ unsafe fn gfmul(a: __m128i, b: __m128i) -> __m128i {
tmp5 = _mm_slli_si128(tmp4, 8); // tmp5 = e0 | 00
tmp4 = _mm_srli_si128(tmp4, 8); // tmp4 = 00 | e1
tmp3 = _mm_xor_si128(tmp3, tmp5); // the lower 128 bits, deg 0 - 127
tmp6 = _mm_xor_si128(tmp6, tmp4); // the higher 128 bits, deg 128 - 252, the 124 least signicicant bits are non-zero
tmp6 = _mm_xor_si128(tmp6, tmp4); // the higher 128 bits, deg 128 - 252(255), only the 124 least signicicant bits are non-zero

// x^0 - x^126
let x0to126 = _mm_and_si128(tmp3, X0TO126_MASK);

// x^127
tmp4 = _mm_and_si128(tmp3, X127_MASK);
tmp4 = _mm_cmpeq_epi64(tmp4, X127_MASK);
tmp4 = _mm_srli_si128(tmp4, 15);
let x127 = _mm_and_si128(tmp4, X127_REMINDER);

// x^128 - x^252
let x128to252 =
_mm_and_si128(
mm_bitshift_left::<2>(tmp6),
mm_bitshift_left::<1>(tmp6),
);

_mm_and_si128(_mm_and_si128(x0to126, x127), x128to252)

// let mut tmp7 = _mm_srli_epi32(tmp6, 31);
// let mut tmp8 = _mm_srli_epi32(tmp6, 30);
// let tmp9 = _mm_srli_epi32(tmp6, 25);

// tmp7 = _mm_xor_si128(tmp7, tmp8);
// tmp7 = _mm_xor_si128(tmp7, tmp9);

// tmp8 = _mm_shuffle_epi32(tmp7, 147);
// tmp7 = _mm_and_si128(xmm_mask, tmp8);
// tmp8 = _mm_andnot_si128(xmm_mask, tmp8);

// tmp3 = _mm_xor_si128(tmp3, tmp8);
// tmp6 = _mm_xor_si128(tmp6, tmp7);

// let tmp10 = _mm_slli_epi32(tmp6, 1);
// tmp3 = _mm_xor_si128(tmp3, tmp10);

// let tmp11 = _mm_slli_epi32(tmp6, 2);
// tmp3 = _mm_xor_si128(tmp3, tmp11);

// let tmp12 = _mm_slli_epi32(tmp6, 7);
// tmp3 = _mm_xor_si128(tmp3, tmp12);

// _mm_xor_si128(tmp3, tmp6)

// x^127
tmp4 = _mm_and_si128(tmp3, X127_MASK);
tmp4 = _mm_cmpeq_epi8(tmp4, X127_MASK);
tmp4 = _mm_srli_si128(tmp4, 15);
let x127 = _mm_and_si128(tmp4, X127_REMINDER);

// x^128 - x^252
// shift left tmp6 by 1 bit
tmp3 = _mm_slli_si128(tmp6, 8);
tmp3 = _mm_srli_epi64(tmp3, 64 - 1);
tmp4 = _mm_slli_epi64(tmp6, 1);
tmp3 = _mm_or_si128(tmp3, tmp4);

// shift left tmp6 by 2 bits
tmp4 = _mm_slli_si128(tmp6, 8);
tmp4 = _mm_srli_epi64(tmp4, 64 - 2);
tmp5 = _mm_slli_epi64(tmp6, 2);
tmp4 = _mm_or_si128(tmp4, tmp5);

let x128to252 = _mm_xor_si128(tmp3, tmp4);

tmp3 = _mm_xor_si128(x0to126, x127);
_mm_xor_si128(tmp3, x128to252)
}

impl Default for AVX512GF2_127 {
Expand Down
1 change: 1 addition & 0 deletions arith/src/extension_field/gf2_127/neon.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

1 change: 1 addition & 0 deletions arith/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod bn254;
mod extension_field;
mod field;
mod gf2;
mod gf2_127;
mod gf2_128;
mod gf2_128x8;
mod m31;
Expand Down
31 changes: 31 additions & 0 deletions arith/src/tests/gf2_127.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use ark_std::test_rng;
use std::io::Cursor;

use crate::{FieldSerde, GF2_127};

use super::{
extension_field::random_extension_field_tests,
field::{random_field_tests, random_inversion_tests},
simd_field::random_simd_field_tests,
};

#[test]
fn test_field() {
random_field_tests::<GF2_127>("GF2_127".to_string());
random_extension_field_tests::<GF2_127>("GF2_127".to_string());

let mut rng = test_rng();
random_inversion_tests::<GF2_127, _>(&mut rng, "GF2_127".to_string());
}

#[test]
fn test_custom_serde_vectorize_gf2() {
let a = GF2_127::from(0);
let mut buffer = vec![];
assert!(a.serialize_into(&mut buffer).is_ok());
let mut cursor = Cursor::new(buffer);
let b = GF2_127::deserialize_from(&mut cursor);
assert!(b.is_ok());
let b = b.unwrap();
assert_eq!(a, b);
}

0 comments on commit 47bb22d

Please sign in to comment.