Skip to content

Commit

Permalink
all: persistent client index
Browse files Browse the repository at this point in the history
  • Loading branch information
schzhn committed Feb 1, 2024
1 parent 66b16e2 commit 4a44e99
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 16 deletions.
28 changes: 14 additions & 14 deletions internal/aghalg/orderedmap.go
Expand Up @@ -5,23 +5,25 @@ import (
)

// OrderedMap is the implementation of the ordered map data structure.
type OrderedMap[K comparable, T any] struct {
vals map[K]T
type OrderedMap[K comparable, V any] struct {
vals map[K]V
cmp func(a, b K) int
keys []K
}

// NewOrderedMap initializes the new instance of ordered map. cmp is a sort
// function.
func NewOrderedMap[K comparable, T any](cmp func(a, b K) int) OrderedMap[K, T] {
return OrderedMap[K, T]{
vals: make(map[K]T),
//
// TODO(s.chzhen): Use cmp.Compare in Go 1.21
func NewOrderedMap[K comparable, V any](cmp func(a, b K) int) OrderedMap[K, V] {
return OrderedMap[K, V]{
vals: make(map[K]V),
cmp: cmp,
}
}

// Add adds val with key to the ordered map.
func (m *OrderedMap[K, T]) Add(key K, val T) {
// Set adds val with key to the ordered map.
func (m *OrderedMap[K, V]) Set(key K, val V) {
i, has := slices.BinarySearchFunc(m.keys, key, m.cmp)
if has {
m.keys[i] = key
Expand All @@ -35,18 +37,16 @@ func (m *OrderedMap[K, T]) Add(key K, val T) {
}

// Del removes the value by key from the ordered map.
func (m *OrderedMap[K, T]) Del(key K) {
func (m *OrderedMap[K, V]) Del(key K) {
i, has := slices.BinarySearchFunc(m.keys, key, m.cmp)
if !has {
return
if has {
m.keys = slices.Delete(m.keys, i, 1)
delete(m.vals, key)
}

m.keys = slices.Delete(m.keys, i, 1)
delete(m.vals, key)
}

// Range calls cb for each element of the map. If cb returns false it stops.
func (m *OrderedMap[K, T]) Range(cb func(K, T) bool) {
func (m *OrderedMap[K, V]) Range(cb func(K, V) (cont bool)) {
for _, k := range m.keys {
if !cb(k, m.vals[k]) {
return
Expand Down
4 changes: 2 additions & 2 deletions internal/aghalg/orderedmap_test.go
Expand Up @@ -21,7 +21,7 @@ func TestNewOrderedMap(t *testing.T) {

nums := []int{}
for i, r := range letters {
m.Add(r, i)
m.Set(r, i)
nums = append(nums, i)
}

Expand All @@ -44,7 +44,7 @@ func TestNewOrderedMap(t *testing.T) {
}

gotLetters := []string{}
m.Range(func(k string, v int) bool {
m.Range(func(k string, _ int) bool {
gotLetters = append(gotLetters, k)

return true
Expand Down
194 changes: 194 additions & 0 deletions internal/home/clientindex.go
@@ -0,0 +1,194 @@
package home

import (
"net"
"net/netip"

"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"golang.org/x/exp/slices"
)

// macUID contains MAC and UID.
type macUID struct {
mac net.HardwareAddr
uid UID
}

// clientIndex stores all information about persistent clients.
type clientIndex struct {
clientIDToUID map[string]UID

ipToUID map[netip.Addr]UID

subnetToUID aghalg.OrderedMap[netip.Prefix, UID]

macUIDs []*macUID

uidToClient map[UID]*persistentClient
}

// NewClientIndex initializes the new instance of client index.
func NewClientIndex() (ci *clientIndex) {
return &clientIndex{
clientIDToUID: map[string]UID{},
ipToUID: map[netip.Addr]UID{},
subnetToUID: aghalg.NewOrderedMap[netip.Prefix, UID](subnetCompare),
uidToClient: map[UID]*persistentClient{},
}
}

// add stores information about a persistent client in the index.
func (ci *clientIndex) add(c *persistentClient) {
for _, id := range c.ClientIDs {
ci.clientIDToUID[id] = c.UID
}

for _, ip := range c.IPs {
ci.ipToUID[ip] = c.UID
}

for _, pref := range c.Subnets {
ci.subnetToUID.Set(pref, c.UID)
}

for _, mac := range c.MACs {
ci.macUIDs = append(ci.macUIDs, &macUID{mac, c.UID})
}

ci.uidToClient[c.UID] = c
}

// contains returns true if the index already has information about persistent
// client.
func (ci *clientIndex) contains(c *persistentClient) (ok bool) {
for _, id := range c.ClientIDs {
_, ok = ci.clientIDToUID[id]
if ok {
return true
}
}

for _, ip := range c.IPs {
_, ok = ci.ipToUID[ip]
if ok {
return true
}
}

for _, pref := range c.Subnets {
ci.subnetToUID.Range(func(p netip.Prefix, id UID) bool {
if pref == p {
ok = true

return false
}

return true
})

if ok {
return true
}
}

for _, mac := range c.MACs {
ok = slices.ContainsFunc(ci.macUIDs, func(muid *macUID) bool {
return slices.Compare(mac, muid.mac) == 0
})

if ok {
return true
}
}

return false
}

// find finds persistent client by string represenation of the client ID, IP
// address, or MAC.
func (ci *clientIndex) find(id string) (c *persistentClient, ok bool) {
uid, found := ci.clientIDToUID[id]
if found {
return ci.uidToClient[uid], true
}

ip, err := netip.ParseAddr(id)
if err == nil {
return ci.findByIP(ip)
}

mac, err := net.ParseMAC(id)
if err == nil {
return ci.findByMAC(mac)
}

return nil, false
}

// find finds persistent client by IP address.
func (ci *clientIndex) findByIP(ip netip.Addr) (c *persistentClient, found bool) {
uid, found := ci.ipToUID[ip]
if found {
return ci.uidToClient[uid], true
}

ci.subnetToUID.Range(func(pref netip.Prefix, id UID) bool {
if pref.Contains(ip) {
uid, found = id, true

return false
}

return true
})

if found {
return ci.uidToClient[uid], true
}

return nil, false
}

// find finds persistent client by MAC.
func (ci *clientIndex) findByMAC(mac net.HardwareAddr) (c *persistentClient, found bool) {
var uid UID
found = slices.ContainsFunc(ci.macUIDs, func(muid *macUID) bool {
if slices.Compare(mac, muid.mac) == 0 {
uid = muid.uid

return true
}

return false
})

if found {
return ci.uidToClient[uid], true
}

return nil, false
}

// del removes information about persistent client from the index.
func (ci *clientIndex) del(c *persistentClient) {
for _, id := range c.ClientIDs {
delete(ci.clientIDToUID, id)
}

for _, ip := range c.IPs {
delete(ci.ipToUID, ip)
}

for _, pref := range c.Subnets {
ci.subnetToUID.Del(pref)
}

for _, mac := range c.MACs {
ci.macUIDs = append(ci.macUIDs, &macUID{mac, c.UID})
slices.DeleteFunc(ci.macUIDs, func(muid *macUID) bool {
return slices.Compare(mac, muid.mac) == 0
})
}

delete(ci.uidToClient, c.UID)
}
74 changes: 74 additions & 0 deletions internal/home/clientindex_internal_test.go
@@ -0,0 +1,74 @@
package home

import (
"net/netip"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestClientIndex(t *testing.T) {
var (
cliNone = "1.2.3.4"
cli1 = "1.1.1.1"
cli2 = "2.2.2.2"

cli1IP = netip.MustParseAddr(cli1)
cli2IP = netip.MustParseAddr(cli2)

cliIPv6 = netip.MustParseAddr("1:2:3::4")
)

ci := NewClientIndex()

uid, err := NewUID()
require.NoError(t, err)

client1 := &persistentClient{
Name: "client1",
IPs: []netip.Addr{cli1IP, cliIPv6},
UID: uid,
}

uid, err = NewUID()
require.NoError(t, err)

client2 := &persistentClient{
Name: "client2",
IPs: []netip.Addr{cli2IP},
UID: uid,
}

t.Run("add_find", func(t *testing.T) {
ci.add(client1)
ci.add(client2)

c, ok := ci.find(cli1)
require.True(t, ok)

assert.Equal(t, "client1", c.Name)

c, ok = ci.find("1:2:3::4")
require.True(t, ok)

assert.Equal(t, "client1", c.Name)

c, ok = ci.find(cli2)
require.True(t, ok)

assert.Equal(t, "client2", c.Name)

_, ok = ci.find(cliNone)
assert.False(t, ok)
})

t.Run("contains_delete", func(t *testing.T) {
ok := ci.contains(client1)
require.True(t, ok)

ci.del(client1)
ok = ci.contains(client1)
require.False(t, ok)
})
}

0 comments on commit 4a44e99

Please sign in to comment.