From a4fd7a6d2f99090eb49d9a1ad5088ac8f6e4851b Mon Sep 17 00:00:00 2001 From: Qinkun Bao Date: Thu, 11 Jun 2020 11:17:03 -0400 Subject: [PATCH 1/5] add ordered set intersection --- cmake/scripts/test.sh | 1 + .../python/builtin_ordered_set_intersect.py | 175 ++++++++++++++ executor/Cargo.toml | 2 + executor/src/builtin.rs | 4 +- function/Cargo.toml | 1 + function/README.md | 4 + function/src/lib.rs | 3 + function/src/ordered_set_intersect.rs | 217 ++++++++++++++++++ .../functions/ordered_set_intersect/psi0.txt | 7 + .../ordered_set_intersect/psi0.txt.enc | Bin 0 -> 4096 bytes .../functions/ordered_set_intersect/psi1.txt | 5 + .../ordered_set_intersect/psi1.txt.enc | Bin 0 -> 4096 bytes 12 files changed, 418 insertions(+), 1 deletion(-) create mode 100644 examples/python/builtin_ordered_set_intersect.py create mode 100644 function/src/ordered_set_intersect.rs create mode 100644 tests/fixtures/functions/ordered_set_intersect/psi0.txt create mode 100644 tests/fixtures/functions/ordered_set_intersect/psi0.txt.enc create mode 100644 tests/fixtures/functions/ordered_set_intersect/psi1.txt create mode 100644 tests/fixtures/functions/ordered_set_intersect/psi1.txt.enc diff --git a/cmake/scripts/test.sh b/cmake/scripts/test.sh index 87e48a83d..c81f95a8e 100755 --- a/cmake/scripts/test.sh +++ b/cmake/scripts/test.sh @@ -163,6 +163,7 @@ run_examples() { python3 builtin_gbdt_train.py python3 builtin_online_decrypt.py python3 builtin_private_join_and_compute.py + python3 builtin_ordered_set_intersect.py popd # kill all background services diff --git a/examples/python/builtin_ordered_set_intersect.py b/examples/python/builtin_ordered_set_intersect.py new file mode 100644 index 000000000..5b0a1dcd2 --- /dev/null +++ b/examples/python/builtin_ordered_set_intersect.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 + +import sys + +from teaclave import (AuthenticationService, FrontendService, + AuthenticationClient, FrontendClient, FunctionInput, + FunctionOutput, OwnerList, DataMap) +from utils import (AUTHENTICATION_SERVICE_ADDRESS, FRONTEND_SERVICE_ADDRESS, + AS_ROOT_CA_CERT_PATH, ENCLAVE_INFO_PATH, USER_ID, + USER_PASSWORD) + +# In the example, user 0 creates the task and user 0, 1, upload their private data. +# Then user 0 invokes the task and user 0, 1 get the result. + + +class UserData: + def __init__(self, + user_id, + password, + input_url="", + output_url="", + input_cmac="", + key=[]): + self.user_id = user_id + self.password = password + self.input_url = input_url + self.output_url = output_url + self.input_cmac = input_cmac + self.key = key + + +INPUT_FILE_URL_PREFIX = "http://localhost:6789/fixtures/functions/ordered_set_intersect/" +OUTPUT_FILE_URL_PREFIX = "http://localhost:6789/fixtures/functions/ordered_set_intersect/" + +USER_DATA_0 = UserData("user0", "password", + INPUT_FILE_URL_PREFIX + "psi0.txt.enc", + OUTPUT_FILE_URL_PREFIX + "output_psi0.enc", + "e08adeb021e876ffe82234445e632121", + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + +USER_DATA_1 = UserData("user1", "password", + INPUT_FILE_URL_PREFIX + "psi1.txt.enc", + OUTPUT_FILE_URL_PREFIX + "output_psi1.enc", + "538dafbf7802d962bb01e2389b4e943a", + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + +class DataList: + def __init__(self, data_name, data_id): + self.data_name = data_name + self.data_id = data_id + + +class Client: + def __init__(self, user_id, user_password): + self.user_id = user_id + self.user_password = user_password + self.client = AuthenticationService( + AUTHENTICATION_SERVICE_ADDRESS, AS_ROOT_CA_CERT_PATH, + ENCLAVE_INFO_PATH).connect().get_client() + print(f"[+] {self.user_id} registering user") + self.client.user_register(self.user_id, self.user_password) + print(f"[+] {self.user_id} login") + token = self.client.user_login(self.user_id, self.user_password) + self.client = FrontendService( + FRONTEND_SERVICE_ADDRESS, AS_ROOT_CA_CERT_PATH, + ENCLAVE_INFO_PATH).connect().get_client() + metadata = {"id": self.user_id, "token": token} + self.client.metadata = metadata + + def set_task(self): + client = self.client + + print(f"[+] {self.user_id} registering function") + + function_id = client.register_function( + name="builtin-ordered-set-intersect", + description="Native Private Set Intersection", + executor_type="builtin", + arguments=["order"], + inputs=[ + FunctionInput("input_data1", "Client 0 data."), + FunctionInput("input_data2", "Client 1 data.") + ], + outputs=[ + FunctionOutput("output_result1", "Output data."), + FunctionOutput("output_result2", "Output data.") + ]) + + print(f"[+] {self.user_id} creating task") + task_id = client.create_task( + function_id=function_id, + function_arguments=({ + "order": "ascending", # Order can be ascending or desending + }), + executor="builtin", + inputs_ownership=[ + OwnerList("input_data1", [USER_DATA_0.user_id]), + OwnerList("input_data2", [USER_DATA_1.user_id]) + ], + outputs_ownership=[ + OwnerList("output_result1", [USER_DATA_0.user_id]), + OwnerList("output_result2", [USER_DATA_1.user_id]) + ]) + + return task_id + + def run_task(self, task_id): + client = self.client + print(f"[+] {self.user_id} invoking task") + client.invoke_task(task_id) + + def register_data(self, task_id, input_url, input_cmac, output_url, + file_key, input_label, output_label): + client = self.client + + print(f"[+] {self.user_id} registering input file") + url = input_url + cmac = input_cmac + schema = "teaclave-file-128" + key = file_key + iv = [] + input_id = client.register_input_file(url, schema, key, iv, cmac) + print(f"[+] {self.user_id} registering output file") + url = output_url + schema = "teaclave-file-128" + key = file_key + iv = [] + output_id = client.register_output_file(url, schema, key, iv) + + print(f"[+] {self.user_id} assigning data to task") + client.assign_data_to_task(task_id, [DataList(input_label, input_id)], + [DataList(output_label, output_id)]) + + def approve_task(self, task_id): + client = self.client + print(f"[+] {self.user_id} approving task") + client.approve_task(task_id) + + def get_task_result(self, task_id): + client = self.client + print(f"[+] {self.user_id} getting task result") + return bytes(client.get_task_result(task_id)) + + +def main(): + user0 = Client(USER_DATA_0.user_id, USER_DATA_0.password) + user1 = Client(USER_DATA_1.user_id, USER_DATA_1.password) + + task_id = user0.set_task() + + user0.register_data(task_id, USER_DATA_0.input_url, USER_DATA_0.input_cmac, + USER_DATA_0.output_url, USER_DATA_0.key, "input_data1", + "output_result1") + + user1.register_data(task_id, USER_DATA_1.input_url, USER_DATA_1.input_cmac, + USER_DATA_1.output_url, USER_DATA_1.key, "input_data2", + "output_result2") + + user0.approve_task(task_id) + user1.approve_task(task_id) + + ## USER 0 start the computation + user0.run_task(task_id) + + ## USER 0, 1 get the result + result_user0 = user0.get_task_result(task_id) + result_user1 = user1.get_task_result(task_id) + + print("[+] User 0 result: " + result_user0.decode("utf-8")) + print("[+] User 1 result: " + result_user1.decode("utf-8")) + + +if __name__ == '__main__': + main() diff --git a/executor/Cargo.toml b/executor/Cargo.toml index 37550378c..0e021a93d 100644 --- a/executor/Cargo.toml +++ b/executor/Cargo.toml @@ -33,6 +33,7 @@ full_builtin_function = [ "builtin_logistic_regression_train", "builtin_online_decrypt", "builtin_private_join_and_compute", + "builtin_ordered_set_intersect", ] builtin_echo = [] @@ -42,6 +43,7 @@ builtin_logistic_regression_predict = [] builtin_logistic_regression_train = [] builtin_online_decrypt = [] builtin_private_join_and_compute = [] +builtin_ordered_set_intersect = [] [dependencies] log = { version = "0.4.6", features = ["release_max_level_info"] } diff --git a/executor/src/builtin.rs b/executor/src/builtin.rs index ef4c4137e..dbf91e63c 100644 --- a/executor/src/builtin.rs +++ b/executor/src/builtin.rs @@ -20,7 +20,7 @@ use std::prelude::v1::*; use teaclave_function::{ Echo, GbdtPredict, GbdtTrain, LogisticRegressionPredict, LogisticRegressionTrain, - OnlineDecrypt, PrivateJoinAndCompute, + OnlineDecrypt, OrderedSetIntersect, PrivateJoinAndCompute, }; use teaclave_types::{FunctionArguments, FunctionRuntime, TeaclaveExecutor}; @@ -54,6 +54,8 @@ impl TeaclaveExecutor for BuiltinFunctionExecutor { OnlineDecrypt::NAME => OnlineDecrypt::new().run(arguments, runtime), #[cfg(feature = "builtin_private_join_and_compute")] PrivateJoinAndCompute::NAME => PrivateJoinAndCompute::new().run(arguments, runtime), + #[cfg(feature = "builtin_ordered_set_intersect")] + OrderedSetIntersect::NAME => OrderedSetIntersect::new().run(arguments, runtime), _ => bail!("Function not found."), } } diff --git a/function/Cargo.toml b/function/Cargo.toml index 587c10ba5..638482e3d 100644 --- a/function/Cargo.toml +++ b/function/Cargo.toml @@ -33,6 +33,7 @@ rusty-machine = { version = "0.5.4" } itertools = { version = "0.8.0", default-features = false } ring = { version = "0.16.5" } base64 = { version = "0.10.1" } +hex = { version = "0.4.0" } teaclave_types = { path = "../types" } teaclave_crypto = { path = "../crypto" } teaclave_runtime = { path = "../runtime", optional = true } diff --git a/function/README.md b/function/README.md index 42df46474..edecf3d76 100644 --- a/function/README.md +++ b/function/README.md @@ -20,6 +20,10 @@ Currently, we have these built-in functions: - `builtin-logistic-regression-predict`: LR prediction with input model and input test data. - `builtin-private-join-and-compute`: Find intersection of muti-parties' input data and compute sum of the common items. + - `builtin-ordered-set-intersect`: Allow two parties to compute the + intersection of their ordered sets without revealing anything except for the + elements in the intersection. Users should calculate hash values of each item + and upload them as a sorted list. The function arguments are in JSON format and can be serialized to a Rust struct very easily. You can learn more about supported arguments in the implementation diff --git a/function/src/lib.rs b/function/src/lib.rs index 0e8388541..e13fb5e3d 100644 --- a/function/src/lib.rs +++ b/function/src/lib.rs @@ -28,6 +28,7 @@ mod gbdt_train; mod logistic_regression_predict; mod logistic_regression_train; mod online_decrypt; +mod ordered_set_intersect; mod private_join_and_compute; pub use echo::Echo; @@ -36,6 +37,7 @@ pub use gbdt_train::GbdtTrain; pub use logistic_regression_predict::LogisticRegressionPredict; pub use logistic_regression_train::LogisticRegressionTrain; pub use online_decrypt::OnlineDecrypt; +pub use ordered_set_intersect::OrderedSetIntersect; pub use private_join_and_compute::PrivateJoinAndCompute; #[cfg(feature = "enclave_unit_test")] @@ -52,6 +54,7 @@ pub mod tests { logistic_regression_predict::tests::run_tests(), online_decrypt::tests::run_tests(), private_join_and_compute::tests::run_tests(), + ordered_set_intersect::tests::run_tests(), ) } } diff --git a/function/src/ordered_set_intersect.rs b/function/src/ordered_set_intersect.rs new file mode 100644 index 000000000..7fd02a4da --- /dev/null +++ b/function/src/ordered_set_intersect.rs @@ -0,0 +1,217 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use anyhow::bail; +use std::cmp; +use std::convert::TryFrom; +use std::format; +use std::io::{self, BufRead, BufReader, Write}; +#[cfg(feature = "mesalock_sgx")] +use std::prelude::v1::*; +use teaclave_types::{FunctionArguments, FunctionRuntime}; + +extern crate hex; +extern crate sgx_tstd as std; + +// Input data should be a list of sorted hash values. + +const IN_DATA1: &str = "input_data1"; +const IN_DATA2: &str = "input_data2"; +const OUT_RESULT1: &str = "output_result1"; +const OUT_RESULT2: &str = "output_result2"; + +#[derive(Default)] +pub struct OrderedSetIntersect; + +#[derive(serde::Deserialize)] +pub struct OrderedSetIntersectArguments { + order: String, +} + +impl TryFrom for OrderedSetIntersectArguments { + type Error = anyhow::Error; + + fn try_from(arguments: FunctionArguments) -> Result { + use anyhow::Context; + serde_json::from_str(&arguments.into_string()).context("Cannot deserialize arguments") + } +} + +impl OrderedSetIntersect { + pub const NAME: &'static str = "builtin-ordered-set-intersect"; + + pub fn new() -> Self { + Default::default() + } + + pub fn run( + &self, + arguments: FunctionArguments, + runtime: FunctionRuntime, + ) -> anyhow::Result { + let input1 = runtime.open_input(IN_DATA1)?; + let input2 = runtime.open_input(IN_DATA2)?; + let mut output1 = runtime.create_output(OUT_RESULT1)?; + let mut output2 = runtime.create_output(OUT_RESULT2)?; + let args = OrderedSetIntersectArguments::try_from(arguments)?; + let order = &args.order[..]; + let ascending_order = match order { + "ascending" => true, + "desending" => false, + _ => bail!("Invalid order"), + }; + + let vec1 = parse_input_data(input1, ascending_order)?; + let vec2 = parse_input_data(input2, ascending_order)?; + let (result1, result2) = intersection_ordered_vec(&vec1, &vec2, ascending_order)?; + + let mut common_sets = 0; + + for item in result1 { + write!(&mut output1, "{}", item)?; + if item > 0 { + common_sets += 1; + } + } + for item in result2 { + write!(&mut output2, "{}", item)?; + } + Ok(format!("{} common items", common_sets)) + } +} + +fn parse_input_data(input: impl io::Read, ascending_order: bool) -> anyhow::Result>> { + let mut samples: Vec> = Vec::new(); + let reader = BufReader::new(input); + for byte_result in reader.lines() { + let byte = byte_result?; + let result = hex::decode(byte)?; + samples.push(result) + } + let len = samples.len(); + + // Verify the order + if len > 1 { + for i in 1..len { + if ascending_order && samples[i] < samples[i - 1] { + bail!("Invalid ordering"); + } + + if !ascending_order && samples[i] > samples[i - 1] { + bail!("Invalid ordering"); + } + } + } + Ok(samples) +} + +fn intersection_ordered_vec( + input1: &[Vec], + input2: &[Vec], + ascending_order: bool, +) -> anyhow::Result<(Vec, Vec)> { + let v1_len = input1.len(); + let v2_len = input2.len(); + + let mut res1 = std::vec![0; v1_len]; + let mut res2 = std::vec![0; v2_len]; + + let mut i = 0; + let mut j = 0; + + while i < v1_len && j < v2_len { + let order = &input1[i].cmp(&input2[j]); + match order { + cmp::Ordering::Equal => { + res1[i] = 1; + res2[j] = 1; + i += 1; + j += 1; + } + cmp::Ordering::Less => { + if ascending_order { + i += 1; + } else { + j += 1; + } + } + cmp::Ordering::Greater => { + if ascending_order { + j += 1; + } else { + i += 1; + } + } + } + } + Ok((res1, res2)) +} + +#[cfg(feature = "enclave_unit_test")] +pub mod tests { + use super::*; + use serde_json::json; + use std::path::Path; + use std::untrusted::fs; + use teaclave_crypto::*; + use teaclave_runtime::*; + use teaclave_test_utils::*; + use teaclave_types::*; + + pub fn run_tests() -> bool { + run_tests!(test_ordered_set_intersect) + } + + fn test_ordered_set_intersect() { + let arguments = FunctionArguments::from_json(json!({ + "order": "ascending" + })) + .unwrap(); + + let base = Path::new("fixtures/functions/ordered_set_intersect"); + + let user1_input = base.join("psi0.txt"); + let user1_output = base.join("output_psi0.txt"); + + let user2_input = base.join("psi1.txt"); + let user2_output = base.join("output_psi1.txt"); + + let input_files = StagedFiles::new(hashmap!( + IN_DATA1 => + StagedFileInfo::new(&user1_input, TeaclaveFile128Key::random(), FileAuthTag::mock()), + IN_DATA2 => + StagedFileInfo::new(&user2_input, TeaclaveFile128Key::random(), FileAuthTag::mock()), + )); + + let output_files = StagedFiles::new(hashmap!( + OUT_RESULT1 => + StagedFileInfo::new(&user1_output, TeaclaveFile128Key::random(), FileAuthTag::mock()), + OUT_RESULT2 => + StagedFileInfo::new(&user2_output, TeaclaveFile128Key::random(), FileAuthTag::mock()), + )); + + let runtime = Box::new(RawIoRuntime::new(input_files, output_files)); + let summary = OrderedSetIntersect::new().run(arguments, runtime).unwrap(); + + let user1_result = fs::read_to_string(&user1_output).unwrap(); + let user2_result = fs::read_to_string(&user2_output).unwrap(); + + assert_eq!(&user1_result[..], "0101010"); + assert_eq!(&user2_result[..], "01101"); + assert_eq!(summary, "3 common items"); + } +} diff --git a/tests/fixtures/functions/ordered_set_intersect/psi0.txt b/tests/fixtures/functions/ordered_set_intersect/psi0.txt new file mode 100644 index 000000000..a2dd24683 --- /dev/null +++ b/tests/fixtures/functions/ordered_set_intersect/psi0.txt @@ -0,0 +1,7 @@ +3129a6f57c01547906c4f851de448d4a85716927d9aae5d13955303833dea3be +3c2ef1901bee3a4866d68e16de37a270e4f16d166132f14da88b5d0bb5c5a369 +699bd76eb9764233eade0f5ca571e86b01b59ef6051e6008f2ab1723b1ba20e8 +6b51d431df5d7f141cbececcf79edf3dd861c3b4069f0b11661a3eefacbba918 +7a90238b179e5d28faa81dcffee49fcd200d591a61f9d0ba9d76eca3cb71a813 +fa3cfb3f1bb823aa9501f88f1f95f732ee6fef2c3a48be7f1d38037b216a549f +ffff5954ee15325a8af0a1251b5e6dc255975484df25c5f9f24542479d8d340e \ No newline at end of file diff --git a/tests/fixtures/functions/ordered_set_intersect/psi0.txt.enc b/tests/fixtures/functions/ordered_set_intersect/psi0.txt.enc new file mode 100644 index 0000000000000000000000000000000000000000..2b27556b13c140383e609ff67746275b14a5ba7f GIT binary patch literal 4096 zcmeH~XCM>+|u^n?eq=Cd4 zan_lua9>)a@>jau1zlBO;spaY_`O`i6_fFbJK=HpbPvkzPPWz~4N{H_qq(>S)(xOM z!DHLQ*2ZBMW-jK}^k8J(Hos0%-X3j&kgCwG$9RlFPsIOjksCr7enddH#1sOuLS7PX zX4%rt1=LqBtnT8;h+?bsks1sJaq(cD=oCGnd(~k`J@txv1!i4s&R_LajL1!8F+UIU zRrXt84Q7`Wr_b~F&gE0e^sTFnj30|&W58Iw1P9Bhqkj5S`7Gd-_9?oN?{#`_ z`?4l4wPFh5PPKxflb3gV=Sn_hUKzD4HYW}EVsGgT)wV4UPqZ*5Z@$|4K{h$)cf-Qi zSbt)I?$OA6(0a{s*cC8qiDrSfHe6Z2$psZG>Mne54DcMml#H^dbqQoe*E}v!rR9Vx zfRReSh?GH97A(ik`tMt~MUSh@c;bbyK!;K6{D4@~p$~;RY0gSmBT=;! z@ep(|X3U%7J`?-92dBV<3)F1Tj zcP6<8jyJ*a%`BrfMn^UJa*Jzujs8QPV1gX*KD-EdDu&ZhTwLJs%UFxjy?{5^WwR^d zeHxlOSWIQzmZso(B-k=-z@K)Q6w>r6Knc1G>;IVaf>?*;i`3cnL2DE z{W@_@Qt`Xr{3k}jxa0NluI5M3F`cNJeWdS6>>xlQ6+lo^bo#Xlg}Y(I0A zDN_Y5EL^2BOal4Er>3j!eg*c#DUfszk!Wd-8Vl$Vhkek5ssyi_$(&wg>hQUcg!<0zBr*)j>gT!JqVeNqk> zPSC$8=_dE~qt6c175{C{o}PmA^`PDIq~(k#e%_Y`x=p_Ljq}bImHlT3`3zRR-#z0$ z)fy!?<6&q7^;#*lU50gJCpTedMYw+I;`v-;LEyMJyGF{?jYU}siH?0;IK44)wROqB zt-m34qajB9L?qs1J*JE;$t?IW$$40_EvWgM@as}wsrMVR&HD^>NbK(&+4(5E%)b!9 z@@Eu6eVLA#y|Jg;4LTSeM_`}(Ee_)Pen{Zi$`y677p@-fc@B0M8$zQMI~d48iSf!o z#{GL^;ahz;0brQ~Js$MRTu8xRiSdsoT_1ow{ZQJn(>4Ti^x2DynnrJywVc@Q9N8r! zo%wQ7Hgu|HNbG@ELF$_mRC};CUEJ8<3nIm(R{L55ZF7y8;u}7-xQG!U+|cC1882Ha zT#oiL$ctPNCVQ(qE<_EF@MHeE^L-UUJnzG{!Vf8p*Dl>&Tdnrer4aSAp^Kg>;1Xqt}A% z<;WlFu08OtI16vZ5=^o9+iRlC-`Ync`_!&?JGCiIpAMh%098H!gt$D>r-~T{!SC?N z->&B8JD|WYtWg>@zD{_={$K+_SX9DQWS{KI`B(^y&{6MW22Mo04~ zMrUF_+Ad`)@^{QC2cX-u)~i%Zrni?nGux_=!j^=q(c8I{4R&Ky^m6-6A!lhCg$P?Z z^utw4v}V+E#xiwC5gfDB(F>@G_nE4Kza=~CMpNFZs|z|@o@hLM5=bNkUD+#v>Uu84 zk52Y-`AQA}n{9tnV@D?6gWT596l2lZH=gv^#Np4VoHB{{%=Mn*Ql4l4b>-zd+hWm! zI(SmWmAxTe)pl!0*>O3)ll>Wm`DO|4%KbUrd%ZYVy0`)_TW>v)H3@)9b=;RQ2f0iw zoJpqM?f}(_jPF|D=S2c-o%kG*R-Y!ypvmoyJd$Gx;mnx=`F`+VUCe>!9;(j^lL|cN zmK_Y&=7ivuyLL z*^aI0EzKUz$|f;32vo}n9jUyt@)aX6tzRoy;nKW1C0Gj5gMFmZ6%|GIaok=NGKNJp z6B8`EI)(#9EEq{(E%^JeS1y{BhXm?Chh3n{LI&c!=Rd6O>O{V#QYhF>wOHjWk;il zpp)~(f>7*P*zI*SEba%RN1cy`fsoRT7wqV~^HJ!Sj0tJ5vTI%7t9N`w46fRopt5Yu zp9*Jujk0lEr(DBn-Z%XeAtS1$dV26{RY9$r#Q}z{%DHE|FruBOgggAm8^#7M9|08B zH_lRZt^R(Fi$N+)FbkFyp#ccyRXLhp+Ew+XyNAqp>vAi@w8sCkKB>}DiYrNh zu;mT&ibP4q%B@CPcqF_n=F(GaNXlX&LCY#sIsBW-`MO=J+7~n(7NX^j?tANPc+NQv z)O6J?4n2Zzw?8oK`l+WNx0G5*CO57N9U<(WV2;0J$=%$3{eW~bXJ^svAtj7?d$~7f z@b+k6s&VP^!B70Z;enA`a)Aeo&M$a}zu{{+Ge4`HS!Y|0f1#Ur$iUA*Ll1^?er@k> z@8s<{j#WtU)Cy8`dK8RY1qEOedm=NOI2#qKSXdq%CNG_^I1-M3k)YHCkg2EEfL~847uZ5kqn0b7J2)@}A_qdgv?r8?;k8c-1)y?gdzG1cxs~g7MITIEJp>YNB8Z_5y zMSpy(#(z3X64k{K;9g})O;bhyMeoCqJC%LT&cEu>Y~E}K&1tz(y?uoiJUkc4gKu<0 zk?RtKLfF^Loe-CKg!Tp|Og=d&b>7tY;(;De%x7M;shh9xdhP%PnRP49&z&Qm^+Sug zkuwS~KwCa{dn^2j*kRNjt=&eiCv+Ro=GJjA&_&t`z;c5^Z+lF`1M!FvpEdxL6FZ#* zk~i+iAd(@un{xcke40n+s8E!-xXPZy?3X@Vb)c{^E);Vej^J2YKc+h4 zZW3?J;|i~LU+T2+sI0cP_c+OpSM@TH?#x$N2GAG7%JNTYHP-v)K-R%t%#(XHDWS!V zlo_uQf_yy(f1p~}!}awp45>W}Yiv0g_z+_OAL=wrV7*08$enGdL94ft7_<7qS)yc= zMX%IuD{a?yEm`#;6OrhCa|y!NE7=zlE~iU(ZOFlh*J^1+BA>rd#e>Pe&GZ-fSqq|k sMW>IKDFWH4v|h;)*SALu5B|)rR1_DVO&c@k{P)57AOEkw|8If+0Hu+)82|tP literal 0 HcmV?d00001 diff --git a/tests/fixtures/functions/ordered_set_intersect/psi1.txt b/tests/fixtures/functions/ordered_set_intersect/psi1.txt new file mode 100644 index 000000000..618168e96 --- /dev/null +++ b/tests/fixtures/functions/ordered_set_intersect/psi1.txt @@ -0,0 +1,5 @@ +35d85143c3bd10badcad7d3e01bdbad074e4d62a9f04f9c8652da5f5259fed7d +3c2ef1901bee3a4866d68e16de37a270e4f16d166132f14da88b5d0bb5c5a369 +6b51d431df5d7f141cbececcf79edf3dd861c3b4069f0b11661a3eefacbba918 +87e58365cf5292ae0150b97d5bba026158e28a5c2fa32cb04cf4c6a0d0c97111 +fa3cfb3f1bb823aa9501f88f1f95f732ee6fef2c3a48be7f1d38037b216a549f \ No newline at end of file diff --git a/tests/fixtures/functions/ordered_set_intersect/psi1.txt.enc b/tests/fixtures/functions/ordered_set_intersect/psi1.txt.enc new file mode 100644 index 0000000000000000000000000000000000000000..2607b43ce36a816d63fe4ead53033bd6dcffdedc GIT binary patch literal 4096 zcmeH~S0EdV!iI?zvqn&RMG39F_ehMQR!Zzqt722F5{(&R$EcbuT6~J4rNpScYksw= z#NJzN&i(ms|I72 z4{0^|O|TKs{~7Ur8}p>`I+8<~kEOuOXoAGSH>R;5sx}ueup%3PK|qYHwrhm(+Q^`fe=>+s2g-3ow}R zDbDnpp(mumu_4Sl=>yhWlZO116f1@d zg`cmkL29hq;Ep*YH3aK>UUyrMIG2^$#Wk5o9z=WHT-nS32hbQQXQ>hDMD#}P31Pb4iAK;d_fbsim^&7JRy^W90eqI3b6f!VitDpLpc zWi$LfWP#GfZrw368qRF7+aJnPU)dhsNdiuDGncrYta4{B%jq|%&@R%OJ&JNZiTaah zPz0lF4&-(pcq(2Dw!K|qnBXKA?p^V(*u|N}F15-o_6IueC>zZ*bwFe1mO$MC6t~8i z_4w9Jaa&d~CetoRnyI$VKjo#}zf3@XC$oKIzXNC0!C!;ndeD)Wr1E?DyzCgO3GEkT}< z)X@tZ(zWl+0dRwsmy6g#4#P6v%c@5-+%v|sjYB^grcbtBGxJ2#9fT_x2y;dRS2r53 z!Ylxubb!lsR3Wm);oc zLNp`{8UT0q+xyxtz0`V(hQ&JP&i@U3zBYGlndC^GgE*-y%1RJek$3*XfTrJ5rYyZ8 z`HeKF~M@cD7+%~M}CuV}^I z-|s$&UxuWvt~WXa=O6NR-!Tcn|1hUdWgSYpn0+h8_4is1~4lPr=t9(%WWVx zrmC@$Ow9TE^(W27o9g-Hfp?T~5q3m9C{mhOSebF9=9Ndj%6u`K z!6XHbfqFvOYLn~log$M}5ud{iE4FoNyDA$wJ>jlPB;|&=IJ`)rVo5NX?zTT2PC3Af z4ZaH2ZFK+Ubr!+hQL6?+m*FM0Fm;8;FTBm@pP9-17PMtsu2pfp&ZV#bxO)t(SB&8P z?v?6T-S~-})YOOWk6o;O;e7GdPPaM?`p+c4w_#_Mcl5Z6n_4=O`imH8w$wsB3gAjs zX_(!hc!1xHtBd~N)ZU6F%9Wqaub0^k%ru>&TnYxIFO)GpHR7%&ni(e=H!DG_Un;$7 zK*tc0AtIQHrmXdmwwH5TV6H+Y3P@}4BWaUHTAPHkpyp1Kz~Nn?LP~{N*p|-Wb@0C% z9m}Yi&tnC_W25F*p(SR%Ea@dFvb-4#9`aN* zN-!$n+#VSfb)Nn)(Z<+M&@(Xdm*BB-+FU{FT%IGNdukEl{DlWlr_=D96~ghj0E2>T zU8^^90X3C??Ji{QH;Lb`?GJ-IC2M$-N;gNL2-5dH10;bZ#xT~4_YLiW0udTImOBKj zj@zuRh{R%eUBo&ypnikf&DSq2s0BUXOio;wZYJd2e7lf&4>&`DOZ^+?&Si-|(2tNx z$rLq{7cz!CPC_{D{OB5S>7Nf3X|`kL^T^Sa+(;Q-;7lkDC077J=!QBTKglb85zkIr zu#CB_F1!9>4Y6TOG7hMg51oQ<6td?M9zQq^itD^v_>1TB@wMjVXHrd|QYU%Le0;XV zo#|D3Sy0j$r5DBf+5AOZwlB);DnW|zMzeX(8vd5Bs`7&kN+am@px)}=C-O@Qb>|J$ z_QJnrnmqK3Z8U+Zo;847KB zIopW*7jm>kbnFa%$cI6t(zkTRT6a@Kvt04Z9-3oPUDpa#O!*vyikzX%sVEI-wn-Ud zp;>Cx>*V(>UA`{a+2P_?P9@0J*e{kBstC_$HMa|f&v{$t;$P^D*~!-$NCj-#l)1l# zwc$>RnwIQ0ZXWaag$3U-msScmiPp-KJD{QSrI=tp%lqvYl_7>k(9#v0_IysOSsCXX zkypwjGOrVsG+S;kmY*u1;jMAc(uZpZFMb3fJIu%)d`u#&>`W^*_CpD-72XfI3O)_k z2B(evedawCN6cd&TNK@(11~{OQadf8u24!U9D^PPj#N9{GD31}|2Pn(9fND>K7DP- zPCyJaPCfl+)1S#FJ|%HIcH+TRr0z4W%V9Ok_yb1vpII#p_H$84-qmMEd43@t31PE> z_XPJ@X-nDcx<455*R&ahBeMdmy&G}G%#D*N1lkaeYShb5@EY1)Mr<@Y2DUkNP_BFl zW${hZ4Wza*RjG#6Qoh1JNYzfb(l=)U9`(dq`@i2IyS-Y?Z+&2d>I&hR5eEh8h!w;; zR290nd=IV=P>aL(!bKQJF#U&PR5mXpAI6DxFMV7|rNPotT06tLo?H=_{n%|o1x6dX zfhkXLWWGd^P9v}q|E+KvGi6Tq0NU_g?@dSfO6I7I!E2ctqATl~%McrEGF0sJRTp^8 zqkpA&Iy;(eZD#au!sGorr{*^AM4oFXo;cJe@&Gg1bDl zxjh%_aB#U#&t)|#?>>Bq`x+u^Xv~??(&hGTeU+Ca58SIMNFxK0&f$E{?%J(K_WIf^ z=x&U)To2+xi|pNW+50*{{u9a*vyi;}kZeIhJ0nc-61GEHxYLq&I!7A*MgS>(@xHz7 zWBi4eCB*U59z~@lId_aIRXBT|2VZ7S;LUuH*n63dnPa`Gzeu!5JlSNx$9ZG8eb6W~y=yp=nD%=?A=G_77)Ng+jODR89vR|-KXd~~cA0j6Joq`z<%iR}HI?0b*Yo!v zyMRK=@z_(tva&~1J}9fc1E?$N6kjel8y1V0yI;fx^{{X&AOp4$1-TN+^0VHY4mZ}n z3LNHGxBOv;brCeDcd3wv9_Mc;Zsh6PZz#thj@$8x`gQ*WcgU6yM1({q3$ONidfe1^i+ro`O)BI^GNXs+-02IANww}lLV zk$qABdg(PM@4DuOTeo=I#;MoO;*CqB+iM6C0mRRdVzb^jrE0c@eP(|XBi>+oiPeU- zG9fvGhy)QurSEAO;cQ!qPs2~Ds?NsZe6tsKD~fSmwV^t4JqlGy)=!7xGqoi;k>vs&YNleqD#ys6=Fw8UFm^U6_&nW+PI z{|`(-f=OD7_Tf&f#)NxNK)~#G$|cEtSPiF`WMKmoAI=Bb2(cfT@$gCquQ^{drH;0Stml9@`KlWyqsjJ%ahNyo6n zqKhokrS_(k@_GdXgEKnN6CFJ44_B}h?^B&KtBxL2_W`N9n9|byR0he^F)UNGRr%ve xFaSgAEnOdPD@Wo%*3lNoI literal 0 HcmV?d00001 From 48487325cb7393ed528728a947d3117189ce9422 Mon Sep 17 00:00:00 2001 From: Qinkun Bao Date: Mon, 15 Jun 2020 15:09:07 -0400 Subject: [PATCH 2/5] code refactoring --- function/src/ordered_set_intersect.rs | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/function/src/ordered_set_intersect.rs b/function/src/ordered_set_intersect.rs index 7fd02a4da..9d254cc7a 100644 --- a/function/src/ordered_set_intersect.rs +++ b/function/src/ordered_set_intersect.rs @@ -97,25 +97,21 @@ impl OrderedSetIntersect { fn parse_input_data(input: impl io::Read, ascending_order: bool) -> anyhow::Result>> { let mut samples: Vec> = Vec::new(); let reader = BufReader::new(input); - for byte_result in reader.lines() { + for (index, byte_result) in reader.lines().enumerate() { let byte = byte_result?; let result = hex::decode(byte)?; - samples.push(result) - } - let len = samples.len(); - - // Verify the order - if len > 1 { - for i in 1..len { - if ascending_order && samples[i] < samples[i - 1] { - bail!("Invalid ordering"); - } - - if !ascending_order && samples[i] > samples[i - 1] { + if index > 0 { + // If vec has more than 2 elements, then verify the ordering + let last_element = &samples[index - 1]; + if ascending_order && result < *last_element + || !ascending_order && result > *last_element + { bail!("Invalid ordering"); } } + samples.push(result) } + Ok(samples) } From e2c6cdf504d882d8035809a2940759f9881e3137 Mon Sep 17 00:00:00 2001 From: Qinkun Bao Date: Mon, 15 Jun 2020 16:35:52 -0400 Subject: [PATCH 3/5] remove references --- function/src/ordered_set_intersect.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/function/src/ordered_set_intersect.rs b/function/src/ordered_set_intersect.rs index 9d254cc7a..7fe70dcf6 100644 --- a/function/src/ordered_set_intersect.rs +++ b/function/src/ordered_set_intersect.rs @@ -102,9 +102,8 @@ fn parse_input_data(input: impl io::Read, ascending_order: bool) -> anyhow::Resu let result = hex::decode(byte)?; if index > 0 { // If vec has more than 2 elements, then verify the ordering - let last_element = &samples[index - 1]; - if ascending_order && result < *last_element - || !ascending_order && result > *last_element + let last_element = samples[index - 1]; + if ascending_order && result < last_element || !ascending_order && result > last_element { bail!("Invalid ordering"); } From 37907ce633c6f032ab2f3772c81183906d75fcf4 Mon Sep 17 00:00:00 2001 From: Qinkun Bao Date: Mon, 15 Jun 2020 16:47:11 -0400 Subject: [PATCH 4/5] Revert "remove references" This reverts commit e2c6cdf504d882d8035809a2940759f9881e3137. --- function/src/ordered_set_intersect.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/function/src/ordered_set_intersect.rs b/function/src/ordered_set_intersect.rs index 7fe70dcf6..9d254cc7a 100644 --- a/function/src/ordered_set_intersect.rs +++ b/function/src/ordered_set_intersect.rs @@ -102,8 +102,9 @@ fn parse_input_data(input: impl io::Read, ascending_order: bool) -> anyhow::Resu let result = hex::decode(byte)?; if index > 0 { // If vec has more than 2 elements, then verify the ordering - let last_element = samples[index - 1]; - if ascending_order && result < last_element || !ascending_order && result > last_element + let last_element = &samples[index - 1]; + if ascending_order && result < *last_element + || !ascending_order && result > *last_element { bail!("Invalid ordering"); } From 78c47581d8bc432ad1a11f71983268e888a1df4f Mon Sep 17 00:00:00 2001 From: Qinkun Bao Date: Mon, 15 Jun 2020 17:02:42 -0400 Subject: [PATCH 5/5] use bool instead of u8 --- function/src/ordered_set_intersect.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/function/src/ordered_set_intersect.rs b/function/src/ordered_set_intersect.rs index 9d254cc7a..ae1c08637 100644 --- a/function/src/ordered_set_intersect.rs +++ b/function/src/ordered_set_intersect.rs @@ -82,13 +82,13 @@ impl OrderedSetIntersect { let mut common_sets = 0; for item in result1 { - write!(&mut output1, "{}", item)?; - if item > 0 { + write!(&mut output1, "{}", item as u32)?; + if item { common_sets += 1; } } for item in result2 { - write!(&mut output2, "{}", item)?; + write!(&mut output2, "{}", item as u32)?; } Ok(format!("{} common items", common_sets)) } @@ -119,12 +119,12 @@ fn intersection_ordered_vec( input1: &[Vec], input2: &[Vec], ascending_order: bool, -) -> anyhow::Result<(Vec, Vec)> { +) -> anyhow::Result<(Vec, Vec)> { let v1_len = input1.len(); let v2_len = input2.len(); - let mut res1 = std::vec![0; v1_len]; - let mut res2 = std::vec![0; v2_len]; + let mut res1 = std::vec![false; v1_len]; + let mut res2 = std::vec![false; v2_len]; let mut i = 0; let mut j = 0; @@ -133,8 +133,8 @@ fn intersection_ordered_vec( let order = &input1[i].cmp(&input2[j]); match order { cmp::Ordering::Equal => { - res1[i] = 1; - res2[j] = 1; + res1[i] = true; + res2[j] = true; i += 1; j += 1; }