From f88ab448596be589d719f1cdbb39b034c524a71a Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Sat, 27 Jun 2020 01:26:22 +0200 Subject: [PATCH] Rewrite our internal type system (Sorry for the really large diff but this touches nearly all parts of diesel) This commit does in two things: * Refactor our handling of nullable values. Instead of having a marker trait for non nullable sql types we now indicate if a sql type is nullable by using a associated type on a new fundamental trait named SqlType. This allows us to reason if an type is nullable or not in a much precise way without running in conflicting implementation issues. This allows us to address #104 in a much more fundamental way. (The correct way as mentioned there by sgrif). * Refactor our handling of typed and untyped sql fragments. Instead of having two separate code paths for `Queryable` (sql types known) and `QueryableByName` (sql types not known) we now have only one code path and indicate if a query is typed or untyped as part of the expression sql type. This is required to address #2150. As part of this change we unify `Queryable`, `QueryableByName` and `FromSqlRow` into only `FromSqlRow`. Additionally we separate the static size component of `FromSqlRow` to allow dynamically sized rows there (but only at the last position for tuple impls.) I should probably have implement those changes in different commits but as both changes basically requiring touching most of our code base this would have required much more work... Both changes are theoretically big major breaking changes. For application code I expect the actual breakage to be minimal, see the required changes in `diesel_tests` and `diesel_cli` for examples. For highly generic code I would expect quite a few required changes. Additionally there are a few small drive by fixes. Fixes #104 Fixes #2274 Fixes #2161 --- CHANGELOG.md | 43 +++ diesel/Cargo.toml | 1 + diesel/src/associations/belongs_to.rs | 3 + diesel/src/connection/mod.rs | 17 +- diesel/src/deserialize.rs | 346 +++++------------- diesel/src/expression/array_comparison.rs | 8 +- diesel/src/expression/bound.rs | 7 +- diesel/src/expression/coerce.rs | 17 +- diesel/src/expression/count.rs | 4 +- diesel/src/expression/exists.rs | 14 +- .../functions/aggregate_ordering.rs | 18 +- diesel/src/expression/mod.rs | 56 ++- diesel/src/expression/nullable.rs | 6 +- diesel/src/expression/operators.rs | 161 +++++++- diesel/src/expression/ops/mod.rs | 3 +- diesel/src/expression/ops/numeric.rs | 3 +- diesel/src/expression/sql_literal.rs | 46 ++- diesel/src/expression/subselect.rs | 6 +- .../bool_expression_methods.rs | 29 +- .../global_expression_methods.rs | 45 ++- .../text_expression_methods.rs | 24 +- diesel/src/macros/mod.rs | 26 +- diesel/src/macros/ops.rs | 2 +- diesel/src/mysql/backend.rs | 2 +- diesel/src/mysql/connection/bind.rs | 74 ++-- diesel/src/mysql/connection/mod.rs | 34 +- diesel/src/mysql/connection/stmt/iterator.rs | 123 +++---- diesel/src/mysql/connection/stmt/metadata.rs | 42 +-- diesel/src/mysql/connection/stmt/mod.rs | 13 +- diesel/src/mysql/types/date_and_time.rs | 13 +- diesel/src/mysql/types/json.rs | 18 +- diesel/src/mysql/types/mod.rs | 31 +- diesel/src/mysql/types/numeric.rs | 6 +- diesel/src/mysql/types/primitives.rs | 31 +- diesel/src/pg/backend.rs | 12 +- diesel/src/pg/connection/cursor.rs | 68 +--- diesel/src/pg/connection/mod.rs | 33 +- diesel/src/pg/connection/result.rs | 30 +- diesel/src/pg/connection/row.rs | 79 ++-- diesel/src/pg/expression/array_comparison.rs | 6 +- diesel/src/pg/expression/date_and_time.rs | 5 +- .../src/pg/expression/expression_methods.rs | 46 ++- diesel/src/pg/expression/operators.rs | 5 +- diesel/src/pg/query_builder/mod.rs | 4 +- diesel/src/pg/types/array.rs | 7 +- diesel/src/pg/types/date_and_time/chrono.rs | 12 +- .../pg/types/date_and_time/deprecated_time.rs | 2 +- diesel/src/pg/types/date_and_time/mod.rs | 17 +- diesel/src/pg/types/date_and_time/std_time.rs | 2 +- diesel/src/pg/types/floats/mod.rs | 3 +- diesel/src/pg/types/integers.rs | 3 +- diesel/src/pg/types/json.rs | 22 +- diesel/src/pg/types/mac_addr.rs | 6 +- diesel/src/pg/types/money.rs | 2 +- diesel/src/pg/types/network_address.rs | 11 +- diesel/src/pg/types/numeric.rs | 2 +- diesel/src/pg/types/primitives.rs | 12 +- diesel/src/pg/types/ranges.rs | 34 +- diesel/src/pg/types/record.rs | 37 +- diesel/src/pg/types/uuid.rs | 9 +- .../insert_statement/insert_from_select.rs | 8 +- .../src/query_builder/insert_statement/mod.rs | 4 +- .../query_builder/select_statement/boxed.rs | 12 +- .../select_statement/dsl_impls.rs | 8 +- diesel/src/query_builder/sql_query.rs | 38 +- diesel/src/query_builder/where_clause.rs | 29 +- diesel/src/query_dsl/load_dsl.rs | 13 +- diesel/src/query_dsl/single_value_dsl.rs | 12 +- diesel/src/query_source/joins.rs | 5 +- diesel/src/r2d2.rs | 21 +- diesel/src/result.rs | 12 + diesel/src/row.rs | 183 ++++++--- diesel/src/serialize.rs | 7 +- diesel/src/sql_types/fold.rs | 10 +- diesel/src/sql_types/mod.rs | 197 +++++++++- diesel/src/sql_types/ops.rs | 40 +- diesel/src/sql_types/ord.rs | 6 +- diesel/src/sqlite/connection/functions.rs | 87 +++-- diesel/src/sqlite/connection/mod.rs | 28 +- diesel/src/sqlite/connection/raw.rs | 8 +- diesel/src/sqlite/connection/sqlite_value.rs | 123 ++++--- .../sqlite/connection/statement_iterator.rs | 56 +-- diesel/src/sqlite/connection/stmt.rs | 33 +- .../src/sqlite/types/date_and_time/chrono.rs | 6 +- diesel/src/sqlite/types/date_and_time/mod.rs | 6 +- diesel/src/sqlite/types/mod.rs | 30 +- diesel/src/sqlite/types/numeric.rs | 2 +- diesel/src/type_impls/floats.rs | 6 +- diesel/src/type_impls/integers.rs | 9 +- diesel/src/type_impls/option.rs | 79 ++-- diesel/src/type_impls/primitives.rs | 45 ++- diesel/src/type_impls/tuples.rs | 147 +++++--- .../infer_schema_internals/data_structures.rs | 34 +- .../information_schema.rs | 7 +- .../src/infer_schema_internals/sqlite.rs | 2 +- .../src/infer_schema_internals/table_data.rs | 23 +- ...mix_aggregate_and_non_aggregate_selects.rs | 4 +- ...r_requires_bool_nonaggregate_expression.rs | 2 +- ...not_support_returning_methods_on_sqlite.rs | 21 +- ...it_on_requires_valid_boolean_expression.rs | 2 +- ...ght_side_of_left_join_requires_nullable.rs | 5 - ...s_table_name_or_sql_type_annotation.stderr | 12 +- diesel_derives/src/diesel_numeric_ops.rs | 9 + diesel_derives/src/from_sql_row.rs | 25 +- diesel_derives/src/lib.rs | 270 +++++++++++++- diesel_derives/src/queryable.rs | 40 +- diesel_derives/src/queryable_by_name.rs | 68 ++-- diesel_derives/src/sql_function.rs | 22 +- diesel_derives/src/sql_type.rs | 7 +- diesel_tests/tests/custom_types.rs | 4 +- .../tests/expressions/date_and_time.rs | 2 +- diesel_tests/tests/expressions/mod.rs | 11 +- diesel_tests/tests/filter.rs | 4 +- diesel_tests/tests/types.rs | 36 +- diesel_tests/tests/types_roundtrip.rs | 9 +- .../postgres/advanced-blog-cli/src/post.rs | 15 +- examples/postgres/custom_types/src/main.rs | 1 - examples/postgres/custom_types/src/model.rs | 4 +- 118 files changed, 2113 insertions(+), 1537 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 654e33a7cf37..72f58083e9c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,9 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/ functionality of `NonAggregate`. See [the upgrade notes](#2-0-0-upgrade-non-aggregate) for details. +* It is now possible to inspect the type of values returned from the database + in such a way to support constructing a dynamic value depending on this type. + ### Removed @@ -47,6 +50,12 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/ * Support for `bigdecimal` < 0.0.13 has been removed. * Support for `pq-sys` < 0.4.0 has been removed. * Support for `mysqlclient-sys` < 0.2.0 has been removed. +* The `Queryable` trait is replaced in favour of `FromSqlRow`. + If you use `#[derive(Queryable)]` a compatible implementation will be generated. + For manual implementations see the relevant section of [the migration guide][2-0-migration] +* The `QueryableByName` is replaced by in favour of `FromSqlRow`. + If you use `#[derive(QueryableByName)]` a compatible implementation will be generated. + For manual implementations see the relevant section of [the migration guide][2-0-migration] ### Changed @@ -93,6 +102,10 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/ `#[non_exhaustive]`. If you matched on one of those variants explicitly you need to introduce a wild card match instead. +* `FromSql::from_sql` is changed to construct value from non nullable database values. + To construct a rust value for nullable values use the new `FromSql::from_nullable_sql` + method instead. + ### Fixed @@ -129,6 +142,10 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/ * We've refactored our type translation layer for Mysql to handle more types now. +* We've refactored our type level representation of nullable values. This allowed us to + fix multiple long standing bugs regarding the correct handling of nullable values in some + corner cases (#104, #2274) + ### Deprecated * `diesel_(prefix|postfix|infix)_operator!` have been deprecated. These macros @@ -171,6 +188,32 @@ Key points: NonAggregate` no longer implies `(OtherType, T): NonAggregate`. - With `feature = "unstable"`, `(T, OtherType): NonAggregate` is still implied. +#### Replacement of `Queryable` and `QueryableByName` with `FromSqlRow` + +FIXME: This should probably be on the website, but I wanted to document it in +the PR adding the changes. + +Key points: + +- Unified deserializing rows into structs between `Queryable`, `QueryableByName` and `FromSqlRow` +- Allows to write code abstracting over `sql_query` and normal query dsl +- Derives continue to work, they will just generate the corresponding `FromSqlRow` impl +- For usage in trait bounds: + + `Queryable` is now `FromSqlRow` + + `QueryableByName` is now `FromSqlRow` +- For manual implementations: + + `Queryable` should implement `FromSqlRow` in such a way + that the the `FromSqlRow` impl for the old `Queryable::Row` type + is called internally. After that ~same construction as before + + `QueryableByName` should implement `FromSqlRow` in such a way + that for each field the following code is called: + + ```rust + let field = row.get("field name").ok_or("Column with name `field name` not found")?; + let value = >::from_nullable_sql(field.value())?; + ``` + + [2-0-migration]: FIXME write a migration guide ## [1.4.5] - 2020-06-09 diff --git a/diesel/Cargo.toml b/diesel/Cargo.toml index 90a48319d4cb..b4f0259ca8bd 100644 --- a/diesel/Cargo.toml +++ b/diesel/Cargo.toml @@ -32,6 +32,7 @@ num-integer = { version = "0.1.39", optional = true } bigdecimal = { version = ">= 0.0.13, < 0.2.0", optional = true } bitflags = { version = "1.2.0", optional = true } r2d2 = { version = ">= 0.8, < 0.9", optional = true } +itoa = "0.4" [dependencies.diesel_derives] version = "~2.0.0" diff --git a/diesel/src/associations/belongs_to.rs b/diesel/src/associations/belongs_to.rs index d6e76e4ff12f..377c01ad0025 100644 --- a/diesel/src/associations/belongs_to.rs +++ b/diesel/src/associations/belongs_to.rs @@ -4,6 +4,7 @@ use crate::expression::array_comparison::AsInExpression; use crate::expression::AsExpression; use crate::prelude::*; use crate::query_dsl::methods::FilterDsl; +use crate::sql_types::SqlType; use std::borrow::Borrow; use std::hash::Hash; @@ -139,6 +140,7 @@ where Id<&'a Parent>: AsExpression<::SqlType>, Child::Table: FilterDsl>>, Child::ForeignKeyColumn: ExpressionMethods, + ::SqlType: SqlType, { type Output = FindBy>; @@ -154,6 +156,7 @@ where Vec>: AsInExpression<::SqlType>, ::Table: FilterDsl>>>, Child::ForeignKeyColumn: ExpressionMethods, + ::SqlType: SqlType, { type Output = Filter>>>; diff --git a/diesel/src/connection/mod.rs b/diesel/src/connection/mod.rs index 328b81328654..4e1a7b310270 100644 --- a/diesel/src/connection/mod.rs +++ b/diesel/src/connection/mod.rs @@ -6,10 +6,10 @@ mod transaction_manager; use std::fmt::Debug; use crate::backend::Backend; -use crate::deserialize::{Queryable, QueryableByName}; +use crate::deserialize::{FromSqlRow, IsCompatibleType}; +use crate::expression::TypedExpressionType; use crate::query_builder::{AsQuery, QueryFragment, QueryId}; use crate::result::*; -use crate::sql_types::HasSqlType; #[doc(hidden)] pub use self::statement_cache::{MaybeCached, StatementCache, StatementCacheKey}; @@ -169,18 +169,13 @@ pub trait Connection: SimpleConnection + Sized + Send { fn execute(&self, query: &str) -> QueryResult; #[doc(hidden)] - fn query_by_index(&self, source: T) -> QueryResult> + fn load(&self, source: T) -> QueryResult> where T: AsQuery, T::Query: QueryFragment + QueryId, - Self::Backend: HasSqlType, - U: Queryable; - - #[doc(hidden)] - fn query_by_name(&self, source: &T) -> QueryResult> - where - T: QueryFragment + QueryId, - U: QueryableByName; + U: FromSqlRow, + T::SqlType: IsCompatibleType, + ST: TypedExpressionType; #[doc(hidden)] fn execute_returning_count(&self, source: &T) -> QueryResult diff --git a/diesel/src/deserialize.rs b/diesel/src/deserialize.rs index b2b07063af9e..930c2c7284bb 100644 --- a/diesel/src/deserialize.rs +++ b/diesel/src/deserialize.rs @@ -4,254 +4,17 @@ use std::error::Error; use std::result; use crate::backend::{self, Backend}; -use crate::row::{NamedRow, Row}; +use crate::expression::TypedExpressionType; +use crate::row::Row; +use crate::sql_types::{HasSqlType, SingleValue, Untyped}; /// A specialized result type representing the result of deserializing /// a value from the database. pub type Result = result::Result>; -/// Trait indicating that a record can be queried from the database. -/// -/// Types which implement `Queryable` represent the result of a SQL query. This -/// does not necessarily mean they represent a single database table. -/// -/// Diesel represents the return type of a query as a tuple. The purpose of this -/// trait is to convert from a tuple of Rust values that have been deserialized -/// into your struct. -/// -/// This trait can be [derived](derive.Queryable.html) -/// -/// # Examples -/// -/// If we just want to map a query to our struct, we can use `derive`. -/// -/// ```rust -/// # include!("doctest_setup.rs"); -/// # -/// #[derive(Queryable, PartialEq, Debug)] -/// struct User { -/// id: i32, -/// name: String, -/// } -/// -/// # fn main() { -/// # run_test(); -/// # } -/// # -/// # fn run_test() -> QueryResult<()> { -/// # use schema::users::dsl::*; -/// # let connection = establish_connection(); -/// let first_user = users.first(&connection)?; -/// let expected = User { id: 1, name: "Sean".into() }; -/// assert_eq!(expected, first_user); -/// # Ok(()) -/// # } -/// ``` -/// -/// If we want to do additional work during deserialization, we can use -/// `deserialize_as` to use a different implementation. -/// -/// ```rust -/// # include!("doctest_setup.rs"); -/// # -/// # use schema::users; -/// # use diesel::backend::{self, Backend}; -/// # use diesel::deserialize::Queryable; -/// # -/// struct LowercaseString(String); -/// -/// impl Into for LowercaseString { -/// fn into(self) -> String { -/// self.0 -/// } -/// } -/// -/// impl Queryable for LowercaseString -/// where -/// DB: Backend, -/// String: Queryable, -/// { -/// type Row = >::Row; -/// -/// fn build(row: Self::Row) -> Self { -/// LowercaseString(String::build(row).to_lowercase()) -/// } -/// } -/// -/// #[derive(Queryable, PartialEq, Debug)] -/// struct User { -/// id: i32, -/// #[diesel(deserialize_as = "LowercaseString")] -/// name: String, -/// } -/// -/// # fn main() { -/// # run_test(); -/// # } -/// # -/// # fn run_test() -> QueryResult<()> { -/// # use schema::users::dsl::*; -/// # let connection = establish_connection(); -/// let first_user = users.first(&connection)?; -/// let expected = User { id: 1, name: "sean".into() }; -/// assert_eq!(expected, first_user); -/// # Ok(()) -/// # } -/// ``` -/// -/// Alternatively, we can implement the trait for our struct manually. -/// -/// ```rust -/// # include!("doctest_setup.rs"); -/// # -/// use schema::users; -/// use diesel::deserialize::Queryable; -/// -/// # /* -/// type DB = diesel::sqlite::Sqlite; -/// # */ -/// -/// #[derive(PartialEq, Debug)] -/// struct User { -/// id: i32, -/// name: String, -/// } -/// -/// impl Queryable for User { -/// type Row = (i32, String); -/// -/// fn build(row: Self::Row) -> Self { -/// User { -/// id: row.0, -/// name: row.1.to_lowercase(), -/// } -/// } -/// } -/// -/// # fn main() { -/// # run_test(); -/// # } -/// # -/// # fn run_test() -> QueryResult<()> { -/// # use schema::users::dsl::*; -/// # let connection = establish_connection(); -/// let first_user = users.first(&connection)?; -/// let expected = User { id: 1, name: "sean".into() }; -/// assert_eq!(expected, first_user); -/// # Ok(()) -/// # } -/// ``` -pub trait Queryable -where - DB: Backend, -{ - /// The Rust type you'd like to map from. - /// - /// This is typically a tuple of all of your struct's fields. - type Row: FromSqlRow; - - /// Construct an instance of this type - fn build(row: Self::Row) -> Self; -} - #[doc(inline)] pub use diesel_derives::Queryable; -/// Deserializes the result of a query constructed with [`sql_query`]. -/// -/// This trait can be [derived](derive.QueryableByName.html) -/// -/// [`sql_query`]: ../fn.sql_query.html -/// -/// # Examples -/// -/// If we just want to map a query to our struct, we can use `derive`. -/// -/// ```rust -/// # include!("doctest_setup.rs"); -/// # use schema::users; -/// # use diesel::sql_query; -/// # -/// #[derive(QueryableByName, PartialEq, Debug)] -/// #[table_name = "users"] -/// struct User { -/// id: i32, -/// name: String, -/// } -/// -/// # fn main() { -/// # run_test(); -/// # } -/// # -/// # fn run_test() -> QueryResult<()> { -/// # let connection = establish_connection(); -/// let first_user = sql_query("SELECT * FROM users ORDER BY id LIMIT 1") -/// .get_result(&connection)?; -/// let expected = User { id: 1, name: "Sean".into() }; -/// assert_eq!(expected, first_user); -/// # Ok(()) -/// # } -/// ``` -/// -/// If we want to do additional work during deserialization, we can use -/// `deserialize_as` to use a different implementation. -/// -/// ```rust -/// # include!("doctest_setup.rs"); -/// # use diesel::sql_query; -/// # use schema::users; -/// # use diesel::backend::{self, Backend}; -/// # use diesel::deserialize::{self, FromSql}; -/// # -/// struct LowercaseString(String); -/// -/// impl Into for LowercaseString { -/// fn into(self) -> String { -/// self.0 -/// } -/// } -/// -/// impl FromSql for LowercaseString -/// where -/// DB: Backend, -/// String: FromSql, -/// { -/// fn from_sql(bytes: Option>) -> deserialize::Result { -/// String::from_sql(bytes) -/// .map(|s| LowercaseString(s.to_lowercase())) -/// } -/// } -/// -/// #[derive(QueryableByName, PartialEq, Debug)] -/// #[table_name = "users"] -/// struct User { -/// id: i32, -/// #[diesel(deserialize_as = "LowercaseString")] -/// name: String, -/// } -/// -/// # fn main() { -/// # run_test(); -/// # } -/// # -/// # fn run_test() -> QueryResult<()> { -/// # let connection = establish_connection(); -/// let first_user = sql_query("SELECT * FROM users ORDER BY id LIMIT 1") -/// .get_result(&connection)?; -/// let expected = User { id: 1, name: "sean".into() }; -/// assert_eq!(expected, first_user); -/// # Ok(()) -/// # } -/// ``` -pub trait QueryableByName -where - Self: Sized, - DB: Backend, -{ - /// Construct an instance of `Self` from the database row - fn build>(row: &R) -> Result; -} - #[doc(inline)] pub use diesel_derives::QueryableByName; @@ -263,7 +26,7 @@ pub use diesel_derives::QueryableByName; /// the database, prefer `i32::from_sql(bytes)` over reading from `bytes` /// directly) /// -/// Types which implement this trait should also have `#[derive(FromSqlRow)]` +/// Types which implement this trait should also have [`#[derive(FromSqlRow)]`] /// /// ### Backend specific details /// @@ -276,6 +39,7 @@ pub use diesel_derives::QueryableByName; /// - For third party backends, consult that backend's documentation. /// /// [`MysqlType`]: ../mysql/enum.MysqlType.html +/// [`#[derive(FromSqlRow)]`]: derive.FromSqlRow.html /// /// ### Examples /// @@ -285,10 +49,10 @@ pub use diesel_derives::QueryableByName; /// ```rust /// # use diesel::backend::{self, Backend}; /// # use diesel::sql_types::*; -/// # use diesel::deserialize::{self, FromSql}; +/// # use diesel::deserialize::{self, FromSql, FromSqlRow}; /// # /// #[repr(i32)] -/// #[derive(Debug, Clone, Copy)] +/// #[derive(Debug, Clone, Copy, FromSqlRow)] /// pub enum MyEnum { /// A = 1, /// B = 2, @@ -299,7 +63,7 @@ pub use diesel_derives::QueryableByName; /// DB: Backend, /// i32: FromSql, /// { -/// fn from_sql(bytes: Option>) -> deserialize::Result { +/// fn from_sql(bytes: backend::RawValue) -> deserialize::Result { /// match i32::from_sql(bytes)? { /// 1 => Ok(MyEnum::A), /// 2 => Ok(MyEnum::B), @@ -310,38 +74,102 @@ pub use diesel_derives::QueryableByName; /// ``` pub trait FromSql: Sized { /// See the trait documentation. - fn from_sql(bytes: Option>) -> Result; + fn from_sql(bytes: backend::RawValue) -> Result; + + /// A specialized variant of `from_sql` for handling null values. + /// + /// The default implementation returns an `UnexpectedNullError` for + /// an encountered null value and calls `Self::from_sql` otherwise + /// + /// If your custom type supports null values you need to provide a + /// custom implementation. + #[inline(always)] + fn from_nullable_sql(bytes: Option>) -> Result { + match bytes { + Some(bytes) => Self::from_sql(bytes), + None => Err(Box::new(crate::result::UnexpectedNullError)), + } + } } -/// Deserialize one or more fields. +/// A trait to check if the sql type of two expression sql types are compatible +/// +/// If you see an error message involving this trait check that you try to load +/// the result of an query into a compatible struct +pub trait IsCompatibleType: TypedExpressionType { + /// A type marked as compatible with `Self` + type Compatible: TypedExpressionType; + + #[doc(hidden)] + #[cfg(feature = "mysql")] + fn mysql_row_metadata(_lookup: &DB::MetadataLookup) -> Option> + where + DB: Backend + crate::sql_types::TypeMetadata, + { + None + } +} + +// Any typed row is compatible with any type that supports loading this typed value +impl IsCompatibleType for ST +where + ST: SingleValue, + DB: Backend + HasSqlType, +{ + type Compatible = ST; + + #[cfg(feature = "mysql")] + fn mysql_row_metadata(lookup: &DB::MetadataLookup) -> Option> + where + DB: Backend + crate::sql_types::TypeMetadata, + { + let mut out = Vec::new(); + >::mysql_row_metadata(&mut out, lookup); + Some(out) + } +} + +// Any untyped row is compatible with any type that supports loading untyped values +impl IsCompatibleType for Untyped +where + DB: Backend, +{ + type Compatible = Untyped; +} + +/// Deserialize a database row into a rust data structure /// /// All types which implement `FromSql` should also implement this trait. This /// trait differs from `FromSql` in that it is also implemented by tuples. /// Implementations of this trait are usually derived. /// -/// In the future, we hope to be able to provide a blanket impl of this trait -/// for all types which implement `FromSql`. However, as of Diesel 1.0, such an -/// impl would conflict with our impl for tuples. +/// For types representing the result of an complete query use +/// [`#[derive(Queryable)]`] or [`#[derive(QueryableByName)]`] to implement this trait. +// +/// For types implementing `FromSql` that represent a single database value +/// use [`#[derive(FromSqlRow)]`]. /// -/// This trait can be [derived](derive.FromSqlRow.html) -pub trait FromSqlRow: Sized { +/// [`#[derive(Queryable)]`]: derive.Queryable.html +/// [`#[derive(QueryableByName)]`]: derive.QueryableByName.html +/// [`#[derive(FromSqlRow)]`]: derive.FromSqlRow.html +pub trait FromSqlRow: Sized { /// See the trait documentation. - fn build_from_row>(row: &mut T) -> Result; + fn build_from_row<'a>(row: &impl Row<'a, DB>) -> Result; } #[doc(inline)] pub use diesel_derives::FromSqlRow; -/// A marker trait indicating that the corresponding type is a statically sized row +/// A marker trait indicating that the corresponding type consumes a static at +/// compile time known number of field /// /// This trait is implemented for all types provided by diesel, that -/// implement `FromSqlRow`. -/// -/// For dynamically sized types, like `diesel_dynamic_schema::DynamicRow` -/// this traits should not be implemented. +/// implement `FromSqlRow where ST: SqlType`. It is not implemented for +/// types implementing `FromSqlRow`. /// -/// This trait can be [derived](derive.FromSqlRow.html) -pub trait StaticallySizedRow: FromSqlRow { +/// This trait can be derived via [`#[derive(Queryable)]`](derive.Queryable.html) or +/// [`#[derive(FromSqlRow)]`](derive.FromSqlRow.html) +pub trait StaticallySizedRow: FromSqlRow { /// The number of fields that this type will consume. Must be equal to /// the number of times you would call `row.take()` in `build_from_row` const FIELD_COUNT: usize = 1; diff --git a/diesel/src/expression/array_comparison.rs b/diesel/src/expression/array_comparison.rs index 1eb5d1db0306..ca90e3950d30 100644 --- a/diesel/src/expression/array_comparison.rs +++ b/diesel/src/expression/array_comparison.rs @@ -2,6 +2,7 @@ use crate::backend::Backend; use crate::expression::subselect::Subselect; use crate::expression::*; use crate::query_builder::*; +use crate::query_builder::{BoxedSelectStatement, SelectStatement}; use crate::result::QueryResult; use crate::sql_types::Bool; @@ -92,9 +93,7 @@ where impl_selectable_expression!(In); impl_selectable_expression!(NotIn); -use crate::query_builder::{BoxedSelectStatement, SelectStatement}; - -pub trait AsInExpression { +pub trait AsInExpression { type InExpression: MaybeEmpty + Expression; fn as_in_expression(self) -> Self::InExpression; @@ -104,6 +103,7 @@ impl AsInExpression for I where I: IntoIterator, T: AsExpression, + ST: SqlType + TypedExpressionType, { type InExpression = Many; @@ -119,6 +119,7 @@ pub trait MaybeEmpty { impl AsInExpression for SelectStatement where + ST: SqlType + TypedExpressionType, Subselect: Expression, Self: SelectQuery, { @@ -131,6 +132,7 @@ where impl<'a, ST, QS, DB> AsInExpression for BoxedSelectStatement<'a, ST, QS, DB> where + ST: SqlType + TypedExpressionType, Subselect, ST>: Expression, { type InExpression = Subselect; diff --git a/diesel/src/expression/bound.rs b/diesel/src/expression/bound.rs index 4dbb605a2450..6c3d33d976e9 100644 --- a/diesel/src/expression/bound.rs +++ b/diesel/src/expression/bound.rs @@ -5,7 +5,7 @@ use crate::backend::Backend; use crate::query_builder::*; use crate::result::QueryResult; use crate::serialize::ToSql; -use crate::sql_types::{DieselNumericOps, HasSqlType}; +use crate::sql_types::{DieselNumericOps, HasSqlType, SqlType}; #[derive(Debug, Clone, Copy, DieselNumericOps)] pub struct Bound { @@ -22,7 +22,10 @@ impl Bound { } } -impl Expression for Bound { +impl Expression for Bound +where + T: SqlType + TypedExpressionType, +{ type SqlType = T; } diff --git a/diesel/src/expression/coerce.rs b/diesel/src/expression/coerce.rs index 3328a3f74f84..ecf3eacc5084 100644 --- a/diesel/src/expression/coerce.rs +++ b/diesel/src/expression/coerce.rs @@ -4,7 +4,7 @@ use crate::backend::Backend; use crate::expression::*; use crate::query_builder::*; use crate::result::QueryResult; -use crate::sql_types::DieselNumericOps; +use crate::sql_types::{DieselNumericOps, SqlType}; #[derive(Debug, Copy, Clone, QueryId, DieselNumericOps)] #[doc(hidden)] @@ -36,13 +36,24 @@ impl Coerce { impl Expression for Coerce where T: Expression, + ST: SqlType + TypedExpressionType, { type SqlType = ST; } -impl SelectableExpression for Coerce where T: SelectableExpression {} +impl SelectableExpression for Coerce +where + T: SelectableExpression, + Self: Expression, +{ +} -impl AppearsOnTable for Coerce where T: AppearsOnTable {} +impl AppearsOnTable for Coerce +where + T: AppearsOnTable, + Self: Expression, +{ +} impl QueryFragment for Coerce where diff --git a/diesel/src/expression/count.rs b/diesel/src/expression/count.rs index c75f2cbde64d..6502340fc677 100644 --- a/diesel/src/expression/count.rs +++ b/diesel/src/expression/count.rs @@ -3,7 +3,7 @@ use super::{Expression, ValidGrouping}; use crate::backend::Backend; use crate::query_builder::*; use crate::result::QueryResult; -use crate::sql_types::{BigInt, DieselNumericOps}; +use crate::sql_types::{BigInt, DieselNumericOps, SingleValue, SqlType}; sql_function! { /// Creates a SQL `COUNT` expression @@ -25,7 +25,7 @@ sql_function! { /// # } /// ``` #[aggregate] - fn count(expr: T) -> BigInt; + fn count(expr: T) -> BigInt; } /// Creates a SQL `COUNT(*)` expression diff --git a/diesel/src/expression/exists.rs b/diesel/src/expression/exists.rs index 8b4d83ed54fe..385065f6bf94 100644 --- a/diesel/src/expression/exists.rs +++ b/diesel/src/expression/exists.rs @@ -32,21 +32,21 @@ pub fn exists(query: T) -> Exists { Exists(Subselect::new(query)) } -#[derive(Debug, Clone, Copy, QueryId)] -pub struct Exists(pub Subselect); +#[derive(Clone, Copy, QueryId, Debug)] +pub struct Exists(pub Subselect); impl Expression for Exists where - Subselect: Expression, + Subselect: Expression, { type SqlType = Bool; } impl ValidGrouping for Exists where - Subselect: ValidGrouping, + Subselect: ValidGrouping, { - type IsAggregate = as ValidGrouping>::IsAggregate; + type IsAggregate = as ValidGrouping>::IsAggregate; } #[cfg(not(feature = "unstable"))] @@ -80,13 +80,13 @@ where impl SelectableExpression for Exists where Self: AppearsOnTable, - Subselect: SelectableExpression, + Subselect: SelectableExpression, { } impl AppearsOnTable for Exists where Self: Expression, - Subselect: AppearsOnTable, + Subselect: AppearsOnTable, { } diff --git a/diesel/src/expression/functions/aggregate_ordering.rs b/diesel/src/expression/functions/aggregate_ordering.rs index 6048d11f0ac1..bb2889590a94 100644 --- a/diesel/src/expression/functions/aggregate_ordering.rs +++ b/diesel/src/expression/functions/aggregate_ordering.rs @@ -1,5 +1,17 @@ use crate::expression::functions::sql_function; -use crate::sql_types::{IntoNullable, SqlOrd}; +use crate::sql_types::{IntoNullable, SingleValue, SqlOrd, SqlType}; + +pub trait SqlOrdAggregate: SingleValue { + type Ret: SqlType + SingleValue; +} + +impl SqlOrdAggregate for T +where + T: SqlOrd + IntoNullable + SingleValue, + T::Nullable: SqlType + SingleValue, +{ + type Ret = T::Nullable; +} sql_function! { /// Represents a SQL `MAX` function. This function can only take types which are @@ -17,7 +29,7 @@ sql_function! { /// assert_eq!(Ok(Some(8)), animals.select(max(legs)).first(&connection)); /// # } #[aggregate] - fn max(expr: ST) -> ST::Nullable; + fn max(expr: ST) -> ST::Ret; } sql_function! { @@ -36,5 +48,5 @@ sql_function! { /// assert_eq!(Ok(Some(4)), animals.select(min(legs)).first(&connection)); /// # } #[aggregate] - fn min(expr: ST) -> ST::Nullable; + fn min(expr: ST) -> ST::Ret; } diff --git a/diesel/src/expression/mod.rs b/diesel/src/expression/mod.rs index fab5ac0e17a5..9dacb410e939 100644 --- a/diesel/src/expression/mod.rs +++ b/diesel/src/expression/mod.rs @@ -82,6 +82,7 @@ pub use self::sql_literal::{SqlLiteral, UncheckedBind}; use crate::backend::Backend; use crate::dsl::AsExprOf; +use crate::sql_types::SqlType; /// Represents a typed fragment of SQL. /// @@ -92,7 +93,45 @@ use crate::dsl::AsExprOf; /// implementing this directly. pub trait Expression { /// The type that this expression represents in SQL - type SqlType; + type SqlType: TypedExpressionType; +} + +/// Marker trait for possible types of [`Expression::SqlType`] +/// +/// [`Expression::SqlType`]: trait.Expression.html#associatedtype.SqlType +pub trait TypedExpressionType {} + +/// Possible types for []`Expression::SqlType`] +/// +/// [`Expression::SqlType`]: trait.Expression.html#associatedtype.SqlType +pub mod expression_types { + use super::TypedExpressionType; + use crate::sql_types::SingleValue; + + /// Query nodes with this expression type do not have a statically at compile + /// time known expression type. + /// + /// An example for such a query node in diesel itself, is `sql_query` as + /// we do not know which fields are returned from such a query at compile time. + /// + /// For loading values from queries returning a type of this expression, consider + /// using [`#[derive(QueryableByName)]`] on the corresponding result type. + /// + /// [`#[derive(QueryableByName)]`]: ../deserialize/derive.QueryableByName.html + #[derive(Clone, Copy, Debug)] + pub struct Untyped; + + /// Query nodes witch cannot be part of a select clause. + /// + /// If you see an error message containing `FromSqlRow` and this type + /// recheck that you have written a valid select clause + #[derive(Debug, Clone, Copy)] + pub struct NotSelectable; + + impl TypedExpressionType for Untyped {} + impl TypedExpressionType for NotSelectable {} + + impl TypedExpressionType for ST where ST: SingleValue {} } impl Expression for Box { @@ -124,7 +163,10 @@ impl<'a, T: Expression + ?Sized> Expression for &'a T { /// /// This trait could be [derived](derive.AsExpression.html) -pub trait AsExpression { +pub trait AsExpression +where + T: SqlType + TypedExpressionType, +{ /// The expression being returned type Expression: Expression; @@ -135,7 +177,11 @@ pub trait AsExpression { #[doc(inline)] pub use diesel_derives::AsExpression; -impl AsExpression for T { +impl AsExpression for T +where + T: Expression, + ST: SqlType + TypedExpressionType, +{ type Expression = Self; fn as_expression(self) -> Self { @@ -177,6 +223,7 @@ pub trait IntoSql { fn into_sql(self) -> AsExprOf where Self: AsExpression + Sized, + T: SqlType + TypedExpressionType, { self.as_expression() } @@ -188,6 +235,7 @@ pub trait IntoSql { fn as_sql<'a, T>(&'a self) -> AsExprOf<&'a Self, T> where &'a Self: AsExpression, + T: SqlType + TypedExpressionType, { self.as_expression() } @@ -432,7 +480,7 @@ use crate::query_builder::{QueryFragment, QueryId}; /// type DB = diesel::sqlite::Sqlite; /// # */ /// -/// fn find_user(search: Search) -> Box> { +/// fn find_user(search: Search) -> Box> { /// match search { /// Search::Id(id) => Box::new(users::id.eq(id)), /// Search::Name(name) => Box::new(users::name.eq(name)), diff --git a/diesel/src/expression/nullable.rs b/diesel/src/expression/nullable.rs index 70c400cea638..13a5bc268882 100644 --- a/diesel/src/expression/nullable.rs +++ b/diesel/src/expression/nullable.rs @@ -1,4 +1,5 @@ use crate::backend::Backend; +use crate::expression::TypedExpressionType; use crate::expression::*; use crate::query_builder::*; use crate::query_source::joins::ToInnerJoin; @@ -17,9 +18,10 @@ impl Nullable { impl Expression for Nullable where T: Expression, - ::SqlType: IntoNullable, + T::SqlType: IntoNullable, + ::Nullable: TypedExpressionType, { - type SqlType = <::SqlType as IntoNullable>::Nullable; + type SqlType = ::Nullable; } impl QueryFragment for Nullable diff --git a/diesel/src/expression/operators.rs b/diesel/src/expression/operators.rs index b5d9116f21e8..02e9d9481df2 100644 --- a/diesel/src/expression/operators.rs +++ b/diesel/src/expression/operators.rs @@ -5,7 +5,7 @@ macro_rules! __diesel_operator_body { notation = $notation:ident, struct_name = $name:ident, operator = $operator:expr, - return_ty = ReturnBasedOnArgs, + return_ty = (ReturnBasedOnArgs), ty_params = ($($ty_param:ident,)+), field_names = $field_names:tt, backend_ty_params = $backend_ty_params:tt, @@ -15,7 +15,7 @@ macro_rules! __diesel_operator_body { notation = $notation, struct_name = $name, operator = $operator, - return_ty = ST, + return_ty = (ST), ty_params = ($($ty_param,)+), field_names = $field_names, backend_ty_params = $backend_ty_params, @@ -29,7 +29,7 @@ macro_rules! __diesel_operator_body { notation = $notation:ident, struct_name = $name:ident, operator = $operator:expr, - return_ty = $return_ty:ty, + return_ty = ($($return_ty:tt)+), ty_params = ($($ty_param:ident,)+), field_names = $field_names:tt, backend_ty_params = $backend_ty_params:tt, @@ -39,7 +39,7 @@ macro_rules! __diesel_operator_body { notation = $notation, struct_name = $name, operator = $operator, - return_ty = $return_ty, + return_ty = ($($return_ty)*), ty_params = ($($ty_param,)+), field_names = $field_names, backend_ty_params = $backend_ty_params, @@ -53,7 +53,7 @@ macro_rules! __diesel_operator_body { notation = $notation:ident, struct_name = $name:ident, operator = $operator:expr, - return_ty = $return_ty:ty, + return_ty = ($($return_ty:tt)+), ty_params = ($($ty_param:ident,)+), field_names = ($($field_name:ident,)+), backend_ty_params = ($($backend_ty_param:ident,)*), @@ -85,7 +85,7 @@ macro_rules! __diesel_operator_body { impl<$($ty_param,)+ $($expression_ty_params,)*> $crate::expression::Expression for $name<$($ty_param,)+> where $($expression_bounds)* { - type SqlType = $return_ty; + type SqlType = $($return_ty)*; } impl<$($ty_param,)+ $($backend_ty_param,)*> $crate::query_builder::QueryFragment<$backend_ty> @@ -187,6 +187,8 @@ macro_rules! __diesel_operator_to_sql { /// /// ```rust /// # include!("../doctest_setup.rs"); +/// # use diesel::sql_types::SqlType; +/// # use diesel::expression::TypedExpressionType; /// # /// # fn main() { /// # use schema::users::dsl::*; @@ -196,9 +198,10 @@ macro_rules! __diesel_operator_to_sql { /// use diesel::expression::AsExpression; /// /// // Normally you would put this on a trait instead -/// fn my_eq(left: T, right: U) -> MyEq where -/// T: Expression, -/// U: AsExpression, +/// fn my_eq(left: T, right: U) -> MyEq where +/// T: Expression, +/// U: AsExpression, +/// ST: SqlType + TypedExpressionType, /// { /// MyEq::new(left, right.as_expression()) /// } @@ -223,11 +226,37 @@ macro_rules! infix_operator { notation = infix, struct_name = $name, operator = $operator, - return_ty = $($return_ty)::*, + return_ty = ( + $crate::sql_types::is_nullable::MaybeNullable< + $crate::sql_types::is_nullable::IsOneNullable< + ::SqlType, + ::SqlType + >, + $($return_ty)::* + > + ), ty_params = (T, U,), field_names = (left, right,), backend_ty_params = (DB,), backend_ty = DB, + expression_ty_params = (), + expression_bounds = ( + T: $crate::expression::Expression, + U: $crate::expression::Expression, + ::SqlType: $crate::sql_types::SqlType, + ::SqlType: $crate::sql_types::SqlType, + $crate::sql_types::is_nullable::IsSqlTypeNullable< + ::SqlType + >: $crate::sql_types::OneIsNullable< + $crate::sql_types::is_nullable::IsSqlTypeNullable< + ::SqlType + > + >, + $crate::sql_types::is_nullable::IsOneNullable< + ::SqlType, + ::SqlType + >: $crate::sql_types::MaybeNullableType<$($return_ty)::*>, + ), ); }; @@ -236,11 +265,37 @@ macro_rules! infix_operator { notation = infix, struct_name = $name, operator = $operator, - return_ty = $return_ty, + return_ty = ( + $crate::sql_types::is_nullable::MaybeNullable< + $crate::sql_types::is_nullable::IsOneNullable< + ::SqlType, + ::SqlType + >, + $return_ty, + > + ), ty_params = (T, U,), field_names = (left, right,), backend_ty_params = (), backend_ty = $backend, + expression_ty_params = (), + expression_bounds = ( + T: $crate::expression::Expression, + U: $crate::expression::Expression, + ::SqlType: $crate::sql_types::SqlType, + ::SqlType: $crate::sql_types::SqlType, + $crate::sql_types::is_nullable::IsSqlTypeNullable< + ::SqlType + >: $crate::sql_types::OneIsNullable< + $crate::sql_types::is_nullable::IsSqlTypeNullable< + ::SqlType + > + >, + $crate::sql_types::is_nullable::IsOneNullable< + ::SqlType, + ::SqlType + >: $crate::sql_types::MaybeNullableType<$return_ty>, + ), ); }; } @@ -278,7 +333,7 @@ macro_rules! postfix_operator { notation = postfix, struct_name = $name, operator = $operator, - return_ty = $return_ty, + return_ty = ($return_ty), ty_params = (Expr,), field_names = (expr,), backend_ty_params = (DB,), @@ -291,7 +346,7 @@ macro_rules! postfix_operator { notation = postfix, struct_name = $name, operator = $operator, - return_ty = $return_ty, + return_ty = ($return_ty), ty_params = (Expr,), field_names = (expr,), backend_ty_params = (), @@ -333,7 +388,7 @@ macro_rules! prefix_operator { notation = prefix, struct_name = $name, operator = $operator, - return_ty = $return_ty, + return_ty = ($return_ty), ty_params = (Expr,), field_names = (expr,), backend_ty_params = (DB,), @@ -346,7 +401,7 @@ macro_rules! prefix_operator { notation = prefix, struct_name = $name, operator = $operator, - return_ty = $return_ty, + return_ty = ($return_ty), ty_params = (Expr,), field_names = (expr,), backend_ty_params = (), @@ -377,20 +432,30 @@ infix_operator!(LtEq, " <= "); infix_operator!(NotBetween, " NOT BETWEEN "); infix_operator!(NotEq, " != "); infix_operator!(NotLike, " NOT LIKE "); -infix_operator!(Or, " OR "); postfix_operator!(IsNull, " IS NULL"); postfix_operator!(IsNotNull, " IS NOT NULL"); -postfix_operator!(Asc, " ASC", ()); -postfix_operator!(Desc, " DESC", ()); +postfix_operator!( + Asc, + " ASC ", + crate::expression::expression_types::NotSelectable +); +postfix_operator!( + Desc, + " DESC ", + crate::expression::expression_types::NotSelectable +); prefix_operator!(Not, "NOT "); -use crate::expression::ValidGrouping; +use crate::expression::{TypedExpressionType, ValidGrouping}; use crate::insertable::{ColumnInsertValue, Insertable}; use crate::query_builder::{QueryId, ValuesClause}; use crate::query_source::Column; -use crate::sql_types::DieselNumericOps; +use crate::sql_types::{ + is_nullable, AllAreNullable, Bool, DieselNumericOps, MaybeNullableType, SqlType, +}; +use crate::Expression; impl Insertable for Eq where @@ -432,6 +497,7 @@ impl crate::expression::Expression for Concat where L: crate::expression::Expression, R: crate::expression::Expression, + ST: SqlType + TypedExpressionType, { type SqlType = ST; } @@ -458,3 +524,58 @@ where Ok(()) } } + +// or is different +// it only evaluates to null if both sides are null +#[derive( + Debug, + Clone, + Copy, + crate::query_builder::QueryId, + crate::sql_types::DieselNumericOps, + crate::expression::ValidGrouping, +)] +#[doc(hidden)] +pub struct Or { + pub(crate) left: T, + pub(crate) right: U, +} + +impl Or { + pub fn new(left: T, right: U) -> Self { + Or { left, right } + } +} + +impl_selectable_expression!(Or); + +impl Expression for Or +where + T: Expression, + U: Expression, + T::SqlType: SqlType, + U::SqlType: SqlType, + is_nullable::IsSqlTypeNullable: + AllAreNullable>, + is_nullable::AreAllNullable: MaybeNullableType, +{ + type SqlType = + is_nullable::MaybeNullable, Bool>; +} + +impl crate::query_builder::QueryFragment for Or +where + DB: crate::backend::Backend, + T: crate::query_builder::QueryFragment, + U: crate::query_builder::QueryFragment, +{ + fn walk_ast( + &self, + mut out: crate::query_builder::AstPass, + ) -> crate::result::QueryResult<()> { + self.left.walk_ast(out.reborrow())?; + out.push_sql(" OR "); + self.right.walk_ast(out.reborrow())?; + Ok(()) + } +} diff --git a/diesel/src/expression/ops/mod.rs b/diesel/src/expression/ops/mod.rs index eea4aa4ac9c5..580a2932da0a 100644 --- a/diesel/src/expression/ops/mod.rs +++ b/diesel/src/expression/ops/mod.rs @@ -2,7 +2,8 @@ macro_rules! generic_numeric_expr_inner { ($tpe: ident, ($($param: ident),*), $op: ident, $fn_name: ident) => { impl ::std::ops::$op for $tpe<$($param),*> where $tpe<$($param),*>: $crate::expression::Expression, - <$tpe<$($param),*> as $crate::Expression>::SqlType: $crate::sql_types::ops::$op, + <$tpe<$($param),*> as $crate::Expression>::SqlType: $crate::sql_types::SqlType + $crate::sql_types::ops::$op, + <<$tpe<$($param),*> as $crate::Expression>::SqlType as $crate::sql_types::ops::$op>::Rhs: $crate::expression::TypedExpressionType, Rhs: $crate::expression::AsExpression< <<$tpe<$($param),*> as $crate::Expression>::SqlType as $crate::sql_types::ops::$op>::Rhs, >, diff --git a/diesel/src/expression/ops/numeric.rs b/diesel/src/expression/ops/numeric.rs index 8a7b79793197..1d5a5d4153cd 100644 --- a/diesel/src/expression/ops/numeric.rs +++ b/diesel/src/expression/ops/numeric.rs @@ -1,5 +1,5 @@ use crate::backend::Backend; -use crate::expression::{Expression, ValidGrouping}; +use crate::expression::{Expression, TypedExpressionType, ValidGrouping}; use crate::query_builder::*; use crate::result::QueryResult; use crate::sql_types; @@ -26,6 +26,7 @@ macro_rules! numeric_operation { Lhs: Expression, Lhs::SqlType: sql_types::ops::$name, Rhs: Expression, + ::Output: TypedExpressionType, { type SqlType = ::Output; } diff --git a/diesel/src/expression/sql_literal.rs b/diesel/src/expression/sql_literal.rs index bbc3dfee3792..68c0e0abe9f5 100644 --- a/diesel/src/expression/sql_literal.rs +++ b/diesel/src/expression/sql_literal.rs @@ -5,7 +5,7 @@ use crate::expression::*; use crate::query_builder::*; use crate::query_dsl::RunQueryDsl; use crate::result::QueryResult; -use crate::sql_types::DieselNumericOps; +use crate::sql_types::{DieselNumericOps, SqlType}; #[derive(Debug, Clone, DieselNumericOps)] #[must_use = "Queries are only executed when calling `load`, `get_result`, or similar."] @@ -18,7 +18,10 @@ pub struct SqlLiteral { _marker: PhantomData, } -impl SqlLiteral { +impl SqlLiteral +where + ST: TypedExpressionType, +{ #[doc(hidden)] pub fn new(sql: String, inner: T) -> Self { SqlLiteral { @@ -51,11 +54,11 @@ impl SqlLiteral { /// # fn main() { /// # use self::users::dsl::*; /// # use diesel::dsl::sql; - /// # use diesel::sql_types::{Integer, Text}; + /// # use diesel::sql_types::{Integer, Text, Bool}; /// # let connection = establish_connection(); /// let seans_id = users /// .select(id) - /// .filter(sql("name = ").bind::("Sean")) + /// .filter(sql::("name = ").bind::("Sean")) /// .get_result(&connection); /// assert_eq!(Ok(1), seans_id); /// @@ -81,14 +84,14 @@ impl SqlLiteral { /// # fn main() { /// # use self::users::dsl::*; /// # use diesel::dsl::sql; - /// # use diesel::sql_types::{Integer, Text}; + /// # use diesel::sql_types::{Integer, Text, Bool}; /// # let connection = establish_connection(); /// # diesel::insert_into(users).values(name.eq("Ryan")) /// # .execute(&connection).unwrap(); /// let query = users /// .select(name) /// .filter( - /// sql("id > ") + /// sql::("id > ") /// .bind::(1) /// .sql(" AND name <> ") /// .bind::("Ryan") @@ -100,6 +103,7 @@ impl SqlLiteral { /// ``` pub fn bind(self, bind_value: U) -> UncheckedBind where + BindST: SqlType + TypedExpressionType, U: AsExpression, { UncheckedBind::new(self, bind_value.as_expression()) @@ -132,14 +136,14 @@ impl SqlLiteral { /// # fn main() { /// # use self::users::dsl::*; /// # use diesel::dsl::sql; - /// # use diesel::sql_types::{Integer, Text}; + /// # use diesel::sql_types::Bool; /// # let connection = establish_connection(); /// # diesel::insert_into(users).values(name.eq("Ryan")) /// # .execute(&connection).unwrap(); /// let query = users /// .select(name) /// .filter( - /// sql("id > 1") + /// sql::("id > 1") /// .sql(" AND name <> 'Ryan'") /// ) /// .get_results(&connection); @@ -152,7 +156,10 @@ impl SqlLiteral { } } -impl Expression for SqlLiteral { +impl Expression for SqlLiteral +where + ST: TypedExpressionType, +{ type SqlType = ST; } @@ -175,15 +182,18 @@ impl QueryId for SqlLiteral { const HAS_STATIC_QUERY_ID: bool = false; } -impl Query for SqlLiteral { +impl Query for SqlLiteral +where + Self: Expression, +{ type SqlType = ST; } impl RunQueryDsl for SqlLiteral {} -impl SelectableExpression for SqlLiteral {} +impl SelectableExpression for SqlLiteral where Self: Expression {} -impl AppearsOnTable for SqlLiteral {} +impl AppearsOnTable for SqlLiteral where Self: Expression {} impl ValidGrouping for SqlLiteral { type IsAggregate = is_aggregate::Never; @@ -215,15 +225,19 @@ impl ValidGrouping for SqlLiteral { /// # /// # fn run_test() -> QueryResult<()> { /// # use schema::users::dsl::*; +/// # use diesel::sql_types::Bool; /// use diesel::dsl::sql; /// # let connection = establish_connection(); -/// let user = users.filter(sql("name = 'Sean'")).first(&connection)?; +/// let user = users.filter(sql::("name = 'Sean'")).first(&connection)?; /// let expected = (1, String::from("Sean")); /// assert_eq!(expected, user); /// # Ok(()) /// # } /// ``` -pub fn sql(sql: &str) -> SqlLiteral { +pub fn sql(sql: &str) -> SqlLiteral +where + ST: TypedExpressionType, +{ SqlLiteral::new(sql.into(), ()) } @@ -272,14 +286,14 @@ where /// # fn main() { /// # use self::users::dsl::*; /// # use diesel::dsl::sql; - /// # use diesel::sql_types::{Integer, Text}; + /// # use diesel::sql_types::{Integer, Bool}; /// # let connection = establish_connection(); /// # diesel::insert_into(users).values(name.eq("Ryan")) /// # .execute(&connection).unwrap(); /// let query = users /// .select(name) /// .filter( - /// sql("id > ") + /// sql::("id > ") /// .bind::(1) /// .sql(" AND name <> 'Ryan'") /// ) diff --git a/diesel/src/expression/subselect.rs b/diesel/src/expression/subselect.rs index 754d80df9f1e..a0bab4e8b240 100644 --- a/diesel/src/expression/subselect.rs +++ b/diesel/src/expression/subselect.rs @@ -4,6 +4,7 @@ use crate::expression::array_comparison::MaybeEmpty; use crate::expression::*; use crate::query_builder::*; use crate::result::QueryResult; +use crate::sql_types::SqlType; #[derive(Debug, Copy, Clone, QueryId)] pub struct Subselect { @@ -20,7 +21,10 @@ impl Subselect { } } -impl Expression for Subselect { +impl Expression for Subselect +where + ST: SqlType + TypedExpressionType, +{ type SqlType = ST; } diff --git a/diesel/src/expression_methods/bool_expression_methods.rs b/diesel/src/expression_methods/bool_expression_methods.rs index 663da37b2ddd..3125e321bc2d 100644 --- a/diesel/src/expression_methods/bool_expression_methods.rs +++ b/diesel/src/expression_methods/bool_expression_methods.rs @@ -1,7 +1,7 @@ use crate::expression::grouped::Grouped; use crate::expression::operators::{And, Or}; -use crate::expression::{AsExpression, Expression}; -use crate::sql_types::{Bool, Nullable}; +use crate::expression::{AsExpression, Expression, TypedExpressionType}; +use crate::sql_types::{BoolOrNullableBool, SqlType}; /// Methods present on boolean expressions pub trait BoolExpressionMethods: Expression + Sized { @@ -36,7 +36,13 @@ pub trait BoolExpressionMethods: Expression + Sized { /// assert_eq!(expected, data); /// # Ok(()) /// # } - fn and>(self, other: T) -> And { + fn and(self, other: T) -> And + where + Self::SqlType: SqlType, + ST: SqlType + TypedExpressionType, + T: AsExpression, + And: Expression, + { And::new(self, other.as_expression()) } @@ -77,7 +83,13 @@ pub trait BoolExpressionMethods: Expression + Sized { /// assert_eq!(expected, data); /// # Ok(()) /// # } - fn or>(self, other: T) -> Grouped> { + fn or(self, other: T) -> Grouped> + where + Self::SqlType: SqlType, + ST: SqlType + TypedExpressionType, + T: AsExpression, + Or: Expression, + { Grouped(Or::new(self, other.as_expression())) } } @@ -88,12 +100,3 @@ where T::SqlType: BoolOrNullableBool, { } - -#[doc(hidden)] -/// Marker trait used to implement `BoolExpressionMethods` on the appropriate -/// types. Once coherence takes associated types into account, we can remove -/// this trait. -pub trait BoolOrNullableBool {} - -impl BoolOrNullableBool for Bool {} -impl BoolOrNullableBool for Nullable {} diff --git a/diesel/src/expression_methods/global_expression_methods.rs b/diesel/src/expression_methods/global_expression_methods.rs index 88f30701b514..c0d4f7c798e7 100644 --- a/diesel/src/expression_methods/global_expression_methods.rs +++ b/diesel/src/expression_methods/global_expression_methods.rs @@ -1,7 +1,7 @@ use crate::expression::array_comparison::{AsInExpression, In, NotIn}; use crate::expression::operators::*; use crate::expression::{nullable, AsExpression, Expression}; -use crate::sql_types::SingleValue; +use crate::sql_types::{SingleValue, SqlType}; /// Methods present on all expressions, except tuples pub trait ExpressionMethods: Expression + Sized { @@ -19,7 +19,11 @@ pub trait ExpressionMethods: Expression + Sized { /// assert_eq!(Ok(1), data.first(&connection)); /// # } /// ``` - fn eq>(self, other: T) -> Eq { + fn eq(self, other: T) -> Eq + where + Self::SqlType: SqlType, + T: AsExpression, + { Eq::new(self, other.as_expression()) } @@ -37,7 +41,11 @@ pub trait ExpressionMethods: Expression + Sized { /// assert_eq!(Ok(2), data.first(&connection)); /// # } /// ``` - fn ne>(self, other: T) -> NotEq { + fn ne(self, other: T) -> NotEq + where + Self::SqlType: SqlType, + T: AsExpression, + { NotEq::new(self, other.as_expression()) } @@ -68,6 +76,7 @@ pub trait ExpressionMethods: Expression + Sized { /// ``` fn eq_any(self, values: T) -> In where + Self::SqlType: SqlType, T: AsInExpression, { In::new(self, values.as_in_expression()) @@ -103,6 +112,7 @@ pub trait ExpressionMethods: Expression + Sized { /// ``` fn ne_all(self, values: T) -> NotIn where + Self::SqlType: SqlType, T: AsInExpression, { NotIn::new(self, values.as_in_expression()) @@ -182,7 +192,11 @@ pub trait ExpressionMethods: Expression + Sized { /// # Ok(()) /// # } /// ``` - fn gt>(self, other: T) -> Gt { + fn gt(self, other: T) -> Gt + where + Self::SqlType: SqlType, + T: AsExpression, + { Gt::new(self, other.as_expression()) } @@ -208,7 +222,11 @@ pub trait ExpressionMethods: Expression + Sized { /// # Ok(()) /// # } /// ``` - fn ge>(self, other: T) -> GtEq { + fn ge(self, other: T) -> GtEq + where + Self::SqlType: SqlType, + T: AsExpression, + { GtEq::new(self, other.as_expression()) } @@ -234,7 +252,11 @@ pub trait ExpressionMethods: Expression + Sized { /// # Ok(()) /// # } /// ``` - fn lt>(self, other: T) -> Lt { + fn lt(self, other: T) -> Lt + where + Self::SqlType: SqlType, + T: AsExpression, + { Lt::new(self, other.as_expression()) } @@ -259,7 +281,11 @@ pub trait ExpressionMethods: Expression + Sized { /// assert_eq!("Sean", data); /// # Ok(()) /// # } - fn le>(self, other: T) -> LtEq { + fn le(self, other: T) -> LtEq + where + Self::SqlType: SqlType, + T: AsExpression, + { LtEq::new(self, other.as_expression()) } @@ -285,6 +311,7 @@ pub trait ExpressionMethods: Expression + Sized { /// ``` fn between(self, lower: T, upper: U) -> Between> where + Self::SqlType: SqlType, T: AsExpression, U: AsExpression, { @@ -320,6 +347,7 @@ pub trait ExpressionMethods: Expression + Sized { upper: U, ) -> NotBetween> where + Self::SqlType: SqlType, T: AsExpression, U: AsExpression, { @@ -365,11 +393,12 @@ pub trait ExpressionMethods: Expression + Sized { /// /// ```rust /// # include!("../doctest_setup.rs"); + /// # use diesel::expression::expression_types::NotSelectable; /// # /// # fn main() { /// # use schema::users::dsl::*; /// # let order = "name"; - /// let ordering: Box> = + /// let ordering: Box> = /// if order == "name" { /// Box::new(name.desc()) /// } else { diff --git a/diesel/src/expression_methods/text_expression_methods.rs b/diesel/src/expression_methods/text_expression_methods.rs index 913690a9eb62..8527edc68038 100644 --- a/diesel/src/expression_methods/text_expression_methods.rs +++ b/diesel/src/expression_methods/text_expression_methods.rs @@ -1,6 +1,6 @@ use crate::expression::operators::{Concat, Like, NotLike}; use crate::expression::{AsExpression, Expression}; -use crate::sql_types::{Nullable, Text}; +use crate::sql_types::{Nullable, SqlType, Text}; /// Methods present on text expressions pub trait TextExpressionMethods: Expression + Sized { @@ -54,7 +54,11 @@ pub trait TextExpressionMethods: Expression + Sized { /// assert_eq!(Ok(expected_names), names); /// # } /// ``` - fn concat>(self, other: T) -> Concat { + fn concat(self, other: T) -> Concat + where + Self::SqlType: SqlType, + T: AsExpression, + { Concat::new(self, other.as_expression()) } @@ -86,8 +90,12 @@ pub trait TextExpressionMethods: Expression + Sized { /// # Ok(()) /// # } /// ``` - fn like>(self, other: T) -> Like { - Like::new(self.as_expression(), other.as_expression()) + fn like(self, other: T) -> Like + where + Self::SqlType: SqlType, + T: AsExpression, + { + Like::new(self, other.as_expression()) } /// Returns a SQL `NOT LIKE` expression @@ -118,8 +126,12 @@ pub trait TextExpressionMethods: Expression + Sized { /// # Ok(()) /// # } /// ``` - fn not_like>(self, other: T) -> NotLike { - NotLike::new(self.as_expression(), other.as_expression()) + fn not_like(self, other: T) -> NotLike + where + Self::SqlType: SqlType, + T: AsExpression, + { + NotLike::new(self, other.as_expression()) } } diff --git a/diesel/src/macros/mod.rs b/diesel/src/macros/mod.rs index 28d8f5d5c659..3a7e0424ff71 100644 --- a/diesel/src/macros/mod.rs +++ b/diesel/src/macros/mod.rs @@ -101,9 +101,9 @@ macro_rules! __diesel_column { impl $crate::EqAll for $column_name where T: $crate::expression::AsExpression<$($Type)*>, - $crate::dsl::Eq<$column_name, T>: $crate::Expression, + $crate::dsl::Eq<$column_name, T::Expression>: $crate::Expression, { - type Output = $crate::dsl::Eq; + type Output = $crate::dsl::Eq; fn eq_all(self, rhs: T) -> Self::Output { $crate::expression::operators::Eq::new(self, rhs.as_expression()) @@ -180,6 +180,7 @@ macro_rules! __diesel_column { /// /// ``` /// # mod diesel_full_text_search { +/// # #[derive(diesel::sql_types::SqlType)] /// # pub struct TsVector; /// # } /// @@ -819,7 +820,7 @@ macro_rules! __diesel_table_impl { } impl Expression for star { - type SqlType = (); + type SqlType = $crate::expression::expression_types::NotSelectable; } impl QueryFragment for star where @@ -1057,19 +1058,6 @@ macro_rules! allow_tables_to_appear_in_same_query { () => {}; } -/// Gets the value out of an option, or returns an error. -/// -/// This is used by `FromSql` implementations. -#[macro_export] -macro_rules! not_none { - ($bytes:expr) => { - match $bytes { - Some(bytes) => bytes, - None => return Err(Box::new($crate::result::UnexpectedNullError)), - } - }; -} - // The order of these modules is important (at least for those which have tests). // Utility macros which don't call any others need to come first. #[macro_use] @@ -1093,7 +1081,7 @@ mod tests { } mod my_types { - #[derive(Debug, Clone, Copy)] + #[derive(Debug, Clone, Copy, crate::sql_types::SqlType)] pub struct MyCustomType; } @@ -1141,11 +1129,11 @@ mod tests { table_with_arbitrarily_complex_types { id -> sql_types::Integer, qualified_nullable -> sql_types::Nullable, - deeply_nested_type -> Option>, + deeply_nested_type -> Nullable>, // This actually should work, but there appears to be a rustc bug // on the `AsExpression` bound for `EqAll` when the ty param is a projection // projected_type -> as sql_types::IntoNullable>::Nullable, - random_tuple -> (Integer, Integer), + //random_tuple -> (Integer, Integer), } } diff --git a/diesel/src/macros/ops.rs b/diesel/src/macros/ops.rs index 63292c31e60f..eac8f12f6ab2 100644 --- a/diesel/src/macros/ops.rs +++ b/diesel/src/macros/ops.rs @@ -39,7 +39,7 @@ macro_rules! numeric_expr { #[doc(hidden)] macro_rules! __diesel_generate_ops_impls_if_numeric { ($column_name:ident, Nullable<$($inner:tt)::*>) => { __diesel_generate_ops_impls_if_numeric!($column_name, $($inner)::*); }; - + ($column_name:ident, Unsigned<$($inner:tt)::*>) => { __diesel_generate_ops_impls_if_numeric!($column_name, $($inner)::*); }; ($column_name:ident, SmallInt) => { numeric_expr!($column_name); }; diff --git a/diesel/src/mysql/backend.rs b/diesel/src/mysql/backend.rs index 7ac49e022e84..39482b871747 100644 --- a/diesel/src/mysql/backend.rs +++ b/diesel/src/mysql/backend.rs @@ -73,7 +73,7 @@ impl<'a> HasRawValue<'a> for Mysql { } impl TypeMetadata for Mysql { - type TypeMetadata = Option; + type TypeMetadata = MysqlType; type MetadataLookup = (); } diff --git a/diesel/src/mysql/connection/bind.rs b/diesel/src/mysql/connection/bind.rs index cf2a85040a09..fa877b39b555 100644 --- a/diesel/src/mysql/connection/bind.rs +++ b/diesel/src/mysql/connection/bind.rs @@ -1,5 +1,6 @@ use mysqlclient_sys as ffi; use std::mem; +use std::ops::Index; use std::os::raw as libc; use super::stmt::Statement; @@ -15,42 +16,22 @@ pub struct Binds { impl Binds { pub fn from_input_data(input: Iter) -> QueryResult where - Iter: IntoIterator, Option>)>, + Iter: IntoIterator>)>, { let data = input .into_iter() - .map(|(metadata, bytes)| { - if let Some(metadata) = metadata { - Ok(BindData::for_input(metadata, bytes)) - } else { - Err("Unknown bind type.") - } - }) - .collect::, _>>() - .map_err(|e| crate::result::Error::QueryBuilderError(e.into()))?; + .map(BindData::for_input) + .collect::>(); Ok(Binds { data }) } - pub fn from_output_types( - types: Vec>, - metadata: Option<&StatementMetadata>, - ) -> Self { - let data = if let Some(metadata) = metadata { - metadata - .fields() - .iter() - .map(|f| (f.field_type(), f.flags())) - .map(BindData::for_output) - .collect() - } else { - types - .into_iter() - .map(|metadata| metadata.expect("We checked that before calling from_output_types, otherwise we would have passed metadata")) - .map(|metadata| metadata.into()) - .map(BindData::for_output) - .collect() - }; + pub fn from_output_types(types: Vec) -> Self { + let data = types + .into_iter() + .map(|metadata| metadata.into()) + .map(BindData::for_output) + .collect(); Binds { data } } @@ -101,19 +82,18 @@ impl Binds { } } - pub fn field_data(&self, idx: usize) -> Option> { - let data = &self.data[idx]; - self.data[idx].bytes().map(|bytes| { - let tpe = (data.tpe, data.flags).into(); - MysqlValue::new(bytes, tpe) - }) - } - pub fn len(&self) -> usize { self.data.len() } } +impl Index for Binds { + type Output = BindData; + fn index(&self, index: usize) -> &Self::Output { + &self.data[index] + } +} + bitflags::bitflags! { pub(crate) struct Flags: u32 { const NOT_NULL_FLAG = 1; @@ -150,7 +130,8 @@ impl From for Flags { } } -struct BindData { +#[derive(Debug)] +pub struct BindData { tpe: ffi::enum_field_types, bytes: Vec, length: libc::c_ulong, @@ -160,7 +141,7 @@ struct BindData { } impl BindData { - fn for_input(tpe: MysqlType, data: Option>) -> Self { + fn for_input((tpe, data): (MysqlType, Option>)) -> Self { let is_null = if data.is_none() { 1 } else { 0 }; let bytes = data.unwrap_or_default(); let length = bytes.len() as libc::c_ulong; @@ -199,14 +180,19 @@ impl BindData { known_buffer_size_for_ffi_type(self.tpe).is_some() } - fn bytes(&self) -> Option<&[u8]> { - if self.is_null == 0 { - Some(&*self.bytes) - } else { + pub fn value(&'_ self) -> Option> { + if self.is_null() { None + } else { + let tpe = (self.tpe, self.flags).into(); + Some(MysqlValue::new(&self.bytes, tpe)) } } + pub fn is_null(&self) -> bool { + self.is_null != 0 + } + fn update_buffer_length(&mut self) { use std::cmp::min; @@ -473,7 +459,7 @@ mod tests { let meta = (bind.tpe, bind.flags).into(); dbg!(meta); let value = MysqlValue::new(&bind.bytes, meta); - dbg!(T::from_sql(Some(value))) + dbg!(T::from_sql(value)) } #[cfg(feature = "extras")] diff --git a/diesel/src/mysql/connection/mod.rs b/diesel/src/mysql/connection/mod.rs index b675a216f51a..44824c326222 100644 --- a/diesel/src/mysql/connection/mod.rs +++ b/diesel/src/mysql/connection/mod.rs @@ -8,11 +8,11 @@ use self::stmt::Statement; use self::url::ConnectionOptions; use super::backend::Mysql; use crate::connection::*; -use crate::deserialize::{Queryable, QueryableByName}; +use crate::deserialize::{FromSqlRow, IsCompatibleType}; +use crate::expression::TypedExpressionType; use crate::query_builder::bind_collector::RawBytesBindCollector; use crate::query_builder::*; use crate::result::*; -use crate::sql_types::HasSqlType; #[allow(missing_debug_implementations, missing_copy_implementations)] /// A connection to a MySQL database. Connection URLs should be in the form @@ -60,38 +60,20 @@ impl Connection for MysqlConnection { } #[doc(hidden)] - fn query_by_index(&self, source: T) -> QueryResult> + fn load(&self, source: T) -> QueryResult> where T: AsQuery, T::Query: QueryFragment + QueryId, - Self::Backend: HasSqlType, - U: Queryable, + U: FromSqlRow, + T::SqlType: IsCompatibleType, + ST: TypedExpressionType, { - use crate::deserialize::FromSqlRow; use crate::result::Error::DeserializationError; let mut stmt = self.prepare_query(&source.as_query())?; - let mut metadata = Vec::new(); - Mysql::mysql_row_metadata(&mut metadata, &()); + let metadata = T::SqlType::mysql_row_metadata(&()); let results = unsafe { stmt.results(metadata)? }; - results.map(|mut row| { - U::Row::build_from_row(&mut row) - .map(U::build) - .map_err(DeserializationError) - }) - } - - #[doc(hidden)] - fn query_by_name(&self, source: &T) -> QueryResult> - where - T: QueryFragment + QueryId, - U: QueryableByName, - { - use crate::result::Error::DeserializationError; - - let mut stmt = self.prepare_query(source)?; - let results = unsafe { stmt.named_results()? }; - results.map(|row| U::build(&row).map_err(DeserializationError)) + results.map(|row| U::build_from_row(&row).map_err(DeserializationError)) } #[doc(hidden)] diff --git a/diesel/src/mysql/connection/stmt/iterator.rs b/diesel/src/mysql/connection/stmt/iterator.rs index a0728c858900..ea6bd95de02b 100644 --- a/diesel/src/mysql/connection/stmt/iterator.rs +++ b/diesel/src/mysql/connection/stmt/iterator.rs @@ -1,7 +1,5 @@ -use std::collections::HashMap; - -use super::{Binds, Statement, StatementMetadata}; -use crate::mysql::{Mysql, MysqlType, MysqlValue}; +use super::{metadata::MysqlFieldMetadata, BindData, Binds, Statement, StatementMetadata}; +use crate::mysql::{Mysql, MysqlType}; use crate::result::QueryResult; use crate::row::*; @@ -14,12 +12,12 @@ pub struct StatementIterator<'a> { #[allow(clippy::should_implement_trait)] // don't neet `Iterator` here impl<'a> StatementIterator<'a> { #[allow(clippy::new_ret_no_self)] - pub fn new(stmt: &'a mut Statement, types: Vec>) -> QueryResult { + pub fn new(stmt: &'a mut Statement, types: Option>) -> QueryResult { let metadata = stmt.metadata()?; - let mut output_binds = if types.iter().any(Option::is_none) { - Binds::from_output_types(types, Some(&metadata)) + let mut output_binds = if let Some(types) = types { + Binds::from_output_types(types) } else { - Binds::from_output_types(types, None) + Binds::from_result_metadata(&metadata) }; stmt.execute_statement(&mut output_binds)?; @@ -55,98 +53,73 @@ impl<'a> StatementIterator<'a> { } } +#[derive(Clone)] pub struct MysqlRow<'a> { col_idx: usize, binds: &'a Binds, metadata: &'a StatementMetadata, } -impl<'a> Row for MysqlRow<'a> { - fn take(&mut self) -> Option> { - let current_idx = self.col_idx; - self.col_idx += 1; - self.binds.field_data(current_idx) - } - - fn next_is_null(&self, count: usize) -> bool { - (0..count).all(|i| self.binds.field_data(self.col_idx + i).is_none()) - } +impl<'a> Row<'a, Mysql> for MysqlRow<'a> { + type Field = MysqlField<'a>; + type InnerPartialRow = Self; - fn column_count(&self) -> usize { + fn field_count(&self) -> usize { self.binds.len() } - fn column_name(&self) -> Option<&str> { - self.metadata.fields()[self.col_idx].field_name() + fn get(&self, idx: I) -> Option + where + Self: RowIndex, + { + let idx = self.idx(idx)?; + Some(MysqlField { + bind: &self.binds[idx], + metadata: &self.metadata.fields()[idx], + }) } -} -pub struct NamedStatementIterator<'a> { - stmt: &'a mut Statement, - output_binds: Binds, - metadata: StatementMetadata, -} - -#[allow(clippy::should_implement_trait)] // don't need `Iterator` here -impl<'a> NamedStatementIterator<'a> { - #[allow(clippy::new_ret_no_self)] - pub fn new(stmt: &'a mut Statement) -> QueryResult { - let metadata = stmt.metadata()?; - let mut output_binds = Binds::from_result_metadata(&metadata); - - stmt.execute_statement(&mut output_binds)?; - - Ok(NamedStatementIterator { - stmt, - output_binds, - metadata, - }) + fn partial_row(&self, range: std::ops::Range) -> PartialRow { + PartialRow::new(self, range) } +} - pub fn map(mut self, mut f: F) -> QueryResult> - where - F: FnMut(NamedMysqlRow) -> QueryResult, - { - let mut results = Vec::new(); - while let Some(row) = self.next() { - results.push(f(row?)?); +impl<'a> RowIndex for MysqlRow<'a> { + fn idx(&self, idx: usize) -> Option { + if idx < self.field_count() { + Some(idx) + } else { + None } - Ok(results) } +} - fn next(&mut self) -> Option> { - match self.stmt.populate_row_buffers(&mut self.output_binds) { - Ok(Some(())) => Some(Ok(NamedMysqlRow { - binds: &self.output_binds, - column_indices: self.metadata.column_indices(), - metadata: &self.metadata, - })), - Ok(None) => None, - Err(e) => Some(Err(e)), - } +impl<'a> RowIndex<&'a str> for MysqlRow<'a> { + fn idx(&self, idx: &'a str) -> Option { + self.metadata + .fields() + .iter() + .enumerate() + .find(|(_, field_meta)| field_meta.field_name() == Some(idx)) + .map(|(idx, _)| idx) } } -pub struct NamedMysqlRow<'a> { - binds: &'a Binds, - column_indices: &'a HashMap<&'a str, usize>, - metadata: &'a StatementMetadata, +pub struct MysqlField<'a> { + bind: &'a BindData, + metadata: &'a MysqlFieldMetadata<'a>, } -impl<'a> NamedRow for NamedMysqlRow<'a> { - fn index_of(&self, column_name: &str) -> Option { - self.column_indices.get(column_name).cloned() +impl<'a> Field<'a, Mysql> for MysqlField<'a> { + fn field_name(&self) -> Option<&str> { + self.metadata.field_name() } - fn get_raw_value(&self, idx: usize) -> Option> { - self.binds.field_data(idx) + fn is_null(&self) -> bool { + self.bind.is_null() } - fn field_names(&self) -> Vec<&str> { - self.metadata - .fields() - .iter() - .filter_map(|f| f.field_name()) - .collect() + fn value(&self) -> Option> { + self.bind.value() } } diff --git a/diesel/src/mysql/connection/stmt/metadata.rs b/diesel/src/mysql/connection/stmt/metadata.rs index 29a58c84fccd..8e302e600b89 100644 --- a/diesel/src/mysql/connection/stmt/metadata.rs +++ b/diesel/src/mysql/connection/stmt/metadata.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::ffi::CStr; use std::ptr::NonNull; use std::slice; @@ -8,25 +7,14 @@ use crate::mysql::connection::bind::Flags; pub struct StatementMetadata { result: NonNull, - // The strings in this hash map are only valid - // as long as we do not free the result pointer above - // We use a 'static lifetime here, because we cannot - // have a self referential lifetime. - // Therefore this lifetime must not leave this module - column_indices: HashMap<&'static str, usize>, } impl StatementMetadata { pub fn new(result: NonNull) -> Self { - let mut res = StatementMetadata { - column_indices: HashMap::new(), - result, - }; - res.populate_column_indices(); - res + StatementMetadata { result } } - pub fn fields<'a>(&'a self) -> &'a [MysqlFieldMetadata] { + pub fn fields(&'_ self) -> &'_ [MysqlFieldMetadata<'_>] { unsafe { let num_fields = ffi::mysql_num_fields(self.result.as_ptr()); let field_ptr = ffi::mysql_fetch_fields(self.result.as_ptr()); @@ -37,32 +25,6 @@ impl StatementMetadata { } } } - - pub fn column_indices<'a>(&'a self) -> &'a HashMap<&'a str, usize> { - &self.column_indices - } - - fn populate_column_indices(&mut self) { - self.column_indices = self - .fields() - .iter() - .enumerate() - .filter_map(|(i, field)| unsafe { - // This is highly unsafe because we create strings slices with a static life time - // * We cannot use `MysqlFieldMetadata` because of this reason - // * We cannot have a concrete life time because otherwise this would be - // an self referential struct - // * This relies on the invariant that non of the slices leave this - // type with anything other then a concrete life time bound to this - // type - if field.0.name.is_null() { - None - } else { - CStr::from_ptr(field.0.name).to_str().ok().map(|f| (f, i)) - } - }) - .collect() - } } impl Drop for StatementMetadata { diff --git a/diesel/src/mysql/connection/stmt/mod.rs b/diesel/src/mysql/connection/stmt/mod.rs index bb86524fef38..ee58e68886e2 100644 --- a/diesel/src/mysql/connection/stmt/mod.rs +++ b/diesel/src/mysql/connection/stmt/mod.rs @@ -8,7 +8,7 @@ use std::os::raw as libc; use std::ptr::NonNull; use self::iterator::*; -use super::bind::Binds; +use super::bind::{BindData, Binds}; use crate::mysql::MysqlType; use crate::result::{DatabaseErrorKind, QueryResult}; @@ -40,7 +40,7 @@ impl Statement { pub fn bind(&mut self, binds: Iter) -> QueryResult<()> where - Iter: IntoIterator, Option>)>, + Iter: IntoIterator>)>, { let input_binds = Binds::from_input_data(binds)?; self.input_bind(input_binds) @@ -79,18 +79,11 @@ impl Statement { /// be called on this statement. pub unsafe fn results( &mut self, - types: Vec>, + types: Option>, ) -> QueryResult { StatementIterator::new(self, types) } - /// This function should be called instead of `execute` for queries which - /// have a return value. After calling this function, `execute` can never - /// be called on this statement. - pub unsafe fn named_results(&mut self) -> QueryResult { - NamedStatementIterator::new(self) - } - fn last_error_message(&self) -> String { unsafe { CStr::from_ptr(ffi::mysql_stmt_error(self.stmt.as_ptr())) } .to_string_lossy() diff --git a/diesel/src/mysql/types/date_and_time.rs b/diesel/src/mysql/types/date_and_time.rs index 9709c1d2ef3f..093dd49fe0bd 100644 --- a/diesel/src/mysql/types/date_and_time.rs +++ b/diesel/src/mysql/types/date_and_time.rs @@ -24,9 +24,8 @@ macro_rules! mysql_time_impls { } impl FromSql<$ty, Mysql> for MYSQL_TIME { - fn from_sql(value: Option>) -> deserialize::Result { - let data = not_none!(value); - data.time_value() + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { + value.time_value() } } }; @@ -44,7 +43,7 @@ impl ToSql for NaiveDateTime { } impl FromSql for NaiveDateTime { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result { >::from_sql(bytes) } } @@ -69,7 +68,7 @@ impl ToSql for NaiveDateTime { } impl FromSql for NaiveDateTime { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result { let mysql_time = >::from_sql(bytes)?; NaiveDate::from_ymd_opt( @@ -109,7 +108,7 @@ impl ToSql for NaiveTime { } impl FromSql for NaiveTime { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result { let mysql_time = >::from_sql(bytes)?; NaiveTime::from_hms_opt( mysql_time.hour as u32, @@ -140,7 +139,7 @@ impl ToSql for NaiveDate { } impl FromSql for NaiveDate { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result { let mysql_time = >::from_sql(bytes)?; NaiveDate::from_ymd_opt( mysql_time.year as i32, diff --git a/diesel/src/mysql/types/json.rs b/diesel/src/mysql/types/json.rs index 237027b13b91..8d67636995c2 100644 --- a/diesel/src/mysql/types/json.rs +++ b/diesel/src/mysql/types/json.rs @@ -5,8 +5,7 @@ use crate::sql_types; use std::io::prelude::*; impl FromSql for serde_json::Value { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { serde_json::from_slice(value.as_bytes()).map_err(|_| "Invalid Json".into()) } } @@ -31,25 +30,24 @@ fn json_to_sql() { fn some_json_from_sql() { use crate::mysql::MysqlType; let input_json = b"true"; - let output_json: serde_json::Value = FromSql::::from_sql(Some( - MysqlValue::new(input_json, MysqlType::String), - )) - .unwrap(); + let output_json: serde_json::Value = + FromSql::::from_sql(MysqlValue::new(input_json, MysqlType::String)) + .unwrap(); assert_eq!(output_json, serde_json::Value::Bool(true)); } #[test] fn bad_json_from_sql() { use crate::mysql::MysqlType; - let uuid: Result = FromSql::::from_sql(Some( - MysqlValue::new(b"boom", MysqlType::String), - )); + let uuid: Result = + FromSql::::from_sql(MysqlValue::new(b"boom", MysqlType::String)); assert_eq!(uuid.unwrap_err().to_string(), "Invalid Json"); } #[test] fn no_json_from_sql() { - let uuid: Result = FromSql::::from_sql(None); + let uuid: Result = + FromSql::::from_nullable_sql(None); assert_eq!( uuid.unwrap_err().to_string(), "Unexpected null for non-null column" diff --git a/diesel/src/mysql/types/mod.rs b/diesel/src/mysql/types/mod.rs index 73b4c7b9ef90..701a5f32d15d 100644 --- a/diesel/src/mysql/types/mod.rs +++ b/diesel/src/mysql/types/mod.rs @@ -46,8 +46,7 @@ impl ToSql for i8 { } impl FromSql for i8 { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { let bytes = value.as_bytes(); Ok(bytes[0] as i8) } @@ -96,7 +95,7 @@ impl ToSql, Mysql> for u8 { } impl FromSql, Mysql> for u8 { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result { let signed: i8 = FromSql::::from_sql(bytes)?; Ok(signed as u8) } @@ -109,7 +108,7 @@ impl ToSql, Mysql> for u16 { } impl FromSql, Mysql> for u16 { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result { let signed: i16 = FromSql::::from_sql(bytes)?; Ok(signed as u16) } @@ -122,7 +121,7 @@ impl ToSql, Mysql> for u32 { } impl FromSql, Mysql> for u32 { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result { let signed: i32 = FromSql::::from_sql(bytes)?; Ok(signed as u32) } @@ -135,7 +134,7 @@ impl ToSql, Mysql> for u64 { } impl FromSql, Mysql> for u64 { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result { let signed: i64 = FromSql::::from_sql(bytes)?; Ok(signed as u64) } @@ -149,32 +148,32 @@ impl ToSql for bool { } impl FromSql for bool { - fn from_sql(bytes: Option>) -> deserialize::Result { - Ok(not_none!(bytes).as_bytes().iter().any(|x| *x != 0)) + fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result { + Ok(bytes.as_bytes().iter().any(|x| *x != 0)) } } impl HasSqlType> for Mysql { - fn metadata(_lookup: &()) -> Option { - Some(MysqlType::UnsignedTiny) + fn metadata(_lookup: &()) -> MysqlType { + MysqlType::UnsignedTiny } } impl HasSqlType> for Mysql { - fn metadata(_lookup: &()) -> Option { - Some(MysqlType::UnsignedShort) + fn metadata(_lookup: &()) -> MysqlType { + MysqlType::UnsignedShort } } impl HasSqlType> for Mysql { - fn metadata(_lookup: &()) -> Option { - Some(MysqlType::UnsignedLong) + fn metadata(_lookup: &()) -> MysqlType { + MysqlType::UnsignedLong } } impl HasSqlType> for Mysql { - fn metadata(_lookup: &()) -> Option { - Some(MysqlType::UnsignedLongLong) + fn metadata(_lookup: &()) -> MysqlType { + MysqlType::UnsignedLongLong } } diff --git a/diesel/src/mysql/types/numeric.rs b/diesel/src/mysql/types/numeric.rs index 38b6f23d9315..783f813383c7 100644 --- a/diesel/src/mysql/types/numeric.rs +++ b/diesel/src/mysql/types/numeric.rs @@ -19,10 +19,10 @@ pub mod bigdecimal { } impl FromSql for BigDecimal { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { use crate::mysql::NumericRepresentation::*; - let data = not_none!(value); - match data.numeric_value()? { + + match value.numeric_value()? { Tiny(x) => Ok(x.into()), Small(x) => Ok(x.into()), Medium(x) => Ok(x.into()), diff --git a/diesel/src/mysql/types/primitives.rs b/diesel/src/mysql/types/primitives.rs index 6e6d9c0f5733..d2583d91a1e9 100644 --- a/diesel/src/mysql/types/primitives.rs +++ b/diesel/src/mysql/types/primitives.rs @@ -29,11 +29,10 @@ where } impl FromSql for i16 { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { use crate::mysql::NumericRepresentation::*; - let data = not_none!(value); - match data.numeric_value()? { + match value.numeric_value()? { Tiny(x) => Ok(x.into()), Small(x) => Ok(x), Medium(x) => Ok(x as Self), @@ -46,11 +45,10 @@ impl FromSql for i16 { } impl FromSql for i32 { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { use crate::mysql::NumericRepresentation::*; - let data = not_none!(value); - match data.numeric_value()? { + match value.numeric_value()? { Tiny(x) => Ok(x.into()), Small(x) => Ok(x.into()), Medium(x) => Ok(x), @@ -63,11 +61,10 @@ impl FromSql for i32 { } impl FromSql for i64 { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { use crate::mysql::NumericRepresentation::*; - let data = not_none!(value); - match data.numeric_value()? { + match value.numeric_value()? { Tiny(x) => Ok(x.into()), Small(x) => Ok(x.into()), Medium(x) => Ok(x.into()), @@ -80,11 +77,10 @@ impl FromSql for i64 { } impl FromSql for f32 { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { use crate::mysql::NumericRepresentation::*; - let data = not_none!(value); - match data.numeric_value()? { + match value.numeric_value()? { Tiny(x) => Ok(x.into()), Small(x) => Ok(x.into()), Medium(x) => Ok(x as Self), @@ -97,11 +93,10 @@ impl FromSql for f32 { } impl FromSql for f64 { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { use crate::mysql::NumericRepresentation::*; - let data = not_none!(value); - match data.numeric_value()? { + match value.numeric_value()? { Tiny(x) => Ok(x.into()), Small(x) => Ok(x.into()), Medium(x) => Ok(x.into()), @@ -114,15 +109,13 @@ impl FromSql for f64 { } impl FromSql for String { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { String::from_utf8(value.as_bytes().into()).map_err(Into::into) } } impl FromSql for Vec { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { Ok(value.as_bytes().into()) } } diff --git a/diesel/src/pg/backend.rs b/diesel/src/pg/backend.rs index 5d78758a561e..a374fae4d738 100644 --- a/diesel/src/pg/backend.rs +++ b/diesel/src/pg/backend.rs @@ -7,7 +7,7 @@ use super::{PgMetadataLookup, PgValue}; use crate::backend::*; use crate::deserialize::Queryable; use crate::query_builder::bind_collector::RawBytesBindCollector; -use crate::sql_types::{Oid, TypeMetadata}; +use crate::sql_types::TypeMetadata; /// The PostgreSQL backend #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] @@ -16,7 +16,7 @@ pub struct Pg; /// The [OIDs] for a SQL type /// /// [OIDs]: https://www.postgresql.org/docs/current/static/datatype-oid.html -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Default)] +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Default, Queryable)] pub struct PgTypeMetadata { /// The [OID] of `T` /// @@ -28,14 +28,6 @@ pub struct PgTypeMetadata { pub array_oid: u32, } -impl Queryable<(Oid, Oid), Pg> for PgTypeMetadata { - type Row = (u32, u32); - - fn build((oid, array_oid): Self::Row) -> Self { - PgTypeMetadata { oid, array_oid } - } -} - impl Backend for Pg { type QueryBuilder = PgQueryBuilder; type BindCollector = RawBytesBindCollector; diff --git a/diesel/src/pg/connection/cursor.rs b/diesel/src/pg/connection/cursor.rs index 698de6899b52..e1973050cf36 100644 --- a/diesel/src/pg/connection/cursor.rs +++ b/diesel/src/pg/connection/cursor.rs @@ -1,77 +1,45 @@ use super::result::PgResult; -use super::row::PgNamedRow; -use crate::deserialize::{FromSqlRow, Queryable, QueryableByName}; -use crate::pg::Pg; -use crate::result::Error::DeserializationError; -use crate::result::QueryResult; - -use std::marker::PhantomData; +use super::row::PgRow; /// The type returned by various [`Connection`](struct.Connection.html) methods. /// Acts as an iterator over `T`. -pub struct Cursor { +pub struct Cursor<'a> { current_row: usize, - db_result: PgResult, - _marker: PhantomData<(ST, T)>, + db_result: &'a PgResult, } -impl Cursor { - #[doc(hidden)] - pub fn new(db_result: PgResult) -> Self { +impl<'a> Cursor<'a> { + pub(super) fn new(db_result: &'a PgResult) -> Self { Cursor { current_row: 0, db_result, - _marker: PhantomData, } } } -impl Iterator for Cursor -where - T: Queryable, -{ - type Item = QueryResult; +impl<'a> ExactSizeIterator for Cursor<'a> {} + +impl<'a> Iterator for Cursor<'a> { + type Item = PgRow<'a>; fn next(&mut self) -> Option { if self.current_row >= self.db_result.num_rows() { None } else { - let mut row = self.db_result.get_row(self.current_row); + let row = self.db_result.get_row(self.current_row); self.current_row += 1; - let value = T::Row::build_from_row(&mut row) - .map(T::build) - .map_err(DeserializationError); - Some(value) - } - } -} -pub struct NamedCursor { - pub(crate) db_result: PgResult, -} - -impl NamedCursor { - pub fn new(db_result: PgResult) -> Self { - NamedCursor { db_result } - } - - pub fn collect(self) -> QueryResult> - where - T: QueryableByName, - { - (0..self.db_result.num_rows()) - .map(|i| { - let row = PgNamedRow::new(&self, i); - T::build(&row).map_err(DeserializationError) - }) - .collect() + Some(row) + } } - pub fn index_of_column(&self, column_name: &str) -> Option { - self.db_result.field_number(column_name) + fn nth(&mut self, n: usize) -> Option { + self.current_row += n; + self.next() } - pub fn get_value(&self, row: usize, column: usize) -> Option<&[u8]> { - self.db_result.get(row, column) + fn size_hint(&self) -> (usize, Option) { + let len = self.db_result.num_rows(); + (len, Some(len)) } } diff --git a/diesel/src/pg/connection/mod.rs b/diesel/src/pg/connection/mod.rs index 121180bf6248..0eada7b9f03f 100644 --- a/diesel/src/pg/connection/mod.rs +++ b/diesel/src/pg/connection/mod.rs @@ -13,13 +13,14 @@ use self::raw::RawConnection; use self::result::PgResult; use self::stmt::Statement; use crate::connection::*; -use crate::deserialize::{Queryable, QueryableByName}; +use crate::deserialize::{FromSqlRow, IsCompatibleType}; +use crate::expression::TypedExpressionType; use crate::pg::{metadata_lookup::PgMetadataCache, Pg, PgMetadataLookup, TransactionBuilder}; use crate::query_builder::bind_collector::RawBytesBindCollector; use crate::query_builder::*; use crate::result::ConnectionError::CouldntSetupConfiguration; +use crate::result::Error::DeserializationError; use crate::result::*; -use crate::sql_types::HasSqlType; /// The connection string expected by `PgConnection::establish` /// should be a PostgreSQL connection string, as documented at @@ -67,29 +68,21 @@ impl Connection for PgConnection { } #[doc(hidden)] - fn query_by_index(&self, source: T) -> QueryResult> + fn load(&self, source: T) -> QueryResult> where T: AsQuery, - T::Query: QueryFragment + QueryId, - Pg: HasSqlType, - U: Queryable, + T::Query: QueryFragment + QueryId, + U: FromSqlRow, + T::SqlType: IsCompatibleType, + ST: TypedExpressionType, { let (query, params) = self.prepare_query(&source.as_query())?; - query - .execute(self, ¶ms) - .and_then(|r| Cursor::new(r).collect()) - } + let result = query.execute(self, ¶ms)?; + let cursor = Cursor::new(&result); - #[doc(hidden)] - fn query_by_name(&self, source: &T) -> QueryResult> - where - T: QueryFragment + QueryId, - U: QueryableByName, - { - let (query, params) = self.prepare_query(source)?; - query - .execute(self, ¶ms) - .and_then(|r| NamedCursor::new(r).collect()) + cursor + .map(|row| U::build_from_row(&row).map_err(DeserializationError)) + .collect::>>() } #[doc(hidden)] diff --git a/diesel/src/pg/connection/result.rs b/diesel/src/pg/connection/result.rs index e84176e6da5e..10913313569f 100644 --- a/diesel/src/pg/connection/result.rs +++ b/diesel/src/pg/connection/result.rs @@ -1,7 +1,7 @@ extern crate pq_sys; use self::pq_sys::*; -use std::ffi::{CStr, CString}; +use std::ffi::CStr; use std::num::NonZeroU32; use std::os::raw as libc; use std::{slice, str}; @@ -12,6 +12,8 @@ use crate::result::{DatabaseErrorInformation, DatabaseErrorKind, Error, QueryRes pub struct PgResult { internal_result: RawResult, + column_count: usize, + row_count: usize, } impl PgResult { @@ -21,7 +23,15 @@ impl PgResult { let result_status = unsafe { PQresultStatus(internal_result.as_ptr()) }; match result_status { - PGRES_COMMAND_OK | PGRES_TUPLES_OK => Ok(PgResult { internal_result }), + PGRES_COMMAND_OK | PGRES_TUPLES_OK => { + let column_count = unsafe { PQnfields(internal_result.as_ptr()) as usize }; + let row_count = unsafe { PQntuples(internal_result.as_ptr()) as usize }; + Ok(PgResult { + internal_result, + column_count, + row_count, + }) + } PGRES_EMPTY_QUERY => { let error_message = "Received an empty query".to_string(); Err(Error::DatabaseError( @@ -71,7 +81,7 @@ impl PgResult { } pub fn num_rows(&self) -> usize { - unsafe { PQntuples(self.internal_result.as_ptr()) as usize } + self.row_count } pub fn get_row(&self, idx: usize) -> PgRow { @@ -105,11 +115,10 @@ impl PgResult { pub fn column_type(&self, col_idx: usize) -> NonZeroU32 { unsafe { - NonZeroU32::new(PQftype( + NonZeroU32::new_unchecked(PQftype( self.internal_result.as_ptr(), col_idx as libc::c_int, )) - .expect("Oid's aren't zero") } } @@ -125,16 +134,7 @@ impl PgResult { } pub fn column_count(&self) -> usize { - unsafe { PQnfields(self.internal_result.as_ptr()) as usize } - } - - pub fn field_number(&self, column_name: &str) -> Option { - let cstr = CString::new(column_name).unwrap_or_default(); - let fnum = unsafe { PQfnumber(self.internal_result.as_ptr(), cstr.as_ptr()) }; - match fnum { - -1 => None, - x => Some(x as usize), - } + self.column_count } } diff --git a/diesel/src/pg/connection/row.rs b/diesel/src/pg/connection/row.rs index bc94e0f9a7b9..fe1fb2d14502 100644 --- a/diesel/src/pg/connection/row.rs +++ b/diesel/src/pg/connection/row.rs @@ -1,70 +1,75 @@ -use super::cursor::NamedCursor; use super::result::PgResult; use crate::pg::{Pg, PgValue}; use crate::row::*; +#[derive(Clone)] pub struct PgRow<'a> { db_result: &'a PgResult, row_idx: usize, - col_idx: usize, } impl<'a> PgRow<'a> { pub fn new(db_result: &'a PgResult, row_idx: usize) -> Self { - PgRow { - db_result, - row_idx, - col_idx: 0, - } + PgRow { row_idx, db_result } } } -impl<'a> Row for PgRow<'a> { - fn take(&mut self) -> Option> { - let current_idx = self.col_idx; - self.col_idx += 1; - let raw = self.db_result.get(self.row_idx, current_idx)?; +impl<'a> Row<'a, Pg> for PgRow<'a> { + type Field = PgField<'a>; + type InnerPartialRow = Self; - Some(PgValue::new(raw, self.db_result.column_type(current_idx))) + fn field_count(&self) -> usize { + self.db_result.column_count() } - fn next_is_null(&self, count: usize) -> bool { - (0..count).all(|i| self.db_result.is_null(self.row_idx, self.col_idx + i)) + fn get(&self, idx: I) -> Option + where + Self: RowIndex, + { + let idx = self.idx(idx)?; + if idx < self.field_count() { + Some(PgField { + db_result: self.db_result, + row_idx: self.row_idx, + col_idx: idx, + }) + } else { + None + } } - fn column_count(&self) -> usize { - self.db_result.column_count() + fn partial_row(&self, range: std::ops::Range) -> PartialRow { + PartialRow::new(self, range) } +} - fn column_name(&self) -> Option<&str> { - self.db_result.column_name(self.col_idx) +impl<'a> RowIndex for PgRow<'a> { + fn idx(&self, idx: usize) -> Option { + Some(idx) } } -pub struct PgNamedRow<'a> { - cursor: &'a NamedCursor, - idx: usize, +impl<'a> RowIndex<&'a str> for PgRow<'a> { + fn idx(&self, field_name: &'a str) -> Option { + (0..self.field_count()).find(|idx| self.db_result.column_name(*idx) == Some(field_name)) + } } -impl<'a> PgNamedRow<'a> { - pub fn new(cursor: &'a NamedCursor, idx: usize) -> Self { - PgNamedRow { cursor, idx } - } +pub struct PgField<'a> { + db_result: &'a PgResult, + row_idx: usize, + col_idx: usize, } -impl<'a> NamedRow for PgNamedRow<'a> { - fn get_raw_value(&self, index: usize) -> Option> { - let raw = self.cursor.get_value(self.idx, index)?; - Some(PgValue::new(raw, self.cursor.db_result.column_type(index))) +impl<'a> Field<'a, Pg> for PgField<'a> { + fn field_name(&self) -> Option<&str> { + self.db_result.column_name(self.col_idx) } - fn index_of(&self, column_name: &str) -> Option { - self.cursor.index_of_column(column_name) - } + fn value(&self) -> Option> { + let raw = self.db_result.get(self.row_idx, self.col_idx)?; + let type_oid = self.db_result.column_type(self.col_idx); - fn field_names(&self) -> Vec<&str> { - (0..self.cursor.db_result.column_count()) - .filter_map(|i| self.cursor.db_result.column_name(i)) - .collect() + Some(PgValue::new(raw, type_oid)) } } diff --git a/diesel/src/pg/expression/array_comparison.rs b/diesel/src/pg/expression/array_comparison.rs index db39198b3911..0dd1fc4cb047 100644 --- a/diesel/src/pg/expression/array_comparison.rs +++ b/diesel/src/pg/expression/array_comparison.rs @@ -1,9 +1,9 @@ use crate::expression::subselect::Subselect; -use crate::expression::{AsExpression, Expression, ValidGrouping}; +use crate::expression::{AsExpression, Expression, TypedExpressionType, ValidGrouping}; use crate::pg::Pg; use crate::query_builder::*; use crate::result::QueryResult; -use crate::sql_types::Array; +use crate::sql_types::{Array, SqlType}; /// Creates a PostgreSQL `ANY` expression. /// @@ -75,6 +75,7 @@ impl Any { impl Expression for Any where Expr: Expression>, + ST: SqlType + TypedExpressionType, { type SqlType = ST; } @@ -108,6 +109,7 @@ impl All { impl Expression for All where Expr: Expression>, + ST: SqlType + TypedExpressionType, { type SqlType = ST; } diff --git a/diesel/src/pg/expression/date_and_time.rs b/diesel/src/pg/expression/date_and_time.rs index f0cfff226b48..94ef7362ae46 100644 --- a/diesel/src/pg/expression/date_and_time.rs +++ b/diesel/src/pg/expression/date_and_time.rs @@ -2,15 +2,14 @@ use crate::expression::{Expression, ValidGrouping}; use crate::pg::Pg; use crate::query_builder::*; use crate::result::QueryResult; -use crate::sql_types::{Date, NotNull, Nullable, Timestamp, Timestamptz, VarChar}; +use crate::sql_types::{is_nullable, Date, Nullable, SqlType, Timestamp, Timestamptz, VarChar}; /// Marker trait for types which are valid in `AT TIME ZONE` expressions pub trait DateTimeLike {} impl DateTimeLike for Date {} impl DateTimeLike for Timestamp {} impl DateTimeLike for Timestamptz {} -impl DateTimeLike for Nullable {} - +impl DateTimeLike for Nullable where T: SqlType + DateTimeLike {} #[derive(Debug, Copy, Clone, QueryId, ValidGrouping)] pub struct AtTimeZone { timestamp: Ts, diff --git a/diesel/src/pg/expression/expression_methods.rs b/diesel/src/pg/expression/expression_methods.rs index 8b8931880bd1..9abf426f3717 100644 --- a/diesel/src/pg/expression/expression_methods.rs +++ b/diesel/src/pg/expression/expression_methods.rs @@ -1,8 +1,8 @@ //! PostgreSQL specific expression methods use super::operators::*; -use crate::expression::{AsExpression, Expression}; -use crate::sql_types::{Array, Nullable, Range, Text}; +use crate::expression::{AsExpression, Expression, TypedExpressionType}; +use crate::sql_types::{Array, Nullable, Range, SqlType, Text}; /// PostgreSQL specific methods which are present on all expressions. pub trait PgExpressionMethods: Expression + Sized { @@ -27,6 +27,7 @@ pub trait PgExpressionMethods: Expression + Sized { /// ``` fn is_not_distinct_from(self, other: T) -> IsNotDistinctFrom where + Self::SqlType: SqlType, T: AsExpression, { IsNotDistinctFrom::new(self, other.as_expression()) @@ -53,6 +54,7 @@ pub trait PgExpressionMethods: Expression + Sized { /// ``` fn is_distinct_from(self, other: T) -> IsDistinctFrom where + Self::SqlType: SqlType, T: AsExpression, { IsDistinctFrom::new(self, other.as_expression()) @@ -187,6 +189,7 @@ pub trait PgArrayExpressionMethods: Expression + Sized { /// ``` fn overlaps_with(self, other: T) -> OverlapsWith where + Self::SqlType: SqlType, T: AsExpression, { OverlapsWith::new(self, other.as_expression()) @@ -236,6 +239,7 @@ pub trait PgArrayExpressionMethods: Expression + Sized { /// ``` fn contains(self, other: T) -> Contains where + Self::SqlType: SqlType, T: AsExpression, { Contains::new(self, other.as_expression()) @@ -286,6 +290,7 @@ pub trait PgArrayExpressionMethods: Expression + Sized { /// ``` fn is_contained_by(self, other: T) -> IsContainedBy where + Self::SqlType: SqlType, T: AsExpression, { IsContainedBy::new(self, other.as_expression()) @@ -309,6 +314,7 @@ impl ArrayOrNullableArray for Array {} impl ArrayOrNullableArray for Nullable> {} use crate::expression::operators::{Asc, Desc}; +use crate::EscapeExpressionMethods; /// PostgreSQL expression methods related to sorting. /// @@ -440,8 +446,11 @@ pub trait PgTextExpressionMethods: Expression + Sized { /// # Ok(()) /// # } /// ``` - fn ilike>(self, other: T) -> ILike { - ILike::new(self.as_expression(), other.as_expression()) + fn ilike(self, other: T) -> ILike + where + T: AsExpression, + { + ILike::new(self, other.as_expression()) } /// Creates a PostgreSQL `NOT ILIKE` expression @@ -466,8 +475,11 @@ pub trait PgTextExpressionMethods: Expression + Sized { /// # Ok(()) /// # } /// ``` - fn not_ilike>(self, other: T) -> NotILike { - NotILike::new(self.as_expression(), other.as_expression()) + fn not_ilike(self, other: T) -> NotILike + where + T: AsExpression, + { + NotILike::new(self, other.as_expression()) } } @@ -487,10 +499,13 @@ where { } +impl EscapeExpressionMethods for ILike {} +impl EscapeExpressionMethods for NotILike {} + #[doc(hidden)] /// Marker trait used to extract the inner type /// of our `Range` sql type, used to implement `PgRangeExpressionMethods` -pub trait RangeHelper { +pub trait RangeHelper: SqlType { type Inner; } @@ -547,10 +562,25 @@ pub trait PgRangeExpressionMethods: Expression + Sized { fn contains(self, other: T) -> Contains where Self::SqlType: RangeHelper, + ::Inner: SqlType + TypedExpressionType, T: AsExpression<::Inner>, { Contains::new(self, other.as_expression()) } } -impl PgRangeExpressionMethods for T where T: Expression> {} +#[doc(hidden)] +/// Marker trait used to implement `PgRangeExpressionMethods` on the appropriate +/// types. Once coherence takes associated types into account, we can remove +/// this trait. +pub trait RangeOrNullableRange {} + +impl RangeOrNullableRange for Range {} +impl RangeOrNullableRange for Nullable> {} + +impl PgRangeExpressionMethods for T +where + T: Expression, + T::SqlType: RangeOrNullableRange, +{ +} diff --git a/diesel/src/pg/expression/operators.rs b/diesel/src/pg/expression/operators.rs index 7c9f93d7c174..611e3f462337 100644 --- a/diesel/src/pg/expression/operators.rs +++ b/diesel/src/pg/expression/operators.rs @@ -1,3 +1,4 @@ +use crate::expression::expression_types::NotSelectable; use crate::pg::Pg; infix_operator!(IsDistinctFrom, " IS DISTINCT FROM ", backend: Pg); @@ -7,5 +8,5 @@ infix_operator!(Contains, " @> ", backend: Pg); infix_operator!(IsContainedBy, " <@ ", backend: Pg); infix_operator!(ILike, " ILIKE ", backend: Pg); infix_operator!(NotILike, " NOT ILIKE ", backend: Pg); -postfix_operator!(NullsFirst, " NULLS FIRST", (), backend: Pg); -postfix_operator!(NullsLast, " NULLS LAST", (), backend: Pg); +postfix_operator!(NullsFirst, " NULLS FIRST", NotSelectable, backend: Pg); +postfix_operator!(NullsLast, " NULLS LAST", NotSelectable, backend: Pg); diff --git a/diesel/src/pg/query_builder/mod.rs b/diesel/src/pg/query_builder/mod.rs index 3e8ad828fff9..981c2745293b 100644 --- a/diesel/src/pg/query_builder/mod.rs +++ b/diesel/src/pg/query_builder/mod.rs @@ -37,8 +37,8 @@ impl QueryBuilder for PgQueryBuilder { fn push_bind_param(&mut self) { self.bind_idx += 1; - let sql = format!("${}", self.bind_idx); - self.push_sql(&sql); + self.sql += "$"; + itoa::fmt(&mut self.sql, self.bind_idx).expect("int formating does not fail"); } fn finish(self) -> String { diff --git a/diesel/src/pg/types/array.rs b/diesel/src/pg/types/array.rs index cddc6801ce1f..7d799b325ba0 100644 --- a/diesel/src/pg/types/array.rs +++ b/diesel/src/pg/types/array.rs @@ -23,8 +23,7 @@ impl FromSql, Pg> for Vec where T: FromSql, { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: PgValue<'_>) -> deserialize::Result { let mut bytes = value.as_bytes(); let num_dimensions = bytes.read_i32::()?; let has_null = bytes.read_i32::()? != 0; @@ -45,11 +44,11 @@ where .map(|_| { let elem_size = bytes.read_i32::()?; if has_null && elem_size == -1 { - T::from_sql(None) + T::from_nullable_sql(None) } else { let (elem_bytes, new_bytes) = bytes.split_at(elem_size as usize); bytes = new_bytes; - T::from_sql(Some(PgValue::new(elem_bytes, value.get_oid()))) + T::from_sql(PgValue::new(elem_bytes, value.get_oid())) } }) .collect() diff --git a/diesel/src/pg/types/date_and_time/chrono.rs b/diesel/src/pg/types/date_and_time/chrono.rs index 686cb11cb8fb..7d4d4c061179 100644 --- a/diesel/src/pg/types/date_and_time/chrono.rs +++ b/diesel/src/pg/types/date_and_time/chrono.rs @@ -19,7 +19,7 @@ fn pg_epoch() -> NaiveDateTime { } impl FromSql for NaiveDateTime { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { let PgTimestamp(offset) = FromSql::::from_sql(bytes)?; match pg_epoch().checked_add_signed(Duration::microseconds(offset)) { Some(v) => Ok(v), @@ -46,7 +46,7 @@ impl ToSql for NaiveDateTime { } impl FromSql for NaiveDateTime { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { FromSql::::from_sql(bytes) } } @@ -58,14 +58,14 @@ impl ToSql for NaiveDateTime { } impl FromSql for DateTime { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { let naive_date_time = >::from_sql(bytes)?; Ok(DateTime::from_utc(naive_date_time, Utc)) } } impl FromSql for DateTime { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { let naive_date_time = >::from_sql(bytes)?; Ok(Local::from_utc_datetime(&Local, &naive_date_time)) } @@ -92,7 +92,7 @@ impl ToSql for NaiveTime { } impl FromSql for NaiveTime { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { let PgTime(offset) = FromSql::::from_sql(bytes)?; let duration = Duration::microseconds(offset); Ok(midnight() + duration) @@ -111,7 +111,7 @@ impl ToSql for NaiveDate { } impl FromSql for NaiveDate { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { let PgDate(offset) = FromSql::::from_sql(bytes)?; match pg_epoch_date().checked_add_signed(Duration::days(i64::from(offset))) { Some(date) => Ok(date), diff --git a/diesel/src/pg/types/date_and_time/deprecated_time.rs b/diesel/src/pg/types/date_and_time/deprecated_time.rs index f6e372df3cfb..41838db34d3a 100644 --- a/diesel/src/pg/types/date_and_time/deprecated_time.rs +++ b/diesel/src/pg/types/date_and_time/deprecated_time.rs @@ -28,7 +28,7 @@ impl ToSql for Timespec { } impl FromSql for Timespec { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { let t = >::from_sql(bytes)?; let pg_epoch = Timespec::new(TIME_SEC_CONV, 0); let duration = Duration::microseconds(t); diff --git a/diesel/src/pg/types/date_and_time/mod.rs b/diesel/src/pg/types/date_and_time/mod.rs index 360f913c0c79..b79d7fb14b59 100644 --- a/diesel/src/pg/types/date_and_time/mod.rs +++ b/diesel/src/pg/types/date_and_time/mod.rs @@ -90,7 +90,7 @@ impl ToSql for PgTimestamp { } impl FromSql for PgTimestamp { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { FromSql::::from_sql(bytes).map(PgTimestamp) } } @@ -102,7 +102,7 @@ impl ToSql for PgTimestamp { } impl FromSql for PgTimestamp { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { FromSql::::from_sql(bytes) } } @@ -114,7 +114,7 @@ impl ToSql for PgDate { } impl FromSql for PgDate { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { FromSql::::from_sql(bytes).map(PgDate) } } @@ -126,7 +126,7 @@ impl ToSql for PgTime { } impl FromSql for PgTime { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { FromSql::::from_sql(bytes).map(PgTime) } } @@ -141,12 +141,11 @@ impl ToSql for PgInterval { } impl FromSql for PgInterval { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: PgValue<'_>) -> deserialize::Result { Ok(PgInterval { - microseconds: FromSql::::from_sql(Some(value.subslice(0..8)))?, - days: FromSql::::from_sql(Some(value.subslice(8..12)))?, - months: FromSql::::from_sql(Some(value.subslice(12..16)))?, + microseconds: FromSql::::from_sql(value.subslice(0..8))?, + days: FromSql::::from_sql(value.subslice(8..12))?, + months: FromSql::::from_sql(value.subslice(12..16))?, }) } } diff --git a/diesel/src/pg/types/date_and_time/std_time.rs b/diesel/src/pg/types/date_and_time/std_time.rs index 1bf879bd83d2..8789aa4773e3 100644 --- a/diesel/src/pg/types/date_and_time/std_time.rs +++ b/diesel/src/pg/types/date_and_time/std_time.rs @@ -27,7 +27,7 @@ impl ToSql for SystemTime { } impl FromSql for SystemTime { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { let usecs_passed = >::from_sql(bytes)?; let before_epoch = usecs_passed < 0; let time_passed = usecs_to_duration(usecs_passed.abs() as u64); diff --git a/diesel/src/pg/types/floats/mod.rs b/diesel/src/pg/types/floats/mod.rs index 21915594c900..8386c19ecb02 100644 --- a/diesel/src/pg/types/floats/mod.rs +++ b/diesel/src/pg/types/floats/mod.rs @@ -50,8 +50,7 @@ impl ::std::fmt::Display for InvalidNumericSign { impl Error for InvalidNumericSign {} impl FromSql for PgNumeric { - fn from_sql(bytes: Option>) -> deserialize::Result { - let bytes = not_none!(bytes); + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { let mut bytes = bytes.as_bytes(); let digit_count = bytes.read_u16::()?; let mut digits = Vec::with_capacity(digit_count as usize); diff --git a/diesel/src/pg/types/integers.rs b/diesel/src/pg/types/integers.rs index 3e74e0473659..bc59ae8f51f4 100644 --- a/diesel/src/pg/types/integers.rs +++ b/diesel/src/pg/types/integers.rs @@ -7,8 +7,7 @@ use crate::serialize::{self, IsNull, Output, ToSql}; use crate::sql_types; impl FromSql for u32 { - fn from_sql(bytes: Option>) -> deserialize::Result { - let bytes = not_none!(bytes); + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { let mut bytes = bytes.as_bytes(); bytes.read_u32::().map_err(Into::into) } diff --git a/diesel/src/pg/types/json.rs b/diesel/src/pg/types/json.rs index 4b9c41648756..e483501cb295 100644 --- a/diesel/src/pg/types/json.rs +++ b/diesel/src/pg/types/json.rs @@ -10,8 +10,7 @@ use crate::serialize::{self, IsNull, Output, ToSql}; use crate::sql_types; impl FromSql for serde_json::Value { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: PgValue<'_>) -> deserialize::Result { serde_json::from_slice(value.as_bytes()).map_err(|_| "Invalid Json".into()) } } @@ -25,8 +24,7 @@ impl ToSql for serde_json::Value { } impl FromSql for serde_json::Value { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: PgValue<'_>) -> deserialize::Result { let bytes = value.as_bytes(); if bytes[0] != 1 { return Err("Unsupported JSONB encoding version".into()); @@ -56,20 +54,21 @@ fn json_to_sql() { fn some_json_from_sql() { let input_json = b"true"; let output_json: serde_json::Value = - FromSql::::from_sql(Some(PgValue::for_test(input_json))).unwrap(); + FromSql::::from_sql(PgValue::for_test(input_json)).unwrap(); assert_eq!(output_json, serde_json::Value::Bool(true)); } #[test] fn bad_json_from_sql() { let uuid: Result = - FromSql::::from_sql(Some(PgValue::for_test(b"boom"))); + FromSql::::from_sql(PgValue::for_test(b"boom")); assert_eq!(uuid.unwrap_err().to_string(), "Invalid Json"); } #[test] fn no_json_from_sql() { - let uuid: Result = FromSql::::from_sql(None); + let uuid: Result = + FromSql::::from_nullable_sql(None); assert_eq!( uuid.unwrap_err().to_string(), "Unexpected null for non-null column" @@ -88,21 +87,21 @@ fn jsonb_to_sql() { fn some_jsonb_from_sql() { let input_json = b"\x01true"; let output_json: serde_json::Value = - FromSql::::from_sql(Some(PgValue::for_test(input_json))).unwrap(); + FromSql::::from_sql(PgValue::for_test(input_json)).unwrap(); assert_eq!(output_json, serde_json::Value::Bool(true)); } #[test] fn bad_jsonb_from_sql() { let uuid: Result = - FromSql::::from_sql(Some(PgValue::for_test(b"\x01boom"))); + FromSql::::from_sql(PgValue::for_test(b"\x01boom")); assert_eq!(uuid.unwrap_err().to_string(), "Invalid Json"); } #[test] fn bad_jsonb_version_from_sql() { let uuid: Result = - FromSql::::from_sql(Some(PgValue::for_test(b"\x02true"))); + FromSql::::from_sql(PgValue::for_test(b"\x02true")); assert_eq!( uuid.unwrap_err().to_string(), "Unsupported JSONB encoding version" @@ -111,7 +110,8 @@ fn bad_jsonb_version_from_sql() { #[test] fn no_jsonb_from_sql() { - let uuid: Result = FromSql::::from_sql(None); + let uuid: Result = + FromSql::::from_nullable_sql(None); assert_eq!( uuid.unwrap_err().to_string(), "Unexpected null for non-null column" diff --git a/diesel/src/pg/types/mac_addr.rs b/diesel/src/pg/types/mac_addr.rs index a5a97751066c..886ca917fa0d 100644 --- a/diesel/src/pg/types/mac_addr.rs +++ b/diesel/src/pg/types/mac_addr.rs @@ -19,8 +19,7 @@ mod foreign_derives { } impl FromSql for [u8; 6] { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: PgValue<'_>) -> deserialize::Result { value .as_bytes() .try_into() @@ -41,7 +40,6 @@ fn macaddr_roundtrip() { let mut bytes = Output::test(); let input_address = [0x52, 0x54, 0x00, 0xfb, 0xc6, 0x16]; ToSql::::to_sql(&input_address, &mut bytes).unwrap(); - let output_address: [u8; 6] = - FromSql::from_sql(Some(PgValue::for_test(bytes.as_ref()))).unwrap(); + let output_address: [u8; 6] = FromSql::from_sql(PgValue::for_test(bytes.as_ref())).unwrap(); assert_eq!(input_address, output_address); } diff --git a/diesel/src/pg/types/money.rs b/diesel/src/pg/types/money.rs index 8939ce1fda2d..d6dfc12b4b0e 100644 --- a/diesel/src/pg/types/money.rs +++ b/diesel/src/pg/types/money.rs @@ -25,7 +25,7 @@ use crate::sql_types::{BigInt, Money}; pub struct PgMoney(pub i64); impl FromSql for PgMoney { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { FromSql::::from_sql(bytes).map(PgMoney) } } diff --git a/diesel/src/pg/types/network_address.rs b/diesel/src/pg/types/network_address.rs index 563f1852fc51..cfe0da706ecf 100644 --- a/diesel/src/pg/types/network_address.rs +++ b/diesel/src/pg/types/network_address.rs @@ -60,9 +60,8 @@ macro_rules! assert_or_error { macro_rules! impl_Sql { ($ty: ty, $net_type: expr) => { impl FromSql<$ty, Pg> for IpNetwork { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: PgValue<'_>) -> deserialize::Result { // https://github.com/postgres/postgres/blob/55c3391d1e6a201b5b891781d21fe682a8c64fe6/src/include/utils/inet.h#L23-L28 - let value = not_none!(value); let bytes = value.as_bytes(); assert_or_error!(4 <= bytes.len(), "input is too short."); let af = bytes[0]; @@ -161,7 +160,7 @@ fn some_v4address_from_sql() { let mut bytes = Output::test(); ToSql::<$ty, Pg>::to_sql(&input_address, &mut bytes).unwrap(); let output_address = - FromSql::<$ty, Pg>::from_sql(Some(PgValue::for_test(bytes.as_ref()))).unwrap(); + FromSql::<$ty, Pg>::from_sql(PgValue::for_test(bytes.as_ref())).unwrap(); assert_eq!(input_address, output_address); }; } @@ -219,7 +218,7 @@ fn some_v6address_from_sql() { let mut bytes = Output::test(); ToSql::<$ty, Pg>::to_sql(&input_address, &mut bytes).unwrap(); let output_address = - FromSql::<$ty, Pg>::from_sql(Some(PgValue::for_test(bytes.as_ref()))).unwrap(); + FromSql::<$ty, Pg>::from_sql(PgValue::for_test(bytes.as_ref())).unwrap(); assert_eq!(input_address, output_address); }; } @@ -233,7 +232,7 @@ fn bad_address_from_sql() { macro_rules! bad_address_from_sql { ($ty:tt) => { let address: Result = - FromSql::<$ty, Pg>::from_sql(Some(PgValue::for_test(&[7, PGSQL_AF_INET, 0]))); + FromSql::<$ty, Pg>::from_sql(PgValue::for_test(&[7, PGSQL_AF_INET, 0])); assert_eq!( address.unwrap_err().to_string(), "invalid network address format. input is too short." @@ -249,7 +248,7 @@ fn bad_address_from_sql() { fn no_address_from_sql() { macro_rules! test_no_address_from_sql { ($ty:ty) => { - let address: Result = FromSql::<$ty, Pg>::from_sql(None); + let address: Result = FromSql::<$ty, Pg>::from_nullable_sql(None); assert_eq!( address.unwrap_err().to_string(), "Unexpected null for non-null column" diff --git a/diesel/src/pg/types/numeric.rs b/diesel/src/pg/types/numeric.rs index c965a6cdfc85..72f0292069fb 100644 --- a/diesel/src/pg/types/numeric.rs +++ b/diesel/src/pg/types/numeric.rs @@ -153,7 +153,7 @@ mod bigdecimal { } impl FromSql for BigDecimal { - fn from_sql(numeric: Option>) -> deserialize::Result { + fn from_sql(numeric: PgValue<'_>) -> deserialize::Result { PgNumeric::from_sql(numeric)?.try_into() } } diff --git a/diesel/src/pg/types/primitives.rs b/diesel/src/pg/types/primitives.rs index 284ef0c24a06..58dc2dc93fd1 100644 --- a/diesel/src/pg/types/primitives.rs +++ b/diesel/src/pg/types/primitives.rs @@ -6,8 +6,7 @@ use crate::serialize::{self, IsNull, Output, ToSql}; use crate::sql_types; impl FromSql for bool { - fn from_sql(bytes: Option>) -> deserialize::Result { - let bytes = not_none!(bytes); + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { Ok(bytes.as_bytes()[0] != 0) } } @@ -29,7 +28,10 @@ fn bool_to_sql() { } #[test] -fn bool_from_sql_treats_null_as_false() { - let result = >::from_sql(None).unwrap(); - assert!(!result); +fn no_bool_from_sql() { + let result = >::from_nullable_sql(None); + assert_eq!( + result.unwrap_err().to_string(), + "Unexpected null for non-null column" + ); } diff --git a/diesel/src/pg/types/ranges.rs b/diesel/src/pg/types/ranges.rs index 81fc127548f1..56865431d59c 100644 --- a/diesel/src/pg/types/ranges.rs +++ b/diesel/src/pg/types/ranges.rs @@ -2,12 +2,14 @@ use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; use std::collections::Bound; use std::io::Write; -use crate::deserialize::{self, FromSql, FromSqlRow, Queryable}; +use crate::deserialize::{self, FromSql, FromSqlRow}; use crate::expression::bound::Bound as SqlBound; use crate::expression::AsExpression; use crate::pg::{Pg, PgMetadataLookup, PgTypeMetadata, PgValue}; +use crate::row::Field; use crate::serialize::{self, IsNull, Output, ToSql}; use crate::sql_types::*; +use deserialize::StaticallySizedRow; // https://github.com/postgres/postgres/blob/113b0045e20d40f726a0a30e33214455e4f1385e/src/include/utils/rangetypes.h#L35-L43 bitflags! { @@ -23,16 +25,6 @@ bitflags! { } } -impl Queryable, Pg> for (Bound, Bound) -where - T: FromSql + Queryable, -{ - type Row = Self; - fn build(row: Self) -> Self { - row - } -} - impl AsExpression> for (Bound, Bound) { type Expression = SqlBound, Self>; @@ -69,17 +61,25 @@ impl FromSqlRow, Pg> for (Bound, Bound) where (Bound, Bound): FromSql, Pg>, { - fn build_from_row>(row: &mut R) -> deserialize::Result { - FromSql::, Pg>::from_sql(row.take()) + fn build_from_row<'a>(row: &impl crate::row::Row<'a, Pg>) -> deserialize::Result { + FromSql::from_nullable_sql( + row.get(0) + .ok_or_else(|| Box::new(crate::result::UnexpectedEndOfRow))? + .value(), + ) } } +impl StaticallySizedRow, Pg> for (Bound, Bound) where + (Bound, Bound): FromSql, Pg> +{ +} + impl FromSql, Pg> for (Bound, Bound) where T: FromSql, { - fn from_sql(bytes: Option>) -> deserialize::Result { - let value = not_none!(bytes); + fn from_sql(value: PgValue<'_>) -> deserialize::Result { let mut bytes = value.as_bytes(); let flags: RangeFlags = RangeFlags::from_bits_truncate(bytes.read_u8()?); let mut lower_bound = Bound::Unbounded; @@ -89,7 +89,7 @@ where let elem_size = bytes.read_i32::()?; let (elem_bytes, new_bytes) = bytes.split_at(elem_size as usize); bytes = new_bytes; - let value = T::from_sql(Some(PgValue::new(elem_bytes, value.get_oid())))?; + let value = T::from_sql(PgValue::new(elem_bytes, value.get_oid()))?; lower_bound = if flags.contains(RangeFlags::LB_INC) { Bound::Included(value) @@ -100,7 +100,7 @@ where if !flags.contains(RangeFlags::UB_INF) { let _size = bytes.read_i32::()?; - let value = T::from_sql(Some(PgValue::new(bytes, value.get_oid())))?; + let value = T::from_sql(PgValue::new(bytes, value.get_oid()))?; upper_bound = if flags.contains(RangeFlags::UB_INC) { Bound::Included(value) diff --git a/diesel/src/pg/types/record.rs b/diesel/src/pg/types/record.rs index 5f0624a9e0f6..8fa4fa9210ac 100644 --- a/diesel/src/pg/types/record.rs +++ b/diesel/src/pg/types/record.rs @@ -2,16 +2,17 @@ use byteorder::*; use std::io::Write; use std::num::NonZeroU32; -use crate::deserialize::{self, FromSql, FromSqlRow, Queryable}; +use crate::deserialize::{self, FromSql, FromSqlRow}; use crate::expression::{ - AppearsOnTable, AsExpression, Expression, SelectableExpression, ValidGrouping, + AppearsOnTable, AsExpression, Expression, SelectableExpression, TypedExpressionType, + ValidGrouping, }; use crate::pg::{Pg, PgValue}; use crate::query_builder::{AstPass, QueryFragment, QueryId}; use crate::result::QueryResult; -use crate::row::Row; +use crate::row::{Field, Row}; use crate::serialize::{self, IsNull, Output, ToSql, WriteTuple}; -use crate::sql_types::{HasSqlType, Record}; +use crate::sql_types::{HasSqlType, Record, SqlType}; macro_rules! tuple_impls { ($( @@ -27,8 +28,7 @@ macro_rules! tuple_impls { // but the only other option would be to use `mem::uninitialized` // and `ptr::write`. #[allow(clippy::eval_order_dependence)] - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: PgValue<'_>) -> deserialize::Result { let mut bytes = value.as_bytes(); let num_elements = bytes.read_i32::()?; @@ -49,14 +49,14 @@ macro_rules! tuple_impls { let num_bytes = bytes.read_i32::()?; if num_bytes == -1 { - $T::from_sql(None)? + $T::from_nullable_sql(None)? } else { let (elem_bytes, new_bytes) = bytes.split_at(num_bytes as usize); bytes = new_bytes; - $T::from_sql(Some(PgValue::new( + $T::from_sql(PgValue::new( elem_bytes, oid, - )))? + ))? } },)+); @@ -73,24 +73,21 @@ macro_rules! tuple_impls { where Self: FromSql, Pg>, { - fn build_from_row>(row: &mut RowT) -> deserialize::Result { - Self::from_sql(row.take()) - } - } - impl<$($T,)+ $($ST,)+> Queryable, Pg> for ($($T,)+) - where - Self: FromSqlRow, Pg>, - { - type Row = Self; - fn build(row: Self::Row) -> Self { - row + fn build_from_row<'a>(row: &impl Row<'a, Pg>) -> deserialize::Result + { + FromSql::from_nullable_sql( + row.get(0) + .ok_or_else(|| Box::new(crate::result::UnexpectedEndOfRow))? + .value(), + ) } } impl<$($T,)+ $($ST,)+> AsExpression> for ($($T,)+) where + $($ST: SqlType + TypedExpressionType,)+ $($T: AsExpression<$ST>,)+ PgTuple<($($T::Expression,)+)>: Expression>, { diff --git a/diesel/src/pg/types/uuid.rs b/diesel/src/pg/types/uuid.rs index 5e5e86ec6956..bf933e263fdb 100644 --- a/diesel/src/pg/types/uuid.rs +++ b/diesel/src/pg/types/uuid.rs @@ -14,8 +14,7 @@ use crate::sql_types::Uuid; struct UuidProxy(uuid::Uuid); impl FromSql for uuid::Uuid { - fn from_sql(bytes: Option>) -> deserialize::Result { - let value = not_none!(bytes); + fn from_sql(value: PgValue<'_>) -> deserialize::Result { uuid::Uuid::from_slice(value.as_bytes()).map_err(Into::into) } } @@ -40,13 +39,13 @@ fn uuid_to_sql() { fn some_uuid_from_sql() { let input_uuid = uuid::Uuid::from_fields(0xFFFF_FFFF, 0xFFFF, 0xFFFF, b"abcdef12").unwrap(); let output_uuid = - FromSql::::from_sql(Some(PgValue::for_test(input_uuid.as_bytes()))).unwrap(); + FromSql::::from_sql(PgValue::for_test(input_uuid.as_bytes())).unwrap(); assert_eq!(input_uuid, output_uuid); } #[test] fn bad_uuid_from_sql() { - let uuid = uuid::Uuid::from_sql(Some(PgValue::for_test(b"boom"))); + let uuid = uuid::Uuid::from_sql(PgValue::for_test(b"boom")); assert_eq!( uuid.unwrap_err().to_string(), "invalid bytes length: expected 16, found 4" @@ -55,7 +54,7 @@ fn bad_uuid_from_sql() { #[test] fn no_uuid_from_sql() { - let uuid = uuid::Uuid::from_sql(None); + let uuid = uuid::Uuid::from_nullable_sql(None); assert_eq!( uuid.unwrap_err().to_string(), "Unexpected null for non-null column" diff --git a/diesel/src/query_builder/insert_statement/insert_from_select.rs b/diesel/src/query_builder/insert_statement/insert_from_select.rs index f5e1132068f1..faafffc4dce0 100644 --- a/diesel/src/query_builder/insert_statement/insert_from_select.rs +++ b/diesel/src/query_builder/insert_statement/insert_from_select.rs @@ -46,8 +46,8 @@ where impl QueryFragment for InsertFromSelect where DB: Backend, - Columns: ColumnList + Expression, - Select: Query + QueryFragment, + Columns: ColumnList + Expression, + Select: Query + QueryFragment, { fn walk_ast(&self, mut out: AstPass) -> QueryResult<()> { out.push_sql("("); @@ -60,7 +60,7 @@ where impl UndecoratedInsertRecord for InsertFromSelect where - Columns: ColumnList + Expression, - Select: Query, + Columns: ColumnList + Expression, + Select: Query, { } diff --git a/diesel/src/query_builder/insert_statement/mod.rs b/diesel/src/query_builder/insert_statement/mod.rs index da798643105b..b252c72714ca 100644 --- a/diesel/src/query_builder/insert_statement/mod.rs +++ b/diesel/src/query_builder/insert_statement/mod.rs @@ -162,8 +162,8 @@ impl InsertStatement, Op, Ret> { columns: C2, ) -> InsertStatement, Op, Ret> where - C2: ColumnList + Expression, - U: Query, + C2: ColumnList
+ Expression, + U: Query, { InsertStatement::new( self.target, diff --git a/diesel/src/query_builder/select_statement/boxed.rs b/diesel/src/query_builder/select_statement/boxed.rs index 5216e617d848..d853b80e5cd6 100644 --- a/diesel/src/query_builder/select_statement/boxed.rs +++ b/diesel/src/query_builder/select_statement/boxed.rs @@ -19,7 +19,7 @@ use crate::query_dsl::*; use crate::query_source::joins::*; use crate::query_source::{QuerySource, Table}; use crate::result::QueryResult; -use crate::sql_types::{BigInt, Bool, NotNull, Nullable}; +use crate::sql_types::{BigInt, BoolOrNullableBool, IntoNullable}; #[allow(missing_debug_implementations)] pub struct BoxedSelectStatement<'a, ST, QS, DB> { @@ -194,7 +194,8 @@ where impl<'a, ST, QS, DB, Predicate> FilterDsl for BoxedSelectStatement<'a, ST, QS, DB> where BoxedWhereClause<'a, DB>: WhereAnd>, - Predicate: AppearsOnTable + NonAggregate, + Predicate: AppearsOnTable + NonAggregate, + Predicate::SqlType: BoolOrNullableBool, { type Output = Self; @@ -207,7 +208,8 @@ where impl<'a, ST, QS, DB, Predicate> OrFilterDsl for BoxedSelectStatement<'a, ST, QS, DB> where BoxedWhereClause<'a, DB>: WhereOr>, - Predicate: AppearsOnTable + NonAggregate, + Predicate: AppearsOnTable + NonAggregate, + Predicate::SqlType: BoolOrNullableBool, { type Output = Self; @@ -331,9 +333,9 @@ where impl<'a, ST, QS, DB> SelectNullableDsl for BoxedSelectStatement<'a, ST, QS, DB> where - ST: NotNull, + ST: IntoNullable, { - type Output = BoxedSelectStatement<'a, Nullable, QS, DB>; + type Output = BoxedSelectStatement<'a, ST::Nullable, QS, DB>; fn nullable(self) -> Self::Output { BoxedSelectStatement { diff --git a/diesel/src/query_builder/select_statement/dsl_impls.rs b/diesel/src/query_builder/select_statement/dsl_impls.rs index 915f927cd973..4fabed4537e8 100644 --- a/diesel/src/query_builder/select_statement/dsl_impls.rs +++ b/diesel/src/query_builder/select_statement/dsl_impls.rs @@ -24,7 +24,7 @@ use crate::query_dsl::methods::*; use crate::query_dsl::*; use crate::query_source::joins::{Join, JoinOn, JoinTo}; use crate::query_source::QuerySource; -use crate::sql_types::{BigInt, Bool}; +use crate::sql_types::{BigInt, BoolOrNullableBool}; impl InternalJoinDsl for SelectStatement @@ -94,7 +94,8 @@ where impl FilterDsl for SelectStatement where - Predicate: Expression + NonAggregate, + Predicate: Expression + NonAggregate, + Predicate::SqlType: BoolOrNullableBool, W: WhereAnd, { type Output = SelectStatement; @@ -116,7 +117,8 @@ where impl OrFilterDsl for SelectStatement where - Predicate: Expression + NonAggregate, + Predicate: Expression + NonAggregate, + Predicate::SqlType: BoolOrNullableBool, W: WhereOr, { type Output = SelectStatement; diff --git a/diesel/src/query_builder/sql_query.rs b/diesel/src/query_builder/sql_query.rs index f177d4a9eae4..f838cc3c5c8b 100644 --- a/diesel/src/query_builder/sql_query.rs +++ b/diesel/src/query_builder/sql_query.rs @@ -1,13 +1,13 @@ use std::marker::PhantomData; +use super::Query; use crate::backend::Backend; use crate::connection::Connection; -use crate::deserialize::QueryableByName; use crate::query_builder::{AstPass, QueryFragment, QueryId}; -use crate::query_dsl::{LoadQuery, RunQueryDsl}; +use crate::query_dsl::RunQueryDsl; use crate::result::QueryResult; use crate::serialize::ToSql; -use crate::sql_types::HasSqlType; +use crate::sql_types::{HasSqlType, Untyped}; #[derive(Debug, Clone)] #[must_use = "Queries are only executed when calling `load`, `get_result` or similar."] @@ -116,15 +116,8 @@ impl QueryId for SqlQuery { const HAS_STATIC_QUERY_ID: bool = false; } -impl LoadQuery for SqlQuery -where - Conn: Connection, - T: QueryableByName, - Self: QueryFragment, -{ - fn internal_load(self, conn: &Conn) -> QueryResult> { - conn.query_by_name(&self) - } +impl Query for SqlQuery { + type SqlType = Untyped; } impl RunQueryDsl for SqlQuery {} @@ -182,15 +175,8 @@ where } } -impl LoadQuery for UncheckedBind -where - Conn: Connection, - T: QueryableByName, - Self: QueryFragment + QueryId, -{ - fn internal_load(self, conn: &Conn) -> QueryResult> { - conn.query_by_name(&self) - } +impl Query for UncheckedBind { + type SqlType = Untyped; } impl RunQueryDsl for UncheckedBind {} @@ -260,15 +246,11 @@ impl QueryId for BoxedSqlQuery<'_, DB, Query> { const HAS_STATIC_QUERY_ID: bool = false; } -impl LoadQuery for BoxedSqlQuery<'_, Conn::Backend, Query> +impl Query for BoxedSqlQuery<'_, DB, Q> where - Conn: Connection, - T: QueryableByName, - Self: QueryFragment + QueryId, + DB: Backend, { - fn internal_load(self, conn: &Conn) -> QueryResult> { - conn.query_by_name(&self) - } + type SqlType = Untyped; } impl RunQueryDsl for BoxedSqlQuery<'_, Conn::Backend, Query> {} diff --git a/diesel/src/query_builder/where_clause.rs b/diesel/src/query_builder/where_clause.rs index 8c47086e44f9..a16e1704bb66 100644 --- a/diesel/src/query_builder/where_clause.rs +++ b/diesel/src/query_builder/where_clause.rs @@ -1,11 +1,9 @@ use super::*; use crate::backend::Backend; -use crate::dsl::Or; -use crate::expression::operators::And; +use crate::expression::operators::{And, Or}; use crate::expression::*; -use crate::expression_methods::*; use crate::result::QueryResult; -use crate::sql_types::Bool; +use crate::sql_types::BoolOrNullableBool; /// Add `Predicate` to the current `WHERE` clause, joining with `AND` if /// applicable. @@ -39,7 +37,8 @@ impl QueryFragment for NoWhereClause { impl WhereAnd for NoWhereClause where - Predicate: Expression, + Predicate: Expression, + Predicate::SqlType: BoolOrNullableBool, { type Output = WhereClause; @@ -50,7 +49,8 @@ where impl WhereOr for NoWhereClause where - Predicate: Expression, + Predicate: Expression, + Predicate::SqlType: BoolOrNullableBool, { type Output = WhereClause; @@ -83,25 +83,29 @@ where impl WhereAnd for WhereClause where - Expr: Expression, - Predicate: Expression, + Expr: Expression, + Expr::SqlType: BoolOrNullableBool, + Predicate: Expression, + Predicate::SqlType: BoolOrNullableBool, { type Output = WhereClause>; fn and(self, predicate: Predicate) -> Self::Output { - WhereClause(self.0.and(predicate)) + WhereClause(And::new(self.0, predicate)) } } impl WhereOr for WhereClause where - Expr: Expression, - Predicate: Expression, + Expr: Expression, + Expr::SqlType: BoolOrNullableBool, + Predicate: Expression, + Predicate::SqlType: BoolOrNullableBool, { type Output = WhereClause>; fn or(self, predicate: Predicate) -> Self::Output { - WhereClause(self.0.or(predicate)) + WhereClause(Or::new(self.0, predicate)) } } @@ -177,7 +181,6 @@ where fn or(self, predicate: Predicate) -> Self::Output { use self::BoxedWhereClause::Where; use crate::expression::grouped::Grouped; - use crate::expression::operators::Or; match self { Where(where_clause) => Where(Box::new(Grouped(Or::new(where_clause, predicate)))), diff --git a/diesel/src/query_dsl/load_dsl.rs b/diesel/src/query_dsl/load_dsl.rs index 4d74bae62da3..f832191e7379 100644 --- a/diesel/src/query_dsl/load_dsl.rs +++ b/diesel/src/query_dsl/load_dsl.rs @@ -1,10 +1,10 @@ use super::RunQueryDsl; use crate::backend::Backend; use crate::connection::Connection; -use crate::deserialize::Queryable; +use crate::deserialize::{FromSqlRow, IsCompatibleType}; +use crate::expression::TypedExpressionType; use crate::query_builder::{AsQuery, QueryFragment, QueryId}; use crate::result::QueryResult; -use crate::sql_types::HasSqlType; /// The `load` method /// @@ -18,16 +18,17 @@ pub trait LoadQuery: RunQueryDsl { fn internal_load(self, conn: &Conn) -> QueryResult>; } -impl LoadQuery for T +impl LoadQuery for T where Conn: Connection, - Conn::Backend: HasSqlType, T: AsQuery + RunQueryDsl, T::Query: QueryFragment + QueryId, - U: Queryable, + U: FromSqlRow, + T::SqlType: IsCompatibleType, + ST: TypedExpressionType, { fn internal_load(self, conn: &Conn) -> QueryResult> { - conn.query_by_index(self) + conn.load(self) } } diff --git a/diesel/src/query_dsl/single_value_dsl.rs b/diesel/src/query_dsl/single_value_dsl.rs index be4c9b272a63..c220b6a93271 100644 --- a/diesel/src/query_dsl/single_value_dsl.rs +++ b/diesel/src/query_dsl/single_value_dsl.rs @@ -3,7 +3,7 @@ use crate::dsl::Limit; use crate::expression::grouped::Grouped; use crate::expression::subselect::Subselect; use crate::query_builder::SelectQuery; -use crate::sql_types::{IntoNullable, SingleValue}; +use crate::sql_types::IntoNullable; /// The `single_value` method /// @@ -20,13 +20,13 @@ pub trait SingleValueDsl { fn single_value(self) -> Self::Output; } -impl SingleValueDsl for T +impl SingleValueDsl for T where - Self: SelectQuery + LimitDsl, - ST: IntoNullable, - ST::Nullable: SingleValue, + Self: SelectQuery + LimitDsl, + ::SqlType: IntoNullable, { - type Output = Grouped, ST::Nullable>>; + type Output = + Grouped, <::SqlType as IntoNullable>::Nullable>>; fn single_value(self) -> Self::Output { Grouped(Subselect::new(self.limit(1))) diff --git a/diesel/src/query_source/joins.rs b/diesel/src/query_source/joins.rs index 90dec6e21a26..e22dac52c1e6 100644 --- a/diesel/src/query_source/joins.rs +++ b/diesel/src/query_source/joins.rs @@ -6,7 +6,7 @@ use crate::expression::SelectableExpression; use crate::prelude::*; use crate::query_builder::*; use crate::result::QueryResult; -use crate::sql_types::Bool; +use crate::sql_types::BoolOrNullableBool; use crate::util::TupleAppend; #[derive(Debug, Clone, Copy, QueryId)] @@ -84,7 +84,8 @@ where impl QuerySource for JoinOn where Join: QuerySource, - On: AppearsOnTable + Clone, + On: AppearsOnTable + Clone, + On::SqlType: BoolOrNullableBool, Join::DefaultSelection: SelectableExpression, { type FromClause = Grouped>; diff --git a/diesel/src/r2d2.rs b/diesel/src/r2d2.rs index bf3ba738550a..44528c42c1f3 100644 --- a/diesel/src/r2d2.rs +++ b/diesel/src/r2d2.rs @@ -17,10 +17,10 @@ use std::fmt; use std::marker::PhantomData; use crate::connection::{SimpleConnection, TransactionManager}; -use crate::deserialize::{Queryable, QueryableByName}; +use crate::deserialize::{FromSqlRow, IsCompatibleType}; +use crate::expression::TypedExpressionType; use crate::prelude::*; use crate::query_builder::{AsQuery, QueryFragment, QueryId}; -use crate::sql_types::HasSqlType; /// An r2d2 connection manager for use with Diesel. /// @@ -142,22 +142,15 @@ where (&**self).execute(query) } - fn query_by_index(&self, source: T) -> QueryResult> + fn load(&self, source: T) -> QueryResult> where T: AsQuery, T::Query: QueryFragment + QueryId, - Self::Backend: HasSqlType, - U: Queryable, + U: FromSqlRow, + T::SqlType: IsCompatibleType, + ST: TypedExpressionType, { - (&**self).query_by_index(source) - } - - fn query_by_name(&self, source: &T) -> QueryResult> - where - T: QueryFragment + QueryId, - U: QueryableByName, - { - (&**self).query_by_name(source) + (&**self).load(source) } fn execute_returning_count(&self, source: &T) -> QueryResult diff --git a/diesel/src/result.rs b/diesel/src/result.rs index fb06e4a6a984..38b458f3c54d 100644 --- a/diesel/src/result.rs +++ b/diesel/src/result.rs @@ -355,3 +355,15 @@ impl fmt::Display for UnexpectedNullError { } impl StdError for UnexpectedNullError {} + +/// Expected more fields then present in the current row while deserialising results +#[derive(Debug, Clone, Copy)] +pub struct UnexpectedEndOfRow; + +impl fmt::Display for UnexpectedEndOfRow { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "Unexpected end of row") + } +} + +impl StdError for UnexpectedEndOfRow {} diff --git a/diesel/src/row.rs b/diesel/src/row.rs index 4ad69eb52013..3747dff336a3 100644 --- a/diesel/src/row.rs +++ b/diesel/src/row.rs @@ -1,69 +1,160 @@ //! Contains the `Row` trait use crate::backend::{self, Backend}; -use crate::deserialize::{self, FromSql}; +use std::ops::Range; + +/// Representing a way to index into database rows +/// +/// * Crates using existing backends should use existing implementations of +/// this traits. Diesel provides `RowIndex` and `RowIndex<&str>` for +/// all bulit-in backends +/// +/// * Crates implementing custom backends need to provide `RowIndex` and +/// `RowIndex<&str>` impls for their [`Row`] type. +/// +/// [`Row`]: trait.Row.html +pub trait RowIndex { + /// Get the numeric index inside the current row for the provided index value + fn idx(&self, idx: I) -> Option; +} /// Represents a single database row. -/// Apps should not need to concern themselves with this trait. /// -/// This trait is only used as an argument to [`FromSqlRow`]. +/// This trait is used as an argument to [`FromSqlRow`]. /// /// [`FromSqlRow`]: ../deserialize/trait.FromSqlRow.html -pub trait Row { - /// Returns the value of the next column in the row. - fn take(&mut self) -> Option>; +pub trait Row<'a, DB: Backend>: RowIndex + RowIndex<&'a str> + Sized { + /// Field type returned by a `Row` implementation + /// + /// * Crates using existing backend should not concern themself with the + /// concrete type of this associated type. + /// + /// * Crates implementing custom backends should provide their own type + /// meeting the required trait bounds + type Field: Field<'a, DB>; - /// Returns whether the next `count` columns are all `NULL`. + /// Return type of `PartialRow` /// - /// If this method returns `true`, then the next `count` calls to `take` - /// would all return `None`. - fn next_is_null(&self, count: usize) -> bool; + /// For all implementations, beside of the `Row` implementation on `PartialRow` itself + /// this should be `Self`. + #[doc(hidden)] + type InnerPartialRow: Row<'a, DB>; - /// Number of columns in the current result set - fn column_count(&self) -> usize; + /// Get the number of fields in the current row + fn field_count(&self) -> usize; - /// Name of the current column + /// Get the field with the provided index from the row. /// - /// May return `None` in cases where the field is not - /// named on sql side - fn column_name(&self) -> Option<&str>; + /// Returns `None` if there is no matching field for the given index + fn get(&self, idx: I) -> Option + where + Self: RowIndex; + + /// Returns a wrapping row that allows only to access fields, where the index is part of + /// the provided range. + #[doc(hidden)] + fn partial_row(&self, range: Range) -> PartialRow; } -/// Represents a row of a SQL query, where the values are accessed by name -/// rather than by index. +/// Represents a single field in a database row. /// -/// This trait is used by implementations of -/// [`QueryableByName`](../deserialize/trait.QueryableByName.html) -pub trait NamedRow { - /// Retrieve and deserialize a single value from the query - /// - /// Note that `ST` *must* be the exact type of the value with that name in - /// the query. The compiler will not be able to verify that you have - /// provided the correct type. If there is a mismatch, you may receive an - /// incorrect value, or a runtime error. +/// This trait allows retrieving information on the name of the colum and on the value of the +/// field. +pub trait Field<'a, DB: Backend> { + /// The name of the current field /// - /// If two or more fields in the query have the given name, the result of - /// this function is undefined. - fn get(&self, column_name: &str) -> deserialize::Result + /// Returns `None` if it's an unnamed field + fn field_name(&self) -> Option<&str>; + + /// Get the value representing the current field in the raw representation + /// as it is transmitted by the database + fn value(&self) -> Option>; + + /// Checks whether this field is null or not. + fn is_null(&self) -> bool { + self.value().is_none() + } +} + +/// A row type that wraps an inner row +/// +/// This type only allows to access fields of the inner row, whose index is +/// part of `range`. +/// +/// Indexing via `usize` starts with 0 for this row type. The index is then shifted +/// by `self.range.start` to match the corresponding field in the underlying row. +#[derive(Debug)] +#[doc(hidden)] +pub struct PartialRow<'a, R> { + inner: &'a R, + range: Range, +} + +impl<'a, R> PartialRow<'a, R> { + #[doc(hidden)] + pub fn new(inner: &'a R, range: Range) -> Self { + Self { inner, range } + } +} + +impl<'a, 'b, DB, R> Row<'a, DB> for PartialRow<'b, R> +where + DB: Backend, + R: Row<'a, DB>, +{ + type Field = R::Field; + type InnerPartialRow = R; + + fn field_count(&self) -> usize { + let inner_length = self.inner.field_count(); + if self.range.start < inner_length { + std::cmp::min(inner_length - self.range.start, self.range.len()) + } else { + 0 + } + } + + fn get(&self, idx: I) -> Option where - T: FromSql, + Self: RowIndex, { - let idx = self - .index_of(column_name) - .ok_or_else(|| format!("Column `{}` was not present in query", column_name).into()); - let idx = match idx { - Ok(x) => x, - Err(e) => return Err(e), - }; - let raw_value = self.get_raw_value(idx); - T::from_sql(raw_value) + let idx = self.idx(idx)?; + Some(self.inner.get(idx).unwrap()) } - #[doc(hidden)] - fn index_of(&self, column_name: &str) -> Option; - #[doc(hidden)] - fn get_raw_value(&self, index: usize) -> Option>; + fn partial_row(&self, range: Range) -> PartialRow { + let range = (self.range.start + range.start)..(self.range.start + range.end); + PartialRow { + inner: self.inner, + range, + } + } +} - /// Get a list of all field names in the current row - fn field_names(&self) -> Vec<&str>; +impl<'a, 'b, R> RowIndex<&'a str> for PartialRow<'b, R> +where + R: RowIndex<&'a str>, +{ + fn idx(&self, idx: &'a str) -> Option { + let idx = self.inner.idx(idx)?; + if self.range.contains(&idx) { + Some(idx) + } else { + None + } + } +} + +impl<'a, R> RowIndex for PartialRow<'a, R> +where + R: RowIndex, +{ + fn idx(&self, idx: usize) -> Option { + let idx = self.inner.idx(idx)? + self.range.start; + if self.range.contains(&idx) { + Some(idx) + } else { + None + } + } } diff --git a/diesel/src/serialize.rs b/diesel/src/serialize.rs index fe70c17b08d3..d87ca18155a1 100644 --- a/diesel/src/serialize.rs +++ b/diesel/src/serialize.rs @@ -145,7 +145,7 @@ where /// database, you should use `i32::to_sql(x, out)` instead of writing to `out` /// yourself. /// -/// Any types which implement this trait should also `#[derive(AsExpression)]`. +/// Any types which implement this trait should also [`#[derive(AsExpression)]`]. /// /// ### Backend specific details /// @@ -157,6 +157,7 @@ where /// - For third party backends, consult that backend's documentation. /// /// [`MysqlType`]: ../mysql/enum.MysqlType.html +/// [`#[derive(AsExpression)]`]: ../expression/derive.AsExpression.html; /// /// ### Examples /// @@ -165,12 +166,14 @@ where /// /// ```rust /// # use diesel::backend::Backend; +/// # use diesel::expression::AsExpression; /// # use diesel::sql_types::*; /// # use diesel::serialize::{self, ToSql, Output}; /// # use std::io::Write; /// # /// #[repr(i32)] -/// #[derive(Debug, Clone, Copy)] +/// #[derive(Debug, Clone, Copy, AsExpression)] +/// #[sql_type = "Integer"] /// pub enum MyEnum { /// A = 1, /// B = 2, diff --git a/diesel/src/sql_types/fold.rs b/diesel/src/sql_types/fold.rs index 023aedf142c3..91e0de2e8132 100644 --- a/diesel/src/sql_types/fold.rs +++ b/diesel/src/sql_types/fold.rs @@ -1,16 +1,16 @@ -use crate::sql_types::{self, NotNull}; +use crate::sql_types::{self, is_nullable, SingleValue, SqlType}; /// Represents SQL types which can be used with `SUM` and `AVG` -pub trait Foldable { +pub trait Foldable: SingleValue { /// The SQL type of `sum(this_type)` - type Sum; + type Sum: SqlType + SingleValue; /// The SQL type of `avg(this_type)` - type Avg; + type Avg: SqlType + SingleValue; } impl Foldable for sql_types::Nullable where - T: Foldable + NotNull, + T: Foldable + SqlType, { type Sum = T::Sum; type Avg = T::Avg; diff --git a/diesel/src/sql_types/mod.rs b/diesel/src/sql_types/mod.rs index bb5994e55332..20ca5ded3317 100644 --- a/diesel/src/sql_types/mod.rs +++ b/diesel/src/sql_types/mod.rs @@ -20,6 +20,7 @@ mod ord; pub use self::fold::Foldable; pub use self::ord::SqlOrd; +use crate::expression::TypedExpressionType; use crate::query_builder::QueryId; /// The boolean SQL type. @@ -377,7 +378,14 @@ pub struct Json; /// /// - `Option` for any `T` which implements `FromSql` #[derive(Debug, Clone, Copy, Default)] -pub struct Nullable(ST); +pub struct Nullable(ST); + +impl SqlType for Nullable +where + ST: SqlType, +{ + type IsNull = is_nullable::IsNullable; +} #[cfg(feature = "postgres")] pub use crate::pg::types::sql_types::*; @@ -427,15 +435,6 @@ pub trait TypeMetadata { type MetadataLookup; } -/// A marker trait indicating that a SQL type is not null. -/// -/// All SQL types must implement this trait. -/// -/// # Deriving -/// -/// This trait is automatically implemented by `#[derive(SqlType)]` -pub trait NotNull {} - /// Converts a type which may or may not be nullable into its nullable /// representation. pub trait IntoNullable { @@ -445,12 +444,41 @@ pub trait IntoNullable { type Nullable; } -impl IntoNullable for T { +impl IntoNullable for T +where + T: SqlType + SingleValue, +{ type Nullable = Nullable; } -impl IntoNullable for Nullable { - type Nullable = Nullable; +impl IntoNullable for Nullable +where + T: SqlType, +{ + type Nullable = Self; +} + +/// Converts a type which may or may not be nullable into its not nullable +/// representation. +pub trait IntoNotNullable { + /// The not nullable representation of this type. + /// + /// For `Nullable`, this will be `T` otherwise the type itself + type NotNullable; +} + +impl IntoNotNullable for T +where + T: SqlType, +{ + type NotNullable = T; +} + +impl IntoNotNullable for Nullable +where + T: SqlType, +{ + type NotNullable = T; } /// A marker trait indicating that a SQL type represents a single value, as @@ -462,12 +490,149 @@ impl IntoNullable for Nullable { /// /// # Deriving /// -/// This trait is automatically implemented by `#[derive(SqlType)]` -pub trait SingleValue {} +/// This trait is automatically implemented by [`#[derive(SqlType)]`] +/// +/// [`#[derive(SqlType)]`]: derive.SqlType.html +pub trait SingleValue: SqlType {} -impl SingleValue for Nullable {} +impl SingleValue for Nullable {} #[doc(inline)] pub use diesel_derives::DieselNumericOps; #[doc(inline)] pub use diesel_derives::SqlType; + +/// A marker trait for SQL types +/// +/// # Deriving +/// +/// This trait is automatically implemented by [`#[derive(SqlType)]`] +/// by setting `IsNull` to [`is_nullable::NotNull`] +/// +/// [`#[derive(SqlType)]`]: derive.SqlType.html +/// [`is_nullable::NotNull`]: is_nullable/struct.NotNull.html +pub trait SqlType { + /// Is this type nullable? + /// + /// This type should always be one of the structs in the ['is_nullable`] + /// module. See the documentation of those structs for more details. + /// + /// ['is_nullable`]: is_nullable/index.html + type IsNull: OneIsNullable + OneIsNullable; +} + +/// Is one value of `IsNull` nullable? +/// +/// You should never implement this trait. +pub trait OneIsNullable { + /// See the trait documentation + type Out: OneIsNullable + OneIsNullable; +} + +/// Are both values of `IsNull` are nullable? +pub trait AllAreNullable { + /// See the trait documentation + type Out: AllAreNullable + AllAreNullable; +} + +/// A type level constructor for maybe nullable types +/// +/// Constructs either `Nullable` (for `Self` == `is_nullable::IsNullable`) +/// or `O` (for `Self` == `is_nullable::NotNull`) +pub trait MaybeNullableType { + /// See the trait documentation + type Out: SqlType + TypedExpressionType; +} + +/// Possible values for `SqlType::IsNullable` +pub mod is_nullable { + use super::*; + + /// No, this type cannot be null as it is marked as `NOT NULL` at database level + /// + /// This should be choosen for basically all manual impls of `SqlType` + /// beside implementing your own `Nullable<>` wrapper type + #[derive(Debug, Clone, Copy)] + pub struct NotNull; + + /// Yes, this type can be null + /// + /// The only diesel provided `SqlType` that uses this value is [`Nullable`] + /// + /// [`Nullable`]: ../struct.Nullable.html + #[derive(Debug, Clone, Copy)] + pub struct IsNullable; + + impl OneIsNullable for NotNull { + type Out = NotNull; + } + + impl OneIsNullable for NotNull { + type Out = IsNullable; + } + + impl OneIsNullable for IsNullable { + type Out = IsNullable; + } + + impl OneIsNullable for IsNullable { + type Out = IsNullable; + } + + impl AllAreNullable for NotNull { + type Out = NotNull; + } + + impl AllAreNullable for NotNull { + type Out = NotNull; + } + + impl AllAreNullable for IsNullable { + type Out = NotNull; + } + + impl AllAreNullable for IsNullable { + type Out = IsNullable; + } + + impl MaybeNullableType for NotNull + where + O: SqlType + TypedExpressionType, + { + type Out = O; + } + + impl MaybeNullableType for IsNullable + where + O: SqlType, + Nullable: TypedExpressionType, + { + type Out = Nullable; + } + + /// Represents the output type of [`MaybeNullableType`](../trait.MaybeNullableType.html) + pub type MaybeNullable = >::Out; + + /// Represents the output type of [`OneIsNullable`](../trait.OneIsNullable.html) + /// for two given SQL types + pub type IsOneNullable = + as OneIsNullable>>::Out; + + /// Represents the output type of [`AllAreNullable`](../trait.AllAreNullable.html) + /// for two given SQL types + pub type AreAllNullable = + as AllAreNullable>>::Out; + + /// Represents if the SQL type is nullable or not + pub type IsSqlTypeNullable = ::IsNull; +} + +/// A marker trait for accepting expressions of the type `Bool` and +/// `Nullable` in the same place +pub trait BoolOrNullableBool {} + +impl BoolOrNullableBool for Bool {} +impl BoolOrNullableBool for Nullable {} + +#[doc(inline)] +pub use crate::expression::expression_types::Untyped; diff --git a/diesel/src/sql_types/ops.rs b/diesel/src/sql_types/ops.rs index ad8a6668126b..f548764c6aba 100644 --- a/diesel/src/sql_types/ops.rs +++ b/diesel/src/sql_types/ops.rs @@ -36,33 +36,33 @@ use super::*; /// Represents SQL types which can be added. pub trait Add { /// The SQL type which can be added to this one - type Rhs; + type Rhs: SqlType; /// The SQL type of the result of adding `Rhs` to `Self` - type Output; + type Output: SqlType; } /// Represents SQL types which can be subtracted. pub trait Sub { /// The SQL type which can be subtracted from this one - type Rhs; + type Rhs: SqlType; /// The SQL type of the result of subtracting `Rhs` from `Self` - type Output; + type Output: SqlType; } /// Represents SQL types which can be multiplied. pub trait Mul { /// The SQL type which this can be multiplied by - type Rhs; + type Rhs: SqlType; /// The SQL type of the result of multiplying `Self` by `Rhs` - type Output; + type Output: SqlType; } /// Represents SQL types which can be divided. pub trait Div { /// The SQL type which this one can be divided by - type Rhs; + type Rhs: SqlType; /// The SQL type of the result of dividing `Self` by `Rhs` - type Output; + type Output: SqlType; } macro_rules! numeric_type { @@ -145,9 +145,9 @@ impl Div for Interval { impl Add for Nullable where - T: Add + NotNull, - T::Rhs: NotNull, - T::Output: NotNull, + T: Add + SqlType, + T::Rhs: SqlType, + T::Output: SqlType, { type Rhs = Nullable; type Output = Nullable; @@ -155,9 +155,9 @@ where impl Sub for Nullable where - T: Sub + NotNull, - T::Rhs: NotNull, - T::Output: NotNull, + T: Sub + SqlType, + T::Rhs: SqlType, + T::Output: SqlType, { type Rhs = Nullable; type Output = Nullable; @@ -165,9 +165,9 @@ where impl Mul for Nullable where - T: Mul + NotNull, - T::Rhs: NotNull, - T::Output: NotNull, + T: Mul + SqlType, + T::Rhs: SqlType, + T::Output: SqlType, { type Rhs = Nullable; type Output = Nullable; @@ -175,9 +175,9 @@ where impl Div for Nullable where - T: Div + NotNull, - T::Rhs: NotNull, - T::Output: NotNull, + T: Div + SqlType, + T::Rhs: SqlType, + T::Output: SqlType, { type Rhs = Nullable; type Output = Nullable; diff --git a/diesel/src/sql_types/ord.rs b/diesel/src/sql_types/ord.rs index 7ce6293d14f6..e3ba8d3fdc6f 100644 --- a/diesel/src/sql_types/ord.rs +++ b/diesel/src/sql_types/ord.rs @@ -1,7 +1,7 @@ -use crate::sql_types::{self, NotNull}; +use crate::sql_types::{self, is_nullable, SqlType}; /// Marker trait for types which can be used with `MAX` and `MIN` -pub trait SqlOrd {} +pub trait SqlOrd: SqlType {} impl SqlOrd for sql_types::SmallInt {} impl SqlOrd for sql_types::Integer {} @@ -13,7 +13,7 @@ impl SqlOrd for sql_types::Date {} impl SqlOrd for sql_types::Interval {} impl SqlOrd for sql_types::Time {} impl SqlOrd for sql_types::Timestamp {} -impl SqlOrd for sql_types::Nullable {} +impl SqlOrd for sql_types::Nullable where T: SqlOrd + SqlType {} #[cfg(feature = "postgres")] impl SqlOrd for sql_types::Timestamptz {} diff --git a/diesel/src/sqlite/connection/functions.rs b/diesel/src/sqlite/connection/functions.rs index 8c7e5814d123..d55e9b7286e0 100644 --- a/diesel/src/sqlite/connection/functions.rs +++ b/diesel/src/sqlite/connection/functions.rs @@ -3,11 +3,12 @@ extern crate libsqlite3_sys as ffi; use super::raw::RawConnection; use super::serialized_value::SerializedValue; use super::{Sqlite, SqliteAggregateFunction, SqliteValue}; -use crate::deserialize::{FromSqlRow, Queryable, StaticallySizedRow}; +use crate::deserialize::{FromSqlRow, StaticallySizedRow}; use crate::result::{DatabaseErrorKind, Error, QueryResult}; -use crate::row::Row; +use crate::row::{Field, PartialRow, Row, RowIndex}; use crate::serialize::{IsNull, Output, ToSql}; use crate::sql_types::HasSqlType; +use std::marker::PhantomData; pub fn register( conn: &RawConnection, @@ -17,12 +18,11 @@ pub fn register( ) -> QueryResult<()> where F: FnMut(&RawConnection, Args) -> Ret + Send + 'static, - Args: Queryable, - Args::Row: StaticallySizedRow, + Args: FromSqlRow + StaticallySizedRow, Ret: ToSql, Sqlite: HasSqlType, { - let fields_needed = Args::Row::FIELD_COUNT; + let fields_needed = Args::FIELD_COUNT; if fields_needed > 127 { return Err(Error::DatabaseError( DatabaseErrorKind::UnableToSendCommand, @@ -46,12 +46,11 @@ pub fn register_aggregate( ) -> QueryResult<()> where A: SqliteAggregateFunction + 'static + Send, - Args: Queryable, - Args::Row: StaticallySizedRow, + Args: FromSqlRow + StaticallySizedRow, Ret: ToSql, Sqlite: HasSqlType, { - let fields_needed = Args::Row::FIELD_COUNT; + let fields_needed = Args::FIELD_COUNT; if fields_needed > 127 { return Err(Error::DatabaseError( DatabaseErrorKind::UnableToSendCommand, @@ -71,12 +70,10 @@ pub(crate) fn build_sql_function_args( args: &[*mut ffi::sqlite3_value], ) -> Result where - Args: Queryable, + Args: FromSqlRow, { - let mut row = FunctionRow::new(args); - let args_row = Args::Row::build_from_row(&mut row).map_err(Error::DeserializationError)?; - - Ok(Args::build(args_row)) + let row = FunctionRow::new(args); + Args::build_from_row(&row).map_err(Error::DeserializationError) } pub(crate) fn process_sql_function_result( @@ -101,39 +98,69 @@ where }) } +#[derive(Clone)] struct FunctionRow<'a> { - column_count: usize, args: &'a [*mut ffi::sqlite3_value], } impl<'a> FunctionRow<'a> { fn new(args: &'a [*mut ffi::sqlite3_value]) -> Self { - Self { - column_count: args.len(), - args, - } + Self { args } } } -impl<'a> Row for FunctionRow<'a> { - fn take(&mut self) -> Option> { - self.args.split_first().and_then(|(&first, rest)| { - self.args = rest; - unsafe { SqliteValue::new(first) } +impl<'a> Row<'a, Sqlite> for FunctionRow<'a> { + type Field = FunctionArgument<'a>; + type InnerPartialRow = Self; + + fn field_count(&self) -> usize { + self.args.len() + } + + fn get(&self, idx: I) -> Option + where + Self: crate::row::RowIndex, + { + let idx = self.idx(idx)?; + + self.args.get(idx).map(|arg| FunctionArgument { + arg: *arg, + p: PhantomData, }) } - fn next_is_null(&self, count: usize) -> bool { - self.args[..count] - .iter() - .all(|&p| unsafe { SqliteValue::new(p) }.is_none()) + fn partial_row(&self, range: std::ops::Range) -> PartialRow { + PartialRow::new(self, range) } +} - fn column_count(&self) -> usize { - self.column_count +impl<'a> RowIndex for FunctionRow<'a> { + fn idx(&self, idx: usize) -> Option { + Some(idx) } +} - fn column_name(&self) -> Option<&str> { +impl<'a> RowIndex<&'a str> for FunctionRow<'a> { + fn idx(&self, _idx: &'a str) -> Option { None } } + +struct FunctionArgument<'a> { + arg: *mut ffi::sqlite3_value, + p: PhantomData<&'a ()>, +} + +impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> { + fn field_name(&self) -> Option<&str> { + None + } + + fn is_null(&self) -> bool { + self.value().is_none() + } + + fn value(&self) -> Option> { + unsafe { SqliteValue::new(self.arg) } + } +} diff --git a/diesel/src/sqlite/connection/mod.rs b/diesel/src/sqlite/connection/mod.rs index 5392bde45609..cb8a49049c96 100644 --- a/diesel/src/sqlite/connection/mod.rs +++ b/diesel/src/sqlite/connection/mod.rs @@ -17,7 +17,8 @@ use self::statement_iterator::*; use self::stmt::{Statement, StatementUse}; use super::SqliteAggregateFunction; use crate::connection::*; -use crate::deserialize::{Queryable, QueryableByName, StaticallySizedRow}; +use crate::deserialize::{FromSqlRow, IsCompatibleType, StaticallySizedRow}; +use crate::expression::TypedExpressionType; use crate::query_builder::bind_collector::RawBytesBindCollector; use crate::query_builder::*; use crate::result::*; @@ -71,12 +72,13 @@ impl Connection for SqliteConnection { } #[doc(hidden)] - fn query_by_index(&self, source: T) -> QueryResult> + fn load(&self, source: T) -> QueryResult> where T: AsQuery, T::Query: QueryFragment + QueryId, - Self::Backend: HasSqlType, - U: Queryable, + U: FromSqlRow, + T::SqlType: IsCompatibleType, + ST: TypedExpressionType, { let mut statement = self.prepare_query(&source.as_query())?; let statement_use = StatementUse::new(&mut statement); @@ -84,18 +86,6 @@ impl Connection for SqliteConnection { iter.collect() } - #[doc(hidden)] - fn query_by_name(&self, source: &T) -> QueryResult> - where - T: QueryFragment + QueryId, - U: QueryableByName, - { - let mut statement = self.prepare_query(source)?; - let statement_use = StatementUse::new(&mut statement); - let iter = NamedStatementIterator::new(statement_use)?; - iter.collect() - } - #[doc(hidden)] fn execute_returning_count(&self, source: &T) -> QueryResult where @@ -227,8 +217,7 @@ impl SqliteConnection { ) -> QueryResult<()> where F: FnMut(Args) -> Ret + Send + 'static, - Args: Queryable, - Args::Row: StaticallySizedRow, + Args: FromSqlRow + StaticallySizedRow, Ret: ToSql, Sqlite: HasSqlType, { @@ -247,8 +236,7 @@ impl SqliteConnection { ) -> QueryResult<()> where A: SqliteAggregateFunction + 'static + Send, - Args: Queryable, - Args::Row: StaticallySizedRow, + Args: FromSqlRow + StaticallySizedRow, Ret: ToSql, Sqlite: HasSqlType, { diff --git a/diesel/src/sqlite/connection/raw.rs b/diesel/src/sqlite/connection/raw.rs index dde7f3c94b11..7d2484646344 100644 --- a/diesel/src/sqlite/connection/raw.rs +++ b/diesel/src/sqlite/connection/raw.rs @@ -9,7 +9,7 @@ use std::{mem, ptr, slice, str}; use super::functions::{build_sql_function_args, process_sql_function_result}; use super::serialized_value::SerializedValue; use super::{Sqlite, SqliteAggregateFunction}; -use crate::deserialize::Queryable; +use crate::deserialize::FromSqlRow; use crate::result::Error::DatabaseError; use crate::result::*; use crate::serialize::ToSql; @@ -116,7 +116,7 @@ impl RawConnection { ) -> QueryResult<()> where A: SqliteAggregateFunction + 'static + Send, - Args: Queryable, + Args: FromSqlRow, Ret: ToSql, Sqlite: HasSqlType, { @@ -266,7 +266,7 @@ extern "C" fn run_aggregator_step_function + 'static + Send, - Args: Queryable, + Args: FromSqlRow, Ret: ToSql, Sqlite: HasSqlType, { @@ -336,7 +336,7 @@ extern "C" fn run_aggregator_final_function + 'static + Send, - Args: Queryable, + Args: FromSqlRow, Ret: ToSql, Sqlite: HasSqlType, { diff --git a/diesel/src/sqlite/connection/sqlite_value.rs b/diesel/src/sqlite/connection/sqlite_value.rs index f63a723f28da..bbbb1ad2278a 100644 --- a/diesel/src/sqlite/connection/sqlite_value.rs +++ b/diesel/src/sqlite/connection/sqlite_value.rs @@ -1,6 +1,6 @@ extern crate libsqlite3_sys as ffi; -use std::collections::HashMap; +use std::marker::PhantomData; use std::os::raw as libc; use std::ptr::NonNull; use std::{slice, str}; @@ -14,20 +14,25 @@ use crate::sqlite::{Sqlite, SqliteType}; /// rust values: #[allow(missing_debug_implementations, missing_copy_implementations)] pub struct SqliteValue<'a> { - value: &'a ffi::sqlite3_value, + value: NonNull, + p: PhantomData<&'a ()>, } -pub struct SqliteRow { +#[derive(Clone)] +pub struct SqliteRow<'a> { stmt: NonNull, next_col_index: libc::c_int, + p: PhantomData<&'a ()>, } impl<'a> SqliteValue<'a> { #[allow(clippy::new_ret_no_self)] pub(crate) unsafe fn new(inner: *mut ffi::sqlite3_value) -> Option { - inner - .as_ref() - .map(|value| SqliteValue { value }) + NonNull::new(inner) + .map(|value| SqliteValue { + value, + p: PhantomData, + }) .and_then(|value| { // We check here that the actual value represented by the inner // `sqlite3_value` is not `NULL` (is sql meaning, not ptr meaning) @@ -41,8 +46,8 @@ impl<'a> SqliteValue<'a> { pub(crate) fn read_text(&self) -> &str { unsafe { - let ptr = ffi::sqlite3_value_text(self.value()); - let len = ffi::sqlite3_value_bytes(self.value()); + let ptr = ffi::sqlite3_value_text(self.value.as_ptr()); + let len = ffi::sqlite3_value_bytes(self.value.as_ptr()); let bytes = slice::from_raw_parts(ptr as *const u8, len as usize); // The string is guaranteed to be utf8 according to // https://www.sqlite.org/c3ref/value_blob.html @@ -52,27 +57,27 @@ impl<'a> SqliteValue<'a> { pub(crate) fn read_blob(&self) -> &[u8] { unsafe { - let ptr = ffi::sqlite3_value_blob(self.value()); - let len = ffi::sqlite3_value_bytes(self.value()); + let ptr = ffi::sqlite3_value_blob(self.value.as_ptr()); + let len = ffi::sqlite3_value_bytes(self.value.as_ptr()); slice::from_raw_parts(ptr as *const u8, len as usize) } } pub(crate) fn read_integer(&self) -> i32 { - unsafe { ffi::sqlite3_value_int(self.value()) as i32 } + unsafe { ffi::sqlite3_value_int(self.value.as_ptr()) as i32 } } pub(crate) fn read_long(&self) -> i64 { - unsafe { ffi::sqlite3_value_int64(self.value()) as i64 } + unsafe { ffi::sqlite3_value_int64(self.value.as_ptr()) as i64 } } pub(crate) fn read_double(&self) -> f64 { - unsafe { ffi::sqlite3_value_double(self.value()) as f64 } + unsafe { ffi::sqlite3_value_double(self.value.as_ptr()) as f64 } } /// Get the type of the value as returned by sqlite pub fn value_type(&self) -> Option { - let tpe = unsafe { ffi::sqlite3_value_type(self.value()) }; + let tpe = unsafe { ffi::sqlite3_value_type(self.value.as_ptr()) }; match tpe { ffi::SQLITE_TEXT => Some(SqliteType::Text), ffi::SQLITE_INTEGER => Some(SqliteType::Long), @@ -86,78 +91,82 @@ impl<'a> SqliteValue<'a> { pub(crate) fn is_null(&self) -> bool { self.value_type().is_none() } - - fn value(&self) -> *mut ffi::sqlite3_value { - self.value as *const _ as _ - } } -impl SqliteRow { - pub(crate) fn new(inner_statement: NonNull) -> Self { +impl<'a> SqliteRow<'a> { + pub(crate) unsafe fn new(inner_statement: NonNull) -> Self { SqliteRow { stmt: inner_statement, next_col_index: 0, - } - } - - pub fn into_named<'a>(self, indices: &'a HashMap<&'a str, usize>) -> SqliteNamedRow<'a> { - SqliteNamedRow { - stmt: self.stmt, - column_indices: indices, + p: PhantomData, } } } -impl Row for SqliteRow { - fn take(&mut self) -> Option> { - let col_index = self.next_col_index; - self.next_col_index += 1; +impl<'a> Row<'a, Sqlite> for SqliteRow<'a> { + type Field = SqliteField<'a>; + type InnerPartialRow = Self; - unsafe { - let ptr = ffi::sqlite3_column_value(self.stmt.as_ptr(), col_index); - SqliteValue::new(ptr) - } + fn field_count(&self) -> usize { + column_count(self.stmt) as usize } - fn next_is_null(&self, count: usize) -> bool { - (0..count).all(|i| { - let idx = self.next_col_index + i as libc::c_int; - let tpe = unsafe { ffi::sqlite3_column_type(self.stmt.as_ptr(), idx) }; - tpe == ffi::SQLITE_NULL + fn get(&self, idx: I) -> Option + where + Self: RowIndex, + { + let idx = self.idx(idx)?; + Some(SqliteField { + stmt: self.stmt, + col_idx: idx as i32, + p: PhantomData, }) } - fn column_name(&self) -> Option<&str> { - column_name(self.stmt, self.next_col_index) + fn partial_row(&self, range: std::ops::Range) -> PartialRow { + PartialRow::new(self, range) + } +} + +impl<'a> RowIndex for SqliteRow<'a> { + fn idx(&self, idx: usize) -> Option { + if idx < self.field_count() { + Some(idx) + } else { + None + } } +} - fn column_count(&self) -> usize { - column_count(self.stmt) as usize +impl<'a> RowIndex<&'a str> for SqliteRow<'a> { + fn idx(&self, field_name: &'a str) -> Option { + (0..column_count(self.stmt)) + .find(|idx| column_name(self.stmt, *idx) == Some(field_name)) + .map(|a| a as usize) } } -pub struct SqliteNamedRow<'a> { +pub struct SqliteField<'a> { stmt: NonNull, - column_indices: &'a HashMap<&'a str, usize>, + col_idx: i32, + p: PhantomData<&'a ()>, } -impl<'a> NamedRow for SqliteNamedRow<'a> { - fn index_of(&self, column_name: &str) -> Option { - self.column_indices.get(column_name).cloned() +impl<'a> Field<'a, Sqlite> for SqliteField<'a> { + fn field_name(&self) -> Option<&str> { + column_name(self.stmt, self.col_idx) + } + + fn is_null(&self) -> bool { + self.value().is_none() } - fn get_raw_value(&self, idx: usize) -> Option> { + fn value(&self) -> Option> { unsafe { - let ptr = ffi::sqlite3_column_value(self.stmt.as_ptr(), idx as libc::c_int); + let ptr = ffi::sqlite3_column_value(self.stmt.as_ptr(), self.col_idx); SqliteValue::new(ptr) } } - - fn field_names(&self) -> Vec<&str> { - (0..column_count(self.stmt)) - .filter_map(|c| column_name(self.stmt, c)) - .collect() - } } fn column_name<'a>(stmt: NonNull, field_number: i32) -> Option<&'a str> { diff --git a/diesel/src/sqlite/connection/statement_iterator.rs b/diesel/src/sqlite/connection/statement_iterator.rs index 098b3c2baeaf..91195631faf9 100644 --- a/diesel/src/sqlite/connection/statement_iterator.rs +++ b/diesel/src/sqlite/connection/statement_iterator.rs @@ -1,8 +1,7 @@ -use std::collections::HashMap; use std::marker::PhantomData; use super::stmt::StatementUse; -use crate::deserialize::{FromSqlRow, Queryable, QueryableByName}; +use crate::deserialize::FromSqlRow; use crate::result::Error::DeserializationError; use crate::result::QueryResult; use crate::sqlite::Sqlite; @@ -23,7 +22,7 @@ impl<'a, ST, T> StatementIterator<'a, ST, T> { impl<'a, ST, T> Iterator for StatementIterator<'a, ST, T> where - T: Queryable, + T: FromSqlRow, { type Item = QueryResult; @@ -32,55 +31,6 @@ where Ok(row) => row, Err(e) => return Some(Err(e)), }; - row.map(|mut row| { - T::Row::build_from_row(&mut row) - .map(T::build) - .map_err(DeserializationError) - }) - } -} - -pub struct NamedStatementIterator<'a, T> { - stmt: StatementUse<'a>, - column_indices: HashMap<&'a str, usize>, - _marker: PhantomData, -} - -impl<'a, T> NamedStatementIterator<'a, T> { - #[allow(clippy::new_ret_no_self)] - pub fn new(stmt: StatementUse<'a>) -> QueryResult { - let column_indices = (0..stmt.num_fields()) - .filter_map(|i| { - stmt.field_name(i).map(|column| { - let column = column - .to_str() - .map_err(|e| DeserializationError(e.into()))?; - Ok((column, i)) - }) - }) - .collect::>()?; - Ok(NamedStatementIterator { - stmt, - column_indices, - _marker: PhantomData, - }) - } -} - -impl<'a, T> Iterator for NamedStatementIterator<'a, T> -where - T: QueryableByName, -{ - type Item = QueryResult; - - fn next(&mut self) -> Option { - let row = match self.stmt.step() { - Ok(row) => row, - Err(e) => return Some(Err(e)), - }; - row.map(|row| { - let row = row.into_named(&self.column_indices); - T::build(&row).map_err(DeserializationError) - }) + row.map(|row| T::build_from_row(&row).map_err(DeserializationError)) } } diff --git a/diesel/src/sqlite/connection/stmt.rs b/diesel/src/sqlite/connection/stmt.rs index 07fe2989847e..3ca28c8ecbc3 100644 --- a/diesel/src/sqlite/connection/stmt.rs +++ b/diesel/src/sqlite/connection/stmt.rs @@ -54,26 +54,13 @@ impl Statement { ensure_sqlite_ok(result, self.raw_connection()) } - fn num_fields(&self) -> usize { - unsafe { ffi::sqlite3_column_count(self.inner_statement.as_ptr()) as usize } - } - - /// The lifetime of the returned CStr is shorter than self. This function - /// should be tied to a lifetime that ends before the next call to `reset` - unsafe fn field_name<'a>(&self, idx: usize) -> Option<&'a CStr> { - let ptr = ffi::sqlite3_column_name(self.inner_statement.as_ptr(), idx as libc::c_int); - if ptr.is_null() { - None - } else { - Some(CStr::from_ptr(ptr)) - } - } - fn step(&mut self) -> QueryResult> { - match unsafe { ffi::sqlite3_step(self.inner_statement.as_ptr()) } { - ffi::SQLITE_DONE => Ok(None), - ffi::SQLITE_ROW => Ok(Some(SqliteRow::new(self.inner_statement))), - _ => Err(last_error(self.raw_connection())), + unsafe { + match ffi::sqlite3_step(self.inner_statement.as_ptr()) { + ffi::SQLITE_DONE => Ok(None), + ffi::SQLITE_ROW => Ok(Some(SqliteRow::new(self.inner_statement))), + _ => Err(last_error(self.raw_connection())), + } } } @@ -158,14 +145,6 @@ impl<'a> StatementUse<'a> { pub fn step(&mut self) -> QueryResult> { self.statement.step() } - - pub fn num_fields(&self) -> usize { - self.statement.num_fields() - } - - pub fn field_name(&self, idx: usize) -> Option<&'a CStr> { - unsafe { self.statement.field_name(idx) } - } } impl<'a> Drop for StatementUse<'a> { diff --git a/diesel/src/sqlite/types/date_and_time/chrono.rs b/diesel/src/sqlite/types/date_and_time/chrono.rs index 7cdb0a0eaadd..7d186c6f6130 100644 --- a/diesel/src/sqlite/types/date_and_time/chrono.rs +++ b/diesel/src/sqlite/types/date_and_time/chrono.rs @@ -12,7 +12,7 @@ use crate::sqlite::Sqlite; const SQLITE_DATE_FORMAT: &str = "%F"; impl FromSql for NaiveDate { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: backend::RawValue) -> deserialize::Result { let text_ptr = <*const str as FromSql>::from_sql(value)?; let text = unsafe { &*text_ptr }; Self::parse_from_str(text, SQLITE_DATE_FORMAT).map_err(Into::into) @@ -27,7 +27,7 @@ impl ToSql for NaiveDate { } impl FromSql for NaiveTime { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: backend::RawValue) -> deserialize::Result { let text_ptr = <*const str as FromSql>::from_sql(value)?; let text = unsafe { &*text_ptr }; let valid_time_formats = &[ @@ -54,7 +54,7 @@ impl ToSql for NaiveTime { } impl FromSql for NaiveDateTime { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: backend::RawValue) -> deserialize::Result { let text_ptr = <*const str as FromSql>::from_sql(value)?; let text = unsafe { &*text_ptr }; diff --git a/diesel/src/sqlite/types/date_and_time/mod.rs b/diesel/src/sqlite/types/date_and_time/mod.rs index 4a0a3d877489..18bdd45713f0 100644 --- a/diesel/src/sqlite/types/date_and_time/mod.rs +++ b/diesel/src/sqlite/types/date_and_time/mod.rs @@ -15,7 +15,7 @@ mod chrono; /// raw pointer instead of a reference with a lifetime due to the structure of /// `FromSql` impl FromSql for *const str { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { FromSql::::from_sql(value) } } @@ -38,7 +38,7 @@ impl ToSql for String { /// raw pointer instead of a reference with a lifetime due to the structure of /// `FromSql` impl FromSql for *const str { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { FromSql::::from_sql(value) } } @@ -61,7 +61,7 @@ impl ToSql for String { /// raw pointer instead of a reference with a lifetime due to the structure of /// `FromSql` impl FromSql for *const str { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { FromSql::::from_sql(value) } } diff --git a/diesel/src/sqlite/types/mod.rs b/diesel/src/sqlite/types/mod.rs index ad477463f2e9..326fc40bc5fe 100644 --- a/diesel/src/sqlite/types/mod.rs +++ b/diesel/src/sqlite/types/mod.rs @@ -15,8 +15,7 @@ use crate::sql_types; /// raw pointer instead of a reference with a lifetime due to the structure of /// `FromSql` impl FromSql for *const str { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { let text = value.read_text(); Ok(text as *const _) } @@ -28,46 +27,45 @@ impl FromSql for *const str { /// raw pointer instead of a reference with a lifetime due to the structure of /// `FromSql` impl FromSql for *const [u8] { - fn from_sql(bytes: Option>) -> deserialize::Result { - let bytes = not_none!(bytes); + fn from_sql(bytes: SqliteValue<'_>) -> deserialize::Result { let bytes = bytes.read_blob(); Ok(bytes as *const _) } } impl FromSql for i16 { - fn from_sql(value: Option>) -> deserialize::Result { - Ok(not_none!(value).read_integer() as i16) + fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + Ok(value.read_integer() as i16) } } impl FromSql for i32 { - fn from_sql(value: Option>) -> deserialize::Result { - Ok(not_none!(value).read_integer()) + fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + Ok(value.read_integer()) } } impl FromSql for bool { - fn from_sql(value: Option>) -> deserialize::Result { - Ok(not_none!(value).read_integer() != 0) + fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + Ok(value.read_integer() != 0) } } impl FromSql for i64 { - fn from_sql(value: Option>) -> deserialize::Result { - Ok(not_none!(value).read_long()) + fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + Ok(value.read_long()) } } impl FromSql for f32 { - fn from_sql(value: Option>) -> deserialize::Result { - Ok(not_none!(value).read_double() as f32) + fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + Ok(value.read_double() as f32) } } impl FromSql for f64 { - fn from_sql(value: Option>) -> deserialize::Result { - Ok(not_none!(value).read_double()) + fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + Ok(value.read_double()) } } diff --git a/diesel/src/sqlite/types/numeric.rs b/diesel/src/sqlite/types/numeric.rs index ac8606b2f33a..976921e56d01 100644 --- a/diesel/src/sqlite/types/numeric.rs +++ b/diesel/src/sqlite/types/numeric.rs @@ -8,7 +8,7 @@ use crate::sqlite::connection::SqliteValue; use crate::sqlite::Sqlite; impl FromSql for BigDecimal { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: SqliteValue<'_>) -> deserialize::Result { let data = >::from_sql(bytes)?; Ok(data.into()) } diff --git a/diesel/src/type_impls/floats.rs b/diesel/src/type_impls/floats.rs index 01d2d139e793..3fd60e680502 100644 --- a/diesel/src/type_impls/floats.rs +++ b/diesel/src/type_impls/floats.rs @@ -11,8 +11,7 @@ impl FromSql for f32 where DB: Backend + for<'a> BinaryRawValue<'a>, { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: crate::backend::RawValue) -> deserialize::Result { let mut bytes = DB::as_bytes(value); debug_assert!( bytes.len() <= 4, @@ -37,8 +36,7 @@ impl FromSql for f64 where DB: Backend + for<'a> BinaryRawValue<'a>, { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: crate::backend::RawValue) -> deserialize::Result { let mut bytes = DB::as_bytes(value); debug_assert!( bytes.len() <= 8, diff --git a/diesel/src/type_impls/integers.rs b/diesel/src/type_impls/integers.rs index 108e8ddcb9c0..bf08ce1878cb 100644 --- a/diesel/src/type_impls/integers.rs +++ b/diesel/src/type_impls/integers.rs @@ -11,8 +11,7 @@ impl FromSql for i16 where DB: Backend + for<'a> BinaryRawValue<'a>, { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: crate::backend::RawValue) -> deserialize::Result { let mut bytes = DB::as_bytes(value); debug_assert!( bytes.len() <= 2, @@ -43,8 +42,7 @@ impl FromSql for i32 where DB: Backend + for<'a> BinaryRawValue<'a>, { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: crate::backend::RawValue) -> deserialize::Result { let mut bytes = DB::as_bytes(value); debug_assert!( bytes.len() <= 4, @@ -74,8 +72,7 @@ impl FromSql for i64 where DB: Backend + for<'a> BinaryRawValue<'a>, { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: crate::backend::RawValue) -> deserialize::Result { let mut bytes = DB::as_bytes(value); debug_assert!( bytes.len() <= 8, diff --git a/diesel/src/type_impls/option.rs b/diesel/src/type_impls/option.rs index 5d11c6038c71..ec0b3a882b06 100644 --- a/diesel/src/type_impls/option.rs +++ b/diesel/src/type_impls/option.rs @@ -1,19 +1,18 @@ use std::io::Write; use crate::backend::{self, Backend}; -use crate::deserialize::{self, FromSql, FromSqlRow, Queryable, QueryableByName}; +use crate::deserialize::{self, FromSql, FromSqlRow, StaticallySizedRow}; use crate::expression::bound::Bound; use crate::expression::*; use crate::query_builder::QueryId; use crate::result::UnexpectedNullError; -use crate::row::NamedRow; use crate::serialize::{self, IsNull, Output, ToSql}; -use crate::sql_types::{HasSqlType, NotNull, Nullable}; +use crate::sql_types::{is_nullable, HasSqlType, IntoNotNullable, Nullable, SqlType, Untyped}; impl HasSqlType> for DB where DB: Backend + HasSqlType, - T: NotNull, + T: SqlType, { fn metadata(lookup: &DB::MetadataLookup) -> DB::TypeMetadata { >::metadata(lookup) @@ -27,7 +26,7 @@ where impl QueryId for Nullable where - T: QueryId + NotNull, + T: QueryId + SqlType, { type QueryId = T::QueryId; @@ -38,69 +37,65 @@ impl FromSql, DB> for Option where T: FromSql, DB: Backend, - ST: NotNull, + ST: SqlType, { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: backend::RawValue) -> deserialize::Result { + T::from_sql(bytes).map(Some) + } + + fn from_nullable_sql(bytes: Option>) -> deserialize::Result { match bytes { - Some(_) => T::from_sql(bytes).map(Some), + Some(bytes) => T::from_sql(bytes).map(Some), None => Ok(None), } } } -impl Queryable, DB> for Option +pub trait ValidNullable { + type Inner; +} + +impl ValidNullable for ST where - T: Queryable, - DB: Backend, - Option: FromSqlRow, DB>, - ST: NotNull, + ST: SqlType + IntoNotNullable, { - type Row = Option; + type Inner = ST::NotNullable; +} - fn build(row: Self::Row) -> Self { - row.map(T::build) - } +impl ValidNullable for Untyped { + type Inner = Untyped; } -impl QueryableByName for Option +impl FromSqlRow for Option where - T: QueryableByName, DB: Backend, + ST: ValidNullable, + T: FromSqlRow, { - fn build>(row: &R) -> deserialize::Result { - match T::build(row) { + fn build_from_row<'a>(row: &impl crate::row::Row<'a, DB>) -> deserialize::Result { + match T::build_from_row(row) { Ok(v) => Ok(Some(v)), - Err(e) => { - if e.is::() { - Ok(None) - } else { - Err(e) - } - } + Err(e) if e.is::() => Ok(None), + Err(e) => Err(e), } } } -impl FromSqlRow, DB> for Option +impl StaticallySizedRow, DB> for Option where - T: FromSqlRow, + T: StaticallySizedRow, + ST: SqlType, DB: Backend, - ST: NotNull, + Self: FromSqlRow, DB>, { - fn build_from_row>(row: &mut R) -> deserialize::Result { - match T::build_from_row(row) { - Ok(v) => Ok(Some(v)), - Err(e) if e.is::() => Ok(None), - Err(e) => Err(e), - } - } + const FIELD_COUNT: usize = T::FIELD_COUNT; } impl ToSql, DB> for Option where T: ToSql, DB: Backend, - ST: NotNull, + ST: SqlType, { fn to_sql(&self, out: &mut Output) -> serialize::Result { if let Some(ref value) = *self { @@ -113,7 +108,8 @@ where impl AsExpression> for Option where - ST: NotNull, + ST: SqlType, + Nullable: TypedExpressionType, { type Expression = Bound, Self>; @@ -124,7 +120,8 @@ where impl<'a, T, ST> AsExpression> for &'a Option where - ST: NotNull, + ST: SqlType, + Nullable: TypedExpressionType, { type Expression = Bound, Self>; diff --git a/diesel/src/type_impls/primitives.rs b/diesel/src/type_impls/primitives.rs index 0d889f0a4c95..e09f6470d7fd 100644 --- a/diesel/src/type_impls/primitives.rs +++ b/diesel/src/type_impls/primitives.rs @@ -2,11 +2,9 @@ use std::error::Error; use std::io::Write; use crate::backend::{self, Backend, BinaryRawValue}; -use crate::deserialize::{self, FromSql, FromSqlRow, Queryable}; +use crate::deserialize::{self, FromSql, FromSqlRow, StaticallySizedRow}; use crate::serialize::{self, IsNull, Output, ToSql}; -use crate::sql_types::{ - self, BigInt, Binary, Bool, Double, Float, Integer, NotNull, SmallInt, Text, -}; +use crate::sql_types::{self, BigInt, Binary, Bool, Double, Float, Integer, SmallInt, Text}; #[allow(dead_code)] mod foreign_impls { @@ -102,14 +100,12 @@ mod foreign_impls { struct BinarySliceProxy([u8]); } -impl NotNull for () {} - impl FromSql for String where DB: Backend, *const str: FromSql, { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: backend::RawValue) -> deserialize::Result { let str_ptr = <*const str as FromSql>::from_sql(bytes)?; // We know that the pointer impl will never return null let string = unsafe { &*str_ptr }; @@ -127,9 +123,8 @@ impl FromSql for *const str where DB: Backend + for<'a> BinaryRawValue<'a>, { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: crate::backend::RawValue) -> deserialize::Result { use std::str; - let value = not_none!(value); let string = str::from_utf8(DB::as_bytes(value))?; Ok(string as *const _) } @@ -140,9 +135,8 @@ impl FromSql for *const str where DB: Backend + for<'a> BinaryRawValue<'a>, { - default fn from_sql(value: Option>) -> deserialize::Result { + default fn from_sql(value: crate::backend::RawValue) -> deserialize::Result { use std::str; - let value = not_none!(value); let string = str::from_utf8(DB::as_bytes(value))?; Ok(string as *const _) } @@ -171,7 +165,7 @@ where DB: Backend, *const [u8]: FromSql, { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: backend::RawValue) -> deserialize::Result { let slice_ptr = <*const [u8] as FromSql>::from_sql(bytes)?; // We know that the pointer impl will never return null let bytes = unsafe { &*slice_ptr }; @@ -188,8 +182,8 @@ impl FromSql for *const [u8] where DB: Backend + for<'a> BinaryRawValue<'a>, { - fn from_sql(bytes: Option>) -> deserialize::Result { - Ok(DB::as_bytes(not_none!(bytes)) as *const _) + fn from_sql(bytes: backend::RawValue) -> deserialize::Result { + Ok(DB::as_bytes(bytes) as *const _) } } @@ -230,7 +224,7 @@ where DB: Backend, T::Owned: FromSql, { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: backend::RawValue) -> deserialize::Result { T::Owned::from_sql(bytes).map(Cow::Owned) } } @@ -241,31 +235,33 @@ where DB: Backend, Cow<'a, T>: FromSql, { - fn build_from_row>(row: &mut R) -> deserialize::Result { - FromSql::::from_sql(row.take()) + fn build_from_row<'b>(row: &impl crate::row::Row<'b, DB>) -> deserialize::Result { + use crate::row::Field; + FromSql::::from_nullable_sql( + row.get(0) + .ok_or_else(|| Box::new(crate::result::UnexpectedEndOfRow))? + .value(), + ) } } -impl<'a, T: ?Sized, ST, DB> Queryable for Cow<'a, T> +impl<'a, T: ?Sized, ST, DB> StaticallySizedRow for Cow<'a, T> where T: 'a + ToOwned, DB: Backend, Self: FromSqlRow, { - type Row = Self; - - fn build(row: Self::Row) -> Self { - row - } } use crate::expression::bound::Bound; -use crate::expression::{AsExpression, Expression}; +use crate::expression::{AsExpression, Expression, TypedExpressionType}; +use sql_types::SqlType; impl<'a, T: ?Sized, ST> AsExpression for Cow<'a, T> where T: 'a + ToOwned, Bound>: Expression, + ST: SqlType + TypedExpressionType, { type Expression = Bound; @@ -278,6 +274,7 @@ impl<'a, 'b, T: ?Sized, ST> AsExpression for &'b Cow<'a, T> where T: 'a + ToOwned, Bound: Expression, + ST: SqlType + TypedExpressionType, { type Expression = Bound; diff --git a/diesel/src/type_impls/tuples.rs b/diesel/src/type_impls/tuples.rs index e0fb8464c665..1c631a887831 100644 --- a/diesel/src/type_impls/tuples.rs +++ b/diesel/src/type_impls/tuples.rs @@ -1,17 +1,16 @@ -use std::error::Error; - use crate::associations::BelongsTo; use crate::backend::Backend; -use crate::deserialize::{self, FromSqlRow, Queryable, QueryableByName, StaticallySizedRow}; +use crate::deserialize::{FromSqlRow, IsCompatibleType, StaticallySizedRow}; use crate::expression::{ - AppearsOnTable, AsExpression, AsExpressionList, Expression, SelectableExpression, ValidGrouping, + AppearsOnTable, AsExpression, AsExpressionList, Expression, SelectableExpression, + TypedExpressionType, ValidGrouping, }; use crate::insertable::{CanInsertInSingleQuery, InsertValues, Insertable}; use crate::query_builder::*; use crate::query_source::*; use crate::result::QueryResult; use crate::row::*; -use crate::sql_types::{HasSqlType, NotNull}; +use crate::sql_types::{HasSqlType, IntoNullable, Nullable, OneIsNullable, SqlType}; use crate::util::TupleAppend; macro_rules! tuple_impls { @@ -35,66 +34,32 @@ macro_rules! tuple_impls { } } - impl<$($T),+> NotNull for ($($T,)+) { - } - - impl<$($T),+, $($ST),+, __DB> FromSqlRow<($($ST,)+), __DB> for ($($T,)+) where - __DB: Backend, - $($T: FromSqlRow<$ST, __DB>),+, - { - - #[allow(non_snake_case)] - fn build_from_row>(row: &mut RowT) -> Result> { - // First call `build_from_row` for all tuple elements - // to advance the row iterator correctly - $( - let $ST = $T::build_from_row(row); - )+ - - // Afterwards bubble up any possible errors - Ok(($($ST?,)+)) - - // As a note for anyone trying to optimize this: - // We cannot just call something like `row.take()` for the - // remaining tuple elements as we cannot know how much childs - // they have on their own. For example one of them could be - // `Option<(A, B)>`. Just calling `row.take()` as many times - // as tuple elements are left would cause calling `row.take()` - // at least one time less then required (as the child has two) - // elements, not one. - } - } + impl_from_sql_row!(($($T,)+), ($($ST,)+)); - impl<$($T),+, $($ST),+, __DB > StaticallySizedRow<($($ST,)+), __DB> for ($($T,)+) where + impl<$($T),+, $($ST),+, __DB > StaticallySizedRow<($($ST,)*), __DB> for ($($T,)+) where __DB: Backend, + Self: FromSqlRow<($($ST,)+), __DB>, $($T: StaticallySizedRow<$ST, __DB>,)+ { const FIELD_COUNT: usize = $($T::FIELD_COUNT +)+ 0; } - impl<$($T),+, $($ST),+, __DB> Queryable<($($ST,)+), __DB> for ($($T,)+) where - __DB: Backend, - $($T: Queryable<$ST, __DB>),+, + impl<$($T: Expression),+> Expression for ($($T,)+) + where ($($T::SqlType, )*): TypedExpressionType { - type Row = ($($T::Row,)+); - - fn build(row: Self::Row) -> Self { - ($($T::build(row.$idx),)+) - } + type SqlType = ($(<$T as Expression>::SqlType,)+); } - impl<$($T,)+ __DB> QueryableByName<__DB> for ($($T,)+) - where - __DB: Backend, - $($T: QueryableByName<__DB>,)+ + impl<$($T: TypedExpressionType,)*> TypedExpressionType for ($($T,)*) {} + impl<$($T: SqlType + TypedExpressionType,)*> TypedExpressionType for Nullable<($($T,)*)> + where ($($T,)*): SqlType { - fn build>(row: &RowT) -> deserialize::Result { - Ok(($($T::build(row)?,)+)) - } } - impl<$($T: Expression),+> Expression for ($($T,)+) { - type SqlType = ($(<$T as Expression>::SqlType,)+); + impl<$($T: SqlType,)*> IntoNullable for ($($T,)*) + where Self: SqlType, + { + type Nullable = Nullable<($($T,)*)>; } impl<$($T: QueryFragment<__DB>),+, __DB: Backend> QueryFragment<__DB> for ($($T,)+) { @@ -256,6 +221,7 @@ macro_rules! tuple_impls { impl<$($T,)+ ST> AsExpressionList for ($($T,)+) where $($T: AsExpression,)+ + ST: SqlType + TypedExpressionType, { type Expression = ($($T::Expression,)+); @@ -263,8 +229,85 @@ macro_rules! tuple_impls { ($(self.$idx.as_expression(),)+) } } + + impl_sql_type!($($T,)*); + + impl<$($T: IsCompatibleType<__DB>,)* __DB> IsCompatibleType<__DB> for ($($T,)*) { + type Compatible = Self; + } + + impl<$($T: IsCompatibleType<__DB> + SqlType,)* __DB> IsCompatibleType<__DB> for Nullable<($($T,)*)> + where ($($T,)*): SqlType + { + type Compatible = Self; + } )+ } } +macro_rules! impl_from_sql_row { + (($T1: ident,), ($ST1: ident,)) => { + impl<$T1, $ST1, __DB> FromSqlRow<($ST1,), __DB> for ($T1,) where + __DB: Backend, + $T1: FromSqlRow<$ST1, __DB>, + { + + #[allow(non_snake_case, unused_variables, unused_mut)] + fn build_from_row<'a>(row: &impl Row<'a, __DB>) + -> crate::deserialize::Result + { + Ok(($T1::build_from_row(row)?,)) + } + } + }; + (($T1: ident, $($T: ident,)*), ($ST1: ident, $($ST: ident,)*)) => { + impl<$T1, $ST1, $($T,)* $($ST,)* __DB> FromSqlRow<($($ST,)* $ST1,), __DB> for ($($T,)* $T1,) where + __DB: Backend, + $T1: FromSqlRow<$ST1, __DB>, + $( + $T: FromSqlRow<$ST, __DB> + StaticallySizedRow<$ST, __DB>, + )* + + { + + #[allow(non_snake_case, unused_variables, unused_mut)] + fn build_from_row<'a>(full_row: &impl Row<'a, __DB>) + -> crate::deserialize::Result + { + let field_count = full_row.field_count(); + + let mut static_field_count = 0; + $( + let row = full_row.partial_row(static_field_count..static_field_count + $T::FIELD_COUNT); + static_field_count += $T::FIELD_COUNT; + let $T = $T::build_from_row(&row)?; + )* + + let row = full_row.partial_row(static_field_count..field_count); + + Ok(($($T,)* $T1::build_from_row(&row)?,)) + } + } + } +} + +macro_rules! impl_sql_type { + ($T1: ident, $($T: ident,)+) => { + impl<$T1, $($T,)+> SqlType for ($T1, $($T,)*) + where $T1: SqlType, + ($($T,)*): SqlType, + $T1::IsNull: OneIsNullable<<($($T,)*) as SqlType>::IsNull>, + { + type IsNull = <$T1::IsNull as OneIsNullable<<($($T,)*) as SqlType>::IsNull>>::Out; + } + }; + ($T1: ident,) => { + impl<$T1> SqlType for ($T1,) + where $T1: SqlType, + { + type IsNull = $T1::IsNull; + } + } +} + __diesel_for_each_tuple!(tuple_impls); diff --git a/diesel_cli/src/infer_schema_internals/data_structures.rs b/diesel_cli/src/infer_schema_internals/data_structures.rs index 1e4fdd76a2a0..09f8b08d4df0 100644 --- a/diesel_cli/src/infer_schema_internals/data_structures.rs +++ b/diesel_cli/src/infer_schema_internals/data_structures.rs @@ -1,6 +1,9 @@ #[cfg(feature = "uses_information_schema")] use diesel::backend::Backend; -use diesel::deserialize::{FromSqlRow, Queryable}; +use diesel::deserialize::FromSqlRow; +use diesel::sql_types::Text; +#[cfg(feature = "sqlite")] +use diesel::sql_types::{Bool, Integer, Nullable}; #[cfg(feature = "sqlite")] use diesel::sqlite::Sqlite; @@ -73,27 +76,34 @@ impl ColumnInformation { } #[cfg(feature = "uses_information_schema")] -impl Queryable for ColumnInformation +impl FromSqlRow<(Text, Text, Text), DB> for ColumnInformation where DB: Backend + UsesInformationSchema, - (String, String, String): FromSqlRow, + (String, String, String): FromSqlRow<(Text, Text, Text), DB>, { - type Row = (String, String, String); + fn build_from_row<'a>( + row: &impl diesel::row::Row<'a, DB>, + ) -> diesel::deserialize::Result { + let row = <(String, String, String) as FromSqlRow<_, DB>>::build_from_row(row)?; - fn build(row: Self::Row) -> Self { - ColumnInformation::new(row.0, row.1, row.2 == "YES") + Ok(ColumnInformation::new(row.0, row.1, row.2 == "YES")) } } #[cfg(feature = "sqlite")] -impl Queryable for ColumnInformation +impl FromSqlRow<(Integer, Text, Text, Bool, Nullable, Bool), Sqlite> for ColumnInformation where - (i32, String, String, bool, Option, bool): FromSqlRow, + (i32, String, String, bool, Option, bool): + FromSqlRow<(Integer, Text, Text, Bool, Nullable, Bool), Sqlite>, { - type Row = (i32, String, String, bool, Option, bool); - - fn build(row: Self::Row) -> Self { - ColumnInformation::new(row.1, row.2, !row.3) + fn build_from_row<'a>( + row: &impl diesel::row::Row<'a, Sqlite>, + ) -> diesel::deserialize::Result { + let row = <(i32, String, String, bool, Option, bool) as FromSqlRow< + (Integer, Text, Text, Bool, Nullable, Bool), + Sqlite, + >>::build_from_row(row)?; + Ok(ColumnInformation::new(row.1, row.2, !row.3)) } } diff --git a/diesel_cli/src/infer_schema_internals/information_schema.rs b/diesel_cli/src/infer_schema_internals/information_schema.rs index 91bee4395965..a3052967f79f 100644 --- a/diesel_cli/src/infer_schema_internals/information_schema.rs +++ b/diesel_cli/src/infer_schema_internals/information_schema.rs @@ -78,7 +78,8 @@ mod information_schema { table_schema -> VarChar, table_name -> VarChar, column_name -> VarChar, - is_nullable -> VarChar, + #[sql_name = "is_nullable"] + __is_nullable -> VarChar, ordinal_position -> BigInt, udt_name -> VarChar, column_type -> VarChar, @@ -135,7 +136,7 @@ where ( columns::column_name, ::TypeColumn, - columns::is_nullable, + columns::__is_nullable, ), >, Eq, @@ -154,7 +155,7 @@ where let type_column = Conn::Backend::type_column(); columns - .select((column_name, type_column, is_nullable)) + .select((column_name, type_column, __is_nullable)) .filter(table_name.eq(&table.sql_name)) .filter(table_schema.eq(schema_name)) .order(ordinal_position) diff --git a/diesel_cli/src/infer_schema_internals/sqlite.rs b/diesel_cli/src/infer_schema_internals/sqlite.rs index 765a903f38d4..80fef012598b 100644 --- a/diesel_cli/src/infer_schema_internals/sqlite.rs +++ b/diesel_cli/src/infer_schema_internals/sqlite.rs @@ -52,7 +52,7 @@ pub fn load_table_names( .select(name) .filter(name.not_like("\\_\\_%").escape('\\')) .filter(name.not_like("sqlite%")) - .filter(sql("type='table'")) + .filter(sql::("type='table'")) .order(name) .load::(connection)? .into_iter() diff --git a/diesel_cli/src/infer_schema_internals/table_data.rs b/diesel_cli/src/infer_schema_internals/table_data.rs index f6c6bad05bc4..a44dbf747170 100644 --- a/diesel_cli/src/infer_schema_internals/table_data.rs +++ b/diesel_cli/src/infer_schema_internals/table_data.rs @@ -1,5 +1,6 @@ use diesel::backend::Backend; -use diesel::deserialize::{FromSqlRow, Queryable}; +use diesel::deserialize::{FromSqlRow, StaticallySizedRow}; +use diesel::sql_types::Text; use std::fmt; use std::str::FromStr; @@ -53,18 +54,28 @@ impl TableName { } } -impl Queryable for TableName +impl FromSqlRow<(Text, Text), DB> for TableName where DB: Backend, - (String, String): FromSqlRow, + (String, String): FromSqlRow<(Text, Text), DB>, { - type Row = (String, String); + fn build_from_row<'a>( + row: &impl diesel::row::Row<'a, DB>, + ) -> diesel::deserialize::Result { + let (name, schema) = <(String, String) as FromSqlRow<_, DB>>::build_from_row(row)?; - fn build((name, schema): Self::Row) -> Self { - TableName::new(name, schema) + Ok(TableName::new(name, schema)) } } +impl StaticallySizedRow<(Text, Text), DB> for TableName +where + DB: Backend, + Self: FromSqlRow<(Text, Text), DB>, +{ + const FIELD_COUNT: usize = 2; +} + impl fmt::Display for TableName { fn fmt(&self, out: &mut fmt::Formatter) -> Result<(), fmt::Error> { match self.schema { diff --git a/diesel_compile_tests/tests/compile-fail/cannot_mix_aggregate_and_non_aggregate_selects.rs b/diesel_compile_tests/tests/compile-fail/cannot_mix_aggregate_and_non_aggregate_selects.rs index 360114dacb74..671fd4b90478 100644 --- a/diesel_compile_tests/tests/compile-fail/cannot_mix_aggregate_and_non_aggregate_selects.rs +++ b/diesel_compile_tests/tests/compile-fail/cannot_mix_aggregate_and_non_aggregate_selects.rs @@ -2,7 +2,7 @@ extern crate diesel; use diesel::*; -use diesel::dsl::count; +use diesel::dsl::count_star; table! { users { @@ -13,6 +13,6 @@ table! { fn main() { use self::users::dsl::*; - let source = users.select((id, count(users.star()))); + let source = users.select((id, count_star())); //~^ ERROR MixedAggregates } diff --git a/diesel_compile_tests/tests/compile-fail/filter_requires_bool_nonaggregate_expression.rs b/diesel_compile_tests/tests/compile-fail/filter_requires_bool_nonaggregate_expression.rs index b262cc08ea0a..83e7e8046293 100644 --- a/diesel_compile_tests/tests/compile-fail/filter_requires_bool_nonaggregate_expression.rs +++ b/diesel_compile_tests/tests/compile-fail/filter_requires_bool_nonaggregate_expression.rs @@ -14,7 +14,7 @@ fn main() { use diesel::dsl::sum; let _ = users::table.filter(users::name); - //~^ ERROR type mismatch resolving `::SqlType == diesel::sql_types::Bool` + //~^ ERROR the trait bound `diesel::sql_types::Text: diesel::sql_types::BoolOrNullableBool` is not satisfied let _ = users::table.filter(sum(users::id).eq(1)); //~^ ERROR MixedAggregates } diff --git a/diesel_compile_tests/tests/compile-fail/insert_statement_does_not_support_returning_methods_on_sqlite.rs b/diesel_compile_tests/tests/compile-fail/insert_statement_does_not_support_returning_methods_on_sqlite.rs index 3505a61c6ec3..8652c78ed032 100644 --- a/diesel_compile_tests/tests/compile-fail/insert_statement_does_not_support_returning_methods_on_sqlite.rs +++ b/diesel_compile_tests/tests/compile-fail/insert_statement_does_not_support_returning_methods_on_sqlite.rs @@ -1,10 +1,10 @@ #[macro_use] extern crate diesel; -use diesel::*; -use diesel::sqlite::SqliteConnection; use diesel::backend::Backend; use diesel::sql_types::{Integer, VarChar}; +use diesel::sqlite::SqliteConnection; +use diesel::*; table! { users { @@ -18,18 +18,19 @@ pub struct User { name: String, } -use diesel::deserialize::FromSqlRow; +use diesel::deserialize::{self, FromSqlRow}; +use diesel::row::Row; -impl Queryable<(Integer, VarChar), DB> for User where - (i32, String): FromSqlRow<(Integer, VarChar), DB>, +impl FromSqlRow for User +where + (i32, String): FromSqlRow, { - type Row = (i32, String); - - fn build(row: Self::Row) -> Self { - User { + fn build_from_row<'a>(row: &impl Row<'a, DB>) -> deserialize::Result { + let row = <(i32, String) as FromSqlRow>::build_from_row(row)?; + Ok(User { id: row.0, name: row.1, - } + }) } } diff --git a/diesel_compile_tests/tests/compile-fail/join_with_explicit_on_requires_valid_boolean_expression.rs b/diesel_compile_tests/tests/compile-fail/join_with_explicit_on_requires_valid_boolean_expression.rs index 509cfc64b5b9..949b33ada6f1 100644 --- a/diesel_compile_tests/tests/compile-fail/join_with_explicit_on_requires_valid_boolean_expression.rs +++ b/diesel_compile_tests/tests/compile-fail/join_with_explicit_on_requires_valid_boolean_expression.rs @@ -32,5 +32,5 @@ fn main() { //~^ ERROR E0271 // Invalid, type is not boolean let _ = users::table.inner_join(posts::table.on(users::id)); - //~^ ERROR E0271 + //~^ ERROR the trait bound `diesel::sql_types::Integer: diesel::sql_types::BoolOrNullableBool` is not satisfied [E0277] } diff --git a/diesel_compile_tests/tests/compile-fail/right_side_of_left_join_requires_nullable.rs b/diesel_compile_tests/tests/compile-fail/right_side_of_left_join_requires_nullable.rs index 7dc540cf90a9..0fd459b41718 100644 --- a/diesel_compile_tests/tests/compile-fail/right_side_of_left_join_requires_nullable.rs +++ b/diesel_compile_tests/tests/compile-fail/right_side_of_left_join_requires_nullable.rs @@ -52,7 +52,6 @@ fn direct_joins() { // Invalid, Nullable is selectable, but lower expects not-null let _ = join.select(lower(posts::title.nullable())); //~^ ERROR E0271 - //~| ERROR E0271 } fn nested_outer_joins_left_associative() { @@ -74,7 +73,6 @@ fn nested_outer_joins_left_associative() { // Invalid, Nullable<title> is selectable, but lower expects not-null let _ = join.select(lower(posts::title.nullable())); //~^ ERROR E0271 - //~| ERROR E0271 } fn nested_mixed_joins_left_associative() { @@ -95,7 +93,6 @@ fn nested_mixed_joins_left_associative() { // Invalid, Nullable<title> is selectable, but lower expects not-null let _ = join.select(lower(posts::title.nullable())); //~^ ERROR E0271 - //~| ERROR E0271 } fn nested_outer_joins_right_associative() { @@ -116,7 +113,6 @@ fn nested_outer_joins_right_associative() { // Invalid, Nullable<title> is selectable, but lower expects not-null let _ = join.select(lower(posts::title.nullable())); //~^ ERROR E0271 - //~| ERROR E0271 } fn nested_mixed_joins_right_associative() { @@ -137,5 +133,4 @@ fn nested_mixed_joins_right_associative() { // Invalid, Nullable<title> is selectable, but lower expects not-null let _ = join.select(lower(posts::title.nullable())); //~^ ERROR E0271 - //~| ERROR E0271 } diff --git a/diesel_compile_tests/tests/ui/queryable_by_name_requires_table_name_or_sql_type_annotation.stderr b/diesel_compile_tests/tests/ui/queryable_by_name_requires_table_name_or_sql_type_annotation.stderr index f800624d2f43..854455ef05b0 100644 --- a/diesel_compile_tests/tests/ui/queryable_by_name_requires_table_name_or_sql_type_annotation.stderr +++ b/diesel_compile_tests/tests/ui/queryable_by_name_requires_table_name_or_sql_type_annotation.stderr @@ -20,6 +20,12 @@ error: All fields of tuple structs must be annotated with `#[column_name]` 10 | struct Bar(i32, String); | ^^^ +error: All fields of tuple structs must be annotated with `#[column_name]` + --> $DIR/queryable_by_name_requires_table_name_or_sql_type_annotation.rs:10:17 + | +10 | struct Bar(i32, String); + | ^^^^^^ + error: Cannot determine the SQL type of field --> $DIR/queryable_by_name_requires_table_name_or_sql_type_annotation.rs:10:12 | @@ -28,12 +34,6 @@ error: Cannot determine the SQL type of field | = help: Your struct must either be annotated with `#[table_name = "foo"]` or have all of its fields annotated with `#[sql_type = "Integer"]` -error: All fields of tuple structs must be annotated with `#[column_name]` - --> $DIR/queryable_by_name_requires_table_name_or_sql_type_annotation.rs:10:17 - | -10 | struct Bar(i32, String); - | ^^^^^^ - error: Cannot determine the SQL type of field --> $DIR/queryable_by_name_requires_table_name_or_sql_type_annotation.rs:10:17 | diff --git a/diesel_derives/src/diesel_numeric_ops.rs b/diesel_derives/src/diesel_numeric_ops.rs index f0b60a369c95..b8e02199097c 100644 --- a/diesel_derives/src/diesel_numeric_ops.rs +++ b/diesel_derives/src/diesel_numeric_ops.rs @@ -22,10 +22,13 @@ pub fn derive(mut item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Di Ok(wrap_in_dummy_mod(quote! { use diesel::expression::{ops, Expression, AsExpression}; use diesel::sql_types::ops::{Add, Sub, Mul, Div}; + use diesel::sql_types::{SqlType, SingleValue}; impl #impl_generics ::std::ops::Add<__Rhs> for #struct_name #ty_generics #where_clause + Self: Expression, <Self as Expression>::SqlType: Add, + <<Self as Expression>::SqlType as Add>::Rhs: SqlType + SingleValue, __Rhs: AsExpression<<<Self as Expression>::SqlType as Add>::Rhs>, { type Output = ops::Add<Self, __Rhs::Expression>; @@ -37,7 +40,9 @@ pub fn derive(mut item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Di impl #impl_generics ::std::ops::Sub<__Rhs> for #struct_name #ty_generics #where_clause + Self: Expression, <Self as Expression>::SqlType: Sub, + <<Self as Expression>::SqlType as Sub>::Rhs: SqlType + SingleValue, __Rhs: AsExpression<<<Self as Expression>::SqlType as Sub>::Rhs>, { type Output = ops::Sub<Self, __Rhs::Expression>; @@ -49,7 +54,9 @@ pub fn derive(mut item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Di impl #impl_generics ::std::ops::Mul<__Rhs> for #struct_name #ty_generics #where_clause + Self: Expression, <Self as Expression>::SqlType: Mul, + <<Self as Expression>::SqlType as Mul>::Rhs: SqlType + SingleValue, __Rhs: AsExpression<<<Self as Expression>::SqlType as Mul>::Rhs>, { type Output = ops::Mul<Self, __Rhs::Expression>; @@ -61,7 +68,9 @@ pub fn derive(mut item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Di impl #impl_generics ::std::ops::Div<__Rhs> for #struct_name #ty_generics #where_clause + Self: Expression, <Self as Expression>::SqlType: Div, + <<Self as Expression>::SqlType as Div>::Rhs: SqlType + SingleValue, __Rhs: AsExpression<<<Self as Expression>::SqlType as Div>::Rhs>, { type Output = ops::Div<Self, __Rhs::Expression>; diff --git a/diesel_derives/src/from_sql_row.rs b/diesel_derives/src/from_sql_row.rs index 73fa7aabd6a6..3332d6835660 100644 --- a/diesel_derives/src/from_sql_row.rs +++ b/diesel_derives/src/from_sql_row.rs @@ -22,6 +22,9 @@ pub fn derive(mut item: syn::DeriveInput) -> Result<TokenStream, Diagnostic> { where_clause .predicates .push(parse_quote!(Self: FromSql<__ST, __DB>)); + where_clause + .predicates + .push(parse_quote!(__ST: diesel::sql_types::SqlType)); } let (impl_generics, _, where_clause) = item.generics.split_for_impl(); @@ -31,20 +34,18 @@ pub fn derive(mut item: syn::DeriveInput) -> Result<TokenStream, Diagnostic> { impl #impl_generics FromSqlRow<__ST, __DB> for #struct_ty #where_clause { - fn build_from_row<R: diesel::row::Row<__DB>>(row: &mut R) - -> deserialize::Result<Self> - { - FromSql::<__ST, __DB>::from_sql(row.take()) - } - } - impl #impl_generics Queryable<__ST, __DB> for #struct_ty - #where_clause - { - type Row = Self; + #[inline(always)] + fn build_from_row<'a>(row: &impl diesel::row::Row<'a, __DB>) + -> deserialize::Result<Self> + { + use diesel::row::Field; - fn build(row: Self::Row) -> Self { - row + FromSql::<__ST, __DB>::from_nullable_sql( + row.get(0) + .ok_or_else(|| Box::new(diesel::result::UnexpectedEndOfRow))? + .value() + ) } } diff --git a/diesel_derives/src/lib.rs b/diesel_derives/src/lib.rs index aae4f5355220..9383108cf9d4 100644 --- a/diesel_derives/src/lib.rs +++ b/diesel_derives/src/lib.rs @@ -171,7 +171,7 @@ pub fn derive_diesel_numeric_ops(input: TokenStream) -> TokenStream { expand_proc_macro(input, diesel_numeric_ops::derive) } -/// Implements `FromSqlRow` and `Queryable` +/// Implements `FromSqlRow` for single field values /// /// This derive is mostly useful to implement support deserializing /// into rust types not supported by diesel itself. @@ -307,7 +307,7 @@ pub fn derive_query_id(input: TokenStream) -> TokenStream { expand_proc_macro(input, query_id::derive) } -/// Implements `Queryable` +/// Implements `FromSqlRow` to load the result of statically typed queries /// /// This trait can only be derived for structs, not enums. /// @@ -331,12 +331,144 @@ pub fn derive_query_id(input: TokenStream) -> TokenStream { /// into the field type, the implementation will deserialize into `Type`. /// Then `Type` is converted via `.into()` into the field type. By default /// this derive will deserialize directly into the field type +/// +/// +/// # Examples +/// +/// If we just want to map a query to our struct, we can use `derive`. +/// +/// ```rust +/// # extern crate diesel; +/// # extern crate dotenv; +/// # include!("../../diesel/src/doctest_setup.rs"); +/// # +/// #[derive(Queryable, PartialEq, Debug)] +/// struct User { +/// id: i32, +/// name: String, +/// } +/// +/// # fn main() { +/// # run_test(); +/// # } +/// # +/// # fn run_test() -> QueryResult<()> { +/// # use schema::users::dsl::*; +/// # let connection = establish_connection(); +/// let first_user = users.first(&connection)?; +/// let expected = User { id: 1, name: "Sean".into() }; +/// assert_eq!(expected, first_user); +/// # Ok(()) +/// # } +/// ``` +/// +/// If we want to do additional work during deserialization, we can use +/// `deserialize_as` to use a different implementation. +/// +/// ```rust +/// # extern crate diesel; +/// # extern crate dotenv; +/// # include!("../../diesel/src/doctest_setup.rs"); +/// # +/// # use schema::users; +/// # use diesel::backend::{self, Backend}; +/// # use diesel::deserialize::{self, FromSqlRow}; +/// # use diesel::row::Row; +/// # +/// struct LowercaseString(String); +/// +/// impl Into<String> for LowercaseString { +/// fn into(self) -> String { +/// self.0 +/// } +/// } +/// +/// impl<DB, ST> FromSqlRow<ST, DB> for LowercaseString +/// where +/// DB: Backend, +/// String: FromSqlRow<ST, DB>, +/// { +/// +/// fn build_from_row<'a>(row: &impl Row<'a, DB>) -> deserialize::Result<Self> { +/// Ok(LowercaseString(String::build_from_row(row)?.to_lowercase())) +/// } +/// } +/// +/// #[derive(Queryable, PartialEq, Debug)] +/// struct User { +/// id: i32, +/// #[diesel(deserialize_as = "LowercaseString")] +/// name: String, +/// } +/// +/// # fn main() { +/// # run_test(); +/// # } +/// # +/// # fn run_test() -> QueryResult<()> { +/// # use schema::users::dsl::*; +/// # let connection = establish_connection(); +/// let first_user = users.first(&connection)?; +/// let expected = User { id: 1, name: "sean".into() }; +/// assert_eq!(expected, first_user); +/// # Ok(()) +/// # } +/// ``` +/// +/// Alternatively, we can implement the trait for our struct manually. +/// +/// ```rust +/// # extern crate diesel; +/// # extern crate dotenv; +/// # include!("../../diesel/src/doctest_setup.rs"); +/// # +/// use schema::users; +/// use diesel::deserialize::{self, FromSqlRow}; +/// use diesel::row::Row; +/// +/// # /* +/// type DB = diesel::sqlite::Sqlite; +/// # */ +/// +/// #[derive(PartialEq, Debug)] +/// struct User { +/// id: i32, +/// name: String, +/// } +/// +/// impl FromSqlRow<users::SqlType, DB> for User +/// where +/// (i32, String): FromSqlRow<users::SqlType, DB>, +/// { +/// fn build_from_row<'a>(row: &impl Row<'a, DB>) -> deserialize::Result<Self> { +/// let row = <(i32, String) as FromSqlRow<users::SqlType, DB>>::build_from_row(row)?; +/// Ok(User { +/// id: row.0, +/// name: row.1.to_lowercase(), +/// }) +/// } +/// } +/// +/// # fn main() { +/// # run_test(); +/// # } +/// # +/// # fn run_test() -> QueryResult<()> { +/// # use schema::users::dsl::*; +/// # let connection = establish_connection(); +/// let first_user = users.first(&connection)?; +/// let expected = User { id: 1, name: "sean".into() }; +/// assert_eq!(expected, first_user); +/// # Ok(()) +/// # } +/// ``` #[proc_macro_derive(Queryable, attributes(column_name, diesel))] pub fn derive_queryable(input: TokenStream) -> TokenStream { expand_proc_macro(input, queryable::derive) } -/// Implements `QueryableByName` +/// Implements `FromSqlRow` for untyped sql queries, such as that one generated +/// by `sql_query` /// /// To derive this trait, Diesel needs to know the SQL type of each field. You /// can do this by either annotating your struct with `#[table_name = @@ -388,6 +520,138 @@ pub fn derive_queryable(input: TokenStream) -> TokenStream { /// * `#[diesel(embed)]`, specifies that the current field maps not only /// single database column, but is a type that implements /// `QueryableByName` on it's own +/// +/// /// # Examples +/// +/// If we just want to map a query to our struct, we can use `derive`. +/// +/// ```rust +/// # extern crate diesel; +/// # extern crate dotenv; +/// # include!("../../diesel/src/doctest_setup.rs"); +/// # use schema::users; +/// # use diesel::sql_query; +/// # +/// #[derive(QueryableByName, PartialEq, Debug)] +/// #[table_name = "users"] +/// struct User { +/// id: i32, +/// name: String, +/// } +/// +/// # fn main() { +/// # run_test(); +/// # } +/// # +/// # fn run_test() -> QueryResult<()> { +/// # let connection = establish_connection(); +/// let first_user = sql_query("SELECT * FROM users ORDER BY id LIMIT 1") +/// .get_result(&connection)?; +/// let expected = User { id: 1, name: "Sean".into() }; +/// assert_eq!(expected, first_user); +/// # Ok(()) +/// # } +/// ``` +/// +/// If we want to do additional work during deserialization, we can use +/// `deserialize_as` to use a different implementation. +/// +/// ```rust +/// # extern crate diesel; +/// # extern crate dotenv; +/// # include!("../../diesel/src/doctest_setup.rs"); +/// # use diesel::sql_query; +/// # use schema::users; +/// # use diesel::backend::{self, Backend}; +/// # use diesel::deserialize::{self, FromSql}; +/// # +/// struct LowercaseString(String); +/// +/// impl Into<String> for LowercaseString { +/// fn into(self) -> String { +/// self.0 +/// } +/// } +/// +/// impl<DB, ST> FromSql<ST, DB> for LowercaseString +/// where +/// DB: Backend, +/// String: FromSql<ST, DB>, +/// { +/// fn from_sql(bytes: backend::RawValue<DB>) -> deserialize::Result<Self> { +/// String::from_sql(bytes) +/// .map(|s| LowercaseString(s.to_lowercase())) +/// } +/// } +/// +/// #[derive(QueryableByName, PartialEq, Debug)] +/// #[table_name = "users"] +/// struct User { +/// id: i32, +/// #[diesel(deserialize_as = "LowercaseString")] +/// name: String, +/// } +/// +/// # fn main() { +/// # run_test(); +/// # } +/// # +/// # fn run_test() -> QueryResult<()> { +/// # let connection = establish_connection(); +/// let first_user = sql_query("SELECT * FROM users ORDER BY id LIMIT 1") +/// .get_result(&connection)?; +/// let expected = User { id: 1, name: "sean".into() }; +/// assert_eq!(expected, first_user); +/// # Ok(()) +/// # } +/// ``` +/// +/// The custom derive generates impls similar to the follownig one +/// +/// ```rust +/// # extern crate diesel; +/// # extern crate dotenv; +/// # include!("../../diesel/src/doctest_setup.rs"); +/// # use schema::users; +/// # use diesel::sql_query; +/// # use diesel::deserialize::{self, FromSqlRow, FromSql}; +/// # use diesel::row::{Row, Field}; +/// # +/// #[derive(PartialEq, Debug)] +/// struct User { +/// id: i32, +/// name: String, +/// } +/// +/// impl<DB> FromSqlRow<diesel::expression::expression_types::Untyped, DB> for User +/// where +/// DB: diesel::backend::Backend, +/// i32: FromSql<diesel::dsl::SqlTypeOf<users::id>, DB>, +/// String: FromSql<diesel::dsl::SqlTypeOf<users::name>, DB>, +/// { +/// fn build_from_row<'a>(row: &impl Row<'a, DB>) -> deserialize::Result<Self> { +/// let id = row.get("id").ok_or("Column `id` was not present in query")?.value(); +/// let id = i32::from_nullable_sql(id)?; +/// +/// let name = row.get("name").ok_or("Column `name` was not present in query")?.value(); +/// let name = String::from_nullable_sql(name)?; +/// Ok(Self { id, name }) +/// } +/// } +/// +/// # fn main() { +/// # run_test(); +/// # } +/// # +/// # fn run_test() -> QueryResult<()> { +/// # let connection = establish_connection(); +/// let first_user = sql_query("SELECT * FROM users ORDER BY id LIMIT 1") +/// .get_result(&connection)?; +/// let expected = User { id: 1, name: "Sean".into() }; +/// assert_eq!(expected, first_user); +/// # Ok(()) +/// # } +/// ``` #[proc_macro_derive(QueryableByName, attributes(table_name, column_name, sql_type, diesel))] pub fn derive_queryable_by_name(input: TokenStream) -> TokenStream { expand_proc_macro(input, queryable_by_name::derive) diff --git a/diesel_derives/src/queryable.rs b/diesel_derives/src/queryable.rs index 077fe3698ac8..fcaf0c7226f1 100644 --- a/diesel_derives/src/queryable.rs +++ b/diesel_derives/src/queryable.rs @@ -19,35 +19,53 @@ pub fn derive(item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Diagno let i = syn::Index::from(i); f.name.assign(parse_quote!(row.#i.into())) }); + let sql_type = (0..model.fields().len()) + .map(|i| { + let i = syn::Ident::new(&format!("__ST{}", i), proc_macro2::Span::call_site()); + quote!(#i) + }) + .collect::<Vec<_>>(); + let sql_type = &sql_type; let (_, ty_generics, _) = item.generics.split_for_impl(); let mut generics = item.generics.clone(); generics .params .push(parse_quote!(__DB: diesel::backend::Backend)); - generics.params.push(parse_quote!(__ST)); + for id in 0..model.fields().len() { + let ident = syn::Ident::new(&format!("__ST{}", id), proc_macro2::Span::call_site()); + generics.params.push(parse_quote!(#ident)); + } { let where_clause = generics.where_clause.get_or_insert(parse_quote!(where)); where_clause .predicates - .push(parse_quote!((#(#field_ty,)*): Queryable<__ST, __DB>)); + .push(parse_quote!((#(#field_ty,)*): FromSqlRow<(#(#sql_type,)*), __DB>)); } let (impl_generics, _, where_clause) = generics.split_for_impl(); + let field_count = field_ty.len(); Ok(wrap_in_dummy_mod(quote! { - use diesel::deserialize::Queryable; + use diesel::deserialize::{FromSqlRow, Result, StaticallySizedRow}; + use diesel::row::{Row, Field}; - impl #impl_generics Queryable<__ST, __DB> for #struct_name #ty_generics - #where_clause + impl #impl_generics FromSqlRow<(#(#sql_type,)*), __DB> for #struct_name #ty_generics + #where_clause { - type Row = <(#(#field_ty,)*) as Queryable<__ST, __DB>>::Row; - - fn build(row: Self::Row) -> Self { - let row: (#(#field_ty,)*) = Queryable::build(row); - Self { + fn build_from_row<'__a>(row: &impl Row<'__a, __DB>) -> Result<Self> + { + let row = + <(#(#field_ty,)*) as FromSqlRow<(#(#sql_type,)*), __DB>>::build_from_row(row)?; + Result::Ok(Self { #(#build_expr,)* - } + }) } } + + impl #impl_generics StaticallySizedRow<((#(#sql_type, )*)), __DB> for #struct_name #ty_generics + #where_clause + { + const FIELD_COUNT:usize = #field_count; + } })) } diff --git a/diesel_derives/src/queryable_by_name.rs b/diesel_derives/src/queryable_by_name.rs index ef9a44f6fe2d..f6f232dd8ccc 100644 --- a/diesel_derives/src/queryable_by_name.rs +++ b/diesel_derives/src/queryable_by_name.rs @@ -9,11 +9,34 @@ pub fn derive(item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Diagno let model = Model::from_item(&item)?; let struct_name = &item.ident; - let field_expr = model + let fields = model.fields().iter().map(get_ident).collect::<Vec<_>>(); + let field_names = model.fields().iter().map(|f| &f.name).collect::<Vec<_>>(); + + let initial_field_expr = model .fields() .iter() - .map(|f| field_expr(f, &model)) - .collect::<Result<Vec<_>, _>>()?; + .map(|f| { + if f.has_flag("embed") { + let field_ty = &f.ty; + Ok(quote!(<#field_ty as FromSqlRow<Untyped, __DB>>::build_from_row( + row, + )?)) + } else { + let name = f.column_name(); + let field_ty = &f.ty; + let deserialize_ty = f.ty_for_deserialize()?; + Ok(quote!( + { + let field = row.get(stringify!(#name)) + .ok_or(concat!("Column `", stringify!(#name), "` was not present in query"))? + .value(); + let field = <#deserialize_ty as diesel::deserialize::FromSql<_, __DB>>::from_nullable_sql(field)?; + <#deserialize_ty as Into<#field_ty>>::into(field) + } + )) + } + }) + .collect::<Result<Vec<_>, Diagnostic>>()?; let (_, ty_generics, ..) = item.generics.split_for_impl(); let mut generics = item.generics.clone(); @@ -27,7 +50,7 @@ pub fn derive(item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Diagno if field.has_flag("embed") { where_clause .predicates - .push(parse_quote!(#field_ty: QueryableByName<__DB>)); + .push(parse_quote!(#field_ty: FromSqlRow<Untyped,__DB>)); } else { let st = sql_type(field, &model); where_clause @@ -39,34 +62,35 @@ pub fn derive(item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Diagno let (impl_generics, _, where_clause) = generics.split_for_impl(); Ok(wrap_in_dummy_mod(quote! { - use diesel::deserialize::{self, QueryableByName}; - use diesel::row::NamedRow; + use diesel::deserialize::{self, FromSqlRow}; + use diesel::row::{Row, Field}; + use diesel::sql_types::Untyped; - impl #impl_generics QueryableByName<__DB> + impl #impl_generics FromSqlRow<Untyped, __DB> for #struct_name #ty_generics #where_clause { - fn build<__R: NamedRow<__DB>>(row: &__R) -> deserialize::Result<Self> { - std::result::Result::Ok(Self { - #(#field_expr,)* + fn build_from_row<'__a>(row: &impl Row<'__a, __DB>) -> deserialize::Result<Self> + { + + + #( + let mut #fields = #initial_field_expr; + )* + deserialize::Result::Ok(Self { + #( + #field_names: #fields, + )* }) } } })) } -fn field_expr(field: &Field, model: &Model) -> Result<syn::FieldValue, Diagnostic> { - if field.has_flag("embed") { - Ok(field - .name - .assign(parse_quote!(QueryableByName::build(row)?))) - } else { - let column_name = field.column_name(); - let ty = field.ty_for_deserialize()?; - let st = sql_type(field, model); - Ok(field - .name - .assign(parse_quote!(row.get::<#st, #ty>(stringify!(#column_name))?.into()))) +fn get_ident(field: &Field) -> Ident { + match &field.name { + FieldName::Named(n) => n.clone(), + FieldName::Unnamed(i) => Ident::new(&format!("field_{}", i.index), Span::call_site()), } } diff --git a/diesel_derives/src/sql_function.rs b/diesel_derives/src/sql_function.rs index 7ffd525e1386..a64a50595fde 100644 --- a/diesel_derives/src/sql_function.rs +++ b/diesel_derives/src/sql_function.rs @@ -141,7 +141,7 @@ pub(crate) fn expand(input: SqlFunctionDecl) -> Result<TokenStream, Diagnostic> use diesel::sqlite::{Sqlite, SqliteConnection}; use diesel::serialize::ToSql; - use diesel::deserialize::{Queryable, StaticallySizedRow}; + use diesel::deserialize::{FromSqlRow, StaticallySizedRow}; use diesel::sqlite::SqliteAggregateFunction; use diesel::sql_types::IntoNullable; }; @@ -163,8 +163,8 @@ pub(crate) fn expand(input: SqlFunctionDecl) -> Result<TokenStream, Diagnostic> where A: SqliteAggregateFunction<(#(#arg_name,)*)> + Send + 'static, A::Output: ToSql<#return_type, Sqlite>, - (#(#arg_name,)*): Queryable<(#(#arg_type,)*), Sqlite>, - <(#(#arg_name,)*) as Queryable<(#(#arg_type,)*), Sqlite>>::Row: StaticallySizedRow<(#(#arg_type,)*), Sqlite>, + (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> + + StaticallySizedRow<(#(#arg_type,)*), Sqlite>, { conn.register_aggregate_function::<(#(#arg_type,)*), #return_type, _, _, A>(#sql_name) } @@ -189,8 +189,8 @@ pub(crate) fn expand(input: SqlFunctionDecl) -> Result<TokenStream, Diagnostic> where A: SqliteAggregateFunction<#arg_name> + Send + 'static, A::Output: ToSql<#return_type, Sqlite>, - #arg_name: Queryable<#arg_type, Sqlite>, - <#arg_name as Queryable<#arg_type, Sqlite>>::Row: StaticallySizedRow<#arg_type, Sqlite>, + #arg_name: FromSqlRow<#arg_type, Sqlite> + + StaticallySizedRow<#arg_type, Sqlite>, { conn.register_aggregate_function::<#arg_type, #return_type, _, _, A>(#sql_name) } @@ -221,7 +221,7 @@ pub(crate) fn expand(input: SqlFunctionDecl) -> Result<TokenStream, Diagnostic> use diesel::sqlite::{Sqlite, SqliteConnection}; use diesel::serialize::ToSql; - use diesel::deserialize::{Queryable, StaticallySizedRow}; + use diesel::deserialize::{FromSqlRow, StaticallySizedRow}; #[allow(dead_code)] /// Registers an implementation for this function on the given connection @@ -237,8 +237,8 @@ pub(crate) fn expand(input: SqlFunctionDecl) -> Result<TokenStream, Diagnostic> ) -> QueryResult<()> where F: Fn(#(#arg_name,)*) -> Ret + Send + 'static, - (#(#arg_name,)*): Queryable<(#(#arg_type,)*), Sqlite>, - <(#(#arg_name,)*) as Queryable<(#(#arg_type,)*), Sqlite>>::Row: StaticallySizedRow<(#(#arg_type,)*), Sqlite>, + (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> + + StaticallySizedRow<(#(#arg_type,)*), Sqlite>, Ret: ToSql<#return_type, Sqlite>, { conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>( @@ -263,8 +263,8 @@ pub(crate) fn expand(input: SqlFunctionDecl) -> Result<TokenStream, Diagnostic> ) -> QueryResult<()> where F: FnMut(#(#arg_name,)*) -> Ret + Send + 'static, - (#(#arg_name,)*): Queryable<(#(#arg_type,)*), Sqlite>, - <(#(#arg_name,)*) as Queryable<(#(#arg_type,)*), Sqlite>>::Row: StaticallySizedRow<(#(#arg_type,)*), Sqlite>, + (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> + + StaticallySizedRow<(#(#arg_type,)*), Sqlite>, Ret: ToSql<#return_type, Sqlite>, { conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>( @@ -323,7 +323,7 @@ impl Parse for SqlFunctionDecl { let return_type = if Option::<Token![->]>::parse(input)?.is_some() { syn::Type::parse(input)? } else { - parse_quote!(()) + parse_quote!(diesel::expression::expression_types::NotSelectable) }; let _semi = Option::<Token![;]>::parse(input)?; diff --git a/diesel_derives/src/sql_type.rs b/diesel_derives/src/sql_type.rs index eceb78eaffd3..2f8ea7a2bc38 100644 --- a/diesel_derives/src/sql_type.rs +++ b/diesel_derives/src/sql_type.rs @@ -13,10 +13,11 @@ pub fn derive(item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Diagno let pg_tokens = pg_tokens(&item); Ok(wrap_in_dummy_mod(quote! { - impl #impl_generics diesel::sql_types::NotNull + impl #impl_generics diesel::sql_types::SqlType for #struct_name #ty_generics #where_clause { + type IsNull = diesel::sql_types::is_nullable::NotNull; } impl #impl_generics diesel::sql_types::SingleValue @@ -71,8 +72,8 @@ fn mysql_tokens(item: &syn::DeriveInput) -> Option<proc_macro2::TokenStream> { for diesel::mysql::Mysql #where_clause { - fn metadata(_: &()) -> std::option::Option<diesel::mysql::MysqlType> { - std::option::Option::Some(diesel::mysql::MysqlType::#ty) + fn metadata(_: &()) -> diesel::mysql::MysqlType { + diesel::mysql::MysqlType::#ty } } }) diff --git a/diesel_tests/tests/custom_types.rs b/diesel_tests/tests/custom_types.rs index 50eaa1caa624..915e81703486 100644 --- a/diesel_tests/tests/custom_types.rs +++ b/diesel_tests/tests/custom_types.rs @@ -37,8 +37,8 @@ impl ToSql<MyType, Pg> for MyEnum { } impl FromSql<MyType, Pg> for MyEnum { - fn from_sql(bytes: Option<PgValue<'_>>) -> deserialize::Result<Self> { - match not_none!(bytes).as_bytes() { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result<Self> { + match bytes.as_bytes() { b"foo" => Ok(MyEnum::Foo), b"bar" => Ok(MyEnum::Bar), _ => Err("Unrecognized enum variant".into()), diff --git a/diesel_tests/tests/expressions/date_and_time.rs b/diesel_tests/tests/expressions/date_and_time.rs index 52d5b40bfab8..fb8624f7e8c7 100644 --- a/diesel_tests/tests/expressions/date_and_time.rs +++ b/diesel_tests/tests/expressions/date_and_time.rs @@ -131,7 +131,7 @@ fn now_can_be_used_as_nullable() { let nullable_timestamp = sql::<Nullable<Timestamp>>("CURRENT_TIMESTAMP"); let result = select(nullable_timestamp.eq(now)).get_result(&connection()); - assert_eq!(Ok(true), result); + assert_eq!(Ok(Some(true)), result); } #[test] diff --git a/diesel_tests/tests/expressions/mod.rs b/diesel_tests/tests/expressions/mod.rs index 314c5690e431..6d6702b7df28 100644 --- a/diesel_tests/tests/expressions/mod.rs +++ b/diesel_tests/tests/expressions/mod.rs @@ -10,7 +10,9 @@ use crate::schema::{ }; use diesel::backend::Backend; use diesel::dsl::*; +use diesel::expression::TypedExpressionType; use diesel::query_builder::*; +use diesel::sql_types::SqlType; use diesel::*; #[test] @@ -152,7 +154,10 @@ struct Arbitrary<T> { _marker: PhantomData<T>, } -impl<T> Expression for Arbitrary<T> { +impl<T> Expression for Arbitrary<T> +where + T: SqlType + TypedExpressionType, +{ type SqlType = T; } @@ -165,9 +170,9 @@ where } } -impl<T, QS> SelectableExpression<QS> for Arbitrary<T> {} +impl<T, QS> SelectableExpression<QS> for Arbitrary<T> where Self: Expression {} -impl<T, QS> AppearsOnTable<QS> for Arbitrary<T> {} +impl<T, QS> AppearsOnTable<QS> for Arbitrary<T> where Self: Expression {} fn arbitrary<T>() -> Arbitrary<T> { Arbitrary { diff --git a/diesel_tests/tests/filter.rs b/diesel_tests/tests/filter.rs index 7b010e10dcd0..9f7d72164974 100644 --- a/diesel_tests/tests/filter.rs +++ b/diesel_tests/tests/filter.rs @@ -328,14 +328,14 @@ fn or_doesnt_mess_with_precedence_of_previous_statements() { let f = false.into_sql::<sql_types::Bool>(); let count = users .filter(f) - .filter(f.or(true)) + .filter(f.or(true.into_sql::<sql_types::Bool>())) .count() .first(&connection); assert_eq!(Ok(0), count); let count = users - .filter(f.or(f).and(f.or(true))) + .filter(f.or(f).and(f.or(true.into_sql::<sql_types::Bool>()))) .count() .first(&connection); diff --git a/diesel_tests/tests/types.rs b/diesel_tests/tests/types.rs index 8c528cc9bbe4..ebca960a45e9 100644 --- a/diesel_tests/tests/types.rs +++ b/diesel_tests/tests/types.rs @@ -5,6 +5,7 @@ extern crate bigdecimal; extern crate chrono; use crate::schema::*; +use diesel::deserialize::FromSqlRow; #[cfg(feature = "postgres")] use diesel::pg::Pg; use diesel::sql_types::*; @@ -144,12 +145,11 @@ fn boolean_from_sql() { } #[test] -#[cfg(feature = "postgres")] fn boolean_treats_null_as_false_when_predicates_return_null() { let connection = connection(); - let one = Some(1).into_sql::<Nullable<Integer>>(); + let one = Some(1).into_sql::<diesel::sql_types::Nullable<Integer>>(); let query = select(one.eq(None::<i32>)); - assert_eq!(Ok(false), query.first(&connection)); + assert_eq!(Ok(Option::<bool>::None), query.first(&connection)); } #[test] @@ -670,16 +670,16 @@ fn pg_specific_option_to_sql() { "'t'::bool", Some(true) )); - assert!(!query_to_sql_equality::<Nullable<Bool>, Option<bool>>( + assert!(query_to_sql_equality::<Nullable<Bool>, Option<bool>>( "'f'::bool", - Some(true) + Some(false) )); assert!(query_to_sql_equality::<Nullable<Bool>, Option<bool>>( "NULL", None )); - assert!(!query_to_sql_equality::<Nullable<Bool>, Option<bool>>( + assert!(query_to_sql_equality::<Nullable<Bool>, Option<bool>>( "NULL::bool", - Some(false) + None )); } @@ -1231,7 +1231,7 @@ fn third_party_crates_can_add_new_types() { } impl FromSql<MyInt, Pg> for i32 { - fn from_sql(bytes: Option<PgValue<'_>>) -> deserialize::Result<Self> { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result<Self> { FromSql::<Integer, Pg>::from_sql(bytes) } } @@ -1241,17 +1241,18 @@ fn third_party_crates_can_add_new_types() { assert_eq!(70_000, query_single_value::<MyInt, i32>("70000")); } -fn query_single_value<T, U: Queryable<T, TestBackend>>(sql_str: &str) -> U +fn query_single_value<T, U: FromSqlRow<T, TestBackend>>(sql_str: &str) -> U where TestBackend: HasSqlType<T>, - T: QueryId + SingleValue, + T: QueryId + SingleValue + SqlType, { use diesel::dsl::sql; let connection = connection(); select(sql::<T>(sql_str)).first(&connection).unwrap() } -use diesel::expression::{is_aggregate, AsExpression, ValidGrouping}; +use diesel::dsl::{And, AsExprOf, Eq, IsNull}; +use diesel::expression::{is_aggregate, AsExpression, SqlLiteral, ValidGrouping}; use diesel::query_builder::{QueryFragment, QueryId}; use std::fmt::Debug; @@ -1261,7 +1262,16 @@ where U::Expression: SelectableExpression<(), SqlType = T> + ValidGrouping<(), IsAggregate = is_aggregate::Never>, U::Expression: QueryFragment<TestBackend> + QueryId, - T: QueryId + SingleValue, + T: QueryId + SingleValue + SqlType, + T::IsNull: OneIsNullable<T::IsNull, Out = T::IsNull>, + T::IsNull: MaybeNullableType<Bool>, + <T::IsNull as MaybeNullableType<Bool>>::Out: SqlType, + diesel::sql_types::is_nullable::NotNull: diesel::sql_types::AllAreNullable< + <<T::IsNull as MaybeNullableType<Bool>>::Out as SqlType>::IsNull, + Out = diesel::sql_types::is_nullable::NotNull, + >, + Eq<SqlLiteral<T>, U>: Expression<SqlType = <T::IsNull as MaybeNullableType<Bool>>::Out>, + And<IsNull<SqlLiteral<T>>, IsNull<AsExprOf<U, T>>>: Expression<SqlType = Bool>, { use diesel::dsl::sql; let connection = connection(); @@ -1272,7 +1282,7 @@ where .or(sql::<T>(sql_str).eq(value.clone())), ); query - .get_result(&connection) + .get_result::<bool>(&connection) .expect(&format!("Error comparing {}, {:?}", sql_str, value)) } diff --git a/diesel_tests/tests/types_roundtrip.rs b/diesel_tests/tests/types_roundtrip.rs index ea2577e03ebd..67818efb7464 100644 --- a/diesel_tests/tests/types_roundtrip.rs +++ b/diesel_tests/tests/types_roundtrip.rs @@ -9,10 +9,11 @@ pub use crate::schema::{connection_without_transaction, TestConnection}; pub use diesel::data_types::*; pub use diesel::result::Error; pub use diesel::serialize::ToSql; -pub use diesel::sql_types::HasSqlType; +pub use diesel::sql_types::{HasSqlType, SingleValue, SqlType}; pub use diesel::*; -use diesel::expression::{AsExpression, NonAggregate}; +use deserialize::FromSqlRow; +use diesel::expression::{AsExpression, NonAggregate, TypedExpressionType}; use diesel::query_builder::{QueryFragment, QueryId}; #[cfg(feature = "postgres")] use std::collections::Bound; @@ -23,10 +24,10 @@ thread_local! { pub fn test_type_round_trips<ST, T>(value: T) -> bool where - ST: QueryId, + ST: QueryId + SqlType + TypedExpressionType + SingleValue, <TestConnection as Connection>::Backend: HasSqlType<ST>, T: AsExpression<ST> - + Queryable<ST, <TestConnection as Connection>::Backend> + + FromSqlRow<ST, <TestConnection as Connection>::Backend> + PartialEq + Clone + ::std::fmt::Debug, diff --git a/examples/postgres/advanced-blog-cli/src/post.rs b/examples/postgres/advanced-blog-cli/src/post.rs index 397a3a68dfb0..a5c0d7f010e1 100644 --- a/examples/postgres/advanced-blog-cli/src/post.rs +++ b/examples/postgres/advanced-blog-cli/src/post.rs @@ -22,17 +22,18 @@ pub enum Status { Published { at: NaiveDateTime }, } -use diesel::deserialize::Queryable; +use diesel::deserialize::{self, FromSqlRow}; use diesel::pg::Pg; +use diesel::row::Row; use diesel::sql_types::{Nullable, Timestamp}; -impl Queryable<Nullable<Timestamp>, Pg> for Status { - type Row = Option<NaiveDateTime>; - - fn build(row: Self::Row) -> Self { +impl FromSqlRow<Nullable<Timestamp>, Pg> for Status { + fn build_from_row<'a>(row: &impl Row<'a, Pg>) -> deserialize::Result<Self> { + let row: Option<NaiveDateTime> = + <_ as FromSqlRow<Nullable<Timestamp>, Pg>>::build_from_row(row)?; match row { - Some(at) => Status::Published { at }, - None => Status::Draft, + Some(at) => Ok(Status::Published { at }), + None => Ok(Status::Draft), } } } diff --git a/examples/postgres/custom_types/src/main.rs b/examples/postgres/custom_types/src/main.rs index 584fd50fb857..48504dd7bd47 100644 --- a/examples/postgres/custom_types/src/main.rs +++ b/examples/postgres/custom_types/src/main.rs @@ -1,4 +1,3 @@ -#[macro_use] extern crate diesel; use diesel::prelude::*; diff --git a/examples/postgres/custom_types/src/model.rs b/examples/postgres/custom_types/src/model.rs index 2acc278faef2..84dbd50d3b9e 100644 --- a/examples/postgres/custom_types/src/model.rs +++ b/examples/postgres/custom_types/src/model.rs @@ -33,8 +33,8 @@ impl ToSql<LanguageType, Pg> for Language { } impl FromSql<LanguageType, Pg> for Language { - fn from_sql(bytes: Option<PgValue>) -> deserialize::Result<Self> { - match not_none!(bytes).as_bytes() { + fn from_sql(bytes: PgValue) -> deserialize::Result<Self> { + match bytes.as_bytes() { b"en" => Ok(Language::En), b"ru" => Ok(Language::Ru), b"de" => Ok(Language::De),