From 1b1256b727a7b355ae3fc30589b0728b850e4e4f Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Sun, 25 Jun 2023 23:32:43 +0300 Subject: [PATCH] sql/postgres: scan an marshal enum as top-level objects --- sql/postgres/inspect.go | 7 ++-- sql/postgres/inspect_test.go | 2 +- sql/postgres/sqlspec.go | 66 ++++++++++++++---------------------- sql/postgres/sqlspec_test.go | 29 +++++++++------- sql/schema/schema.go | 10 ++++++ 5 files changed, 56 insertions(+), 58 deletions(-) diff --git a/sql/postgres/inspect.go b/sql/postgres/inspect.go index af10e78f6fc..492f6891d9d 100644 --- a/sql/postgres/inspect.go +++ b/sql/postgres/inspect.go @@ -197,10 +197,7 @@ func (i *inspect) columns(ctx context.Context, s *schema.Schema, scope queryScop return fmt.Errorf("postgres: %w", err) } } - if err := rows.Close(); err != nil { - return err - } - return nil + return rows.Close() } // addColumn scans the current row and adds a new column from it to the scope (table or view). @@ -1226,6 +1223,8 @@ FROM JOIN pg_namespace n ON t.typnamespace = n.oid WHERE n.nspname IN (%s) +ORDER BY + n.nspname, e.enumtypid, e.enumsortorder ` // Query to list foreign-keys. fksQuery = ` diff --git a/sql/postgres/inspect_test.go b/sql/postgres/inspect_test.go index e16a230859c..4f358845d52 100644 --- a/sql/postgres/inspect_test.go +++ b/sql/postgres/inspect_test.go @@ -5,12 +5,12 @@ package postgres import ( - "ariga.io/atlas/sql/internal/sqlx" "context" "fmt" "testing" "ariga.io/atlas/sql/internal/sqltest" + "ariga.io/atlas/sql/internal/sqlx" "ariga.io/atlas/sql/migrate" "ariga.io/atlas/sql/schema" diff --git a/sql/postgres/sqlspec.go b/sql/postgres/sqlspec.go index a3a6d6a55f1..7987a7dca43 100644 --- a/sql/postgres/sqlspec.go +++ b/sql/postgres/sqlspec.go @@ -427,16 +427,26 @@ func convertColumnType(spec *sqlspec.Column) (schema.Type, error) { // convertEnums converts possibly referenced column types (like enums) to // an actual schema.Type and sets it on the correct schema.Column. func convertEnums(tables []*sqlspec.Table, enums []*Enum, r *schema.Realm) error { - var ( - used = make(map[*Enum]struct{}) - byName = make(map[string]*Enum) - ) + byName := make(map[string]*schema.EnumType) for _, e := range enums { - byName[e.Name] = e + if byName[e.Name] != nil { + return fmt.Errorf("duplicate enum %q", e.Name) + } + ns, err := specutil.SchemaName(e.Schema) + if err != nil { + return fmt.Errorf("extract schema name from enum reference: %w", err) + } + es, ok := r.Schema(ns) + if !ok { + return fmt.Errorf("schema %q defined on enum %q was not found in realm", ns, e.Name) + } + e1 := &schema.EnumType{T: e.Name, Schema: es, Values: e.Values} + es.Objects = append(es.Objects, e1) + byName[e.Name] = e1 } for _, t := range tables { for _, c := range t.Columns { - var enum *Enum + var enum *schema.EnumType switch { case c.Type.IsRef: n, err := enumName(c.Type) @@ -445,7 +455,7 @@ func convertEnums(tables []*sqlspec.Table, enums []*Enum, r *schema.Realm) error } e, ok := byName[n] if !ok { - return fmt.Errorf("enum %q was not found", n) + return fmt.Errorf("enum %q was not found in realm", n) } enum = e default: @@ -455,15 +465,6 @@ func convertEnums(tables []*sqlspec.Table, enums []*Enum, r *schema.Realm) error } enum = byName[n] } - used[enum] = struct{}{} - schemaE, err := specutil.SchemaName(enum.Schema) - if err != nil { - return fmt.Errorf("extract schema name from enum reference: %w", err) - } - es, ok := r.Schema(schemaE) - if !ok { - return fmt.Errorf("schema %q not found in realm for table %q", schemaE, t.Name) - } schemaT, err := specutil.SchemaName(t.Schema) if err != nil { return fmt.Errorf("extract schema name from table reference: %w", err) @@ -480,20 +481,14 @@ func convertEnums(tables []*sqlspec.Table, enums []*Enum, r *schema.Realm) error if !ok { return fmt.Errorf("column %q not found in table %q", c.Name, t.Name) } - e := &schema.EnumType{T: enum.Name, Schema: es, Values: enum.Values} switch t := cc.Type.Type.(type) { case *ArrayType: - t.Type = e + t.Type = enum default: - cc.Type.Type = e + cc.Type.Type = enum } } } - for _, e := range enums { - if _, ok := used[e]; !ok { - return fmt.Errorf("enum %q declared but not used", e.Name) - } - } return nil } @@ -514,8 +509,8 @@ func enumRef(n string) *schemahcl.Ref { } // schemaSpec converts from a concrete Postgres schema to Atlas specification. -func schemaSpec(schem *schema.Schema) (*doc, error) { - spec, err := specutil.FromSchema(schem, tableSpec, viewSpec) +func schemaSpec(s *schema.Schema) (*doc, error) { + spec, err := specutil.FromSchema(s, tableSpec, viewSpec) if err != nil { return nil, err } @@ -523,26 +518,15 @@ func schemaSpec(schem *schema.Schema) (*doc, error) { Tables: spec.Tables, Views: spec.Views, Schemas: []*sqlspec.Schema{spec.Schema}, + Enums: make([]*Enum, 0, len(s.Objects)), } - enums := make(map[string]bool) - mayAdd := func(c *schema.Column) { - if e, ok := hasEnumType(c); ok && !enums[e.T] { + for _, o := range s.Objects { + if e, ok := o.(*schema.EnumType); ok { d.Enums = append(d.Enums, &Enum{ Name: e.T, - Schema: specutil.SchemaRef(spec.Schema.Name), Values: e.Values, + Schema: specutil.SchemaRef(spec.Schema.Name), }) - enums[e.T] = true - } - } - for _, t := range schem.Tables { - for _, c := range t.Columns { - mayAdd(c) - } - } - for _, t := range schem.Views { - for _, c := range t.Columns { - mayAdd(c) } } return d, nil diff --git a/sql/postgres/sqlspec_test.go b/sql/postgres/sqlspec_test.go index 5c0008ba7b0..fa8c97225f0 100644 --- a/sql/postgres/sqlspec_test.go +++ b/sql/postgres/sqlspec_test.go @@ -102,9 +102,8 @@ enum "account_type" { var s schema.Schema err := EvalHCLBytes([]byte(f), &s, nil) require.NoError(t, err) - exp := &schema.Schema{ - Name: "schema", - } + exp := schema.New("schema") + exp.AddObjects(&schema.EnumType{T: "account_type", Values: []string{"private", "business"}, Schema: exp}) exp.Tables = []*schema.Table{ { Name: "table", @@ -1118,7 +1117,18 @@ table "users" { } func TestMarshalSpec_Enum(t *testing.T) { + stateE := &schema.EnumType{ + T: "state", + Values: []string{"on", "off"}, + } + typeE := &schema.EnumType{ + T: "account_type", + Values: []string{"private", "business"}, + } s := schema.New("test"). + AddObjects( + typeE, stateE, + ). AddTables( schema.NewTable("account"). AddColumns( @@ -1128,19 +1138,14 @@ func TestMarshalSpec_Enum(t *testing.T) { ), schema.NewColumn("account_states"). SetType(&ArrayType{ - T: "states[]", - Type: &schema.EnumType{ - T: "state", - Values: []string{"on", "off"}, - }, + T: "states[]", + Type: stateE, }), ), schema.NewTable("table2"). AddColumns( - schema.NewEnumColumn("account_type", - schema.EnumName("account_type"), - schema.EnumValues("private", "business"), - ), + schema.NewColumn("account_type"). + SetType(typeE), ), ) buf, err := MarshalSpec(s, hclState) diff --git a/sql/schema/schema.go b/sql/schema/schema.go index 04e30033cbf..42fa4234945 100644 --- a/sql/schema/schema.go +++ b/sql/schema/schema.go @@ -134,6 +134,16 @@ func (s *Schema) View(name string) (*View, bool) { return nil, false } +// Object returns the first object that matched the given predicate. +func (s *Schema) Object(f func(Object) bool) (Object, bool) { + for _, o := range s.Objects { + if f(o) { + return o, true + } + } + return nil, false +} + // Column returns the first column that matched the given name. func (t *Table) Column(name string) (*Column, bool) { for _, c := range t.Columns {