-
Notifications
You must be signed in to change notification settings - Fork 187
/
make_one_of_discriminant_required.go
112 lines (91 loc) · 2.84 KB
/
make_one_of_discriminant_required.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
/*
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT license.
*/
package pipeline
import (
"context"
"github.com/pkg/errors"
"github.com/Azure/azure-service-operator/v2/tools/generator/internal/astmodel"
)
// MakeOneOfDiscriminantRequiredStageId is the unique identifier for this pipeline stage
const MakeOneOfDiscriminantRequiredStageId = "makeOneOfDiscriminantRequired"
// MakeOneOfDiscriminantRequired walks the type graph and builds new types for communicating
// with ARM
func MakeOneOfDiscriminantRequired() *Stage {
return NewStage(
MakeOneOfDiscriminantRequiredStageId,
"Fix one of types to a discriminator which is not omitempty/optional",
func(ctx context.Context, state *State) (*State, error) {
updatedDefs := make(astmodel.TypeDefinitionSet)
for _, def := range state.Definitions() {
isOneOf := astmodel.OneOfFlag.IsOn(def.Type())
isARM := def.Name().IsARMType()
if !isOneOf || !isARM {
continue
}
updated, err := makeOneOfDiscriminantTypeRequired(def, state.Definitions())
if err != nil {
return nil, err
}
updatedDefs.AddTypes(updated)
}
return state.WithOverlaidDefinitions(updatedDefs), nil
})
}
type propertyModifier struct {
visitor astmodel.TypeVisitor[any]
json string
}
func newPropertyModifier(json string) *propertyModifier {
modifier := &propertyModifier{
json: json,
}
modifier.visitor = astmodel.TypeVisitorBuilder[any]{
VisitObjectType: modifier.makeDiscriminatorPropertiesRequired,
}.Build()
return modifier
}
func (r *propertyModifier) makeDiscriminatorPropertiesRequired(
this *astmodel.TypeVisitor[any],
ot *astmodel.ObjectType,
ctx any,
) (astmodel.Type, error) {
ot.Properties().ForEach(func(prop *astmodel.PropertyDefinition) {
if json, ok := prop.JSONName(); ok && r.json == json {
ot = ot.WithProperty(prop.MakeTypeRequired())
}
})
return astmodel.IdentityVisitOfObjectType(this, ot, ctx)
}
func makeOneOfDiscriminantTypeRequired(
oneOf astmodel.TypeDefinition,
defs astmodel.TypeDefinitionSet,
) (astmodel.TypeDefinitionSet, error) {
objectType, ok := astmodel.AsObjectType(oneOf.Type())
if !ok {
return nil, errors.Errorf(
"OneOf %s was not of type Object, instead was: %s",
oneOf.Name(),
astmodel.DebugDescription(oneOf.Type()))
}
result := make(astmodel.TypeDefinitionSet)
discriminantJson, values, err := astmodel.DetermineDiscriminantAndValues(objectType, defs)
if err != nil {
return nil, err
}
astmodel.NewPropertyInjector()
remover := newPropertyModifier(discriminantJson)
for _, value := range values {
def, err := defs.GetDefinition(value.TypeName)
if err != nil {
return nil, err
}
updatedDef, err := remover.visitor.VisitDefinition(def, nil)
if err != nil {
return nil, errors.Wrapf(err, "error updating definition %s", def.Name())
}
result.Add(updatedDef)
}
return result, nil
}