Skip to content

Commit

Permalink
Merge 4330284 into dced588
Browse files Browse the repository at this point in the history
  • Loading branch information
letmutx committed Feb 5, 2020
2 parents dced588 + 4330284 commit b998b98
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 50 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ keywords = ["memcache", "memcached", "driver", "cache", "database"]

[features]
default = ["tls"]

tls = ["openssl"]

[dependencies]
Expand All @@ -19,3 +18,4 @@ url = "2.1.1"
rand = "0.7"
enum_dispatch = "0.2"
openssl = { version = "^0.10", optional = true }
r2d2 = "0.8.8"
78 changes: 40 additions & 38 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::time::Duration;

use url::Url;

use connection::Connection;
use connection::ConnectionManager;
use error::MemcacheError;
use protocol::{Protocol, ProtocolTrait};
use r2d2::Pool;
use stream::Stream;
use value::{FromMemcacheValueExt, ToMemcacheValue};

Expand Down Expand Up @@ -48,7 +48,7 @@ impl Connectable for Vec<&str> {

#[derive(Clone)]
pub struct Client {
connections: Arc<Vec<Connection>>,
connections: Vec<Pool<ConnectionManager>>,
pub hash_function: fn(&str) -> u64,
}

Expand All @@ -66,28 +66,27 @@ impl Client {
return Self::connect(target);
}

pub fn connect<C: Connectable>(target: C) -> Result<Self, MemcacheError> {
pub fn with_pool_size<C: Connectable>(target: C, size: u32) -> Result<Self, MemcacheError> {
let urls = target.get_urls();
let mut connections = vec![];
for url in urls {
let parsed = Url::parse(url.as_str())?;
let connection = Connection::connect(&parsed)?;

if parsed.has_authority() && parsed.username() != "" && parsed.password().is_some() {
let username = parsed.username();
let password = parsed.password().unwrap();
connection.get_ref().auth(username, password)?;
}

connections.push(connection);
let pool = r2d2::Pool::builder()
.max_size(size)
.build(ConnectionManager::new(parsed))?;
connections.push(pool);
}
return Ok(Client {
connections: Arc::new(connections),
Ok(Client {
connections,
hash_function: default_hash_function,
});
})
}

pub fn connect<C: Connectable>(target: C) -> Result<Self, MemcacheError> {
Self::with_pool_size(target, 1)
}

fn get_connection(&self, key: &str) -> Connection {
fn get_connection(&self, key: &str) -> Pool<ConnectionManager> {
let connections_count = self.connections.len();
return self.connections[(self.hash_function)(key) as usize % connections_count].clone();
}
Expand All @@ -102,7 +101,8 @@ impl Client {
/// ```
pub fn set_read_timeout(&self, timeout: Option<Duration>) -> Result<(), MemcacheError> {
for conn in self.connections.iter() {
match *conn.get_ref() {
let mut conn = conn.get()?;
match **conn {
Protocol::Ascii(ref mut protocol) => protocol.stream().set_read_timeout(timeout)?,
Protocol::Binary(ref mut protocol) => protocol.stream.set_read_timeout(timeout)?,
}
Expand All @@ -120,7 +120,8 @@ impl Client {
/// ```
pub fn set_write_timeout(&self, timeout: Option<Duration>) -> Result<(), MemcacheError> {
for conn in self.connections.iter() {
match *conn.get_ref() {
let mut conn = conn.get()?;
match **conn {
Protocol::Ascii(ref mut protocol) => protocol.stream().set_read_timeout(timeout)?,
Protocol::Binary(ref mut protocol) => protocol.stream.set_write_timeout(timeout)?,
}
Expand All @@ -139,9 +140,9 @@ impl Client {
pub fn version(&self) -> Result<Vec<(String, String)>, MemcacheError> {
let mut result = Vec::with_capacity(self.connections.len());
for connection in self.connections.iter() {
let mut protocol = connection.get_ref();
let url = connection.url.to_string();
result.push((url, protocol.version()?));
let mut connection = connection.get()?;
let url = connection.get_url();
result.push((url, connection.version()?));
}
Ok(result)
}
Expand All @@ -156,7 +157,7 @@ impl Client {
/// ```
pub fn flush(&self) -> Result<(), MemcacheError> {
for connection in self.connections.iter() {
connection.get_ref().flush()?;
connection.get()?.flush()?;
}
return Ok(());
}
Expand All @@ -171,7 +172,7 @@ impl Client {
/// ```
pub fn flush_with_delay(&self, delay: u32) -> Result<(), MemcacheError> {
for connection in self.connections.iter() {
connection.get_ref().flush_with_delay(delay)?;
connection.get()?.flush_with_delay(delay)?;
}
return Ok(());
}
Expand All @@ -185,7 +186,7 @@ impl Client {
/// let _: Option<String> = client.get("foo").unwrap();
/// ```
pub fn get<V: FromMemcacheValueExt>(&self, key: &str) -> Result<Option<V>, MemcacheError> {
return self.get_connection(key).get_ref().get(key);
return self.get_connection(key).get()?.get(key);
}

/// Get multiple keys from memcached server. Using this function instead of calling `get` multiple times can reduce netwark workloads.
Expand All @@ -211,7 +212,7 @@ impl Client {
}
for (&connection_index, keys) in con_keys.iter() {
let connection = self.connections[connection_index].clone();
result.extend(connection.get_ref().gets(keys)?);
result.extend(connection.get()?.gets(keys)?);
}
return Ok(result);
}
Expand All @@ -226,7 +227,7 @@ impl Client {
/// # client.flush().unwrap();
/// ```
pub fn set<V: ToMemcacheValue<Stream>>(&self, key: &str, value: V, expiration: u32) -> Result<(), MemcacheError> {
return self.get_connection(key).get_ref().set(key, value, expiration);
return self.get_connection(key).get()?.set(key, value, expiration);
}

/// Compare and swap a key with the associate value into memcached server with expiration seconds.
Expand All @@ -251,7 +252,7 @@ impl Client {
expiration: u32,
cas_id: u64,
) -> Result<bool, MemcacheError> {
self.get_connection(key).get_ref().cas(key, value, expiration, cas_id)
self.get_connection(key).get()?.cas(key, value, expiration, cas_id)
}

/// Add a key with associate value into memcached server with expiration seconds.
Expand All @@ -266,7 +267,7 @@ impl Client {
/// # client.flush().unwrap();
/// ```
pub fn add<V: ToMemcacheValue<Stream>>(&self, key: &str, value: V, expiration: u32) -> Result<(), MemcacheError> {
return self.get_connection(key).get_ref().add(key, value, expiration);
return self.get_connection(key).get()?.add(key, value, expiration);
}

/// Replace a key with associate value into memcached server with expiration seconds.
Expand All @@ -286,7 +287,7 @@ impl Client {
value: V,
expiration: u32,
) -> Result<(), MemcacheError> {
return self.get_connection(key).get_ref().replace(key, value, expiration);
return self.get_connection(key).get()?.replace(key, value, expiration);
}

/// Append value to the key.
Expand All @@ -303,7 +304,7 @@ impl Client {
/// # client.flush().unwrap();
/// ```
pub fn append<V: ToMemcacheValue<Stream>>(&self, key: &str, value: V) -> Result<(), MemcacheError> {
return self.get_connection(key).get_ref().append(key, value);
return self.get_connection(key).get()?.append(key, value);
}

/// Prepend value to the key.
Expand All @@ -320,7 +321,7 @@ impl Client {
/// # client.flush().unwrap();
/// ```
pub fn prepend<V: ToMemcacheValue<Stream>>(&self, key: &str, value: V) -> Result<(), MemcacheError> {
return self.get_connection(key).get_ref().prepend(key, value);
return self.get_connection(key).get()?.prepend(key, value);
}

/// Delete a key from memcached server.
Expand All @@ -333,7 +334,7 @@ impl Client {
/// # client.flush().unwrap();
/// ```
pub fn delete(&self, key: &str) -> Result<bool, MemcacheError> {
return self.get_connection(key).get_ref().delete(key);
return self.get_connection(key).get()?.delete(key);
}

/// Increment the value with amount.
Expand All @@ -346,7 +347,7 @@ impl Client {
/// # client.flush().unwrap();
/// ```
pub fn increment(&self, key: &str, amount: u64) -> Result<u64, MemcacheError> {
return self.get_connection(key).get_ref().increment(key, amount);
return self.get_connection(key).get()?.increment(key, amount);
}

/// Decrement the value with amount.
Expand All @@ -359,7 +360,7 @@ impl Client {
/// # client.flush().unwrap();
/// ```
pub fn decrement(&self, key: &str, amount: u64) -> Result<u64, MemcacheError> {
return self.get_connection(key).get_ref().decrement(key, amount);
return self.get_connection(key).get()?.decrement(key, amount);
}

/// Set a new expiration time for a exist key.
Expand All @@ -374,7 +375,7 @@ impl Client {
/// # client.flush().unwrap();
/// ```
pub fn touch(&self, key: &str, expiration: u32) -> Result<bool, MemcacheError> {
return self.get_connection(key).get_ref().touch(key, expiration);
return self.get_connection(key).get()?.touch(key, expiration);
}

/// Get all servers' statistics.
Expand All @@ -387,8 +388,9 @@ impl Client {
pub fn stats(&self) -> Result<Vec<(String, Stats)>, MemcacheError> {
let mut result: Vec<(String, HashMap<String, String>)> = vec![];
for connection in self.connections.iter() {
let stats_info = connection.get_ref().stats()?;
let url = connection.url.to_string();
let mut connection = connection.get()?;
let stats_info = connection.stats()?;
let url = connection.get_url();
result.push((url, stats_info));
}
return Ok(result);
Expand Down
68 changes: 57 additions & 11 deletions src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,74 @@
use std::net::TcpStream;
use std::ops::{Deref, DerefMut};
#[cfg(unix)]
use std::os::unix::net::UnixStream;
use std::sync::{Arc, Mutex, MutexGuard};
use std::sync::Arc;
use std::time::Duration;
use url::Url;

use error::MemcacheError;

#[cfg(feature = "tls")]
use openssl::ssl::{SslConnector, SslFiletype, SslMethod, SslVerifyMode};
use protocol::{AsciiProtocol, BinaryProtocol, Protocol};
use protocol::{AsciiProtocol, BinaryProtocol, Protocol, ProtocolTrait};
use r2d2::ManageConnection;
use stream::Stream;
use stream::UdpStream;

/// a connection to the memcached server
#[derive(Clone)]
/// A connection to the memcached server
pub struct Connection {
/// Taking a lock on a `Protocol` will never fail unless the `Mutex`
/// is poisoned(which implies another bug in protocol module) because
/// all `Client` methods unlock the mutex before returning control flow.
pub protocol: Arc<Mutex<Protocol>>,
pub protocol: Protocol,
pub url: Arc<String>,
}

impl DerefMut for Connection {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.protocol
}
}

impl Deref for Connection {
type Target = Protocol;
fn deref(&self) -> &Self::Target {
&self.protocol
}
}

pub(crate) struct ConnectionManager {
url: Url,
}

impl ConnectionManager {
pub(crate) fn new(url: Url) -> Self {
Self { url }
}
}

impl ManageConnection for ConnectionManager {
type Connection = Connection;
type Error = MemcacheError;

fn connect(&self) -> Result<Self::Connection, Self::Error> {
let url = &self.url;
let mut connection = Connection::connect(url)?;
if url.has_authority() && !url.username().is_empty() && url.password().is_some() {
let username = url.username();
let password = url.password().unwrap();
connection.auth(username, password)?;
}
Ok(connection)
}

fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
conn.version().map(|_| ())
}

fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
// TODO: fix this
false
}
}

enum Transport {
Tcp(TcpOptions),
Udp,
Expand Down Expand Up @@ -163,8 +209,8 @@ fn tcp_stream(url: &Url, opts: &TcpOptions) -> Result<TcpStream, MemcacheError>
}

impl Connection {
pub(crate) fn get_ref(&self) -> MutexGuard<Protocol> {
self.protocol.lock().expect("won't fail")
pub(crate) fn get_url(&self) -> String {
self.url.to_string()
}

pub(crate) fn connect(url: &Url) -> Result<Self, MemcacheError> {
Expand Down Expand Up @@ -211,7 +257,7 @@ impl Connection {

Ok(Connection {
url: Arc::new(url.to_string()),
protocol: Arc::new(Mutex::new(protocol)),
protocol: protocol,
})
}
}
Expand Down
11 changes: 11 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use r2d2;
use std::borrow::Cow;
use std::error;
use std::fmt;
Expand Down Expand Up @@ -237,6 +238,8 @@ pub enum MemcacheError {
OpensslError(openssl::ssl::HandshakeError<std::net::TcpStream>),
/// Parse errors
ParseError(ParseError),
/// ConnectionPool errors
PoolError(r2d2::Error),
}

impl fmt::Display for MemcacheError {
Expand All @@ -250,6 +253,7 @@ impl fmt::Display for MemcacheError {
MemcacheError::ClientError(ref err) => err.fmt(f),
MemcacheError::ServerError(ref err) => err.fmt(f),
MemcacheError::CommandError(ref err) => err.fmt(f),
MemcacheError::PoolError(ref err) => err.fmt(f),
}
}
}
Expand All @@ -266,6 +270,7 @@ impl error::Error for MemcacheError {
MemcacheError::ClientError(_) => None,
MemcacheError::ServerError(_) => None,
MemcacheError::CommandError(_) => None,
MemcacheError::PoolError(ref p) => p.source(),
}
}
}
Expand All @@ -289,3 +294,9 @@ impl From<openssl::ssl::HandshakeError<std::net::TcpStream>> for MemcacheError {
MemcacheError::OpensslError(err)
}
}

impl From<r2d2::Error> for MemcacheError {
fn from(err: r2d2::Error) -> MemcacheError {
MemcacheError::PoolError(err)
}
}

0 comments on commit b998b98

Please sign in to comment.