-
-
Notifications
You must be signed in to change notification settings - Fork 1
/
withBody.go
228 lines (174 loc) · 5.88 KB
/
withBody.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
package http
import (
"encoding/json"
"reflect"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/go-playground/locales/en"
ut "github.com/go-playground/universal-translator"
en2 "github.com/go-playground/validator/translations/en"
"gopkg.in/go-playground/validator.v9"
)
// DecodeHandlerFunc is a handler which works with withBody decorator.
// It receives a struct which was decoded by withBody decorator before.
// Ex: json -> withBody -> DecodeHandlerFunc.
type DecodeHandlerFunc func(p any, c *fiber.Ctx) error
// PayloadContextValue is a wrapper type used to keep Context.Locals safe.
type PayloadContextValue string
// ConstructorFunc representing a constructor of any type.
type ConstructorFunc func() any
// decoderHandler decodes payload coming from requests.
type decoderHandler struct {
handler DecodeHandlerFunc
constructor ConstructorFunc
structSource any
}
func newOfType(s any) any {
t := reflect.TypeOf(s)
v := reflect.New(t.Elem())
return v.Interface()
}
// FiberHandlerFunc is a method on the decoderHandler struct. It decodes the incoming request's body to a Go struct,
// validates it, checks for any extraneous fields not defined in the struct, and finally calls the wrapped handler function.
func (d *decoderHandler) FiberHandlerFunc(c *fiber.Ctx) error {
var s any
if d.constructor != nil {
s = d.constructor()
} else {
s = newOfType(d.structSource)
}
bodyBytes := c.Body() // Get the body bytes
if err := json.Unmarshal(bodyBytes, s); err != nil {
return err
}
marshaled, err := json.Marshal(s)
if err != nil {
return err
}
var originalMap, marshaledMap map[string]any
if err := json.Unmarshal(bodyBytes, &originalMap); err != nil {
return err
}
if err := json.Unmarshal(marshaled, &marshaledMap); err != nil {
return err
}
// Generate a map that only contains fields that are present in the original payload but not recognized by the Go struct.
diffFields := make(map[string]any)
for key, value := range originalMap {
if _, ok := marshaledMap[key]; !ok {
diffFields[key] = value
}
}
if len(diffFields) > 0 {
return BadRequest(c, fiber.Map{"code": "BAD_REQUEST", "message": "Incoming JSON fields do not match the request payload fields", "fields": diffFields})
}
if err := ValidateStruct(s); err != nil {
return BadRequest(c, err)
}
c.Locals(string(PayloadContextValue("fields")), diffFields)
return d.handler(s, c)
}
// WithDecode wraps a handler function, providing it with a struct instance created using the provided constructor function.
func WithDecode(c ConstructorFunc, h DecodeHandlerFunc) fiber.Handler {
d := &decoderHandler{
handler: h,
constructor: c,
}
return d.FiberHandlerFunc
}
// WithBody wraps a handler function, providing it with an instance of the specified struct.
func WithBody(s any, h DecodeHandlerFunc) fiber.Handler {
d := &decoderHandler{
handler: h,
structSource: s,
}
return d.FiberHandlerFunc
}
// SetBodyInContext is a higher-order function that wraps a Fiber handler, injecting the decoded body into the request context.
func SetBodyInContext(handler fiber.Handler) DecodeHandlerFunc {
return func(s any, c *fiber.Ctx) error {
c.Locals(string(PayloadContextValue("payload")), s)
return handler(c)
}
}
// GetPayloadFromContext retrieves the decoded request payload from the Fiber context.
func GetPayloadFromContext(c *fiber.Ctx) any {
return c.Locals(string(PayloadContextValue("payload")))
}
// ValidateStruct validates a struct against defined validation rules, using the validator package.
func ValidateStruct(s any) error {
v, trans := newValidator()
k := reflect.ValueOf(s).Kind()
if k == reflect.Ptr {
k = reflect.ValueOf(s).Elem().Kind()
}
if k != reflect.Struct {
return nil
}
err := v.Struct(s)
if err != nil {
errPtr := malformedRequestErr(err.(validator.ValidationErrors), trans)
return &errPtr
}
return nil
}
//nolint:ireturn
func newValidator() (*validator.Validate, ut.Translator) {
locale := en.New()
uni := ut.New(locale, locale)
trans, _ := uni.GetTranslator("en")
v := validator.New()
if err := en2.RegisterDefaultTranslations(v, trans); err != nil {
panic(err)
}
v.RegisterTagNameFunc(func(fld reflect.StructField) string {
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
if name == "-" {
return ""
}
return name
})
_ = v.RegisterTranslation("CPF", trans, func(ut ut.Translator) error {
return ut.Add("CPF", "{0} must be a valid Brazilian CPF", true)
}, func(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("CPF", fe.Field())
return t
})
_ = v.RegisterTranslation("localPhoneNumber", trans, func(ut ut.Translator) error {
return ut.Add("localPhoneNumber", "{0} must be a valid phone number without the country code", true)
}, func(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("localPhoneNumber", fe.Field())
return t
})
_ = v.RegisterTranslation("phoneNumber", trans, func(ut ut.Translator) error {
return ut.Add("phoneNumber", "{0} must be a valid phone number", true)
}, func(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("phoneNumber", fe.Field())
return t
})
_ = v.RegisterTranslation("countryCode", trans, func(ut ut.Translator) error {
return ut.Add("countryCode", "{0} must be a valid countryCode registered in https://countrycode.org", true)
}, func(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("countryCode", fe.Field())
return t
})
return v, trans
}
func malformedRequestErr(err validator.ValidationErrors, trans ut.Translator) ValidationError {
return ValidationError{
Code: "400",
Message: "Malformed request.",
Fields: fields(err, trans),
}
}
func fields(errors validator.ValidationErrors, trans ut.Translator) FieldValidations {
l := len(errors)
if l > 0 {
fields := make(FieldValidations, l)
for _, e := range errors {
fields[e.Field()] = e.Translate(trans)
}
return fields
}
return nil
}