diff --git a/batch_requests.go b/batch_requests.go index fc3b82cab11..8a4da004e5e 100644 --- a/batch_requests.go +++ b/batch_requests.go @@ -1,6 +1,7 @@ package main import ( + "crypto/tls" "encoding/json" "fmt" "io/ioutil" @@ -40,7 +41,16 @@ type BatchRequestHandler struct { // doRequest will make the same request but return a BatchReplyUnit func (b *BatchRequestHandler) doRequest(req *http.Request, relURL string) BatchReplyUnit { - resp, err := http.DefaultClient.Do(req) + tr := &http.Transport{TLSClientConfig: &tls.Config{}} + + if cert := getUpstreamCertificate(req.Host, b.API); cert != nil { + tr.TLSClientConfig.Certificates = []tls.Certificate{*cert} + } + + tr.TLSClientConfig.InsecureSkipVerify = config.Global.ProxySSLInsecureSkipVerify + client := &http.Client{Transport: tr} + + resp, err := client.Do(req) if err != nil { log.Error("Webhook request failed: ", err) return BatchReplyUnit{} diff --git a/batch_requests_test.go b/batch_requests_test.go index 35f7129ec91..03ce3d267b6 100644 --- a/batch_requests_test.go +++ b/batch_requests_test.go @@ -1,10 +1,18 @@ package main import ( + "crypto/tls" + "crypto/x509" + "encoding/base64" "encoding/json" "io/ioutil" + "net/http" + "net/http/httptest" + "strings" "testing" + "github.com/TykTechnologies/tyk/apidef" + "github.com/TykTechnologies/tyk/config" "github.com/TykTechnologies/tyk/test" ) @@ -65,3 +73,122 @@ func TestBatch(t *testing.T) { } } } + +var virtBatchTest = `function batchTest (request, session, config) { + // Set up a response object + var response = { + Body: "", + Headers: { + "content-type": "application/json" + }, + Code: 202 + } + + // Batch request + var batch = { + "requests": [ + { + "method": "GET", + "headers": {}, + "body": "", + "relative_url": "{upstream_URL}" + }, + { + "method": "GET", + "headers": {}, + "body": "", + "relative_url": "{upstream_URL}" + } + ], + "suppress_parallel_execution": false + } + + var newBody = TykBatchRequest(JSON.stringify(batch)) + var asJS = JSON.parse(newBody) + for (var i in asJS) { + if (asJS[i].code == 0){ + response.Code = 404 + } + } + return TykJsResponse(response, session.meta_data) + +}` + +func TestSSLBatch(t *testing.T) { + + _, _, combinedClientPEM, clientCert := genCertificate(&x509.Certificate{}) + clientCert.Leaf, _ = x509.ParseCertificate(clientCert.Certificate[0]) + + upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })) + + // Mutual TLS protected upstream + pool := x509.NewCertPool() + pool.AddCert(clientCert.Leaf) + upstream.TLS = &tls.Config{ + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: pool, + InsecureSkipVerify: true, + } + + upstream.StartTLS() + defer upstream.Close() + + _, _, combinedPEM, _ := genServerCertificate() + serverCertID, _ := CertificateManager.Add(combinedPEM, "") + defer CertificateManager.Delete(serverCertID) + + clientCertID, _ := CertificateManager.Add(combinedClientPEM, "") + defer CertificateManager.Delete(clientCertID) + + virtBatchTest = strings.Replace(virtBatchTest, "{upstream_URL}", upstream.URL, 2) + defer upstream.Close() + log.Debug(upstream.URL) + upstreamHost := strings.TrimPrefix(upstream.URL, "https://") + + config.Global.Security.Certificates.Upstream = map[string]string{upstreamHost: clientCertID} + config.Global.HttpServerOptions.UseSSL = true + config.Global.HttpServerOptions.SSLCertificates = []string{serverCertID} + config.Global.ProxySSLInsecureSkipVerify = true + + defer resetTestConfig() + + ts := newTykTestServer() + defer ts.Close() + + buildAndLoadAPI(func(spec *APISpec) { + spec.Proxy.ListenPath = "/" + virtualMeta := apidef.VirtualMeta{ + ResponseFunctionName: "batchTest", + FunctionSourceType: "blob", + FunctionSourceURI: base64.StdEncoding.EncodeToString([]byte(virtBatchTest)), + Path: "/virt", + Method: "GET", + } + v := spec.VersionData.Versions["v1"] + v.UseExtendedPaths = true + v.ExtendedPaths = apidef.ExtendedPathsSet{ + Virtual: []apidef.VirtualMeta{virtualMeta}, + } + spec.VersionData.Versions["v1"] = v + }) + client := &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }}} + + t.Run("Skip verification", func(t *testing.T) { + ts.Run(t, test.TestCase{ + Path: "/virt", Code: 202, Client: client, + }) + }) + + t.Run("Verification required", func(t *testing.T) { + + config.Global.ProxySSLInsecureSkipVerify = false + + ts.Run(t, test.TestCase{ + Path: "/virt", Code: 404, Client: client, + }) + }) + +}