Skip to content

Commit

Permalink
Merge pull request #778 from nappa85/master
Browse files Browse the repository at this point in the history
Stream metrics
  • Loading branch information
tyt2y3 committed Jun 26, 2022
2 parents 580fa90 + 0e1c825 commit d074faf
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 67 deletions.
56 changes: 56 additions & 0 deletions src/database/stream/metric.rs
@@ -0,0 +1,56 @@
use std::{time::Duration, pin::Pin, task::Poll};

use futures::Stream;

use crate::{QueryResult, DbErr, Statement};

pub(crate) struct MetricStream<'a> {
metric_callback: &'a Option<crate::metric::Callback>,
stmt: &'a Statement,
elapsed: Option<Duration>,
stream: Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + 'a + Send>>,
}

impl<'a> MetricStream<'a> {
pub(crate) fn new<S>(metric_callback: &'a Option<crate::metric::Callback>, stmt: &'a Statement, elapsed: Option<Duration>, stream: S) -> Self
where
S: Stream<Item = Result<QueryResult, DbErr>> + 'a + Send,
{
MetricStream {
metric_callback,
stmt,
elapsed,
stream: Box::pin(stream),
}
}
}

impl<'a> Stream for MetricStream<'a> {
type Item = Result<QueryResult, DbErr>;

fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let _start = this.metric_callback.is_some().then(std::time::SystemTime::now);
let res = Pin::new(&mut this.stream).poll_next(cx);
if let (Some(_start), Some(elapsed)) = (_start, &mut this.elapsed) {
*elapsed += _start.elapsed().unwrap_or_default();
}
res
}
}

impl<'a> Drop for MetricStream<'a> {
fn drop(&mut self) {
if let (Some(callback), Some(elapsed)) = (self.metric_callback.as_deref(), self.elapsed) {
let info = crate::metric::Info {
elapsed: elapsed,
statement: self.stmt,
failed: false,
};
callback(&info);
}
}
}
2 changes: 2 additions & 0 deletions src/database/stream/mod.rs
@@ -1,3 +1,5 @@
mod metric;

mod query;
mod transaction;

Expand Down
48 changes: 26 additions & 22 deletions src/database/stream/query.rs
@@ -1,6 +1,6 @@
#![allow(missing_docs)]

use std::{pin::Pin, task::Poll};
use std::{pin::Pin, task::Poll, time::SystemTime};

#[cfg(feature = "mock")]
use std::sync::Arc;
Expand All @@ -16,6 +16,8 @@ use tracing::instrument;

use crate::{DbErr, InnerConnection, QueryResult, Statement};

use super::metric::MetricStream;

/// Creates a stream from a [QueryResult]
#[ouroboros::self_referencing]
pub struct QueryStream {
Expand All @@ -24,7 +26,7 @@ pub struct QueryStream {
metric_callback: Option<crate::metric::Callback>,
#[borrows(mut conn, stmt, metric_callback)]
#[not_covariant]
stream: Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send + 'this>>,
stream: MetricStream<'this>,
}

#[cfg(feature = "sqlx-mysql")]
Expand Down Expand Up @@ -124,38 +126,40 @@ impl QueryStream {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(c) => {
let query = crate::driver::sqlx_mysql::sqlx_query(stmt);
crate::metric::metric_ok!(_metric_callback, stmt, {
Box::pin(
c.fetch(query)
let _start = _metric_callback.is_some().then(SystemTime::now);
let stream = c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
})
.map_err(crate::sqlx_error_to_query_err);
let elapsed = _start.map(|s| s.elapsed().unwrap_or_default());
MetricStream::new(_metric_callback, stmt, elapsed, stream)
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(c) => {
let query = crate::driver::sqlx_postgres::sqlx_query(stmt);
crate::metric::metric_ok!(_metric_callback, stmt, {
Box::pin(
c.fetch(query)
let _start = _metric_callback.is_some().then(SystemTime::now);
let stream = c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
})
.map_err(crate::sqlx_error_to_query_err);
let elapsed = _start.map(|s| s.elapsed().unwrap_or_default());
MetricStream::new(_metric_callback, stmt, elapsed, stream)
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(c) => {
let query = crate::driver::sqlx_sqlite::sqlx_query(stmt);
crate::metric::metric_ok!(_metric_callback, stmt, {
Box::pin(
c.fetch(query)
let _start = _metric_callback.is_some().then(SystemTime::now);
let stream = c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
})
.map_err(crate::sqlx_error_to_query_err);
let elapsed = _start.map(|s| s.elapsed().unwrap_or_default());
MetricStream::new(_metric_callback, stmt, elapsed, stream)
}
#[cfg(feature = "mock")]
InnerConnection::Mock(c) => c.fetch(stmt),
InnerConnection::Mock(c) => {
let _start = _metric_callback.is_some().then(SystemTime::now);
let stream = c.fetch(stmt);
let elapsed = _start.map(|s| s.elapsed().unwrap_or_default());
MetricStream::new(_metric_callback, stmt, elapsed, stream)
},
#[allow(unreachable_patterns)]
_ => unreachable!(),
},
Expand All @@ -172,6 +176,6 @@ impl Stream for QueryStream {
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
this.with_stream_mut(|stream| stream.as_mut().poll_next(cx))
this.with_stream_mut(|stream| Pin::new(stream).poll_next(cx))
}
}
51 changes: 26 additions & 25 deletions src/database/stream/transaction.rs
@@ -1,6 +1,6 @@
#![allow(missing_docs)]

use std::{ops::DerefMut, pin::Pin, task::Poll};
use std::{ops::DerefMut, pin::Pin, task::Poll, time::SystemTime};

use futures::Stream;
#[cfg(feature = "sqlx-dep")]
Expand All @@ -15,6 +15,8 @@ use tracing::instrument;

use crate::{DbErr, InnerConnection, QueryResult, Statement};

use super::metric::MetricStream;

/// `TransactionStream` cannot be used in a `transaction` closure as it does not impl `Send`.
/// It seems to be a Rust limitation right now, and solution to work around this deemed to be extremely hard.
#[ouroboros::self_referencing]
Expand All @@ -24,7 +26,7 @@ pub struct TransactionStream<'a> {
metric_callback: Option<crate::metric::Callback>,
#[borrows(mut conn, stmt, metric_callback)]
#[not_covariant]
stream: Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + 'this + Send>>,
stream: MetricStream<'this>,
}

impl<'a> std::fmt::Debug for TransactionStream<'a> {
Expand All @@ -48,41 +50,40 @@ impl<'a> TransactionStream<'a> {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(c) => {
let query = crate::driver::sqlx_mysql::sqlx_query(stmt);
crate::metric::metric_ok!(_metric_callback, stmt, {
Box::pin(
c.fetch(query)
let _start = _metric_callback.is_some().then(SystemTime::now);
let stream = c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
as Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send>>
})
.map_err(crate::sqlx_error_to_query_err);
let elapsed = _start.map(|s| s.elapsed().unwrap_or_default());
MetricStream::new(_metric_callback, stmt, elapsed, stream)
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(c) => {
let query = crate::driver::sqlx_postgres::sqlx_query(stmt);
crate::metric::metric_ok!(_metric_callback, stmt, {
Box::pin(
c.fetch(query)
let _start = _metric_callback.is_some().then(SystemTime::now);
let stream = c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
as Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send>>
})
.map_err(crate::sqlx_error_to_query_err);
let elapsed = _start.map(|s| s.elapsed().unwrap_or_default());
MetricStream::new(_metric_callback, stmt, elapsed, stream)
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(c) => {
let query = crate::driver::sqlx_sqlite::sqlx_query(stmt);
crate::metric::metric_ok!(_metric_callback, stmt, {
Box::pin(
c.fetch(query)
let _start = _metric_callback.is_some().then(SystemTime::now);
let stream = c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
as Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send>>
})
.map_err(crate::sqlx_error_to_query_err);
let elapsed = _start.map(|s| s.elapsed().unwrap_or_default());
MetricStream::new(_metric_callback, stmt, elapsed, stream)
}
#[cfg(feature = "mock")]
InnerConnection::Mock(c) => c.fetch(stmt),
InnerConnection::Mock(c) => {
let _start = _metric_callback.is_some().then(SystemTime::now);
let stream = c.fetch(stmt);
let elapsed = _start.map(|s| s.elapsed().unwrap_or_default());
MetricStream::new(_metric_callback, stmt, elapsed, stream)
},
#[allow(unreachable_patterns)]
_ => unreachable!(),
},
Expand All @@ -99,6 +100,6 @@ impl<'a> Stream for TransactionStream<'a> {
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
this.with_stream_mut(|stream| stream.as_mut().poll_next(cx))
this.with_stream_mut(|stream| Pin::new(stream).poll_next(cx))
}
}
23 changes: 3 additions & 20 deletions src/metric.rs
Expand Up @@ -3,7 +3,7 @@ use std::{sync::Arc, time::Duration};
pub(crate) type Callback = Arc<dyn Fn(&Info<'_>) + Send + Sync>;

#[allow(unused_imports)]
pub(crate) use inner::{metric, metric_ok};
pub(crate) use inner::metric;

#[derive(Debug)]
/// Query execution infos
Expand All @@ -20,9 +20,9 @@ mod inner {
#[allow(unused_macros)]
macro_rules! metric {
($metric_callback:expr, $stmt:expr, $code:block) => {{
let _start = std::time::SystemTime::now();
let _start = $metric_callback.is_some().then(std::time::SystemTime::now);
let res = $code;
if let Some(callback) = $metric_callback.as_deref() {
if let (Some(_start), Some(callback)) = (_start, $metric_callback.as_deref()) {
let info = crate::metric::Info {
elapsed: _start.elapsed().unwrap_or_default(),
statement: $stmt,
Expand All @@ -34,21 +34,4 @@ mod inner {
}};
}
pub(crate) use metric;
#[allow(unused_macros)]
macro_rules! metric_ok {
($metric_callback:expr, $stmt:expr, $code:block) => {{
let _start = std::time::SystemTime::now();
let res = $code;
if let Some(callback) = $metric_callback.as_deref() {
let info = crate::metric::Info {
elapsed: _start.elapsed().unwrap_or_default(),
statement: $stmt,
failed: false,
};
callback(&info);
}
res
}};
}
pub(crate) use metric_ok;
}

0 comments on commit d074faf

Please sign in to comment.