From 3f35c3eee22cb52de04dea764e529517f6bfef0d Mon Sep 17 00:00:00 2001 From: Josh Matthews Date: Thu, 12 Jan 2017 18:05:54 -0500 Subject: [PATCH] Add a permanent root to WebIDL callbacks, ensuring they are always safe to store. --- components/script/dom/bindings/callback.rs | 55 +++++++++++++++---- .../dom/bindings/codegen/CodegenRust.py | 17 +++--- components/script/dom/eventtarget.rs | 27 ++++++--- components/script/script_runtime.rs | 4 +- 4 files changed, 75 insertions(+), 28 deletions(-) diff --git a/components/script/dom/bindings/callback.rs b/components/script/dom/bindings/callback.rs index 067b498a35d5..e260195f7ef5 100644 --- a/components/script/dom/bindings/callback.rs +++ b/components/script/dom/bindings/callback.rs @@ -5,16 +5,16 @@ //! Base classes to work with IDL callbacks. use dom::bindings::error::{Error, Fallible, report_pending_exception}; -use dom::bindings::js::Root; +use dom::bindings::js::{Root, MutHeapJSVal}; use dom::bindings::reflector::DomObject; use dom::bindings::settings_stack::AutoEntryScript; use dom::globalscope::GlobalScope; use js::jsapi::{Heap, MutableHandleObject}; -use js::jsapi::{IsCallable, JSContext, JSObject, JS_WrapObject}; -use js::jsapi::{JSCompartment, JS_EnterCompartment, JS_LeaveCompartment}; +use js::jsapi::{IsCallable, JSContext, JSObject, JS_WrapObject, AddRawValueRoot}; +use js::jsapi::{JSCompartment, JS_EnterCompartment, JS_LeaveCompartment, RemoveRawValueRoot}; use js::jsapi::JSAutoCompartment; use js::jsapi::JS_GetProperty; -use js::jsval::{JSVal, UndefinedValue}; +use js::jsval::{JSVal, UndefinedValue, ObjectValue}; use std::default::Default; use std::ffi::CString; use std::mem::drop; @@ -33,22 +33,52 @@ pub enum ExceptionHandling { /// A common base class for representing IDL callback function and /// callback interface types. -#[derive(Default, JSTraceable)] +#[derive(JSTraceable)] +#[must_root] pub struct CallbackObject { /// The underlying `JSObject`. callback: Heap<*mut JSObject>, + permanent_js_root: MutHeapJSVal, +} + +impl Default for CallbackObject { + #[allow(unrooted_must_root)] + fn default() -> CallbackObject { + CallbackObject::new() + } } impl CallbackObject { + #[allow(unrooted_must_root)] fn new() -> CallbackObject { CallbackObject { callback: Heap::default(), + permanent_js_root: MutHeapJSVal::new(), } } pub fn get(&self) -> *mut JSObject { self.callback.get() } + + #[allow(unsafe_code)] + unsafe fn init(&mut self, cx: *mut JSContext, callback: *mut JSObject) { + self.callback.set(callback); + self.permanent_js_root.set(ObjectValue(callback)); + assert!(AddRawValueRoot(cx, self.permanent_js_root.get_unsafe(), + b"CallbackObject::root\n" as *const _ as *const _)); + } +} + +impl Drop for CallbackObject { + #[allow(unsafe_code)] + fn drop(&mut self) { + unsafe { + let cx = GlobalScope::from_object(self.callback.get()).get_cx(); + RemoveRawValueRoot(cx, self.permanent_js_root.get_unsafe()); + } + } + } impl PartialEq for CallbackObject { @@ -62,7 +92,7 @@ impl PartialEq for CallbackObject { /// callback interface types. pub trait CallbackContainer { /// Create a new CallbackContainer object for the given `JSObject`. - fn new(callback: *mut JSObject) -> Rc; + unsafe fn new(cx: *mut JSContext, callback: *mut JSObject) -> Rc; /// Returns the underlying `CallbackObject`. fn callback_holder(&self) -> &CallbackObject; /// Returns the underlying `JSObject`. @@ -74,12 +104,14 @@ pub trait CallbackContainer { /// A common base class for representing IDL callback function types. #[derive(JSTraceable, PartialEq)] +#[must_root] pub struct CallbackFunction { object: CallbackObject, } impl CallbackFunction { /// Create a new `CallbackFunction` for this object. + #[allow(unrooted_must_root)] pub fn new() -> CallbackFunction { CallbackFunction { object: CallbackObject::new(), @@ -93,14 +125,17 @@ impl CallbackFunction { /// Initialize the callback function with a value. /// Should be called once this object is done moving. - pub fn init(&mut self, callback: *mut JSObject) { - self.object.callback.set(callback); + pub unsafe fn init(&mut self, cx: *mut JSContext, callback: *mut JSObject) { + self.object.init(cx, callback); } } + + /// A common base class for representing IDL callback interface types. #[derive(JSTraceable, PartialEq)] +#[must_root] pub struct CallbackInterface { object: CallbackObject, } @@ -120,8 +155,8 @@ impl CallbackInterface { /// Initialize the callback function with a value. /// Should be called once this object is done moving. - pub fn init(&mut self, callback: *mut JSObject) { - self.object.callback.set(callback); + pub unsafe fn init(&mut self, cx: *mut JSContext, callback: *mut JSObject) { + self.object.init(cx, callback); } /// Returns the property with the given `name`, if it is a callable object, diff --git a/components/script/dom/bindings/codegen/CodegenRust.py b/components/script/dom/bindings/codegen/CodegenRust.py index ffd82a7f135c..3df7c01f148d 100644 --- a/components/script/dom/bindings/codegen/CodegenRust.py +++ b/components/script/dom/bindings/codegen/CodegenRust.py @@ -776,7 +776,7 @@ def wrapObjectTemplate(templateBody, nullValue, isDefinitelyObject, type, if descriptor.interface.isCallback(): name = descriptor.nativeType declType = CGWrapper(CGGeneric(name), pre="Rc<", post=">") - template = "%s::new(${val}.get().to_object())" % name + template = "%s::new(cx, ${val}.get().to_object())" % name if type.nullable(): declType = CGWrapper(declType, pre="Option<", post=">") template = wrapObjectTemplate("Some(%s)" % template, "None", @@ -2195,7 +2195,7 @@ def define(self): class CGCallbackTempRoot(CGGeneric): def __init__(self, name): - CGGeneric.__init__(self, "%s::new(${val}.get().to_object())" % name) + CGGeneric.__init__(self, "%s::new(cx, ${val}.get().to_object())" % name) def getAllTypes(descriptors, dictionaries, callbacks, typedefs): @@ -4444,10 +4444,11 @@ def getBody(self, cgClass): "});\n" "// Note: callback cannot be moved after calling init.\n" "match Rc::get_mut(&mut ret) {\n" - " Some(ref mut callback) => callback.parent.init(%s),\n" + " Some(ref mut callback) => unsafe { callback.parent.init(%s, %s) },\n" " None => unreachable!(),\n" "};\n" - "ret") % (cgClass.name, '\n'.join(initializers), self.args[0].name)) + "ret") % (cgClass.name, '\n'.join(initializers), + self.args[0].name, self.args[1].name)) def declare(self, cgClass): args = ', '.join([a.declare() for a in self.args]) @@ -6236,11 +6237,11 @@ def __init__(self, idlObject, descriptorProvider, baseName, methods, bases=[ClassBase(baseName)], constructors=self.getConstructors(), methods=realMethods + getters + setters, - decorators="#[derive(JSTraceable, PartialEq)]") + decorators="#[derive(JSTraceable, PartialEq)]\n#[allow_unrooted_interior]") def getConstructors(self): return [ClassConstructor( - [Argument("*mut JSObject", "aCallback")], + [Argument("*mut JSContext", "aCx"), Argument("*mut JSObject", "aCallback")], bodyInHeader=True, visibility="pub", explicit=False, @@ -6336,8 +6337,8 @@ class CGCallbackFunctionImpl(CGGeneric): def __init__(self, callback): impl = string.Template("""\ impl CallbackContainer for ${type} { - fn new(callback: *mut JSObject) -> Rc<${type}> { - ${type}::new(callback) + unsafe fn new(cx: *mut JSContext, callback: *mut JSObject) -> Rc<${type}> { + ${type}::new(cx, callback) } fn callback_holder(&self) -> &CallbackObject { diff --git a/components/script/dom/eventtarget.rs b/components/script/dom/eventtarget.rs index 09cafb65bac5..1023b97600ff 100644 --- a/components/script/dom/eventtarget.rs +++ b/components/script/dom/eventtarget.rs @@ -441,13 +441,13 @@ impl EventTarget { assert!(!funobj.is_null()); // Step 1.14 if is_error { - Some(CommonEventHandler::ErrorEventHandler(OnErrorEventHandlerNonNull::new(funobj))) + Some(CommonEventHandler::ErrorEventHandler(OnErrorEventHandlerNonNull::new(cx, funobj))) } else { if ty == &atom!("beforeunload") { Some(CommonEventHandler::BeforeUnloadEventHandler( - OnBeforeUnloadEventHandlerNonNull::new(funobj))) + OnBeforeUnloadEventHandlerNonNull::new(cx, funobj))) } else { - Some(CommonEventHandler::EventHandler(EventHandlerNonNull::new(funobj))) + Some(CommonEventHandler::EventHandler(EventHandlerNonNull::new(cx, funobj))) } } } @@ -455,36 +455,47 @@ impl EventTarget { pub fn set_event_handler_common( &self, ty: &str, listener: Option>) { + let cx = self.global().get_cx(); + let event_listener = listener.map(|listener| InlineEventListener::Compiled( CommonEventHandler::EventHandler( - EventHandlerNonNull::new(listener.callback())))); + EventHandlerNonNull::new(cx, listener.callback())))); self.set_inline_event_listener(Atom::from(ty), event_listener); } pub fn set_error_event_handler( &self, ty: &str, listener: Option>) { + let cx = self.global().get_cx(); + let event_listener = listener.map(|listener| InlineEventListener::Compiled( CommonEventHandler::ErrorEventHandler( - OnErrorEventHandlerNonNull::new(listener.callback())))); + OnErrorEventHandlerNonNull::new(cx, listener.callback())))); self.set_inline_event_listener(Atom::from(ty), event_listener); } pub fn set_beforeunload_event_handler(&self, ty: &str, - listener: Option>) { + listener: Option>) { + let cx = self.global().get_cx(); + let event_listener = listener.map(|listener| InlineEventListener::Compiled( CommonEventHandler::BeforeUnloadEventHandler( - OnBeforeUnloadEventHandlerNonNull::new(listener.callback()))) + OnBeforeUnloadEventHandlerNonNull::new(cx, listener.callback()))) ); self.set_inline_event_listener(Atom::from(ty), event_listener); } + #[allow(unsafe_code)] pub fn get_event_handler_common(&self, ty: &str) -> Option> { + let cx = self.global().get_cx(); let listener = self.get_inline_event_listener(&Atom::from(ty)); - listener.map(|listener| CallbackContainer::new(listener.parent().callback_holder().get())) + unsafe { + listener.map(|listener| + CallbackContainer::new(cx, listener.parent().callback_holder().get())) + } } pub fn has_handlers(&self) -> bool { diff --git a/components/script/script_runtime.rs b/components/script/script_runtime.rs index 75ce94084f07..b7a0374428b9 100644 --- a/components/script/script_runtime.rs +++ b/components/script/script_runtime.rs @@ -176,7 +176,7 @@ impl PromiseJobQueue { /// promise job queue, and enqueues a runnable to perform a microtask checkpoint if one /// is not already pending. #[allow(unsafe_code)] -unsafe extern "C" fn enqueue_job(_cx: *mut JSContext, +unsafe extern "C" fn enqueue_job(cx: *mut JSContext, job: HandleObject, _allocation_site: HandleObject, _data: *mut c_void) -> bool { @@ -184,7 +184,7 @@ unsafe extern "C" fn enqueue_job(_cx: *mut JSContext, let global = GlobalScope::from_object(job.get()); let pipeline = global.pipeline_id(); global.enqueue_promise_job(EnqueuedPromiseCallback { - callback: PromiseJobCallback::new(job.get()), + callback: PromiseJobCallback::new(cx, job.get()), pipeline: pipeline, }); true