-
Notifications
You must be signed in to change notification settings - Fork 187
/
check_for_anytype.go
133 lines (110 loc) · 3.86 KB
/
check_for_anytype.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
/*
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT license.
*/
package pipeline
import (
"context"
"sort"
"strings"
"github.com/Azure/azure-service-operator/v2/internal/set"
"github.com/pkg/errors"
"github.com/Azure/azure-service-operator/v2/tools/generator/internal/astmodel"
)
// CheckForAnyTypeStageID is the unique identifier for this stage
const CheckForAnyTypeStageID = "rogueCheck"
// FilterOutDefinitionsUsingAnyType returns a stage that will check for any definitions
// containing AnyTypes. It accepts a set of packages that we expect to contain types
// with AnyTypes. Those packages will be quietly filtered out of the output of the
// stage, but if there are more AnyTypes in other packages they'll be reported as an
// error. The stage will also return an error if there are packages that we expect
// to have AnyTypes but turn out not to, ensuring that we clean up our configuration
// as the schemas are fixed and our handling improves.
func FilterOutDefinitionsUsingAnyType(packages []string) *Stage {
return checkForAnyType("Filter out rogue definitions using AnyTypes", packages)
}
// EnsureDefinitionsDoNotUseAnyTypes returns a stage that will check for any
// definitions containing AnyTypes. The stage will return errors for each type
// found that uses an AnyType.
func EnsureDefinitionsDoNotUseAnyTypes() *Stage {
return checkForAnyType("Check for rogue definitions using AnyTypes", []string{})
}
func checkForAnyType(description string, packages []string) *Stage {
expectedPackages := set.Make[string]()
for _, p := range packages {
expectedPackages.Add(p)
}
return NewLegacyStage(
CheckForAnyTypeStageID,
description,
func(ctx context.Context, defs astmodel.TypeDefinitionSet) (astmodel.TypeDefinitionSet, error) {
var badNames []astmodel.InternalTypeName
output := make(astmodel.TypeDefinitionSet)
for name, def := range defs {
if containsAnyType(def.Type()) {
badNames = append(badNames, name)
}
// We only want to include this type in the output if
// it's not in a package that we know contains
// AnyTypes.
if expectedPackages.Contains(name.InternalPackageReference().FolderPath()) {
continue
}
output.Add(def)
}
badPackages, err := collectBadPackages(badNames, expectedPackages)
if err != nil {
return nil, errors.Wrap(err, "summarising bad types")
}
if len(badPackages) > 0 {
return nil, errors.Errorf("AnyTypes found - add exclusions for: %s", strings.Join(badPackages, ", "))
}
return output, nil
})
}
func containsAnyType(theType astmodel.Type) bool {
var found bool
detectAnyType := func(it *astmodel.PrimitiveType) astmodel.Type {
if it == astmodel.AnyType {
found = true
}
return it
}
visitor := astmodel.TypeVisitorBuilder[any]{
VisitPrimitive: detectAnyType,
}.Build()
_, _ = visitor.Visit(theType, nil)
return found
}
func collectBadPackages(
names []astmodel.InternalTypeName,
expectedPackages set.Set[string],
) ([]string, error) {
grouped := make(map[string][]string)
for _, name := range names {
packagePath := name.InternalPackageReference().FolderPath()
grouped[packagePath] = append(grouped[packagePath], name.Name())
}
var groupNames []string //nolint:prealloc // unlikely case
for groupName := range grouped {
// Only complain about this package if it's one we don't know about.
if expectedPackages.Contains(groupName) {
expectedPackages.Remove(groupName)
continue
}
groupNames = append(groupNames, groupName)
}
sort.Strings(groupNames)
// Complain if there were some packages where we expected problems
// but didn't see any.
if len(expectedPackages) > 0 {
var leftovers []string
for value := range expectedPackages {
leftovers = append(leftovers, value)
}
sort.Strings(leftovers)
return nil, errors.Errorf(
"no AnyTypes found in: %s", strings.Join(leftovers, ", "))
}
return groupNames, nil
}