Skip to content

Commit

Permalink
fix: restore req body (#397)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigua-cs committed Feb 26, 2024
1 parent f4daa21 commit b835b2d
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 9 deletions.
51 changes: 51 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,37 @@ func TestServe(t *testing.T) {
},
startTLS,
},
{
"https request body is not empty",
"testdata/https.yml",
func(t *testing.T) {
query := "SELECT SleepTimeout"
buf := bytes.NewBufferString(query)
req, err := http.NewRequest("POST", "https://127.0.0.1:8443", buf)
checkErr(t, err)
req.SetBasicAuth("default", "qwerty")
req.Close = true

resp, err := tlsClient.Do(req)
checkErr(t, err)
if resp.StatusCode != http.StatusGatewayTimeout {
t.Fatalf("unexpected status code: %d; expected: %d", resp.StatusCode, http.StatusGatewayTimeout)
}

bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("error while reading body from response; err: %q", err)
}

b := string(bodyBytes)
if !strings.Contains(b, query) {
t.Fatalf("expected request body: %q; got: %q", query, b)
}

resp.Body.Close()
},
startTLS,
},
{
"https cache with mix query source",
"testdata/https.cache.yml",
Expand Down Expand Up @@ -1019,6 +1050,26 @@ func fakeCHHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprint(w, b)
fmt.Fprint(w, "Ok.\n")
case strings.Contains(q, "SELECT SleepTimeout"):
w.WriteHeader(http.StatusGatewayTimeout)

bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
fmt.Fprintf(w, "query: %s; error while reading body: %s", query, err)
return
}

b := string(bodyBytes)
// Ensure the original request body is not empty and remains unchanged
// after it is processed by getFullQuery.
if b == "" && b != q {
fmt.Fprintf(w, "got original req body: <%s>; escaped query: <%s>", b, q)
return
}

// execute sleep 1.5 sec
time.Sleep(1500 * time.Millisecond)
fmt.Fprint(w, b)
default:
if strings.Contains(string(query), killQueryPattern) {
fakeCHState.kill()
Expand Down
16 changes: 7 additions & 9 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,23 +207,21 @@ func executeWithRetry(
startTime := time.Now()
var since float64

// keep the request body
body, err := io.ReadAll(req.Body)
req.Body.Close()
// Use readAndRestoreRequestBody to read the entire request body into a byte slice,
// and to restore req.Body so that it can be reused later in the code.
body, err := readAndRestoreRequestBody(req)
if err != nil {
since = time.Since(startTime).Seconds()

since := time.Since(startTime).Seconds()
return since, err
}

numRetry := 0
for {
// update body
req.Body = io.NopCloser(bytes.NewBuffer(body))
req.Body.Close()

rp(rw, req)

// Restore req.Body after it's consumed by 'rp' for potential reuse.
req.Body = io.NopCloser(bytes.NewBuffer(body))

err := ctx.Err()
if err != nil {
since = time.Since(startTime).Seconds()
Expand Down
19 changes: 19 additions & 0 deletions proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,25 @@ func TestReverseProxy_ServeHTTP2(t *testing.T) {
t.Fatalf("expected response: %q; got: %q", expected, b)
}
})

t.Run("request body not empty", func(t *testing.T) {
proxy, err := getProxy(goodCfg)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
body := bytes.NewBufferString("SELECT sleep(1.5)")
expected := "SELECT sleep(1.5)"
req := httptest.NewRequest("POST", fakeServer.URL, body)

resp := makeCustomRequest(proxy, req)
b := bbToString(t, resp.Body)
resp.Body.Close()

if !strings.Contains(b, expected) {
t.Fatalf("expected response: %q; got: %q", expected, b)
}

})
}

func getNetwork(s string) *net.IPNet {
Expand Down
23 changes: 23 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,17 @@ func getQuerySnippetFromBody(req *http.Request) string {
// 'read' request body, so it traps into to crc.
// Ignore any errors, since getQuerySnippet is called only
// during error reporting.
// Temporary solution: Quick and dirty way to work with the request body.
// TODO: Create an original copy of req.Body and work with the copy to avoid altering the original request.
// This current approach consumes the req.Body content with io.Copy(io.Discard, crc) to reset the internal state of crc.
// However, it is not the most efficient or safest method, as it modifies the original req.Body.
io.Copy(io.Discard, crc) // nolint
data := crc.String()

// Here, we attempt to restore req.Body by wrapping the string data in a ReadCloser.
// This is part of the temporary solution and should be replaced with a more robust method that does not consume the original req.Body.
req.Body = io.NopCloser(strings.NewReader(data))

u := getDecompressor(req)
if u == nil {
return data
Expand Down Expand Up @@ -295,3 +303,18 @@ func calcCredentialHash(user string, pwd string) (uint32, error) {
_, err := h.Write([]byte(user + pwd))
return h.Sum32(), err
}

// Function to read the request body and return it as a byte slice.
// It also restores the req.Body to be used again.
func readAndRestoreRequestBody(req *http.Request) ([]byte, error) {
// Read the entire request body.
body, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
// Restore the req.Body with a new reader for the original content.
req.Body = io.NopCloser(bytes.NewReader(body))

// Return the read body.
return body, nil
}

0 comments on commit b835b2d

Please sign in to comment.