From 9dabfadce894419abeda80bf42831a24b3afd66a Mon Sep 17 00:00:00 2001 From: Anton Lazarev Date: Sat, 18 Apr 2020 23:13:44 -0400 Subject: [PATCH] :recycle: impl ProtocolTrait for AsciiProtocol and BinaryProtocol --- src/protocol/ascii.rs | 317 ++++++++++++++++++++--------------------- src/protocol/binary.rs | 131 ++++++++--------- 2 files changed, 217 insertions(+), 231 deletions(-) diff --git a/src/protocol/ascii.rs b/src/protocol/ascii.rs index 346c11a..3d1e6c1 100644 --- a/src/protocol/ascii.rs +++ b/src/protocol/ascii.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::fmt; use std::io::{Read, Write}; +use super::ProtocolTrait; use client::Stats; use error::{ClientError, CommandError, MemcacheError, ServerError}; use std::borrow::Cow; @@ -125,82 +126,12 @@ pub struct AsciiProtocol { reader: CappedLineReader, } -impl AsciiProtocol { - pub(crate) fn new(stream: Stream) -> Self { - Self { - reader: CappedLineReader::new(stream), - } - } - - pub(crate) fn stream(&mut self) -> &mut Stream { - self.reader.get_mut() - } - - pub(super) fn auth(&mut self, username: &str, password: &str) -> Result<(), MemcacheError> { +impl ProtocolTrait for AsciiProtocol { + fn auth(&mut self, username: &str, password: &str) -> Result<(), MemcacheError> { return self.set("auth", format!("{} {}", username, password), 0); } - fn store>( - &mut self, - command: StoreCommand, - key: &str, - value: V, - options: &Options, - ) -> Result { - if command == StoreCommand::Cas { - if options.cas.is_none() { - Err(ClientError::Error(Cow::Borrowed( - "cas_id should be present when using cas command", - )))?; - } - } - let noreply = if options.noreply { " noreply" } else { "" }; - if options.cas.is_some() { - write!( - self.reader.get_mut(), - "{command} {key} {flags} {exptime} {vlen} {cas}{noreply}\r\n", - command = command, - key = key, - flags = value.get_flags(), - exptime = options.exptime, - vlen = value.get_length(), - cas = options.cas.unwrap(), - noreply = noreply - )?; - } else { - write!( - self.reader.get_mut(), - "{command} {key} {flags} {exptime} {vlen}{noreply}\r\n", - command = command, - key = key, - flags = value.get_flags(), - exptime = options.exptime, - vlen = value.get_length(), - noreply = noreply - )?; - } - - value.write_to(self.reader.get_mut())?; - self.reader.get_mut().write(b"\r\n")?; - self.reader.get_mut().flush()?; - - if options.noreply { - return Ok(true); - } - - self.reader.read_line(|response| { - let response = MemcacheError::try_from(response)?; - match response { - "STORED\r\n" => Ok(true), - "NOT_STORED\r\n" => Ok(false), - "EXISTS\r\n" => Err(CommandError::KeyExists)?, - "NOT_FOUND\r\n" => Err(CommandError::KeyNotFound)?, - response => Err(ServerError::BadResponse(Cow::Owned(response.into())))?, - } - }) - } - - pub(super) fn version(&mut self) -> Result { + fn version(&mut self) -> Result { self.reader.get_mut().write(b"version\r\n")?; self.reader.get_mut().flush()?; self.reader.read_line(|response| { @@ -213,29 +144,18 @@ impl AsciiProtocol { }) } - fn parse_ok_response(&mut self) -> Result<(), MemcacheError> { - self.reader.read_line(|response| { - let response = MemcacheError::try_from(response)?; - if response == "OK\r\n" { - Ok(()) - } else { - Err(ServerError::BadResponse(Cow::Owned(response.into())))? - } - }) - } - - pub(super) fn flush(&mut self) -> Result<(), MemcacheError> { + fn flush(&mut self) -> Result<(), MemcacheError> { write!(self.reader.get_mut(), "flush_all\r\n")?; self.parse_ok_response() } - pub(super) fn flush_with_delay(&mut self, delay: u32) -> Result<(), MemcacheError> { + fn flush_with_delay(&mut self, delay: u32) -> Result<(), MemcacheError> { write!(self.reader.get_mut(), "flush_all {}\r\n", delay)?; self.reader.get_mut().flush()?; self.parse_ok_response() } - pub(super) fn get(&mut self, key: &str) -> Result, MemcacheError> { + fn get(&mut self, key: &str) -> Result, MemcacheError> { write!(self.reader.get_mut(), "get {}\r\n", key)?; if let Some((k, v)) = self.parse_get_response(false)? { @@ -253,53 +173,7 @@ impl AsciiProtocol { } } - fn parse_get_response( - &mut self, - has_cas: bool, - ) -> Result, MemcacheError> { - let result = self.reader.read_line(|buf| { - let buf = MemcacheError::try_from(buf)?; - if buf == END { - return Ok(None); - } - if !buf.starts_with("VALUE") { - return Err(ServerError::BadResponse(Cow::Owned(buf.into())))?; - } - let mut header = buf.trim_end_matches("\r\n").split(" "); - let mut next_or_err = || { - header - .next() - .ok_or_else(|| ServerError::BadResponse(Cow::Owned(buf.into()))) - }; - let _ = next_or_err()?; - let key = next_or_err()?; - let flags: u32 = next_or_err()?.parse()?; - let length: usize = next_or_err()?.parse()?; - let cas: Option = if has_cas { Some(next_or_err()?.parse()?) } else { None }; - if header.next().is_some() { - return Err(ServerError::BadResponse(Cow::Owned(buf.into())))?; - } - Ok(Some((key.to_string(), flags, length, cas))) - })?; - match result { - Some((key, flags, length, cas)) => { - let mut value = vec![0u8; length + 2]; - self.reader.read_exact(value.as_mut_slice())?; - if &value[length..] != b"\r\n" { - return Err(ServerError::BadResponse(Cow::Owned(String::from_utf8(value)?)))?; - } - // remove the trailing \r\n - value.pop(); - value.pop(); - value.shrink_to_fit(); - let value = FromMemcacheValueExt::from_memcache_value(value, flags, cas)?; - Ok(Some((key.to_string(), value))) - } - None => Ok(None), - } - } - - pub(super) fn gets(&mut self, keys: &[&str]) -> Result, MemcacheError> { + fn gets(&mut self, keys: &[&str]) -> Result, MemcacheError> { write!(self.reader.get_mut(), "gets {}\r\n", keys.join(" "))?; let mut result: HashMap = HashMap::with_capacity(keys.len()); @@ -316,7 +190,7 @@ impl AsciiProtocol { Err(ServerError::BadResponse(Cow::Borrowed("Expected end of gets response")))? } - pub(super) fn cas>( + fn cas>( &mut self, key: &str, value: V, @@ -337,12 +211,7 @@ impl AsciiProtocol { } } - pub(super) fn set>( - &mut self, - key: &str, - value: V, - expiration: u32, - ) -> Result<(), MemcacheError> { + fn set>(&mut self, key: &str, value: V, expiration: u32) -> Result<(), MemcacheError> { let options = Options { exptime: expiration, ..Default::default() @@ -350,12 +219,7 @@ impl AsciiProtocol { self.store(StoreCommand::Set, key, value, &options).map(|_| ()) } - pub(super) fn add>( - &mut self, - key: &str, - value: V, - expiration: u32, - ) -> Result<(), MemcacheError> { + fn add>(&mut self, key: &str, value: V, expiration: u32) -> Result<(), MemcacheError> { let options = Options { exptime: expiration, ..Default::default() @@ -363,7 +227,7 @@ impl AsciiProtocol { self.store(StoreCommand::Add, key, value, &options).map(|_| ()) } - pub(super) fn replace>( + fn replace>( &mut self, key: &str, value: V, @@ -376,17 +240,17 @@ impl AsciiProtocol { self.store(StoreCommand::Replace, key, value, &options).map(|_| ()) } - pub(super) fn append>(&mut self, key: &str, value: V) -> Result<(), MemcacheError> { + fn append>(&mut self, key: &str, value: V) -> Result<(), MemcacheError> { self.store(StoreCommand::Append, key, value, &Default::default()) .map(|_| ()) } - pub(super) fn prepend>(&mut self, key: &str, value: V) -> Result<(), MemcacheError> { + fn prepend>(&mut self, key: &str, value: V) -> Result<(), MemcacheError> { self.store(StoreCommand::Prepend, key, value, &Default::default()) .map(|_| ()) } - pub(super) fn delete(&mut self, key: &str) -> Result { + fn delete(&mut self, key: &str) -> Result { write!(self.reader.get_mut(), "delete {}\r\n", key)?; self.reader.get_mut().flush()?; self.reader @@ -403,24 +267,17 @@ impl AsciiProtocol { }) } - fn parse_u64_response(&mut self) -> Result { - self.reader.read_line(|response| { - let s = MemcacheError::try_from(response)?; - Ok(s.trim_end_matches("\r\n").parse::()?) - }) - } - - pub(super) fn increment(&mut self, key: &str, amount: u64) -> Result { + fn increment(&mut self, key: &str, amount: u64) -> Result { write!(self.reader.get_mut(), "incr {} {}\r\n", key, amount)?; self.parse_u64_response() } - pub(super) fn decrement(&mut self, key: &str, amount: u64) -> Result { + fn decrement(&mut self, key: &str, amount: u64) -> Result { write!(self.reader.get_mut(), "decr {} {}\r\n", key, amount)?; self.parse_u64_response() } - pub(super) fn touch(&mut self, key: &str, expiration: u32) -> Result { + fn touch(&mut self, key: &str, expiration: u32) -> Result { write!(self.reader.get_mut(), "touch {} {}\r\n", key, expiration)?; self.reader.get_mut().flush()?; self.reader @@ -437,7 +294,7 @@ impl AsciiProtocol { }) } - pub(super) fn stats(&mut self) -> Result { + fn stats(&mut self) -> Result { self.reader.get_mut().write(b"stats\r\n")?; self.reader.get_mut().flush()?; @@ -473,3 +330,139 @@ impl AsciiProtocol { } } } + +impl AsciiProtocol { + pub(crate) fn new(stream: Stream) -> Self { + Self { + reader: CappedLineReader::new(stream), + } + } + + pub(crate) fn stream(&mut self) -> &mut Stream { + self.reader.get_mut() + } + + fn store>( + &mut self, + command: StoreCommand, + key: &str, + value: V, + options: &Options, + ) -> Result { + if command == StoreCommand::Cas { + if options.cas.is_none() { + Err(ClientError::Error(Cow::Borrowed( + "cas_id should be present when using cas command", + )))?; + } + } + let noreply = if options.noreply { " noreply" } else { "" }; + if options.cas.is_some() { + write!( + self.reader.get_mut(), + "{command} {key} {flags} {exptime} {vlen} {cas}{noreply}\r\n", + command = command, + key = key, + flags = value.get_flags(), + exptime = options.exptime, + vlen = value.get_length(), + cas = options.cas.unwrap(), + noreply = noreply + )?; + } else { + write!( + self.reader.get_mut(), + "{command} {key} {flags} {exptime} {vlen}{noreply}\r\n", + command = command, + key = key, + flags = value.get_flags(), + exptime = options.exptime, + vlen = value.get_length(), + noreply = noreply + )?; + } + + value.write_to(self.reader.get_mut())?; + self.reader.get_mut().write(b"\r\n")?; + self.reader.get_mut().flush()?; + + if options.noreply { + return Ok(true); + } + + self.reader.read_line(|response| { + let response = MemcacheError::try_from(response)?; + match response { + "STORED\r\n" => Ok(true), + "NOT_STORED\r\n" => Ok(false), + "EXISTS\r\n" => Err(CommandError::KeyExists)?, + "NOT_FOUND\r\n" => Err(CommandError::KeyNotFound)?, + response => Err(ServerError::BadResponse(Cow::Owned(response.into())))?, + } + }) + } + + fn parse_ok_response(&mut self) -> Result<(), MemcacheError> { + self.reader.read_line(|response| { + let response = MemcacheError::try_from(response)?; + if response == "OK\r\n" { + Ok(()) + } else { + Err(ServerError::BadResponse(Cow::Owned(response.into())))? + } + }) + } + + fn parse_get_response( + &mut self, + has_cas: bool, + ) -> Result, MemcacheError> { + let result = self.reader.read_line(|buf| { + let buf = MemcacheError::try_from(buf)?; + if buf == END { + return Ok(None); + } + if !buf.starts_with("VALUE") { + return Err(ServerError::BadResponse(Cow::Owned(buf.into())))?; + } + let mut header = buf.trim_end_matches("\r\n").split(" "); + let mut next_or_err = || { + header + .next() + .ok_or_else(|| ServerError::BadResponse(Cow::Owned(buf.into()))) + }; + let _ = next_or_err()?; + let key = next_or_err()?; + let flags: u32 = next_or_err()?.parse()?; + let length: usize = next_or_err()?.parse()?; + let cas: Option = if has_cas { Some(next_or_err()?.parse()?) } else { None }; + if header.next().is_some() { + return Err(ServerError::BadResponse(Cow::Owned(buf.into())))?; + } + Ok(Some((key.to_string(), flags, length, cas))) + })?; + match result { + Some((key, flags, length, cas)) => { + let mut value = vec![0u8; length + 2]; + self.reader.read_exact(value.as_mut_slice())?; + if &value[length..] != b"\r\n" { + return Err(ServerError::BadResponse(Cow::Owned(String::from_utf8(value)?)))?; + } + // remove the trailing \r\n + value.pop(); + value.pop(); + value.shrink_to_fit(); + let value = FromMemcacheValueExt::from_memcache_value(value, flags, cas)?; + Ok(Some((key.to_string(), value))) + } + None => Ok(None), + } + } + + fn parse_u64_response(&mut self) -> Result { + self.reader.read_line(|response| { + let s = MemcacheError::try_from(response)?; + Ok(s.trim_end_matches("\r\n").parse::()?) + }) + } +} diff --git a/src/protocol/binary.rs b/src/protocol/binary.rs index dc8d74c..42cf137 100644 --- a/src/protocol/binary.rs +++ b/src/protocol/binary.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use std::io::Write; +use super::ProtocolTrait; use byteorder::{BigEndian, WriteBytesExt}; use client::Stats; use error::MemcacheError; @@ -12,8 +13,8 @@ pub struct BinaryProtocol { pub stream: Stream, } -impl BinaryProtocol { - pub(super) fn auth(&mut self, username: &str, password: &str) -> Result<(), MemcacheError> { +impl ProtocolTrait for BinaryProtocol { + fn auth(&mut self, username: &str, password: &str) -> Result<(), MemcacheError> { let key = "PLAIN"; let request_header = PacketHeader { magic: Magic::Request as u8, @@ -29,48 +30,7 @@ impl BinaryProtocol { binary_packet::parse_start_auth_response(&mut self.stream).map(|_| ()) } - fn send_request>( - &mut self, - opcode: Opcode, - key: &str, - value: V, - expiration: u32, - cas: Option, - ) -> Result<(), MemcacheError> { - let request_header = PacketHeader { - magic: Magic::Request as u8, - opcode: opcode as u8, - key_length: key.len() as u16, - extras_length: 8, - total_body_length: (8 + key.len() + value.get_length()) as u32, - cas: cas.unwrap_or(0), - ..Default::default() - }; - let extras = binary_packet::StoreExtras { - flags: value.get_flags(), - expiration, - }; - request_header.write(&mut self.stream)?; - self.stream.write_u32::(extras.flags)?; - self.stream.write_u32::(extras.expiration)?; - self.stream.write_all(key.as_bytes())?; - value.write_to(&mut self.stream)?; - self.stream.flush().map_err(Into::into) - } - - fn store>( - &mut self, - opcode: Opcode, - key: &str, - value: V, - expiration: u32, - cas: Option, - ) -> Result<(), MemcacheError> { - self.send_request(opcode, key, value, expiration, cas)?; - binary_packet::parse_response(&mut self.stream)?.err().map(|_| ()) - } - - pub(super) fn version(&mut self) -> Result { + fn version(&mut self) -> Result { let request_header = PacketHeader { magic: Magic::Request as u8, opcode: Opcode::Version as u8, @@ -82,7 +42,7 @@ impl BinaryProtocol { return Ok(version); } - pub(super) fn flush(&mut self) -> Result<(), MemcacheError> { + fn flush(&mut self) -> Result<(), MemcacheError> { let request_header = PacketHeader { magic: Magic::Request as u8, opcode: Opcode::Flush as u8, @@ -93,7 +53,7 @@ impl BinaryProtocol { binary_packet::parse_response(&mut self.stream)?.err().map(|_| ()) } - pub(super) fn flush_with_delay(&mut self, delay: u32) -> Result<(), MemcacheError> { + fn flush_with_delay(&mut self, delay: u32) -> Result<(), MemcacheError> { let request_header = PacketHeader { magic: Magic::Request as u8, opcode: Opcode::Flush as u8, @@ -107,7 +67,7 @@ impl BinaryProtocol { binary_packet::parse_response(&mut self.stream)?.err().map(|_| ()) } - pub(super) fn get(&mut self, key: &str) -> Result, MemcacheError> { + fn get(&mut self, key: &str) -> Result, MemcacheError> { let request_header = PacketHeader { magic: Magic::Request as u8, opcode: Opcode::Get as u8, @@ -121,7 +81,7 @@ impl BinaryProtocol { return binary_packet::parse_get_response(&mut self.stream); } - pub(super) fn gets(&mut self, keys: &[&str]) -> Result, MemcacheError> { + fn gets(&mut self, keys: &[&str]) -> Result, MemcacheError> { for key in keys { let request_header = PacketHeader { magic: Magic::Request as u8, @@ -142,7 +102,7 @@ impl BinaryProtocol { return binary_packet::parse_gets_response(&mut self.stream, keys.len()); } - pub(super) fn cas>( + fn cas>( &mut self, key: &str, value: V, @@ -153,25 +113,15 @@ impl BinaryProtocol { binary_packet::parse_cas_response(&mut self.stream) } - pub(super) fn set>( - &mut self, - key: &str, - value: V, - expiration: u32, - ) -> Result<(), MemcacheError> { + fn set>(&mut self, key: &str, value: V, expiration: u32) -> Result<(), MemcacheError> { return self.store(Opcode::Set, key, value, expiration, None); } - pub(super) fn add>( - &mut self, - key: &str, - value: V, - expiration: u32, - ) -> Result<(), MemcacheError> { + fn add>(&mut self, key: &str, value: V, expiration: u32) -> Result<(), MemcacheError> { return self.store(Opcode::Add, key, value, expiration, None); } - pub(super) fn replace>( + fn replace>( &mut self, key: &str, value: V, @@ -180,7 +130,7 @@ impl BinaryProtocol { return self.store(Opcode::Replace, key, value, expiration, None); } - pub(super) fn append>(&mut self, key: &str, value: V) -> Result<(), MemcacheError> { + fn append>(&mut self, key: &str, value: V) -> Result<(), MemcacheError> { let request_header = PacketHeader { magic: Magic::Request as u8, opcode: Opcode::Append as u8, @@ -195,7 +145,7 @@ impl BinaryProtocol { binary_packet::parse_response(&mut self.stream)?.err().map(|_| ()) } - pub(super) fn prepend>(&mut self, key: &str, value: V) -> Result<(), MemcacheError> { + fn prepend>(&mut self, key: &str, value: V) -> Result<(), MemcacheError> { let request_header = PacketHeader { magic: Magic::Request as u8, opcode: Opcode::Prepend as u8, @@ -210,7 +160,7 @@ impl BinaryProtocol { binary_packet::parse_response(&mut self.stream).map(|_| ()) } - pub(super) fn delete(&mut self, key: &str) -> Result { + fn delete(&mut self, key: &str) -> Result { let request_header = PacketHeader { magic: Magic::Request as u8, opcode: Opcode::Delete as u8, @@ -224,7 +174,7 @@ impl BinaryProtocol { return binary_packet::parse_delete_response(&mut self.stream); } - pub(super) fn increment(&mut self, key: &str, amount: u64) -> Result { + fn increment(&mut self, key: &str, amount: u64) -> Result { let request_header = PacketHeader { magic: Magic::Request as u8, opcode: Opcode::Increment as u8, @@ -247,7 +197,7 @@ impl BinaryProtocol { return binary_packet::parse_counter_response(&mut self.stream); } - pub(super) fn decrement(&mut self, key: &str, amount: u64) -> Result { + fn decrement(&mut self, key: &str, amount: u64) -> Result { let request_header = PacketHeader { magic: Magic::Request as u8, opcode: Opcode::Decrement as u8, @@ -270,7 +220,7 @@ impl BinaryProtocol { return binary_packet::parse_counter_response(&mut self.stream); } - pub(super) fn touch(&mut self, key: &str, expiration: u32) -> Result { + fn touch(&mut self, key: &str, expiration: u32) -> Result { let request_header = PacketHeader { magic: Magic::Request as u8, opcode: Opcode::Touch as u8, @@ -286,7 +236,7 @@ impl BinaryProtocol { return binary_packet::parse_touch_response(&mut self.stream); } - pub(super) fn stats(&mut self) -> Result { + fn stats(&mut self) -> Result { let request_header = PacketHeader { magic: Magic::Request as u8, opcode: Opcode::Stat as u8, @@ -298,3 +248,46 @@ impl BinaryProtocol { return Ok(stats_info); } } + +impl BinaryProtocol { + fn send_request>( + &mut self, + opcode: Opcode, + key: &str, + value: V, + expiration: u32, + cas: Option, + ) -> Result<(), MemcacheError> { + let request_header = PacketHeader { + magic: Magic::Request as u8, + opcode: opcode as u8, + key_length: key.len() as u16, + extras_length: 8, + total_body_length: (8 + key.len() + value.get_length()) as u32, + cas: cas.unwrap_or(0), + ..Default::default() + }; + let extras = binary_packet::StoreExtras { + flags: value.get_flags(), + expiration, + }; + request_header.write(&mut self.stream)?; + self.stream.write_u32::(extras.flags)?; + self.stream.write_u32::(extras.expiration)?; + self.stream.write_all(key.as_bytes())?; + value.write_to(&mut self.stream)?; + self.stream.flush().map_err(Into::into) + } + + fn store>( + &mut self, + opcode: Opcode, + key: &str, + value: V, + expiration: u32, + cas: Option, + ) -> Result<(), MemcacheError> { + self.send_request(opcode, key, value, expiration, cas)?; + binary_packet::parse_response(&mut self.stream)?.err().map(|_| ()) + } +}