Skip to content

Commit

Permalink
[go-server] Feat: add required assertions to models (#10068)
Browse files Browse the repository at this point in the history
* Add RequiredError

* Add IsZeroValue helper

* Add AssertRequired method to all models

* Add AssertRequired call for body param

* Regenerate files

* Add DisallowUnknownFields

* Regenerate samples

* Use hasRequired in model to remove unnecessary code

* Revert disallowUnknownFields

* Use isAdditionalPropertiesTrue for disallowing unknown fields

* Updated samples

* Fix indent

* Add require checks for nested slices

* Add new tests

* Regenerate samples

* Regenerate samples after merging
  • Loading branch information
lwj5 committed Aug 7, 2021
1 parent 189b44b commit 11d29eb
Show file tree
Hide file tree
Showing 56 changed files with 4,441 additions and 52 deletions.
9 changes: 9 additions & 0 deletions bin/configs/go-server-required.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
generatorName: go-server
outputDir: samples/server/petstore/go-server-required
inputSpec: modules/openapi-generator/src/test/resources/3_0/server-required.yaml
templateDir: modules/openapi-generator/src/main/resources/go-server
additionalProperties:
hideGenerationTimestamp: "true"
packageName: petstoreserver
addResponseHeaders: true
router: "chi"
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,36 @@ func (c *{{classname}}Controller) {{nickname}}(w http.ResponseWriter, r *http.Re
{{paramName}} := r.Header.Get("{{baseName}}")
{{/isHeaderParam}}
{{#isBodyParam}}
{{paramName}} := &{{dataType}}{}
if err := json.NewDecoder(r.Body).Decode(&{{paramName}}); err != nil {
{{paramName}} := {{dataType}}{}
d := json.NewDecoder(r.Body)
{{^isAdditionalPropertiesTrue}}
d.DisallowUnknownFields()
{{/isAdditionalPropertiesTrue}}
if err := d.Decode(&{{paramName}}); err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
return
}
{{#isArray}}
{{#items.isModel}}
for _, el := range {{paramName}} {
if err := Assert{{baseType}}Required(el); err != nil {
c.errorHandler(w, r, err, nil)
return
}
}
{{/items.isModel}}
{{/isArray}}
{{^isArray}}
{{#isModel}}
if err := Assert{{baseType}}Required({{paramName}}); err != nil {
c.errorHandler(w, r, err, nil)
return
}
{{/isModel}}
{{/isArray}}
{{/isBodyParam}}
{{/allParams}}
result, err := c.service.{{nickname}}(r.Context(){{#allParams}}, {{#isBodyParam}}*{{/isBodyParam}}{{paramName}}{{/allParams}})
result, err := c.service.{{nickname}}(r.Context(){{#allParams}}, {{paramName}}{{/allParams}})
// If an error occurred, encode the error with the status code
if err != nil {
c.errorHandler(w, r, err, &result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,16 @@
package {{packageName}}

import (
"errors"
"fmt"
"net/http"
)

var (
// ErrTypeAssertionError is thrown when type an interface does not match the asserted type
ErrTypeAssertionError = errors.New("unable to assert type")
)

// ParsingError indicates that an error has occurred when parsing request parameters
type ParsingError struct {
Err error
Expand All @@ -18,6 +25,15 @@ func (e *ParsingError) Error() string {
return e.Err.Error()
}

// RequiredError indicates that an error has occurred when parsing request parameters
type RequiredError struct {
Field string
}

func (e *RequiredError) Error() string {
return fmt.Sprintf("required field '%s' is zero value.", e.Field)
}

// ErrorHandler defines the required method for handling error. You may implement it and inject this into a controller if
// you would like errors to be handled differently from the DefaultErrorHandler
type ErrorHandler func(w http.ResponseWriter, r *http.Request, err error, result *ImplResponse)
Expand All @@ -28,6 +44,9 @@ func DefaultErrorHandler(w http.ResponseWriter, r *http.Request, err error, resu
if _, ok := err.(*ParsingError); ok {
// Handle parsing errors
EncodeJSONResponse(err.Error(), func(i int) *int { return &i }(http.StatusBadRequest),{{#addResponseHeaders}} map[string][]string{},{{/addResponseHeaders}} w)
} else if _, ok := err.(*RequiredError); ok {
// Handle missing required errors
EncodeJSONResponse(err.Error(), func(i int) *int { return &i }(http.StatusUnprocessableEntity),{{#addResponseHeaders}} map[string][]string{},{{/addResponseHeaders}} w)
} else {
// Handle all other errors
EncodeJSONResponse(err.Error(), &result.Code,{{#addResponseHeaders}} result.Headers,{{/addResponseHeaders}} w)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
{{>partial_header}}
package {{packageName}}

//Response return a ImplResponse struct filled
import (
"reflect"
)

// Response return a ImplResponse struct filled
func Response(code int, body interface{}) ImplResponse {
return ImplResponse {
Code: code,
Expand All @@ -13,7 +17,7 @@ func Response(code int, body interface{}) ImplResponse {
}
{{#addResponseHeaders}}

//ResponseWithHeaders return a ImplResponse struct filled, including headers
// ResponseWithHeaders return a ImplResponse struct filled, including headers
func ResponseWithHeaders(code int, headers map[string][]string, body interface{}) ImplResponse {
return ImplResponse {
Code: code,
Expand All @@ -22,3 +26,35 @@ func ResponseWithHeaders(code int, headers map[string][]string, body interface{}
}
}
{{/addResponseHeaders}}

// IsZeroValue checks if the val is the zero-ed value.
func IsZeroValue(val interface{}) bool {
return val == nil || reflect.DeepEqual(val, reflect.Zero(reflect.TypeOf(val)).Interface())
}

// AssertInterfaceRequired recursively checks each struct in a slice against the callback.
// This method traverse nested slices in a preorder fashion.
func AssertRecurseInterfaceRequired(obj interface{}, callback func(interface{}) error) error {
return AssertRecurseValueRequired(reflect.ValueOf(obj), callback)
}

// AssertNestedValueRequired checks each struct in the nested slice against the callback.
// This method traverse nested slices in a preorder fashion.
func AssertRecurseValueRequired(value reflect.Value, callback func(interface{}) error) error {
switch value.Kind() {
// If it is a struct we check using callback
case reflect.Struct:
if err := callback(value.Interface()); err != nil {
return err
}

// If it is a slice we continue recursion
case reflect.Slice:
for i := 0; i < value.Len(); i += 1 {
if err := AssertRecurseValueRequired(value.Index(i), callback); err != nil {
return err
}
}
}
return nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,101 @@ type {{classname}} struct {
{{/deprecated}}
{{name}} {{#isNullable}}*{{/isNullable}}{{{dataType}}} `json:"{{baseName}}{{^required}},omitempty{{/required}}"{{#vendorExtensions.x-go-custom-tag}} {{{.}}}{{/vendorExtensions.x-go-custom-tag}}`
{{/vars}}
}{{/isEnum}}{{/model}}{{/models}}
}{{/isEnum}}

// Assert{{classname}}Required checks if the required fields are not zero-ed
func Assert{{classname}}Required(obj {{classname}}) error {
{{#hasRequired}}
elements := map[string]interface{}{
{{#requiredVars}} "{{baseName}}": obj.{{name}},
{{/requiredVars}} }
for name, el := range elements {
if isZero := IsZeroValue(el); isZero {
return &RequiredError{Field: name}
}
}

{{/hasRequired}}
{{#parent}}
{{^isMap}}
{{^isArray}}
if err := Assert{{{parent}}}Required(obj.{{{parent}}}); err != nil {
return err
}

{{/isArray}}
{{/isMap}}
{{/parent}}
{{#Vars}}
{{#isNullable}}
{{#isModel}}
if obj.{{name}} != nil {
{{/isModel}}
{{#isArray}}
{{#items.isModel}}
if obj.{{name}} != nil {
{{/items.isModel}}
{{^items.isModel}}
{{#mostInnerItems.isModel}}
{{^mostInnerItems.isPrimitiveType}}
if obj.{{name}} != nil {
{{/mostInnerItems.isPrimitiveType}}
{{/mostInnerItems.isModel}}
{{/items.isModel}}
{{/isArray}}
{{/isNullable}}
{{#isModel}}
{{#isNullable}} {{/isNullable}} if err := Assert{{baseType}}Required({{#isNullable}}*{{/isNullable}}obj.{{name}}); err != nil {
{{#isNullable}} {{/isNullable}} return err
{{#isNullable}} {{/isNullable}} }
{{/isModel}}
{{#isArray}}
{{#items.isModel}}
{{#isNullable}} {{/isNullable}} for _, el := range {{#isNullable}}*{{/isNullable}}obj.{{name}} {
{{#isNullable}} {{/isNullable}} if err := Assert{{items.baseType}}Required(el); err != nil {
{{#isNullable}} {{/isNullable}} return err
{{#isNullable}} {{/isNullable}} }
{{#isNullable}} {{/isNullable}} }
{{/items.isModel}}
{{^items.isModel}}
{{#mostInnerItems.isModel}}
{{^mostInnerItems.isPrimitiveType}}
{{#isNullable}} {{/isNullable}} if err := AssertRecurse{{mostInnerItems.dataType}}Required({{#isNullable}}*{{/isNullable}}obj.{{name}}); err != nil {
{{#isNullable}} {{/isNullable}} return err
{{#isNullable}} {{/isNullable}} }
{{/mostInnerItems.isPrimitiveType}}
{{/mostInnerItems.isModel}}
{{/items.isModel}}
{{/isArray}}
{{#isNullable}}
{{#isModel}}
}
{{/isModel}}
{{#isArray}}
{{#items.isModel}}
}
{{/items.isModel}}
{{^items.isModel}}
{{#mostInnerItems.isModel}}
{{^mostInnerItems.isPrimitiveType}}
}
{{/mostInnerItems.isPrimitiveType}}
{{/mostInnerItems.isModel}}
{{/items.isModel}}
{{/isArray}}
{{/isNullable}}
{{/Vars}}
return nil
}

// AssertRecurse{{classname}}Required recursively checks if required fields are not zero-ed in a nested slice.
// Accepts only nested slice of {{classname}} (e.g. [][]{{classname}}), otherwise ErrTypeAssertionError is thrown.
func AssertRecurse{{classname}}Required(objSlice interface{}) error {
return AssertRecurseInterfaceRequired(objSlice, func(obj interface{}) error {
a{{classname}}, ok := obj.({{classname}})
if !ok {
return ErrTypeAssertionError
}
return Assert{{classname}}Required(a{{classname}})
})
}{{/model}}{{/models}}
Loading

0 comments on commit 11d29eb

Please sign in to comment.