diff --git a/common/rusty_leveldb_sgx/src/version_set.rs b/common/rusty_leveldb_sgx/src/version_set.rs index 50bb1b22d..371897625 100644 --- a/common/rusty_leveldb_sgx/src/version_set.rs +++ b/common/rusty_leveldb_sgx/src/version_set.rs @@ -1067,7 +1067,6 @@ pub mod tests { ve.delete_file(0, 2); ve.set_compact_pointer(2, LookupKey::new("xxx".as_bytes(), 123).internal_key()); - println!("XXXXXXXXXXXXXXXXXXXXXXXX"); let mut b = Builder::new(); let mut ptrs: [Vec; NUM_LEVELS] = Default::default(); b.apply(&ve, &mut ptrs); diff --git a/services/execution/enclave/src/service.rs b/services/execution/enclave/src/service.rs index 191ac16b9..afc37becb 100644 --- a/services/execution/enclave/src/service.rs +++ b/services/execution/enclave/src/service.rs @@ -193,7 +193,7 @@ fn prepare_task(task: &StagedTask) -> StagedFunction { let runtime_name = "default".to_string(); let executor_type = task.executor_type; - let function_name = task.native_func.clone(); + let executor_name = task.executor; let function_payload = String::from_utf8_lossy(&task.function_payload).to_string(); let function_arguments = task.function_arguments.clone(); @@ -241,7 +241,7 @@ fn prepare_task(task: &StagedTask) -> StagedFunction { let output_files = StagedFiles::new(output_file_map); StagedFunction::new() - .name(function_name) + .executor(executor_name) .payload(function_payload) .arguments(function_arguments) .input_files(input_files) @@ -262,11 +262,8 @@ pub mod tests { let task_id = Uuid::new_v4(); let staged_task = StagedTask::new() .task_id(task_id) - .native_func("echo") - .function_arguments(hashmap!( - "message" => "Hello, Teaclave!" - - )); + .executor(Executor::Echo) + .function_arguments(hashmap!("message" => "Hello, Teaclave!")); let invocation = prepare_task(&staged_task); @@ -313,7 +310,7 @@ pub mod tests { let staged_task = StagedTask::new() .task_id(task_id) - .native_func("gbdt_training") + .executor(Executor::GbdtTraining) .function_arguments(function_arguments) .input_data(input_data) .output_data(output_data); diff --git a/services/management/enclave/src/service.rs b/services/management/enclave/src/service.rs index 55177b5f5..f1a0a6603 100644 --- a/services/management/enclave/src/service.rs +++ b/services/management/enclave/src/service.rs @@ -281,8 +281,8 @@ impl TeaclaveManagement for TeaclaveManagementService { let mut task = Task::new( user_id, - &function, request.executor, + &function, request.function_arguments, request.input_owners_map, request.output_owners_map, @@ -486,7 +486,7 @@ impl TeaclaveManagement for TeaclaveManagementService { let staged_task = StagedTask::new() .task_id(task.task_id) - .native_func(task.executor.clone()) + .executor(task.executor) .executor_type(function.executor_type) .function_id(function.id) .function_payload(function.payload) @@ -685,8 +685,8 @@ pub mod tests { let task = Task::new( UserID::from("mock_user"), + Executor::MesaPy, &function, - "mesapy".to_string(), function_arguments, HashMap::new(), HashMap::new(), @@ -714,7 +714,7 @@ pub mod tests { let staged_task = StagedTask::new() .task_id(Uuid::new_v4()) - .native_func("mesapy".to_string()) + .executor(Executor::MesaPy) .function_payload(function.payload) .function_arguments(hashmap!("arg" => "data")) .input_data(hashmap!("input" => input_data)) diff --git a/services/proto/src/teaclave_frontend_service.rs b/services/proto/src/teaclave_frontend_service.rs index e4841bbf3..ca2aa9dc0 100644 --- a/services/proto/src/teaclave_frontend_service.rs +++ b/services/proto/src/teaclave_frontend_service.rs @@ -64,10 +64,8 @@ pub struct RegisterInputFileResponse { } impl RegisterInputFileResponse { - pub fn new(data_id: impl Into) -> Self { - Self { - data_id: data_id.into(), - } + pub fn new(data_id: ExternalID) -> Self { + Self { data_id } } } @@ -93,10 +91,8 @@ pub struct RegisterOutputFileResponse { } impl RegisterOutputFileResponse { - pub fn new(data_id: impl Into) -> Self { - Self { - data_id: data_id.into(), - } + pub fn new(data_id: ExternalID) -> Self { + Self { data_id } } } @@ -121,10 +117,8 @@ pub struct RegisterFusionOutputResponse { } impl RegisterFusionOutputResponse { - pub fn new(data_id: impl Into) -> Self { - Self { - data_id: data_id.into(), - } + pub fn new(data_id: ExternalID) -> Self { + Self { data_id } } } @@ -136,10 +130,8 @@ pub struct RegisterInputFromOutputRequest { } impl RegisterInputFromOutputRequest { - pub fn new(data_id: impl Into) -> Self { - Self { - data_id: data_id.into(), - } + pub fn new(data_id: ExternalID) -> Self { + Self { data_id } } } @@ -151,10 +143,8 @@ pub struct RegisterInputFromOutputResponse { } impl RegisterInputFromOutputResponse { - pub fn new(data_id: impl Into) -> Self { - Self { - data_id: data_id.into(), - } + pub fn new(data_id: ExternalID) -> Self { + Self { data_id } } } @@ -166,10 +156,8 @@ pub struct GetInputFileRequest { } impl GetInputFileRequest { - pub fn new(data_id: impl Into) -> Self { - Self { - data_id: data_id.into(), - } + pub fn new(data_id: ExternalID) -> Self { + Self { data_id } } } @@ -198,10 +186,8 @@ pub struct GetOutputFileRequest { } impl GetOutputFileRequest { - pub fn new(data_id: impl Into) -> Self { - Self { - data_id: data_id.into(), - } + pub fn new(data_id: ExternalID) -> Self { + Self { data_id } } } @@ -314,10 +300,8 @@ pub struct RegisterFunctionResponse { } impl RegisterFunctionResponse { - pub fn new(function_id: impl Into) -> Self { - Self { - function_id: function_id.into(), - } + pub fn new(function_id: ExternalID) -> Self { + Self { function_id } } } @@ -329,10 +313,8 @@ pub struct GetFunctionRequest { } impl GetFunctionRequest { - pub fn new(function_id: impl Into) -> Self { - Self { - function_id: function_id.into(), - } + pub fn new(function_id: ExternalID) -> Self { + Self { function_id } } } @@ -366,9 +348,9 @@ impl CreateTaskRequest { Self::default() } - pub fn function_id(self, function_id: impl Into) -> Self { + pub fn function_id(self, function_id: ExternalID) -> Self { Self { - function_id: function_id.into(), + function_id, ..self } } @@ -380,9 +362,9 @@ impl CreateTaskRequest { } } - pub fn executor(self, executor: impl ToString) -> Self { + pub fn executor(self, executor: impl Into) -> Self { Self { - executor: executor.to_string(), + executor: executor.into(), ..self } } @@ -409,10 +391,8 @@ pub struct CreateTaskResponse { } impl CreateTaskResponse { - pub fn new(task_id: impl Into) -> Self { - Self { - task_id: task_id.into(), - } + pub fn new(task_id: ExternalID) -> Self { + Self { task_id } } } @@ -424,10 +404,8 @@ pub struct GetTaskRequest { } impl GetTaskRequest { - pub fn new(task_id: impl Into) -> Self { - Self { - task_id: task_id.into(), - } + pub fn new(task_id: ExternalID) -> Self { + Self { task_id } } } @@ -461,12 +439,12 @@ pub struct AssignDataRequest { impl AssignDataRequest { pub fn new( - task_id: impl Into, + task_id: ExternalID, input_map: HashMap, output_map: HashMap, ) -> Self { Self { - task_id: task_id.into(), + task_id, input_map, output_map, } @@ -484,10 +462,8 @@ pub struct ApproveTaskRequest { } impl ApproveTaskRequest { - pub fn new(task_id: impl Into) -> Self { - Self { - task_id: task_id.into(), - } + pub fn new(task_id: ExternalID) -> Self { + Self { task_id } } } @@ -502,10 +478,8 @@ pub struct InvokeTaskRequest { } impl InvokeTaskRequest { - pub fn new(task_id: impl Into) -> Self { - Self { - task_id: task_id.into(), - } + pub fn new(task_id: ExternalID) -> Self { + Self { task_id } } } @@ -811,11 +785,12 @@ impl std::convert::TryFrom for RegisterFunctionR .into_iter() .map(FunctionOutput::try_from) .collect(); + let executor_type = proto.executor_type.try_into()?; let ret = Self { name: proto.name, description: proto.description, - executor_type: proto.executor_type.as_str().try_into()?, + executor_type, payload: proto.payload, public: proto.public, arguments: proto.arguments, @@ -904,12 +879,13 @@ impl std::convert::TryFrom for GetFunctionResponse { .into_iter() .map(FunctionOutput::try_from) .collect(); + let executor_type = proto.executor_type.try_into()?; let ret = Self { name: proto.name, description: proto.description, owner: proto.owner.into(), - executor_type: proto.executor_type.as_str().try_into()?, + executor_type, payload: proto.payload, public: proto.public, arguments: proto.arguments, @@ -981,11 +957,12 @@ impl std::convert::TryFrom for CreateTaskRequest { let input_owners_map = data_owner_map_from_proto(proto.input_owners_map)?; let output_owners_map = data_owner_map_from_proto(proto.output_owners_map)?; let function_id = proto.function_id.try_into()?; + let executor = proto.executor.try_into()?; let ret = Self { function_id, function_arguments, - executor: proto.executor, + executor, input_owners_map, output_owners_map, }; @@ -1002,7 +979,7 @@ impl From for proto::CreateTaskRequest { Self { function_id: request.function_id.to_string(), function_arguments, - executor: request.executor, + executor: request.executor.to_string(), input_owners_map, output_owners_map, } diff --git a/tests/functional/enclave/src/end_to_end/mesapy_echo.rs b/tests/functional/enclave/src/end_to_end/mesapy_echo.rs index 835426327..57af5a51b 100644 --- a/tests/functional/enclave/src/end_to_end/mesapy_echo.rs +++ b/tests/functional/enclave/src/end_to_end/mesapy_echo.rs @@ -48,7 +48,7 @@ def entrypoint(argv): let request = CreateTaskRequest::new() .function_id(function_id) .function_arguments(hashmap!("message" => "Hello From Teaclave!")) - .executor("mesapy"); + .executor(Executor::MesaPy); let response = client.create_task(request).unwrap(); @@ -66,20 +66,20 @@ def entrypoint(argv): log::info!("Assign data: {:?}", response); // Approve Task - let request = ApproveTaskRequest::new(&task_id); + let request = ApproveTaskRequest::new(task_id.clone()); let response = client.approve_task(request).unwrap(); log::info!("Approve task: {:?}", response); // Invoke Task - let request = InvokeTaskRequest::new(&task_id); + let request = InvokeTaskRequest::new(task_id.clone()); let response = client.invoke_task(request).unwrap(); log::info!("Invoke task: {:?}", response); // Get Task loop { - let request = GetTaskRequest::new(&task_id); + let request = GetTaskRequest::new(task_id.clone()); let response = client.get_task(request).unwrap(); log::info!("Get task: {:?}", response); std::thread::sleep(std::time::Duration::from_secs(1)); diff --git a/tests/functional/enclave/src/end_to_end/native_echo.rs b/tests/functional/enclave/src/end_to_end/native_echo.rs index 068e7f43b..7ff1d296d 100644 --- a/tests/functional/enclave/src/end_to_end/native_echo.rs +++ b/tests/functional/enclave/src/end_to_end/native_echo.rs @@ -41,7 +41,7 @@ pub fn test_echo_task_success() { let request = CreateTaskRequest::new() .function_id(function_id) .function_arguments(hashmap!("message" => "Hello From Teaclave!")) - .executor("echo"); + .executor(Executor::Echo); let response = client.create_task(request).unwrap(); @@ -59,20 +59,20 @@ pub fn test_echo_task_success() { log::info!("Assign data: {:?}", response); // Approve Task - let request = ApproveTaskRequest::new(&task_id); + let request = ApproveTaskRequest::new(task_id.clone()); let response = client.approve_task(request).unwrap(); log::info!("Approve task: {:?}", response); // Invoke Task - let request = InvokeTaskRequest::new(&task_id); + let request = InvokeTaskRequest::new(task_id.clone()); let response = client.invoke_task(request).unwrap(); log::info!("Invoke task: {:?}", response); // Get Task loop { - let request = GetTaskRequest::new(&task_id); + let request = GetTaskRequest::new(task_id.clone()); let response = client.get_task(request).unwrap(); log::info!("Get task: {:?}", response); std::thread::sleep(std::time::Duration::from_secs(1)); diff --git a/tests/functional/enclave/src/execution_service.rs b/tests/functional/enclave/src/execution_service.rs index 5caccec83..0f8d68def 100644 --- a/tests/functional/enclave/src/execution_service.rs +++ b/tests/functional/enclave/src/execution_service.rs @@ -36,12 +36,11 @@ fn test_execute_function() { }; let function_id = Uuid::new_v4(); - let native_func = "echo"; let staged_task = StagedTask::new() .task_id(task_id) .function_id(function_id.clone()) - .native_func(native_func) + .executor(Executor::Echo) .function_arguments(hashmap!( "message" => "Hello, Teaclave Tests!" )); diff --git a/tests/functional/enclave/src/frontend_service.rs b/tests/functional/enclave/src/frontend_service.rs index 8f18b07f3..9b7a8a721 100644 --- a/tests/functional/enclave/src/frontend_service.rs +++ b/tests/functional/enclave/src/frontend_service.rs @@ -185,15 +185,14 @@ fn test_get_output_file() { }; let mut client = get_client(); - let response = client.register_output_file(request); - let data_id = response.unwrap().data_id; + let response = client.register_output_file(request).unwrap(); + let data_id = response.data_id; - let request = GetOutputFileRequest::new(&data_id); - let response = client.get_output_file(request); - assert!(response.is_ok()); - assert!(response.unwrap().hash.is_empty()); + let request = GetOutputFileRequest::new(data_id.clone()); + let response = client.get_output_file(request).unwrap(); + assert!(response.hash.is_empty()); - let request = GetOutputFileRequest::new(&data_id); + let request = GetOutputFileRequest::new(data_id); client .metadata_mut() .insert("token".to_string(), "wrong token".to_string()); @@ -209,15 +208,14 @@ fn test_get_input_file() { }; let mut client = get_client(); - let response = client.register_input_file(request); - let data_id = response.unwrap().data_id; + let response = client.register_input_file(request).unwrap(); + let data_id = response.data_id; - let request = GetInputFileRequest::new(&data_id); - let response = client.get_input_file(request); - assert!(response.is_ok()); - assert!(!response.unwrap().hash.is_empty()); + let request = GetInputFileRequest::new(data_id.clone()); + let response = client.get_input_file(request).unwrap(); + assert!(!response.hash.is_empty()); - let request = GetInputFileRequest::new(&data_id); + let request = GetInputFileRequest::new(data_id); client .metadata_mut() .insert("token".to_string(), "wrong token".to_string()); @@ -260,24 +258,22 @@ fn test_get_function() { fn test_create_task() { let mut client = get_client(); - let data_owner_id_list = OwnerList::new(vec!["frontend_user", "mock_user"]); - let function_id = ExternalID::try_from("function-00000000-0000-0000-0000-000000000002").unwrap(); let request = CreateTaskRequest::new() .function_id(function_id.clone()) .function_arguments(hashmap!("arg1" => "data1")) - .executor("mesapy") - .output_owners_map(hashmap!("output" => data_owner_id_list.clone())); + .executor(Executor::MesaPy) + .output_owners_map(hashmap!("output" => vec!["frontend_user", "mock_user"])); let response = client.create_task(request); assert!(response.is_ok()); let request = CreateTaskRequest::new() .function_id(function_id) .function_arguments(hashmap!("arg1" => "data1")) - .executor("mesapy") - .output_owners_map(hashmap!("output" => data_owner_id_list)); + .executor(Executor::MesaPy) + .output_owners_map(hashmap!("output" => vec!["frontend_user", "mock_user"])); client .metadata_mut() .insert("token".to_string(), "wrong token".to_string()); @@ -289,17 +285,16 @@ fn test_get_task() { let mut client = get_client(); let function_id = ExternalID::try_from("function-00000000-0000-0000-0000-000000000002").unwrap(); - let data_owner_id_list = OwnerList::new(vec!["frontend_user", "mock_user"]); let request = CreateTaskRequest::new() .function_id(function_id) .function_arguments(hashmap!("arg1" => "data1")) - .executor("mesapy") - .output_owners_map(hashmap!("output" => data_owner_id_list)); + .executor(Executor::MesaPy) + .output_owners_map(hashmap!("output" => vec!["frontend_user", "mock_user"])); let response = client.create_task(request).unwrap(); let task_id = response.task_id; - let request = GetTaskRequest::new(&task_id); + let request = GetTaskRequest::new(task_id.clone()); let response = client.get_task(request); assert!(response.is_ok()); @@ -316,13 +311,12 @@ fn test_assign_data() { let function_id = ExternalID::try_from("function-00000000-0000-0000-0000-000000000002").unwrap(); - let data_owner_id_list = OwnerList::new(vec!["frontend_user"]); let request = CreateTaskRequest::new() .function_id(function_id) .function_arguments(hashmap!("arg1" => "data1")) - .executor("mesapy") - .output_owners_map(hashmap!("output" => data_owner_id_list)); + .executor(Executor::MesaPy) + .output_owners_map(hashmap!("output" => vec!["frontend_user"])); let response = client.create_task(request).unwrap(); let task_id = response.task_id; @@ -349,9 +343,7 @@ fn test_assign_data() { let request = AssignDataRequest { task_id, input_map: HashMap::new(), - output_map: vec![("output".to_string(), output_id)] - .into_iter() - .collect(), + output_map: hashmap!("output" => output_id), }; client .metadata_mut() @@ -368,11 +360,11 @@ fn test_approve_task() { let request = CreateTaskRequest::new() .function_id(function_id) .function_arguments(hashmap!("arg1" => "data1")) - .executor("mesapy") - .output_owners_map(hashmap!("output" => OwnerList::new(vec!["frontend_user"]))); + .executor(Executor::MesaPy) + .output_owners_map(hashmap!("output" => vec!["frontend_user"])); - let response = client.create_task(request); - let task_id = response.unwrap().task_id; + let response = client.create_task(request).unwrap(); + let task_id = response.task_id; let request = RegisterOutputFileRequest { url: Url::parse("s3://s3.us-west-2.amazonaws.com/mybucket/puppy.jpg.enc?key-id=deadbeefdeadbeef&key=deadbeefdeadbeef").unwrap(), @@ -390,7 +382,7 @@ fn test_approve_task() { client.assign_data(request).unwrap(); - let request = ApproveTaskRequest::new(&task_id); + let request = ApproveTaskRequest::new(task_id.clone()); let correct_token = client.metadata().get("token").unwrap().to_string(); client .metadata_mut() @@ -398,7 +390,7 @@ fn test_approve_task() { let response = client.approve_task(request); assert!(response.is_err()); - let request = ApproveTaskRequest::new(&task_id); + let request = ApproveTaskRequest::new(task_id); client .metadata_mut() .insert("token".to_string(), correct_token); @@ -413,11 +405,8 @@ fn test_invoke_task() { let request = CreateTaskRequest::new() .function_id(function_id) .function_arguments(hashmap!("arg1" => "data1")) - .executor("mesapy") - .output_owners_map(hashmap!( - "output" => - OwnerList::new(vec!["frontend_user"]) - )); + .executor(Executor::MesaPy) + .output_owners_map(hashmap!("output" => vec!["frontend_user"])); let response = client.create_task(request).unwrap(); let task_id = response.task_id; @@ -437,10 +426,10 @@ fn test_invoke_task() { client.assign_data(request).unwrap(); - let request = ApproveTaskRequest::new(&task_id); + let request = ApproveTaskRequest::new(task_id.clone()); client.approve_task(request).unwrap(); - let request = InvokeTaskRequest::new(&task_id); + let request = InvokeTaskRequest::new(task_id.clone()); let correct_token = client.metadata().get("token").unwrap().to_string(); client .metadata_mut() @@ -448,14 +437,14 @@ fn test_invoke_task() { let response = client.invoke_task(request); assert!(response.is_err()); - let request = InvokeTaskRequest::new(&task_id); + let request = InvokeTaskRequest::new(task_id.clone()); client .metadata_mut() .insert("token".to_string(), correct_token); let response = client.invoke_task(request); assert!(response.is_ok()); - let request = GetTaskRequest::new(&task_id); + let request = GetTaskRequest::new(task_id); let response = client.get_task(request).unwrap(); assert_eq!(response.status, TaskStatus::Running); diff --git a/tests/functional/enclave/src/management_service.rs b/tests/functional/enclave/src/management_service.rs index 921d841c6..95eff5499 100644 --- a/tests/functional/enclave/src/management_service.rs +++ b/tests/functional/enclave/src/management_service.rs @@ -180,11 +180,11 @@ fn test_get_output_file() { let mut client = get_client("mock_user"); let response = client.register_output_file(request).unwrap(); let data_id = response.data_id; - let request = GetOutputFileRequest::new(&data_id); + let request = GetOutputFileRequest::new(data_id.clone()); let response = client.get_output_file(request); assert!(response.is_ok()); let mut client = get_client("mock_another_user"); - let request = GetOutputFileRequest::new(&data_id); + let request = GetOutputFileRequest::new(data_id); let response = client.get_output_file(request); assert!(response.is_err()); } @@ -197,11 +197,11 @@ fn test_get_input_file() { let mut client = get_client("mock_user"); let response = client.register_input_file(request).unwrap(); let data_id = response.data_id; - let request = GetInputFileRequest::new(&data_id); + let request = GetInputFileRequest::new(data_id.clone()); let response = client.get_input_file(request); assert!(response.is_ok()); let mut client = get_client("mock_another_user"); - let request = GetInputFileRequest::new(&data_id); + let request = GetInputFileRequest::new(data_id); let response = client.get_input_file(request); assert!(response.is_err()); } @@ -282,7 +282,7 @@ fn get_correct_create_task() -> CreateTaskRequest { CreateTaskRequest { function_id, function_arguments, - executor: "mesapy".to_string(), + executor: Executor::MesaPy, input_owners_map, output_owners_map, } @@ -290,11 +290,8 @@ fn get_correct_create_task() -> CreateTaskRequest { fn test_create_task() { let request = CreateTaskRequest { - function_id: ExternalID::default(), - function_arguments: HashMap::new().into(), - executor: "mesapy".to_string(), - input_owners_map: HashMap::new(), - output_owners_map: HashMap::new(), + executor: Executor::MesaPy, + ..Default::default() }; let mut client = get_client("mock_user"); let response = client.create_task(request); @@ -405,9 +402,8 @@ fn test_assign_data() { let request = GetOutputFileRequest { data_id: existing_outfile_id_user1.clone(), }; - let response = client1.get_output_file(request); - assert!(response.is_ok()); - assert!(!response.unwrap().hash.is_empty()); + let response = client1.get_output_file(request).unwrap(); + assert!(!response.hash.is_empty()); let mut request = AssignDataRequest { task_id: task_id.clone(), input_map: HashMap::new(), @@ -651,30 +647,30 @@ fn test_approve_task() { let response = client3.assign_data(request); assert!(response.is_ok()); - let request = GetTaskRequest::new(&task_id); + let request = GetTaskRequest::new(task_id.clone()); let response = client2.get_task(request); assert_eq!(response.unwrap().status, TaskStatus::Ready); // user_id not in task.participants let mut unknown_client = get_client("non-participant"); - let request = ApproveTaskRequest::new(&task_id); + let request = ApproveTaskRequest::new(task_id.clone()); let response = unknown_client.approve_task(request); assert!(response.is_err()); //all participants approve the task - let request = ApproveTaskRequest::new(&task_id); + let request = ApproveTaskRequest::new(task_id.clone()); let response = client.approve_task(request); assert!(response.is_ok()); - let request = ApproveTaskRequest::new(&task_id); + let request = ApproveTaskRequest::new(task_id.clone()); let response = client1.approve_task(request); assert!(response.is_ok()); - let request = ApproveTaskRequest::new(&task_id); + let request = ApproveTaskRequest::new(task_id.clone()); let response = client2.approve_task(request); assert!(response.is_ok()); - let request = ApproveTaskRequest::new(&task_id); + let request = ApproveTaskRequest::new(task_id.clone()); let response = client3.approve_task(request); assert!(response.is_ok()); - let request = GetTaskRequest::new(&task_id); + let request = GetTaskRequest::new(task_id); let response = client2.get_task(request); assert_eq!(response.unwrap().status, TaskStatus::Approved); } @@ -744,33 +740,33 @@ fn test_invoke_task() { assert!(response.is_ok()); // task status != Approved - let request = InvokeTaskRequest::new(&task_id); + let request = InvokeTaskRequest::new(task_id.clone()); let response = client.invoke_task(request); assert!(response.is_err()); //all participants approve the task - let request = ApproveTaskRequest::new(&task_id); + let request = ApproveTaskRequest::new(task_id.clone()); client.approve_task(request).unwrap(); - let request = ApproveTaskRequest::new(&task_id); + let request = ApproveTaskRequest::new(task_id.clone()); client1.approve_task(request).unwrap(); - let request = ApproveTaskRequest::new(&task_id); + let request = ApproveTaskRequest::new(task_id.clone()); client2.approve_task(request).unwrap(); - let request = ApproveTaskRequest::new(&task_id); + let request = ApproveTaskRequest::new(task_id.clone()); client3.approve_task(request).unwrap(); - let request = GetTaskRequest::new(&task_id); + let request = GetTaskRequest::new(task_id.clone()); let response = client2.get_task(request).unwrap(); assert_eq!(response.status, TaskStatus::Approved); // user_id != task.creator - let request = InvokeTaskRequest::new(&task_id); + let request = InvokeTaskRequest::new(task_id.clone()); let response = client2.invoke_task(request); assert!(response.is_err()); // invoke task - let request = InvokeTaskRequest::new(&task_id); + let request = InvokeTaskRequest::new(task_id.clone()); client.invoke_task(request).unwrap(); - let request = GetTaskRequest::new(&task_id); + let request = GetTaskRequest::new(task_id); let response = client2.get_task(request).unwrap(); assert_eq!(response.status, TaskStatus::Running); diff --git a/tests/functional/enclave/src/scheduler_service.rs b/tests/functional/enclave/src/scheduler_service.rs index 10050cef6..2fe2f14ce 100644 --- a/tests/functional/enclave/src/scheduler_service.rs +++ b/tests/functional/enclave/src/scheduler_service.rs @@ -31,14 +31,11 @@ pub fn run_tests() -> bool { } fn test_pull_task() { - let task_id = Uuid::new_v4(); let function_id = Uuid::new_v4(); - let native_func = "echo"; - let staged_task = StagedTask::new() - .task_id(task_id) + .task_id(Uuid::new_v4()) .function_id(function_id.clone()) - .native_func(native_func); + .executor(Executor::Echo); let mut storage_client = get_storage_client(); let enqueue_request = EnqueueRequest::new( @@ -64,12 +61,11 @@ fn test_update_task_status_result() { }; let function_id = Uuid::new_v4(); - let native_func = "echo"; let staged_task = StagedTask::new() .task_id(task_id.clone()) .function_id(function_id) - .native_func(native_func); + .executor(Executor::Echo); let mut storage_client = get_storage_client(); let enqueue_request = EnqueueRequest::new( diff --git a/tests/integration/enclave/src/teaclave_worker.rs b/tests/integration/enclave/src/teaclave_worker.rs index 6d763bc97..383bbaa11 100644 --- a/tests/integration/enclave/src/teaclave_worker.rs +++ b/tests/integration/enclave/src/teaclave_worker.rs @@ -18,8 +18,8 @@ use std::prelude::v1::*; use teaclave_types::{ - hashmap, read_all_bytes, ExecutorType, FunctionArguments, StagedFileInfo, StagedFiles, - StagedFunction, TeaclaveFile128Key, + hashmap, read_all_bytes, Executor, ExecutorType, FunctionArguments, StagedFileInfo, + StagedFiles, StagedFunction, TeaclaveFile128Key, }; use teaclave_worker::Worker; @@ -51,12 +51,12 @@ fn test_start_worker() { "trained_model" => output_info.clone())); let staged_function = StagedFunction::new() - .name("gbdt_training") + .executor_type(ExecutorType::Native) + .executor(Executor::GbdtTraining) .arguments(arguments) .input_files(input_files) .output_files(output_files) - .runtime_name("default") - .executor_type(ExecutorType::Native); + .runtime_name("default"); let worker = Worker::default(); diff --git a/types/src/staged_function.rs b/types/src/staged_function.rs index 756cfeeab..258fff4ef 100644 --- a/types/src/staged_function.rs +++ b/types/src/staged_function.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::{ExecutorType, StagedFiles}; +use crate::{Executor, ExecutorType, StagedFiles}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -155,7 +155,7 @@ impl FunctionArguments { #[derive(Debug, Default)] pub struct StagedFunction { - pub name: String, + pub executor: Executor, pub payload: String, pub arguments: FunctionArguments, pub input_files: StagedFiles, @@ -169,11 +169,8 @@ impl StagedFunction { Self::default() } - pub fn name(self, name: impl ToString) -> Self { - Self { - name: name.to_string(), - ..self - } + pub fn executor(self, executor: Executor) -> Self { + Self { executor, ..self } } pub fn payload(self, payload: impl ToString) -> Self { diff --git a/types/src/staged_task.rs b/types/src/staged_task.rs index 5081329c2..7bee29bfe 100644 --- a/types/src/staged_task.rs +++ b/types/src/staged_task.rs @@ -23,7 +23,8 @@ use url::Url; use uuid::Uuid; use crate::{ - ExecutorType, FileCrypto, FunctionArguments, Storable, TeaclaveInputFile, TeaclaveOutputFile, + Executor, ExecutorType, FileCrypto, FunctionArguments, Storable, TeaclaveInputFile, + TeaclaveOutputFile, }; const STAGED_TASK_PREFIX: &str = "staged-"; // staged-task-uuid @@ -80,7 +81,7 @@ impl FunctionOutputFile { pub struct StagedTask { pub task_id: Uuid, pub function_id: Uuid, - pub native_func: String, + pub executor: Executor, pub executor_type: ExecutorType, pub function_payload: Vec, pub function_arguments: FunctionArguments, @@ -114,11 +115,8 @@ impl StagedTask { } } - pub fn native_func(self, native_func: impl ToString) -> Self { - Self { - native_func: native_func.to_string(), - ..self - } + pub fn executor(self, executor: Executor) -> Self { + Self { executor, ..self } } pub fn function_payload(self, function_payload: Vec) -> Self { diff --git a/types/src/task.rs b/types/src/task.rs index 20194647a..e47353042 100644 --- a/types/src/task.rs +++ b/types/src/task.rs @@ -21,8 +21,36 @@ use crate::*; use anyhow::{anyhow, bail, ensure, Result}; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; +use std::convert::TryInto; use uuid::Uuid; +#[derive(Debug, Default, Clone, Deserialize, PartialEq, Eq, Hash, Serialize)] +pub struct UserID(String); + +impl std::convert::From for UserID { + fn from(uid: String) -> UserID { + UserID(uid) + } +} + +impl std::convert::From<&str> for UserID { + fn from(uid: &str) -> UserID { + UserID(uid.to_string()) + } +} + +impl std::convert::From for String { + fn from(user_id: UserID) -> String { + user_id.to_string() + } +} + +impl std::fmt::Display for UserID { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + pub type UserList = OwnerList; #[derive(Debug, Deserialize, Serialize, Clone, Default, PartialEq)] @@ -111,37 +139,6 @@ impl Default for TaskStatus { } } -const TASK_PREFIX: &str = "task"; - -#[derive(Debug, Default, Deserialize, Serialize)] -pub struct Task { - pub task_id: Uuid, - pub creator: UserID, - pub function_id: ExternalID, - pub function_arguments: FunctionArguments, - pub executor: Executor, - pub input_owners_map: HashMap, - pub output_owners_map: HashMap, - pub function_owner: UserID, - pub participants: UserList, - pub approved_users: UserList, - pub input_map: HashMap, - pub output_map: HashMap, - pub return_value: Option>, - pub output_file_hash: HashMap, - pub status: TaskStatus, -} - -impl Storable for Task { - fn key_prefix() -> &'static str { - TASK_PREFIX - } - - fn uuid(&self) -> Uuid { - self.task_id - } -} - #[derive(Debug, Clone, PartialEq, Default, Deserialize, Serialize)] pub struct ExternalID { pub prefix: String, @@ -189,7 +186,6 @@ impl std::convert::TryFrom<&str> for ExternalID { } } -use std::convert::TryInto; impl std::convert::TryFrom for ExternalID { type Error = anyhow::Error; fn try_from(ext_id: String) -> Result { @@ -197,40 +193,42 @@ impl std::convert::TryFrom for ExternalID { } } -#[derive(Debug, Default, Clone, Deserialize, PartialEq, Eq, Hash, Serialize)] -pub struct UserID(String); - -impl std::convert::From for UserID { - fn from(uid: String) -> UserID { - UserID(uid) - } -} +const TASK_PREFIX: &str = "task"; -impl std::convert::From<&str> for UserID { - fn from(uid: &str) -> UserID { - UserID(uid.to_string()) - } +#[derive(Debug, Default, Deserialize, Serialize)] +pub struct Task { + pub task_id: Uuid, + pub creator: UserID, + pub function_id: ExternalID, + pub function_arguments: FunctionArguments, + pub executor: Executor, + pub input_owners_map: HashMap, + pub output_owners_map: HashMap, + pub function_owner: UserID, + pub participants: UserList, + pub approved_users: UserList, + pub input_map: HashMap, + pub output_map: HashMap, + pub return_value: Option>, + pub output_file_hash: HashMap, + pub status: TaskStatus, } -impl std::convert::From for String { - fn from(user_id: UserID) -> String { - user_id.to_string() +impl Storable for Task { + fn key_prefix() -> &'static str { + TASK_PREFIX } -} -impl std::fmt::Display for UserID { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) + fn uuid(&self) -> Uuid { + self.task_id } } -pub type Executor = String; - impl Task { pub fn new( requester: UserID, - function: &Function, executor: Executor, + function: &Function, function_arguments: FunctionArguments, input_owners_map: HashMap, output_owners_map: HashMap, diff --git a/types/src/worker.rs b/types/src/worker.rs index faf716f61..28862ee24 100644 --- a/types/src/worker.rs +++ b/types/src/worker.rs @@ -19,6 +19,7 @@ use crate::FunctionArguments; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::collections::HashSet; +use std::convert::TryInto; use std::io; use std::prelude::v1::*; @@ -35,7 +36,7 @@ pub trait TeaclaveFunction { ) -> anyhow::Result; } -#[derive(Debug, Copy, Clone, Deserialize, Serialize)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)] pub enum ExecutorType { Native, Python, @@ -51,12 +52,20 @@ impl std::convert::TryFrom<&str> for ExecutorType { type Error = anyhow::Error; fn try_from(selector: &str) -> anyhow::Result { - let sel = match selector { + let executor_type = match selector { "python" => ExecutorType::Python, "native" | "platform" => ExecutorType::Native, - _ => anyhow::bail!("Invalid executor selector: {}", selector), + _ => anyhow::bail!("Invalid executor type: {}", selector), }; - Ok(sel) + Ok(executor_type) + } +} + +impl std::convert::TryFrom for ExecutorType { + type Error = anyhow::Error; + + fn try_from(selector: String) -> anyhow::Result { + selector.as_str().try_into() } } @@ -75,6 +84,60 @@ impl std::fmt::Display for ExecutorType { } } +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)] +pub enum Executor { + MesaPy, + GbdtTraining, + GbdtPrediction, + LogitRegTraining, + LogitRegPrediction, + Echo, +} + +impl std::default::Default for Executor { + fn default() -> Self { + Executor::MesaPy + } +} + +impl std::convert::TryFrom<&str> for Executor { + type Error = anyhow::Error; + + fn try_from(selector: &str) -> anyhow::Result { + let executor = match selector { + "mesapy" => Executor::MesaPy, + "echo" => Executor::Echo, + "gbdt_training" => Executor::GbdtTraining, + "gbdt_prediction" => Executor::GbdtPrediction, + "logistic_regression_training" => Executor::LogitRegTraining, + "logistic_regression_prediction" => Executor::LogitRegPrediction, + _ => anyhow::bail!("Unsupported executor: {}", selector), + }; + Ok(executor) + } +} + +impl std::convert::TryFrom for Executor { + type Error = anyhow::Error; + + fn try_from(selector: String) -> anyhow::Result { + selector.as_str().try_into() + } +} + +impl std::fmt::Display for Executor { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Executor::MesaPy => write!(f, "mesapy"), + Executor::Echo => write!(f, "echo"), + Executor::GbdtTraining => write!(f, "gbdt_training"), + Executor::GbdtPrediction => write!(f, "gbdt_prediction"), + Executor::LogitRegTraining => write!(f, "logistic_regression_training"), + Executor::LogitRegPrediction => write!(f, "logistic_regression_prediction"), + } + } +} + #[derive(Debug)] pub struct WorkerCapability { pub runtimes: HashSet, diff --git a/worker/src/worker.rs b/worker/src/worker.rs index 8be51242e..c08e7cc20 100644 --- a/worker/src/worker.rs +++ b/worker/src/worker.rs @@ -22,7 +22,8 @@ use std::collections::HashMap; use std::format; use teaclave_types::{ - hashmap, ExecutorType, FunctionArguments, StagedFiles, StagedFunction, WorkerCapability, + hashmap, Executor, ExecutorType, FunctionArguments, StagedFiles, StagedFunction, + WorkerCapability, }; use teaclave_function as function; @@ -30,11 +31,11 @@ use teaclave_runtime as runtime; use teaclave_types::{TeaclaveFunction, TeaclaveRuntime}; macro_rules! register_functions{ - ($($name: expr => ($executor: expr, $fn_type: ty),)*) => {{ - let mut functions: HashMap = HashMap::new(); + ($(($executor_type: expr, $executor_name: expr) => $fn_type: ty,)*) => {{ + let mut functions: HashMap<(ExecutorType, Executor), FunctionBuilder> = HashMap::new(); $( functions.insert( - make_function_identifier($executor, $name), + ($executor_type, $executor_name), Box::new(|| Box::new(<$fn_type>::default())), ); )* @@ -44,24 +45,27 @@ macro_rules! register_functions{ pub struct Worker { runtimes: HashMap, - functions: HashMap, + functions: HashMap<(ExecutorType, Executor), FunctionBuilder>, } impl Worker { pub fn default() -> Worker { Worker { functions: register_functions!( - "gbdt_training" => (ExecutorType::Native, function::GbdtTraining), - "gbdt_prediction" => (ExecutorType::Native, function::GbdtPrediction), - "echo" => (ExecutorType::Native, function::Echo), - "mesapy" => (ExecutorType::Python, function::Mesapy), + (ExecutorType::Python, Executor::MesaPy) => function::Mesapy, + (ExecutorType::Native, Executor::Echo) => function::Echo, + (ExecutorType::Native, Executor::GbdtTraining) => function::GbdtTraining, + (ExecutorType::Native, Executor::GbdtPrediction) => function::GbdtPrediction, + (ExecutorType::Native, Executor::LogitRegTraining) => function::LogitRegTraining, + (ExecutorType::Native, Executor::LogitRegPrediction) => function::LogitRegPrediction, ), runtimes: setup_runtimes(), } } pub fn invoke_function(&self, staged_function: StagedFunction) -> anyhow::Result { - let function = self.get_function(staged_function.executor_type, &staged_function.name)?; + let function = + self.get_function(staged_function.executor_type, staged_function.executor)?; let runtime = self.get_runtime( &staged_function.runtime_name, staged_function.input_files, @@ -78,7 +82,12 @@ impl Worker { pub fn get_capability(&self) -> WorkerCapability { WorkerCapability { runtimes: self.runtimes.keys().cloned().collect(), - functions: self.functions.keys().cloned().collect(), + functions: self + .functions + .keys() + .cloned() + .map(|(exec_type, exec_name)| make_function_identifier(exec_type, exec_name)) + .collect(), } } @@ -99,23 +108,22 @@ impl Worker { fn get_function( &self, - func_type: ExecutorType, - func_name: &str, + exec_type: ExecutorType, + exec_name: Executor, ) -> anyhow::Result> { - let identifier = make_function_identifier(func_type, func_name); + let identifier = (exec_type, exec_name); let build_function = self .functions .get(&identifier) - .ok_or_else(|| anyhow::anyhow!(format!("function not available: {}", identifier)))?; + .ok_or_else(|| anyhow::anyhow!(format!("function not available: {:?}", identifier)))?; let function = build_function(); Ok(function) } } -fn make_function_identifier(func_type: ExecutorType, func_name: &str) -> String { - let type_str = func_type.to_string(); - format!("{}-{}", type_str, func_name) +fn make_function_identifier(exec_type: ExecutorType, exec_name: Executor) -> String { + format!("{}-{}", exec_type, exec_name) } fn setup_runtimes() -> HashMap {