Skip to content

Commit

Permalink
home: add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
schzhn committed Jan 10, 2024
1 parent 112f1bd commit 855d320
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 28 deletions.
74 changes: 65 additions & 9 deletions internal/home/client.go
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/google/uuid"
"golang.org/x/exp/slices"
Expand Down Expand Up @@ -86,20 +87,47 @@ type persistentClient struct {
IgnoreStatistics bool
}

// parseIDs parses a list of strings into typed fields.
func (c *persistentClient) parseIDs(ids []string) (err error) {
// setTags sets the tags if they are known, otherwise logs an unknown tag.
func (c *persistentClient) setTags(tags []string, known *stringutil.Set) {
for _, t := range tags {
if !known.Has(t) {
log.Info("skipping unknown tag %q", t)

continue
}

c.Tags = append(c.Tags, t)
}

slices.Sort(c.Tags)
}

// setIDs parses a list of strings into typed fields and returns an error if
// there is one.
func (c *persistentClient) setIDs(ids []string) (err error) {
for _, id := range ids {
err = c.checkID(id)
err = c.setID(id)
if err != nil {
return err
}
}

slices.SortFunc(c.IPs, netip.Addr.Compare)
slices.SortFunc(c.Subnets, func(a, b netip.Prefix) int {
return strings.Compare(a.String(), b.String())
})

slices.SortFunc(c.MACs, func(a, b net.HardwareAddr) int {
return strings.Compare(a.String(), b.String())
})

slices.Sort(c.ClientIDs)

return nil
}

// checkID parses id into typed field if there is no error.
func (c *persistentClient) checkID(id string) (err error) {
// setID parses id into typed field if there is no error.
func (c *persistentClient) setID(id string) (err error) {
if id == "" {
return errors.Error("clientid is empty")
}
Expand Down Expand Up @@ -160,10 +188,38 @@ func (c *persistentClient) idsLen() (n int) {
return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs)
}

// clone returns a deep copy of the client, except upstreamConfig,
// compareIDs returns true if the ids of the current and previous clients are
// the same.
func (c *persistentClient) compareIDs(prev *persistentClient) (equal bool) {
n := slices.CompareFunc(c.IPs, prev.IPs, netip.Addr.Compare)
if n != 0 {
return false
}

n = slices.CompareFunc(c.Subnets, prev.Subnets, func(a, b netip.Prefix) int {
return strings.Compare(a.String(), b.String())
})

if n != 0 {
return false
}

n = slices.CompareFunc(c.MACs, prev.MACs, func(a, b net.HardwareAddr) int {
return strings.Compare(a.String(), b.String())
})

if n != 0 {
return false
}

return slices.Compare(c.ClientIDs, prev.ClientIDs) == 0
}

// shallowClone returns a deep copy of the client, except upstreamConfig,
// safeSearchConf, SafeSearch fields, because it's difficult to copy them.
func (c *persistentClient) clone() (sh *persistentClient) {
clone := *c
func (c *persistentClient) shallowClone() (clone *persistentClient) {
clone = &persistentClient{}
*clone = *c

clone.BlockedServices = c.BlockedServices.Clone()
clone.Tags = stringutil.CloneSlice(c.Tags)
Expand All @@ -174,7 +230,7 @@ func (c *persistentClient) clone() (sh *persistentClient) {
clone.MACs = slices.Clone(c.MACs)
clone.ClientIDs = slices.Clone(c.ClientIDs)

return &clone
return clone
}

// closeUpstreams closes the client-specific upstream config of c if any.
Expand Down
125 changes: 125 additions & 0 deletions internal/home/client_internal_test.go
@@ -0,0 +1,125 @@
package home

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestPersistentClient_CompareIDs(t *testing.T) {
const (
ip = "0.0.0.0"
ip1 = "1.1.1.1"
ip2 = "2.2.2.2"

cidr = "0.0.0.0/0"
cidr1 = "1.1.1.1/11"
cidr2 = "2.2.2.2/22"

mac = "00-00-00-00-00-00"
mac1 = "11-11-11-11-11-11"
mac2 = "22-20-22-22-22-22"

cli = "client0"
cli1 = "client1"
cli2 = "client2"
)

testCases := []struct {
name string
ids []string
prevIDs []string
want bool
}{{
name: "single_ip",
ids: []string{ip1},
prevIDs: []string{ip1},
want: true,
}, {
name: "single_ip_not_equal",
ids: []string{ip1},
prevIDs: []string{ip2},
want: false,
}, {
name: "ips_not_equal",
ids: []string{ip1, ip2},
prevIDs: []string{ip1, ip},
want: false,
}, {
name: "ips_mixed_equal",
ids: []string{ip1, ip2},
prevIDs: []string{ip2, ip1},
want: true,
}, {
name: "single_subnet",
ids: []string{cidr1},
prevIDs: []string{cidr1},
want: true,
}, {
name: "subnets_not_equal",
ids: []string{ip1, ip2, cidr1, cidr2},
prevIDs: []string{ip1, ip2, cidr1, cidr},
want: false,
}, {
name: "subnets_mixed_equal",
ids: []string{ip1, ip2, cidr1, cidr2},
prevIDs: []string{cidr2, cidr1, ip2, ip1},
want: true,
}, {
name: "single_mac",
ids: []string{mac1},
prevIDs: []string{mac1},
want: true,
}, {
name: "single_mac_not_equal",
ids: []string{mac1},
prevIDs: []string{mac2},
want: false,
}, {
name: "macs_not_equal",
ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2},
prevIDs: []string{ip1, ip2, cidr1, cidr2, mac1, mac},
want: false,
}, {
name: "macs_mixed_equal",
ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2},
prevIDs: []string{mac2, mac1, cidr2, cidr1, ip2, ip1},
want: true,
}, {
name: "single_client_id",
ids: []string{cli1},
prevIDs: []string{cli1},
want: true,
}, {
name: "single_client_id_not_equal",
ids: []string{cli1},
prevIDs: []string{cli2},
want: false,
}, {
name: "client_ids_not_equal",
ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2, cli1, cli2},
prevIDs: []string{ip1, ip2, cidr1, cidr2, mac1, mac2, cli1, cli},
want: false,
}, {
name: "client_ids_mixed_equal",
ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2, cli1, cli2},
prevIDs: []string{cli2, cli1, mac2, mac1, cidr2, cidr1, ip2, ip1},
want: true,
}}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := &persistentClient{}
err := c.setIDs(tc.ids)
require.Nil(t, err)

prev := &persistentClient{}
err = prev.setIDs(tc.prevIDs)
require.Nil(t, err)

equal := c.compareIDs(prev)
assert.Equal(t, tc.want, equal)
})
}
}
32 changes: 14 additions & 18 deletions internal/home/clients.go
Expand Up @@ -234,7 +234,7 @@ func (o *clientObject) toPersistent(
UpstreamsCacheSize: o.UpstreamsCacheSize,
}

err = cli.parseIDs(o.IDs)
err = cli.setIDs(o.IDs)
if err != nil {
return nil, fmt.Errorf("parsing ids: %w", err)
}
Expand Down Expand Up @@ -266,15 +266,7 @@ func (o *clientObject) toPersistent(

cli.BlockedServices = o.BlockedServices.Clone()

for _, t := range o.Tags {
if allTags.Has(t) {
cli.Tags = append(cli.Tags, t)
} else {
log.Info("skipping unknown tag %q", t)
}
}

slices.Sort(cli.Tags)
cli.setTags(o.Tags, allTags)

return cli, nil
}
Expand Down Expand Up @@ -453,7 +445,7 @@ func (clients *clientsContainer) find(id string) (c *persistentClient, ok bool)
return nil, false
}

return c.clone(), true
return c.shallowClone(), true
}

// shouldCountClient is a wrapper around [clientsContainer.find] to make it a
Expand Down Expand Up @@ -725,14 +717,18 @@ func (clients *clientsContainer) update(prev, c *persistentClient) (err error) {
}
}

if c.compareIDs(prev) {
clients.removeLocked(prev)
clients.addLocked(c)

return nil
}

// Check the ID index.
ids := c.ids()
if !slices.Equal(prev.ids(), ids) {
for _, id := range ids {
existing, ok := clients.idIndex[id]
if ok && existing != prev {
return fmt.Errorf("id %q is used by client with name %q", id, existing.Name)
}
for _, id := range c.ids() {
existing, ok := clients.idIndex[id]
if ok && existing != prev {
return fmt.Errorf("id %q is used by client with name %q", id, existing.Name)
}
}

Expand Down
2 changes: 1 addition & 1 deletion internal/home/clientshttp.go
Expand Up @@ -195,7 +195,7 @@ func (clients *clientsContainer) jsonToClient(
return nil, err
}

err = c.parseIDs(cj.IDs)
err = c.setIDs(cj.IDs)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
Expand Down

0 comments on commit 855d320

Please sign in to comment.