Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[go-server] Feat: add required assertions to models #10068

Merged
merged 17 commits into from
Aug 7, 2021
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public class GoServerCodegen extends AbstractGoCodegen {
protected String sourceFolder = "go";
protected Boolean corsFeatureEnabled = false;
protected Boolean addResponseHeaders = false;

protected Boolean disallowUnknownFields = true;

public GoServerCodegen() {
super();
Expand Down Expand Up @@ -106,6 +106,12 @@ public GoServerCodegen() {
optAddResponseHeaders.defaultValue(addResponseHeaders.toString());
cliOptions.add(optAddResponseHeaders);

// option to disallow unknown fields in body
lwj5 marked this conversation as resolved.
Show resolved Hide resolved
CliOption optDisallowUnknownFields = new CliOption("disallowUnknownFields", "To disallow unknown fields in request body");
optDisallowUnknownFields.setType("bool");
optDisallowUnknownFields.defaultValue(disallowUnknownFields.toString());
cliOptions.add(optDisallowUnknownFields);

/*
* Models. You can write model files using the modelTemplateFiles map.
* if you want to create one template for file, you can do so here.
Expand Down Expand Up @@ -211,6 +217,12 @@ public void processOpts() {
additionalProperties.put("addResponseHeaders", addResponseHeaders);
}

if (additionalProperties.containsKey("disallowUnknownFields")) {
this.setDisallowUnknownFields(convertPropertyToBooleanAndWriteBack("disallowUnknownFields"));
} else {
additionalProperties.put("disallowUnknownFields", disallowUnknownFields);
}

if (additionalProperties.containsKey(CodegenConstants.ENUM_CLASS_PREFIX)) {
setEnumClassPrefix(Boolean.parseBoolean(additionalProperties.get(CodegenConstants.ENUM_CLASS_PREFIX).toString()));
if (enumClassPrefix) {
Expand Down Expand Up @@ -359,4 +371,8 @@ public void setFeatureCORS(Boolean featureCORS) {
public void setAddResponseHeaders(Boolean addResponseHeaders) {
this.addResponseHeaders = addResponseHeaders;
}

public void setDisallowUnknownFields(Boolean disallowUnknownFields) {
this.disallowUnknownFields = disallowUnknownFields;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,36 @@ func (c *{{classname}}Controller) {{nickname}}(w http.ResponseWriter, r *http.Re
{{paramName}} := r.Header.Get("{{baseName}}")
{{/isHeaderParam}}
{{#isBodyParam}}
{{paramName}} := &{{dataType}}{}
if err := json.NewDecoder(r.Body).Decode(&{{paramName}}); err != nil {
{{paramName}} := {{dataType}}{}
d := json.NewDecoder(r.Body)
{{#disallowUnknownFields}}
d.DisallowUnknownFields()
{{/disallowUnknownFields}}
if err := d.Decode(&{{paramName}}); err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
return
}
{{#isArray}}
{{#items.isModel}}
for _, el := range {{paramName}} {
if err := AssertRequired{{baseType}}(el); err != nil {
c.errorHandler(w, r, err, nil)
return
}
}
{{/items.isModel}}
{{/isArray}}
{{^isArray}}
{{#isModel}}
if err := AssertRequired{{baseType}}({{paramName}}); err != nil {
c.errorHandler(w, r, err, nil)
return
}
{{/isModel}}
{{/isArray}}
{{/isBodyParam}}
{{/allParams}}
result, err := c.service.{{nickname}}(r.Context(){{#allParams}}, {{#isBodyParam}}*{{/isBodyParam}}{{paramName}}{{/allParams}})
result, err := c.service.{{nickname}}(r.Context(){{#allParams}}, {{paramName}}{{/allParams}})
// If an error occurred, encode the error with the status code
if err != nil {
c.errorHandler(w, r, err, &result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package {{packageName}}

import (
"fmt"
"net/http"
)

Expand All @@ -18,6 +19,15 @@ func (e *ParsingError) Error() string {
return e.Err.Error()
}

// RequiredError indicates that an error has occurred when parsing request parameters
type RequiredError struct {
Field string
}

func (e *RequiredError) Error() string {
return fmt.Sprintf("required field '%s' is zero value.", e.Field)
}

// ErrorHandler defines the required method for handling error. You may implement it and inject this into a controller if
// you would like errors to be handled differently from the DefaultErrorHandler
type ErrorHandler func(w http.ResponseWriter, r *http.Request, err error, result *ImplResponse)
Expand All @@ -28,6 +38,9 @@ func DefaultErrorHandler(w http.ResponseWriter, r *http.Request, err error, resu
if _, ok := err.(*ParsingError); ok {
// Handle parsing errors
EncodeJSONResponse(err.Error(), func(i int) *int { return &i }(http.StatusBadRequest),{{#addResponseHeaders}} map[string][]string{},{{/addResponseHeaders}} w)
} else if _, ok := err.(*RequiredError); ok {
// Handle missing required errors
EncodeJSONResponse(err.Error(), func(i int) *int { return &i }(http.StatusUnprocessableEntity),{{#addResponseHeaders}} map[string][]string{},{{/addResponseHeaders}} w)
} else {
// Handle all other errors
EncodeJSONResponse(err.Error(), &result.Code,{{#addResponseHeaders}} result.Headers,{{/addResponseHeaders}} w)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
{{>partial_header}}
package {{packageName}}

//Response return a ImplResponse struct filled
import (
"reflect"
)

// Response return a ImplResponse struct filled
func Response(code int, body interface{}) ImplResponse {
return ImplResponse {
Code: code,
Expand All @@ -13,7 +17,7 @@ func Response(code int, body interface{}) ImplResponse {
}
{{#addResponseHeaders}}

//ResponseWithHeaders return a ImplResponse struct filled, including headers
// ResponseWithHeaders return a ImplResponse struct filled, including headers
func ResponseWithHeaders(code int, headers map[string][]string, body interface{}) ImplResponse {
return ImplResponse {
Code: code,
Expand All @@ -22,3 +26,8 @@ func ResponseWithHeaders(code int, headers map[string][]string, body interface{}
}
}
{{/addResponseHeaders}}

// IsZeroValue checks if the val is the zero-ed value
func IsZeroValue(val interface{}) bool {
return val == nil || reflect.DeepEqual(val, reflect.Zero(reflect.TypeOf(val)).Interface())
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,69 @@ type {{classname}} struct {
// {{{description}}}{{/description}}
{{name}} {{#isNullable}}*{{/isNullable}}{{{dataType}}} `json:"{{baseName}}{{^required}},omitempty{{/required}}"{{#vendorExtensions.x-go-custom-tag}} {{{.}}}{{/vendorExtensions.x-go-custom-tag}}`
{{/vars}}
}{{/isEnum}}{{/model}}{{/models}}
}{{/isEnum}}

// AssertRequired{{classname}} checks if the required fields are not zero-ed
func AssertRequired{{classname}}(obj {{classname}}) error {
elements := map[string]interface{}{
{{#requiredVars}} "{{baseName}}": obj.{{name}},
{{/requiredVars}} }
for name, el := range elements {
if isZero := IsZeroValue(el); isZero {
return &RequiredError{Field: name}
}
}

{{#Vars}}
{{#isModel}}
{{#isNullable}}
if obj.{{name}} != nil {
if err := AssertRequired{{baseType}}(*obj.{{name}}); err != nil {
return err
}
}

{{/isNullable}}
{{^isNullable}}
if err := AssertRequired{{baseType}}(obj.{{name}}); err != nil {
return err
}

{{/isNullable}}
{{/isModel}}
{{#isArray}}
{{#items.isModel}}
{{#isNullable}}
if obj.{{name}} != nil {
for _, el := range {{#isNullable}}*{{/isNullable}}obj.{{name}} {
if err := AssertRequired{{items.baseType}}(el); err != nil {
return err
}
}
}

{{/isNullable}}
{{^isNullable}}
for _, el := range {{#isNullable}}*{{/isNullable}}obj.{{name}} {
if err := AssertRequired{{items.baseType}}(el); err != nil {
return err
}
}

{{/isNullable}}
{{/items.isModel}}
{{/isArray}}
{{/Vars}}
{{#parent}}
{{^isMap}}
{{^isArray}}
if err := AssertRequired{{{parent}}}(obj.{{{parent}}}); err != nil {
return err
}

{{/isArray}}
{{/isMap}}
{{/parent}}
return nil
}
{{/model}}{{/models}}
24 changes: 18 additions & 6 deletions samples/server/petstore/go-api-server/go/api_pet.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,18 @@ func (c *PetApiController) Routes() Routes {

// AddPet - Add a new pet to the store
func (c *PetApiController) AddPet(w http.ResponseWriter, r *http.Request) {
pet := &Pet{}
if err := json.NewDecoder(r.Body).Decode(&pet); err != nil {
pet := Pet{}
d := json.NewDecoder(r.Body)
d.DisallowUnknownFields()
if err := d.Decode(&pet); err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
return
}
result, err := c.service.AddPet(r.Context(), *pet)
if err := AssertRequiredPet(pet); err != nil {
c.errorHandler(w, r, err, nil)
return
}
result, err := c.service.AddPet(r.Context(), pet)
// If an error occurred, encode the error with the status code
if err != nil {
c.errorHandler(w, r, err, &result)
Expand Down Expand Up @@ -192,12 +198,18 @@ func (c *PetApiController) GetPetById(w http.ResponseWriter, r *http.Request) {

// UpdatePet - Update an existing pet
func (c *PetApiController) UpdatePet(w http.ResponseWriter, r *http.Request) {
pet := &Pet{}
if err := json.NewDecoder(r.Body).Decode(&pet); err != nil {
pet := Pet{}
d := json.NewDecoder(r.Body)
d.DisallowUnknownFields()
if err := d.Decode(&pet); err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
return
}
result, err := c.service.UpdatePet(r.Context(), *pet)
if err := AssertRequiredPet(pet); err != nil {
c.errorHandler(w, r, err, nil)
return
}
result, err := c.service.UpdatePet(r.Context(), pet)
// If an error occurred, encode the error with the status code
if err != nil {
c.errorHandler(w, r, err, &result)
Expand Down
12 changes: 9 additions & 3 deletions samples/server/petstore/go-api-server/go/api_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,18 @@ func (c *StoreApiController) GetOrderById(w http.ResponseWriter, r *http.Request

// PlaceOrder - Place an order for a pet
func (c *StoreApiController) PlaceOrder(w http.ResponseWriter, r *http.Request) {
order := &Order{}
if err := json.NewDecoder(r.Body).Decode(&order); err != nil {
order := Order{}
d := json.NewDecoder(r.Body)
d.DisallowUnknownFields()
if err := d.Decode(&order); err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
return
}
result, err := c.service.PlaceOrder(r.Context(), *order)
if err := AssertRequiredOrder(order); err != nil {
c.errorHandler(w, r, err, nil)
return
}
result, err := c.service.PlaceOrder(r.Context(), order)
// If an error occurred, encode the error with the status code
if err != nil {
c.errorHandler(w, r, err, &result)
Expand Down
52 changes: 40 additions & 12 deletions samples/server/petstore/go-api-server/go/api_user.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,18 @@ func (c *UserApiController) Routes() Routes {

// CreateUser - Create user
func (c *UserApiController) CreateUser(w http.ResponseWriter, r *http.Request) {
user := &User{}
if err := json.NewDecoder(r.Body).Decode(&user); err != nil {
user := User{}
d := json.NewDecoder(r.Body)
d.DisallowUnknownFields()
if err := d.Decode(&user); err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
return
}
result, err := c.service.CreateUser(r.Context(), *user)
if err := AssertRequiredUser(user); err != nil {
c.errorHandler(w, r, err, nil)
return
}
result, err := c.service.CreateUser(r.Context(), user)
// If an error occurred, encode the error with the status code
if err != nil {
c.errorHandler(w, r, err, &result)
Expand All @@ -121,12 +127,20 @@ func (c *UserApiController) CreateUser(w http.ResponseWriter, r *http.Request) {

// CreateUsersWithArrayInput - Creates list of users with given input array
func (c *UserApiController) CreateUsersWithArrayInput(w http.ResponseWriter, r *http.Request) {
user := &[]User{}
if err := json.NewDecoder(r.Body).Decode(&user); err != nil {
user := []User{}
d := json.NewDecoder(r.Body)
d.DisallowUnknownFields()
if err := d.Decode(&user); err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
return
}
result, err := c.service.CreateUsersWithArrayInput(r.Context(), *user)
for _, el := range user {
if err := AssertRequiredUser(el); err != nil {
c.errorHandler(w, r, err, nil)
return
}
}
result, err := c.service.CreateUsersWithArrayInput(r.Context(), user)
// If an error occurred, encode the error with the status code
if err != nil {
c.errorHandler(w, r, err, &result)
Expand All @@ -139,12 +153,20 @@ func (c *UserApiController) CreateUsersWithArrayInput(w http.ResponseWriter, r *

// CreateUsersWithListInput - Creates list of users with given input array
func (c *UserApiController) CreateUsersWithListInput(w http.ResponseWriter, r *http.Request) {
user := &[]User{}
if err := json.NewDecoder(r.Body).Decode(&user); err != nil {
user := []User{}
d := json.NewDecoder(r.Body)
d.DisallowUnknownFields()
if err := d.Decode(&user); err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
return
}
result, err := c.service.CreateUsersWithListInput(r.Context(), *user)
for _, el := range user {
if err := AssertRequiredUser(el); err != nil {
c.errorHandler(w, r, err, nil)
return
}
}
result, err := c.service.CreateUsersWithListInput(r.Context(), user)
// If an error occurred, encode the error with the status code
if err != nil {
c.errorHandler(w, r, err, &result)
Expand Down Expand Up @@ -221,12 +243,18 @@ func (c *UserApiController) UpdateUser(w http.ResponseWriter, r *http.Request) {
params := mux.Vars(r)
username := params["username"]

user := &User{}
if err := json.NewDecoder(r.Body).Decode(&user); err != nil {
user := User{}
d := json.NewDecoder(r.Body)
d.DisallowUnknownFields()
if err := d.Decode(&user); err != nil {
c.errorHandler(w, r, &ParsingError{Err: err}, nil)
return
}
result, err := c.service.UpdateUser(r.Context(), username, *user)
if err := AssertRequiredUser(user); err != nil {
c.errorHandler(w, r, err, nil)
return
}
result, err := c.service.UpdateUser(r.Context(), username, user)
// If an error occurred, encode the error with the status code
if err != nil {
c.errorHandler(w, r, err, &result)
Expand Down
Loading