Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[go-server] Feat: add required assertions to models #10068

Merged
merged 17 commits into from
Aug 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions bin/configs/go-server-required.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
generatorName: go-server
outputDir: samples/server/petstore/go-server-required
inputSpec: modules/openapi-generator/src/test/resources/3_0/server-required.yaml
templateDir: modules/openapi-generator/src/main/resources/go-server
additionalProperties:
hideGenerationTimestamp: "true"
packageName: petstoreserver
addResponseHeaders: true
router: "chi"
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,36 @@ func (c *{{classname}}Controller) {{nickname}}(w http.ResponseWriter, r *http.Re
{{paramName}} := r.Header.Get("{{baseName}}")
{{/isHeaderParam}}
{{#isBodyParam}}
{{paramName}} := &{{dataType}}{}
if err := json.NewDecoder(r.Body).Decode(&{{paramName}}); err != nil {
{{paramName}} := {{dataType}}{}
d := json.NewDecoder(r.Body)
{{^isAdditionalPropertiesTrue}}
d.DisallowUnknownFields()
{{/isAdditionalPropertiesTrue}}
if err := d.Decode(&{{paramName}}); err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
return
}
{{#isArray}}
{{#items.isModel}}
for _, el := range {{paramName}} {
if err := Assert{{baseType}}Required(el); err != nil {
c.errorHandler(w, r, err, nil)
return
}
}
{{/items.isModel}}
{{/isArray}}
{{^isArray}}
{{#isModel}}
if err := Assert{{baseType}}Required({{paramName}}); err != nil {
c.errorHandler(w, r, err, nil)
return
}
{{/isModel}}
{{/isArray}}
{{/isBodyParam}}
{{/allParams}}
result, err := c.service.{{nickname}}(r.Context(){{#allParams}}, {{#isBodyParam}}*{{/isBodyParam}}{{paramName}}{{/allParams}})
result, err := c.service.{{nickname}}(r.Context(){{#allParams}}, {{paramName}}{{/allParams}})
// If an error occurred, encode the error with the status code
if err != nil {
c.errorHandler(w, r, err, &result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,16 @@
package {{packageName}}

import (
"errors"
"fmt"
"net/http"
)

var (
// ErrTypeAssertionError is thrown when type an interface does not match the asserted type
ErrTypeAssertionError = errors.New("unable to assert type")
)

// ParsingError indicates that an error has occurred when parsing request parameters
type ParsingError struct {
Err error
Expand All @@ -18,6 +25,15 @@ func (e *ParsingError) Error() string {
return e.Err.Error()
}

// RequiredError indicates that an error has occurred when parsing request parameters
type RequiredError struct {
Field string
}

func (e *RequiredError) Error() string {
return fmt.Sprintf("required field '%s' is zero value.", e.Field)
}

// ErrorHandler defines the required method for handling error. You may implement it and inject this into a controller if
// you would like errors to be handled differently from the DefaultErrorHandler
type ErrorHandler func(w http.ResponseWriter, r *http.Request, err error, result *ImplResponse)
Expand All @@ -28,6 +44,9 @@ func DefaultErrorHandler(w http.ResponseWriter, r *http.Request, err error, resu
if _, ok := err.(*ParsingError); ok {
// Handle parsing errors
EncodeJSONResponse(err.Error(), func(i int) *int { return &i }(http.StatusBadRequest),{{#addResponseHeaders}} map[string][]string{},{{/addResponseHeaders}} w)
} else if _, ok := err.(*RequiredError); ok {
// Handle missing required errors
EncodeJSONResponse(err.Error(), func(i int) *int { return &i }(http.StatusUnprocessableEntity),{{#addResponseHeaders}} map[string][]string{},{{/addResponseHeaders}} w)
} else {
// Handle all other errors
EncodeJSONResponse(err.Error(), &result.Code,{{#addResponseHeaders}} result.Headers,{{/addResponseHeaders}} w)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
{{>partial_header}}
package {{packageName}}

//Response return a ImplResponse struct filled
import (
"reflect"
)

// Response return a ImplResponse struct filled
func Response(code int, body interface{}) ImplResponse {
return ImplResponse {
Code: code,
Expand All @@ -13,7 +17,7 @@ func Response(code int, body interface{}) ImplResponse {
}
{{#addResponseHeaders}}

//ResponseWithHeaders return a ImplResponse struct filled, including headers
// ResponseWithHeaders return a ImplResponse struct filled, including headers
func ResponseWithHeaders(code int, headers map[string][]string, body interface{}) ImplResponse {
return ImplResponse {
Code: code,
Expand All @@ -22,3 +26,35 @@ func ResponseWithHeaders(code int, headers map[string][]string, body interface{}
}
}
{{/addResponseHeaders}}

// IsZeroValue checks if the val is the zero-ed value.
func IsZeroValue(val interface{}) bool {
return val == nil || reflect.DeepEqual(val, reflect.Zero(reflect.TypeOf(val)).Interface())
}

// AssertInterfaceRequired recursively checks each struct in a slice against the callback.
// This method traverse nested slices in a preorder fashion.
func AssertRecurseInterfaceRequired(obj interface{}, callback func(interface{}) error) error {
return AssertRecurseValueRequired(reflect.ValueOf(obj), callback)
}

// AssertNestedValueRequired checks each struct in the nested slice against the callback.
// This method traverse nested slices in a preorder fashion.
func AssertRecurseValueRequired(value reflect.Value, callback func(interface{}) error) error {
switch value.Kind() {
// If it is a struct we check using callback
case reflect.Struct:
if err := callback(value.Interface()); err != nil {
return err
}

// If it is a slice we continue recursion
case reflect.Slice:
for i := 0; i < value.Len(); i += 1 {
if err := AssertRecurseValueRequired(value.Index(i), callback); err != nil {
return err
}
}
}
return nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,101 @@ type {{classname}} struct {
{{/deprecated}}
{{name}} {{#isNullable}}*{{/isNullable}}{{{dataType}}} `json:"{{baseName}}{{^required}},omitempty{{/required}}"{{#vendorExtensions.x-go-custom-tag}} {{{.}}}{{/vendorExtensions.x-go-custom-tag}}`
{{/vars}}
}{{/isEnum}}{{/model}}{{/models}}
}{{/isEnum}}

// Assert{{classname}}Required checks if the required fields are not zero-ed
func Assert{{classname}}Required(obj {{classname}}) error {
{{#hasRequired}}
elements := map[string]interface{}{
{{#requiredVars}} "{{baseName}}": obj.{{name}},
{{/requiredVars}} }
for name, el := range elements {
if isZero := IsZeroValue(el); isZero {
return &RequiredError{Field: name}
}
}

{{/hasRequired}}
{{#parent}}
{{^isMap}}
{{^isArray}}
if err := Assert{{{parent}}}Required(obj.{{{parent}}}); err != nil {
return err
}

{{/isArray}}
{{/isMap}}
{{/parent}}
{{#Vars}}
{{#isNullable}}
{{#isModel}}
if obj.{{name}} != nil {
{{/isModel}}
{{#isArray}}
{{#items.isModel}}
if obj.{{name}} != nil {
{{/items.isModel}}
{{^items.isModel}}
{{#mostInnerItems.isModel}}
{{^mostInnerItems.isPrimitiveType}}
if obj.{{name}} != nil {
{{/mostInnerItems.isPrimitiveType}}
{{/mostInnerItems.isModel}}
{{/items.isModel}}
{{/isArray}}
{{/isNullable}}
{{#isModel}}
{{#isNullable}} {{/isNullable}} if err := Assert{{baseType}}Required({{#isNullable}}*{{/isNullable}}obj.{{name}}); err != nil {
{{#isNullable}} {{/isNullable}} return err
{{#isNullable}} {{/isNullable}} }
{{/isModel}}
{{#isArray}}
{{#items.isModel}}
{{#isNullable}} {{/isNullable}} for _, el := range {{#isNullable}}*{{/isNullable}}obj.{{name}} {
{{#isNullable}} {{/isNullable}} if err := Assert{{items.baseType}}Required(el); err != nil {
{{#isNullable}} {{/isNullable}} return err
{{#isNullable}} {{/isNullable}} }
{{#isNullable}} {{/isNullable}} }
{{/items.isModel}}
{{^items.isModel}}
{{#mostInnerItems.isModel}}
{{^mostInnerItems.isPrimitiveType}}
{{#isNullable}} {{/isNullable}} if err := AssertRecurse{{mostInnerItems.dataType}}Required({{#isNullable}}*{{/isNullable}}obj.{{name}}); err != nil {
{{#isNullable}} {{/isNullable}} return err
{{#isNullable}} {{/isNullable}} }
{{/mostInnerItems.isPrimitiveType}}
{{/mostInnerItems.isModel}}
{{/items.isModel}}
{{/isArray}}
{{#isNullable}}
{{#isModel}}
}
{{/isModel}}
{{#isArray}}
{{#items.isModel}}
}
{{/items.isModel}}
{{^items.isModel}}
{{#mostInnerItems.isModel}}
{{^mostInnerItems.isPrimitiveType}}
}
{{/mostInnerItems.isPrimitiveType}}
{{/mostInnerItems.isModel}}
{{/items.isModel}}
{{/isArray}}
{{/isNullable}}
{{/Vars}}
return nil
}

// AssertRecurse{{classname}}Required recursively checks if required fields are not zero-ed in a nested slice.
// Accepts only nested slice of {{classname}} (e.g. [][]{{classname}}), otherwise ErrTypeAssertionError is thrown.
func AssertRecurse{{classname}}Required(objSlice interface{}) error {
return AssertRecurseInterfaceRequired(objSlice, func(obj interface{}) error {
a{{classname}}, ok := obj.({{classname}})
if !ok {
return ErrTypeAssertionError
}
return Assert{{classname}}Required(a{{classname}})
})
}{{/model}}{{/models}}
Loading