Skip to content

Commit

Permalink
add some new methods to extensions (#4)
Browse files Browse the repository at this point in the history
* add some new methods to extensions

* docs
  • Loading branch information
conradludgate committed Aug 30, 2022
1 parent 3f75e39 commit 8e1c6df
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 20 deletions.
4 changes: 1 addition & 3 deletions README.md
Expand Up @@ -45,9 +45,7 @@ async fn my_task() {
}

let a: i64 = 3;
let mut ext = Extensions::new();
ext.insert(a);
let (out_ext, _) = with_extensions(ext, my_task()).await;
let (out_ext, _) = with_extensions(Extensions::new().with(a), my_task()).await;
let msg = out_ext.get::<String>().unwrap();
assert_eq!(msg.as_str(), "The value of a is: 3");
```
Expand Down
43 changes: 29 additions & 14 deletions src/extensions.rs
Expand Up @@ -15,46 +15,62 @@ use std::hash::{BuildHasherDefault, Hasher};
/// on the outgoing path (e.g. error class).
#[derive(Default)]
pub struct Extensions {
map: Option<HashMap<TypeId, Box<dyn Any + Send + Sync>, BuildHasherDefault<IdHasher>>>,
map: HashMap<TypeId, Box<dyn Any + Send + Sync>, BuildHasherDefault<IdHasher>>,
}

impl Extensions {
/// Create an empty `Extensions`.
pub fn new() -> Self {
Self { map: None }
Self {
map: HashMap::default(),
}
}

/// Insert a value ino this [`Extensions`], returning self instead of any pre-inserted values.
///
/// This is useful for any builder style patterns
///
/// ```
/// # use task_local_extensions::Extensions;
/// let ext = Extensions::new().with(true).with(5_i32);
/// assert_eq!(ext.get(), Some(&true));
/// assert_eq!(ext.get(), Some(&5_i32));
/// ```
pub fn with<T: Send + Sync + 'static>(mut self, val: T) -> Self {
self.insert(val);
self
}

/// Removes the values from `other` and inserts them into `self`.
pub fn append(&mut self, other: &mut Self) {
self.map.extend(other.map.drain())
}

/// Insert a value into this `Extensions`.
///
/// If a value of this type already exists, it will be returned.
pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) -> Option<T> {
self.map
.get_or_insert_with(Default::default)
.insert(TypeId::of::<T>(), Box::new(val))
.and_then(|boxed| (boxed as Box<dyn Any>).downcast().ok().map(|boxed| *boxed))
}

/// Check if container contains value for type
pub fn contains<T: 'static>(&self) -> bool {
self.map
.as_ref()
.and_then(|m| m.get(&TypeId::of::<T>()))
.is_some()
self.map.get(&TypeId::of::<T>()).is_some()
}

/// Get a reference to a value previously inserted on this `Extensions`.
pub fn get<T: 'static>(&self) -> Option<&T> {
self.map
.as_ref()
.and_then(|m| m.get(&TypeId::of::<T>()))
.get(&TypeId::of::<T>())
.and_then(|boxed| (&**boxed as &(dyn Any)).downcast_ref())
}

/// Get a mutable reference to a value previously inserted on this `Extensions`.
pub fn get_mut<T: 'static>(&mut self) -> Option<&mut T> {
self.map
.as_mut()
.and_then(|m| m.get_mut(&TypeId::of::<T>()))
.get_mut(&TypeId::of::<T>())
.and_then(|boxed| (&mut **boxed as &mut (dyn Any)).downcast_mut())
}

Expand All @@ -63,15 +79,14 @@ impl Extensions {
/// If a value of this type exists, it will be returned.
pub fn remove<T: 'static>(&mut self) -> Option<T> {
self.map
.as_mut()
.and_then(|m| m.remove(&TypeId::of::<T>()))
.remove(&TypeId::of::<T>())
.and_then(|boxed| (boxed as Box<dyn Any>).downcast().ok().map(|boxed| *boxed))
}

/// Clear the `Extensions` of all inserted values.
#[inline]
pub fn clear(&mut self) {
self.map = None;
self.map.clear();
}
}

Expand Down
31 changes: 31 additions & 0 deletions src/lib.rs
@@ -1,4 +1,35 @@
//! A type map for storing data of arbritrary type.
//!
//! # Extensions
//! [`Extensions`] is a container that can store up to one value of each type, so you can insert and retrive values by
//! their type:
//!
//! ```
//! use task_local_extensions::Extensions;
//!
//! let a: i64 = 3;
//! let mut ext = Extensions::new();
//! extensions.insert(a);
//! assert_eq!(ext.get::<i64>(), Some(&3));
//! ```
//!
//! # Task Local Extensions
//! The crate also provides [`with_extensions`] so you set an [`Extensions`] instance while running a given task:
//!
//! ```
//! use task_local_extensions::{get_local_item, set_local_item, with_extensions, Extensions};
//!
//! async fn my_task() {
//! let a: i64 = get_local_item().await.unwrap(0);
//! let msg = format!("The value of a is: {}", a);
//! set_local_item(msg).await;
//! }
//!
//! let a: i64 = 3;
//! let (out_ext, _) = with_extensions(Extensions::new().with(a), my_task()).await;
//! let msg = out_ext.get::<String>().unwrap();
//! assert_eq!(msg.as_str(), "The value of a is: 3");
//! ```

mod extensions;
mod task_local;
Expand Down
10 changes: 7 additions & 3 deletions src/task_local.rs
@@ -1,3 +1,7 @@
// clippy bug wrongly flags the task_local macro as being bad.
// a fix is already merged but hasn't made it upstream yet
#![allow(clippy::declare_interior_mutable_const)]

use crate::Extensions;
use std::cell::RefCell;
use std::future::Future;
Expand All @@ -16,11 +20,11 @@ pub async fn with_extensions<T>(
EXTENSIONS
.scope(RefCell::new(extensions), async move {
let response = fut.await;
let extensions = RefCell::new(Extensions::new());
let mut extensions = Extensions::new();

EXTENSIONS.with(|ext| ext.swap(&extensions));
EXTENSIONS.with(|ext| std::mem::swap(&mut *ext.borrow_mut(), &mut extensions));

(extensions.into_inner(), response)
(extensions, response)
})
.await
}
Expand Down

0 comments on commit 8e1c6df

Please sign in to comment.