Skip to content

Commit

Permalink
config: moved custom value types into types.go and various utils into…
Browse files Browse the repository at this point in the history
… utils.go
  • Loading branch information
valyala committed Nov 1, 2017
1 parent a5064d3 commit b0cc348
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 77 deletions.
77 changes: 0 additions & 77 deletions config/config.go
Expand Up @@ -3,8 +3,6 @@ package config
import (
"fmt"
"io/ioutil"
"net"
"strings"
"time"

"gopkg.in/yaml.v2"
Expand Down Expand Up @@ -426,52 +424,6 @@ func (ng *NetworkGroups) UnmarshalYAML(unmarshal func(interface{}) error) error
return checkOverflow(ng.XXX, "network_groups")
}

// Networks is a list of IPNet entities
type Networks []*net.IPNet

// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (n *Networks) UnmarshalYAML(unmarshal func(interface{}) error) error {
var s []string
if err := unmarshal(&s); err != nil {
return err
}
networks := make(Networks, len(s))
for i, s := range s {
ipnet, err := stringToIPnet(s)
if err != nil {
return err
}
networks[i] = ipnet
}
*n = networks
return nil
}

// Contains checks whether passed addr is in the range of networks
func (n Networks) Contains(addr string) bool {
if len(n) == 0 {
return true
}

h, _, err := net.SplitHostPort(addr)
if err != nil {
panic(fmt.Sprintf("BUG: unexpected error while parsing RemoteAddr: %s", err))
}

ip := net.ParseIP(h)
if ip == nil {
panic(fmt.Sprintf("BUG: unexpected error while parsing IP: %s", h))
}

for _, ipnet := range n {
if ipnet.Contains(ip) {
return true
}
}

return false
}

// NetworksOrGroups is a list of strings with names of NetworkGroups
// or just Networks
type NetworksOrGroups []string
Expand Down Expand Up @@ -623,32 +575,3 @@ func (c Config) checkVulnerabilities() error {
}
return nil
}

func checkOverflow(m map[string]interface{}, ctx string) error {
if len(m) > 0 {
var keys []string
for k := range m {
keys = append(keys, k)
}
return fmt.Errorf("unknown fields in %s: %s", ctx, strings.Join(keys, ", "))
}
return nil
}

const entireIPv4 = "0.0.0.0/0"

func stringToIPnet(s string) (*net.IPNet, error) {
if s == entireIPv4 {
return nil, fmt.Errorf("suspicious mask specified \"0.0.0.0/0\". " +
"If you want to allow all then just omit `allowed_networks` field")
}
ip := s
if !strings.Contains(ip, `/`) {
ip += "/32"
}
_, ipnet, err := net.ParseCIDR(ip)
if err != nil {
return nil, fmt.Errorf("wrong network group name or address %q: %s", s, err)
}
return ipnet, nil
}
52 changes: 52 additions & 0 deletions config/types.go
@@ -0,0 +1,52 @@
package config

import (
"fmt"
"net"
)

// Networks is a list of IPNet entities
type Networks []*net.IPNet

// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (n *Networks) UnmarshalYAML(unmarshal func(interface{}) error) error {
var s []string
if err := unmarshal(&s); err != nil {
return err
}
networks := make(Networks, len(s))
for i, s := range s {
ipnet, err := stringToIPnet(s)
if err != nil {
return err
}
networks[i] = ipnet
}
*n = networks
return nil
}

// Contains checks whether passed addr is in the range of networks
func (n Networks) Contains(addr string) bool {
if len(n) == 0 {
return true
}

h, _, err := net.SplitHostPort(addr)
if err != nil {
panic(fmt.Sprintf("BUG: unexpected error while parsing RemoteAddr: %s", err))
}

ip := net.ParseIP(h)
if ip == nil {
panic(fmt.Sprintf("BUG: unexpected error while parsing IP: %s", h))
}

for _, ipnet := range n {
if ipnet.Contains(ip) {
return true
}
}

return false
}
36 changes: 36 additions & 0 deletions config/utils.go
@@ -0,0 +1,36 @@
package config

import (
"fmt"
"net"
"strings"
)

const entireIPv4 = "0.0.0.0/0"

func stringToIPnet(s string) (*net.IPNet, error) {
if s == entireIPv4 {
return nil, fmt.Errorf("suspicious mask specified \"0.0.0.0/0\". " +
"If you want to allow all then just omit `allowed_networks` field")
}
ip := s
if !strings.Contains(ip, `/`) {
ip += "/32"
}
_, ipnet, err := net.ParseCIDR(ip)
if err != nil {
return nil, fmt.Errorf("wrong network group name or address %q: %s", s, err)
}
return ipnet, nil
}

func checkOverflow(m map[string]interface{}, ctx string) error {
if len(m) > 0 {
var keys []string
for k := range m {
keys = append(keys, k)
}
return fmt.Errorf("unknown fields in %s: %s", ctx, strings.Join(keys, ", "))
}
return nil
}

0 comments on commit b0cc348

Please sign in to comment.