Skip to content

Commit

Permalink
Add DNS ttl tests
Browse files Browse the repository at this point in the history
  • Loading branch information
buger committed Jan 17, 2018
1 parent 2cd94fd commit 6bc8b5a
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 28 deletions.
28 changes: 15 additions & 13 deletions gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -682,28 +682,30 @@ func TestWithCacheAllSafeRequests(t *testing.T) {
}...)
}

func TestWebsocketsUpstreamUpgradeRequest(t *testing.T) {
// setup spec and do test HTTP upgrade-request
config.Global.HttpServerOptions.EnableWebSockets = true
func TestDNSChange(t *testing.T) {
config.Global.ProxyDefaultTimeout = 1
defer resetTestConfig()

ts := newTykTestServer()
defer ts.Close()

domainsToAddresses.Store("dynamic.local.", "127.0.0.1")

buildAndLoadAPI(func(spec *APISpec) {
spec.Proxy.TargetURL = strings.Replace(testHttpAny, "127.0.0.1", "dynamic.local", -1)
spec.Proxy.ListenPath = "/"
})

ts.Run(t, test.TestCase{
Path: "/ws",
Headers: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-Websocket-Version": "13",
"Sec-Websocket-Key": "abc",
},
Code: http.StatusSwitchingProtocols,
})
ts.Run(t, test.TestCase{Path: "/", Code: 200})

domainsToAddresses.Store("dynamic.local.", "128.0.0.1")

ts.Run(t, test.TestCase{Path: "/", Code: 200})

// DNS ttl is 1 second
time.Sleep(time.Second)

ts.Run(t, test.TestCase{Path: "/", Code: 500})
}

func TestConcurrencyReloads(t *testing.T) {
Expand Down
32 changes: 17 additions & 15 deletions helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,11 +514,7 @@ func buildAndLoadAPI(apiGens ...func(spec *APISpec)) (specs []*APISpec) {
return loadAPI(buildAPI(apiGens...)...)
}

var domainsToAddresses = map[string]string{
"host1.local.": "127.0.0.1",
"host2.local.": "127.0.0.1",
"host3.local.": "127.0.0.1",
}
var domainsToAddresses sync.Map

type dnsMockHandler struct{}

Expand All @@ -530,7 +526,7 @@ func (d *dnsMockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
msg.Authoritative = true
domain := msg.Question[0].Name

address, ok := domainsToAddresses[domain]
address, ok := domainsToAddresses.Load(domain)
if !ok {
// ^ start of line
// localhost\. match literally
Expand All @@ -545,8 +541,8 @@ func (d *dnsMockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}

msg.Answer = append(msg.Answer, &dns.A{
Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(address),
Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 1},
A: net.ParseIP(address.(string)),
})
}
w.WriteMsg(&msg)
Expand All @@ -560,15 +556,21 @@ func initDNSMock() {
dnsMock.Handler = &dnsMockHandler{}
go dnsMock.ActivateAndServe()

domainsToAddresses.Store("host1.local.", "127.0.0.1")
domainsToAddresses.Store("host2.local.", "127.0.0.1")
domainsToAddresses.Store("host3.local.", "127.0.0.1")

defaultResolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{}
return d.DialContext(ctx, network, dnsMock.PacketConn.LocalAddr().String())
},
}

http.DefaultTransport = &http.Transport{
DialContext: (&net.Dialer{
Resolver: &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{}
return d.DialContext(ctx, network, dnsMock.PacketConn.LocalAddr().String())
},
},
Resolver: defaultResolver,
}).DialContext,
}
}
3 changes: 3 additions & 0 deletions reverse_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,8 @@ func (p *ReverseProxy) CheckCircuitBreakerEnforced(spec *APISpec, req *http.Requ
return false, nil
}

var defaultResolver *net.Resolver

func httpTransport(timeOut int, rw http.ResponseWriter, req *http.Request, p *ReverseProxy) http.RoundTripper {
transport := defaultTransport() // modifies a newly created transport
transport.TLSClientConfig = &tls.Config{}
Expand All @@ -421,6 +423,7 @@ func httpTransport(timeOut int, rw http.ResponseWriter, req *http.Request, p *Re
transport.DialContext = (&net.Dialer{
Timeout: time.Duration(timeOut) * time.Second,
KeepAlive: 30 * time.Second,
Resolver: defaultResolver,
}).DialContext
transport.ResponseHeaderTimeout = time.Duration(timeOut) * time.Second
}
Expand Down

0 comments on commit 6bc8b5a

Please sign in to comment.