/
db_utils.go
147 lines (139 loc) · 4.42 KB
/
db_utils.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
package database
import (
"os"
"sort"
"strings"
"github.com/anz-bank/sysl/pkg/sysl"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/spf13/afero"
)
func CreateTableDepthMap(tableMap map[string]*sysl.Type) map[int][]string {
var completedTableDepthMap = map[int][]string{}
var incompleteTableDepthMap = map[string]int{}
var completeTableDepthMap = map[string]int{}
var visitedTableAttrDepth = map[string]string{}
for tableName := range tableMap {
incompleteTableDepthMap[tableName] = 0
}
processTableDepth(tableMap, completedTableDepthMap, completeTableDepthMap, incompleteTableDepthMap,
visitedTableAttrDepth)
return completedTableDepthMap
}
func processTableDepth(
tableMap map[string]*sysl.Type,
completedTableDepthMap map[int][]string,
completeTableDepthMap map[string]int,
incompleteTableDepthMap map[string]int,
visitedTableAttrs map[string]string,
) {
for tableName := range incompleteTableDepthMap {
processComplete, size, tempVisitedAttrs := findTableDepth(tableName, tableMap[tableName],
visitedTableAttrs, completeTableDepthMap)
if processComplete {
processedTablesSlice := completedTableDepthMap[size]
if processedTablesSlice == nil {
processedTablesSlice = nil
}
processedTablesSlice = append(processedTablesSlice, tableName)
completedTableDepthMap[size] = processedTablesSlice
completeTableDepthMap[tableName] = size
delete(incompleteTableDepthMap, tableName)
for tempAttr := range tempVisitedAttrs {
visitedTableAttrs[tempAttr] = tempVisitedAttrs[tempAttr]
}
}
}
if len(incompleteTableDepthMap) != 0 {
processTableDepth(tableMap, completedTableDepthMap, completeTableDepthMap, incompleteTableDepthMap,
visitedTableAttrs)
}
}
func findTableDepth(
tableName string,
table *sysl.Type,
visitedTableAttrs map[string]string,
completeTableDepthMap map[string]int,
) (bool, int, map[string]string) {
var allAttrProcessed = true
var tableDepth int
var tempVisitedAttrs = map[string]string{}
if relEntity := table.GetRelation(); relEntity != nil {
var attrNames []string
for attrName := range relEntity.AttrDefs {
attrNames = append(attrNames, attrName)
}
for _, attrName := range attrNames {
attrType := relEntity.AttrDefs[attrName]
if typeRef := attrType.GetTypeRef(); typeRef != nil {
if val, ok := visitedTableAttrs[typeRef.GetRef().Path[0]+"."+typeRef.GetRef().Path[1]]; ok {
newDepth := completeTableDepthMap[typeRef.GetRef().Path[0]] + 1
tempVisitedAttrs[tableName+"."+attrName] = val
if newDepth > tableDepth {
tableDepth = newDepth
}
} else {
allAttrProcessed = false
}
} else {
tempVisitedAttrs[tableName+"."+attrName] = attrType.GetPrimitive().String()
}
}
}
return allAttrProcessed, tableDepth, tempVisitedAttrs
}
func GenerateFromSQLMap(m []ScriptOutput, fs afero.Fs, logger *logrus.Logger) error { //nolint:interfacer
for _, e := range m {
err := errors.Wrapf(afero.WriteFile(fs, e.filename, []byte(e.content),
os.ModePerm), "writing %q", e.filename)
if err != nil {
logger.Errorf("error received while writing the file %s. The error message is - %s", e.filename, err.Error())
return err
}
}
return nil
}
func isAutoIncrementAndPrimaryKey(attrType *sysl.Type) (bool, bool) {
isAutoIncrement := false
isPrimaryKey := false
if patterns := attrType.GetAttrs(); patterns != nil {
if attributesArray := patterns["patterns"].GetA(); attributesArray != nil {
nestedAttrs := attributesArray.GetElt()
for _, nestedAttr := range nestedAttrs {
if strings.EqualFold("autoinc", nestedAttr.GetS()) {
isAutoIncrement = true
}
if strings.EqualFold("pk", nestedAttr.GetS()) {
isPrimaryKey = true
}
}
}
}
return isAutoIncrement, isPrimaryKey
}
func getDataTypeAndSize(attrType *sysl.Type) (string, int64) {
syslDataType := strings.ToLower(attrType.GetPrimitive().String())
var attributeSize int64
attributeSize = defaultTextSize
if syslDataType == strConst {
constraint := attrType.GetConstraint()
if len(constraint) > 0 {
length := constraint[0].GetLength()
if length != nil {
max := length.GetMax()
if max > 0 {
attributeSize = max
}
}
}
}
return syslDataType, attributeSize
}
func sortColumnNamesIntoList(attrMap map[string]*sysl.Type) []string {
var sortedColumnNames []string
for columnName := range attrMap {
sortedColumnNames = append(sortedColumnNames, columnName)
}
sort.Strings(sortedColumnNames)
return sortedColumnNames
}