-
Notifications
You must be signed in to change notification settings - Fork 35
Expand file tree
/
Copy pathmain.rs
More file actions
304 lines (260 loc) · 8.37 KB
/
main.rs
File metadata and controls
304 lines (260 loc) · 8.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
//! A simple chi squared computation that demonstrates how to
//! optimize [`fhe_program`]s. This example shows the parts of chi
//! squared computed homomorphically. The problem can be summized
//! as given integers `n_0`, `n_1`, `n_2`, compute:
//! * `alpha` = `(4 * n_0 * n_2 - n_1^2)^2`
//! * `b_1` = `2(2n_0 + n_1)^2`
//! * `b_2` = `(2n_0 + n_1) * (2n_2 + n_1)`
//! * `b_3` = `2(2n_2 + n_1)^2`
//!
//! For more details on this algorithm and to compare
//! Sunscreen's results with other FHE compilers, see
//! [SoK: Fully Homomorphic Encryption Compilers](https://arxiv.org/abs/2101.07078).
use sunscreen::{
fhe_program,
types::{
bfv::{Batched, Signed},
Cipher, FheType, TypeName,
},
Compiler, Error, FheProgramFn, FheProgramInput, FheRuntime, PlainModulusConstraint,
};
use std::marker::PhantomData;
use std::ops::*;
use std::time::Instant;
/**
* The naive implementation of chi squared. More or less a
* transliteration of the problem statement.
*
* Defining the implementation generically this way allows us
* to use both the Signed and Batched data types.
*/
fn chi_sq_impl<T>(n_0: T, n_1: T, n_2: T) -> (T, T, T, T)
where
T: Add<T, Output = T> + Mul<T, Output = T> + Sub<T, Output = T> + Copy,
i64: Mul<T, Output = T>,
{
let a = 4 * n_0 * n_2 - n_1 * n_1;
let a_sq = a * a;
let b_1 = 2 * n_0 + n_1;
let b_1_sq = 2 * b_1 * b_1;
let b_2 = (2 * n_0 + n_1) * (2 * n_2 + n_1);
let b_3 = 2 * (2 * n_2 + n_1) * (2 * n_2 + n_1);
(a_sq, b_1_sq, b_2, b_3)
}
/**
* This implementation features the following optimizations:
* * Replace multiplication by constant with additions.
* * Common subexpression elimination. I.e. reuse temporaries multiple times to avoid recomputation.
*
* On a first gen M1 Mac, this implementation is over 6x
* faster than the naive implementation.
*/
fn chi_sq_optimized_impl<T>(n_0: T, n_1: T, n_2: T) -> (T, T, T, T)
where
T: Add<T, Output = T> + Mul<T, Output = T> + Sub<T, Output = T> + Copy,
i64: Mul<T, Output = T>,
{
let x = n_0 + n_0 + n_1;
let y = n_2 + n_2 + n_1;
// alpha
let n_0_n_2 = n_0 * n_2;
let n_0_n_2 = n_0_n_2 + n_0_n_2;
let n_0_n_2 = n_0_n_2 + n_0_n_2;
let n_1_sq = n_1 * n_1;
let alpha = n_0_n_2 - n_1_sq;
let alpha = alpha * alpha;
// b_1
let b_1 = x * x;
let b_1 = b_1 + b_1;
// b_2
let b_2 = x * y;
// b_3
let b_3 = y * y;
let b_3 = b_3 + b_3;
(alpha, b_1, b_2, b_3)
}
#[fhe_program(scheme = "bfv")]
fn chi_sq_fhe_program(
n_0: Cipher<Signed>,
n_1: Cipher<Signed>,
n_2: Cipher<Signed>,
) -> (
Cipher<Signed>,
Cipher<Signed>,
Cipher<Signed>,
Cipher<Signed>,
) {
chi_sq_impl(n_0, n_1, n_2)
}
#[fhe_program(scheme = "bfv")]
fn chi_sq_optimized_fhe_program(
n_0: Cipher<Signed>,
n_1: Cipher<Signed>,
n_2: Cipher<Signed>,
) -> (
Cipher<Signed>,
Cipher<Signed>,
Cipher<Signed>,
Cipher<Signed>,
) {
chi_sq_optimized_impl(n_0, n_1, n_2)
}
#[fhe_program(scheme = "bfv")]
fn chi_sq_batched_fhe_program(
n_0: Cipher<Batched<4096>>,
n_1: Cipher<Batched<4096>>,
n_2: Cipher<Batched<4096>>,
) -> (
Cipher<Batched<4096>>,
Cipher<Batched<4096>>,
Cipher<Batched<4096>>,
Cipher<Batched<4096>>,
) {
chi_sq_optimized_impl(n_0, n_1, n_2)
}
/**
* Compute chi squared without encryption. This function may report
* as taking 0 seconds due to being faster than the clock
* resolution, but a typical time on a first gen M1 mac under
* 40ns.
*/
fn run_native<F>(f: F, n_0: i64, n_1: i64, n_2: i64)
where
F: Fn(i64, i64, i64) -> (i64, i64, i64, i64),
{
let start = Instant::now();
let (a, b_1, b_2, b_3) = f(n_0, n_1, n_2);
let elapsed = start.elapsed().as_secs_f64();
println!(
"\t\tchi_sq (non-fhe) alpha {a}, beta_1 {b_1}, beta_2 {b_2}, beta_3 {b_3}, ({elapsed}s)"
);
}
/**
* Compile the given fhe_program, encrypt some data, homomorphically
* run the fhe_program, decrypt the result, and report timings on
* each step.
*
* The [`PhantomData`] argument allows us to tell Rust what the type
* of U. This is preferable than passing an explicit type for F
* using turbofish, since the concrete type of F is is an
* implementation detail of the `#[fhe_program]` macro and could
* change in the future.
*/
fn run_fhe<F, T, U>(
c: F,
_u: PhantomData<U>,
n_0: T,
n_1: T,
n_2: T,
plain_modulus: PlainModulusConstraint,
) -> Result<(), Error>
where
F: FheProgramFn + Clone + 'static + AsRef<str>,
U: From<T> + FheType + TypeName + std::fmt::Display,
{
let start = Instant::now();
let app = Compiler::new()
.fhe_program(c.clone())
.plain_modulus_constraint(plain_modulus)
.compile()?;
let elapsed = start.elapsed().as_secs_f64();
println!("\t\tCompile time {elapsed}s");
let runtime = FheRuntime::new(app.params())?;
let n_0 = U::from(n_0);
let n_1 = U::from(n_1);
let n_2 = U::from(n_2);
let start = Instant::now();
let (public_key, private_key) = runtime.generate_keys()?;
let elapsed = start.elapsed().as_secs_f64();
println!("\t\tKeygen time {elapsed}s");
let start = Instant::now();
let n_0_enc = runtime.encrypt(n_0, &public_key)?;
let n_1_enc = runtime.encrypt(n_1, &public_key)?;
let n_2_enc = runtime.encrypt(n_2, &public_key)?;
let elapsed = start.elapsed().as_secs_f64();
println!("\t\tEncryption time {elapsed}s");
let start = Instant::now();
let args: Vec<FheProgramInput> = vec![n_0_enc.into(), n_1_enc.into(), n_2_enc.into()];
let result = runtime.run(app.get_fhe_program(c).unwrap(), args, &public_key)?;
let elapsed = start.elapsed().as_secs_f64();
println!("\t\tRun time {elapsed}s");
let start = Instant::now();
let a: U = runtime.decrypt(&result[0], &private_key)?;
let b_1: U = runtime.decrypt(&result[1], &private_key)?;
let b_2: U = runtime.decrypt(&result[2], &private_key)?;
let b_3: U = runtime.decrypt(&result[3], &private_key)?;
let elapsed = start.elapsed().as_secs_f64();
println!("\t\tDecryption time {elapsed}s");
println!("\t\tchi_sq (fhe) alpha {a:40}, beta_1 {b_1:40}, beta_2 {b_2:40}, beta_3 {b_3:40}");
Ok(())
}
fn main() -> Result<(), Error> {
let n_0 = 2;
let n_1 = 7;
let n_2 = 9;
env_logger::init();
// Signed types allow us to use a really small modulus,
// allowing us to get very performant parameters.
let plain_modulus = PlainModulusConstraint::Raw(64);
println!("**********Naive**************");
println!("\t**********Native************");
run_native(chi_sq_impl, n_0, n_1, n_2);
println!("\t**********FHE************");
run_fhe(
chi_sq_fhe_program,
PhantomData::<Signed>,
n_0,
n_1,
n_2,
plain_modulus,
)?;
println!("**********Optimized************");
println!("\t**********Native************");
// run_native(chi_sq_optimized_impl, n_0, n_1, n_2);
println!("\t**********FHE************");
// On a first-gen M1 mac, the optimized fhe_program is around 6
// orders of magnitude slower than running natively, taking
// just under 50ms...
run_fhe(
chi_sq_optimized_fhe_program,
PhantomData::<Signed>,
n_0,
n_1,
n_2,
plain_modulus,
)?;
// Pack repetitions of n_0, n_1, n_2 into 2x4096 vectors
// to demonstrate batching.
let n_0 = [[n_0; 4096], [n_0; 4096]];
let n_1 = [[n_1; 4096], [n_1; 4096]];
let n_2 = [[n_2; 4096], [n_2; 4096]];
let plain_modulus = PlainModulusConstraint::BatchingMinimum(16);
// Using batching, we get a fhe_program
// that runs with the same latency, but rather than computing
// 1 instance of the chi squared function, we can compute
// 16_384 values concurrently. This would result in an
// amortized throughput only 1-2 orders of magnitude
// slower than native!
println!("**********Batched************");
println!("\t**********FHE************");
run_fhe(
chi_sq_batched_fhe_program,
PhantomData::<Batched<4096>>,
n_0,
n_1,
n_2,
plain_modulus,
)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn main_works() {
let ten_mb = 10 * 1024 * 1024;
let builder = std::thread::Builder::new().stack_size(ten_mb);
let handler = builder.spawn(main).unwrap();
handler.join().unwrap().unwrap();
}
}