Skip to content

Commit

Permalink
Add initial header-based rate limit handling
Browse files Browse the repository at this point in the history
  • Loading branch information
SpaceManiac committed Aug 20, 2016
1 parent 931dae5 commit b71d1e3
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 39 deletions.
5 changes: 1 addition & 4 deletions src/connection.rs
Expand Up @@ -266,10 +266,7 @@ impl Connection {
::sleep_ms(1000);
}
// If those fail, hit REST for a new endpoint
let (conn, ready) = try!(::Discord {
client: ::hyper::client::Client::new(),
token: self.token.to_owned()
}.connect());
let (conn, ready) = try!(::Discord::from_token_raw(self.token.to_owned()).connect());
try!(::std::mem::replace(self, conn).shutdown());
self.session_id = Some(ready.session_id.clone());
Ok(ready)
Expand Down
98 changes: 63 additions & 35 deletions src/lib.rs
Expand Up @@ -43,6 +43,7 @@ extern crate flate2;
use std::collections::BTreeMap;
use serde_json::builder::ObjectBuilder;

mod ratelimit;
mod error;
mod connection;
mod state;
Expand All @@ -59,6 +60,7 @@ pub use error::{Result, Error};
pub use connection::Connection;
pub use state::{State, ChannelRef};
use model::*;
use ratelimit::RateLimits;

const USER_AGENT: &'static str = concat!("DiscordBot (https://github.com/SpaceManiac/discord-rs, ", env!("CARGO_PKG_VERSION"), ")");
macro_rules! api_concat {
Expand All @@ -69,26 +71,22 @@ macro_rules! status_concat {
}

macro_rules! request {
($self_:ident, $method:ident($body:expr), $url:expr, $($rest:tt)*) => {
try!($self_.request(|| $self_.client.$method(
&format!(api_concat!($url), $($rest)*)
).body(&$body)))
};
($self_:ident, $method:ident, $url:expr, $($rest:tt)*) => {
try!($self_.request(|| $self_.client.$method(
&format!(api_concat!($url), $($rest)*)
)))
};
($self_:ident, $method:ident($body:expr), $url:expr) => {
try!($self_.request(|| $self_.client.$method(
api_concat!($url)
).body(&$body)))
};
($self_:ident, $method:ident, $url:expr) => {
try!($self_.request(|| $self_.client.$method(
api_concat!($url)
)))
};
($self_:ident, $method:ident($body:expr), $url:expr, $($rest:tt)*) => {{
let path = format!(api_concat!($url), $($rest)*);
try!($self_.request(&path, || $self_.client.$method(&path).body(&$body)))
}};
($self_:ident, $method:ident, $url:expr, $($rest:tt)*) => {{
let path = format!(api_concat!($url), $($rest)*);
try!($self_.request(&path, || $self_.client.$method(&path)))
}};
($self_:ident, $method:ident($body:expr), $url:expr) => {{
let path = api_concat!($url);
try!($self_.request(path, || $self_.client.$method(path).body(&$body)))
}};
($self_:ident, $method:ident, $url:expr) => {{
let path = api_concat!($url);
try!($self_.request(path, || $self_.client.$method(path)))
}};
}

/// Client for the Discord REST API.
Expand All @@ -98,6 +96,7 @@ macro_rules! request {
/// use `logout()` to invalidate the token when done. Other methods manipulate
/// the Discord REST API.
pub struct Discord {
rate_limits: RateLimits,
client: hyper::Client,
token: String,
}
Expand All @@ -121,6 +120,7 @@ impl Discord {
None => return Err(Error::Protocol("Response missing \"token\" in Discord::new()"))
};
Ok(Discord {
rate_limits: RateLimits::default(),
client: client,
token: token,
})
Expand Down Expand Up @@ -170,6 +170,7 @@ impl Discord {
None => return Err(Error::Protocol("Response missing \"token\" in Discord::new()"))
};
Discord {
rate_limits: RateLimits::default(),
client: client,
token: token,
}
Expand Down Expand Up @@ -203,9 +204,18 @@ impl Discord {
Ok(discord)
}

fn from_token_raw(token: String) -> Discord {
Discord {
rate_limits: RateLimits::default(),
client: hyper::Client::new(),
token: token,
}
}

/// Log in as a bot account using the given authentication token.
pub fn from_bot_token(token: &str) -> Result<Discord> {
Ok(Discord {
rate_limits: RateLimits::default(),
client: hyper::Client::new(),
token: format!("Bot {}", token),
})
Expand All @@ -221,10 +231,25 @@ impl Discord {
check_empty(request!(self, post(body), "/auth/logout"))
}

fn request<'a, F: Fn() -> hyper::client::RequestBuilder<'a>>(&self, f: F) -> Result<hyper::client::Response> {
retry(|| f()
fn request<'a, F: Fn() -> hyper::client::RequestBuilder<'a>>(&self, url: &str, f: F) -> Result<hyper::client::Response> {
self.rate_limits.pre_check(url);
let f2 = || f()
.header(hyper::header::ContentType::json())
.header(hyper::header::Authorization(self.token.clone())))
.header(hyper::header::Authorization(self.token.clone()));
let result = retry(&f2);
if let Ok(response) = result.as_ref() {
if self.rate_limits.post_update(url, response) {
// we were rate limited, we have slept, it is time to retry
// the request once. if it fails the second time, give up
debug!("Retrying after having been ratelimited");
let result = retry(f2);
if let Ok(response) = result.as_ref() {
self.rate_limits.post_update(url, response);
}
return check_status(result)
}
}
check_status(result)
}

/// Create a channel.
Expand Down Expand Up @@ -313,7 +338,7 @@ impl Discord {
GetMessages::After(id) => { let _ = write!(url, "&after={}", id); },
GetMessages::Around(id) => { let _ = write!(url, "&around={}", id); },
}
let response = try!(self.request(|| self.client.get(&url)));
let response = try!(self.request(&url, || self.client.get(&url)));
decode_array(try!(serde_json::from_reader(response)), Message::decode)
}

Expand Down Expand Up @@ -826,8 +851,8 @@ pub fn read_image<P: AsRef<::std::path::Path>>(path: P) -> Result<String> {
/// Retrieves the active maintenance statuses.
pub fn get_active_maintenances() -> Result<Vec<Maintenance>> {
let client = hyper::Client::new();
let response = try!(retry(|| client.get(
status_concat!("/api/v2/scheduled-maintenances/active.json"))));
let response = try!(check_status(retry(|| client.get(
status_concat!("/api/v2/scheduled-maintenances/active.json")))));
let mut json: BTreeMap<String, serde_json::Value> = try!(serde_json::from_reader(response));

match json.remove("scheduled_maintenances") {
Expand All @@ -839,8 +864,8 @@ pub fn get_active_maintenances() -> Result<Vec<Maintenance>> {
/// Retrieves the upcoming maintenance statuses.
pub fn get_upcoming_maintenances() -> Result<Vec<Maintenance>> {
let client = hyper::Client::new();
let response = try!(retry(|| client.get(
status_concat!("/api/v2/scheduled-maintenances/upcoming.json"))));
let response = try!(check_status(retry(|| client.get(
status_concat!("/api/v2/scheduled-maintenances/upcoming.json")))));
let mut json: BTreeMap<String, serde_json::Value> = try!(serde_json::from_reader(response));

match json.remove("scheduled_maintenances") {
Expand Down Expand Up @@ -994,28 +1019,31 @@ impl EditProfile {
}
}

fn retry<'a, F: Fn() -> hyper::client::RequestBuilder<'a>>(f: F) -> Result<hyper::client::Response> {
let f2 = || check_status(f()
/// Send a request with the correct UserAgent, retrying it a second time if the
/// connection is aborted the first time.
fn retry<'a, F: Fn() -> hyper::client::RequestBuilder<'a>>(f: F) -> hyper::Result<hyper::client::Response> {
let f2 = || f()
.header(hyper::header::UserAgent(USER_AGENT.to_owned()))
.send());
.send();
// retry on a ConnectionAborted, which occurs if it's been a while since the last request
match f2() {
Err(Error::Hyper(hyper::error::Error::Io(ref io)))
Err(hyper::error::Error::Io(ref io))
if io.kind() == std::io::ErrorKind::ConnectionAborted => f2(),
other => other
}
}

#[inline]
/// Convert non-success hyper statuses to discord crate errors, tossing info.
fn check_status(response: hyper::Result<hyper::client::Response>) -> Result<hyper::client::Response> {
let response = try!(response);
let response: hyper::client::Response = try!(response);
if !response.status.is_success() {
return Err(Error::from_response(response))
}
Ok(response)
}

#[inline]
/// Validate a request that is expected to return 204 No Content and print
/// debug information if it does not.
fn check_empty(mut response: hyper::client::Response) -> Result<()> {
if response.status != hyper::status::StatusCode::NoContent {
use std::io::Read;
Expand Down
123 changes: 123 additions & 0 deletions src/ratelimit.rs
@@ -0,0 +1,123 @@
use std::sync::Mutex;
use std::collections::BTreeMap;
use std;

use hyper;
use time::get_time;

use {Result, Error};

#[derive(Default)]
pub struct RateLimits {
global: Mutex<RateLimit>,
endpoints: Mutex<BTreeMap<String, RateLimit>>,
}

impl RateLimits {
/// Check before issuing a request for the given URL.
pub fn pre_check(&self, url: &str) {
self.global.lock().expect("Rate limits poisoned").pre_check();
if let Some(rl) = self.endpoints.lock().expect("Rate limits poisoned").get_mut(url) {
rl.pre_check();
}
}

/// Update based on rate limit headers in the response for given URL.
/// Returns `true` if the request was rate limited and should be retried.
pub fn post_update(&self, url: &str, response: &hyper::client::Response) -> bool {
if response.headers.get_raw("X-RateLimit-Global").is_some() {
self.global.lock().expect("Rate limits poisoned").post_update(response)
} else {
self.endpoints.lock().expect("Rate limits poisoned")
.entry(url.to_owned())
.or_insert_with(RateLimit::default)
.post_update(response)
}
}
}

#[derive(Default)]
struct RateLimit {
reset: i64,
limit: i64,
remaining: i64,
}

impl RateLimit {
fn pre_check(&mut self) {
// break out if uninitialized
if self.limit == 0 { return }

let difference = self.reset - get_time().sec;
if difference < 0 {
// If reset is apparently in the past, optimistically assume that
// the reset has occurred and we're good for the next three seconds
// or so. When the response comes back we will know for real.
self.reset += 3;
self.remaining = self.limit;
return
}

// if no requests remain, wait a bit
if self.remaining <= 0 {
// 900ms in case "difference" is off by 1
let delay = difference as u64 * 1000 + 900;
warn!("pre-ratelimit: sleeping for {}ms", delay);
::sleep_ms(delay);
return
}

// Deduct from our remaining requests. If a lot of requests are issued
// before any responses are received, this will mean we can still limit
// preemptively.
self.remaining -= 1;
}

fn post_update(&mut self, response: &hyper::client::Response) -> bool {
match self.try_post_update(response) {
Err(e) => {
error!("rate limit checking error: {}", e);
false
}
Ok(r) => r
}
}

fn try_post_update(&mut self, response: &hyper::client::Response) -> Result<bool> {
if let Some(reset) = try!(read_header(&response.headers, "X-RateLimit-Reset")) {
self.reset = reset;
}
if let Some(limit) = try!(read_header(&response.headers, "X-RateLimit-Limit")) {
self.limit = limit;
}
if let Some(remaining) = try!(read_header(&response.headers, "X-RateLimit-Remaining")) {
self.remaining = remaining;
}
if response.status == hyper::status::StatusCode::TooManyRequests {
if let Some(delay) = try!(read_header(&response.headers, "Retry-After")) {
let delay = delay as u64 + 100; // 100ms of leeway
warn!("429: sleeping for {}ms", delay);
::sleep_ms(delay);
return Ok(true); // retry the request
}
}
Ok(false)
}
}

fn read_header(headers: &hyper::header::Headers, name: &str) -> Result<Option<i64>> {
match headers.get_raw(name) {
Some(hdr) => if hdr.len() == 1 {
match std::str::from_utf8(&hdr[0]) {
Ok(text) => match text.parse() {
Ok(val) => Ok(Some(val)),
Err(_) => Err(Error::Other("header is not an i64"))
},
Err(_) => Err(Error::Other("header is not UTF-8"))
}
} else {
Err(Error::Other("header appears multiple times"))
},
None => Ok(None)
}
}

0 comments on commit b71d1e3

Please sign in to comment.