Skip to content

Commit

Permalink
client: add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
schzhn committed Jun 18, 2024
1 parent 702467f commit 045b838
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 93 deletions.
34 changes: 34 additions & 0 deletions internal/client/persistent.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
Expand Down Expand Up @@ -98,6 +99,39 @@ type Persistent struct {
SafeSearchConf filtering.SafeSearchConfig
}

// Validate returns an error if persistent client information contains errors.
func (c *Persistent) Validate(allTags *container.MapSet[string]) (err error) {
switch {
case c.Name == "":
return errors.Error("empty name")
case c.IDsLen() == 0:
return errors.Error("id required")
case c.UID == UID{}:
return errors.Error("uid required")
}

conf, err := proxy.ParseUpstreamsConfig(c.Upstreams, &upstream.Options{})
if err != nil {
return fmt.Errorf("invalid upstream servers: %w", err)
}

err = conf.Close()
if err != nil {
log.Error("client: closing upstream config: %s", err)
}

for _, t := range c.Tags {
if !allTags.Has(t) {
return fmt.Errorf("invalid tag: %q", t)
}
}

// TODO(s.chzhen): Move to the constructor.
slices.Sort(c.Tags)

return nil
}

// SetTags sets the tags if they are known, otherwise logs an unknown tag.
func (c *Persistent) SetTags(tags []string, known *container.MapSet[string]) {
for _, t := range tags {
Expand Down
51 changes: 50 additions & 1 deletion internal/client/persistent_internal_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package client

import (
"net/netip"
"testing"

"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestPersistentClient_EqualIDs(t *testing.T) {
func TestPersistent_EqualIDs(t *testing.T) {
const (
ip = "0.0.0.0"
ip1 = "1.1.1.1"
Expand Down Expand Up @@ -122,3 +124,50 @@ func TestPersistentClient_EqualIDs(t *testing.T) {
})
}
}

func TestPersistent_Validate(t *testing.T) {
// TODO(s.chzhen): Add test cases.
testCases := []struct {
name string
cli *Persistent
wantErrMsg string
}{{
name: "basic",
cli: &Persistent{
Name: "basic",
IPs: []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
},
UID: MustNewUID(),
},
wantErrMsg: "",
}, {
name: "empty_name",
cli: &Persistent{
Name: "",
},
wantErrMsg: "empty name",
}, {
name: "no_id",
cli: &Persistent{
Name: "no_id",
},
wantErrMsg: "id required",
}, {
name: "no_uid",
cli: &Persistent{
Name: "no_uid",
IPs: []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
},
},
wantErrMsg: "uid required",
}}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.cli.Validate(nil)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}
63 changes: 8 additions & 55 deletions internal/client/storage.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
package client

import (
"fmt"
"net/netip"
"slices"
"sync"

"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)

// Storage contains information about persistent and runtime clients.
type Storage struct {
// allTags is a set of all client tags.
allTags *container.MapSet[string]

// mu protects index of persistent clients.
mu *sync.Mutex

Expand All @@ -29,72 +20,35 @@ type Storage struct {
}

// NewStorage returns initialized client storage.
func NewStorage(clientTags []string) (s *Storage) {
allTags := container.NewMapSet(clientTags...)

func NewStorage() (s *Storage) {
return &Storage{
allTags: allTags,
mu: &sync.Mutex{},
index: NewIndex(),
runtimeIndex: map[netip.Addr]*Runtime{},
}
}

// Add stores persistent client information or returns an error. p must be
// valid persistent client.
// valid persistent client. See [Persistent.Validate].
func (s *Storage) Add(p *Persistent) (err error) {
defer func() { err = errors.Annotate(err, "adding client: %w") }()

s.mu.Lock()
defer s.mu.Unlock()

err = s.check(p)
if err != nil {
return fmt.Errorf("adding client: %w", err)
}

s.index.Add(p)

return nil
}

// check returns an error if persistent client information contains errors.
//
// TODO(s.chzhen): Remove persistent client information validation.
func (s *Storage) check(p *Persistent) (err error) {
switch {
case p == nil:
return errors.Error("client is nil")
case p.Name == "":
return errors.Error("empty name")
case p.IDsLen() == 0:
return errors.Error("id required")
case p.UID == UID{}:
return errors.Error("uid required")
}

err = s.index.ClashesUID(p)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}

conf, err := proxy.ParseUpstreamsConfig(p.Upstreams, &upstream.Options{})
if err != nil {
return fmt.Errorf("invalid upstream servers: %w", err)
}

err = conf.Close()
err = s.index.Clashes(p)
if err != nil {
log.Error("client: closing upstream config: %s", err)
}

for _, t := range p.Tags {
if !s.allTags.Has(t) {
return fmt.Errorf("invalid tag: %q", t)
}
// Don't wrap the error since there is already an annotation deferred.
return err
}

// TODO(s.chzhen): Move to the constructor.
slices.Sort(p.Tags)
s.index.Add(p)

return nil
}
Expand All @@ -117,7 +71,6 @@ func (s *Storage) RemoveByName(name string) (ok bool) {
func (s *Storage) Update(p, n *Persistent) (err error) {
defer func() { err = errors.Annotate(err, "updating client: %w") }()

err = s.check(n)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
Expand Down
Loading

0 comments on commit 045b838

Please sign in to comment.