Skip to content

Commit

Permalink
all: imp code, docs & tests
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Jun 15, 2021
1 parent a989e8a commit 8fe7cb0
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 161 deletions.
8 changes: 0 additions & 8 deletions internal/dnsforward/dnsforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,6 @@ func (s *Server) WriteDiskConfig(c *FilteringConfig) {
c.UpstreamDNS = aghstrings.CloneSlice(sc.UpstreamDNS)
}

// UpstreamTimeout returns the copy of actual RDNS configuration.
func (s *Server) UpstreamTimeout() (timeout time.Duration) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()

return s.conf.UpstreamTimeout
}

// RDNSSettings returns the copy of actual RDNS configuration.
func (s *Server) RDNSSettings() (localPTRResolvers []string, resolveClients, resolvePTR bool) {
s.serverLock.RLock()
Expand Down
2 changes: 1 addition & 1 deletion internal/dnsforward/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
result := map[string]string{}
bootstraps := req.BootstrapDNS

timeout := s.UpstreamTimeout()
timeout := s.conf.UpstreamTimeout
for _, host := range req.Upstreams {
err = checkDNS(host, bootstraps, timeout, checkDNSUpstreamExc)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/home/authratelimiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (ab *authRateLimiter) check(usrID string) (left time.Duration) {
// incLocked increments the number of unsuccessful attempts for attempter with
// ip and updates it's blocking moment if needed. For internal use only.
func (ab *authRateLimiter) incLocked(usrID string, now time.Time) {
var until time.Time = now.Add(failedAuthTTL)
until := now.Add(failedAuthTTL)
var attNum uint = 1

a, ok := ab.failedAuths[usrID]
Expand Down
1 change: 0 additions & 1 deletion internal/home/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,6 @@ func (c *configuration) write() error {
dns.LocalPTRResolvers,
dns.ResolveClients,
dns.UsePrivateRDNS = s.RDNSSettings()
dns.UpstreamTimeout = Duration{s.UpstreamTimeout()}
}

if Context.dhcpServer != nil {
Expand Down
285 changes: 135 additions & 150 deletions internal/home/duration_test.go
Original file line number Diff line number Diff line change
@@ -1,186 +1,171 @@
package home

import (
"encoding"
"encoding/json"
"encoding/xml"
"fmt"
"io"
"strings"
"testing"
"time"

"github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v2"
yaml "gopkg.in/yaml.v2"
)

// durationMarshalTester is a helper struct to simplify testing different
// Duration marshalling and unmarshalling cases.
type durationMarshalTester struct {
PtrMap map[string]*Duration `json:"ptr_map"`
PtrSlice []*Duration `json:"ptr_slice"`
PtrValue *Duration `json:"ptr_value"`
PtrArray [1]*Duration `json:"ptr_array"`
Map map[string]Duration `json:"map"`
Slice []Duration `json:"slice"`
Value Duration `json:"value"`
Array [1]Duration `json:"array"`
}

const nl = "\n"
const (
// ErrNotTextMarshaler is returned when passed interface does not
// implement the encoding.TextMarshaler interface.
ErrNotTextMarshaler errors.Error = "not a text marshaler"
// ErrNotTextUnmarshaler is returned when passed interface does not
// implement the encoding.TextUnmarshaler interface.
ErrNotTextUnmarshaler errors.Error = "not a text unmarshaler"
jsonStr = `{` +
`"ptr_map":{"dur":"1ms"},` +
`"ptr_slice":["1ms"],` +
`"ptr_value":"1ms",` +
`"ptr_array":["1ms"],` +
`"map":{"dur":"1ms"},` +
`"slice":["1ms"],` +
`"value":"1ms",` +
`"array":["1ms"]` +
`}`
yamlStr = `ptrmap:` + nl +
` dur: 1ms` + nl +
`ptrslice:` + nl +
`- 1ms` + nl +
`ptrvalue: 1ms` + nl +
`ptrarray:` + nl +
`- 1ms` + nl +
`map:` + nl +
` dur: 1ms` + nl +
`slice:` + nl +
`- 1ms` + nl +
`value: 1ms` + nl +
`array:` + nl +
`- 1ms`
)

// directText implements Encode and Decode methods like other encoding-related
// packages do. Simplifies testing of encoding.TextMarshaler and
// encoding.TextUnmarshaler interfaces implementations.
//
// TODO(e.burkov): Put into aghtest when there will be other
// encoding.TextMarshaler or encoding.TextUnmarshaler implementations.
type directText struct {
// w is an io.Writer that directText will write encoded data.
w io.Writer
// r is an io.Reader that directText will read data to decode from.
r io.Reader
}
// checkFields verifies m's fields. It expects the m to be unmarshalled from
// one of the constant strings above.
func (m *durationMarshalTester) checkFields(t *testing.T, d Duration) {
require.NotNil(t, m.PtrMap)

// Encode expects the v to be an encoding.TextMarshaler and writes the data from
// it using internal writer.
func (e *directText) Encode(v interface{}) (err error) {
val, ok := v.(encoding.TextMarshaler)
if !ok {
return ErrNotTextMarshaler
}
fromPtrMap, ok := m.PtrMap["dur"]
require.True(t, ok)
require.NotNil(t, fromPtrMap)

var data []byte
data, err = val.MarshalText()
if err != nil {
return err
}
require.Len(t, m.PtrSlice, 1)
fromPtrSlice := m.PtrSlice[0]
require.NotNil(t, fromPtrSlice)

_, err = e.w.Write(data)
if err != nil {
return err
}
fromPtrArray := m.PtrArray[0]
require.NotNil(t, fromPtrArray)

return nil
}
require.NotNil(t, m.PtrValue)

// Decode expects the v to be an encoding.TextUnmarshaler. It reads the data
// internal reader passing it into v.
func (e *directText) Decode(v interface{}) (err error) {
val, ok := v.(encoding.TextUnmarshaler)
if !ok {
return ErrNotTextUnmarshaler
}
var fromMap Duration
fromMap, ok = m.Map["dur"]
require.True(t, ok)

var data []byte
data, err = io.ReadAll(e.r)
if err != nil {
return err
}
require.Len(t, m.Slice, 1)

err = val.UnmarshalText(data)
if err != nil {
return err
}

return nil
assert.Equal(t, d, *fromPtrMap)
assert.Equal(t, d, *fromPtrSlice)
assert.Equal(t, d, *m.PtrValue)
assert.Equal(t, d, *fromPtrArray)
assert.Equal(t, d, fromMap)
assert.Equal(t, d, m.Slice[0])
assert.Equal(t, d, m.Value)
assert.Equal(t, d, m.Array[0])
}

// val is the default value throughout tests.
const val = 1 * time.Millisecond

// valStr is a text representation of val.
var valStr = val.String()
// val is the default time.Duration value to be used throughout the tests of
// Duration.
const val = time.Millisecond

func TestDuration_MarshalText(t *testing.T) {
d := Duration{val}
dPtr := &d

m := durationMarshalTester{
PtrMap: map[string]*Duration{"dur": dPtr},
PtrSlice: []*Duration{dPtr},
PtrValue: dPtr,
PtrArray: [1]*Duration{dPtr},
Map: map[string]Duration{"dur": d},
Slice: []Duration{d},
Value: d,
Array: [1]Duration{d},
}

b := &strings.Builder{}
t.Run("json", func(t *testing.T) {
t.Cleanup(b.Reset)
err := json.NewEncoder(b).Encode(m)
require.NoError(t, err)

testCases := []struct {
enc interface {
Encode(v interface{}) (err error)
}
name string
fmtStr string
}{{
enc: yaml.NewEncoder(b),
name: "yaml",
fmtStr: "%s\n",
}, {
enc: json.NewEncoder(b),
name: "json",
fmtStr: "%q\n",
}, {
enc: xml.NewEncoder(b),
name: "xml",
fmtStr: "<Duration>%s</Duration>",
}, {
enc: &directText{
w: b,
},
name: "direct",
fmtStr: "%s",
}}

for _, tc := range testCases {
b.Reset()
t.Run(tc.name, func(t *testing.T) {
err := tc.enc.Encode(d)
require.NoError(t, err)

assert.Equal(t, fmt.Sprintf(tc.fmtStr, val), b.String())
})
}
assert.JSONEq(t, jsonStr, b.String())
})

t.Run("yaml", func(t *testing.T) {
t.Cleanup(b.Reset)
err := yaml.NewEncoder(b).Encode(m)
require.NoError(t, err)

assert.YAMLEq(t, yamlStr, b.String(), b.String())
})

t.Run("direct", func(t *testing.T) {
data, err := d.MarshalText()
require.NoError(t, err)

assert.EqualValues(t, []byte(val.String()), data)
})
}

func TestDuration_UnmarshalText(t *testing.T) {
d := Duration{}

testCases := []struct {
dec interface {
Decode(v interface{}) (err error)
}
name string
}{{
dec: yaml.NewDecoder(
strings.NewReader(valStr),
),
name: "yaml",
}, {
dec: json.NewDecoder(
strings.NewReader(`"` + valStr + `"`),
),
name: "json",
}, {
dec: xml.NewDecoder(
strings.NewReader("<Duration>" + valStr + "</Duration>"),
),
name: "xml",
}, {
dec: &directText{
r: strings.NewReader(valStr),
},
name: "direct",
}}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.dec.Decode(&d)
require.NoError(t, err)

assert.Equal(t, val, d.Duration)
})
}
d := Duration{val}
var m *durationMarshalTester

t.Run("json", func(t *testing.T) {
m = &durationMarshalTester{}

r := strings.NewReader(jsonStr)
err := json.NewDecoder(r).Decode(m)
require.NoError(t, err)

m.checkFields(t, d)
})

t.Run("yaml", func(t *testing.T) {
m = &durationMarshalTester{}

r := strings.NewReader(yamlStr)
err := yaml.NewDecoder(r).Decode(m)
require.NoError(t, err)

m.checkFields(t, d)
})

t.Run("direct", func(t *testing.T) {
dd := &Duration{}

err := dd.UnmarshalText([]byte(d.String()))
require.NoError(t, err)

assert.Equal(t, d, *dd)
})

t.Run("bad_data", func(t *testing.T) {
const wrongDur = "abc"

dec := &directText{
r: strings.NewReader(wrongDur),
}
err := dec.Decode(&d)
require.Error(t, err)

assert.Equal(
t,
fmt.Sprintf("unmarshalling duration: time: invalid duration %q", wrongDur),
err.Error(),
)
const wrongData = `abc`

assert.Error(t, (&Duration{}).UnmarshalText([]byte(wrongData)))
})
}

0 comments on commit 8fe7cb0

Please sign in to comment.