Skip to content

Commit

Permalink
sql/postgres: scan an marshal enum as top-level objects
Browse files Browse the repository at this point in the history
  • Loading branch information
a8m committed Jun 25, 2023
1 parent d1071c7 commit 1b1256b
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 58 deletions.
7 changes: 3 additions & 4 deletions sql/postgres/inspect.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,7 @@ func (i *inspect) columns(ctx context.Context, s *schema.Schema, scope queryScop
return fmt.Errorf("postgres: %w", err)
}
}
if err := rows.Close(); err != nil {
return err
}
return nil
return rows.Close()
}

// addColumn scans the current row and adds a new column from it to the scope (table or view).
Expand Down Expand Up @@ -1226,6 +1223,8 @@ FROM
JOIN pg_namespace n ON t.typnamespace = n.oid
WHERE
n.nspname IN (%s)
ORDER BY
n.nspname, e.enumtypid, e.enumsortorder
`
// Query to list foreign-keys.
fksQuery = `
Expand Down
2 changes: 1 addition & 1 deletion sql/postgres/inspect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
package postgres

import (
"ariga.io/atlas/sql/internal/sqlx"
"context"
"fmt"
"testing"

"ariga.io/atlas/sql/internal/sqltest"
"ariga.io/atlas/sql/internal/sqlx"
"ariga.io/atlas/sql/migrate"
"ariga.io/atlas/sql/schema"

Expand Down
66 changes: 25 additions & 41 deletions sql/postgres/sqlspec.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,16 +427,26 @@ func convertColumnType(spec *sqlspec.Column) (schema.Type, error) {
// convertEnums converts possibly referenced column types (like enums) to
// an actual schema.Type and sets it on the correct schema.Column.
func convertEnums(tables []*sqlspec.Table, enums []*Enum, r *schema.Realm) error {
var (
used = make(map[*Enum]struct{})
byName = make(map[string]*Enum)
)
byName := make(map[string]*schema.EnumType)
for _, e := range enums {
byName[e.Name] = e
if byName[e.Name] != nil {
return fmt.Errorf("duplicate enum %q", e.Name)
}
ns, err := specutil.SchemaName(e.Schema)
if err != nil {
return fmt.Errorf("extract schema name from enum reference: %w", err)
}
es, ok := r.Schema(ns)
if !ok {
return fmt.Errorf("schema %q defined on enum %q was not found in realm", ns, e.Name)
}
e1 := &schema.EnumType{T: e.Name, Schema: es, Values: e.Values}
es.Objects = append(es.Objects, e1)
byName[e.Name] = e1
}
for _, t := range tables {
for _, c := range t.Columns {
var enum *Enum
var enum *schema.EnumType
switch {
case c.Type.IsRef:
n, err := enumName(c.Type)
Expand All @@ -445,7 +455,7 @@ func convertEnums(tables []*sqlspec.Table, enums []*Enum, r *schema.Realm) error
}
e, ok := byName[n]
if !ok {
return fmt.Errorf("enum %q was not found", n)
return fmt.Errorf("enum %q was not found in realm", n)
}
enum = e
default:
Expand All @@ -455,15 +465,6 @@ func convertEnums(tables []*sqlspec.Table, enums []*Enum, r *schema.Realm) error
}
enum = byName[n]
}
used[enum] = struct{}{}
schemaE, err := specutil.SchemaName(enum.Schema)
if err != nil {
return fmt.Errorf("extract schema name from enum reference: %w", err)
}
es, ok := r.Schema(schemaE)
if !ok {
return fmt.Errorf("schema %q not found in realm for table %q", schemaE, t.Name)
}
schemaT, err := specutil.SchemaName(t.Schema)
if err != nil {
return fmt.Errorf("extract schema name from table reference: %w", err)
Expand All @@ -480,20 +481,14 @@ func convertEnums(tables []*sqlspec.Table, enums []*Enum, r *schema.Realm) error
if !ok {
return fmt.Errorf("column %q not found in table %q", c.Name, t.Name)
}
e := &schema.EnumType{T: enum.Name, Schema: es, Values: enum.Values}
switch t := cc.Type.Type.(type) {
case *ArrayType:
t.Type = e
t.Type = enum
default:
cc.Type.Type = e
cc.Type.Type = enum
}
}
}
for _, e := range enums {
if _, ok := used[e]; !ok {
return fmt.Errorf("enum %q declared but not used", e.Name)
}
}
return nil
}

Expand All @@ -514,35 +509,24 @@ func enumRef(n string) *schemahcl.Ref {
}

// schemaSpec converts from a concrete Postgres schema to Atlas specification.
func schemaSpec(schem *schema.Schema) (*doc, error) {
spec, err := specutil.FromSchema(schem, tableSpec, viewSpec)
func schemaSpec(s *schema.Schema) (*doc, error) {
spec, err := specutil.FromSchema(s, tableSpec, viewSpec)
if err != nil {
return nil, err
}
d := &doc{
Tables: spec.Tables,
Views: spec.Views,
Schemas: []*sqlspec.Schema{spec.Schema},
Enums: make([]*Enum, 0, len(s.Objects)),
}
enums := make(map[string]bool)
mayAdd := func(c *schema.Column) {
if e, ok := hasEnumType(c); ok && !enums[e.T] {
for _, o := range s.Objects {
if e, ok := o.(*schema.EnumType); ok {
d.Enums = append(d.Enums, &Enum{
Name: e.T,
Schema: specutil.SchemaRef(spec.Schema.Name),
Values: e.Values,
Schema: specutil.SchemaRef(spec.Schema.Name),
})
enums[e.T] = true
}
}
for _, t := range schem.Tables {
for _, c := range t.Columns {
mayAdd(c)
}
}
for _, t := range schem.Views {
for _, c := range t.Columns {
mayAdd(c)
}
}
return d, nil
Expand Down
29 changes: 17 additions & 12 deletions sql/postgres/sqlspec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@ enum "account_type" {
var s schema.Schema
err := EvalHCLBytes([]byte(f), &s, nil)
require.NoError(t, err)
exp := &schema.Schema{
Name: "schema",
}
exp := schema.New("schema")
exp.AddObjects(&schema.EnumType{T: "account_type", Values: []string{"private", "business"}, Schema: exp})
exp.Tables = []*schema.Table{
{
Name: "table",
Expand Down Expand Up @@ -1118,7 +1117,18 @@ table "users" {
}

func TestMarshalSpec_Enum(t *testing.T) {
stateE := &schema.EnumType{
T: "state",
Values: []string{"on", "off"},
}
typeE := &schema.EnumType{
T: "account_type",
Values: []string{"private", "business"},
}
s := schema.New("test").
AddObjects(
typeE, stateE,
).
AddTables(
schema.NewTable("account").
AddColumns(
Expand All @@ -1128,19 +1138,14 @@ func TestMarshalSpec_Enum(t *testing.T) {
),
schema.NewColumn("account_states").
SetType(&ArrayType{
T: "states[]",
Type: &schema.EnumType{
T: "state",
Values: []string{"on", "off"},
},
T: "states[]",
Type: stateE,
}),
),
schema.NewTable("table2").
AddColumns(
schema.NewEnumColumn("account_type",
schema.EnumName("account_type"),
schema.EnumValues("private", "business"),
),
schema.NewColumn("account_type").
SetType(typeE),
),
)
buf, err := MarshalSpec(s, hclState)
Expand Down
10 changes: 10 additions & 0 deletions sql/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,16 @@ func (s *Schema) View(name string) (*View, bool) {
return nil, false
}

// Object returns the first object that matched the given predicate.
func (s *Schema) Object(f func(Object) bool) (Object, bool) {
for _, o := range s.Objects {
if f(o) {
return o, true
}
}
return nil, false
}

// Column returns the first column that matched the given name.
func (t *Table) Column(name string) (*Column, bool) {
for _, c := range t.Columns {
Expand Down

0 comments on commit 1b1256b

Please sign in to comment.