diff --git a/src/postgres/def/types.rs b/src/postgres/def/types.rs index ae5a0c04..e2057f62 100644 --- a/src/postgres/def/types.rs +++ b/src/postgres/def/types.rs @@ -91,7 +91,7 @@ pub enum Type { Bit(BitAttr), // Text search types - /// A sorted list of distincp lexemes which are words that have been normalized to merge different + /// A sorted list of distinct lexemes which are words that have been normalized to merge different /// variants of the same word TsVector, /// A list of lexemes that are to be searched for, and can be combined using Boolean operators AND, @@ -101,7 +101,7 @@ pub enum Type { /// A universally unique identifier as defined by RFC 4122, ISO 9834-8:2005, and related standards Uuid, - /// XML data checked for well-formedness and with additonal support functions + /// XML data checked for well-formedness and with additional support functions Xml, /// JSON data checked for validity and with additional functions @@ -138,6 +138,8 @@ pub enum Type { PgLsn, // TODO: Pseudo-types Unknown(String), + /// Defines an PostgreSQL + Enum(EnumDef), } impl Type { @@ -166,7 +168,6 @@ impl Type { "time with time zone" => Type::TimeWithTimeZone(TimeAttr::default()), "interval" => Type::Interval(IntervalAttr::default()), "boolean" => Type::Boolean, - // "" => Type::Enum, "point" => Type::Point, "line" => Type::Line, "lseg" => Type::Lseg, @@ -194,6 +195,7 @@ impl Type { "daterange" => Type::DateRange, // "" => Type::Domain, "pg_lsn" => Type::PgLsn, + "user-defined" => Type::Enum(EnumDef::default()), _ => Type::Unknown(name.to_owned()), } @@ -237,6 +239,16 @@ pub struct BitAttr { pub length: Option, } +/// Defines an enum for the PostgreSQL module +#[derive(Clone, Debug, PartialEq, Default)] +#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))] +pub struct EnumDef { + /// Holds the fields of the `ENUM` + pub values: Vec, + /// Defines the name of the PostgreSQL enum identifier + pub typename: String, +} + impl Type { pub fn has_numeric_attr(&self) -> bool { matches!(self, Type::Numeric(_) | Type::Decimal(_)) @@ -263,4 +275,8 @@ impl Type { pub fn has_bit_attr(&self) -> bool { matches!(self, Type::Bit(_)) } + + pub fn has_enum_attr(&self) -> bool { + matches!(self, Type::Enum(_)) + } } diff --git a/src/postgres/discovery/mod.rs b/src/postgres/discovery/mod.rs index e64d643c..f8c5b622 100644 --- a/src/postgres/discovery/mod.rs +++ b/src/postgres/discovery/mod.rs @@ -4,10 +4,12 @@ use crate::debug_print; use crate::postgres::def::*; use crate::postgres::parser::parse_table_constraint_query_results; use crate::postgres::query::{ - ColumnQueryResult, SchemaQueryBuilder, TableConstraintsQueryResult, TableQueryResult, + ColumnQueryResult, EnumQueryResult, SchemaQueryBuilder, TableConstraintsQueryResult, + TableQueryResult, }; use futures::future; use sea_query::{Alias, Iden, IntoIden, SeaRc}; +use std::collections::HashMap; mod executor; pub use executor::*; @@ -30,12 +32,12 @@ impl SchemaDiscovery { } } - pub async fn discover(mut self) -> Schema { + pub async fn discover(&self) -> Schema { let tables = self.discover_tables().await; let tables = future::join_all( tables .into_iter() - .map(|t| (&self, t)) + .map(|t| (self, t)) .map(Self::discover_table_static), ) .await; @@ -46,7 +48,7 @@ impl SchemaDiscovery { } } - pub async fn discover_tables(&mut self) -> Vec { + pub async fn discover_tables(&self) -> Vec { let rows = self .executor .fetch_all(self.query.query_tables(self.schema.clone())) @@ -176,4 +178,37 @@ impl SchemaDiscovery { constraints } + + pub async fn discover_enums(&self) -> Vec { + let rows = self.executor.fetch_all(self.query.query_enums()).await; + + let enum_rows: Vec = rows + .iter() + .map(|row| { + let result: EnumQueryResult = row.into(); + debug_print!("{:?}", result); + return result; + }) + .collect(); + + let map = enum_rows.into_iter().fold( + HashMap::new(), + |mut map: HashMap>, + EnumQueryResult { + typename, + enumlabel, + }| { + if let Some(entry_exists) = map.get_mut(&typename) { + entry_exists.push(enumlabel); + } else { + map.insert(typename, vec![enumlabel]); + } + map + }, + ); + + map.into_iter() + .map(|(typename, values)| EnumDef { values, typename }) + .collect() + } } diff --git a/src/postgres/parser/column.rs b/src/postgres/parser/column.rs index 33e9ec4b..61e8ac1f 100644 --- a/src/postgres/parser/column.rs +++ b/src/postgres/parser/column.rs @@ -41,6 +41,9 @@ pub fn parse_column_type(result: &ColumnQueryResult) -> ColumnType { if ctype.has_bit_attr() { ctype = parse_bit_attributes(result.character_maximum_length, ctype); } + if ctype.has_enum_attr() { + ctype = parse_enum_attributes(result.udt_name.as_ref(), ctype); + } ctype } @@ -165,3 +168,17 @@ pub fn parse_bit_attributes( ctype } + +pub fn parse_enum_attributes(udt_name: Option<&String>, mut ctype: ColumnType) -> ColumnType { + match ctype { + Type::Enum(ref mut def) => { + def.typename = match udt_name { + None => panic!("parse_enum_attributes(_) received an empty udt_name"), + Some(typename) => typename.to_string(), + }; + } + _ => panic!("parse_enum_attributes(_) received a type that does not have EnumDef"), + }; + + ctype +} diff --git a/src/postgres/query/column.rs b/src/postgres/query/column.rs index e1e701a7..998fc0d6 100644 --- a/src/postgres/query/column.rs +++ b/src/postgres/query/column.rs @@ -64,6 +64,8 @@ pub struct ColumnQueryResult { pub interval_type: Option, pub interval_precision: Option, + + pub udt_name: Option, } impl SchemaQueryBuilder { @@ -88,6 +90,7 @@ impl SchemaQueryBuilder { ColumnsField::DatetimePrecision, ColumnsField::IntervalType, ColumnsField::IntervalPrecision, + ColumnsField::UdtName, ]) .from((InformationSchema::Schema, InformationSchema::Columns)) .and_where(Expr::col(ColumnsField::TableSchema).eq(schema.to_string())) @@ -115,6 +118,7 @@ impl From<&PgRow> for ColumnQueryResult { datetime_precision: row.get(11), interval_type: row.get(12), interval_precision: row.get(13), + udt_name: row.get(14), } } } diff --git a/src/postgres/query/enumeration.rs b/src/postgres/query/enumeration.rs new file mode 100644 index 00000000..a1269c41 --- /dev/null +++ b/src/postgres/query/enumeration.rs @@ -0,0 +1,63 @@ +use super::SchemaQueryBuilder; +use crate::sqlx_types::postgres::PgRow; +use sea_query::{Expr, Order, Query, SelectStatement}; + +#[derive(Debug, sea_query::Iden)] +pub enum PgType { + #[iden = "pg_type"] + Table, + #[iden = "typname"] + TypeName, + #[iden = "oid"] + Oid, +} + +#[derive(Debug, sea_query::Iden)] +pub enum PgEnum { + #[iden = "pg_enum"] + Table, + #[iden = "enumlabel"] + EnumLabel, + #[iden = "enumtypid"] + EnumTypeId, +} + +#[derive(Debug, Default)] +pub struct EnumQueryResult { + pub typename: String, + pub enumlabel: String, +} + +impl SchemaQueryBuilder { + pub fn query_enums(&self) -> SelectStatement { + Query::select() + .column((PgType::Table, PgType::TypeName)) + .column((PgEnum::Table, PgEnum::EnumLabel)) + .from(PgType::Table) + .inner_join( + PgEnum::Table, + Expr::tbl(PgEnum::Table, PgEnum::EnumTypeId).equals(PgType::Table, PgType::Oid), + ) + .order_by((PgType::Table, PgType::TypeName), Order::Asc) + .order_by((PgEnum::Table, PgEnum::EnumLabel), Order::Asc) + .take() + } +} + +#[cfg(feature = "sqlx-postgres")] +impl From<&PgRow> for EnumQueryResult { + fn from(row: &PgRow) -> Self { + use crate::sqlx_types::Row; + Self { + typename: row.get(0), + enumlabel: row.get(1), + } + } +} + +#[cfg(not(feature = "sqlx-postgres"))] +impl From<&PgRow> for EnumQueryResult { + fn from(row: &PgRow) -> Self { + Self::default() + } +} diff --git a/src/postgres/query/mod.rs b/src/postgres/query/mod.rs index 78935ef0..f66212ae 100644 --- a/src/postgres/query/mod.rs +++ b/src/postgres/query/mod.rs @@ -1,11 +1,13 @@ pub mod char_set; pub mod column; pub mod constraints; +pub mod enumeration; pub mod schema; pub mod table; pub use char_set::*; pub use column::*; pub use constraints::*; +pub use enumeration::*; pub use schema::*; pub use table::*; diff --git a/src/postgres/writer/column.rs b/src/postgres/writer/column.rs index 8b46dd55..df53c219 100644 --- a/src/postgres/writer/column.rs +++ b/src/postgres/writer/column.rs @@ -156,7 +156,7 @@ impl ColumnInfo { col_def.custom(Alias::new("path")); } Type::Polygon => { - col_def.custom(Alias::new("ploygon")); + col_def.custom(Alias::new("polygon")); } Type::Circle => { col_def.custom(Alias::new("circle")); @@ -227,6 +227,9 @@ impl ColumnInfo { Type::Unknown(s) => { col_def.custom(Alias::new(s)); } + Type::Enum(def) => { + col_def.custom(Alias::new(def.typename.as_str())); + } }; col_def } diff --git a/src/postgres/writer/enumeration.rs b/src/postgres/writer/enumeration.rs new file mode 100644 index 00000000..939cf895 --- /dev/null +++ b/src/postgres/writer/enumeration.rs @@ -0,0 +1,15 @@ +use crate::postgres::def::EnumDef; +use sea_query::{ + extension::postgres::{Type, TypeCreateStatement}, + Alias, +}; + +impl EnumDef { + /// Converts the [EnumDef] to a [TypeCreateStatement] + pub fn write(&self) -> TypeCreateStatement { + Type::create() + .as_enum(Alias::new(self.typename.as_str())) + .values(self.values.iter().map(|val| Alias::new(val.as_str()))) + .to_owned() + } +} diff --git a/src/postgres/writer/mod.rs b/src/postgres/writer/mod.rs index 818d9dfc..732480f4 100644 --- a/src/postgres/writer/mod.rs +++ b/src/postgres/writer/mod.rs @@ -1,11 +1,13 @@ mod column; mod constraints; +mod enumeration; mod schema; mod table; mod types; pub use column::*; pub use constraints::*; +pub use enumeration::*; pub use schema::*; pub use table::*; pub use types::*; diff --git a/tests/live/postgres/src/main.rs b/tests/live/postgres/src/main.rs index 724f9506..cfe2a4a3 100644 --- a/tests/live/postgres/src/main.rs +++ b/tests/live/postgres/src/main.rs @@ -2,8 +2,8 @@ use std::collections::HashMap; use sea_schema::postgres::{def::TableDef, discovery::SchemaDiscovery}; use sea_schema::sea_query::{ - Alias, ColumnDef, ForeignKey, ForeignKeyAction, Index, PostgresQueryBuilder, Table, - TableCreateStatement, + extension::postgres::Type, Alias, ColumnDef, ForeignKey, ForeignKeyAction, Index, + PostgresQueryBuilder, Table, TableCreateStatement, }; use sqlx::{PgPool, Pool, Postgres}; @@ -13,6 +13,22 @@ async fn main() { let connection = setup("postgres://sea:sea@localhost", "sea-schema").await; let mut executor = connection.acquire().await.unwrap(); + let create_enum_stmt = Type::create() + .as_enum(Alias::new("crazy_enum")) + .values(vec![ + Alias::new("Astro0%00%8987,.!@#$%^&*()_-+=[]{}\\|.<>/? ``"), + Alias::new("Biology"), + Alias::new("Chemistry"), + Alias::new("Math"), + Alias::new("Physics"), + ]) + .to_string(PostgresQueryBuilder); + + sqlx::query(&create_enum_stmt) + .execute(&mut executor) + .await + .unwrap(); + let tbl_create_stmts = vec![ create_bakery_table(), create_baker_table(), @@ -53,6 +69,20 @@ async fn main() { println!(); assert_eq!(expected_sql, sql); } + + let enum_defs = schema_discovery.discover_enums().await; + + dbg!(&enum_defs); + + let enum_create_statements: Vec = enum_defs + .into_iter() + .map(|enum_def| enum_def.write().to_string(PostgresQueryBuilder)) + .collect(); + + dbg!(&create_enum_stmt); + dbg!(&enum_create_statements); + + assert_eq!(create_enum_stmt, enum_create_statements[0]); } async fn setup(base_url: &str, db_name: &str) -> Pool { @@ -90,6 +120,7 @@ fn create_bakery_table() -> TableCreateStatement { ) .col(ColumnDef::new(Alias::new("name")).string()) .col(ColumnDef::new(Alias::new("profit_margin")).double()) + .col(ColumnDef::new(Alias::new("crazy_enum_col")).custom(Alias::new("crazy_enum"))) .primary_key( Index::create() .primary()