diff --git a/cli/src/main.rs b/cli/src/main.rs index 372443e38..1c672edc6 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -260,7 +260,7 @@ fn attest(opt: AttestOpt) -> Result<()> { let mut session = rustls::ClientSession::new(&rc_config, hostname); let mut tls_stream = rustls::Stream::new(&mut session, &mut stream); - tls_stream.write(&[0]).unwrap(); + tls_stream.write_all(&[0]).unwrap(); Ok(()) } diff --git a/cmake/scripts/test.sh b/cmake/scripts/test.sh index 5388371a5..ab3a5df71 100755 --- a/cmake/scripts/test.sh +++ b/cmake/scripts/test.sh @@ -191,6 +191,7 @@ run_examples() { python3 builtin_ordered_set_intersect.py python3 builtin_rsa_sign.py python3 builtin_face_detection.py + python3 builtin_password_check.py popd # kill all background services diff --git a/examples/python/builtin_password_check.py b/examples/python/builtin_password_check.py new file mode 100644 index 000000000..fb87d42e5 --- /dev/null +++ b/examples/python/builtin_password_check.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys + +from teaclave import (AuthenticationService, FrontendService, + AuthenticationClient, FrontendClient, FunctionInput, + FunctionOutput, OwnerList, DataMap) +from utils import (AUTHENTICATION_SERVICE_ADDRESS, FRONTEND_SERVICE_ADDRESS, + AS_ROOT_CA_CERT_PATH, ENCLAVE_INFO_PATH, USER_ID, + USER_PASSWORD) + +# In the example, user 0 creates the task and user 0, 1, upload their private data. +# Then user 0 invokes the task and user 0, 1 get the result. + + +class UserData: + def __init__(self, + user_id, + password, + input_url="", + encryption_algorithm="teaclave-file-128", + input_cmac="", + iv=[], + key=[]): + self.user_id = user_id + self.password = password + self.input_url = input_url + self.encryption_algorithm = encryption_algorithm + self.input_cmac = input_cmac + self.iv = iv + self.key = key + + +INPUT_FILE_URL_PREFIX = "http://localhost:6789/fixtures/functions/password_check/" + +# Client +USER_DATA_0 = UserData( + "user0", + "password", + "data:text/plain;base64,c+mpvRfZ0fboR0j3rTgOGDBiubSzlCt9", # base64 of encrypted string "password" + "aes-gcm-128", + "e84748f7ad380e183062b9b4b3942b7d", + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + +# Data provider +USER_DATA_1 = UserData("user1", "password", + INPUT_FILE_URL_PREFIX + "exposed_passwords.txt.enc", + "teaclave-file-128", "42b16c29edeb9ee0e4d219f3b5395946", + [], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + +class DataList: + def __init__(self, data_name, data_id): + self.data_name = data_name + self.data_id = data_id + + +class Client: + def __init__(self, user_id, user_password): + self.user_id = user_id + self.user_password = user_password + self.client = AuthenticationService( + AUTHENTICATION_SERVICE_ADDRESS, AS_ROOT_CA_CERT_PATH, + ENCLAVE_INFO_PATH).connect().get_client() + print(f"[+] {self.user_id} registering user") + self.client.user_register(self.user_id, self.user_password) + print(f"[+] {self.user_id} login") + token = self.client.user_login(self.user_id, self.user_password) + self.client = FrontendService( + FRONTEND_SERVICE_ADDRESS, AS_ROOT_CA_CERT_PATH, + ENCLAVE_INFO_PATH).connect().get_client() + metadata = {"id": self.user_id, "token": token} + self.client.metadata = metadata + + def set_task(self): + client = self.client + + print(f"[+] {self.user_id} registering function") + + function_id = client.register_function( + name="builtin-password-check", + description="Check whether a password is exposed.", + executor_type="builtin", + arguments=[], + inputs=[ + FunctionInput("password", "Client 0 data."), + FunctionInput("exposed_passwords", "Client 1 data.") + ], + outputs=[]) + + print(f"[+] {self.user_id} creating task") + task_id = client.create_task( + function_id=function_id, + function_arguments={}, + executor="builtin", + inputs_ownership=[ + OwnerList("password", [USER_DATA_0.user_id]), + OwnerList("exposed_passwords", [USER_DATA_1.user_id]) + ], + ) + + return task_id + + def run_task(self, task_id): + client = self.client + print(f"[+] {self.user_id} invoking task") + client.invoke_task(task_id) + + def register_data(self, task_id, input_url, algorithm, input_cmac, + file_key, iv, input_label): + client = self.client + + print(f"[+] {self.user_id} registering input file") + url = input_url + cmac = input_cmac + schema = algorithm + key = file_key + input_id = client.register_input_file(url, schema, key, iv, cmac) + + print(f"[+] {self.user_id} assigning data to task") + client.assign_data_to_task(task_id, [DataList(input_label, input_id)], + []) + + def approve_task(self, task_id): + client = self.client + print(f"[+] {self.user_id} approving task") + client.approve_task(task_id) + + def get_task_result(self, task_id): + client = self.client + print(f"[+] {self.user_id} getting task result") + return bytes(client.get_task_result(task_id)) + + +def main(): + user0 = Client(USER_DATA_0.user_id, USER_DATA_0.password) + user1 = Client(USER_DATA_1.user_id, USER_DATA_1.password) + + task_id = user0.set_task() + + user0.register_data( + task_id, + USER_DATA_0.input_url, + USER_DATA_0.encryption_algorithm, + USER_DATA_0.input_cmac, + USER_DATA_0.key, + USER_DATA_0.iv, + "password", + ) + + user1.register_data( + task_id, + USER_DATA_1.input_url, + USER_DATA_1.encryption_algorithm, + USER_DATA_1.input_cmac, + USER_DATA_1.key, + USER_DATA_1.iv, + "exposed_passwords", + ) + + user0.approve_task(task_id) + user1.approve_task(task_id) + + ## USER 0 start the computation + user0.run_task(task_id) + + ## USER 0, 1 get the result + result_user0 = user0.get_task_result(task_id) + result_user1 = user1.get_task_result(task_id) + + print("[+] User 0 result: " + result_user0.decode("utf-8")) + print("[+] User 1 result: " + result_user1.decode("utf-8")) + + +if __name__ == '__main__': + main() diff --git a/executor/Cargo.toml b/executor/Cargo.toml index 76ea1a9d4..7f88450cf 100644 --- a/executor/Cargo.toml +++ b/executor/Cargo.toml @@ -29,29 +29,31 @@ enclave_unit_test = [ full_builtin_function = [ "builtin_echo", + "builtin_face_detection", "builtin_gbdt_predict", "builtin_gbdt_train", "builtin_logistic_regression_predict", "builtin_logistic_regression_train", + "builtin_password_check", "builtin_online_decrypt", - "builtin_private_join_and_compute", "builtin_ordered_set_intersect", - "builtin_rsa_sign", - "builtin_face_detection", "builtin_principal_components_analysis", + "builtin_private_join_and_compute", + "builtin_rsa_sign", ] builtin_echo = [] +builtin_face_detection = [] builtin_gbdt_predict = [] builtin_gbdt_train = [] builtin_logistic_regression_predict = [] builtin_logistic_regression_train = [] +builtin_password_check = [] builtin_online_decrypt = [] -builtin_private_join_and_compute = [] builtin_ordered_set_intersect = [] -builtin_rsa_sign = [] -builtin_face_detection = [] builtin_principal_components_analysis = [] +builtin_private_join_and_compute = [] +builtin_rsa_sign = [] [dependencies] log = { version = "0.4.6", features = ["release_max_level_info"] } diff --git a/executor/src/builtin.rs b/executor/src/builtin.rs index afc98624d..070bd8820 100644 --- a/executor/src/builtin.rs +++ b/executor/src/builtin.rs @@ -20,8 +20,8 @@ use std::prelude::v1::*; use teaclave_function::{ Echo, FaceDetection, GbdtPredict, GbdtTrain, LogisticRegressionPredict, - LogisticRegressionTrain, OnlineDecrypt, OrderedSetIntersect, PrincipalComponentsAnalysis, - PrivateJoinAndCompute, RsaSign, + LogisticRegressionTrain, OnlineDecrypt, OrderedSetIntersect, PasswordCheck, + PrincipalComponentsAnalysis, PrivateJoinAndCompute, RsaSign, }; use teaclave_types::{FunctionArguments, FunctionRuntime, TeaclaveExecutor}; @@ -65,6 +65,8 @@ impl TeaclaveExecutor for BuiltinFunctionExecutor { } #[cfg(feature = "builtin_face_detection")] FaceDetection::NAME => FaceDetection::new().run(arguments, runtime), + #[cfg(feature = "builtin_password_check")] + PasswordCheck::NAME => PasswordCheck::new().run(arguments, runtime), _ => bail!("Function not found."), } } diff --git a/file_agent/Cargo.toml b/file_agent/Cargo.toml index ee888f027..c40c2064f 100644 --- a/file_agent/Cargo.toml +++ b/file_agent/Cargo.toml @@ -16,6 +16,7 @@ default = [] [dependencies] log = { version = "0.4.6", features = ["release_max_level_info"] } anyhow = { version = "1.0.26" } +base64 = { version = "0.10.1" } serde_json = { version = "1.0.39" } serde = { version = "1.0.92", features = ["derive"] } thiserror = { version = "1.0.9" } diff --git a/file_agent/src/agent.rs b/file_agent/src/agent.rs index 3754e653f..caa0d4504 100644 --- a/file_agent/src/agent.rs +++ b/file_agent/src/agent.rs @@ -129,6 +129,15 @@ async fn handle_download( ); copy_file(src, dst).await?; } + "data" => { + let data = remote.path().split(',').collect::>(); + if data.len() == 2 && data[0] == "text/plain;base64" { + let bytes = base64::decode(data[1])?; + tokio::fs::write(dst, bytes).await?; + } else { + anyhow::bail!("Scheme format not supported") + } + } _ => anyhow::bail!("Scheme not supported"), } Ok(()) @@ -376,4 +385,18 @@ mod tests { std::fs::remove_dir_all(&base).unwrap(); } + + #[test] + fn test_data_scheme() { + let url = Url::parse("data:text/plain;base64,SGVsbG8sIFdvcmxkIQ==").unwrap(); + let dest = PathBuf::from("/tmp/input_test.txt"); + let info = HandleFileInfo::new(&dest, &url); + let req = FileAgentRequest::new(HandleFileCommand::Download, vec![info], ""); + + let bytes = serde_json::to_vec(&req).unwrap(); + handle_file_request(&bytes).unwrap(); + assert_eq!(std::fs::read_to_string(&dest).unwrap(), "Hello, World!"); + + std::fs::remove_file(&dest).unwrap(); + } } diff --git a/function/README.md b/function/README.md index d1bbe013c..c27adc0f8 100644 --- a/function/README.md +++ b/function/README.md @@ -28,6 +28,8 @@ Currently, we have these built-in functions: - `builtin-face-detection`: An implementation of Funnel-Structured cascade, which is designed for real-time multi-view face detection. - `builtin-principal-components-analysis`: Example to calculate PCA. + - `builtin-password-check`: Given a password, check whether it is in the + exposed password list. The function arguments are in JSON format and can be serialized to a Rust struct very easily. You can learn more about supported arguments in the implementation diff --git a/function/src/lib.rs b/function/src/lib.rs index 6bb6ce1b6..641512c47 100644 --- a/function/src/lib.rs +++ b/function/src/lib.rs @@ -30,6 +30,7 @@ mod logistic_regression_predict; mod logistic_regression_train; mod online_decrypt; mod ordered_set_intersect; +mod password_check; mod principal_components_analysis; mod private_join_and_compute; mod rsa_sign; @@ -42,6 +43,7 @@ pub use logistic_regression_predict::LogisticRegressionPredict; pub use logistic_regression_train::LogisticRegressionTrain; pub use online_decrypt::OnlineDecrypt; pub use ordered_set_intersect::OrderedSetIntersect; +pub use password_check::PasswordCheck; pub use principal_components_analysis::PrincipalComponentsAnalysis; pub use private_join_and_compute::PrivateJoinAndCompute; pub use rsa_sign::RsaSign; @@ -54,16 +56,17 @@ pub mod tests { pub fn run_tests() -> bool { check_all_passed!( echo::tests::run_tests(), - gbdt_train::tests::run_tests(), + face_detection::tests::run_tests(), gbdt_predict::tests::run_tests(), - logistic_regression_train::tests::run_tests(), + gbdt_train::tests::run_tests(), logistic_regression_predict::tests::run_tests(), + logistic_regression_train::tests::run_tests(), + password_check::tests::run_tests(), online_decrypt::tests::run_tests(), - private_join_and_compute::tests::run_tests(), ordered_set_intersect::tests::run_tests(), - rsa_sign::tests::run_tests(), - face_detection::tests::run_tests(), principal_components_analysis::tests::run_tests(), + private_join_and_compute::tests::run_tests(), + rsa_sign::tests::run_tests(), ) } } diff --git a/function/src/password_check.rs b/function/src/password_check.rs new file mode 100644 index 000000000..10d4652c0 --- /dev/null +++ b/function/src/password_check.rs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::io::prelude::*; +use std::prelude::v1::*; + +use std::collections::HashSet; +use std::convert::TryFrom; +use std::io::BufReader; +use teaclave_types::{FunctionArguments, FunctionRuntime}; + +#[derive(Default)] +pub struct PasswordCheck; + +#[derive(serde::Deserialize)] +struct PasswordCheckArguments; + +impl TryFrom for PasswordCheckArguments { + type Error = anyhow::Error; + + fn try_from(arguments: FunctionArguments) -> Result { + use anyhow::Context; + serde_json::from_str(&arguments.into_string()).context("Cannot deserialize arguments") + } +} + +impl PasswordCheck { + pub const NAME: &'static str = "builtin-password-check"; + + pub fn new() -> Self { + Default::default() + } + + pub fn run(&self, _: FunctionArguments, runtime: FunctionRuntime) -> anyhow::Result { + let password_file = runtime.open_input("password")?; + let password = BufReader::new(password_file) + .lines() + .next() + .unwrap() + .unwrap() + .trim() + .to_owned(); + let exposed_passwords_file = runtime.open_input("exposed_passwords")?; + let exposed_passwords: HashSet = BufReader::new(exposed_passwords_file) + .lines() + .map(|l| l.expect("Could not parse line").trim().to_string()) + .collect::>() + .iter() + .cloned() + .collect(); + if exposed_passwords.contains(&password) { + Ok("true".to_string()) + } else { + Ok("false".to_string()) + } + } +} + +#[cfg(feature = "enclave_unit_test")] +pub mod tests { + use super::*; + use std::path::Path; + use teaclave_crypto::*; + use teaclave_runtime::*; + use teaclave_test_utils::*; + use teaclave_types::*; + + pub fn run_tests() -> bool { + run_tests!(test_password_check) + } + + fn test_password_check() { + let password_input = Path::new("fixtures/functions/password_check/password.txt"); + let exposed_passwords_input = + Path::new("fixtures/functions/password_check/exposed_passwords.txt"); + let arguments = FunctionArguments::default(); + + let input_files = StagedFiles::new(hashmap!( + "password" => StagedFileInfo::new(&password_input, TeaclaveFile128Key::random(), FileAuthTag::mock()), + "exposed_passwords" => StagedFileInfo::new(&exposed_passwords_input, TeaclaveFile128Key::random(), FileAuthTag::mock()), + )); + let output_files = StagedFiles::new(hashmap!()); + let runtime = Box::new(RawIoRuntime::new(input_files, output_files)); + + let result = PasswordCheck::new().run(arguments, runtime).unwrap(); + + assert_eq!(result, "true"); + } +} diff --git a/tests/fixtures/functions/password_check/exposed_passwords.txt b/tests/fixtures/functions/password_check/exposed_passwords.txt new file mode 100644 index 000000000..48ee88910 --- /dev/null +++ b/tests/fixtures/functions/password_check/exposed_passwords.txt @@ -0,0 +1,20 @@ +123456 +123456789 +qwerty +password +1111111 +12345678 +abc123 +1234567 +password1 +12345 +1234567890 +123123 +000000 +Iloveyou +1234 +1q2w3e4r5t +Qwertyuiop +123 +Monkey +Dragon diff --git a/tests/fixtures/functions/password_check/exposed_passwords.txt.enc b/tests/fixtures/functions/password_check/exposed_passwords.txt.enc new file mode 100644 index 000000000..af31d201c Binary files /dev/null and b/tests/fixtures/functions/password_check/exposed_passwords.txt.enc differ diff --git a/tests/fixtures/functions/password_check/password.txt b/tests/fixtures/functions/password_check/password.txt new file mode 100644 index 000000000..f3097ab13 --- /dev/null +++ b/tests/fixtures/functions/password_check/password.txt @@ -0,0 +1 @@ +password diff --git a/tests/scripts/simple_http_server.py b/tests/scripts/simple_http_server.py old mode 100644 new mode 100755 index 5a0cd0161..709285bee --- a/tests/scripts/simple_http_server.py +++ b/tests/scripts/simple_http_server.py @@ -23,7 +23,7 @@ def do_PUT(self): else: port = 6789 socketserver.TCPServer.allow_reuse_address = True - with socketserver.TCPServer(("localhost", port), + with socketserver.TCPServer(("0.0.0.0", port), HTTPRequestHandler) as httpd: print("serving at port", port) httpd.serve_forever()