Skip to content

Commit

Permalink
Merge pull request #12 from Moranilt/feature/multipart-nested-structure
Browse files Browse the repository at this point in the history
feat: add support for nested structures
  • Loading branch information
Moranilt committed Mar 18, 2024
2 parents 948024d + fb8cfec commit 6cc16d5
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 4 deletions.
44 changes: 43 additions & 1 deletion handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,40 @@ func (h *HandlerMaker[ReqT, RespT]) WithQuery() *HandlerMaker[ReqT, RespT] {
// SingleFile *multipart.FileHeader `mapstructure:"single_file"`
// Name string `mapstructure:"name"`
// }
//
// # Supported nested structures.
//
// Example:
//
// type Recipient struct {
// Name string `json:"name,omitempty" mapstructure:"name"`
// Age string `json:"age,omitempty" mapstructure:"age"`
// }
//
// type CreateOrder struct {
// Recipient Recipient `json:"recipient" mapstructure:"recipient"`
// Content map[string]string `json:"content" mapstructure:"content"`
// }
//
// Request body(multipart-form):
//
// {
// "recipient[name]": "John",
// "recipient[age]": "30",
// "content[title]": "content title",
// "content[body]": "content body"
// }
//
// Result:
//
// func main() {
// // ...
// var order CreateOrder
// fmt.Println(order.Recipient.Name) // John
// fmt.Println(order.Recipient.Age) // 30
// fmt.Println(order.Content["title"]) // content title
// fmt.Println(order.Content["body"]) // content body
// }
func (h *HandlerMaker[ReqT, RespT]) WithMultipart(maxMemory int64) *HandlerMaker[ReqT, RespT] {
if h.err != nil {
return h
Expand All @@ -171,7 +205,15 @@ func (h *HandlerMaker[ReqT, RespT]) WithMultipart(maxMemory int64) *HandlerMaker
result := make(map[string]any, len(h.request.MultipartForm.Value)+len(h.request.MultipartForm.File))
for name, value := range h.request.MultipartForm.Value {
if len(value) > 0 {
result[name] = value[0]
fieldName, subName, validName := extractSubName(name)
if validName {
if _, ok := result[fieldName]; !ok {
result[fieldName] = make(map[string]any)
}
result[fieldName].(map[string]any)[subName] = value[0]
} else {
result[name] = value[0]
}
}
}

Expand Down
82 changes: 79 additions & 3 deletions handler/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,18 @@ type mockRequest struct {
Message string `json:"message,omitempty" mapstructure:"message"`
}

type mockRecipient struct {
Name string `json:"name,omitempty" mapstructure:"name"`
Age string `json:"age,omitempty" mapstructure:"age"`
}

type mockMultipartRequest struct {
Files []*multipart.FileHeader `json:"files" mapstructure:"files"`
Name string `json:"name" mapstructure:"name"`
Age string `json:"age" mapstructure:"age"`
Files []*multipart.FileHeader `json:"files" mapstructure:"files"`
SingleFile *multipart.FileHeader `json:"single_file" mapstructure:"single_file"`
Name string `json:"name" mapstructure:"name"`
Age string `json:"age" mapstructure:"age"`
Recipient mockRecipient `json:"recipient" mapstructure:"recipient"`
Content map[string]string `json:"content" mapstructure:"content"`
}

type mockResponse struct {
Expand Down Expand Up @@ -340,6 +348,18 @@ func TestHandler(t *testing.T) {
}
}

if request.SingleFile == nil {
return &mockResponse{
Info: "single file field is empty",
}
}

if request.SingleFile.Filename == "" {
return &mockResponse{
Info: "single file filename is empty",
}
}

if request.Name == "" {
return &mockResponse{
Info: errNameRequired,
Expand All @@ -352,6 +372,30 @@ func TestHandler(t *testing.T) {
}
}

if request.Recipient.Age == "" {
return &mockResponse{
Info: "recipient age required",
}
}

if request.Recipient.Name == "" {
return &mockResponse{
Info: "recipient name required",
}
}

if request.Content["title"] == "" {
return &mockResponse{
Info: "content title required",
}
}

if request.Content["body"] == "" {
return &mockResponse{
Info: "content body required",
}
}

return &mockResponse{
Info: successInfo,
}
Expand Down Expand Up @@ -564,6 +608,22 @@ func mockedMultipartData(t testing.TB) *multipartData {
name: "age",
value: []byte("20"),
},
{
name: "recipient[name]",
value: []byte("Elizabeth"),
},
{
name: "recipient[age]",
value: []byte("21"),
},
{
name: "content[title]",
value: []byte("Content Title"),
},
{
name: "content[body]",
value: []byte("Content Body"),
},
}

var requestBody bytes.Buffer
Expand Down Expand Up @@ -610,6 +670,22 @@ func mockedMultipartData(t testing.TB) *multipartData {
io.Copy(fw, newFile)
}

newFile, err := os.Create("single_file.json")
if err != nil {
t.Errorf("create file: %v", err)
return nil
}

newFile.WriteString("{\"name\": \"Elizabeth\"}")
defer os.Remove(newFile.Name())

fw, err := w.CreateFormFile("single_file", newFile.Name())
if err != nil {
t.Errorf("create form file: %v", err)
return nil
}
io.Copy(fw, newFile)

return &multipartData{
data: &requestBody,
header: w.FormDataContentType(),
Expand Down
17 changes: 17 additions & 0 deletions handler/valid_array_name.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,20 @@ func extractArrayName(name string) (string, bool) {

return name[:i], true
}

func isValidSubName(name string) (int, bool) {
lastIndex := len(name) - 1
if lastIndex == -1 || name[lastIndex] != ']' {
return -1, false
}
open := strings.Index(name, "[")
return open, open != -1 && name[lastIndex] == ']' && lastIndex-open > 1
}

func extractSubName(name string) (string, string, bool) {
i, valid := isValidSubName(name)
if !valid {
return name, "", false
}
return name[:i], name[i+1 : len(name)-1], true
}
47 changes: 47 additions & 0 deletions handler/valid_array_name_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,50 @@ func BenchmarkValidArrayName(b *testing.B) {
isValidNameArray("sahgdjhgajhds[]")
}
}

func TestExtractSubName(t *testing.T) {
tests := []struct {
name string
field string
expected_start string
expected_sub string
valid bool
}{
{
name: "valid array name",
field: "content[title]",
expected_start: "content",
expected_sub: "title",
valid: true,
},
{
name: "not valid name",
field: "content[]",
expected_start: "content[]",
expected_sub: "",
valid: false,
},
{
name: "not valid name",
field: "content",
expected_start: "content",
expected_sub: "",
valid: false,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
start, sub, valid := extractSubName(test.field)
if valid != test.valid {
t.Errorf("got %t, expected %t", valid, test.valid)
}
if start != test.expected_start {
t.Errorf("got %s, expected %s", start, test.expected_start)
}
if sub != test.expected_sub {
t.Errorf("got %s, expected %s", sub, test.expected_sub)
}
})
}
}

0 comments on commit 6cc16d5

Please sign in to comment.