-
Notifications
You must be signed in to change notification settings - Fork 3
/
data.go
90 lines (80 loc) · 2.41 KB
/
data.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
package pgdump
import (
"database/sql"
"fmt"
"strings"
)
// returns a slice of table names in the public schema.
func getTables(db *sql.DB) ([]string, error) {
query := "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"
rows, err := db.Query(query)
if err != nil {
return nil, err
}
defer rows.Close()
var tables []string
for rows.Next() {
var tableName string
if err := rows.Scan(&tableName); err != nil {
return nil, err
}
tables = append(tables, tableName)
}
return tables, nil
}
// generates the SQL for creating a table, including column definitions.
func getCreateTableStatement(db *sql.DB, tableName string) (string, error) {
query := fmt.Sprintf("SELECT column_name, data_type, character_maximum_length FROM information_schema.columns WHERE table_name = '%s'", tableName)
rows, err := db.Query(query)
if err != nil {
return "", err
}
defer rows.Close()
var columns []string
for rows.Next() {
var columnName, dataType string
var charMaxLength *int
if err := rows.Scan(&columnName, &dataType, &charMaxLength); err != nil {
return "", err
}
columnDef := fmt.Sprintf("%s %s", columnName, dataType)
if charMaxLength != nil {
columnDef += fmt.Sprintf("(%d)", *charMaxLength)
}
columns = append(columns, columnDef)
}
return fmt.Sprintf("CREATE TABLE %s (\n %s\n);", tableName, strings.Join(columns, ",\n ")), nil
}
// generates the COPY command to import data for a table.
func getTableDataCopyFormat(db *sql.DB, tableName string) (string, error) {
query := fmt.Sprintf("SELECT * FROM %s", tableName)
rows, err := db.Query(query)
if err != nil {
return "", err
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return "", err
}
values := make([]sql.RawBytes, len(columns))
scanArgs := make([]interface{}, len(values))
for i := range values {
scanArgs[i] = &values[i]
}
var output strings.Builder
output.WriteString(fmt.Sprintf("COPY %s (%s) FROM stdin;\n", tableName, strings.Join(columns, ", ")))
for rows.Next() {
err := rows.Scan(scanArgs...)
if err != nil {
return "", err
}
var valueStrings []string
for _, value := range values {
valueStrings = append(valueStrings, string(value))
}
output.WriteString(strings.Join(valueStrings, "\t") + "\n")
}
output.WriteString("\\.\n")
return output.String(), nil
}