Skip to content

Commit

Permalink
all: imp code
Browse files Browse the repository at this point in the history
  • Loading branch information
schzhn committed Jun 21, 2024
1 parent eae49f9 commit 5fde76b
Show file tree
Hide file tree
Showing 9 changed files with 296 additions and 87 deletions.
5 changes: 3 additions & 2 deletions internal/client/persistent.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ type Persistent struct {
// BlockedServices is the configuration of blocked services of a client.
BlockedServices *filtering.BlockedServices

// Name of the persistent client. Must not be empty.
Name string

Tags []string
Expand Down Expand Up @@ -99,8 +100,8 @@ type Persistent struct {
SafeSearchConf filtering.SafeSearchConfig
}

// Validate returns an error if persistent client information contains errors.
func (c *Persistent) Validate(allTags *container.MapSet[string]) (err error) {
// validate returns an error if persistent client information contains errors.
func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) {
switch {
case c.Name == "":
return errors.Error("empty name")
Expand Down
2 changes: 1 addition & 1 deletion internal/client/persistent_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func TestPersistent_Validate(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.cli.Validate(nil)
err := tc.cli.validate(nil)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
Expand Down
96 changes: 83 additions & 13 deletions internal/client/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,46 @@ import (
"net/netip"
"sync"

"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)

// Storage contains information about persistent and runtime clients.
type Storage struct {
// mu protects index of persistent clients.
// allowedTags is a set of all allowed tags.
allowedTags *container.MapSet[string]

// mu protects indexes of persistent and runtime clients.
mu *sync.Mutex

// index contains information about persistent clients.
index *Index

// runtimeIndex contains information about runtime clients.
runtimeIndex *RuntimeIndex
}

// NewStorage returns initialized client storage.
func NewStorage() (s *Storage) {
func NewStorage(allowedTags *container.MapSet[string]) (s *Storage) {
return &Storage{
mu: &sync.Mutex{},
index: NewIndex(),
allowedTags: allowedTags,
mu: &sync.Mutex{},
index: NewIndex(),
runtimeIndex: NewRuntimeIndex(),
}
}

// Add stores persistent client information or returns an error. p must be
// valid persistent client. See [Persistent.Validate].
// Add stores persistent client information or returns an error.
func (s *Storage) Add(p *Persistent) (err error) {
defer func() { err = errors.Annotate(err, "adding client: %w") }()

err = p.validate(s.allowedTags)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}

s.mu.Lock()
defer s.mu.Unlock()

Expand Down Expand Up @@ -129,11 +143,16 @@ func (s *Storage) RemoveByName(name string) (ok bool) {
}

// Update finds the stored persistent client by its name and updates its
// information from n. n must be valid persistent client. See
// [Persistent.Validate].
func (s *Storage) Update(name string, n *Persistent) (err error) {
// information from p.
func (s *Storage) Update(name string, p *Persistent) (err error) {
defer func() { err = errors.Annotate(err, "updating client: %w") }()

err = p.validate(s.allowedTags)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}

s.mu.Lock()
defer s.mu.Unlock()

Expand All @@ -142,19 +161,19 @@ func (s *Storage) Update(name string, n *Persistent) (err error) {
return fmt.Errorf("client %q is not found", name)
}

// Client n has a newly generated UID, so replace it with the stored one.
// Client p has a newly generated UID, so replace it with the stored one.
//
// TODO(s.chzhen): Remove when frontend starts handling UIDs.
n.UID = stored.UID
p.UID = stored.UID

err = s.index.Clashes(n)
err = s.index.Clashes(p)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}

s.index.Delete(stored)
s.index.Add(n)
s.index.Add(p)

return nil
}
Expand Down Expand Up @@ -183,3 +202,54 @@ func (s *Storage) CloseUpstreams() (err error) {

return s.index.CloseUpstreams()
}

// ClientRuntime returns a copy of the saved runtime client by ip. If no such
// client exists, returns nil.
func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) {
s.mu.Lock()
defer s.mu.Unlock()

return s.runtimeIndex.Client(ip)
}

// AddRuntime saves the runtime client information in the storage. IP address
// of a client must be unique. rc must not be nil.
func (s *Storage) AddRuntime(rc *Runtime) {
s.mu.Lock()
defer s.mu.Unlock()

s.runtimeIndex.Add(rc)
}

// SizeRuntime returns the number of the runtime clients.
func (s *Storage) SizeRuntime() (n int) {
s.mu.Lock()
defer s.mu.Unlock()

return s.runtimeIndex.Size()
}

// RangeRuntime calls f for each runtime client in an undefined order.
func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
s.mu.Lock()
defer s.mu.Unlock()

s.runtimeIndex.Range(f)
}

// DeleteRuntime removes the runtime client by ip.
func (s *Storage) DeleteRuntime(ip netip.Addr) {
s.mu.Lock()
defer s.mu.Unlock()

s.runtimeIndex.Delete(ip)
}

// DeleteBySource removes all runtime clients that have information only from
// the specified source and returns the number of removed clients.
func (s *Storage) DeleteBySource(src Source) (n int) {
s.mu.Lock()
defer s.mu.Unlock()

return s.runtimeIndex.DeleteBySource(src)
}
13 changes: 9 additions & 4 deletions internal/client/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) {
tb.Helper()

s = client.NewStorage()
s = client.NewStorage(nil)

for _, c := range m {
c.UID = client.MustNewUID()
Expand Down Expand Up @@ -57,7 +57,7 @@ func TestStorage_Add(t *testing.T) {
UID: existingClientUID,
}

s := client.NewStorage()
s := client.NewStorage(nil)
err := s.Add(existingClient)
require.NoError(t, err)

Expand Down Expand Up @@ -137,7 +137,7 @@ func TestStorage_RemoveByName(t *testing.T) {
UID: client.MustNewUID(),
}

s := client.NewStorage()
s := client.NewStorage(nil)
err := s.Add(existingClient)
require.NoError(t, err)

Expand All @@ -162,7 +162,7 @@ func TestStorage_RemoveByName(t *testing.T) {
}

t.Run("duplicate_remove", func(t *testing.T) {
s = client.NewStorage()
s = client.NewStorage(nil)
err = s.Add(existingClient)
require.NoError(t, err)

Expand Down Expand Up @@ -366,27 +366,31 @@ func TestStorage_Update(t *testing.T) {
cli: &client.Persistent{
Name: "basic",
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
UID: client.MustNewUID(),
},
wantErrMsg: "",
}, {
name: "duplicate_name",
cli: &client.Persistent{
Name: obstructingName,
IPs: []netip.Addr{netip.MustParseAddr("3.3.3.3")},
UID: client.MustNewUID(),
},
wantErrMsg: `updating client: another client uses the same name "obstructing_name"`,
}, {
name: "duplicate_ip",
cli: &client.Persistent{
Name: "duplicate_ip",
IPs: []netip.Addr{obstructingIP},
UID: client.MustNewUID(),
},
wantErrMsg: `updating client: another client "obstructing_name" uses the same IP "1.2.3.4"`,
}, {
name: "duplicate_subnet",
cli: &client.Persistent{
Name: "duplicate_subnet",
Subnets: []netip.Prefix{obstructingSubnet},
UID: client.MustNewUID(),
},
wantErrMsg: `updating client: another client "obstructing_name" ` +
`uses the same subnet "1.2.3.0/24"`,
Expand All @@ -395,6 +399,7 @@ func TestStorage_Update(t *testing.T) {
cli: &client.Persistent{
Name: "duplicate_client_id",
ClientIDs: []string{obstructingClientID},
UID: client.MustNewUID(),
},
wantErrMsg: `updating client: another client "obstructing_name" ` +
`uses the same ClientID "obstructing_client_id"`,
Expand Down
Loading

0 comments on commit 5fde76b

Please sign in to comment.