diff --git a/npm/testpolicies/allow-ns-y-z-pod-b-c-with-namedPort.yaml b/npm/testpolicies/allow-ns-y-z-pod-b-c-with-namedPort.yaml new file mode 100644 index 0000000000..b42eb7d367 --- /dev/null +++ b/npm/testpolicies/allow-ns-y-z-pod-b-c-with-namedPort.yaml @@ -0,0 +1,39 @@ +apiVersion: networking.k8s.io/v1 +kind: NetworkPolicy +metadata: + name: allow-ns-y-z-pod-b-c-with-namedPort + namespace: netpol-4537-x +spec: + ingress: + - from: + - namespaceSelector: + matchExpressions: + - key: ns + operator: NotIn + values: + - netpol-4537-x + - netpol-4537-y + podSelector: + matchExpressions: + - key: pod + operator: In + values: + - b + - c + - key: app + operator: In + values: + - test + - int + ports: + - port: serve-80 + protocol: TCP + podSelector: + matchExpressions: + - key: pod + operator: In + values: + - a + - x + policyTypes: + - Ingress diff --git a/npm/translatePolicy.go b/npm/translatePolicy.go index 6efaf335eb..b85ed32ca6 100644 --- a/npm/translatePolicy.go +++ b/npm/translatePolicy.go @@ -1700,10 +1700,9 @@ func translatePolicy(npObj *networkingv1.NetworkPolicy) ([]string, []string, map } entries = append(entries, getDefaultDropEntries(npNs, npObj.Spec.PodSelector, hasIngress, hasEgress)...) - resultSets = util.UniqueStrSlice(resultSets) for resultListKey, resultLists := range resultListMap { resultListMap[resultListKey] = util.UniqueStrSlice(resultLists) } - return resultSets, resultNamedPorts, resultListMap, resultIngressIPCidrs, resultEgressIPCidrs, entries + return util.UniqueStrSlice(resultSets), util.UniqueStrSlice(resultNamedPorts), resultListMap, resultIngressIPCidrs, resultEgressIPCidrs, entries } diff --git a/npm/translatePolicy_test.go b/npm/translatePolicy_test.go index 245528f4d7..44173a0a8e 100644 --- a/npm/translatePolicy_test.go +++ b/npm/translatePolicy_test.go @@ -2,8 +2,10 @@ package npm import ( "encoding/json" + "fmt" "io/ioutil" "reflect" + "sort" "testing" "github.com/Azure/azure-container-networking/npm/iptm" @@ -2388,14 +2390,79 @@ func TestAllowAllFromAppBackend(t *testing.T) { } } +// sortedIpSetMap returns a map which has direction and sorted ipsets as key and values respectively. +func sortedIpSetMap(specs []string) map[string][]string { + var prevSpec string + namedPortIpSet := fmt.Sprintf("%s,%s", util.IptablesDstFlag, util.IptablesDstFlag) + ipSetMap := make(map[string][]string) + + // Direction is always followed after ipset + for _, spec := range specs { + if spec == util.IptablesDstFlag || spec == util.IptablesSrcFlag || spec == namedPortIpSet { + ipSetMap[spec] = append(ipSetMap[spec], prevSpec) + } + prevSpec = spec + } + + // sorting ipsets + for _, ipsets := range ipSetMap { + sort.Strings(ipsets) + } + return ipSetMap +} + +// equalSetNamesInIptEntries checks whether ipset is the same or not after sorting them +func equalIPSetsInIptEntries(iptEntries, expectedIptEntries []*iptm.IptEntry, t *testing.T) bool { + if len(iptEntries) != len(expectedIptEntries) { + return false + } + + for i := 0; i < len(expectedIptEntries); i++ { + IpSetMap := sortedIpSetMap(iptEntries[i].Specs) + expectedIpsetMap := sortedIpSetMap(expectedIptEntries[i].Specs) + + if !reflect.DeepEqual(IpSetMap, expectedIpsetMap) { + t.Errorf("Ipsets are different\n got %+v\n want %+v", IpSetMap, expectedIpsetMap) + return false + } + } + + return true +} + +// check all returned values from translation function against expected values +func checkNetPolTranslationResult(netPolPolicy string, sets, expectedSets []string, lists, expectedLists map[string][]string, iptEntries, expectedIptEntries []*iptm.IptEntry, t *testing.T) bool { + if !util.CompareSlices(sets, expectedSets) { + t.Errorf("translatedPolicy failed @ %s\n sets: %v\n expectedSets: %v", netPolPolicy, sets, expectedSets) + return false + } + + // TODO(jungukcho): check whether this (map) is the same issue or not before merging it to master + if !reflect.DeepEqual(lists, expectedLists) { + t.Errorf("translatedPolicy failed @ %s\n lists: %v\n expectedLists: %v", netPolPolicy, lists, expectedLists) + return false + } + + if !reflect.DeepEqual(iptEntries, expectedIptEntries) { + if !equalIPSetsInIptEntries(iptEntries, expectedIptEntries, t) { + marshalledIptEntries, _ := json.Marshal(iptEntries) + marshalledExpectedIptEntries, _ := json.Marshal(expectedIptEntries) + t.Errorf("translatedPolicy failed @ %s\n iptEntries: %s\n expectedIptEntries: %s", netPolPolicy, marshalledIptEntries, marshalledExpectedIptEntries) + return false + } + } + + return true +} func TestAllowMultiplePodSelectors(t *testing.T) { - multiPodSlector, err := readPolicyYaml("testpolicies/allow-ns-y-z-pod-b-c.yaml") + // TODO(jungukcho): need to set util.IsNewNwPolicyVerFlag as true. It is a very strong dependency. Need to remove this dependency. + util.IsNewNwPolicyVerFlag = true + netPolFile := "allow-ns-y-z-pod-b-c.yaml" + multiPodSlector, err := readPolicyYaml(fmt.Sprintf("testpolicies/%s", netPolFile)) if err != nil { t.Fatal(err) } - util.IsNewNwPolicyVerFlag = true - sets, _, lists, _, _, iptEntries := translatePolicy(multiPodSlector) expectedSets := []string{ @@ -2407,11 +2474,6 @@ func TestAllowMultiplePodSelectors(t *testing.T) { "app:test", "app:int", } - if !util.CompareSlices(sets, expectedSets) { - t.Errorf("translatedPolicy failed @ allow-ns-y-z-pod-b-c sets comparison") - t.Errorf("sets: %v", sets) - t.Errorf("expectedSets: %v", expectedSets) - } expectedLists := map[string][]string{ "app:test:int": { @@ -2429,14 +2491,8 @@ func TestAllowMultiplePodSelectors(t *testing.T) { "pod:c", }, } - if !reflect.DeepEqual(lists, expectedLists) { - t.Errorf("translatedPolicy failed @ allow-ns-y-z-pod-b-c lists comparison") - t.Errorf("lists: %v", lists) - t.Errorf("expectedLists: %v", expectedLists) - } - expectedIptEntries := []*iptm.IptEntry{} - nonKubeSystemEntries := []*iptm.IptEntry{ + expectedIptEntries := []*iptm.IptEntry{ { Chain: util.IptablesAzureIngressFromChain, Specs: []string{ @@ -2516,15 +2572,43 @@ func TestAllowMultiplePodSelectors(t *testing.T) { }, }, } - expectedIptEntries = append(expectedIptEntries, nonKubeSystemEntries...) - // has egress, but empty map means allow all expectedIptEntries = append(expectedIptEntries, getDefaultDropEntries("netpol-4537-x", multiPodSlector.Spec.PodSelector, true, false)...) - if !reflect.DeepEqual(iptEntries, expectedIptEntries) { - t.Errorf("translatedPolicy failed @ allow-ns-y-z-pod-b-c policy comparison") - marshalledIptEntries, _ := json.Marshal(iptEntries) - marshalledExpectedIptEntries, _ := json.Marshal(expectedIptEntries) - t.Errorf("iptEntries: %s", marshalledIptEntries) - t.Errorf("expectedIptEntries: %s", marshalledExpectedIptEntries) + + if !checkNetPolTranslationResult(netPolFile, sets, expectedSets, lists, expectedLists, iptEntries, expectedIptEntries, t) { + t.Errorf("translatedPolicy failed @ %s", netPolFile) + } + + netPolFile = "allow-ns-y-z-pod-b-c-with-namedPort.yaml" + multiPodSlector, err = readPolicyYaml(fmt.Sprintf("testpolicies/%s", netPolFile)) + if err != nil { + t.Fatal(err) + } + + var namedPorts []string + sets, namedPorts, lists, _, _, iptEntries = translatePolicy(multiPodSlector) + + // reuse previously used ipentries + entriesForNamePort := []string{ + util.IptablesModuleFlag, + util.IptablesSetModuleFlag, + util.IptablesMatchSetFlag, + util.GetHashedName("namedport:serve-80"), + util.IptablesDstFlag + "," + util.IptablesDstFlag, + util.IptablesJumpFlag, + } + // Do not add namedPort entries in last drop rule + for i := 0; i < len(expectedIptEntries)-1; i++ { + expectedIptEntries[i].Specs = append(entriesForNamePort, expectedIptEntries[i].Specs...) + } + + // compared NamedPort results + expectedNamedPorts := []string{"namedport:serve-80"} + if !reflect.DeepEqual(namedPorts, expectedNamedPorts) { + t.Errorf("translatedPolicy failed namedPort @ %s comparison\n sets: %v\n expectedSets %v", netPolFile, namedPorts, expectedNamedPorts) + } + + if !checkNetPolTranslationResult(netPolFile, sets, expectedSets, lists, expectedLists, iptEntries, expectedIptEntries, t) { + t.Errorf("translatedPolicy failed @ %s", netPolFile) } } diff --git a/npm/util/util.go b/npm/util/util.go index 09d7b7d0d1..0112f64517 100644 --- a/npm/util/util.go +++ b/npm/util/util.go @@ -6,6 +6,7 @@ import ( "fmt" "hash/fnv" "os" + "reflect" "regexp" "sort" "strconv" @@ -334,10 +335,7 @@ func StrExistsInSlice(items []string, val string) bool { } func CompareSlices(list1, list2 []string) bool { - for _, item := range list1 { - if !StrExistsInSlice(list2, item) { - return false - } - } - return true + sort.Strings(list1) + sort.Strings(list2) + return reflect.DeepEqual(list1, list2) }