Skip to content

Commit

Permalink
Merge 1158d7a into faed4e1
Browse files Browse the repository at this point in the history
  • Loading branch information
letmutx committed Feb 7, 2020
2 parents faed4e1 + 1158d7a commit 0278696
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 60 deletions.
118 changes: 69 additions & 49 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use std::time::Duration;

use url::Url;

use connection::ConnectionManager;
use connection::{ConnCustomizer, ConnectionManager};
use error::MemcacheError;
use protocol::{Protocol, ProtocolTrait};
use protocol::ProtocolTrait;
use r2d2::Pool;
use stream::Stream;
use value::{FromMemcacheValueExt, ToMemcacheValue};
Expand Down Expand Up @@ -60,19 +60,70 @@ fn default_hash_function(key: &str) -> u64 {
return hasher.finish();
}

impl Client {
#[deprecated(since = "0.10.0", note = "please use `connect` instead")]
pub fn new<C: Connectable>(target: C) -> Result<Self, MemcacheError> {
return Self::connect(target);
pub struct ClientBuilder<C> {
target: C,
max_size: u32,
min_idle: Option<u32>,
max_lifetime: Option<Duration>,
read_timeout: Option<Duration>,
write_timeout: Option<Duration>,
pool_wait_timeout: Duration,
}

impl<C: Connectable> ClientBuilder<C> {
pub fn new(target: C) -> Self {
Self {
target,
max_size: 1,
min_idle: None,
max_lifetime: None,
read_timeout: None,
write_timeout: None,
pool_wait_timeout: Duration::new(30, 0),
}
}

pub fn with_min_idle(mut self, min_idle: Option<u32>) -> Self {
self.min_idle = min_idle;
self
}

pub fn with_max_lifetime(mut self, max_lifetime: Option<Duration>) -> Self {
self.max_lifetime = max_lifetime;
self
}

pub fn with_pool_size<C: Connectable>(target: C, size: u32) -> Result<Self, MemcacheError> {
let urls = target.get_urls();
pub fn with_max_size(mut self, size: u32) -> Self {
self.max_size = size;
self
}

pub fn with_read_timeout(mut self, read_timeout: Option<Duration>) -> Self {
self.read_timeout = read_timeout;
self
}

pub fn with_pool_wait_timeout(mut self, pool_wait_timeout: Duration) -> Self {
self.pool_wait_timeout = pool_wait_timeout;
self
}

pub fn with_write_timeout(mut self, write_timeout: Option<Duration>) -> Self {
self.write_timeout = write_timeout;
self
}

pub fn build(self) -> Result<Client, MemcacheError> {
let urls = self.target.get_urls();
let mut connections = vec![];
for url in urls {
let parsed = Url::parse(url.as_str())?;
let pool = r2d2::Pool::builder()
.max_size(size)
let pool = Pool::builder()
.max_size(self.max_size)
.min_idle(self.min_idle)
.max_lifetime(self.max_lifetime)
.connection_timeout(self.pool_wait_timeout)
.connection_customizer(Box::new(ConnCustomizer::new(self.read_timeout, self.write_timeout)))
.build(ConnectionManager::new(parsed))?;
connections.push(pool);
}
Expand All @@ -81,54 +132,23 @@ impl Client {
hash_function: default_hash_function,
})
}
}

impl Client {
#[deprecated(since = "0.10.0", note = "please use `connect` instead")]
pub fn new<C: Connectable>(target: C) -> Result<Self, MemcacheError> {
return Self::connect(target);
}

pub fn connect<C: Connectable>(target: C) -> Result<Self, MemcacheError> {
Self::with_pool_size(target, 1)
ClientBuilder::new(target).with_max_size(1).build()
}

fn get_connection(&self, key: &str) -> Pool<ConnectionManager> {
let connections_count = self.connections.len();
return self.connections[(self.hash_function)(key) as usize % connections_count].clone();
}

/// Set the socket read timeout for TCP connections.
///
/// Example:
///
/// ```rust
/// let client = memcache::Client::connect("memcache://localhost:12345").unwrap();
/// client.set_read_timeout(Some(::std::time::Duration::from_secs(3))).unwrap();
/// ```
pub fn set_read_timeout(&self, timeout: Option<Duration>) -> Result<(), MemcacheError> {
for conn in self.connections.iter() {
let mut conn = conn.get()?;
match **conn {
Protocol::Ascii(ref mut protocol) => protocol.stream().set_read_timeout(timeout)?,
Protocol::Binary(ref mut protocol) => protocol.stream.set_read_timeout(timeout)?,
}
}
Ok(())
}

/// Set the socket write timeout for TCP connections.
///
/// Example:
///
/// ```rust
/// let client = memcache::Client::connect("memcache://localhost:12345?protocol=ascii").unwrap();
/// client.set_write_timeout(Some(::std::time::Duration::from_secs(3))).unwrap();
/// ```
pub fn set_write_timeout(&self, timeout: Option<Duration>) -> Result<(), MemcacheError> {
for conn in self.connections.iter() {
let mut conn = conn.get()?;
match **conn {
Protocol::Ascii(ref mut protocol) => protocol.stream().set_read_timeout(timeout)?,
Protocol::Binary(ref mut protocol) => protocol.stream.set_write_timeout(timeout)?,
}
}
Ok(())
}

/// Get the memcached server version.
///
/// Example:
Expand Down
170 changes: 159 additions & 11 deletions src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use client::Stats;
use std::collections::HashMap;
use std::net::TcpStream;
use std::ops::{Deref, DerefMut};
#[cfg(unix)]
use std::os::unix::net::UnixStream;
use std::sync::Arc;
Expand All @@ -14,23 +15,35 @@ use protocol::{AsciiProtocol, BinaryProtocol, Protocol, ProtocolTrait};
use r2d2::ManageConnection;
use stream::Stream;
use stream::UdpStream;
use value::{FromMemcacheValueExt, ToMemcacheValue};

/// A connection to the memcached server
pub struct Connection {
pub protocol: Protocol,
is_dirty: bool,
pub url: Arc<String>,
}

impl DerefMut for Connection {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.protocol
#[derive(Debug)]
pub(crate) struct ConnCustomizer {
read_timeout: Option<Duration>,
write_timeout: Option<Duration>,
}

impl r2d2::CustomizeConnection<Connection, MemcacheError> for ConnCustomizer {
fn on_acquire(&self, conn: &mut Connection) -> Result<(), MemcacheError> {
conn.set_read_timeout(self.read_timeout)?;
conn.set_write_timeout(self.write_timeout)?;
Ok(())
}
}

impl Deref for Connection {
type Target = Protocol;
fn deref(&self) -> &Self::Target {
&self.protocol
impl ConnCustomizer {
pub(crate) fn new(read_timeout: Option<Duration>, write_timeout: Option<Duration>) -> Self {
Self {
read_timeout,
write_timeout,
}
}
}

Expand Down Expand Up @@ -63,9 +76,8 @@ impl ManageConnection for ConnectionManager {
conn.version().map(|_| ())
}

fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
// TODO: fix this
false
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
conn.is_dirty
}
}

Expand Down Expand Up @@ -208,7 +220,142 @@ fn tcp_stream(url: &Url, opts: &TcpOptions) -> Result<TcpStream, MemcacheError>
Ok(tcp_stream)
}

impl ProtocolTrait for Connection {
fn auth(&mut self, username: &str, password: &str) -> Result<(), MemcacheError> {
self.protocol.auth(username, password).map_err(|e| {
self.is_dirty = !e.is_recoverable();
e
})
}

fn version(&mut self) -> Result<String, MemcacheError> {
self.protocol.version().map_err(|e| {
self.is_dirty = !e.is_recoverable();
e
})
}

fn flush(&mut self) -> Result<(), MemcacheError> {
self.protocol.flush().map_err(|e| {
self.is_dirty = !e.is_recoverable();
e
})
}
fn flush_with_delay(&mut self, delay: u32) -> Result<(), MemcacheError> {
self.protocol.flush_with_delay(delay).map_err(|e| {
self.is_dirty = !e.is_recoverable();
e
})
}

fn get<V: FromMemcacheValueExt>(&mut self, key: &str) -> Result<Option<V>, MemcacheError> {
self.protocol.get(key).map_err(|e| {
self.is_dirty = !e.is_recoverable();
e
})
}
fn gets<V: FromMemcacheValueExt>(&mut self, keys: &[&str]) -> Result<HashMap<String, V>, MemcacheError> {
self.protocol.gets(keys).map_err(|e| {
self.is_dirty = !e.is_recoverable();
e
})
}
fn set<V: ToMemcacheValue<Stream>>(&mut self, key: &str, value: V, expiration: u32) -> Result<(), MemcacheError> {
self.protocol.set(key, value, expiration).map_err(|e| {
self.is_dirty = !e.is_recoverable();
e
})
}
fn cas<V: ToMemcacheValue<Stream>>(
&mut self,
key: &str,
value: V,
expiration: u32,
cas: u64,
) -> Result<bool, MemcacheError> {
self.protocol.cas(key, value, expiration, cas).map_err(|e| {
self.is_dirty = !e.is_recoverable();
e
})
}
fn add<V: ToMemcacheValue<Stream>>(&mut self, key: &str, value: V, expiration: u32) -> Result<(), MemcacheError> {
self.protocol.add(key, value, expiration).map_err(|e| {
self.is_dirty = !e.is_recoverable();
e
})
}
fn replace<V: ToMemcacheValue<Stream>>(
&mut self,
key: &str,
value: V,
expiration: u32,
) -> Result<(), MemcacheError> {
self.protocol.replace(key, value, expiration).map_err(|e| {
self.is_dirty = !e.is_recoverable();
e
})
}
fn append<V: ToMemcacheValue<Stream>>(&mut self, key: &str, value: V) -> Result<(), MemcacheError> {
self.protocol.append(key, value).map_err(|e| {
self.is_dirty = !e.is_recoverable();
e
})
}
fn prepend<V: ToMemcacheValue<Stream>>(&mut self, key: &str, value: V) -> Result<(), MemcacheError> {
self.protocol.prepend(key, value).map_err(|e| {
self.is_dirty = !e.is_recoverable();
e
})
}
fn delete(&mut self, key: &str) -> Result<bool, MemcacheError> {
self.protocol.delete(key).map_err(|e| {
self.is_dirty = !e.is_recoverable();
e
})
}
fn increment(&mut self, key: &str, amount: u64) -> Result<u64, MemcacheError> {
self.protocol.increment(key, amount).map_err(|e| {
self.is_dirty = !e.is_recoverable();
e
})
}
fn decrement(&mut self, key: &str, amount: u64) -> Result<u64, MemcacheError> {
self.protocol.decrement(key, amount).map_err(|e| {
self.is_dirty = !e.is_recoverable();
e
})
}
fn touch(&mut self, key: &str, expiration: u32) -> Result<bool, MemcacheError> {
self.protocol.touch(key, expiration).map_err(|e| {
self.is_dirty = !e.is_recoverable();
e
})
}
fn stats(&mut self) -> Result<Stats, MemcacheError> {
self.protocol.stats().map_err(|e| {
self.is_dirty = !e.is_recoverable();
e
})
}
}

impl Connection {
pub(crate) fn set_write_timeout(&mut self, timeout: Option<Duration>) -> Result<(), MemcacheError> {
match self.protocol {
Protocol::Ascii(ref mut protocol) => protocol.stream().set_write_timeout(timeout)?,
Protocol::Binary(ref mut protocol) => protocol.stream.set_write_timeout(timeout)?,
}
Ok(())
}

pub(crate) fn set_read_timeout(&mut self, timeout: Option<Duration>) -> Result<(), MemcacheError> {
match self.protocol {
Protocol::Ascii(ref mut protocol) => protocol.stream().set_read_timeout(timeout)?,
Protocol::Binary(ref mut protocol) => protocol.stream.set_read_timeout(timeout)?,
}
Ok(())
}

pub(crate) fn get_url(&self) -> String {
self.url.to_string()
}
Expand Down Expand Up @@ -257,6 +404,7 @@ impl Connection {

Ok(Connection {
url: Arc::new(url.to_string()),
is_dirty: false,
protocol: protocol,
})
}
Expand Down
12 changes: 12 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ pub enum CommandError {
}

impl MemcacheError {
pub(crate) fn is_recoverable(&self) -> bool {
match self {
MemcacheError::ClientError(ref err) if err != &ClientError::KeyTooLong => false,
MemcacheError::ServerError(_) => false,
MemcacheError::IOError(_) => false,
MemcacheError::ParseError(_) => false,
#[cfg(feature = "tls")]
MemcacheError::OpensslError(_) => false,
_ => true,
}
}

pub(crate) fn try_from(s: &str) -> Result<&str, MemcacheError> {
if s == "ERROR\r\n" {
Err(CommandError::InvalidCommand)?
Expand Down

0 comments on commit 0278696

Please sign in to comment.