Skip to content

Commit

Permalink
sql/mssql: inspect column default value (#1782)
Browse files Browse the repository at this point in the history
  • Loading branch information
giautm committed Jun 29, 2023
1 parent 566da41 commit a95bdca
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 62 deletions.
38 changes: 35 additions & 3 deletions sql/mssql/inspect.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,14 +280,15 @@ func (i *inspect) addColumn(s *schema.Schema, rows *sql.Rows, scope queryScope)
nullable, userDefined sql.NullInt64
identity, identitySeek, identityIncrement sql.NullInt64
size, precision, scale, isPersisted sql.NullInt64
genexpr sql.NullString
genexpr, defaults sql.NullString
isComputed int64
)
if err = rows.Scan(
&table, &name, &typeName, &comment,
&nullable, &userDefined,
&identity, &identitySeek, &identityIncrement,
&collation, &size, &precision, &scale, &isComputed, &genexpr, &isPersisted,
&collation, &size, &precision, &scale, &isComputed,
&genexpr, &isPersisted, &defaults,
); err != nil {
return err
}
Expand Down Expand Up @@ -323,6 +324,9 @@ func (i *inspect) addColumn(s *schema.Schema, rows *sql.Rows, scope queryScope)
}
c.SetGeneratedExpr(x)
}
if defaults.Valid {
c.Default = i.defaultExpr(c, defaults.String)
}
if sqlx.ValidString(comment) {
c.SetComment(comment.String)
}
Expand Down Expand Up @@ -439,6 +443,29 @@ func nArgs(start, n int) string {
return b.String()
}

// defaultExpr returns the default expression of the given column.
//
// https://learn.microsoft.com/en-us/sql/relational-databases/tables/specify-default-values-for-columns
func (i *inspect) defaultExpr(_ *schema.Column, x string) schema.Expr {
// Remove the parenthesis from the expression.
x = mayUnwrap(x)
// Literal expression is quoted or wrapped with parenthesis.
if sqlx.IsQuoted(x, '\'') || mayUnwrap(x) != x {
return &schema.Literal{V: x}
}
// Raw expression does not have a parenthesis.
return &schema.RawExpr{X: x}
}

// mayUnwrap removes the wrapping parentheses from the given string.
func mayUnwrap(s string) string {
n := len(s) - 1
if len(s) < 2 || s[0] != '(' || s[n] != ')' {
return s
}
return s[1:n]
}

const (
// Query to list server properties.
propertiesQuery = "SELECT SERVERPROPERTY('ProductVersion'), SERVERPROPERTY('Collation'), SERVERPROPERTY('SqlCharSetName')"
Expand Down Expand Up @@ -532,7 +559,8 @@ SELECT
[scale] = [c1].[scale],
[is_computed] = [c1].[is_computed],
[computed_definition] = [cc].[definition],
[computed_persisted] = [cc].[is_persisted]
[computed_persisted] = [cc].[is_persisted],
[default_definition] = [dc].[definition]
FROM
[sys].[tables] [t1]
INNER JOIN [sys].[columns] [c1]
Expand All @@ -542,6 +570,10 @@ FROM
LEFT JOIN [sys].[computed_columns] [cc]
ON [cc].[object_id] = [c1].[object_id]
AND [cc].[column_id] = [c1].[column_id]
LEFT JOIN [sys].[default_constraints] [dc]
ON [dc].[object_id] = [c1].[default_object_id]
AND [dc].[parent_object_id] = [c1].[object_id]
AND [dc].[parent_column_id] = [c1].[column_id]
LEFT JOIN [sys].[identity_columns] [ti]
ON [ti].[object_id] = [c1].[object_id]
AND [ti].[column_id] = [c1].[column_id]
Expand Down
Loading

0 comments on commit a95bdca

Please sign in to comment.