diff --git a/gateway_test.go b/gateway_test.go index b8cafdfe869..94fb6526202 100644 --- a/gateway_test.go +++ b/gateway_test.go @@ -682,28 +682,30 @@ func TestWithCacheAllSafeRequests(t *testing.T) { }...) } -func TestWebsocketsUpstreamUpgradeRequest(t *testing.T) { - // setup spec and do test HTTP upgrade-request - config.Global.HttpServerOptions.EnableWebSockets = true +func TestDNSChange(t *testing.T) { + config.Global.ProxyDefaultTimeout = 1 defer resetTestConfig() ts := newTykTestServer() defer ts.Close() + domainsToAddresses.Store("dynamic.local.", "127.0.0.1") + buildAndLoadAPI(func(spec *APISpec) { + spec.Proxy.TargetURL = strings.Replace(testHttpAny, "127.0.0.1", "dynamic.local", -1) spec.Proxy.ListenPath = "/" }) - ts.Run(t, test.TestCase{ - Path: "/ws", - Headers: map[string]string{ - "Connection": "Upgrade", - "Upgrade": "websocket", - "Sec-Websocket-Version": "13", - "Sec-Websocket-Key": "abc", - }, - Code: http.StatusSwitchingProtocols, - }) + ts.Run(t, test.TestCase{Path: "/", Code: 200}) + + domainsToAddresses.Store("dynamic.local.", "128.0.0.1") + + ts.Run(t, test.TestCase{Path: "/", Code: 200}) + + // DNS ttl is 1 second + time.Sleep(time.Second) + + ts.Run(t, test.TestCase{Path: "/", Code: 500}) } func TestConcurrencyReloads(t *testing.T) { diff --git a/helpers_test.go b/helpers_test.go index bc15ccd2eca..a5c53a68179 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -514,11 +514,7 @@ func buildAndLoadAPI(apiGens ...func(spec *APISpec)) (specs []*APISpec) { return loadAPI(buildAPI(apiGens...)...) } -var domainsToAddresses = map[string]string{ - "host1.local.": "127.0.0.1", - "host2.local.": "127.0.0.1", - "host3.local.": "127.0.0.1", -} +var domainsToAddresses sync.Map type dnsMockHandler struct{} @@ -530,7 +526,7 @@ func (d *dnsMockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { msg.Authoritative = true domain := msg.Question[0].Name - address, ok := domainsToAddresses[domain] + address, ok := domainsToAddresses.Load(domain) if !ok { // ^ start of line // localhost\. match literally @@ -545,8 +541,8 @@ func (d *dnsMockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } msg.Answer = append(msg.Answer, &dns.A{ - Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, - A: net.ParseIP(address), + Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 1}, + A: net.ParseIP(address.(string)), }) } w.WriteMsg(&msg) @@ -560,15 +556,21 @@ func initDNSMock() { dnsMock.Handler = &dnsMockHandler{} go dnsMock.ActivateAndServe() + domainsToAddresses.Store("host1.local.", "127.0.0.1") + domainsToAddresses.Store("host2.local.", "127.0.0.1") + domainsToAddresses.Store("host3.local.", "127.0.0.1") + + defaultResolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{} + return d.DialContext(ctx, network, dnsMock.PacketConn.LocalAddr().String()) + }, + } + http.DefaultTransport = &http.Transport{ DialContext: (&net.Dialer{ - Resolver: &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - d := net.Dialer{} - return d.DialContext(ctx, network, dnsMock.PacketConn.LocalAddr().String()) - }, - }, + Resolver: defaultResolver, }).DialContext, } } diff --git a/reverse_proxy.go b/reverse_proxy.go index 1b4e489fe91..b8fcb7f60b1 100644 --- a/reverse_proxy.go +++ b/reverse_proxy.go @@ -407,6 +407,8 @@ func (p *ReverseProxy) CheckCircuitBreakerEnforced(spec *APISpec, req *http.Requ return false, nil } +var defaultResolver *net.Resolver + func httpTransport(timeOut int, rw http.ResponseWriter, req *http.Request, p *ReverseProxy) http.RoundTripper { transport := defaultTransport() // modifies a newly created transport transport.TLSClientConfig = &tls.Config{} @@ -421,6 +423,7 @@ func httpTransport(timeOut int, rw http.ResponseWriter, req *http.Request, p *Re transport.DialContext = (&net.Dialer{ Timeout: time.Duration(timeOut) * time.Second, KeepAlive: 30 * time.Second, + Resolver: defaultResolver, }).DialContext transport.ResponseHeaderTimeout = time.Duration(timeOut) * time.Second }