-
Notifications
You must be signed in to change notification settings - Fork 187
/
remove_type_aliases.go
104 lines (89 loc) · 3.27 KB
/
remove_type_aliases.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
/*
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT license.
*/
package pipeline
import (
"context"
"fmt"
kerrors "k8s.io/apimachinery/pkg/util/errors"
"github.com/pkg/errors"
"github.com/Azure/azure-service-operator/v2/tools/generator/internal/astmodel"
)
// RemoveTypeAliasesStageID is the unique identifier for this pipeline stage
const RemoveTypeAliasesStageID = "removeAliases"
// RemoveTypeAliases creates a pipeline stage removing type aliases
func RemoveTypeAliases() *Stage {
return NewLegacyStage(
RemoveTypeAliasesStageID,
"Remove type aliases",
func(ctx context.Context, definitions astmodel.TypeDefinitionSet) (astmodel.TypeDefinitionSet, error) {
simplifyAliases := func(this *astmodel.TypeVisitor[any], it astmodel.InternalTypeName, ctx any) (astmodel.Type, error) {
return resolveTypeName(this, it, definitions)
}
visitor := astmodel.TypeVisitorBuilder[any]{
VisitInternalTypeName: simplifyAliases,
}.Build()
result := make(astmodel.TypeDefinitionSet)
var errs []error
for _, typeDef := range definitions {
visitedType, err := visitor.Visit(typeDef.Type(), nil)
if err != nil {
errs = append(errs, err)
} else {
result.Add(typeDef.WithType(visitedType))
}
}
if len(errs) > 0 {
return nil, kerrors.NewAggregate(errs)
}
return result, nil
})
}
func resolveTypeName(
visitor *astmodel.TypeVisitor[any],
name astmodel.InternalTypeName,
definitions astmodel.TypeDefinitionSet,
) (astmodel.Type, error) {
// Don't try to remove external refs
if astmodel.IsExternalPackageReference(name.PackageReference()) {
return name, nil
}
def, ok := definitions[name]
if !ok {
return nil, errors.Errorf("couldn't find definition for type name %s", name)
}
// If this typeName definition has a type of object, enum, validated, flagged, resource, or resourceList
// it's okay. Everything else we want to pull up one level to remove the alias
switch concreteType := def.Type().(type) {
case *astmodel.ObjectType:
return def.Name(), nil // must remain named for controller-gen
case *astmodel.EnumType:
return def.Name(), nil // must remain named so there is somewhere to put validations
case *astmodel.ResourceType:
return def.Name(), nil // must remain named for controller-gen
case *astmodel.ValidatedType:
return def.Name(), nil // must remain named so there is somewhere to put validations
case *astmodel.FlaggedType:
return def.Name(), nil // must remain named as it is just wrapping objectType (and objectType remains named)
case *astmodel.InterfaceType:
return def.Name(), nil // must remain named
case *astmodel.ExternalTypeName:
return def.Name(), nil // must remain named
case astmodel.InternalTypeName:
// We need to resolve further because this type is an alias
return resolveTypeName(visitor, concreteType, definitions)
case *astmodel.PrimitiveType:
return visitor.Visit(concreteType, nil)
case *astmodel.OptionalType:
return visitor.Visit(concreteType, nil)
case *astmodel.ArrayType:
return visitor.Visit(concreteType, nil)
case *astmodel.MapType:
return visitor.Visit(concreteType, nil)
case *astmodel.ErroredType:
return visitor.Visit(concreteType, nil)
default:
panic(fmt.Sprintf("Don't know how to resolve type %T for typeName %s", concreteType, name))
}
}