diff --git a/components/net/about_loader.rs b/components/net/about_loader.rs index 04878fed70b7..6a03aeb041d1 100644 --- a/components/net/about_loader.rs +++ b/components/net/about_loader.rs @@ -9,13 +9,16 @@ use hyper::mime::{Mime, SubLevel, TopLevel}; use mime_classifier::MIMEClassifier; use net_traits::ProgressMsg::Done; use net_traits::{LoadConsumer, LoadData, Metadata}; -use resource_task::{send_error, start_sending_sniffed_opt}; +use resource_task::{CancellationListener, send_error, start_sending_sniffed_opt}; use std::fs::PathExt; use std::sync::Arc; use url::Url; use util::resource_files::resources_dir_path; -pub fn factory(mut load_data: LoadData, start_chan: LoadConsumer, classifier: Arc) { +pub fn factory(mut load_data: LoadData, + start_chan: LoadConsumer, + classifier: Arc, + cancel_listener: CancellationListener) { match load_data.url.non_relative_scheme_data().unwrap() { "blank" => { let metadata = Metadata { @@ -42,5 +45,5 @@ pub fn factory(mut load_data: LoadData, start_chan: LoadConsumer, classifier: Ar return } }; - file_loader::factory(load_data, start_chan, classifier) + file_loader::factory(load_data, start_chan, classifier, cancel_listener) } diff --git a/components/net/data_loader.rs b/components/net/data_loader.rs index f2f72bac88e2..b7fb25c124dc 100644 --- a/components/net/data_loader.rs +++ b/components/net/data_loader.rs @@ -6,21 +6,27 @@ use hyper::mime::{Mime, TopLevel, SubLevel, Attr, Value}; use mime_classifier::MIMEClassifier; use net_traits::ProgressMsg::{Done, Payload}; use net_traits::{LoadConsumer, LoadData, Metadata}; -use resource_task::{send_error, start_sending_sniffed_opt}; +use resource_task::{CancellationListener, send_error, start_sending_sniffed_opt}; use rustc_serialize::base64::FromBase64; use std::sync::Arc; use url::SchemeData; use url::percent_encoding::percent_decode; -pub fn factory(load_data: LoadData, senders: LoadConsumer, classifier: Arc) { +pub fn factory(load_data: LoadData, + senders: LoadConsumer, + classifier: Arc, + cancel_listener: CancellationListener) { // NB: we don't spawn a new task. // Hypothesis: data URLs are too small for parallel base64 etc. to be worth it. // Should be tested at some point. // Left in separate function to allow easy moving to a task, if desired. - load(load_data, senders, classifier) + load(load_data, senders, classifier, cancel_listener) } -pub fn load(load_data: LoadData, start_chan: LoadConsumer, classifier: Arc) { +pub fn load(load_data: LoadData, + start_chan: LoadConsumer, + classifier: Arc, + cancel_listener: CancellationListener) { let url = load_data.url; assert!(&*url.scheme == "data"); @@ -63,8 +69,11 @@ pub fn load(load_data: LoadData, start_chan: LoadConsumer, classifier: Arc Result { let mut buf = vec![0; READ_SIZE]; match reader.read(&mut buf) { @@ -33,17 +39,24 @@ fn read_block(reader: &mut File) -> Result { } } -fn read_all(reader: &mut File, progress_chan: &ProgressSender) - -> Result<(), String> { +fn read_all(reader: &mut File, progress_chan: &ProgressSender, cancel_listener: &CancellationListener) + -> Result { loop { + if cancel_listener.is_cancelled() { + return Ok(LoadResult::Cancelled); + } + match try!(read_block(reader)) { ReadStatus::Partial(buf) => progress_chan.send(Payload(buf)).unwrap(), - ReadStatus::EOF => return Ok(()), + ReadStatus::EOF => return Ok(LoadResult::Finished), } } } -pub fn factory(load_data: LoadData, senders: LoadConsumer, classifier: Arc) { +pub fn factory(load_data: LoadData, + senders: LoadConsumer, + classifier: Arc, + cancel_listener: CancellationListener) { let url = load_data.url; assert!(&*url.scheme == "file"); spawn_named("file_loader".to_owned(), move || { @@ -52,14 +65,22 @@ pub fn factory(load_data: LoadData, senders: LoadConsumer, classifier: Arc { match File::open(&file_path) { Ok(ref mut reader) => { + if cancel_listener.is_cancelled() { + return; + } match read_block(reader) { Ok(ReadStatus::Partial(buf)) => { let metadata = Metadata::default(url); let progress_chan = start_sending_sniffed(senders, metadata, classifier, &buf); progress_chan.send(Payload(buf)).unwrap(); - let res = read_all(reader, &progress_chan); - let _ = progress_chan.send(Done(res)); + let read_result = read_all(reader, &progress_chan, &cancel_listener); + if let Ok(load_result) = read_result { + match load_result { + LoadResult::Cancelled => return, + LoadResult::Finished => progress_chan.send(Done(Ok(()))).unwrap(), + } + } } Ok(ReadStatus::EOF) => { let metadata = Metadata::default(url); diff --git a/components/net/http_loader.rs b/components/net/http_loader.rs index 15ed5f7fdf41..ef854a504dd9 100644 --- a/components/net/http_loader.rs +++ b/components/net/http_loader.rs @@ -28,7 +28,7 @@ use net_traits::ProgressMsg::{Done, Payload}; use net_traits::hosts::replace_hosts; use net_traits::{CookieSource, IncludeSubdomains, LoadConsumer, LoadData, Metadata}; use openssl::ssl::{SSL_VERIFY_PEER, SslContext, SslMethod}; -use resource_task::{send_error, start_sending_sniffed_opt}; +use resource_task::{CancellationListener, send_error, start_sending_sniffed_opt}; use std::borrow::ToOwned; use std::boxed::FnBox; use std::collections::HashSet; @@ -59,8 +59,11 @@ pub fn factory(user_agent: String, cookie_jar: Arc>, devtools_chan: Option>, connector: Arc>) - -> Box) + Send> { - box move |load_data: LoadData, senders, classifier| { + -> Box, + CancellationListener) + Send> { + box move |load_data: LoadData, senders, classifier, cancel_listener| { spawn_named(format!("http_loader for {}", load_data.url.serialize()), move || { load_for_consumer(load_data, senders, @@ -69,6 +72,7 @@ pub fn factory(user_agent: String, hsts_list, cookie_jar, devtools_chan, + cancel_listener, user_agent) }) } @@ -104,6 +108,7 @@ fn load_for_consumer(load_data: LoadData, hsts_list: Arc>, cookie_jar: Arc>, devtools_chan: Option>, + cancel_listener: CancellationListener, user_agent: String) { let factory = NetworkHttpRequestFactory { @@ -132,13 +137,12 @@ fn load_for_consumer(load_data: LoadData, image.push("badcert.html"); let load_data = LoadData::new(Url::from_file_path(&*image).unwrap(), None); - file_loader::factory(load_data, start_chan, classifier) - + file_loader::factory(load_data, start_chan, classifier, cancel_listener) } Err(LoadError::ConnectionAborted(_)) => unreachable!(), Ok(mut load_response) => { let metadata = load_response.metadata.clone(); - send_data(&mut load_response, start_chan, metadata, classifier) + send_data(&mut load_response, start_chan, metadata, classifier, cancel_listener) } } } @@ -717,7 +721,8 @@ pub fn load(load_data: LoadData, fn send_data(reader: &mut R, start_chan: LoadConsumer, metadata: Metadata, - classifier: Arc) { + classifier: Arc, + cancel_listener: CancellationListener) { let (progress_chan, mut chunk) = { let buf = match read_block(reader) { Ok(ReadResult::Payload(buf)) => buf, @@ -731,6 +736,11 @@ fn send_data(reader: &mut R, }; loop { + if cancel_listener.is_cancelled() { + let _ = progress_chan.send(Done(Err("load cancelled".to_owned()))); + return; + } + if progress_chan.send(Payload(chunk)).is_err() { // The send errors when the receiver is out of scope, // which will happen if the fetch has timed out (or has been aborted) diff --git a/components/net/image_cache_task.rs b/components/net/image_cache_task.rs index 198d04e0bfd7..4ae17a6b1a54 100644 --- a/components/net/image_cache_task.rs +++ b/components/net/image_cache_task.rs @@ -428,7 +428,8 @@ impl ImageCache { sender: action_sender, }; let msg = ControlMsg::Load(load_data, - LoadConsumer::Listener(response_target)); + LoadConsumer::Listener(response_target), + None); let progress_sender = self.progress_sender.clone(); ROUTER.add_route(action_receiver.to_opaque(), box move |message| { let action: ResponseAction = message.to().unwrap(); diff --git a/components/net/resource_task.rs b/components/net/resource_task.rs index 529f79b448ec..caf73c4a9769 100644 --- a/components/net/resource_task.rs +++ b/components/net/resource_task.rs @@ -19,10 +19,12 @@ use ipc_channel::ipc::{self, IpcReceiver, IpcSender}; use mime_classifier::{ApacheBugFlag, MIMEClassifier, NoSniffFlag}; use net_traits::ProgressMsg::Done; use net_traits::{AsyncResponseTarget, Metadata, ProgressMsg, ResourceTask, ResponseAction}; -use net_traits::{ControlMsg, CookieSource, LoadConsumer, LoadData, LoadResponse}; +use net_traits::{ControlMsg, CookieSource, LoadConsumer, LoadData, LoadResponse, ResourceId}; use std::borrow::ToOwned; use std::boxed::FnBox; -use std::sync::mpsc::{Sender, channel}; +use std::cell::Cell; +use std::collections::HashMap; +use std::sync::mpsc::{Receiver, Sender, channel}; use std::sync::{Arc, RwLock}; use url::Url; use util::opts; @@ -146,6 +148,7 @@ pub fn new_resource_task(user_agent: String, }; let (setup_chan, setup_port) = ipc::channel().unwrap(); + let setup_chan_clone = setup_chan.clone(); spawn_named("ResourceManager".to_owned(), move || { let resource_manager = ResourceManager::new( user_agent, hsts_preload, devtools_chan @@ -155,8 +158,7 @@ pub fn new_resource_task(user_agent: String, from_client: setup_port, resource_manager: resource_manager }; - - channel_manager.start(); + channel_manager.start(setup_chan_clone); }); setup_chan } @@ -167,28 +169,85 @@ struct ResourceChannelManager { } impl ResourceChannelManager { - fn start(&mut self) { + fn start(&mut self, control_sender: ResourceTask) { loop { match self.from_client.recv().unwrap() { - ControlMsg::Load(load_data, consumer) => { - self.resource_manager.load(load_data, consumer) - } - ControlMsg::SetCookiesForUrl(request, cookie_list, source) => { - self.resource_manager.set_cookies_for_url(request, cookie_list, source) - } + ControlMsg::Load(load_data, consumer, id_sender) => + self.resource_manager.load(load_data, consumer, id_sender, control_sender.clone()), + ControlMsg::SetCookiesForUrl(request, cookie_list, source) => + self.resource_manager.set_cookies_for_url(request, cookie_list, source), ControlMsg::GetCookiesForUrl(url, consumer, source) => { let cookie_jar = &self.resource_manager.cookie_storage; let mut cookie_jar = cookie_jar.write().unwrap(); consumer.send(cookie_jar.cookies_for_url(&url, source)).unwrap(); } - ControlMsg::Exit => { - break + ControlMsg::Cancel(res_id) => { + if let Some(cancel_sender) = self.resource_manager.cancel_load_map.get(&res_id) { + let _ = cancel_sender.send(()); + } + self.resource_manager.cancel_load_map.remove(&res_id); } + ControlMsg::Exit => break, } } } } +/// The optional resources required by the `CancellationListener` +pub struct CancellableResource { + /// The receiver which receives a message on load cancellation + cancel_receiver: Receiver<()>, + /// The `CancellationListener` is unique to this `ResourceId` + resource_id: ResourceId, + /// If we haven't initiated any cancel requests, then the loaders ask + /// the listener to remove the `ResourceId` in the `HashMap` of + /// `ResourceManager` once they finish loading + resource_task: ResourceTask, +} + +/// A listener which is basically a wrapped optional receiver which looks +/// for the load cancellation message. Some of the loading processes always keep +/// an eye out for this message and stop loading stuff once they receive it. +pub struct CancellationListener { + /// We'll be needing the resources only if we plan to cancel it + cancel_resource: Option, + /// This lets us know whether the request has already been cancelled + cancel_status: Cell, +} + +impl CancellationListener { + pub fn new(resources: Option) -> CancellationListener { + CancellationListener { + cancel_resource: resources, + cancel_status: Cell::new(false), + } + } + + pub fn is_cancelled(&self) -> bool { + match self.cancel_resource { + Some(ref resource) => { + match resource.cancel_receiver.try_recv() { + Ok(_) => { + self.cancel_status.set(true); + true + }, + Err(_) => self.cancel_status.get(), + } + }, + None => false, // channel doesn't exist! + } + } +} + +impl Drop for CancellationListener { + fn drop(&mut self) { + if let Some(ref resource) = self.cancel_resource { + // Ensure that the resource manager stops tracking this request now that it's terminated. + let _ = resource.resource_task.send(ControlMsg::Cancel(resource.resource_id)); + } + } +} + pub struct ResourceManager { user_agent: String, cookie_storage: Arc>, @@ -196,6 +255,8 @@ pub struct ResourceManager { devtools_chan: Option>, hsts_list: Arc>, connector: Arc>, + cancel_load_map: HashMap>, + next_resource_id: ResourceId, } impl ResourceManager { @@ -209,11 +270,11 @@ impl ResourceManager { devtools_chan: devtools_channel, hsts_list: Arc::new(RwLock::new(hsts_list)), connector: create_http_connector(), + cancel_load_map: HashMap::new(), + next_resource_id: ResourceId(0), } } -} -impl ResourceManager { fn set_cookies_for_url(&mut self, request: Url, cookie_list: String, source: CookieSource) { let header = Header::parse_header(&[cookie_list.into_bytes()]); if let Ok(SetCookie(cookies)) = header { @@ -227,15 +288,36 @@ impl ResourceManager { } } - fn load(&mut self, load_data: LoadData, consumer: LoadConsumer) { + fn load(&mut self, + load_data: LoadData, + consumer: LoadConsumer, + id_sender: Option>, + resource_task: ResourceTask) { - fn from_factory(factory: fn(LoadData, LoadConsumer, Arc)) - -> Box) + Send> { - box move |load_data, senders, classifier| { - factory(load_data, senders, classifier) + fn from_factory(factory: fn(LoadData, LoadConsumer, Arc, CancellationListener)) + -> Box, + CancellationListener) + Send> { + box move |load_data, senders, classifier, cancel_listener| { + factory(load_data, senders, classifier, cancel_listener) } } + let cancel_resource = id_sender.map(|sender| { + let current_res_id = self.next_resource_id; + let _ = sender.send(current_res_id); + let (cancel_sender, cancel_receiver) = channel(); + self.cancel_load_map.insert(current_res_id, cancel_sender); + self.next_resource_id.0 += 1; + CancellableResource { + cancel_receiver: cancel_receiver, + resource_id: current_res_id, + resource_task: resource_task, + } + }); + + let cancel_listener = CancellationListener::new(cancel_resource); let loader = match &*load_data.url.scheme { "file" => from_factory(file_loader::factory), "http" | "https" | "view-source" => @@ -254,6 +336,9 @@ impl ResourceManager { }; debug!("resource_task: loading url: {}", load_data.url.serialize()); - loader.call_box((load_data, consumer, self.mime_classifier.clone())); + loader.call_box((load_data, + consumer, + self.mime_classifier.clone(), + cancel_listener)); } } diff --git a/components/net_traits/lib.rs b/components/net_traits/lib.rs index de59bf654d78..5671ee200a19 100644 --- a/components/net_traits/lib.rs +++ b/components/net_traits/lib.rs @@ -227,12 +227,15 @@ pub enum IncludeSubdomains { #[derive(Deserialize, Serialize)] pub enum ControlMsg { /// Request the data associated with a particular URL - Load(LoadData, LoadConsumer), + Load(LoadData, LoadConsumer, Option>), /// Store a set of cookies for a given originating URL SetCookiesForUrl(Url, String, CookieSource), /// Retrieve the stored cookies for a given URL GetCookiesForUrl(Url, IpcSender>, CookieSource), - Exit + /// Cancel a network request corresponding to a given `ResourceId` + Cancel(ResourceId), + /// Break the load handler loop and exit + Exit, } /// Initialized but unsent request. Encapsulates everything necessary to instruct @@ -279,7 +282,7 @@ impl PendingAsyncLoad { self.guard.neuter(); let load_data = LoadData::new(self.url, self.pipeline); let consumer = LoadConsumer::Listener(listener); - self.resource_task.send(ControlMsg::Load(load_data, consumer)).unwrap(); + self.resource_task.send(ControlMsg::Load(load_data, consumer, None)).unwrap(); } } @@ -377,7 +380,7 @@ pub fn load_whole_resource(resource_task: &ResourceTask, url: Url, pipeline_id: -> Result<(Metadata, Vec), String> { let (start_chan, start_port) = ipc::channel().unwrap(); resource_task.send(ControlMsg::Load(LoadData::new(url, pipeline_id), - LoadConsumer::Channel(start_chan))).unwrap(); + LoadConsumer::Channel(start_chan), None)).unwrap(); let response = start_port.recv().unwrap(); let mut buf = vec!(); @@ -389,3 +392,7 @@ pub fn load_whole_resource(resource_task: &ResourceTask, url: Url, pipeline_id: } } } + +/// An unique identifier to keep track of each load message in the resource handler +#[derive(Clone, PartialEq, Eq, Copy, Hash, Debug, Deserialize, Serialize, HeapSizeOf)] +pub struct ResourceId(pub u32); diff --git a/components/script/dom/xmlhttprequest.rs b/components/script/dom/xmlhttprequest.rs index e8025880f2e6..552c0dc9825b 100644 --- a/components/script/dom/xmlhttprequest.rs +++ b/components/script/dom/xmlhttprequest.rs @@ -280,7 +280,7 @@ impl XMLHttpRequest { ROUTER.add_route(action_receiver.to_opaque(), box move |message| { listener.notify(message.to().unwrap()); }); - resource_task.send(Load(load_data, LoadConsumer::Listener(response_target))).unwrap(); + resource_task.send(Load(load_data, LoadConsumer::Listener(response_target), None)).unwrap(); } } diff --git a/components/script/script_task.rs b/components/script/script_task.rs index 28e194b2a4a0..1da866f2dabf 100644 --- a/components/script/script_task.rs +++ b/components/script/script_task.rs @@ -1965,7 +1965,7 @@ impl ScriptTask { data: load_data.data, cors: None, pipeline_id: Some(id), - }, LoadConsumer::Listener(response_target))).unwrap(); + }, LoadConsumer::Listener(response_target), None)).unwrap(); self.incomplete_loads.borrow_mut().push(incomplete); } diff --git a/tests/unit/net/data_loader.rs b/tests/unit/net/data_loader.rs index 06271396189e..588c3c32f0b0 100644 --- a/tests/unit/net/data_loader.rs +++ b/tests/unit/net/data_loader.rs @@ -18,13 +18,16 @@ fn assert_parse(url: &'static str, data: Option>) { use net::data_loader::load; use net::mime_classifier::MIMEClassifier; + use net::resource_task::CancellationListener; use std::sync::Arc; use std::sync::mpsc::channel; use url::Url; let (start_chan, start_port) = ipc::channel().unwrap(); let classifier = Arc::new(MIMEClassifier::new()); - load(LoadData::new(Url::parse(url).unwrap(), None), Channel(start_chan), classifier); + load(LoadData::new(Url::parse(url).unwrap(), None), + Channel(start_chan), + classifier, CancellationListener::new(None)); let response = start_port.recv().unwrap(); assert_eq!(&response.metadata.content_type, &content_type); diff --git a/tests/unit/net/resource_task.rs b/tests/unit/net/resource_task.rs index b3812aade542..8d925082aa68 100644 --- a/tests/unit/net/resource_task.rs +++ b/tests/unit/net/resource_task.rs @@ -11,7 +11,6 @@ use std::collections::HashMap; use std::sync::mpsc::channel; use url::Url; - #[test] fn test_exit() { let resource_task = new_resource_task("".to_owned(), None); @@ -23,7 +22,7 @@ fn test_bad_scheme() { let resource_task = new_resource_task("".to_owned(), None); let (start_chan, start) = ipc::channel().unwrap(); let url = Url::parse("bogus://whatever").unwrap(); - resource_task.send(ControlMsg::Load(LoadData::new(url, None), LoadConsumer::Channel(start_chan))).unwrap(); + resource_task.send(ControlMsg::Load(LoadData::new(url, None), LoadConsumer::Channel(start_chan), None)).unwrap(); let response = start.recv().unwrap(); match response.progress_port.recv().unwrap() { ProgressMsg::Done(result) => { assert!(result.is_err()) }