diff --git a/Cargo.toml b/Cargo.toml index efffc04..afac98f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,6 @@ keywords = ["memcache", "memcached", "driver", "cache", "database"] [features] default = ["tls"] - tls = ["openssl"] [dependencies] @@ -19,3 +18,4 @@ url = "2.1.1" rand = "0.7" enum_dispatch = "0.2" openssl = { version = "^0.10", optional = true } +r2d2 = "0.8.8" diff --git a/src/client.rs b/src/client.rs index fc103bb..8cd5ab0 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,14 +1,14 @@ use std::collections::hash_map::DefaultHasher; use std::collections::HashMap; use std::hash::{Hash, Hasher}; -use std::sync::Arc; use std::time::Duration; use url::Url; -use connection::Connection; +use connection::ConnectionManager; use error::MemcacheError; use protocol::{Protocol, ProtocolTrait}; +use r2d2::Pool; use stream::Stream; use value::{FromMemcacheValueExt, ToMemcacheValue}; @@ -48,7 +48,7 @@ impl Connectable for Vec<&str> { #[derive(Clone)] pub struct Client { - connections: Arc>, + connections: Vec>, pub hash_function: fn(&str) -> u64, } @@ -66,28 +66,27 @@ impl Client { return Self::connect(target); } - pub fn connect(target: C) -> Result { + pub fn with_pool_size(target: C, size: u32) -> Result { let urls = target.get_urls(); let mut connections = vec![]; for url in urls { let parsed = Url::parse(url.as_str())?; - let connection = Connection::connect(&parsed)?; - - if parsed.has_authority() && parsed.username() != "" && parsed.password().is_some() { - let username = parsed.username(); - let password = parsed.password().unwrap(); - connection.get_ref().auth(username, password)?; - } - - connections.push(connection); + let pool = r2d2::Pool::builder() + .max_size(size) + .build(ConnectionManager::new(parsed))?; + connections.push(pool); } - return Ok(Client { - connections: Arc::new(connections), + Ok(Client { + connections, hash_function: default_hash_function, - }); + }) + } + + pub fn connect(target: C) -> Result { + Self::with_pool_size(target, 1) } - fn get_connection(&self, key: &str) -> Connection { + fn get_connection(&self, key: &str) -> Pool { let connections_count = self.connections.len(); return self.connections[(self.hash_function)(key) as usize % connections_count].clone(); } @@ -102,7 +101,8 @@ impl Client { /// ``` pub fn set_read_timeout(&self, timeout: Option) -> Result<(), MemcacheError> { for conn in self.connections.iter() { - match *conn.get_ref() { + let mut conn = conn.get()?; + match **conn { Protocol::Ascii(ref mut protocol) => protocol.stream().set_read_timeout(timeout)?, Protocol::Binary(ref mut protocol) => protocol.stream.set_read_timeout(timeout)?, } @@ -120,7 +120,8 @@ impl Client { /// ``` pub fn set_write_timeout(&self, timeout: Option) -> Result<(), MemcacheError> { for conn in self.connections.iter() { - match *conn.get_ref() { + let mut conn = conn.get()?; + match **conn { Protocol::Ascii(ref mut protocol) => protocol.stream().set_read_timeout(timeout)?, Protocol::Binary(ref mut protocol) => protocol.stream.set_write_timeout(timeout)?, } @@ -139,9 +140,9 @@ impl Client { pub fn version(&self) -> Result, MemcacheError> { let mut result = Vec::with_capacity(self.connections.len()); for connection in self.connections.iter() { - let mut protocol = connection.get_ref(); - let url = connection.url.to_string(); - result.push((url, protocol.version()?)); + let mut connection = connection.get()?; + let url = connection.get_url(); + result.push((url, connection.version()?)); } Ok(result) } @@ -156,7 +157,7 @@ impl Client { /// ``` pub fn flush(&self) -> Result<(), MemcacheError> { for connection in self.connections.iter() { - connection.get_ref().flush()?; + connection.get()?.flush()?; } return Ok(()); } @@ -171,7 +172,7 @@ impl Client { /// ``` pub fn flush_with_delay(&self, delay: u32) -> Result<(), MemcacheError> { for connection in self.connections.iter() { - connection.get_ref().flush_with_delay(delay)?; + connection.get()?.flush_with_delay(delay)?; } return Ok(()); } @@ -185,7 +186,7 @@ impl Client { /// let _: Option = client.get("foo").unwrap(); /// ``` pub fn get(&self, key: &str) -> Result, MemcacheError> { - return self.get_connection(key).get_ref().get(key); + return self.get_connection(key).get()?.get(key); } /// Get multiple keys from memcached server. Using this function instead of calling `get` multiple times can reduce netwark workloads. @@ -211,7 +212,7 @@ impl Client { } for (&connection_index, keys) in con_keys.iter() { let connection = self.connections[connection_index].clone(); - result.extend(connection.get_ref().gets(keys)?); + result.extend(connection.get()?.gets(keys)?); } return Ok(result); } @@ -226,7 +227,7 @@ impl Client { /// # client.flush().unwrap(); /// ``` pub fn set>(&self, key: &str, value: V, expiration: u32) -> Result<(), MemcacheError> { - return self.get_connection(key).get_ref().set(key, value, expiration); + return self.get_connection(key).get()?.set(key, value, expiration); } /// Compare and swap a key with the associate value into memcached server with expiration seconds. @@ -251,7 +252,7 @@ impl Client { expiration: u32, cas_id: u64, ) -> Result { - self.get_connection(key).get_ref().cas(key, value, expiration, cas_id) + self.get_connection(key).get()?.cas(key, value, expiration, cas_id) } /// Add a key with associate value into memcached server with expiration seconds. @@ -266,7 +267,7 @@ impl Client { /// # client.flush().unwrap(); /// ``` pub fn add>(&self, key: &str, value: V, expiration: u32) -> Result<(), MemcacheError> { - return self.get_connection(key).get_ref().add(key, value, expiration); + return self.get_connection(key).get()?.add(key, value, expiration); } /// Replace a key with associate value into memcached server with expiration seconds. @@ -286,7 +287,7 @@ impl Client { value: V, expiration: u32, ) -> Result<(), MemcacheError> { - return self.get_connection(key).get_ref().replace(key, value, expiration); + return self.get_connection(key).get()?.replace(key, value, expiration); } /// Append value to the key. @@ -303,7 +304,7 @@ impl Client { /// # client.flush().unwrap(); /// ``` pub fn append>(&self, key: &str, value: V) -> Result<(), MemcacheError> { - return self.get_connection(key).get_ref().append(key, value); + return self.get_connection(key).get()?.append(key, value); } /// Prepend value to the key. @@ -320,7 +321,7 @@ impl Client { /// # client.flush().unwrap(); /// ``` pub fn prepend>(&self, key: &str, value: V) -> Result<(), MemcacheError> { - return self.get_connection(key).get_ref().prepend(key, value); + return self.get_connection(key).get()?.prepend(key, value); } /// Delete a key from memcached server. @@ -333,7 +334,7 @@ impl Client { /// # client.flush().unwrap(); /// ``` pub fn delete(&self, key: &str) -> Result { - return self.get_connection(key).get_ref().delete(key); + return self.get_connection(key).get()?.delete(key); } /// Increment the value with amount. @@ -346,7 +347,7 @@ impl Client { /// # client.flush().unwrap(); /// ``` pub fn increment(&self, key: &str, amount: u64) -> Result { - return self.get_connection(key).get_ref().increment(key, amount); + return self.get_connection(key).get()?.increment(key, amount); } /// Decrement the value with amount. @@ -359,7 +360,7 @@ impl Client { /// # client.flush().unwrap(); /// ``` pub fn decrement(&self, key: &str, amount: u64) -> Result { - return self.get_connection(key).get_ref().decrement(key, amount); + return self.get_connection(key).get()?.decrement(key, amount); } /// Set a new expiration time for a exist key. @@ -374,7 +375,7 @@ impl Client { /// # client.flush().unwrap(); /// ``` pub fn touch(&self, key: &str, expiration: u32) -> Result { - return self.get_connection(key).get_ref().touch(key, expiration); + return self.get_connection(key).get()?.touch(key, expiration); } /// Get all servers' statistics. @@ -387,8 +388,9 @@ impl Client { pub fn stats(&self) -> Result, MemcacheError> { let mut result: Vec<(String, HashMap)> = vec![]; for connection in self.connections.iter() { - let stats_info = connection.get_ref().stats()?; - let url = connection.url.to_string(); + let mut connection = connection.get()?; + let stats_info = connection.stats()?; + let url = connection.get_url(); result.push((url, stats_info)); } return Ok(result); diff --git a/src/connection.rs b/src/connection.rs index 63c230f..9df9838 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,7 +1,8 @@ use std::net::TcpStream; +use std::ops::{Deref, DerefMut}; #[cfg(unix)] use std::os::unix::net::UnixStream; -use std::sync::{Arc, Mutex, MutexGuard}; +use std::sync::Arc; use std::time::Duration; use url::Url; @@ -9,20 +10,65 @@ use error::MemcacheError; #[cfg(feature = "tls")] use openssl::ssl::{SslConnector, SslFiletype, SslMethod, SslVerifyMode}; -use protocol::{AsciiProtocol, BinaryProtocol, Protocol}; +use protocol::{AsciiProtocol, BinaryProtocol, Protocol, ProtocolTrait}; +use r2d2::ManageConnection; use stream::Stream; use stream::UdpStream; -/// a connection to the memcached server -#[derive(Clone)] +/// A connection to the memcached server pub struct Connection { - /// Taking a lock on a `Protocol` will never fail unless the `Mutex` - /// is poisoned(which implies another bug in protocol module) because - /// all `Client` methods unlock the mutex before returning control flow. - pub protocol: Arc>, + pub protocol: Protocol, pub url: Arc, } +impl DerefMut for Connection { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.protocol + } +} + +impl Deref for Connection { + type Target = Protocol; + fn deref(&self) -> &Self::Target { + &self.protocol + } +} + +pub(crate) struct ConnectionManager { + url: Url, +} + +impl ConnectionManager { + pub(crate) fn new(url: Url) -> Self { + Self { url } + } +} + +impl ManageConnection for ConnectionManager { + type Connection = Connection; + type Error = MemcacheError; + + fn connect(&self) -> Result { + let url = &self.url; + let mut connection = Connection::connect(url)?; + if url.has_authority() && !url.username().is_empty() && url.password().is_some() { + let username = url.username(); + let password = url.password().unwrap(); + connection.auth(username, password)?; + } + Ok(connection) + } + + fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + conn.version().map(|_| ()) + } + + fn has_broken(&self, _conn: &mut Self::Connection) -> bool { + // TODO: fix this + false + } +} + enum Transport { Tcp(TcpOptions), Udp, @@ -163,8 +209,8 @@ fn tcp_stream(url: &Url, opts: &TcpOptions) -> Result } impl Connection { - pub(crate) fn get_ref(&self) -> MutexGuard { - self.protocol.lock().expect("won't fail") + pub(crate) fn get_url(&self) -> String { + self.url.to_string() } pub(crate) fn connect(url: &Url) -> Result { @@ -211,7 +257,7 @@ impl Connection { Ok(Connection { url: Arc::new(url.to_string()), - protocol: Arc::new(Mutex::new(protocol)), + protocol: protocol, }) } } diff --git a/src/error.rs b/src/error.rs index e1e285c..64533c7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,3 +1,4 @@ +use r2d2; use std::borrow::Cow; use std::error; use std::fmt; @@ -237,6 +238,8 @@ pub enum MemcacheError { OpensslError(openssl::ssl::HandshakeError), /// Parse errors ParseError(ParseError), + /// ConnectionPool errors + PoolError(r2d2::Error), } impl fmt::Display for MemcacheError { @@ -250,6 +253,7 @@ impl fmt::Display for MemcacheError { MemcacheError::ClientError(ref err) => err.fmt(f), MemcacheError::ServerError(ref err) => err.fmt(f), MemcacheError::CommandError(ref err) => err.fmt(f), + MemcacheError::PoolError(ref err) => err.fmt(f), } } } @@ -266,6 +270,7 @@ impl error::Error for MemcacheError { MemcacheError::ClientError(_) => None, MemcacheError::ServerError(_) => None, MemcacheError::CommandError(_) => None, + MemcacheError::PoolError(ref p) => p.source(), } } } @@ -289,3 +294,9 @@ impl From> for MemcacheError { MemcacheError::OpensslError(err) } } + +impl From for MemcacheError { + fn from(err: r2d2::Error) -> MemcacheError { + MemcacheError::PoolError(err) + } +} diff --git a/src/lib.rs b/src/lib.rs index 20ffcc9..0168d89 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -69,6 +69,7 @@ extern crate byteorder; extern crate enum_dispatch; #[cfg(feature = "tls")] extern crate openssl; +extern crate r2d2; extern crate rand; extern crate url; @@ -81,6 +82,7 @@ mod value; pub use client::{Client, Connectable}; pub use error::{ClientError, CommandError, MemcacheError, ServerError}; +pub use r2d2::Error; pub use value::{FromMemcacheValue, FromMemcacheValueExt, ToMemcacheValue}; /// Create a memcached client instance and connect to memcached server.