diff --git a/.gitignore b/.gitignore index eb3d3857..c6653ad5 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ /docs/.output /docs/.nuxt /docs/static/sw.js +.idea \ No newline at end of file diff --git a/cache/redis_cache.go b/cache/redis_cache.go index ba645f96..4252601e 100644 --- a/cache/redis_cache.go +++ b/cache/redis_cache.go @@ -3,6 +3,7 @@ package cache import ( "bytes" "context" + "encoding/base64" "encoding/json" "github.com/contentsquare/chproxy/config" "github.com/contentsquare/chproxy/log" @@ -118,13 +119,18 @@ func (r *redisCache) Get(key *Key) (*CachedData, error) { log.Errorf("Not able to fetch TTL for: %s ", key) } + decoded, err := base64.StdEncoding.DecodeString(payload.Payload) + if err != nil { + log.Errorf("failed to decode payload: %s , due to: %v ", payload.Payload, err) + return nil, ErrMissing + } value := &CachedData{ ContentMetadata: ContentMetadata{ Length: payload.Length, Type: payload.Type, Encoding: payload.Encoding, }, - Data: bytes.NewReader([]byte(payload.Payload)), + Data: bytes.NewReader(decoded), Ttl: ttl, } @@ -137,8 +143,9 @@ func (r *redisCache) Put(reader io.Reader, contentMetadata ContentMetadata, key return 0, err } + encoded := base64.StdEncoding.EncodeToString(data) payload := &redisCachePayload{ - Length: contentMetadata.Length, Type: contentMetadata.Type, Encoding: contentMetadata.Encoding, Payload: string(data), + Length: contentMetadata.Length, Type: contentMetadata.Type, Encoding: contentMetadata.Encoding, Payload: encoded, } marshalled, err := json.Marshal(payload) diff --git a/main_test.go b/main_test.go index 1810e2cb..760f2b2e 100644 --- a/main_test.go +++ b/main_test.go @@ -5,6 +5,8 @@ import ( "compress/gzip" "context" "crypto/tls" + "encoding/base64" + "encoding/json" "fmt" "github.com/contentsquare/chproxy/cache" "io" @@ -19,9 +21,9 @@ import ( "testing" "time" + "github.com/alicebob/miniredis/v2" "github.com/contentsquare/chproxy/config" "github.com/contentsquare/chproxy/log" - "github.com/alicebob/miniredis/v2" ) var testDir = "./temp-test-data" @@ -365,7 +367,7 @@ func TestServe(t *testing.T) { str, err := redisClient.Get(key.String()) checkErr(t, err) - if !strings.Contains(str, "Ok") || !strings.Contains(str, "text/plain") || !strings.Contains(str, "charset=utf-8") { + if !strings.Contains(str, base64.StdEncoding.EncodeToString([]byte("Ok."))) || !strings.Contains(str, "text/plain") || !strings.Contains(str, "charset=utf-8") { t.Fatalf("result from cache query is wrong: %s", str) } @@ -376,6 +378,57 @@ func TestServe(t *testing.T) { }, startHTTP, }, + { + "http requests with caching in redis (testcase for base64 encoding/decoding)", + "testdata/http.cache.redis.yml", + func(t *testing.T) { + redisClient.FlushAll() + q := "SELECT 1 FORMAT TabSeparatedWithNamesAndTypes" + req, err := http.NewRequest("GET", "http://127.0.0.1:9090?query="+url.QueryEscape(q), nil) + checkErr(t, err) + + resp := httpRequest(t, req, http.StatusOK) + checkHttpResponse(t, resp, string(bytesWithInvalidUTFPairs)) + resp2 := httpRequest(t, req, http.StatusOK) + // if we do not use base64 to encode/decode the cached payload, EOF error will be thrown here. + checkHttpResponse(t, resp2, string(bytesWithInvalidUTFPairs)) + keys := redisClient.Keys() + if len(keys) != 1 { + t.Fatalf("unexpected amount of keys in redis: %v", len(keys)) + } + + // check cached response + key := &cache.Key{ + Query: []byte(q), + AcceptEncoding: "gzip", + Version: cache.Version, + } + str, err := redisClient.Get(key.String()) + checkErr(t, err) + + type redisCachePayload struct { + Length int64 `json:"l"` + Type string `json:"t"` + Encoding string `json:"enc"` + Payload string `json:"payload"` + } + + var unMarshaledPayload redisCachePayload + err = json.Unmarshal([]byte(str), &unMarshaledPayload) + checkErr(t, err) + if unMarshaledPayload.Payload != base64.StdEncoding.EncodeToString(bytesWithInvalidUTFPairs) { + t.Fatalf("result from cache query is wrong: %s", str) + } + decoded, err := base64.StdEncoding.DecodeString(unMarshaledPayload.Payload) + checkErr(t, err) + + if unMarshaledPayload.Length != int64(len(decoded)) { + t.Fatalf("the declared length %d and actual length %d is not same", unMarshaledPayload.Length, len(decoded)) + } + }, + startHTTP, + }, + { "http gzipped POST request", "testdata/http.cache.yml", @@ -706,6 +759,9 @@ func fakeCHHandler(w http.ResponseWriter, r *http.Request) { fakeCHState.sleep() fmt.Fprint(w, "bar") + case "SELECT 1 FORMAT TabSeparatedWithNamesAndTypes": + w.WriteHeader(http.StatusOK) + w.Write(bytesWithInvalidUTFPairs) default: if strings.Contains(string(query), killQueryPattern) { fakeCHState.kill() @@ -715,6 +771,8 @@ func fakeCHHandler(w http.ResponseWriter, r *http.Request) { } } +var bytesWithInvalidUTFPairs = []byte{239, 191, 189, 1, 32, 50, 239, 191} + var fakeCHState = &stateCH{ syncCH: make(chan struct{}), } diff --git a/proxy_test.go b/proxy_test.go index 9c656013..7dfc7ca6 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -251,9 +251,9 @@ func TestReverseProxy_ServeHTTP1(t *testing.T) { p.users["default"].maxConcurrentQueries = 1 p.users["default"].queueCh = make(chan struct{}, 1) go makeHeavyRequest(p, time.Millisecond*20) - time.Sleep(time.Millisecond * 5) + time.Sleep(time.Millisecond * 1) // in case ci runner is slow go makeHeavyRequest(p, time.Millisecond*20) - time.Sleep(time.Millisecond * 5) + time.Sleep(time.Millisecond * 1) return makeHeavyRequest(p, time.Millisecond*20) }, },