Skip to content

Commit

Permalink
[Network Path] Fix parsing of query params (#24030)
Browse files Browse the repository at this point in the history
* fix parsing of query params

* use mux test helper func

* fix linter error
  • Loading branch information
ken-schneider authored and alexgallotta committed May 9, 2024
1 parent 8d3a61f commit 6d40de8
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 20 deletions.
22 changes: 11 additions & 11 deletions cmd/system-probe/modules/traceroute.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ func (t *traceroute) Register(httpMux *module.Router) error {
// TODO: what other config should be passed as part of this request?
httpMux.HandleFunc("/traceroute/{host}", func(w http.ResponseWriter, req *http.Request) {
start := time.Now()
vars := mux.Vars(req)
id := getClientID(req)
cfg, err := parseParams(vars)
cfg, err := parseParams(req)
if err != nil {
log.Errorf("invalid params for host: %s: %s", cfg.DestHostname, err)
w.WriteHeader(http.StatusBadRequest)
Expand Down Expand Up @@ -99,17 +98,18 @@ func logTracerouteRequests(host string, client string, runCount uint64, start ti
}
}

func parseParams(vars map[string]string) (tracerouteutil.Config, error) {
func parseParams(req *http.Request) (tracerouteutil.Config, error) {
vars := mux.Vars(req)
host := vars["host"]
port, err := parseUint(vars, "port", 16)
port, err := parseUint(req, "port", 16)
if err != nil {
return tracerouteutil.Config{}, fmt.Errorf("invalid port: %s", err)
}
maxTTL, err := parseUint(vars, "max_ttl", 8)
maxTTL, err := parseUint(req, "max_ttl", 8)
if err != nil {
return tracerouteutil.Config{}, fmt.Errorf("invalid max_ttl: %s", err)
}
timeout, err := parseUint(vars, "timeout", 32)
timeout, err := parseUint(req, "timeout", 32)
if err != nil {
return tracerouteutil.Config{}, fmt.Errorf("invalid timeout: %s", err)
}
Expand All @@ -122,10 +122,10 @@ func parseParams(vars map[string]string) (tracerouteutil.Config, error) {
}, nil
}

func parseUint(vars map[string]string, field string, bitSize int) (uint64, error) {
value, ok := vars[field]
if !ok {
return 0, nil
func parseUint(req *http.Request, field string, bitSize int) (uint64, error) {
if req.URL.Query().Has(field) {
return strconv.ParseUint(req.URL.Query().Get(field), 10, bitSize)
}
return strconv.ParseUint(value, 10, bitSize)

return 0, nil
}
32 changes: 23 additions & 9 deletions cmd/system-probe/modules/traceroute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,36 @@
package modules

import (
"context"
"net/http"
"testing"

tracerouteutil "github.com/DataDog/datadog-agent/pkg/networkpath/traceroute"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"testing"
"github.com/stretchr/testify/require"
)

func TestParseParams(t *testing.T) {
tests := []struct {
name string
vars map[string]string
host string
params map[string]string
expectedConfig tracerouteutil.Config
expectedError string
}{
{
name: "only host",
vars: map[string]string{
"host": "1.2.3.4",
},
name: "only host",
host: "1.2.3.4",
params: map[string]string{},
expectedConfig: tracerouteutil.Config{
DestHostname: "1.2.3.4",
},
},
{
name: "all config",
vars: map[string]string{
"host": "1.2.3.4",
host: "1.2.3.4",
params: map[string]string{
"port": "42",
"max_ttl": "35",
"timeout": "1000",
Expand All @@ -47,7 +52,16 @@ func TestParseParams(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t1 *testing.T) {
config, err := parseParams(tt.vars)
req, err := http.NewRequestWithContext(context.Background(), "GET", "http://example.com", nil)
q := req.URL.Query()
for k, v := range tt.params {
q.Add(k, v)
}
req.URL.RawQuery = q.Encode()
req = mux.SetURLVars(req, map[string]string{"host": tt.host})

require.NoError(t, err)
config, err := parseParams(req)
assert.Equal(t, tt.expectedConfig, config)
if tt.expectedError != "" {
assert.EqualError(t, err, tt.expectedError)
Expand Down

0 comments on commit 6d40de8

Please sign in to comment.