Skip to content

Commit

Permalink
Improve tests & Fix CNAME support (#1033)
Browse files Browse the repository at this point in the history
* rework dns tests to not rely on ports

* rework memory handler to support cnames

* deduplicate a bunch of code
  • Loading branch information
BeryJu committed May 22, 2024
1 parent d11ac22 commit aa4a519
Show file tree
Hide file tree
Showing 13 changed files with 623 additions and 128 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ LD_FLAGS = -X beryju.io/gravity/pkg/extconfig.Version=${VERSION}
GO_FLAGS = -ldflags "${LD_FLAGS}" -v
SCHEMA_FILE = schema.yml
TEST_COUNT = 1
TEST_FLAGS = -v
TEST_FLAGS =

ci--env:
echo "sha=${GITHUB_SHA}" >> ${GITHUB_OUTPUT}
Expand Down
14 changes: 8 additions & 6 deletions pkg/roles/api/auth/method_oidc_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package auth_test

import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"

Expand Down Expand Up @@ -31,10 +31,12 @@ func TestAuthOIDC(t *testing.T) {
},
}))))
defer role.Stop()
tests.WaitForPort(8008)

res, err := http.DefaultClient.Get(fmt.Sprintf("http://%s/auth/oidc", tests.Listen(8008)))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.True(t, strings.HasPrefix(res.Request.URL.String(), "http://127.0.0.1:5556/dex/auth/local"))
rr := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/auth/oidc", nil)
role.Mux().ServeHTTP(rr, req)

assert.Equal(t, http.StatusFound, rr.Result().StatusCode)
loc, _ := rr.Result().Location()
assert.True(t, strings.HasPrefix(loc.String(), "http://127.0.0.1:5556/dex/auth"), loc.String())
}
4 changes: 4 additions & 0 deletions pkg/roles/api/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ func New(instance roles.Instance) *Role {
return r
}

func (r *Role) Mux() *mux.Router {
return r.m
}

func (r *Role) SessionStore() sessions.Store {
return r.sessions
}
Expand Down
17 changes: 14 additions & 3 deletions pkg/roles/dns/handler_coredns_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package dns_test

import (
"net"
"testing"

"beryju.io/gravity/pkg/extconfig"
"beryju.io/gravity/pkg/instance"
"beryju.io/gravity/pkg/roles/dns"
"beryju.io/gravity/pkg/roles/dns/types"
"beryju.io/gravity/pkg/tests"
d "github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -45,6 +46,16 @@ func TestRoleDNSHandlerCoreDNS(t *testing.T) {
assert.Nil(t, role.Start(ctx, RoleConfig()))
defer role.Stop()

tests.WaitForPort(1054)
assert.Equal(t, []string{"10.0.0.1"}, tests.DNSLookup("example.org.", extconfig.Get().Listen(1054)))
fw := NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Name: "example.org.",
Qtype: d.TypeA,
Qclass: d.ClassINET,
},
},
})
ans := fw.Msg().Answer[0]
assert.Equal(t, net.ParseIP("10.0.0.1").String(), ans.(*d.A).A.String())
}
53 changes: 26 additions & 27 deletions pkg/roles/dns/handler_etcd.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,38 @@ import (
const EtcdType = "etcd"

type EtcdHandler struct {
log *zap.Logger
z *Zone
log *zap.Logger
z *Zone
lookupKey func(k *storage.Key, qname string, r *utils.DNSRequest) []dns.RR
}

func NewEtcdHandler(z *Zone, config map[string]string) *EtcdHandler {
eh := &EtcdHandler{
z: z,
}
eh.lookupKey = func(k *storage.Key, qname string, r *utils.DNSRequest) []dns.RR {
answers := []dns.RR{}
es := sentry.TransactionFromContext(r.Context()).StartChild("gravity.dns.handler.etcd.get")
defer es.Finish()
key := k.String()
eh.log.Debug("fetching kv key", zap.String("key", key))
es.SetTag("gravity.dns.handler.etcd.key", key)
res, err := eh.z.inst.KV().Get(r.Context(), key, clientv3.WithPrefix())
if err != nil || len(res.Kvs) < 1 {
return answers
}
for _, kv := range res.Kvs {
rec, err := eh.z.recordFromKV(kv)
if err != nil {
continue
}
ans := rec.ToDNS(qname)
if ans != nil {
answers = append(answers, ans)
}
}
return answers
}
eh.log = z.log.With(zap.String("handler", eh.Identifier()))
return eh
}
Expand All @@ -31,31 +55,6 @@ func (eh *EtcdHandler) Identifier() string {
return EtcdType
}

// lookupKey Lookup direct key and fetch all UID entries below it
func (eh *EtcdHandler) lookupKey(k *storage.Key, qname string, r *utils.DNSRequest) []dns.RR {
answers := []dns.RR{}
es := sentry.TransactionFromContext(r.Context()).StartChild("gravity.dns.handler.etcd.get")
defer es.Finish()
key := k.String()
eh.log.Debug("fetching kv key", zap.String("key", key))
es.SetTag("gravity.dns.handler.etcd.key", key)
res, err := eh.z.inst.KV().Get(r.Context(), key, clientv3.WithPrefix())
if err != nil || len(res.Kvs) < 1 {
return answers
}
for _, kv := range res.Kvs {
rec, err := eh.z.recordFromKV(kv)
if err != nil {
continue
}
ans := rec.ToDNS(qname)
if ans != nil {
answers = append(answers, ans)
}
}
return answers
}

func (eh *EtcdHandler) findWildcard(r *utils.DNSRequest, relRecordName string, question dns.Question) []dns.RR {
// Assuming the question is foo.bar.baz and the zone is baz,
// we'll try replacing all names from left to right by starts and query with that
Expand Down
101 changes: 87 additions & 14 deletions pkg/roles/dns/handler_etcd_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package dns_test

import (
"net"
"testing"

"beryju.io/gravity/pkg/extconfig"
"beryju.io/gravity/pkg/instance"
"beryju.io/gravity/pkg/roles/dns"
"beryju.io/gravity/pkg/roles/dns/types"
"beryju.io/gravity/pkg/tests"
d "github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -51,8 +52,18 @@ func TestRoleDNS_Etcd(t *testing.T) {
assert.Nil(t, role.Start(ctx, RoleConfig()))
defer role.Stop()

tests.WaitForPort(1054)
assert.Equal(t, []string{"10.1.2.3"}, tests.DNSLookup("foo.", extconfig.Get().Listen(1054)))
fw := NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Name: "foo.",
Qtype: d.TypeA,
Qclass: d.ClassINET,
},
},
})
ans := fw.Msg().Answer[0]
assert.Equal(t, net.ParseIP("10.1.2.3").String(), ans.(*d.A).A.String())
}

func TestRoleDNS_Etcd_Wildcard(t *testing.T) {
Expand Down Expand Up @@ -95,8 +106,18 @@ func TestRoleDNS_Etcd_Wildcard(t *testing.T) {
assert.Nil(t, role.Start(ctx, RoleConfig()))
defer role.Stop()

tests.WaitForPort(1054)
assert.Equal(t, []string{"10.1.2.3"}, tests.DNSLookup("foo.", extconfig.Get().Listen(1054)))
fw := NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Name: "foo.",
Qtype: d.TypeA,
Qclass: d.ClassINET,
},
},
})
ans := fw.Msg().Answer[0]
assert.Equal(t, net.ParseIP("10.1.2.3").String(), ans.(*d.A).A.String())
}

func TestRoleDNS_Etcd_CNAME(t *testing.T) {
Expand Down Expand Up @@ -153,9 +174,31 @@ func TestRoleDNS_Etcd_CNAME(t *testing.T) {
assert.Nil(t, role.Start(ctx, RoleConfig()))
defer role.Stop()

tests.WaitForPort(1054)
assert.Equal(t, []string{"10.2.3.4"}, tests.DNSLookup("bar.test.", extconfig.Get().Listen(1054)))
assert.Equal(t, []string{"10.2.3.4"}, tests.DNSLookup("foo.test.", extconfig.Get().Listen(1054)))
fw := NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Name: "bar.test.",
Qtype: d.TypeA,
Qclass: d.ClassINET,
},
},
})
ans := fw.Msg().Answer[0]
assert.Equal(t, net.ParseIP("10.2.3.4").String(), ans.(*d.A).A.String())

fw = NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Name: "foo.test.",
Qtype: d.TypeCNAME,
Qclass: d.ClassINET,
},
},
})
ans = fw.Msg().Answer[0]
assert.Equal(t, "bar.test.", ans.(*d.CNAME).Target)
}

func TestRoleDNS_Etcd_WildcardNested(t *testing.T) {
Expand Down Expand Up @@ -198,8 +241,18 @@ func TestRoleDNS_Etcd_WildcardNested(t *testing.T) {
assert.Nil(t, role.Start(ctx, RoleConfig()))
defer role.Stop()

tests.WaitForPort(1054)
assert.Equal(t, []string{"10.1.2.3"}, tests.DNSLookup("foo.bar.", extconfig.Get().Listen(1054)))
fw := NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Name: "foo.bar.",
Qtype: d.TypeA,
Qclass: d.ClassINET,
},
},
})
ans := fw.Msg().Answer[0]
assert.Equal(t, net.ParseIP("10.1.2.3").String(), ans.(*d.A).A.String())
}

func TestRoleDNS_Etcd_MixedCase(t *testing.T) {
Expand Down Expand Up @@ -242,8 +295,18 @@ func TestRoleDNS_Etcd_MixedCase(t *testing.T) {
assert.Nil(t, role.Start(ctx, RoleConfig()))
defer role.Stop()

tests.WaitForPort(1054)
assert.Equal(t, []string{"10.1.2.3"}, tests.DNSLookup("bar.test.", extconfig.Get().Listen(1054)))
fw := NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Name: "bar.test.",
Qtype: d.TypeA,
Qclass: d.ClassINET,
},
},
})
ans := fw.Msg().Answer[0]
assert.Equal(t, net.ParseIP("10.1.2.3").String(), ans.(*d.A).A.String())
}

func TestRoleDNS_Etcd_MixedCase_Reverse(t *testing.T) {
Expand Down Expand Up @@ -286,6 +349,16 @@ func TestRoleDNS_Etcd_MixedCase_Reverse(t *testing.T) {
assert.Nil(t, role.Start(ctx, RoleConfig()))
defer role.Stop()

tests.WaitForPort(1054)
assert.Equal(t, []string{"10.1.2.3"}, tests.DNSLookup("bar.TesT.", extconfig.Get().Listen(1054)))
fw := NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Name: "bar.TesT.",
Qtype: d.TypeA,
Qclass: d.ClassINET,
},
},
})
ans := fw.Msg().Answer[0]
assert.Equal(t, net.ParseIP("10.1.2.3").String(), ans.(*d.A).A.String())
}
17 changes: 14 additions & 3 deletions pkg/roles/dns/handler_forward_blocky_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package dns_test

import (
"net"
"testing"

"beryju.io/gravity/pkg/extconfig"
"beryju.io/gravity/pkg/instance"
"beryju.io/gravity/pkg/roles/dns"
"beryju.io/gravity/pkg/roles/dns/types"
"beryju.io/gravity/pkg/tests"
d "github.com/miekg/dns"
"github.com/stretchr/testify/assert"
clientv3 "go.etcd.io/etcd/client/v3"
)
Expand Down Expand Up @@ -47,6 +48,16 @@ func TestRoleDNS_BlockyForwarder(t *testing.T) {
assert.Nil(t, role.Start(ctx, RoleConfig()))
defer role.Stop()

tests.WaitForPort(1054)
assert.Equal(t, []string{"0.0.0.0", "::"}, tests.DNSLookup("gravity.beryju.io.", extconfig.Get().Listen(1054)))
fw := NewNullDNSWriter()
role.Handler(fw, &d.Msg{
Question: []d.Question{
{
Name: "gravity.beryju.io.",
Qtype: d.TypeA,
Qclass: d.ClassINET,
},
},
})
ans := fw.Msg().Answer[0]
assert.Equal(t, net.ParseIP("0.0.0.0").String(), ans.(*d.A).A.String())
}
Loading

0 comments on commit aa4a519

Please sign in to comment.