diff --git a/.travis.yml b/.travis.yml index f78eea2..31ec3b4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,9 @@ language: go go: - - "1.10.4" + - "1.10.8" + - "1.11.9" + - "1.12.4" before_install: - curl -L -s https://github.com/golang/dep/releases/download/v0.5.0/dep-linux-amd64 -o $GOPATH/bin/dep diff --git a/pkg/nettools/consts.go b/pkg/nettools/consts.go new file mode 100644 index 0000000..0bb2bdb --- /dev/null +++ b/pkg/nettools/consts.go @@ -0,0 +1,45 @@ +/* Copyright 2019 DevFactory FZ LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +package nettools + +import ( + "fmt" + "strings" +) + +// Protocol is the type to provide different protocol names constants. +type Protocol string + +const ( + // Unknown is the Protocol returned for unknown protocols + Unknown Protocol = "unknown" + // TCP is the name of the tcp protocol, as used in go's net library + TCP Protocol = "tcp" + // UDP is the name of the udp protocol, as used in go's net library + UDP Protocol = "udp" +) + +// ParseProtocol parses string and returns no error and a known Protocol. +// If the protocol name is unknown, returns Unknown and error. +func ParseProtocol(proto string) (Protocol, error) { + p := strings.ToLower(proto) + switch p { + case "tcp": + return TCP, nil + case "udp": + return UDP, nil + default: + return Unknown, fmt.Errorf("Unknown protocol %s passed for parsing", proto) + } +} diff --git a/pkg/nettools/ipset.go b/pkg/nettools/ipset.go index 75f3258..059ca59 100644 --- a/pkg/nettools/ipset.go +++ b/pkg/nettools/ipset.go @@ -14,9 +14,11 @@ limitations under the License. */ package nettools import ( + "bytes" "fmt" "net" "os/exec" + "strconv" "strings" "github.com/DevFactory/go-tools/pkg/extensions/collections" @@ -29,6 +31,30 @@ const ( IPSetListWithAwk = "ipset list %s | awk " + `'$0 ~ "^Members:$" {found=1; ln=NR}; NR>ln && found == 1 {print $1}'` ) +/* +NetPort allows to store net.IPNet and a port number with protocol for ipsets based on that data. +*/ +type NetPort struct { + Net net.IPNet + Protocol Protocol + Port uint16 +} + +// String returns a format accepted by ipset, ie. +// 10.0.0.0/8,tcp:80 +func (np NetPort) String() string { + return fmt.Sprintf("%s,%s:%d", np.Net.String(), np.Protocol, np.Port) +} + +// Equal returns true only if the NetPort has exactly the same values as the parameter NetPort. +func (np NetPort) Equal(np2 NetPort) bool { + if np.Net.IP.Equal(np2.Net.IP) && bytes.Compare(np.Net.Mask, np2.Net.Mask) == 0 && + np.Port == np2.Port && np.Protocol == np2.Protocol { + return true + } + return false +} + /* IPSetHelper provides methods to manage ipset sets. @@ -46,6 +72,8 @@ type IPSetHelper interface { DeleteSet(name string) error EnsureSetHasOnly(name string, ips []net.IP) error GetIPs(name string) ([]net.IP, error) + EnsureSetHasOnlyNetPort(name string, netports []NetPort) error + GetNetPorts(name string) ([]NetPort, error) } type execIPSetHelper struct { @@ -91,39 +119,123 @@ func (h *execIPSetHelper) DeleteSet(name string) error { } func (h *execIPSetHelper) EnsureSetHasOnly(name string, ips []net.IP) error { - // load the current set and assume tentatively all IPs from it are going - // to be removed - current, err := h.GetIPs(name) + return h.ensureSetHasOnlyGeneric(name, "IP", ipSliceToInterface(ips), + func(setName string) ([]interface{}, error) { + ips, err := h.GetIPs(setName) + return ipSliceToInterface(ips), err + }, + func(e1, e2 interface{}) bool { + return (e1.(net.IP)).Equal(e2.(net.IP)) + }, + func(setName string, obj interface{}) error { + return h.addElementToSet(name, "IP", obj.(net.IP)) + }, + func(setName string, obj interface{}) error { + return h.removeElementFromSet(name, "IP", obj.(net.IP)) + }) +} + +func (h *execIPSetHelper) EnsureSetHasOnlyNetPort(name string, netports []NetPort) error { + return h.ensureSetHasOnlyGeneric(name, "NetPort", netPortSliceToInterface(netports), + func(setName string) ([]interface{}, error) { + nps, err := h.GetNetPorts(setName) + return netPortSliceToInterface(nps), err + }, + func(e1, e2 interface{}) bool { + return e1.(NetPort).Equal(e2.(NetPort)) + }, + func(setName string, obj interface{}) error { + return h.addElementToSet(name, "NetPort", obj.(NetPort)) + }, + func(setName string, obj interface{}) error { + return h.removeElementFromSet(name, "NetPort", obj.(NetPort)) + }) +} + +func (h *execIPSetHelper) GetIPs(name string) ([]net.IP, error) { + // format to parse: + // 127.0.0.1 + // 127.0.0.2 + lines, err := h.getIPSetEntries(name) if err != nil { - return err + return []net.IP{}, err } - newAsInterface := make([]interface{}, len(ips)) - for i, ip := range ips { - newAsInterface[i] = ip + result := make([]net.IP, 0, len(lines)) + for _, line := range lines { + if line == "" { + continue + } + ip := net.ParseIP(strings.TrimSpace(line)) + if ip != nil { + result = append(result, ip) + } } - currentAsInterface := make([]interface{}, len(current)) - for i, ip := range current { - currentAsInterface[i] = ip + return result, nil +} + +func (h *execIPSetHelper) GetNetPorts(name string) ([]NetPort, error) { + // format to parse: + // 10.0.0.0/8,tcp:80 + lines, err := h.getIPSetEntries(name) + if err != nil { + return []NetPort{}, err } - // find the diff - toAdd, toRemove := collections.GetSlicesDifferences(newAsInterface, currentAsInterface, - func(ip1, ip2 interface{}) bool { - return (ip1.(net.IP)).Equal(ip2.(net.IP)) + result := make([]NetPort, 0, len(lines)) + for _, line := range lines { + if line == "" { + continue + } + entry := strings.Split(strings.TrimSpace(line), ",") + _, ipnet, err := net.ParseCIDR(entry[0]) + if err != nil { + return []NetPort{}, err + } + + protoPort := strings.Split(entry[1], ":") + proto, err := ParseProtocol(protoPort[0]) + if err != nil { + return []NetPort{}, err + } + + pt, err := strconv.ParseUint(protoPort[1], 10, 16) + if err != nil { + return []NetPort{}, err + } + port := uint16(pt) + + result = append(result, NetPort{ + Net: *ipnet, + Protocol: proto, + Port: port, }) + } + return result, nil +} - for _, iip := range toAdd { - ip := iip.(net.IP) - log.Debugf("Adding IP %s to ipset %s", ip.String(), name) - if err := h.addIPToSet(name, ip); err != nil { - log.Errorf("Error adding entry %v to ipset %s", ip, name) +func (h *execIPSetHelper) ensureSetHasOnlyGeneric(setName, typeName string, required []interface{}, + getter func(setName string) ([]interface{}, error), + comparer func(e1, e2 interface{}) bool, + adder func(setName string, obj interface{}) error, + remover func(setName string, obj interface{}) error) error { + + current, err := getter(setName) + if err != nil { + return err + } + // find the diff + toAdd, toRemove := collections.GetSlicesDifferences(required, current, comparer) + + for _, el := range toAdd { + log.Debugf("Adding %s %v to ipset %s", typeName, el, setName) + if err := adder(setName, el); err != nil { + log.Errorf("Error adding entry %v to ipset %s", el, setName) return err } } - for _, iip := range toRemove { - ip := iip.(net.IP) - log.Debugf("Removing IP %s from ipset %s", ip.String(), name) - if err := h.removeIPFromSet(name, ip); err != nil { - log.Debugf("Error removing entry %v from ipset %s", ip, name) + for _, el := range toRemove { + log.Debugf("Removing %s %v from ipset %s", typeName, el, setName) + if err := remover(setName, el); err != nil { + log.Debugf("Error removing entry %v from ipset %s", el, setName) return err } } @@ -131,44 +243,51 @@ func (h *execIPSetHelper) EnsureSetHasOnly(name string, ips []net.IP) error { return nil } -func (h *execIPSetHelper) GetIPs(name string) ([]net.IP, error) { +func (h *execIPSetHelper) getIPSetEntries(name string) ([]string, error) { // # ipset list myset | awk '$0 ~ "^Members:$" {found=1; ln=NR}; NR>ln && found == 1 {print $1}' - // 127.0.0.1 - // 127.0.0.2 cmd := fmt.Sprintf(IPSetListWithAwk, name) res := h.exec.RunCommand("sh", "-c", cmd) if res.Err != nil || res.ExitCode != 0 { log.Debugf("Problem listing ipset %s - probably it's OK and it just doesn't exist: "+ "%v, stdErr: %s", name, res.Err, res.StdErr) - return []net.IP{}, res.Err + return nil, res.Err } lines := strings.Split(res.StdOut, "\n") - result := make([]net.IP, 0, len(lines)) - for _, line := range lines { - ip := net.ParseIP(strings.TrimSpace(line)) - if ip != nil { - result = append(result, ip) - } - } - return result, nil + return lines, nil } -func (h *execIPSetHelper) addIPToSet(name string, ip net.IP) error { - res := h.exec.RunCommand("ipset", "add", name, ip.String()) +func (h *execIPSetHelper) addElementToSet(setName, elementTypeName string, element fmt.Stringer) error { + res := h.exec.RunCommand("ipset", "add", setName, element.String()) if res.Err != nil || res.ExitCode != 0 { - log.Errorf("Error adding IP %s to ipset %s: %v, stdErr: %s", - ip.String(), name, res.Err, res.StdErr) + log.Errorf("Error adding %s %s to ipset %s: %v, stdErr: %s", + elementTypeName, element.String(), setName, res.Err, res.StdErr) return res.Err } return nil } -func (h *execIPSetHelper) removeIPFromSet(name string, ip net.IP) error { - res := h.exec.RunCommand("ipset", "del", name, ip.String()) +func (h *execIPSetHelper) removeElementFromSet(setName, elementTypeName string, element fmt.Stringer) error { + res := h.exec.RunCommand("ipset", "del", setName, element.String()) if res.Err != nil || res.ExitCode != 0 { - log.Debugf("Error removing IP %s from ipset %s: %v, stdErr: %s", - ip.String(), name, res.Err, res.StdErr) + log.Debugf("Error removing %s %s from ipset %s: %v, stdErr: %s", + elementTypeName, element.String(), setName, res.Err, res.StdErr) return res.Err } return nil } + +func ipSliceToInterface(ips []net.IP) []interface{} { + res := make([]interface{}, len(ips)) + for i, ip := range ips { + res[i] = ip + } + return res +} + +func netPortSliceToInterface(nps []NetPort) []interface{} { + res := make([]interface{}, len(nps)) + for i, np := range nps { + res[i] = np + } + return res +} diff --git a/pkg/nettools/ipset_test.go b/pkg/nettools/ipset_test.go index a6deabe..8adf30e 100644 --- a/pkg/nettools/ipset_test.go +++ b/pkg/nettools/ipset_test.go @@ -251,6 +251,159 @@ func Test_execIPSetHelper_EnsureSetHasOnly(t *testing.T) { } } +func Test_execIPSetHelper_GetNetPorts(t *testing.T) { + np1, np2 := getSampleNetPorts() + tests := []struct { + name string + setName string + err error + expected []nt.NetPort + mockInfo []*cmdmock.ExecInfo + }{ + { + name: "get from existing empty set", + setName: "12341234abc", + err: nil, + expected: []nt.NetPort{}, + mockInfo: []*cmdmock.ExecInfo{ + { + Expected: fmt.Sprintf("sh -c %s", fmt.Sprintf(nettools.IPSetListWithAwk, "12341234abc")), + Returned: netth.ExecResultOKNoOutput(), + }, + }, + }, + { + name: "get from non existing set", + setName: "12341234abc", + err: &exec.ExitError{ + Stderr: []byte("ipset v6.34: The set with the given name does not exist"), + }, + expected: []nt.NetPort{}, + mockInfo: []*cmdmock.ExecInfo{ + { + Expected: fmt.Sprintf("sh -c %s", fmt.Sprintf(nettools.IPSetListWithAwk, "12341234abc")), + Returned: execResultIpsetNotFound(), + }, + }, + }, + { + name: "get from existing non empty set", + setName: "12341234abc", + err: nil, + expected: []nt.NetPort{np1, np2}, + mockInfo: []*cmdmock.ExecInfo{ + { + Expected: fmt.Sprintf("sh -c %s", fmt.Sprintf(nettools.IPSetListWithAwk, "12341234abc")), + Returned: execResultIpsetNetPorts(), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + execMock := cmdmock.NewMockExecutorFromInfos(t, tt.mockInfo...) + ipSetHelper := nt.NewExecIPSetHelper(execMock) + nps, err := ipSetHelper.GetNetPorts(tt.setName) + assert.Equal(t, tt.expected, nps) + assert.Equal(t, tt.err, err) + execMock.ValidateCallNum() + }) + } +} + +func Test_execIPSetHelper_EnsureSetHasOnlyNetPort(t *testing.T) { + np1, np2 := getSampleNetPorts() + np3 := getDifferentSampleNetPorts() + tests := []struct { + name string + setName string + err error + expected []nt.NetPort + mockInfo []*cmdmock.ExecInfo + }{ + { + name: "sync empty ipset with empty required set", + setName: "12341234abc", + err: nil, + expected: []nt.NetPort{}, + mockInfo: []*cmdmock.ExecInfo{ + { + Expected: fmt.Sprintf("sh -c %s", fmt.Sprintf(nettools.IPSetListWithAwk, "12341234abc")), + Returned: netth.ExecResultOKNoOutput(), + }, + }, + }, + { + name: "sync empty ipset with non empty required set", + setName: "12341234abc", + err: nil, + expected: []nt.NetPort{np1, np2}, + mockInfo: []*cmdmock.ExecInfo{ + { + Expected: fmt.Sprintf("sh -c %s", fmt.Sprintf(nettools.IPSetListWithAwk, "12341234abc")), + Returned: netth.ExecResultOKNoOutput(), + }, + { + Expected: fmt.Sprintf("ipset add 12341234abc %s", np1), + Returned: netth.ExecResultOKNoOutput(), + }, + { + Expected: fmt.Sprintf("ipset add 12341234abc %s", np2), + Returned: netth.ExecResultOKNoOutput(), + }, + }, + }, + { + name: "sync non empty ipset with empty required set", + setName: "12341234abc", + err: nil, + expected: []nt.NetPort{}, + mockInfo: []*cmdmock.ExecInfo{ + { + Expected: fmt.Sprintf("sh -c %s", fmt.Sprintf(nettools.IPSetListWithAwk, "12341234abc")), + Returned: execResultIpsetNetPorts(), + }, + { + Expected: fmt.Sprintf("ipset del 12341234abc %s", np1), + Returned: netth.ExecResultOKNoOutput(), + }, + { + Expected: fmt.Sprintf("ipset del 12341234abc %s", np2), + Returned: netth.ExecResultOKNoOutput(), + }, + }, + }, + { + name: "sync non empty ipset with non empty required set", + setName: "12341234abc", + err: nil, + expected: []nt.NetPort{np2, np3}, + mockInfo: []*cmdmock.ExecInfo{ + { + Expected: fmt.Sprintf("sh -c %s", fmt.Sprintf(nettools.IPSetListWithAwk, "12341234abc")), + Returned: execResultIpsetNetPorts(), + }, + { + Expected: fmt.Sprintf("ipset add 12341234abc %s", np3), + Returned: netth.ExecResultOKNoOutput(), + }, + { + Expected: fmt.Sprintf("ipset del 12341234abc %s", np1), + Returned: netth.ExecResultOKNoOutput(), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + execMock := cmdmock.NewMockExecutorFromInfos(t, tt.mockInfo...) + ipSetHelper := nt.NewExecIPSetHelper(execMock) + err := ipSetHelper.EnsureSetHasOnlyNetPort(tt.setName, tt.expected) + assert.Equal(t, tt.err, err) + execMock.ValidateCallNum() + }) + } +} func execResultIpsetNotFound() *command.ExecResult { return &command.ExecResult{ ExitCode: 1, @@ -266,3 +419,147 @@ func execResultIpsetIPs() *command.ExecResult { StdOut: "127.0.0.1\n127.0.0.2\n", } } + +func execResultIpsetNetPorts() *command.ExecResult { + np1, np2 := getSampleNetPorts() + return &command.ExecResult{ + StdOut: fmt.Sprintf("%s\n%s\n", np1.String(), np2.String()), + } +} + +func getSampleNetPorts() (nt.NetPort, nt.NetPort) { + _, netAddr1, _ := net.ParseCIDR("10.10.0.0/24") + np1 := nt.NetPort{ + Net: *netAddr1, + Port: 80, + Protocol: nt.TCP, + } + _, netAddr2, _ := net.ParseCIDR("10.20.0.0/24") + np2 := nt.NetPort{ + Net: *netAddr2, + Port: 8080, + Protocol: nt.UDP, + } + return np1, np2 +} + +func getDifferentSampleNetPorts() nt.NetPort { + _, netAddr2, _ := net.ParseCIDR("10.120.0.0/24") + np := nt.NetPort{ + Net: *netAddr2, + Port: 8080, + Protocol: nt.UDP, + } + return np +} + +func TestNetPort_String(t *testing.T) { + _, subnet, _ := net.ParseCIDR("10.0.0.0/8") + type fields struct { + Net net.IPNet + Protocol nt.Protocol + Port uint16 + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "formatting test", + fields: fields{ + Net: *subnet, + Port: 8080, + Protocol: nt.UDP, + }, + want: "10.0.0.0/8,udp:8080", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + np := nt.NetPort{ + Net: tt.fields.Net, + Protocol: tt.fields.Protocol, + Port: tt.fields.Port, + } + if got := np.String(); got != tt.want { + t.Errorf("NetPort.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNetPort_Equal(t *testing.T) { + _, subnet1, _ := net.ParseCIDR("10.0.0.0/8") + _, subnet2, _ := net.ParseCIDR("10.0.0.0/16") + tests := []struct { + name string + np1 nt.NetPort + np2 nt.NetPort + want bool + }{ + { + name: "should be equal", + np1: nt.NetPort{ + Net: *subnet1, + Port: 80, + Protocol: nt.TCP, + }, + np2: nt.NetPort{ + Net: *subnet1, + Port: 80, + Protocol: nt.TCP, + }, + want: true, + }, + { + name: "not equal - subnet", + np1: nt.NetPort{ + Net: *subnet1, + Port: 80, + Protocol: nt.TCP, + }, + np2: nt.NetPort{ + Net: *subnet2, + Port: 80, + Protocol: nt.TCP, + }, + want: false, + }, + { + name: "not equal - port", + np1: nt.NetPort{ + Net: *subnet1, + Port: 80, + Protocol: nt.TCP, + }, + np2: nt.NetPort{ + Net: *subnet1, + Port: 8080, + Protocol: nt.TCP, + }, + want: false, + }, + { + name: "not equal - protocol", + np1: nt.NetPort{ + Net: *subnet1, + Port: 80, + Protocol: nt.TCP, + }, + np2: nt.NetPort{ + Net: *subnet1, + Port: 80, + Protocol: nt.UDP, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.np1.Equal(tt.np2); got != tt.want { + t.Errorf("NetPort.Equal() = %v, want %v", got, tt.want) + } + }) + } +}