diff --git a/src/client.rs b/src/client.rs index 8cd5ab0..1d2216e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -189,7 +189,13 @@ impl Client { return self.get_connection(key).get()?.get(key); } - /// Get multiple keys from memcached server. Using this function instead of calling `get` multiple times can reduce netwark workloads. + /// Map a key to a connection index. + fn hash_key(&self, key: &str) -> usize { + let connections_count = self.connections.len(); + (self.hash_function)(key) as usize % connections_count + } + + /// Get multiple keys from memcached server. Using this function instead of calling `get` multiple times can reduce network workloads. /// /// Example: /// @@ -200,15 +206,17 @@ impl Client { /// assert_eq!(result.len(), 1); /// assert_eq!(result["foo"], "42"); /// ``` - pub fn gets(&self, keys: &[&str]) -> Result, MemcacheError> { - let mut con_keys: HashMap> = HashMap::new(); + pub fn gets(&self, keys: I) -> Result, MemcacheError> + where + V: FromMemcacheValueExt, + K: AsRef, + I: IntoIterator, + { + let mut con_keys: HashMap> = HashMap::new(); let mut result: HashMap = HashMap::new(); - let connections_count = self.connections.len(); - for key in keys { - let connection_index = (self.hash_function)(key) as usize % connections_count; - let array = con_keys.entry(connection_index).or_insert_with(Vec::new); - array.push(key); + for k in keys { + con_keys.entry(self.hash_key(k.as_ref())).or_default().push(k); } for (&connection_index, keys) in con_keys.iter() { let connection = self.connections[connection_index].clone(); @@ -230,6 +238,39 @@ impl Client { return self.get_connection(key).get()?.set(key, value, expiration); } + /// Set multiple keys with associated values into memcached server with expiration seconds. + /// + /// Uses pipelining to reduce the number of server round trips. + /// + /// Example: + /// + /// ```rust + /// let client = memcache::Client::connect("memcache://localhost:12345").unwrap(); + /// client.sets(vec![("foo", "Foo"), ("bar", "Bar")], 10).unwrap(); + /// # client.flush().unwrap(); + /// ``` + pub fn sets(&self, entries: I, expiration: u32) -> Result<(), MemcacheError> + where + V: ToMemcacheValue, + K: AsRef, + I: IntoIterator, + { + let mut entry_map: HashMap> = HashMap::new(); + for (key, value) in entries.into_iter() { + entry_map + .entry(self.hash_key(key.as_ref())) + .or_default() + .push((key, value)); + } + + for (connection_index, entries_subset) in entry_map.into_iter() { + let connection = self.connections[connection_index].clone(); + connection.get()?.sets(entries_subset, expiration)?; + } + + Ok(()) + } + /// Compare and swap a key with the associate value into memcached server with expiration seconds. /// `cas_id` should be obtained from a previous `gets` call. /// @@ -337,6 +378,37 @@ impl Client { return self.get_connection(key).get()?.delete(key); } + /// Delete multiple keys from memcached server. + /// + /// Uses pipelining to reduce the number of server round trips. + /// + /// Example: + /// + /// ```rust + /// let client = memcache::Client::connect("memcache://localhost:12345").unwrap(); + /// client.deletes(&["foo", "bar"]).unwrap(); + /// # client.flush().unwrap(); + /// ``` + pub fn deletes(&self, keys: I) -> Result, MemcacheError> + where + K: AsRef + Eq + Hash, + I: IntoIterator, + { + let mut con_keys: HashMap> = HashMap::new(); + for key in keys.into_iter() { + con_keys.entry(self.hash_key(key.as_ref())).or_default().push(key); + } + + let mut result: HashMap = HashMap::new(); + for (connection_index, keys_subset) in con_keys.into_iter() { + let connection = self.connections[connection_index].clone(); + for (deleted, key) in connection.get()?.deletes(&keys_subset)?.into_iter().zip(keys_subset) { + result.insert(key, deleted); + } + } + Ok(result) + } + /// Increment the value with amount. /// /// Example: diff --git a/src/protocol/ascii.rs b/src/protocol/ascii.rs index 45014b3..e217ddd 100644 --- a/src/protocol/ascii.rs +++ b/src/protocol/ascii.rs @@ -107,9 +107,23 @@ impl CappedLineReader { return Err(ClientError::Error(Cow::Borrowed("Ascii protocol no line found")))?; } self.filled += read; - if let Some(n) = get_line(&buf[..read]) { - let result = cb(std::str::from_utf8(&self.buf[..filled + n])?); - self.consume(n); + + // Find the next \r\n. + let search_start; + let search_buf; + if filled > 0 { + // Start searching one character back, otherwise we would skip over \r\n + // sequences that happen to straddle packet boundaries. + search_start = filled - 1; + search_buf = &self.buf[search_start..read + 1]; + } else { + search_start = filled; + search_buf = buf; + } + + if let Some(n) = get_line(search_buf) { + let result = cb(std::str::from_utf8(&self.buf[..search_start + n])?); + self.consume(search_start + n); return result; } } @@ -148,7 +162,15 @@ impl AsciiProtocol { value: V, options: &Options, ) -> Result { - check_key_len(key)?; + Ok(self.stores(command, Some((key, value)), options)?) + } + + fn stores, K: AsRef, I: IntoIterator>( + &mut self, + command: StoreCommand, + entries: I, + options: &Options, + ) -> Result { if command == StoreCommand::Cas { if options.cas.is_none() { Err(ClientError::Error(Cow::Borrowed( @@ -156,50 +178,79 @@ impl AsciiProtocol { )))?; } } + let noreply = if options.noreply { " noreply" } else { "" }; - if options.cas.is_some() { - write!( - self.reader.get_mut(), - "{command} {key} {flags} {exptime} {vlen} {cas}{noreply}\r\n", - command = command, - key = key, - flags = value.get_flags(), - exptime = options.exptime, - vlen = value.get_length(), - cas = options.cas.unwrap(), - noreply = noreply - )?; - } else { - write!( - self.reader.get_mut(), - "{command} {key} {flags} {exptime} {vlen}{noreply}\r\n", - command = command, - key = key, - flags = value.get_flags(), - exptime = options.exptime, - vlen = value.get_length(), - noreply = noreply - )?; - } + let mut sent_count = 0; + + { + let reader = self.reader.get_mut(); + for (key_ref, value) in entries.into_iter() { + let key = key_ref.as_ref(); + check_key_len(key)?; + if options.cas.is_some() { + write!( + reader, + "{command} {key} {flags} {exptime} {vlen} {cas}{noreply}\r\n", + command = command, + key = key, + flags = value.get_flags(), + exptime = options.exptime, + vlen = value.get_length(), + cas = options.cas.unwrap(), + noreply = noreply + )?; + } else { + write!( + reader, + "{command} {key} {flags} {exptime} {vlen}{noreply}\r\n", + command = command, + key = key, + flags = value.get_flags(), + exptime = options.exptime, + vlen = value.get_length(), + noreply = noreply + )?; + } - value.write_to(self.reader.get_mut())?; - self.reader.get_mut().write(b"\r\n")?; - self.reader.get_mut().flush()?; + value.write_to(reader)?; + reader.write(b"\r\n")?; + sent_count += 1; + } + + // Flush now that all the requests have been written. + reader.flush()?; + } if options.noreply { return Ok(true); } - self.reader.read_line(|response| { - let response = MemcacheError::try_from(response)?; - match response { - "STORED\r\n" => Ok(true), - "NOT_STORED\r\n" => Ok(false), - "EXISTS\r\n" => Err(CommandError::KeyExists)?, - "NOT_FOUND\r\n" => Err(CommandError::KeyNotFound)?, - response => Err(ServerError::BadResponse(Cow::Owned(response.into())))?, + // Receive all the responses. If there were errors, return the first. + + let mut all_stored = true; + let mut error_list: Vec = Vec::new(); + for _ in 0..sent_count { + let one_result = self.reader.read_line(|response| { + let response = MemcacheError::try_from(response)?; + match response { + "STORED\r\n" => Ok(true), + "NOT_STORED\r\n" => Ok(false), + "EXISTS\r\n" => Err(CommandError::KeyExists)?, + "NOT_FOUND\r\n" => Err(CommandError::KeyNotFound)?, + response => Err(ServerError::BadResponse(Cow::Owned(response.into())))?, + } + }); + match one_result { + Ok(true) => (), + Ok(false) => all_stored = false, + Err(e) => error_list.push(e), } - }) + } + + match error_list.into_iter().next() { + None => Ok(all_stored), + Some(e) => Err(e), + } } pub(super) fn version(&mut self) -> Result { @@ -301,11 +352,25 @@ impl AsciiProtocol { } } - pub(super) fn gets(&mut self, keys: &[&str]) -> Result, MemcacheError> { - for key in keys { + pub(super) fn gets, I: IntoIterator>( + &mut self, + keys: I, + ) -> Result, MemcacheError> { + let keys: Vec = keys.into_iter().collect(); + let mut capacity = 0; + for k in keys.iter() { + let key = k.as_ref(); check_key_len(key)?; + capacity += key.len() + 1; } - write!(self.reader.get_mut(), "gets {}\r\n", keys.join(" "))?; + + let mut keystr = String::with_capacity(capacity); + for k in keys.iter() { + keystr.push(' '); + keystr.push_str(k.as_ref()); + } + + write!(self.reader.get_mut(), "gets{}\r\n", keystr)?; let mut result: HashMap = HashMap::with_capacity(keys.len()); // there will be atmost keys.len() "VALUE <...>" responses and one END response @@ -355,6 +420,18 @@ impl AsciiProtocol { self.store(StoreCommand::Set, key, value, &options).map(|_| ()) } + pub(super) fn sets, K: AsRef, I: IntoIterator>( + &mut self, + entries: I, + expiration: u32, + ) -> Result<(), MemcacheError> { + let options = Options { + exptime: expiration, + ..Default::default() + }; + self.stores(StoreCommand::Set, entries, &options).map(|_| ()) + } + pub(super) fn add>( &mut self, key: &str, @@ -393,22 +470,55 @@ impl AsciiProtocol { .map(|_| ()) } - pub(super) fn delete(&mut self, key: &str) -> Result { - check_key_len(key)?; - write!(self.reader.get_mut(), "delete {}\r\n", key)?; - self.reader.get_mut().flush()?; - self.reader - .read_line(|response| match MemcacheError::try_from(response) { - Ok(s) => { - if s == "DELETED\r\n" { - Ok(true) - } else { - Err(ServerError::BadResponse(Cow::Owned(s.into())).into()) + pub(super) fn deletes, I: IntoIterator>( + &mut self, + keys: I, + ) -> Result, MemcacheError> { + let mut sent_count = 0; + { + let reader = self.reader.get_mut(); + for k in keys.into_iter() { + let key = k.as_ref(); + check_key_len(key)?; + write!(reader, "delete {}\r\n", key)?; + sent_count += 1; + } + // Flush now that all the requests have been written. + reader.flush()?; + } + + // Receive all the responses. If there were errors, return the first. + + let mut deleted_list = Vec::with_capacity(sent_count); + let mut error_list: Vec = Vec::new(); + for _ in 0..sent_count { + let one_result = self + .reader + .read_line(|response| match MemcacheError::try_from(response) { + Ok(s) => { + if s == "DELETED\r\n" { + Ok(true) + } else { + Err(ServerError::BadResponse(Cow::Owned(s.into())).into()) + } } - } - Err(MemcacheError::CommandError(CommandError::KeyNotFound)) => Ok(false), - Err(e) => Err(e), - }) + Err(MemcacheError::CommandError(CommandError::KeyNotFound)) => Ok(false), + Err(e) => Err(e), + }); + match one_result { + Ok(deleted) => deleted_list.push(deleted), + Err(e) => error_list.push(e), + } + } + + match error_list.into_iter().next() { + None => Ok(deleted_list), + Some(e) => Err(e), + } + } + + pub(super) fn delete(&mut self, key: &str) -> Result { + Ok(self.deletes(&[key])?[0]) } fn parse_u64_response(&mut self) -> Result { @@ -484,3 +594,86 @@ impl AsciiProtocol { } } } + +#[cfg(test)] +mod test { + #[test] + fn test_read_line_with_line_straddling_packets() { + use super::CappedLineReader; + use std::io::Cursor; + use std::io::Seek; + use std::io::Write; + + let mut cursor = Cursor::new(Vec::new()); + // Write 102 * 20 = 2040 characters + for _ in 0..102 { + cursor.write(b"1234567890abcdefghij").unwrap(); + } + cursor + .write(b"\r\nline 2 to be read exactly\r\nline 3 to be sure\r\n") + .unwrap(); + cursor.seek(std::io::SeekFrom::Start(0)).unwrap(); + + let mut capped = CappedLineReader::new(cursor); + + let length = capped + .read_line(|line| { + assert_eq!(2042, line.len()); // 102 * 20 + "\r\n".len() + Ok(line.len()) + }) + .unwrap(); + assert_eq!(2042, length); + + let length = capped + .read_line(|line| { + assert_eq!("line 2 to be read exactly\r\n", line); + Ok(line.len()) + }) + .unwrap(); + assert_eq!(27, length); + + // Older versions would fail here because not all of the + // consumed bytes were marked as consumed. + let length = capped + .read_line(|line| { + assert_eq!("line 3 to be sure\r\n", line); + Ok(line.len()) + }) + .unwrap(); + assert_eq!(19, length); + } + + #[test] + fn test_read_line_with_crlf_straddling_packets() { + use super::CappedLineReader; + use std::io::Read; + + struct FourBytePacketReader { + content: Vec, + pos: usize, + } + + impl Read for FourBytePacketReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let size = (self.content.len() - self.pos).min(4); + buf[..size].copy_from_slice(&self.content[self.pos..self.pos + size]); + self.pos += size; + Ok(size) + } + } + + let inner = FourBytePacketReader { + content: Vec::from("GET\r\nOK\r\n"), + pos: 0, + }; + let mut capped = CappedLineReader::new(inner); + + let length = capped + .read_line(|line| { + assert_eq!("GET\r\n", line); + Ok(line.len()) + }) + .unwrap(); + assert_eq!(5, length); + } +} diff --git a/src/protocol/binary.rs b/src/protocol/binary.rs index 39367f8..b3d24a9 100644 --- a/src/protocol/binary.rs +++ b/src/protocol/binary.rs @@ -57,7 +57,7 @@ impl BinaryProtocol { self.stream.write_u32::(extras.expiration)?; self.stream.write_all(key.as_bytes())?; value.write_to(&mut self.stream)?; - self.stream.flush().map_err(Into::into) + Ok(()) } fn store>( @@ -69,9 +69,39 @@ impl BinaryProtocol { cas: Option, ) -> Result<(), MemcacheError> { self.send_request(opcode, key, value, expiration, cas)?; + self.stream.flush()?; binary_packet::parse_response(&mut self.stream)?.err().map(|_| ()) } + /// Support efficient multi-store operations using pipelining. + fn stores, K: AsRef, I: IntoIterator>( + &mut self, + opcode: Opcode, + entries: I, + expiration: u32, + cas: Option, + ) -> Result<(), MemcacheError> { + let mut sent_count = 0; + for (key, value) in entries.into_iter() { + self.send_request(opcode, key.as_ref(), value, expiration, cas)?; + sent_count += 1; + } + // Flush now that all the requests have been written. + self.stream.flush()?; + // Receive all the responses. If there were errors, return the first. + let mut error_list = Vec::new(); + for _ in 0..sent_count { + match binary_packet::parse_response(&mut self.stream) { + Ok(_) => (), + Err(e) => error_list.push(e), + }; + } + match error_list.into_iter().next() { + None => Ok(()), + Some(e) => Err(e), + } + } + pub(super) fn version(&mut self) -> Result { let request_header = PacketHeader { magic: Magic::Request as u8, @@ -124,8 +154,13 @@ impl BinaryProtocol { return binary_packet::parse_get_response(&mut self.stream); } - pub(super) fn gets(&mut self, keys: &[&str]) -> Result, MemcacheError> { - for key in keys { + pub(super) fn gets, I: IntoIterator>( + &mut self, + keys: I, + ) -> Result, MemcacheError> { + let mut count = 0; + for k in keys.into_iter() { + let key = k.as_ref(); check_key_len(key)?; let request_header = PacketHeader { magic: Magic::Request as u8, @@ -136,6 +171,7 @@ impl BinaryProtocol { }; request_header.write(&mut self.stream)?; self.stream.write_all(key.as_bytes())?; + count += 1; } let noop_request_header = PacketHeader { magic: Magic::Request as u8, @@ -143,7 +179,8 @@ impl BinaryProtocol { ..Default::default() }; noop_request_header.write(&mut self.stream)?; - return binary_packet::parse_gets_response(&mut self.stream, keys.len()); + self.stream.flush()?; + return binary_packet::parse_gets_response(&mut self.stream, count); } pub(super) fn cas>( @@ -154,6 +191,7 @@ impl BinaryProtocol { cas: u64, ) -> Result { self.send_request(Opcode::Set, key, value, expiration, Some(cas))?; + self.stream.flush()?; binary_packet::parse_cas_response(&mut self.stream) } @@ -166,6 +204,14 @@ impl BinaryProtocol { return self.store(Opcode::Set, key, value, expiration, None); } + pub(super) fn sets, K: AsRef, I: IntoIterator>( + &mut self, + entries: I, + expiration: u32, + ) -> Result<(), MemcacheError> { + return self.stores(Opcode::Set, entries, expiration, None); + } + pub(super) fn add>( &mut self, key: &str, @@ -217,18 +263,43 @@ impl BinaryProtocol { } pub(super) fn delete(&mut self, key: &str) -> Result { - check_key_len(key)?; - let request_header = PacketHeader { - magic: Magic::Request as u8, - opcode: Opcode::Delete as u8, - key_length: key.len() as u16, - total_body_length: key.len() as u32, - ..Default::default() - }; - request_header.write(&mut self.stream)?; - self.stream.write_all(key.as_bytes())?; + Ok(self.deletes(&[key])?[0]) + } + + pub(super) fn deletes, I: IntoIterator>( + &mut self, + keys: I, + ) -> Result, MemcacheError> { + let mut sent_count = 0; + for k in keys.into_iter() { + check_key_len(k.as_ref())?; + let key = k.as_ref(); + let request_header = PacketHeader { + magic: Magic::Request as u8, + opcode: Opcode::Delete as u8, + key_length: key.len() as u16, + total_body_length: key.len() as u32, + ..Default::default() + }; + request_header.write(&mut self.stream)?; + self.stream.write_all(key.as_bytes())?; + sent_count += 1; + } + // Flush now that all the requests have been written. self.stream.flush()?; - return binary_packet::parse_delete_response(&mut self.stream); + // Receive all the responses. If there were errors, return the first. + let mut deleted_list = Vec::with_capacity(sent_count); + let mut error_list: Vec = Vec::new(); + for _ in 0..sent_count { + match binary_packet::parse_delete_response(&mut self.stream) { + Ok(deleted) => deleted_list.push(deleted), + Err(e) => error_list.push(e), + } + } + match error_list.into_iter().next() { + None => Ok(deleted_list), + Some(e) => Err(e), + } } pub(super) fn increment(&mut self, key: &str, amount: u64) -> Result { diff --git a/src/protocol/binary_packet.rs b/src/protocol/binary_packet.rs index 0a11a87..dcdfb7b 100644 --- a/src/protocol/binary_packet.rs +++ b/src/protocol/binary_packet.rs @@ -8,6 +8,7 @@ use value::FromMemcacheValueExt; const OK_STATUS: u16 = 0x0; #[allow(dead_code)] +#[derive(Copy, Clone)] pub enum Opcode { Get = 0x00, Set = 0x01, diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 56d667b..188cf25 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -32,8 +32,16 @@ pub trait ProtocolTrait { fn flush(&mut self) -> Result<(), MemcacheError>; fn flush_with_delay(&mut self, delay: u32) -> Result<(), MemcacheError>; fn get(&mut self, key: &str) -> Result, MemcacheError>; - fn gets(&mut self, keys: &[&str]) -> Result, MemcacheError>; + fn gets, I: IntoIterator>( + &mut self, + keys: I, + ) -> Result, MemcacheError>; fn set>(&mut self, key: &str, value: V, expiration: u32) -> Result<(), MemcacheError>; + fn sets, K: AsRef, I: IntoIterator>( + &mut self, + entries: I, + expiration: u32, + ) -> Result<(), MemcacheError>; fn cas>( &mut self, key: &str, @@ -51,6 +59,7 @@ pub trait ProtocolTrait { fn append>(&mut self, key: &str, value: V) -> Result<(), MemcacheError>; fn prepend>(&mut self, key: &str, value: V) -> Result<(), MemcacheError>; fn delete(&mut self, key: &str) -> Result; + fn deletes, I: IntoIterator>(&mut self, keys: I) -> Result, MemcacheError>; fn increment(&mut self, key: &str, amount: u64) -> Result; fn decrement(&mut self, key: &str, amount: u64) -> Result; fn touch(&mut self, key: &str, expiration: u32) -> Result; diff --git a/tests/tests.rs b/tests/tests.rs index a50049a..5f931b7 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -17,6 +17,8 @@ fn gen_random_key() -> String { #[test] fn test() { + use std::collections::HashMap; + let mut urls = vec![ "memcache://localhost:12346?tcp_nodelay=true", "memcache://localhost:12347?timeout=10", @@ -45,16 +47,32 @@ fn test() { assert_eq!(value, None); let mut keys: Vec = Vec::new(); - for _ in 0..1000 { + let mut batch: HashMap = HashMap::new(); + for i in 0..1000 { let key = gen_random_key(); keys.push(key.clone()); - client.set(key.as_str(), "xxx", 0).unwrap(); + if i < 10 { + // Set the first 10 one at a time + client.set(key.as_str(), "xxx", 0).unwrap(); + } else { + // Set the rest as a batch + batch.insert(key, "xxx".to_string()); + } } + client.sets(batch, 0).unwrap(); - for key in keys { + let all_at_once: HashMap = client.gets(&keys).unwrap(); + assert_eq!(1000, all_at_once.len()); + + for key in keys.iter() { let value: String = client.get(key.as_str()).unwrap().unwrap(); assert_eq!(value, "xxx"); + assert_eq!(all_at_once[key], "xxx"); } + + client.deletes(&keys).unwrap(); + let all_at_once: HashMap = client.gets(&keys).unwrap(); + assert_eq!(0, all_at_once.len()); } #[test]