From b514631c04b038331a4b25acc63b49fc93d62a24 Mon Sep 17 00:00:00 2001 From: Shane Hathaway Date: Sat, 8 Feb 2020 19:29:53 -0700 Subject: [PATCH 1/7] Add two pipelined methods: client.sets() and client.deletes(). Pipelining should make communication much more efficient for large numbers of sets and deletes. --- src/client.rs | 69 +++++++++++++- src/protocol/ascii.rs | 165 +++++++++++++++++++++++----------- src/protocol/binary.rs | 72 ++++++++++++--- src/protocol/binary_packet.rs | 1 + src/protocol/mod.rs | 6 ++ tests/tests.rs | 25 +++++- 6 files changed, 266 insertions(+), 72 deletions(-) diff --git a/src/client.rs b/src/client.rs index 8cd5ab0..ac22435 100644 --- a/src/client.rs +++ b/src/client.rs @@ -189,6 +189,12 @@ impl Client { return self.get_connection(key).get()?.get(key); } + /// Map a key to a connection index. + fn hash_key(&self, key: &str) -> usize { + let connections_count = self.connections.len(); + (self.hash_function)(key) as usize % connections_count + } + /// Get multiple keys from memcached server. Using this function instead of calling `get` multiple times can reduce netwark workloads. /// /// Example: @@ -203,12 +209,9 @@ impl Client { pub fn gets(&self, keys: &[&str]) -> Result, MemcacheError> { let mut con_keys: HashMap> = HashMap::new(); let mut result: HashMap = HashMap::new(); - let connections_count = self.connections.len(); for key in keys { - let connection_index = (self.hash_function)(key) as usize % connections_count; - let array = con_keys.entry(connection_index).or_insert_with(Vec::new); - array.push(key); + con_keys.entry(self.hash_key(key)).or_default().push(key); } for (&connection_index, keys) in con_keys.iter() { let connection = self.connections[connection_index].clone(); @@ -230,6 +233,36 @@ impl Client { return self.get_connection(key).get()?.set(key, value, expiration); } + /// Set multiple keys with associated values into memcached server with expiration seconds. + /// + /// Example: + /// + /// ```rust + /// let client = memcache::Client::connect("memcache://localhost:12345").unwrap(); + /// client.sets(vec![("foo", "Foo"), ("bar", "Bar")], 10).unwrap(); + /// # client.flush().unwrap(); + /// ``` + pub fn sets, K: AsRef, I: IntoIterator>( + &self, + entries: I, + expiration: u32, + ) -> Result<(), MemcacheError> { + let mut entry_map: HashMap> = HashMap::new(); + for (key, value) in entries.into_iter() { + entry_map + .entry(self.hash_key(key.as_ref())) + .or_default() + .push((key, value)); + } + + for (connection_index, entries_subset) in entry_map.into_iter() { + let connection = self.connections[connection_index].clone(); + connection.get()?.sets(entries_subset, expiration)?; + } + + Ok(()) + } + /// Compare and swap a key with the associate value into memcached server with expiration seconds. /// `cas_id` should be obtained from a previous `gets` call. /// @@ -337,6 +370,34 @@ impl Client { return self.get_connection(key).get()?.delete(key); } + /// Delete keys from memcached server. + /// + /// Example: + /// + /// ```rust + /// let client = memcache::Client::connect("memcache://localhost:12345").unwrap(); + /// client.deletes(&["foo", "bar"]).unwrap(); + /// # client.flush().unwrap(); + /// ``` + pub fn deletes + Eq + Hash, I: IntoIterator>( + &self, + keys: I, + ) -> Result, MemcacheError> { + let mut con_keys: HashMap> = HashMap::new(); + for key in keys.into_iter() { + con_keys.entry(self.hash_key(key.as_ref())).or_default().push(key); + } + + let mut result: HashMap = HashMap::new(); + for (connection_index, keys_subset) in con_keys.into_iter() { + let connection = self.connections[connection_index].clone(); + for (deleted, key) in connection.get()?.deletes(&keys_subset)?.into_iter().zip(keys_subset) { + result.insert(key, deleted); + } + } + Ok(result) + } + /// Increment the value with amount. /// /// Example: diff --git a/src/protocol/ascii.rs b/src/protocol/ascii.rs index 45014b3..581e4e0 100644 --- a/src/protocol/ascii.rs +++ b/src/protocol/ascii.rs @@ -148,7 +148,15 @@ impl AsciiProtocol { value: V, options: &Options, ) -> Result { - check_key_len(key)?; + Ok(self.stores(command, Some((key, value)), options)?) + } + + fn stores, K: AsRef, I: IntoIterator>( + &mut self, + command: StoreCommand, + entries: I, + options: &Options, + ) -> Result { if command == StoreCommand::Cas { if options.cas.is_none() { Err(ClientError::Error(Cow::Borrowed( @@ -156,50 +164,68 @@ impl AsciiProtocol { )))?; } } + let noreply = if options.noreply { " noreply" } else { "" }; - if options.cas.is_some() { - write!( - self.reader.get_mut(), - "{command} {key} {flags} {exptime} {vlen} {cas}{noreply}\r\n", - command = command, - key = key, - flags = value.get_flags(), - exptime = options.exptime, - vlen = value.get_length(), - cas = options.cas.unwrap(), - noreply = noreply - )?; - } else { - write!( - self.reader.get_mut(), - "{command} {key} {flags} {exptime} {vlen}{noreply}\r\n", - command = command, - key = key, - flags = value.get_flags(), - exptime = options.exptime, - vlen = value.get_length(), - noreply = noreply - )?; - } + let mut sent_count = 0; + + { + let reader = self.reader.get_mut(); + for (key_ref, value) in entries.into_iter() { + let key = key_ref.as_ref(); + check_key_len(key)?; + if options.cas.is_some() { + write!( + reader, + "{command} {key} {flags} {exptime} {vlen} {cas}{noreply}\r\n", + command = command, + key = key, + flags = value.get_flags(), + exptime = options.exptime, + vlen = value.get_length(), + cas = options.cas.unwrap(), + noreply = noreply + )?; + } else { + write!( + reader, + "{command} {key} {flags} {exptime} {vlen}{noreply}\r\n", + command = command, + key = key, + flags = value.get_flags(), + exptime = options.exptime, + vlen = value.get_length(), + noreply = noreply + )?; + } - value.write_to(self.reader.get_mut())?; - self.reader.get_mut().write(b"\r\n")?; - self.reader.get_mut().flush()?; + value.write_to(reader)?; + reader.write(b"\r\n")?; + sent_count += 1; + } + + // Flush now that all the requests have been written. + reader.flush()?; + } if options.noreply { return Ok(true); } - self.reader.read_line(|response| { - let response = MemcacheError::try_from(response)?; - match response { - "STORED\r\n" => Ok(true), - "NOT_STORED\r\n" => Ok(false), - "EXISTS\r\n" => Err(CommandError::KeyExists)?, - "NOT_FOUND\r\n" => Err(CommandError::KeyNotFound)?, - response => Err(ServerError::BadResponse(Cow::Owned(response.into())))?, - } - }) + let mut all_stored = true; + for _ in 0..sent_count { + all_stored = all_stored + && self.reader.read_line(|response| { + let response = MemcacheError::try_from(response)?; + match response { + "STORED\r\n" => Ok(true), + "NOT_STORED\r\n" => Ok(false), + "EXISTS\r\n" => Err(CommandError::KeyExists)?, + "NOT_FOUND\r\n" => Err(CommandError::KeyNotFound)?, + response => Err(ServerError::BadResponse(Cow::Owned(response.into())))?, + } + })?; + } + Ok(all_stored) } pub(super) fn version(&mut self) -> Result { @@ -355,6 +381,18 @@ impl AsciiProtocol { self.store(StoreCommand::Set, key, value, &options).map(|_| ()) } + pub(super) fn sets, K: AsRef, I: IntoIterator>( + &mut self, + entries: I, + expiration: u32, + ) -> Result<(), MemcacheError> { + let options = Options { + exptime: expiration, + ..Default::default() + }; + self.stores(StoreCommand::Set, entries, &options).map(|_| ()) + } + pub(super) fn add>( &mut self, key: &str, @@ -393,22 +431,43 @@ impl AsciiProtocol { .map(|_| ()) } + pub(super) fn deletes, I: IntoIterator>( + &mut self, + keys: I, + ) -> Result, MemcacheError> { + let mut sent_count = 0; + { + let reader = self.reader.get_mut(); + for k in keys.into_iter() { + let key = k.as_ref(); + check_key_len(key)?; + write!(reader, "delete {}\r\n", key)?; + sent_count += 1; + } + reader.flush()?; + } + let mut res = Vec::with_capacity(sent_count); + for _ in 0..sent_count { + res.push( + self.reader + .read_line(|response| match MemcacheError::try_from(response) { + Ok(s) => { + if s == "DELETED\r\n" { + Ok(true) + } else { + Err(ServerError::BadResponse(Cow::Owned(s.into())).into()) + } + } + Err(MemcacheError::CommandError(CommandError::KeyNotFound)) => Ok(false), + Err(e) => Err(e), + })?, + ); + } + Ok(res) + } + pub(super) fn delete(&mut self, key: &str) -> Result { - check_key_len(key)?; - write!(self.reader.get_mut(), "delete {}\r\n", key)?; - self.reader.get_mut().flush()?; - self.reader - .read_line(|response| match MemcacheError::try_from(response) { - Ok(s) => { - if s == "DELETED\r\n" { - Ok(true) - } else { - Err(ServerError::BadResponse(Cow::Owned(s.into())).into()) - } - } - Err(MemcacheError::CommandError(CommandError::KeyNotFound)) => Ok(false), - Err(e) => Err(e), - }) + Ok(self.deletes(&[key])?[0]) } fn parse_u64_response(&mut self) -> Result { diff --git a/src/protocol/binary.rs b/src/protocol/binary.rs index 39367f8..0cd0291 100644 --- a/src/protocol/binary.rs +++ b/src/protocol/binary.rs @@ -57,7 +57,7 @@ impl BinaryProtocol { self.stream.write_u32::(extras.expiration)?; self.stream.write_all(key.as_bytes())?; value.write_to(&mut self.stream)?; - self.stream.flush().map_err(Into::into) + Ok(()) } fn store>( @@ -69,9 +69,30 @@ impl BinaryProtocol { cas: Option, ) -> Result<(), MemcacheError> { self.send_request(opcode, key, value, expiration, cas)?; + self.stream.flush()?; binary_packet::parse_response(&mut self.stream)?.err().map(|_| ()) } + /// Support efficient multi-store operations using pipelining. + fn stores, K: AsRef, I: IntoIterator>( + &mut self, + opcode: Opcode, + entries: I, + expiration: u32, + cas: Option, + ) -> Result<(), MemcacheError> { + let mut sent_count = 0; + for (key, value) in entries.into_iter() { + self.send_request(opcode, key.as_ref(), value, expiration, cas)?; + sent_count += 1; + } + self.stream.flush()?; + for _ in 0..sent_count { + binary_packet::parse_response(&mut self.stream)?; + } + Ok(()) + } + pub(super) fn version(&mut self) -> Result { let request_header = PacketHeader { magic: Magic::Request as u8, @@ -143,6 +164,7 @@ impl BinaryProtocol { ..Default::default() }; noop_request_header.write(&mut self.stream)?; + self.stream.flush()?; return binary_packet::parse_gets_response(&mut self.stream, keys.len()); } @@ -154,6 +176,7 @@ impl BinaryProtocol { cas: u64, ) -> Result { self.send_request(Opcode::Set, key, value, expiration, Some(cas))?; + self.stream.flush()?; binary_packet::parse_cas_response(&mut self.stream) } @@ -166,6 +189,14 @@ impl BinaryProtocol { return self.store(Opcode::Set, key, value, expiration, None); } + pub(super) fn sets, K: AsRef, I: IntoIterator>( + &mut self, + entries: I, + expiration: u32, + ) -> Result<(), MemcacheError> { + return self.stores(Opcode::Set, entries, expiration, None); + } + pub(super) fn add>( &mut self, key: &str, @@ -217,18 +248,35 @@ impl BinaryProtocol { } pub(super) fn delete(&mut self, key: &str) -> Result { - check_key_len(key)?; - let request_header = PacketHeader { - magic: Magic::Request as u8, - opcode: Opcode::Delete as u8, - key_length: key.len() as u16, - total_body_length: key.len() as u32, - ..Default::default() - }; - request_header.write(&mut self.stream)?; - self.stream.write_all(key.as_bytes())?; + Ok(self.deletes(&[key])?[0]) + } + + pub(super) fn deletes, I: IntoIterator>( + &mut self, + keys: I, + ) -> Result, MemcacheError> { + let mut sent_count = 0; + for k in keys.into_iter() { + check_key_len(k.as_ref())?; + let key = k.as_ref(); + let request_header = PacketHeader { + magic: Magic::Request as u8, + opcode: Opcode::Delete as u8, + key_length: key.len() as u16, + total_body_length: key.len() as u32, + ..Default::default() + }; + request_header.write(&mut self.stream)?; + self.stream.write_all(key.as_bytes())?; + sent_count += 1; + } + // Flush now that all the requests have been written. self.stream.flush()?; - return binary_packet::parse_delete_response(&mut self.stream); + let mut res = Vec::with_capacity(sent_count); + for _ in 0..sent_count { + res.push(binary_packet::parse_delete_response(&mut self.stream)?); + } + Ok(res) } pub(super) fn increment(&mut self, key: &str, amount: u64) -> Result { diff --git a/src/protocol/binary_packet.rs b/src/protocol/binary_packet.rs index 0a11a87..dcdfb7b 100644 --- a/src/protocol/binary_packet.rs +++ b/src/protocol/binary_packet.rs @@ -8,6 +8,7 @@ use value::FromMemcacheValueExt; const OK_STATUS: u16 = 0x0; #[allow(dead_code)] +#[derive(Copy, Clone)] pub enum Opcode { Get = 0x00, Set = 0x01, diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 56d667b..605f74e 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -34,6 +34,11 @@ pub trait ProtocolTrait { fn get(&mut self, key: &str) -> Result, MemcacheError>; fn gets(&mut self, keys: &[&str]) -> Result, MemcacheError>; fn set>(&mut self, key: &str, value: V, expiration: u32) -> Result<(), MemcacheError>; + fn sets, K: AsRef, I: IntoIterator>( + &mut self, + entries: I, + expiration: u32, + ) -> Result<(), MemcacheError>; fn cas>( &mut self, key: &str, @@ -51,6 +56,7 @@ pub trait ProtocolTrait { fn append>(&mut self, key: &str, value: V) -> Result<(), MemcacheError>; fn prepend>(&mut self, key: &str, value: V) -> Result<(), MemcacheError>; fn delete(&mut self, key: &str) -> Result; + fn deletes, I: IntoIterator>(&mut self, keys: I) -> Result, MemcacheError>; fn increment(&mut self, key: &str, amount: u64) -> Result; fn decrement(&mut self, key: &str, amount: u64) -> Result; fn touch(&mut self, key: &str, expiration: u32) -> Result; diff --git a/tests/tests.rs b/tests/tests.rs index a50049a..b0941a3 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -17,6 +17,8 @@ fn gen_random_key() -> String { #[test] fn test() { + use std::collections::HashMap; + let mut urls = vec![ "memcache://localhost:12346?tcp_nodelay=true", "memcache://localhost:12347?timeout=10", @@ -45,16 +47,33 @@ fn test() { assert_eq!(value, None); let mut keys: Vec = Vec::new(); - for _ in 0..1000 { + let mut batch: HashMap = HashMap::new(); + for i in 0..1000 { let key = gen_random_key(); keys.push(key.clone()); - client.set(key.as_str(), "xxx", 0).unwrap(); + if i < 10 { + // Set the first 10 one at a time + client.set(key.as_str(), "xxx", 0).unwrap(); + } else { + // Set the rest as a batch + batch.insert(key, "xxx".to_string()); + } } + client.sets(batch, 0).unwrap(); - for key in keys { + let keys_strs: Vec<&str> = keys.iter().map(String::as_str).collect(); + let all_at_once: HashMap = client.gets(&keys_strs).unwrap(); + assert_eq!(1000, all_at_once.len()); + + for key in keys.iter() { let value: String = client.get(key.as_str()).unwrap().unwrap(); assert_eq!(value, "xxx"); + assert_eq!(all_at_once[key], "xxx"); } + + client.deletes(&keys_strs).unwrap(); + let all_at_once: HashMap = client.gets(&keys_strs).unwrap(); + assert_eq!(0, all_at_once.len()); } #[test] From 4c7507044b0f85566dd5fee26d57a2839e381a50 Mon Sep 17 00:00:00 2001 From: Shane Hathaway Date: Sat, 8 Feb 2020 20:11:27 -0700 Subject: [PATCH 2/7] Improve error handling: when pipelining, receive all responses to keep the protocol in sync. --- src/protocol/ascii.rs | 77 +++++++++++++++++++++++++++--------------- src/protocol/binary.rs | 27 ++++++++++++--- 2 files changed, 72 insertions(+), 32 deletions(-) diff --git a/src/protocol/ascii.rs b/src/protocol/ascii.rs index 581e4e0..1ad484f 100644 --- a/src/protocol/ascii.rs +++ b/src/protocol/ascii.rs @@ -211,21 +211,32 @@ impl AsciiProtocol { return Ok(true); } + // Receive all the responses. If there were errors, return the first. + let mut all_stored = true; + let mut error_list: Vec = Vec::new(); for _ in 0..sent_count { - all_stored = all_stored - && self.reader.read_line(|response| { - let response = MemcacheError::try_from(response)?; - match response { - "STORED\r\n" => Ok(true), - "NOT_STORED\r\n" => Ok(false), - "EXISTS\r\n" => Err(CommandError::KeyExists)?, - "NOT_FOUND\r\n" => Err(CommandError::KeyNotFound)?, - response => Err(ServerError::BadResponse(Cow::Owned(response.into())))?, - } - })?; + let one_result = self.reader.read_line(|response| { + let response = MemcacheError::try_from(response)?; + match response { + "STORED\r\n" => Ok(true), + "NOT_STORED\r\n" => Ok(false), + "EXISTS\r\n" => Err(CommandError::KeyExists)?, + "NOT_FOUND\r\n" => Err(CommandError::KeyNotFound)?, + response => Err(ServerError::BadResponse(Cow::Owned(response.into())))?, + } + }); + match one_result { + Ok(true) => (), + Ok(false) => all_stored = false, + Err(e) => error_list.push(e), + } + } + + match error_list.into_iter().next() { + None => Ok(all_stored), + Some(e) => Err(e), } - Ok(all_stored) } pub(super) fn version(&mut self) -> Result { @@ -444,26 +455,38 @@ impl AsciiProtocol { write!(reader, "delete {}\r\n", key)?; sent_count += 1; } + // Flush now that all the requests have been written. reader.flush()?; } - let mut res = Vec::with_capacity(sent_count); + + // Receive all the responses. If there were errors, return the first. + + let mut deleted_list = Vec::with_capacity(sent_count); + let mut error_list: Vec = Vec::new(); for _ in 0..sent_count { - res.push( - self.reader - .read_line(|response| match MemcacheError::try_from(response) { - Ok(s) => { - if s == "DELETED\r\n" { - Ok(true) - } else { - Err(ServerError::BadResponse(Cow::Owned(s.into())).into()) - } + let one_result = self + .reader + .read_line(|response| match MemcacheError::try_from(response) { + Ok(s) => { + if s == "DELETED\r\n" { + Ok(true) + } else { + Err(ServerError::BadResponse(Cow::Owned(s.into())).into()) } - Err(MemcacheError::CommandError(CommandError::KeyNotFound)) => Ok(false), - Err(e) => Err(e), - })?, - ); + } + Err(MemcacheError::CommandError(CommandError::KeyNotFound)) => Ok(false), + Err(e) => Err(e), + }); + match one_result { + Ok(deleted) => deleted_list.push(deleted), + Err(e) => error_list.push(e), + } + } + + match error_list.into_iter().next() { + None => Ok(deleted_list), + Some(e) => Err(e), } - Ok(res) } pub(super) fn delete(&mut self, key: &str) -> Result { diff --git a/src/protocol/binary.rs b/src/protocol/binary.rs index 0cd0291..538434f 100644 --- a/src/protocol/binary.rs +++ b/src/protocol/binary.rs @@ -86,11 +86,20 @@ impl BinaryProtocol { self.send_request(opcode, key.as_ref(), value, expiration, cas)?; sent_count += 1; } + // Flush now that all the requests have been written. self.stream.flush()?; + // Receive all the responses. If there were errors, return the first. + let mut error_list = Vec::new(); for _ in 0..sent_count { - binary_packet::parse_response(&mut self.stream)?; + match binary_packet::parse_response(&mut self.stream) { + Ok(_) => (), + Err(e) => error_list.push(e), + }; + } + match error_list.into_iter().next() { + None => Ok(()), + Some(e) => Err(e), } - Ok(()) } pub(super) fn version(&mut self) -> Result { @@ -272,11 +281,19 @@ impl BinaryProtocol { } // Flush now that all the requests have been written. self.stream.flush()?; - let mut res = Vec::with_capacity(sent_count); + // Receive all the responses. If there were errors, return the first. + let mut deleted_list = Vec::with_capacity(sent_count); + let mut error_list: Vec = Vec::new(); for _ in 0..sent_count { - res.push(binary_packet::parse_delete_response(&mut self.stream)?); + match binary_packet::parse_delete_response(&mut self.stream) { + Ok(deleted) => deleted_list.push(deleted), + Err(e) => error_list.push(e), + } + } + match error_list.into_iter().next() { + None => Ok(deleted_list), + Some(e) => Err(e), } - Ok(res) } pub(super) fn increment(&mut self, key: &str, amount: u64) -> Result { From 038658b708cdafea03c338074c72bec0675adea9 Mon Sep 17 00:00:00 2001 From: Shane Hathaway Date: Sat, 8 Feb 2020 21:35:05 -0700 Subject: [PATCH 3/7] Revise proposed method names to set_multi and delete_multi; add get_multi and leave a "gets" alias. All the _multi methods accept IntoIterator, making them compatible with many kinds of collections. --- src/client.rs | 37 +++++++++++++++++++++++++------------ src/protocol/ascii.rs | 26 ++++++++++++++++++++------ src/protocol/binary.rs | 18 ++++++++++++------ src/protocol/mod.rs | 9 ++++++--- tests/tests.rs | 9 ++++----- 5 files changed, 67 insertions(+), 32 deletions(-) diff --git a/src/client.rs b/src/client.rs index ac22435..1130f8e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -195,6 +195,11 @@ impl Client { (self.hash_function)(key) as usize % connections_count } + /// Alias for get_multi(). + pub fn gets(&self, keys: &[&str]) -> Result, MemcacheError> { + self.get_multi(keys) + } + /// Get multiple keys from memcached server. Using this function instead of calling `get` multiple times can reduce netwark workloads. /// /// Example: @@ -202,20 +207,23 @@ impl Client { /// ```rust /// let client = memcache::Client::connect("memcache://localhost:12345").unwrap(); /// client.set("foo", "42", 0).unwrap(); - /// let result: std::collections::HashMap = client.gets(&["foo", "bar", "baz"]).unwrap(); + /// let result: std::collections::HashMap = client.get_multi(&["foo", "bar", "baz"]).unwrap(); /// assert_eq!(result.len(), 1); /// assert_eq!(result["foo"], "42"); /// ``` - pub fn gets(&self, keys: &[&str]) -> Result, MemcacheError> { - let mut con_keys: HashMap> = HashMap::new(); + pub fn get_multi, I: IntoIterator>( + &self, + keys: I, + ) -> Result, MemcacheError> { + let mut con_keys: HashMap> = HashMap::new(); let mut result: HashMap = HashMap::new(); - for key in keys { - con_keys.entry(self.hash_key(key)).or_default().push(key); + for k in keys { + con_keys.entry(self.hash_key(k.as_ref())).or_default().push(k); } for (&connection_index, keys) in con_keys.iter() { let connection = self.connections[connection_index].clone(); - result.extend(connection.get()?.gets(keys)?); + result.extend(connection.get()?.get_multi(keys)?); } return Ok(result); } @@ -239,10 +247,10 @@ impl Client { /// /// ```rust /// let client = memcache::Client::connect("memcache://localhost:12345").unwrap(); - /// client.sets(vec![("foo", "Foo"), ("bar", "Bar")], 10).unwrap(); + /// client.set_multi(vec![("foo", "Foo"), ("bar", "Bar")], 10).unwrap(); /// # client.flush().unwrap(); /// ``` - pub fn sets, K: AsRef, I: IntoIterator>( + pub fn set_multi, K: AsRef, I: IntoIterator>( &self, entries: I, expiration: u32, @@ -257,7 +265,7 @@ impl Client { for (connection_index, entries_subset) in entry_map.into_iter() { let connection = self.connections[connection_index].clone(); - connection.get()?.sets(entries_subset, expiration)?; + connection.get()?.set_multi(entries_subset, expiration)?; } Ok(()) @@ -376,10 +384,10 @@ impl Client { /// /// ```rust /// let client = memcache::Client::connect("memcache://localhost:12345").unwrap(); - /// client.deletes(&["foo", "bar"]).unwrap(); + /// client.delete_multi(&["foo", "bar"]).unwrap(); /// # client.flush().unwrap(); /// ``` - pub fn deletes + Eq + Hash, I: IntoIterator>( + pub fn delete_multi + Eq + Hash, I: IntoIterator>( &self, keys: I, ) -> Result, MemcacheError> { @@ -391,7 +399,12 @@ impl Client { let mut result: HashMap = HashMap::new(); for (connection_index, keys_subset) in con_keys.into_iter() { let connection = self.connections[connection_index].clone(); - for (deleted, key) in connection.get()?.deletes(&keys_subset)?.into_iter().zip(keys_subset) { + for (deleted, key) in connection + .get()? + .delete_multi(&keys_subset)? + .into_iter() + .zip(keys_subset) + { result.insert(key, deleted); } } diff --git a/src/protocol/ascii.rs b/src/protocol/ascii.rs index 1ad484f..5482197 100644 --- a/src/protocol/ascii.rs +++ b/src/protocol/ascii.rs @@ -338,11 +338,25 @@ impl AsciiProtocol { } } - pub(super) fn gets(&mut self, keys: &[&str]) -> Result, MemcacheError> { - for key in keys { + pub(super) fn get_multi, I: IntoIterator>( + &mut self, + keys: I, + ) -> Result, MemcacheError> { + let keys: Vec = keys.into_iter().collect(); + let mut capacity = 0; + for k in keys.iter() { + let key = k.as_ref(); check_key_len(key)?; + capacity += key.len() + 1; } - write!(self.reader.get_mut(), "gets {}\r\n", keys.join(" "))?; + + let mut keystr = String::with_capacity(capacity); + for k in keys.iter() { + keystr.push(' '); + keystr.push_str(k.as_ref()); + } + + write!(self.reader.get_mut(), "gets{}\r\n", keystr)?; let mut result: HashMap = HashMap::with_capacity(keys.len()); // there will be atmost keys.len() "VALUE <...>" responses and one END response @@ -392,7 +406,7 @@ impl AsciiProtocol { self.store(StoreCommand::Set, key, value, &options).map(|_| ()) } - pub(super) fn sets, K: AsRef, I: IntoIterator>( + pub(super) fn set_multi, K: AsRef, I: IntoIterator>( &mut self, entries: I, expiration: u32, @@ -442,7 +456,7 @@ impl AsciiProtocol { .map(|_| ()) } - pub(super) fn deletes, I: IntoIterator>( + pub(super) fn delete_multi, I: IntoIterator>( &mut self, keys: I, ) -> Result, MemcacheError> { @@ -490,7 +504,7 @@ impl AsciiProtocol { } pub(super) fn delete(&mut self, key: &str) -> Result { - Ok(self.deletes(&[key])?[0]) + Ok(self.delete_multi(&[key])?[0]) } fn parse_u64_response(&mut self) -> Result { diff --git a/src/protocol/binary.rs b/src/protocol/binary.rs index 538434f..84eb220 100644 --- a/src/protocol/binary.rs +++ b/src/protocol/binary.rs @@ -154,8 +154,13 @@ impl BinaryProtocol { return binary_packet::parse_get_response(&mut self.stream); } - pub(super) fn gets(&mut self, keys: &[&str]) -> Result, MemcacheError> { - for key in keys { + pub(super) fn get_multi, I: IntoIterator>( + &mut self, + keys: I, + ) -> Result, MemcacheError> { + let mut count = 0; + for k in keys.into_iter() { + let key = k.as_ref(); check_key_len(key)?; let request_header = PacketHeader { magic: Magic::Request as u8, @@ -166,6 +171,7 @@ impl BinaryProtocol { }; request_header.write(&mut self.stream)?; self.stream.write_all(key.as_bytes())?; + count += 1; } let noop_request_header = PacketHeader { magic: Magic::Request as u8, @@ -174,7 +180,7 @@ impl BinaryProtocol { }; noop_request_header.write(&mut self.stream)?; self.stream.flush()?; - return binary_packet::parse_gets_response(&mut self.stream, keys.len()); + return binary_packet::parse_gets_response(&mut self.stream, count); } pub(super) fn cas>( @@ -198,7 +204,7 @@ impl BinaryProtocol { return self.store(Opcode::Set, key, value, expiration, None); } - pub(super) fn sets, K: AsRef, I: IntoIterator>( + pub(super) fn set_multi, K: AsRef, I: IntoIterator>( &mut self, entries: I, expiration: u32, @@ -257,10 +263,10 @@ impl BinaryProtocol { } pub(super) fn delete(&mut self, key: &str) -> Result { - Ok(self.deletes(&[key])?[0]) + Ok(self.delete_multi(&[key])?[0]) } - pub(super) fn deletes, I: IntoIterator>( + pub(super) fn delete_multi, I: IntoIterator>( &mut self, keys: I, ) -> Result, MemcacheError> { diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 605f74e..de349db 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -32,9 +32,12 @@ pub trait ProtocolTrait { fn flush(&mut self) -> Result<(), MemcacheError>; fn flush_with_delay(&mut self, delay: u32) -> Result<(), MemcacheError>; fn get(&mut self, key: &str) -> Result, MemcacheError>; - fn gets(&mut self, keys: &[&str]) -> Result, MemcacheError>; + fn get_multi, I: IntoIterator>( + &mut self, + keys: I, + ) -> Result, MemcacheError>; fn set>(&mut self, key: &str, value: V, expiration: u32) -> Result<(), MemcacheError>; - fn sets, K: AsRef, I: IntoIterator>( + fn set_multi, K: AsRef, I: IntoIterator>( &mut self, entries: I, expiration: u32, @@ -56,7 +59,7 @@ pub trait ProtocolTrait { fn append>(&mut self, key: &str, value: V) -> Result<(), MemcacheError>; fn prepend>(&mut self, key: &str, value: V) -> Result<(), MemcacheError>; fn delete(&mut self, key: &str) -> Result; - fn deletes, I: IntoIterator>(&mut self, keys: I) -> Result, MemcacheError>; + fn delete_multi, I: IntoIterator>(&mut self, keys: I) -> Result, MemcacheError>; fn increment(&mut self, key: &str, amount: u64) -> Result; fn decrement(&mut self, key: &str, amount: u64) -> Result; fn touch(&mut self, key: &str, expiration: u32) -> Result; diff --git a/tests/tests.rs b/tests/tests.rs index b0941a3..06f4d65 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -59,10 +59,9 @@ fn test() { batch.insert(key, "xxx".to_string()); } } - client.sets(batch, 0).unwrap(); + client.set_multi(batch, 0).unwrap(); - let keys_strs: Vec<&str> = keys.iter().map(String::as_str).collect(); - let all_at_once: HashMap = client.gets(&keys_strs).unwrap(); + let all_at_once: HashMap = client.get_multi(&keys).unwrap(); assert_eq!(1000, all_at_once.len()); for key in keys.iter() { @@ -71,8 +70,8 @@ fn test() { assert_eq!(all_at_once[key], "xxx"); } - client.deletes(&keys_strs).unwrap(); - let all_at_once: HashMap = client.gets(&keys_strs).unwrap(); + client.delete_multi(&keys).unwrap(); + let all_at_once: HashMap = client.get_multi(&keys).unwrap(); assert_eq!(0, all_at_once.len()); } From e1753f527422797bb1adce15ceaf81d4830d0dbb Mon Sep 17 00:00:00 2001 From: Shane Hathaway Date: Sat, 8 Feb 2020 23:19:47 -0700 Subject: [PATCH 4/7] Use where clauses on the _multi methods to improve the documentation layout --- src/client.rs | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/src/client.rs b/src/client.rs index 1130f8e..c5ab502 100644 --- a/src/client.rs +++ b/src/client.rs @@ -200,7 +200,7 @@ impl Client { self.get_multi(keys) } - /// Get multiple keys from memcached server. Using this function instead of calling `get` multiple times can reduce netwark workloads. + /// Get multiple keys from memcached server. Using this function instead of calling `get` multiple times can reduce network workloads. /// /// Example: /// @@ -211,10 +211,12 @@ impl Client { /// assert_eq!(result.len(), 1); /// assert_eq!(result["foo"], "42"); /// ``` - pub fn get_multi, I: IntoIterator>( - &self, - keys: I, - ) -> Result, MemcacheError> { + pub fn get_multi(&self, keys: I) -> Result, MemcacheError> + where + V: FromMemcacheValueExt, + K: AsRef, + I: IntoIterator, + { let mut con_keys: HashMap> = HashMap::new(); let mut result: HashMap = HashMap::new(); @@ -243,6 +245,8 @@ impl Client { /// Set multiple keys with associated values into memcached server with expiration seconds. /// + /// Uses pipelining to reduce the number of server round trips. + /// /// Example: /// /// ```rust @@ -250,11 +254,12 @@ impl Client { /// client.set_multi(vec![("foo", "Foo"), ("bar", "Bar")], 10).unwrap(); /// # client.flush().unwrap(); /// ``` - pub fn set_multi, K: AsRef, I: IntoIterator>( - &self, - entries: I, - expiration: u32, - ) -> Result<(), MemcacheError> { + pub fn set_multi(&self, entries: I, expiration: u32) -> Result<(), MemcacheError> + where + V: ToMemcacheValue, + K: AsRef, + I: IntoIterator, + { let mut entry_map: HashMap> = HashMap::new(); for (key, value) in entries.into_iter() { entry_map @@ -380,6 +385,8 @@ impl Client { /// Delete keys from memcached server. /// + /// Uses pipelining to reduce the number of server round trips. + /// /// Example: /// /// ```rust @@ -387,10 +394,11 @@ impl Client { /// client.delete_multi(&["foo", "bar"]).unwrap(); /// # client.flush().unwrap(); /// ``` - pub fn delete_multi + Eq + Hash, I: IntoIterator>( - &self, - keys: I, - ) -> Result, MemcacheError> { + pub fn delete_multi(&self, keys: I) -> Result, MemcacheError> + where + K: AsRef + Eq + Hash, + I: IntoIterator, + { let mut con_keys: HashMap> = HashMap::new(); for key in keys.into_iter() { con_keys.entry(self.hash_key(key.as_ref())).or_default().push(key); From 337433ea827df3de15aabd8602ed31e8d068355a Mon Sep 17 00:00:00 2001 From: Shane Hathaway Date: Sun, 9 Feb 2020 02:23:15 -0700 Subject: [PATCH 5/7] Fix two ASCII protocol bugs. When a line spanned multiple read packets, CappedLineReader would put the line back together. Line reassembly had two bugs: 1. After assembling the line and passing it to the caller, read_line() should have called self.consume(filled + n), but it only called self.consume(n), causing part of the line to be read again later. 2. If a \r\n sequence happened to straddle two packets, read_line() would fail to notice the end of line. If packets contained only a single byte at a time, read_line() would never notice any end of line at all. Fixed and added tests. --- src/protocol/ascii.rs | 103 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 100 insertions(+), 3 deletions(-) diff --git a/src/protocol/ascii.rs b/src/protocol/ascii.rs index 5482197..cdd3f1a 100644 --- a/src/protocol/ascii.rs +++ b/src/protocol/ascii.rs @@ -107,9 +107,23 @@ impl CappedLineReader { return Err(ClientError::Error(Cow::Borrowed("Ascii protocol no line found")))?; } self.filled += read; - if let Some(n) = get_line(&buf[..read]) { - let result = cb(std::str::from_utf8(&self.buf[..filled + n])?); - self.consume(n); + + // Find the next \r\n. + let search_start; + let search_buf; + if filled > 0 { + // Start searching one character back, otherwise we would skip over \r\n + // sequences that happen to straddle packet boundaries. + search_start = filled - 1; + search_buf = &self.buf[search_start..read + 1]; + } else { + search_start = filled; + search_buf = buf; + } + + if let Some(n) = get_line(search_buf) { + let result = cb(std::str::from_utf8(&self.buf[..search_start + n])?); + self.consume(search_start + n); return result; } } @@ -580,3 +594,86 @@ impl AsciiProtocol { } } } + +#[cfg(test)] +mod test { + #[test] + fn test_read_line_with_line_straddling_packets() { + use super::CappedLineReader; + use std::io::Cursor; + use std::io::Seek; + use std::io::Write; + + let mut cursor = Cursor::new(Vec::new()); + // Write 102 * 20 = 2040 characters + for _ in 0..102 { + cursor.write(b"1234567890abcdefghij").unwrap(); + } + cursor + .write(b"\r\nline 2 to be read exactly\r\nline 3 to be sure\r\n") + .unwrap(); + cursor.seek(std::io::SeekFrom::Start(0)).unwrap(); + + let mut capped = CappedLineReader::new(cursor); + + let length = capped + .read_line(|line| { + assert_eq!(2042, line.len()); // 102 * 20 + "\r\n".len() + Ok(line.len()) + }) + .unwrap(); + assert_eq!(2042, length); + + let length = capped + .read_line(|line| { + assert_eq!("line 2 to be read exactly\r\n", line); + Ok(line.len()) + }) + .unwrap(); + assert_eq!(27, length); + + // Older versions would fail here because not all of the + // consumed bytes were marked as consumed. + let length = capped + .read_line(|line| { + assert_eq!("line 3 to be sure\r\n", line); + Ok(line.len()) + }) + .unwrap(); + assert_eq!(19, length); + } + + #[test] + fn test_read_line_with_crlf_straddling_packets() { + use super::CappedLineReader; + use std::io::Read; + + struct FourBytePacketReader { + content: Vec, + pos: usize, + } + + impl Read for FourBytePacketReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let size = (self.content.len() - self.pos).min(4); + buf[..size].copy_from_slice(&self.content[self.pos..self.pos + size]); + self.pos += size; + Ok(size) + } + } + + let inner = FourBytePacketReader { + content: Vec::from("GET\r\nOK\r\n"), + pos: 0, + }; + let mut capped = CappedLineReader::new(inner); + + let length = capped + .read_line(|line| { + assert_eq!("GET\r\n", line); + Ok(line.len()) + }) + .unwrap(); + assert_eq!(5, length); + } +} From 29f88ffd88eea8911f8bc93b064cef35828d2891 Mon Sep 17 00:00:00 2001 From: Shane Hathaway Date: Tue, 11 Feb 2020 01:09:21 -0700 Subject: [PATCH 6/7] Rename get_multi -> gets, set_multi -> sets, delete_multi -> deletes. More concise. --- src/client.rs | 30 ++++++++++-------------------- src/protocol/ascii.rs | 8 ++++---- src/protocol/binary.rs | 8 ++++---- src/protocol/mod.rs | 6 +++--- tests/tests.rs | 8 ++++---- 5 files changed, 25 insertions(+), 35 deletions(-) diff --git a/src/client.rs b/src/client.rs index c5ab502..1d2216e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -195,11 +195,6 @@ impl Client { (self.hash_function)(key) as usize % connections_count } - /// Alias for get_multi(). - pub fn gets(&self, keys: &[&str]) -> Result, MemcacheError> { - self.get_multi(keys) - } - /// Get multiple keys from memcached server. Using this function instead of calling `get` multiple times can reduce network workloads. /// /// Example: @@ -207,11 +202,11 @@ impl Client { /// ```rust /// let client = memcache::Client::connect("memcache://localhost:12345").unwrap(); /// client.set("foo", "42", 0).unwrap(); - /// let result: std::collections::HashMap = client.get_multi(&["foo", "bar", "baz"]).unwrap(); + /// let result: std::collections::HashMap = client.gets(&["foo", "bar", "baz"]).unwrap(); /// assert_eq!(result.len(), 1); /// assert_eq!(result["foo"], "42"); /// ``` - pub fn get_multi(&self, keys: I) -> Result, MemcacheError> + pub fn gets(&self, keys: I) -> Result, MemcacheError> where V: FromMemcacheValueExt, K: AsRef, @@ -225,7 +220,7 @@ impl Client { } for (&connection_index, keys) in con_keys.iter() { let connection = self.connections[connection_index].clone(); - result.extend(connection.get()?.get_multi(keys)?); + result.extend(connection.get()?.gets(keys)?); } return Ok(result); } @@ -251,10 +246,10 @@ impl Client { /// /// ```rust /// let client = memcache::Client::connect("memcache://localhost:12345").unwrap(); - /// client.set_multi(vec![("foo", "Foo"), ("bar", "Bar")], 10).unwrap(); + /// client.sets(vec![("foo", "Foo"), ("bar", "Bar")], 10).unwrap(); /// # client.flush().unwrap(); /// ``` - pub fn set_multi(&self, entries: I, expiration: u32) -> Result<(), MemcacheError> + pub fn sets(&self, entries: I, expiration: u32) -> Result<(), MemcacheError> where V: ToMemcacheValue, K: AsRef, @@ -270,7 +265,7 @@ impl Client { for (connection_index, entries_subset) in entry_map.into_iter() { let connection = self.connections[connection_index].clone(); - connection.get()?.set_multi(entries_subset, expiration)?; + connection.get()?.sets(entries_subset, expiration)?; } Ok(()) @@ -383,7 +378,7 @@ impl Client { return self.get_connection(key).get()?.delete(key); } - /// Delete keys from memcached server. + /// Delete multiple keys from memcached server. /// /// Uses pipelining to reduce the number of server round trips. /// @@ -391,10 +386,10 @@ impl Client { /// /// ```rust /// let client = memcache::Client::connect("memcache://localhost:12345").unwrap(); - /// client.delete_multi(&["foo", "bar"]).unwrap(); + /// client.deletes(&["foo", "bar"]).unwrap(); /// # client.flush().unwrap(); /// ``` - pub fn delete_multi(&self, keys: I) -> Result, MemcacheError> + pub fn deletes(&self, keys: I) -> Result, MemcacheError> where K: AsRef + Eq + Hash, I: IntoIterator, @@ -407,12 +402,7 @@ impl Client { let mut result: HashMap = HashMap::new(); for (connection_index, keys_subset) in con_keys.into_iter() { let connection = self.connections[connection_index].clone(); - for (deleted, key) in connection - .get()? - .delete_multi(&keys_subset)? - .into_iter() - .zip(keys_subset) - { + for (deleted, key) in connection.get()?.deletes(&keys_subset)?.into_iter().zip(keys_subset) { result.insert(key, deleted); } } diff --git a/src/protocol/ascii.rs b/src/protocol/ascii.rs index cdd3f1a..e217ddd 100644 --- a/src/protocol/ascii.rs +++ b/src/protocol/ascii.rs @@ -352,7 +352,7 @@ impl AsciiProtocol { } } - pub(super) fn get_multi, I: IntoIterator>( + pub(super) fn gets, I: IntoIterator>( &mut self, keys: I, ) -> Result, MemcacheError> { @@ -420,7 +420,7 @@ impl AsciiProtocol { self.store(StoreCommand::Set, key, value, &options).map(|_| ()) } - pub(super) fn set_multi, K: AsRef, I: IntoIterator>( + pub(super) fn sets, K: AsRef, I: IntoIterator>( &mut self, entries: I, expiration: u32, @@ -470,7 +470,7 @@ impl AsciiProtocol { .map(|_| ()) } - pub(super) fn delete_multi, I: IntoIterator>( + pub(super) fn deletes, I: IntoIterator>( &mut self, keys: I, ) -> Result, MemcacheError> { @@ -518,7 +518,7 @@ impl AsciiProtocol { } pub(super) fn delete(&mut self, key: &str) -> Result { - Ok(self.delete_multi(&[key])?[0]) + Ok(self.deletes(&[key])?[0]) } fn parse_u64_response(&mut self) -> Result { diff --git a/src/protocol/binary.rs b/src/protocol/binary.rs index 84eb220..b3d24a9 100644 --- a/src/protocol/binary.rs +++ b/src/protocol/binary.rs @@ -154,7 +154,7 @@ impl BinaryProtocol { return binary_packet::parse_get_response(&mut self.stream); } - pub(super) fn get_multi, I: IntoIterator>( + pub(super) fn gets, I: IntoIterator>( &mut self, keys: I, ) -> Result, MemcacheError> { @@ -204,7 +204,7 @@ impl BinaryProtocol { return self.store(Opcode::Set, key, value, expiration, None); } - pub(super) fn set_multi, K: AsRef, I: IntoIterator>( + pub(super) fn sets, K: AsRef, I: IntoIterator>( &mut self, entries: I, expiration: u32, @@ -263,10 +263,10 @@ impl BinaryProtocol { } pub(super) fn delete(&mut self, key: &str) -> Result { - Ok(self.delete_multi(&[key])?[0]) + Ok(self.deletes(&[key])?[0]) } - pub(super) fn delete_multi, I: IntoIterator>( + pub(super) fn deletes, I: IntoIterator>( &mut self, keys: I, ) -> Result, MemcacheError> { diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index de349db..188cf25 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -32,12 +32,12 @@ pub trait ProtocolTrait { fn flush(&mut self) -> Result<(), MemcacheError>; fn flush_with_delay(&mut self, delay: u32) -> Result<(), MemcacheError>; fn get(&mut self, key: &str) -> Result, MemcacheError>; - fn get_multi, I: IntoIterator>( + fn gets, I: IntoIterator>( &mut self, keys: I, ) -> Result, MemcacheError>; fn set>(&mut self, key: &str, value: V, expiration: u32) -> Result<(), MemcacheError>; - fn set_multi, K: AsRef, I: IntoIterator>( + fn sets, K: AsRef, I: IntoIterator>( &mut self, entries: I, expiration: u32, @@ -59,7 +59,7 @@ pub trait ProtocolTrait { fn append>(&mut self, key: &str, value: V) -> Result<(), MemcacheError>; fn prepend>(&mut self, key: &str, value: V) -> Result<(), MemcacheError>; fn delete(&mut self, key: &str) -> Result; - fn delete_multi, I: IntoIterator>(&mut self, keys: I) -> Result, MemcacheError>; + fn deletes, I: IntoIterator>(&mut self, keys: I) -> Result, MemcacheError>; fn increment(&mut self, key: &str, amount: u64) -> Result; fn decrement(&mut self, key: &str, amount: u64) -> Result; fn touch(&mut self, key: &str, expiration: u32) -> Result; diff --git a/tests/tests.rs b/tests/tests.rs index 06f4d65..5f931b7 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -59,9 +59,9 @@ fn test() { batch.insert(key, "xxx".to_string()); } } - client.set_multi(batch, 0).unwrap(); + client.sets(batch, 0).unwrap(); - let all_at_once: HashMap = client.get_multi(&keys).unwrap(); + let all_at_once: HashMap = client.gets(&keys).unwrap(); assert_eq!(1000, all_at_once.len()); for key in keys.iter() { @@ -70,8 +70,8 @@ fn test() { assert_eq!(all_at_once[key], "xxx"); } - client.delete_multi(&keys).unwrap(); - let all_at_once: HashMap = client.get_multi(&keys).unwrap(); + client.deletes(&keys).unwrap(); + let all_at_once: HashMap = client.gets(&keys).unwrap(); assert_eq!(0, all_at_once.len()); } From 25ddf82ad5a7d75ff0d8ef363924e216618095d8 Mon Sep 17 00:00:00 2001 From: Shane Hathaway Date: Tue, 11 Feb 2020 22:58:38 -0700 Subject: [PATCH 7/7] Implement most of the suggestions by @letmutx. - In ascii::gets, instead of building a string, use BufWriter. Also remove unnecessary allocation. There is some necessary allocation due to the signature of the method. We want to maintain the signature. - Don't add unnecessary indentation just to drop variables. Instead, use drop() or avoid creating variables that need to be dropped. - Reworked the way we collect pipelined responses from the server. Now there's a simple ``final_result`` var and we update that var for each result. - Handle recoverable errors differently from unrecoverable errors. Exit fast on most errors, since they're unrecoverable, but collect all responses on CommandError, since CommandError means we got a full response. --- src/protocol/ascii.rs | 170 ++++++++++++++++++++++------------------- src/protocol/binary.rs | 43 +++++++---- 2 files changed, 121 insertions(+), 92 deletions(-) diff --git a/src/protocol/ascii.rs b/src/protocol/ascii.rs index e217ddd..aba9ce0 100644 --- a/src/protocol/ascii.rs +++ b/src/protocol/ascii.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; use std::fmt; -use std::io::{Read, Write}; +use std::io::{BufWriter, Read, Write}; use super::check_key_len; use client::Stats; @@ -182,79 +182,86 @@ impl AsciiProtocol { let noreply = if options.noreply { " noreply" } else { "" }; let mut sent_count = 0; - { - let reader = self.reader.get_mut(); - for (key_ref, value) in entries.into_iter() { - let key = key_ref.as_ref(); - check_key_len(key)?; - if options.cas.is_some() { - write!( - reader, - "{command} {key} {flags} {exptime} {vlen} {cas}{noreply}\r\n", - command = command, - key = key, - flags = value.get_flags(), - exptime = options.exptime, - vlen = value.get_length(), - cas = options.cas.unwrap(), - noreply = noreply - )?; - } else { - write!( - reader, - "{command} {key} {flags} {exptime} {vlen}{noreply}\r\n", - command = command, - key = key, - flags = value.get_flags(), - exptime = options.exptime, - vlen = value.get_length(), - noreply = noreply - )?; - } - - value.write_to(reader)?; - reader.write(b"\r\n")?; - sent_count += 1; + for (key_ref, value) in entries.into_iter() { + let key = key_ref.as_ref(); + check_key_len(key)?; + if options.cas.is_some() { + write!( + self.reader.get_mut(), + "{command} {key} {flags} {exptime} {vlen} {cas}{noreply}\r\n", + command = command, + key = key, + flags = value.get_flags(), + exptime = options.exptime, + vlen = value.get_length(), + cas = options.cas.unwrap(), + noreply = noreply + )?; + } else { + write!( + self.reader.get_mut(), + "{command} {key} {flags} {exptime} {vlen}{noreply}\r\n", + command = command, + key = key, + flags = value.get_flags(), + exptime = options.exptime, + vlen = value.get_length(), + noreply = noreply + )?; } - // Flush now that all the requests have been written. - reader.flush()?; + value.write_to(self.reader.get_mut())?; + self.reader.get_mut().write_all(b"\r\n")?; + sent_count += 1; } + // Flush now that all the requests have been written. + self.reader.get_mut().flush()?; + if options.noreply { return Ok(true); } - // Receive all the responses. If there were errors, return the first. + // In order to keep the client in sync with the server, + // read all the responses, even after an EXISTS or NOT_FOUND + // error, unless some other error occurs. + // If there were errors, return the first error. + + let mut final_result = Ok(true); - let mut all_stored = true; - let mut error_list: Vec = Vec::new(); for _ in 0..sent_count { let one_result = self.reader.read_line(|response| { let response = MemcacheError::try_from(response)?; match response { "STORED\r\n" => Ok(true), "NOT_STORED\r\n" => Ok(false), - "EXISTS\r\n" => Err(CommandError::KeyExists)?, - "NOT_FOUND\r\n" => Err(CommandError::KeyNotFound)?, - response => Err(ServerError::BadResponse(Cow::Owned(response.into())))?, + "EXISTS\r\n" => Err(CommandError::KeyExists.into()), + "NOT_FOUND\r\n" => Err(CommandError::KeyNotFound.into()), + response => Err(ServerError::BadResponse(Cow::Owned(response.into())).into()), } }); match one_result { Ok(true) => (), - Ok(false) => all_stored = false, - Err(e) => error_list.push(e), + Ok(false) => { + if let Ok(true) = final_result { + final_result = Ok(false) + } + } + Err(MemcacheError::CommandError(e)) => { + // Recoverable error. Report it after reading the rest of the responses. + if final_result.is_ok() { + final_result = Err(MemcacheError::CommandError(e)); + } + } + Err(e) => return Err(e), // Unrecoverable error. Stop immediately. } } - match error_list.into_iter().next() { - None => Ok(all_stored), - Some(e) => Err(e), - } + final_result } pub(super) fn version(&mut self) -> Result { - self.reader.get_mut().write(b"version\r\n")?; + self.reader.get_mut().write_all(b"version\r\n")?; self.reader.get_mut().flush()?; self.reader.read_line(|response| { let response = MemcacheError::try_from(response)?; @@ -356,21 +363,23 @@ impl AsciiProtocol { &mut self, keys: I, ) -> Result, MemcacheError> { - let keys: Vec = keys.into_iter().collect(); - let mut capacity = 0; + // Note: it would be nice to avoid allocation here, but we have to allocate strings + // anyway because the input key type is a reference while the output key type is String. + let keys: Vec = keys.into_iter().map(|s| s.as_ref().to_string()).collect(); + for k in keys.iter() { - let key = k.as_ref(); - check_key_len(key)?; - capacity += key.len() + 1; + check_key_len(k)?; } - let mut keystr = String::with_capacity(capacity); + let mut writer = BufWriter::new(self.reader.get_mut()); + writer.write_all(b"gets")?; for k in keys.iter() { - keystr.push(' '); - keystr.push_str(k.as_ref()); + writer.write_all(b" ")?; + writer.write_all(k.as_bytes())?; } - - write!(self.reader.get_mut(), "gets{}\r\n", keystr)?; + writer.write_all(b"\r\n")?; + writer.flush()?; + drop(writer); let mut result: HashMap = HashMap::with_capacity(keys.len()); // there will be atmost keys.len() "VALUE <...>" responses and one END response @@ -475,22 +484,19 @@ impl AsciiProtocol { keys: I, ) -> Result, MemcacheError> { let mut sent_count = 0; - { - let reader = self.reader.get_mut(); - for k in keys.into_iter() { - let key = k.as_ref(); - check_key_len(key)?; - write!(reader, "delete {}\r\n", key)?; - sent_count += 1; - } - // Flush now that all the requests have been written. - reader.flush()?; + for k in keys.into_iter() { + let key = k.as_ref(); + check_key_len(key)?; + write!(self.reader.get_mut(), "delete {}\r\n", key)?; + sent_count += 1; } + // Flush now that all the requests have been written. + self.reader.get_mut().flush()?; // Receive all the responses. If there were errors, return the first. - let mut deleted_list = Vec::with_capacity(sent_count); - let mut error_list: Vec = Vec::new(); + let mut final_result = Ok(Vec::with_capacity(sent_count)); + for _ in 0..sent_count { let one_result = self .reader @@ -505,16 +511,24 @@ impl AsciiProtocol { Err(MemcacheError::CommandError(CommandError::KeyNotFound)) => Ok(false), Err(e) => Err(e), }); + match one_result { - Ok(deleted) => deleted_list.push(deleted), - Err(e) => error_list.push(e), + Ok(deleted) => { + if let Ok(deleted_list) = &mut final_result { + deleted_list.push(deleted); + } + } + Err(MemcacheError::CommandError(e)) => { + // Recoverable error. Report it after reading the rest of the responses. + if final_result.is_ok() { + final_result = Err(MemcacheError::CommandError(e)); + } + } + Err(e) => return Err(e), // Unrecoverable error. Stop immediately. } } - match error_list.into_iter().next() { - None => Ok(deleted_list), - Some(e) => Err(e), - } + final_result } pub(super) fn delete(&mut self, key: &str) -> Result { @@ -559,7 +573,7 @@ impl AsciiProtocol { } pub(super) fn stats(&mut self) -> Result { - self.reader.get_mut().write(b"stats\r\n")?; + self.reader.get_mut().write_all(b"stats\r\n")?; self.reader.get_mut().flush()?; enum Loop { diff --git a/src/protocol/binary.rs b/src/protocol/binary.rs index b3d24a9..27f7ce7 100644 --- a/src/protocol/binary.rs +++ b/src/protocol/binary.rs @@ -88,18 +88,24 @@ impl BinaryProtocol { } // Flush now that all the requests have been written. self.stream.flush()?; + // Receive all the responses. If there were errors, return the first. - let mut error_list = Vec::new(); + let mut final_result = Ok(()); + for _ in 0..sent_count { match binary_packet::parse_response(&mut self.stream) { Ok(_) => (), - Err(e) => error_list.push(e), + Err(MemcacheError::CommandError(e)) => { + // Recoverable error. Report it after reading the rest of the responses. + if final_result.is_ok() { + final_result = Err(MemcacheError::CommandError(e)); + } + } + Err(e) => return Err(e), // Unrecoverable error. Stop immediately. }; } - match error_list.into_iter().next() { - None => Ok(()), - Some(e) => Err(e), - } + + final_result } pub(super) fn version(&mut self) -> Result { @@ -287,19 +293,28 @@ impl BinaryProtocol { } // Flush now that all the requests have been written. self.stream.flush()?; + // Receive all the responses. If there were errors, return the first. - let mut deleted_list = Vec::with_capacity(sent_count); - let mut error_list: Vec = Vec::new(); + let mut final_result = Ok(Vec::with_capacity(sent_count)); + for _ in 0..sent_count { match binary_packet::parse_delete_response(&mut self.stream) { - Ok(deleted) => deleted_list.push(deleted), - Err(e) => error_list.push(e), + Ok(deleted) => { + if let Ok(deleted_list) = &mut final_result { + deleted_list.push(deleted); + } + } + Err(MemcacheError::CommandError(e)) => { + // Recoverable error. Report it after reading the rest of the responses. + if final_result.is_ok() { + final_result = Err(MemcacheError::CommandError(e)); + } + } + Err(e) => return Err(e), // Unrecoverable error. Stop immediately. } } - match error_list.into_iter().next() { - None => Ok(deleted_list), - Some(e) => Err(e), - } + + final_result } pub(super) fn increment(&mut self, key: &str, amount: u64) -> Result {