diff --git a/cmake/scripts/test.sh b/cmake/scripts/test.sh index f437d0281..59188ee28 100755 --- a/cmake/scripts/test.sh +++ b/cmake/scripts/test.sh @@ -163,6 +163,7 @@ run_examples() { python3 builtin_gbdt_train.py python3 builtin_online_decrypt.py python3 builtin_private_join_and_compute.py + python3 builtin_ordered_set_intersect.py python3 builtin_rsa_sign.py popd diff --git a/examples/python/builtin_ordered_set_intersect.py b/examples/python/builtin_ordered_set_intersect.py new file mode 100644 index 000000000..5b0a1dcd2 --- /dev/null +++ b/examples/python/builtin_ordered_set_intersect.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 + +import sys + +from teaclave import (AuthenticationService, FrontendService, + AuthenticationClient, FrontendClient, FunctionInput, + FunctionOutput, OwnerList, DataMap) +from utils import (AUTHENTICATION_SERVICE_ADDRESS, FRONTEND_SERVICE_ADDRESS, + AS_ROOT_CA_CERT_PATH, ENCLAVE_INFO_PATH, USER_ID, + USER_PASSWORD) + +# In the example, user 0 creates the task and user 0, 1, upload their private data. +# Then user 0 invokes the task and user 0, 1 get the result. + + +class UserData: + def __init__(self, + user_id, + password, + input_url="", + output_url="", + input_cmac="", + key=[]): + self.user_id = user_id + self.password = password + self.input_url = input_url + self.output_url = output_url + self.input_cmac = input_cmac + self.key = key + + +INPUT_FILE_URL_PREFIX = "http://localhost:6789/fixtures/functions/ordered_set_intersect/" +OUTPUT_FILE_URL_PREFIX = "http://localhost:6789/fixtures/functions/ordered_set_intersect/" + +USER_DATA_0 = UserData("user0", "password", + INPUT_FILE_URL_PREFIX + "psi0.txt.enc", + OUTPUT_FILE_URL_PREFIX + "output_psi0.enc", + "e08adeb021e876ffe82234445e632121", + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + +USER_DATA_1 = UserData("user1", "password", + INPUT_FILE_URL_PREFIX + "psi1.txt.enc", + OUTPUT_FILE_URL_PREFIX + "output_psi1.enc", + "538dafbf7802d962bb01e2389b4e943a", + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + +class DataList: + def __init__(self, data_name, data_id): + self.data_name = data_name + self.data_id = data_id + + +class Client: + def __init__(self, user_id, user_password): + self.user_id = user_id + self.user_password = user_password + self.client = AuthenticationService( + AUTHENTICATION_SERVICE_ADDRESS, AS_ROOT_CA_CERT_PATH, + ENCLAVE_INFO_PATH).connect().get_client() + print(f"[+] {self.user_id} registering user") + self.client.user_register(self.user_id, self.user_password) + print(f"[+] {self.user_id} login") + token = self.client.user_login(self.user_id, self.user_password) + self.client = FrontendService( + FRONTEND_SERVICE_ADDRESS, AS_ROOT_CA_CERT_PATH, + ENCLAVE_INFO_PATH).connect().get_client() + metadata = {"id": self.user_id, "token": token} + self.client.metadata = metadata + + def set_task(self): + client = self.client + + print(f"[+] {self.user_id} registering function") + + function_id = client.register_function( + name="builtin-ordered-set-intersect", + description="Native Private Set Intersection", + executor_type="builtin", + arguments=["order"], + inputs=[ + FunctionInput("input_data1", "Client 0 data."), + FunctionInput("input_data2", "Client 1 data.") + ], + outputs=[ + FunctionOutput("output_result1", "Output data."), + FunctionOutput("output_result2", "Output data.") + ]) + + print(f"[+] {self.user_id} creating task") + task_id = client.create_task( + function_id=function_id, + function_arguments=({ + "order": "ascending", # Order can be ascending or desending + }), + executor="builtin", + inputs_ownership=[ + OwnerList("input_data1", [USER_DATA_0.user_id]), + OwnerList("input_data2", [USER_DATA_1.user_id]) + ], + outputs_ownership=[ + OwnerList("output_result1", [USER_DATA_0.user_id]), + OwnerList("output_result2", [USER_DATA_1.user_id]) + ]) + + return task_id + + def run_task(self, task_id): + client = self.client + print(f"[+] {self.user_id} invoking task") + client.invoke_task(task_id) + + def register_data(self, task_id, input_url, input_cmac, output_url, + file_key, input_label, output_label): + client = self.client + + print(f"[+] {self.user_id} registering input file") + url = input_url + cmac = input_cmac + schema = "teaclave-file-128" + key = file_key + iv = [] + input_id = client.register_input_file(url, schema, key, iv, cmac) + print(f"[+] {self.user_id} registering output file") + url = output_url + schema = "teaclave-file-128" + key = file_key + iv = [] + output_id = client.register_output_file(url, schema, key, iv) + + print(f"[+] {self.user_id} assigning data to task") + client.assign_data_to_task(task_id, [DataList(input_label, input_id)], + [DataList(output_label, output_id)]) + + def approve_task(self, task_id): + client = self.client + print(f"[+] {self.user_id} approving task") + client.approve_task(task_id) + + def get_task_result(self, task_id): + client = self.client + print(f"[+] {self.user_id} getting task result") + return bytes(client.get_task_result(task_id)) + + +def main(): + user0 = Client(USER_DATA_0.user_id, USER_DATA_0.password) + user1 = Client(USER_DATA_1.user_id, USER_DATA_1.password) + + task_id = user0.set_task() + + user0.register_data(task_id, USER_DATA_0.input_url, USER_DATA_0.input_cmac, + USER_DATA_0.output_url, USER_DATA_0.key, "input_data1", + "output_result1") + + user1.register_data(task_id, USER_DATA_1.input_url, USER_DATA_1.input_cmac, + USER_DATA_1.output_url, USER_DATA_1.key, "input_data2", + "output_result2") + + user0.approve_task(task_id) + user1.approve_task(task_id) + + ## USER 0 start the computation + user0.run_task(task_id) + + ## USER 0, 1 get the result + result_user0 = user0.get_task_result(task_id) + result_user1 = user1.get_task_result(task_id) + + print("[+] User 0 result: " + result_user0.decode("utf-8")) + print("[+] User 1 result: " + result_user1.decode("utf-8")) + + +if __name__ == '__main__': + main() diff --git a/executor/Cargo.toml b/executor/Cargo.toml index c102d446e..ea03e45d3 100644 --- a/executor/Cargo.toml +++ b/executor/Cargo.toml @@ -33,6 +33,7 @@ full_builtin_function = [ "builtin_logistic_regression_train", "builtin_online_decrypt", "builtin_private_join_and_compute", + "builtin_ordered_set_intersect", "builtin_rsa_sign", ] @@ -43,6 +44,7 @@ builtin_logistic_regression_predict = [] builtin_logistic_regression_train = [] builtin_online_decrypt = [] builtin_private_join_and_compute = [] +builtin_ordered_set_intersect = [] builtin_rsa_sign = [] [dependencies] diff --git a/executor/src/builtin.rs b/executor/src/builtin.rs index 5874d7442..677e7828c 100644 --- a/executor/src/builtin.rs +++ b/executor/src/builtin.rs @@ -20,7 +20,7 @@ use std::prelude::v1::*; use teaclave_function::{ Echo, GbdtPredict, GbdtTrain, LogisticRegressionPredict, LogisticRegressionTrain, - OnlineDecrypt, PrivateJoinAndCompute, RsaSign, + OnlineDecrypt, OrderedSetIntersect, PrivateJoinAndCompute, RsaSign, }; use teaclave_types::{FunctionArguments, FunctionRuntime, TeaclaveExecutor}; @@ -54,6 +54,8 @@ impl TeaclaveExecutor for BuiltinFunctionExecutor { OnlineDecrypt::NAME => OnlineDecrypt::new().run(arguments, runtime), #[cfg(feature = "builtin_private_join_and_compute")] PrivateJoinAndCompute::NAME => PrivateJoinAndCompute::new().run(arguments, runtime), + #[cfg(feature = "builtin_ordered_set_intersect")] + OrderedSetIntersect::NAME => OrderedSetIntersect::new().run(arguments, runtime), #[cfg(feature = "builtin_rsa_sign")] RsaSign::NAME => RsaSign::new().run(arguments, runtime), _ => bail!("Function not found."), diff --git a/function/Cargo.toml b/function/Cargo.toml index 587c10ba5..638482e3d 100644 --- a/function/Cargo.toml +++ b/function/Cargo.toml @@ -33,6 +33,7 @@ rusty-machine = { version = "0.5.4" } itertools = { version = "0.8.0", default-features = false } ring = { version = "0.16.5" } base64 = { version = "0.10.1" } +hex = { version = "0.4.0" } teaclave_types = { path = "../types" } teaclave_crypto = { path = "../crypto" } teaclave_runtime = { path = "../runtime", optional = true } diff --git a/function/README.md b/function/README.md index 91bed63e3..fbcfc920c 100644 --- a/function/README.md +++ b/function/README.md @@ -20,6 +20,10 @@ Currently, we have these built-in functions: - `builtin-logistic-regression-predict`: LR prediction with input model and input test data. - `builtin-private-join-and-compute`: Find intersection of muti-parties' input data and compute sum of the common items. + - `builtin-ordered-set-intersect`: Allow two parties to compute the + intersection of their ordered sets without revealing anything except for the + elements in the intersection. Users should calculate hash values of each item + and upload them as a sorted list. - `builtin-rsa-sign`: Signing data with RSA key. The function arguments are in JSON format and can be serialized to a Rust struct diff --git a/function/src/lib.rs b/function/src/lib.rs index 2d436409f..1ea110fb0 100644 --- a/function/src/lib.rs +++ b/function/src/lib.rs @@ -28,6 +28,7 @@ mod gbdt_train; mod logistic_regression_predict; mod logistic_regression_train; mod online_decrypt; +mod ordered_set_intersect; mod private_join_and_compute; mod rsa_sign; @@ -37,6 +38,7 @@ pub use gbdt_train::GbdtTrain; pub use logistic_regression_predict::LogisticRegressionPredict; pub use logistic_regression_train::LogisticRegressionTrain; pub use online_decrypt::OnlineDecrypt; +pub use ordered_set_intersect::OrderedSetIntersect; pub use private_join_and_compute::PrivateJoinAndCompute; pub use rsa_sign::RsaSign; @@ -54,6 +56,7 @@ pub mod tests { logistic_regression_predict::tests::run_tests(), online_decrypt::tests::run_tests(), private_join_and_compute::tests::run_tests(), + ordered_set_intersect::tests::run_tests(), rsa_sign::tests::run_tests(), ) } diff --git a/function/src/ordered_set_intersect.rs b/function/src/ordered_set_intersect.rs new file mode 100644 index 000000000..ae1c08637 --- /dev/null +++ b/function/src/ordered_set_intersect.rs @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use anyhow::bail; +use std::cmp; +use std::convert::TryFrom; +use std::format; +use std::io::{self, BufRead, BufReader, Write}; +#[cfg(feature = "mesalock_sgx")] +use std::prelude::v1::*; +use teaclave_types::{FunctionArguments, FunctionRuntime}; + +extern crate hex; +extern crate sgx_tstd as std; + +// Input data should be a list of sorted hash values. + +const IN_DATA1: &str = "input_data1"; +const IN_DATA2: &str = "input_data2"; +const OUT_RESULT1: &str = "output_result1"; +const OUT_RESULT2: &str = "output_result2"; + +#[derive(Default)] +pub struct OrderedSetIntersect; + +#[derive(serde::Deserialize)] +pub struct OrderedSetIntersectArguments { + order: String, +} + +impl TryFrom for OrderedSetIntersectArguments { + type Error = anyhow::Error; + + fn try_from(arguments: FunctionArguments) -> Result { + use anyhow::Context; + serde_json::from_str(&arguments.into_string()).context("Cannot deserialize arguments") + } +} + +impl OrderedSetIntersect { + pub const NAME: &'static str = "builtin-ordered-set-intersect"; + + pub fn new() -> Self { + Default::default() + } + + pub fn run( + &self, + arguments: FunctionArguments, + runtime: FunctionRuntime, + ) -> anyhow::Result { + let input1 = runtime.open_input(IN_DATA1)?; + let input2 = runtime.open_input(IN_DATA2)?; + let mut output1 = runtime.create_output(OUT_RESULT1)?; + let mut output2 = runtime.create_output(OUT_RESULT2)?; + let args = OrderedSetIntersectArguments::try_from(arguments)?; + let order = &args.order[..]; + let ascending_order = match order { + "ascending" => true, + "desending" => false, + _ => bail!("Invalid order"), + }; + + let vec1 = parse_input_data(input1, ascending_order)?; + let vec2 = parse_input_data(input2, ascending_order)?; + let (result1, result2) = intersection_ordered_vec(&vec1, &vec2, ascending_order)?; + + let mut common_sets = 0; + + for item in result1 { + write!(&mut output1, "{}", item as u32)?; + if item { + common_sets += 1; + } + } + for item in result2 { + write!(&mut output2, "{}", item as u32)?; + } + Ok(format!("{} common items", common_sets)) + } +} + +fn parse_input_data(input: impl io::Read, ascending_order: bool) -> anyhow::Result>> { + let mut samples: Vec> = Vec::new(); + let reader = BufReader::new(input); + for (index, byte_result) in reader.lines().enumerate() { + let byte = byte_result?; + let result = hex::decode(byte)?; + if index > 0 { + // If vec has more than 2 elements, then verify the ordering + let last_element = &samples[index - 1]; + if ascending_order && result < *last_element + || !ascending_order && result > *last_element + { + bail!("Invalid ordering"); + } + } + samples.push(result) + } + + Ok(samples) +} + +fn intersection_ordered_vec( + input1: &[Vec], + input2: &[Vec], + ascending_order: bool, +) -> anyhow::Result<(Vec, Vec)> { + let v1_len = input1.len(); + let v2_len = input2.len(); + + let mut res1 = std::vec![false; v1_len]; + let mut res2 = std::vec![false; v2_len]; + + let mut i = 0; + let mut j = 0; + + while i < v1_len && j < v2_len { + let order = &input1[i].cmp(&input2[j]); + match order { + cmp::Ordering::Equal => { + res1[i] = true; + res2[j] = true; + i += 1; + j += 1; + } + cmp::Ordering::Less => { + if ascending_order { + i += 1; + } else { + j += 1; + } + } + cmp::Ordering::Greater => { + if ascending_order { + j += 1; + } else { + i += 1; + } + } + } + } + Ok((res1, res2)) +} + +#[cfg(feature = "enclave_unit_test")] +pub mod tests { + use super::*; + use serde_json::json; + use std::path::Path; + use std::untrusted::fs; + use teaclave_crypto::*; + use teaclave_runtime::*; + use teaclave_test_utils::*; + use teaclave_types::*; + + pub fn run_tests() -> bool { + run_tests!(test_ordered_set_intersect) + } + + fn test_ordered_set_intersect() { + let arguments = FunctionArguments::from_json(json!({ + "order": "ascending" + })) + .unwrap(); + + let base = Path::new("fixtures/functions/ordered_set_intersect"); + + let user1_input = base.join("psi0.txt"); + let user1_output = base.join("output_psi0.txt"); + + let user2_input = base.join("psi1.txt"); + let user2_output = base.join("output_psi1.txt"); + + let input_files = StagedFiles::new(hashmap!( + IN_DATA1 => + StagedFileInfo::new(&user1_input, TeaclaveFile128Key::random(), FileAuthTag::mock()), + IN_DATA2 => + StagedFileInfo::new(&user2_input, TeaclaveFile128Key::random(), FileAuthTag::mock()), + )); + + let output_files = StagedFiles::new(hashmap!( + OUT_RESULT1 => + StagedFileInfo::new(&user1_output, TeaclaveFile128Key::random(), FileAuthTag::mock()), + OUT_RESULT2 => + StagedFileInfo::new(&user2_output, TeaclaveFile128Key::random(), FileAuthTag::mock()), + )); + + let runtime = Box::new(RawIoRuntime::new(input_files, output_files)); + let summary = OrderedSetIntersect::new().run(arguments, runtime).unwrap(); + + let user1_result = fs::read_to_string(&user1_output).unwrap(); + let user2_result = fs::read_to_string(&user2_output).unwrap(); + + assert_eq!(&user1_result[..], "0101010"); + assert_eq!(&user2_result[..], "01101"); + assert_eq!(summary, "3 common items"); + } +} diff --git a/tests/fixtures/functions/ordered_set_intersect/psi0.txt b/tests/fixtures/functions/ordered_set_intersect/psi0.txt new file mode 100644 index 000000000..a2dd24683 --- /dev/null +++ b/tests/fixtures/functions/ordered_set_intersect/psi0.txt @@ -0,0 +1,7 @@ +3129a6f57c01547906c4f851de448d4a85716927d9aae5d13955303833dea3be +3c2ef1901bee3a4866d68e16de37a270e4f16d166132f14da88b5d0bb5c5a369 +699bd76eb9764233eade0f5ca571e86b01b59ef6051e6008f2ab1723b1ba20e8 +6b51d431df5d7f141cbececcf79edf3dd861c3b4069f0b11661a3eefacbba918 +7a90238b179e5d28faa81dcffee49fcd200d591a61f9d0ba9d76eca3cb71a813 +fa3cfb3f1bb823aa9501f88f1f95f732ee6fef2c3a48be7f1d38037b216a549f +ffff5954ee15325a8af0a1251b5e6dc255975484df25c5f9f24542479d8d340e \ No newline at end of file diff --git a/tests/fixtures/functions/ordered_set_intersect/psi0.txt.enc b/tests/fixtures/functions/ordered_set_intersect/psi0.txt.enc new file mode 100644 index 000000000..2b27556b1 Binary files /dev/null and b/tests/fixtures/functions/ordered_set_intersect/psi0.txt.enc differ diff --git a/tests/fixtures/functions/ordered_set_intersect/psi1.txt b/tests/fixtures/functions/ordered_set_intersect/psi1.txt new file mode 100644 index 000000000..618168e96 --- /dev/null +++ b/tests/fixtures/functions/ordered_set_intersect/psi1.txt @@ -0,0 +1,5 @@ +35d85143c3bd10badcad7d3e01bdbad074e4d62a9f04f9c8652da5f5259fed7d +3c2ef1901bee3a4866d68e16de37a270e4f16d166132f14da88b5d0bb5c5a369 +6b51d431df5d7f141cbececcf79edf3dd861c3b4069f0b11661a3eefacbba918 +87e58365cf5292ae0150b97d5bba026158e28a5c2fa32cb04cf4c6a0d0c97111 +fa3cfb3f1bb823aa9501f88f1f95f732ee6fef2c3a48be7f1d38037b216a549f \ No newline at end of file diff --git a/tests/fixtures/functions/ordered_set_intersect/psi1.txt.enc b/tests/fixtures/functions/ordered_set_intersect/psi1.txt.enc new file mode 100644 index 000000000..2607b43ce Binary files /dev/null and b/tests/fixtures/functions/ordered_set_intersect/psi1.txt.enc differ