Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enclave: Remove all panickings in RA #506

Merged
merged 1 commit into from Sep 29, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
99 changes: 44 additions & 55 deletions standalone/pruntime/enclave/src/pal_sgx.rs
Expand Up @@ -3,8 +3,7 @@ use std::os::unix::prelude::OsStrExt as _;
use std::str;
use std::time::Duration;

use anyhow::anyhow;
use anyhow::Result;
use anyhow::{anyhow, Context, Result};
use http_req::request::{Method, Request};
use log::{error, info, warn};
use rand::RngCore as _;
Expand Down Expand Up @@ -181,7 +180,7 @@ extern "C" {
) -> sgx_status_t;
}

fn ias_spid() -> sgx_spid_t {
fn ias_spid() -> Result<sgx_spid_t> {
// Try load persisted sealed data
let mut key_buf = vec![0; 256].into_boxed_slice();
let mut key_len: usize = 0;
Expand All @@ -193,30 +192,30 @@ fn ias_spid() -> sgx_spid_t {
let load_result = unsafe { ocall_load_ias_spid(&mut retval, key_ptr, key_len_ptr, 256) };

if load_result != sgx_status_t::SGX_SUCCESS || key_len == 0 {
panic!("Load SPID failure.");
return Err(anyhow!("Load SPID failure."));
}

let key_str = str::from_utf8(key_slice).unwrap();
let key_str = str::from_utf8(key_slice).context("UTF8 decode key_slice")?;
// println!("IAS SPID: {}", key_str.to_owned());

decode_spid(&key_str[..key_len])
}

fn decode_spid(raw_hex: &str) -> sgx_spid_t {
fn decode_spid(raw_hex: &str) -> Result<sgx_spid_t> {
let mut spid = sgx_spid_t::default();
let raw_hex = raw_hex.trim();

if raw_hex.len() < 16 * 2 {
log::warn!("Input spid file len ({}) is incorrect!", raw_hex.len());
return spid;
return Ok(spid);
}

let decoded_vec = hex::decode(raw_hex).expect("Failed to decode SPID hex");
let decoded_vec = hex::decode(raw_hex).context("Failed to decode SPID hex")?;
spid.id.copy_from_slice(&decoded_vec[..16]);
spid
Ok(spid)
}

fn ias_key() -> String {
fn ias_key() -> Result<String> {
// Try load persisted sealed data
let mut key_buf = vec![0; 256].into_boxed_slice();
let mut key_len: usize = 0;
Expand All @@ -227,16 +226,16 @@ fn ias_key() -> String {
let mut retval = sgx_status_t::SGX_SUCCESS;
let load_result = unsafe { ocall_load_ias_key(&mut retval, key_ptr, key_len_ptr, 256) };
if load_result != sgx_status_t::SGX_SUCCESS || key_len == 0 {
panic!("Load IAS KEY failure.");
return Err(anyhow!("Load IAS KEY failure."));
}

let key_str = str::from_utf8(key_slice).unwrap();
let key_str = str::from_utf8(key_slice).context("UTF8 decode key_slice")?;
// println!("IAS KEY: {}", key_str.to_owned());

key_str[..key_len].to_owned()
Ok(key_str[..key_len].to_owned())
}

pub fn get_sigrl_from_intel(gid: u32) -> Vec<u8> {
pub fn get_sigrl_from_intel(gid: u32) -> Result<Vec<u8>> {
// println!("get_sigrl_from_intel fd = {:?}", fd);
//let sigrl_arg = SigRLArg { group_id : gid };
//let sigrl_req = sigrl_arg.to_httpreq();
Expand All @@ -245,15 +244,15 @@ pub fn get_sigrl_from_intel(gid: u32) -> Vec<u8> {
let timeout = Some(Duration::from_secs(8));

let url = format!("https://{}{}/{:08x}", IAS_HOST, IAS_SIGRL_ENDPOINT, gid);
let url = TryFrom::try_from(url.as_str()).expect("Invalid IAS URI");
let url = TryFrom::try_from(url.as_str()).context("Invalid IAS URI")?;
let res = Request::new(&url)
.header("Connection", "Close")
.header("Ocp-Apim-Subscription-Key", &ias_key())
.header("Ocp-Apim-Subscription-Key", &ias_key()?)
.timeout(timeout)
.connect_timeout(timeout)
.read_timeout(timeout)
.send(&mut res_body_buffer)
.unwrap();
.context("Http request to IAS failed")?;

// parse_response_sigrl

Expand All @@ -273,23 +272,22 @@ pub fn get_sigrl_from_intel(gid: u32) -> Vec<u8> {
};

error!("{}", msg);
// TODO: should return Err
panic!("status code {}", status_code);
return Err(anyhow!(format!("Bad http status: {}", status_code)));
}

if res.content_len() != None && res.content_len() != Some(0) {
let res_body = res_body_buffer;
let encoded_sigrl = str::from_utf8(&res_body).unwrap();
let encoded_sigrl = str::from_utf8(&res_body).context("UTF8 decode sigrl")?;
info!("Base64-encoded SigRL: {:?}", encoded_sigrl);

return base64::decode(encoded_sigrl).unwrap();
return Ok(base64::decode(encoded_sigrl).context("Base64 decode sigrl")?);
}

Vec::new()
Ok(Vec::new())
}

// TODO: support pse
pub fn get_report_from_intel(quote: Vec<u8>) -> (String, String, String) {
pub fn get_report_from_intel(quote: Vec<u8>) -> Result<(String, String, String)> {
// println!("get_report_from_intel fd = {:?}", fd);
let encoded_quote = base64::encode(&quote[..]);
let encoded_json = format!("{{\"isvEnclaveQuote\":\"{}\"}}\r\n", encoded_quote);
Expand All @@ -300,19 +298,19 @@ pub fn get_report_from_intel(quote: Vec<u8>) -> (String, String, String) {
let timeout = Some(Duration::from_secs(8));

let url = format!("https://{}{}", IAS_HOST, IAS_REPORT_ENDPOINT);
let url = TryFrom::try_from(url.as_str()).expect("Invalid IAS URI");
let url = TryFrom::try_from(url.as_str()).context("Invalid IAS URI")?;
let res = Request::new(&url)
.header("Connection", "Close")
.header("Content-Type", "application/json")
.header("Content-Length", &encoded_json.len())
.header("Ocp-Apim-Subscription-Key", &ias_key)
.header("Ocp-Apim-Subscription-Key", &ias_key?)
.method(Method::POST)
.body(encoded_json.as_bytes())
.timeout(timeout)
.connect_timeout(timeout)
.read_timeout(timeout)
.send(&mut res_body_buffer)
.unwrap();
.context("Http request to IAS failed")?;

let status_code = u16::from(res.status_code());
if status_code != 200 {
Expand All @@ -330,8 +328,7 @@ pub fn get_report_from_intel(quote: Vec<u8>) -> (String, String, String) {
};

error!("{}", msg);
// TODO: should return Err
panic!("status code not 200");
return Err(anyhow!(format!("Bad http status: {}", status_code)));
}

let content_len = match res.content_len() {
Expand All @@ -343,43 +340,42 @@ pub fn get_report_from_intel(quote: Vec<u8>) -> (String, String, String) {
};

if content_len == 0 {
// TODO: should return Err
panic!("don't know how to handle content_length is 0");
return Err(anyhow!("Empty HTTP response"));
}

let attn_report = String::from_utf8(res_body_buffer).unwrap();
let attn_report = String::from_utf8(res_body_buffer).context("UTF8 decode response")?;
let sig = res
.headers()
.get("X-IASReport-Signature")
.unwrap()
.context("Get X-IASReport-Signature")?
.to_string();
let mut cert = res
.headers()
.get("X-IASReport-Signing-Certificate")
.unwrap()
.context("Get X-IASReport-Signing-Certificate")?
.to_string();

// Remove %0A from cert, and only obtain the signing cert
cert = cert.replace("%0A", "");
cert = percent_decode(cert);
cert = percent_decode(cert).context("percent_decode cert")?;
let v: Vec<&str> = cert.split("-----").collect();
let sig_cert = v[2].to_string();

// len_num == 0
(attn_report, sig, sig_cert)
Ok((attn_report, sig, sig_cert))
}

fn percent_decode(orig: String) -> String {
fn percent_decode(orig: String) -> Result<String> {
let v: Vec<&str> = orig.split('%').collect();
let mut ret = String::new();
ret.push_str(v[0]);
if v.len() > 1 {
for s in v[1..].iter() {
ret.push(u8::from_str_radix(&s[0..2], 16).unwrap() as char);
ret.push(u8::from_str_radix(&s[0..2], 16).context("Invalid radix code")? as char);
ret.push_str(&s[2..]);
}
}
ret
Ok(ret)
}

fn as_u32_le(array: &[u8; 4]) -> u32 {
Expand All @@ -396,7 +392,7 @@ pub fn create_attestation_report(
) -> Result<(String, String, String)> {
let data_len = data.len();
if data_len > SGX_REPORT_DATA_SIZE {
panic!("data length over 64 bytes");
return Err(anyhow!("data length over 64 bytes"));
}

// Workflow:
Expand Down Expand Up @@ -435,23 +431,15 @@ pub fn create_attestation_report(
//println!("Got ias_sock = {}", ias_sock);

// Now sigrl_vec is the revocation list, a vec<u8>
let sigrl_vec: Vec<u8> = get_sigrl_from_intel(eg_num);
let sigrl_vec: Vec<u8> = get_sigrl_from_intel(eg_num)?;

// (2) Generate the report
// Fill data into report_data
let mut report_data: sgx_report_data_t = sgx_report_data_t::default();
report_data.d[..data_len].clone_from_slice(data);

let rep = match rsgx_create_report(&ti, &report_data) {
Ok(r) => {
info!("Report creation => success {:?}", r.body.mr_signer.m);
Some(r)
}
Err(e) => {
warn!("Report creation => failed {:?}", e);
None
}
};
let rep = rsgx_create_report(&ti, &report_data)
.map_err(|err| anyhow!(format!("Create report failed: {}", err)))?;

let mut quote_nonce = sgx_quote_nonce_t { rand: [0; 16] };
let mut os_rng = rand::thread_rng();
Expand All @@ -478,10 +466,10 @@ pub fn create_attestation_report(
} else {
(sigrl_vec.as_ptr(), sigrl_vec.len() as u32)
};
let p_report = (&rep.unwrap()) as *const sgx_report_t;
let p_report = &rep;
let quote_type = sign_type;

let spid: sgx_spid_t = ias_spid();
let spid: sgx_spid_t = ias_spid()?;

let p_spid = &spid as *const sgx_spid_t;
let p_nonce = &quote_nonce as *const sgx_quote_nonce_t;
Expand Down Expand Up @@ -554,7 +542,8 @@ pub fn create_attestation_report(

let mut rhs_vec: Vec<u8> = quote_nonce.rand.to_vec();
rhs_vec.extend(&return_quote_buf[..quote_len as usize]);
let rhs_hash = rsgx_sha256_slice(&rhs_vec[..]).unwrap();
let rhs_hash = rsgx_sha256_slice(&rhs_vec[..])
.map_err(|err| anyhow!(format!("sha256 hash error: {}", err)))?;
let lhs_hash = &qe_report.body.report_data.d[..32];

info!("rhs hash = {}", hex::encode(rhs_hash));
Expand All @@ -566,7 +555,7 @@ pub fn create_attestation_report(
}

let quote_vec: Vec<u8> = return_quote_buf[..quote_len as usize].to_vec();
let (attn_report, sig, cert) = get_report_from_intel(quote_vec);
let (attn_report, sig, cert) = get_report_from_intel(quote_vec)?;
Ok((attn_report, sig, cert))
}

Expand All @@ -583,7 +572,7 @@ fn generate_seal_key() -> [u8; 16] {
config_svn: 0_u16,
reserved2: [0_u8; SGX_KEY_REQUEST_RESERVED2_BYTES],
};
let seal_key = rsgx_get_align_key(&key_request).unwrap();
let seal_key = rsgx_get_align_key(&key_request).unwrap_or_default();
seal_key.key
}

Expand Down