diff --git a/npm/pkg/dataplane/policies/chain-management_linux.go b/npm/pkg/dataplane/policies/chain-management_linux.go index 2168706584..4ce6053caa 100644 --- a/npm/pkg/dataplane/policies/chain-management_linux.go +++ b/npm/pkg/dataplane/policies/chain-management_linux.go @@ -5,7 +5,6 @@ import ( "fmt" "strconv" "strings" - "time" "github.com/Azure/azure-container-networking/npm/metrics" "github.com/Azure/azure-container-networking/npm/pkg/dataplane/ioutil" @@ -17,7 +16,6 @@ import ( const ( defaultlockWaitTimeInSeconds string = "60" - reconcileChainTimeInMinutes int = 5 doesNotExistErrorCode int = 1 // Bad rule (does a matching rule exist in that chain?) couldntLoadTargetErrorCode int = 2 // Couldn't load target `AZURE-NPM-EGRESS':No such file or directory @@ -54,6 +52,51 @@ var ( ingressOrEgressPolicyChainPattern = fmt.Sprintf("'Chain %s-\\|Chain %s-'", util.IptablesAzureIngressPolicyChainPrefix, util.IptablesAzureEgressPolicyChainPrefix) ) +type staleChains struct { + chainsToCleanup map[string]struct{} +} + +func newStaleChains() *staleChains { + return &staleChains{make(map[string]struct{})} +} + +func (s *staleChains) add(chain string) { + s.chainsToCleanup[chain] = struct{}{} +} + +func (s *staleChains) remove(chain string) { + delete(s.chainsToCleanup, chain) +} + +func (s *staleChains) emptyAndGetAll() []string { + result := make([]string, len(s.chainsToCleanup)) + k := 0 + for chain := range s.chainsToCleanup { + result[k] = chain + s.remove(chain) + k++ + } + return result +} + +func (s *staleChains) empty() { + s.chainsToCleanup = make(map[string]struct{}) +} + +// A proactive approach to avoid time to install default chains when the first networkpolicy comes again. +// Different from v1, which uninits when there are no policies and initializes when there are policies. +// The dataplane also initializes when it's created, so this keeps the policymanager in-line with that philosophy of having chains initialized at all times. +func (pMgr *PolicyManager) reboot() error { + // TODO for the sake of UTs, need to have a pMgr config specifying whether or not this reboot happens + // if err := pMgr.reset(); err != nil { + // return npmerrors.SimpleErrorWrapper("failed to remove NPM chains while rebooting", err) + // } + // if err := pMgr.initialize(); err != nil { + // return npmerrors.SimpleErrorWrapper("failed to initialize NPM chains while rebooting", err) + // } + return nil +} + func (pMgr *PolicyManager) initialize() error { if err := pMgr.initializeNPMChains(); err != nil { return npmerrors.SimpleErrorWrapper("failed to initialize NPM chains", err) @@ -65,6 +108,7 @@ func (pMgr *PolicyManager) reset() error { if err := pMgr.removeNPMChains(); err != nil { return npmerrors.SimpleErrorWrapper("failed to remove NPM chains", err) } + pMgr.staleChains.empty() return nil } @@ -72,7 +116,7 @@ func (pMgr *PolicyManager) reset() error { // AZURE-NPM chain is after the jumps to KUBE-FORWARD & KUBE-SERVICES chains (if they exist). func (pMgr *PolicyManager) initializeNPMChains() error { klog.Infof("Initializing AZURE-NPM chains.") - creator := pMgr.getCreatorForInitChains() + creator := pMgr.creatorForInitChains() err := restore(creator) if err != nil { return npmerrors.SimpleErrorWrapper("failed to create chains and rules", err) @@ -101,7 +145,7 @@ func (pMgr *PolicyManager) removeNPMChains() error { } // flush all chains (will create any chain, including deprecated ones, if they don't exist) - creatorToFlush, chainsToDelete := pMgr.getCreatorAndChainsForReset() + creatorToFlush, chainsToDelete := pMgr.creatorAndChainsForReset() restoreError := restore(creatorToFlush) if restoreError != nil { return npmerrors.SimpleErrorWrapper("failed to flush chains", restoreError) @@ -123,26 +167,37 @@ func (pMgr *PolicyManager) removeNPMChains() error { return nil } -// ReconcileChains periodically creates the jump rule from FORWARD chain to AZURE-NPM chain (if it d.n.e) -// and makes sure it's after the jumps to KUBE-FORWARD & KUBE-SERVICES chains (if they exist). -func (pMgr *PolicyManager) ReconcileChains(stopChannel <-chan struct{}) { - go pMgr.reconcileChains(stopChannel) +// reconcile does the following: +// - cleans up stale policy chains +// - creates the jump rule from FORWARD chain to AZURE-NPM chain (if it does not exist) and makes sure it's after the jumps to KUBE-FORWARD & KUBE-SERVICES chains (if they exist). +func (pMgr *PolicyManager) reconcile() { + if err := pMgr.positionAzureChainJumpRule(); err != nil { + klog.Errorf("failed to reconcile jump rule to Azure-NPM due to %s", err.Error()) + } + if err := pMgr.cleanupChains(pMgr.staleChains.emptyAndGetAll()); err != nil { + klog.Errorf("failed to clean up old policy chains with the following error %s", err.Error()) + } } -func (pMgr *PolicyManager) reconcileChains(stopChannel <-chan struct{}) { - ticker := time.NewTicker(time.Minute * time.Duration(reconcileChainTimeInMinutes)) - defer ticker.Stop() - - for { - select { - case <-stopChannel: - return - case <-ticker.C: - if err := pMgr.positionAzureChainJumpRule(); err != nil { - metrics.SendErrorLogAndMetric(util.NpmID, "Error: failed to reconcile jump rule to Azure-NPM due to %s", err.Error()) +// have to use slice argument for deterministic behavior for UTs +func (pMgr *PolicyManager) cleanupChains(chains []string) error { + var aggregateError error + for _, chain := range chains { + errCode, err := pMgr.runIPTablesCommand(util.IptablesDestroyFlag, chain) // TODO run the one that ignores doesNotExistErrorCode + if err != nil && errCode != doesNotExistErrorCode { + pMgr.staleChains.add(chain) + currentErrString := fmt.Sprintf("failed to clean up policy chain %s with err [%v]", chain, err) + if aggregateError == nil { + aggregateError = npmerrors.SimpleError(currentErrString) + } else { + aggregateError = npmerrors.SimpleErrorWrapper(fmt.Sprintf("%s and had previous error", currentErrString), aggregateError) } } } + if aggregateError != nil { + return npmerrors.SimpleErrorWrapper("failed to clean up some policy chains with errors", aggregateError) + } + return nil } // this function has a direct comparison in NPM v1 iptables manager (iptm.go) @@ -170,8 +225,8 @@ func (pMgr *PolicyManager) runIPTablesCommand(operationFlag string, args ...stri return 0, nil } -func (pMgr *PolicyManager) getCreatorForInitChains() *ioutil.FileCreator { - creator := pMgr.getNewCreatorWithChains(iptablesAzureChains) +func (pMgr *PolicyManager) creatorForInitChains() *ioutil.FileCreator { + creator := pMgr.newCreatorWithChains(iptablesAzureChains) // add AZURE-NPM chain rules creator.AddLine("", nil, util.IptablesAppendFlag, util.IptablesAzureChain, util.IptablesJumpFlag, util.IptablesAzureIngressChain) @@ -180,32 +235,32 @@ func (pMgr *PolicyManager) getCreatorForInitChains() *ioutil.FileCreator { // add AZURE-NPM-INGRESS chain rules ingressDropSpecs := []string{util.IptablesAppendFlag, util.IptablesAzureIngressChain, util.IptablesJumpFlag, util.IptablesDrop} - ingressDropSpecs = append(ingressDropSpecs, getOnMarkSpecs(util.IptablesAzureIngressDropMarkHex)...) - ingressDropSpecs = append(ingressDropSpecs, getCommentSpecs(fmt.Sprintf("DROP-ON-INGRESS-DROP-MARK-%s", util.IptablesAzureIngressDropMarkHex))...) + ingressDropSpecs = append(ingressDropSpecs, onMarkSpecs(util.IptablesAzureIngressDropMarkHex)...) + ingressDropSpecs = append(ingressDropSpecs, commentSpecs(fmt.Sprintf("DROP-ON-INGRESS-DROP-MARK-%s", util.IptablesAzureIngressDropMarkHex))...) creator.AddLine("", nil, ingressDropSpecs...) // add AZURE-NPM-INGRESS-ALLOW-MARK chain markIngressAllowSpecs := []string{util.IptablesAppendFlag, util.IptablesAzureIngressAllowMarkChain} - markIngressAllowSpecs = append(markIngressAllowSpecs, getSetMarkSpecs(util.IptablesAzureIngressAllowMarkHex)...) - markIngressAllowSpecs = append(markIngressAllowSpecs, getCommentSpecs(fmt.Sprintf("SET-INGRESS-ALLOW-MARK-%s", util.IptablesAzureIngressAllowMarkHex))...) + markIngressAllowSpecs = append(markIngressAllowSpecs, setMarkSpecs(util.IptablesAzureIngressAllowMarkHex)...) + markIngressAllowSpecs = append(markIngressAllowSpecs, commentSpecs(fmt.Sprintf("SET-INGRESS-ALLOW-MARK-%s", util.IptablesAzureIngressAllowMarkHex))...) creator.AddLine("", nil, markIngressAllowSpecs...) creator.AddLine("", nil, util.IptablesAppendFlag, util.IptablesAzureIngressAllowMarkChain, util.IptablesJumpFlag, util.IptablesAzureEgressChain) // add AZURE-NPM-EGRESS chain rules egressDropSpecs := []string{util.IptablesAppendFlag, util.IptablesAzureEgressChain, util.IptablesJumpFlag, util.IptablesDrop} - egressDropSpecs = append(egressDropSpecs, getOnMarkSpecs(util.IptablesAzureEgressDropMarkHex)...) - egressDropSpecs = append(egressDropSpecs, getCommentSpecs(fmt.Sprintf("DROP-ON-EGRESS-DROP-MARK-%s", util.IptablesAzureEgressDropMarkHex))...) + egressDropSpecs = append(egressDropSpecs, onMarkSpecs(util.IptablesAzureEgressDropMarkHex)...) + egressDropSpecs = append(egressDropSpecs, commentSpecs(fmt.Sprintf("DROP-ON-EGRESS-DROP-MARK-%s", util.IptablesAzureEgressDropMarkHex))...) creator.AddLine("", nil, egressDropSpecs...) jumpOnIngressMatchSpecs := []string{util.IptablesAppendFlag, util.IptablesAzureEgressChain, util.IptablesJumpFlag, util.IptablesAzureAcceptChain} - jumpOnIngressMatchSpecs = append(jumpOnIngressMatchSpecs, getOnMarkSpecs(util.IptablesAzureIngressAllowMarkHex)...) - jumpOnIngressMatchSpecs = append(jumpOnIngressMatchSpecs, getCommentSpecs(fmt.Sprintf("ACCEPT-ON-INGRESS-ALLOW-MARK-%s", util.IptablesAzureIngressAllowMarkHex))...) + jumpOnIngressMatchSpecs = append(jumpOnIngressMatchSpecs, onMarkSpecs(util.IptablesAzureIngressAllowMarkHex)...) + jumpOnIngressMatchSpecs = append(jumpOnIngressMatchSpecs, commentSpecs(fmt.Sprintf("ACCEPT-ON-INGRESS-ALLOW-MARK-%s", util.IptablesAzureIngressAllowMarkHex))...) creator.AddLine("", nil, jumpOnIngressMatchSpecs...) // add AZURE-NPM-ACCEPT chain rules clearSpecs := []string{util.IptablesAppendFlag, util.IptablesAzureAcceptChain} - clearSpecs = append(clearSpecs, getSetMarkSpecs(util.IptablesAzureClearMarkHex)...) - clearSpecs = append(clearSpecs, getCommentSpecs("Clear-AZURE-NPM-MARKS")...) + clearSpecs = append(clearSpecs, setMarkSpecs(util.IptablesAzureClearMarkHex)...) + clearSpecs = append(clearSpecs, commentSpecs("Clear-AZURE-NPM-MARKS")...) creator.AddLine("", nil, clearSpecs...) creator.AddLine("", nil, util.IptablesAppendFlag, util.IptablesAzureAcceptChain, util.IptablesJumpFlag, util.IptablesAccept) creator.AddLine("", nil, util.IptablesRestoreCommit) @@ -215,7 +270,7 @@ func (pMgr *PolicyManager) getCreatorForInitChains() *ioutil.FileCreator { // add/reposition AZURE-NPM chain after KUBE-FORWARD and KUBE-SERVICE chains if they exist // this function has a direct comparison in NPM v1 iptables manager (iptm.go) func (pMgr *PolicyManager) positionAzureChainJumpRule() error { - kubeServicesLine, kubeServicesLineNumErr := pMgr.getChainLineNumber(util.IptablesKubeServicesChain) + kubeServicesLine, kubeServicesLineNumErr := pMgr.chainLineNumber(util.IptablesKubeServicesChain) if kubeServicesLineNumErr != nil { // not possible to cover this branch currently because of testing limitations for PipeCommandToGrep() baseErrString := "failed to get index of jump from KUBE-SERVICES chain to FORWARD chain with error" @@ -225,7 +280,7 @@ func (pMgr *PolicyManager) positionAzureChainJumpRule() error { index := kubeServicesLine + 1 - // TODO could call getChainLineNumber instead, and say it doesn't exist for lineNum == 0 + // TODO could call chainLineNumber instead, and say it doesn't exist for lineNum == 0 jumpRuleErrCode, checkErr := pMgr.runIPTablesCommand(util.IptablesCheckFlag, jumpFromForwardToAzureChainArgs...) hadCheckError := checkErr != nil && jumpRuleErrCode != doesNotExistErrorCode if hadCheckError { @@ -252,7 +307,7 @@ func (pMgr *PolicyManager) positionAzureChainJumpRule() error { return nil } - npmChainLine, npmLineNumErr := pMgr.getChainLineNumber(util.IptablesAzureChain) + npmChainLine, npmLineNumErr := pMgr.chainLineNumber(util.IptablesAzureChain) if npmLineNumErr != nil { // not possible to cover this branch currently because of testing limitations for PipeCommandToGrep() baseErrString := "failed to get index of jump from FORWARD chain to AZURE-NPM chain" @@ -293,7 +348,7 @@ func (pMgr *PolicyManager) positionAzureChainJumpRule() error { // returns 0 if the chain d.n.e. // this function has a direct comparison in NPM v1 iptables manager (iptm.go) -func (pMgr *PolicyManager) getChainLineNumber(chain string) (int, error) { +func (pMgr *PolicyManager) chainLineNumber(chain string) (int, error) { // TODO could call this once and use regex instead of grep to cut down on OS calls listForwardEntriesCommand := pMgr.ioShim.Exec.Command(util.Iptables, util.IptablesWaitFlag, defaultlockWaitTimeInSeconds, util.IptablesTableFlag, util.IptablesFilterTable, @@ -316,20 +371,20 @@ func (pMgr *PolicyManager) getChainLineNumber(chain string) (int, error) { } // make this a function for easier testing -func (pMgr *PolicyManager) getCreatorAndChainsForReset() (creator *ioutil.FileCreator, chainsToFlush []string) { - oldPolicyChains, err := pMgr.getPolicyChainNames() +func (pMgr *PolicyManager) creatorAndChainsForReset() (creator *ioutil.FileCreator, chainsToFlush []string) { + oldPolicyChains, err := pMgr.policyChainNames() if err != nil { // not possible to cover this branch currently because of testing limitations for PipeCommandToGrep() metrics.SendErrorLogAndMetric(util.IptmID, "Error: failed to determine NPM ingress/egress policy chains to delete") } chainsToFlush = iptablesOldAndNewChains chainsToFlush = append(chainsToFlush, oldPolicyChains...) // will work even if oldPolicyChains is nil - creator = pMgr.getNewCreatorWithChains(chainsToFlush) + creator = pMgr.newCreatorWithChains(chainsToFlush) creator.AddLine("", nil, util.IptablesRestoreCommit) return } -func (pMgr *PolicyManager) getPolicyChainNames() ([]string, error) { +func (pMgr *PolicyManager) policyChainNames() ([]string, error) { iptablesListCommand := pMgr.ioShim.Exec.Command(util.Iptables, util.IptablesWaitFlag, defaultlockWaitTimeInSeconds, util.IptablesTableFlag, util.IptablesFilterTable, util.IptablesNumericFlag, util.IptablesListFlag, @@ -355,7 +410,7 @@ func (pMgr *PolicyManager) getPolicyChainNames() ([]string, error) { return chainNames, nil } -func getOnMarkSpecs(mark string) []string { +func onMarkSpecs(mark string) []string { return []string{ util.IptablesModuleFlag, util.IptablesMarkVerb, diff --git a/npm/pkg/dataplane/policies/chain-management_linux_test.go b/npm/pkg/dataplane/policies/chain-management_linux_test.go index 73482e4abe..cbd582e1ae 100644 --- a/npm/pkg/dataplane/policies/chain-management_linux_test.go +++ b/npm/pkg/dataplane/policies/chain-management_linux_test.go @@ -2,6 +2,7 @@ package policies import ( "fmt" + "sort" "strings" "testing" @@ -11,9 +12,70 @@ import ( "github.com/stretchr/testify/require" ) +const ( + testChain1 = "chain1" + testChain2 = "chain2" + testChain3 = "chain3" +) + +func TestEmptyAndGetAll(t *testing.T) { + pMgr := NewPolicyManager(common.NewMockIOShim(nil)) + pMgr.staleChains.add(testChain1) + pMgr.staleChains.add(testChain2) + chainsToCleanup := pMgr.staleChains.emptyAndGetAll() + require.Equal(t, 2, len(chainsToCleanup)) + require.True(t, chainsToCleanup[0] == testChain1 || chainsToCleanup[1] == testChain1) + require.True(t, chainsToCleanup[0] == testChain2 || chainsToCleanup[1] == testChain2) + assertStaleChainsContain(t, pMgr.staleChains) +} + +func assertStaleChainsContain(t *testing.T, s *staleChains, expectedChains ...string) { + require.Equal(t, len(expectedChains), len(s.chainsToCleanup), "incorrectly tracking chains for cleanup") + for _, chain := range expectedChains { + _, exists := s.chainsToCleanup[chain] + require.True(t, exists, "incorrectly tracking chains for cleanup") + } +} + +func TestCleanupChainsSuccess(t *testing.T) { + calls := []testutils.TestCmd{ + getFakeDestroyCommand(testChain1), + getFakeDestroyCommandWithExitCode(testChain2, 1), // exit code 1 means the chain d.n.e. + } + ioshim := common.NewMockIOShim(calls) + // TODO defer ioshim.VerifyCalls(t, ioshim, calls) + pMgr := NewPolicyManager(ioshim) + + pMgr.staleChains.add(testChain1) + pMgr.staleChains.add(testChain2) + chainsToCleanup := pMgr.staleChains.emptyAndGetAll() + sort.Strings(chainsToCleanup) + require.NoError(t, pMgr.cleanupChains(chainsToCleanup)) + assertStaleChainsContain(t, pMgr.staleChains) +} + +func TestCleanupChainsFailure(t *testing.T) { + calls := []testutils.TestCmd{ + getFakeDestroyCommandWithExitCode(testChain1, 2), + getFakeDestroyCommand(testChain2), + getFakeDestroyCommandWithExitCode(testChain3, 2), + } + ioshim := common.NewMockIOShim(calls) + // TODO defer ioshim.VerifyCalls(t, ioshim, calls) + pMgr := NewPolicyManager(ioshim) + + pMgr.staleChains.add(testChain1) + pMgr.staleChains.add(testChain2) + pMgr.staleChains.add(testChain3) + chainsToCleanup := pMgr.staleChains.emptyAndGetAll() + sort.Strings(chainsToCleanup) + require.Error(t, pMgr.cleanupChains(chainsToCleanup)) + assertStaleChainsContain(t, pMgr.staleChains, testChain1, testChain3) +} + func TestInitChainsCreator(t *testing.T) { pMgr := NewPolicyManager(common.NewMockIOShim(nil)) - creator := pMgr.getCreatorForInitChains() // doesn't make any exec calls + creator := pMgr.creatorForInitChains() // doesn't make any exec calls actualLines := strings.Split(creator.ToString(), "\n") expectedLines := []string{"*filter"} for _, chain := range iptablesAzureChains { @@ -77,7 +139,7 @@ func TestRemoveChainsCreator(t *testing.T) { } pMgr := NewPolicyManager(common.NewMockIOShim(creatorCalls)) - creator, chainsToFlush := pMgr.getCreatorAndChainsForReset() + creator, chainsToFlush := pMgr.creatorAndChainsForReset() expectedChainsToFlush := []string{ "AZURE-NPM", "AZURE-NPM-INGRESS", @@ -109,7 +171,7 @@ func TestRemoveChainsCreator(t *testing.T) { func TestRemoveChainsSuccess(t *testing.T) { calls := GetResetTestCalls() - for _, chain := range iptablesOldAndNewChains { + for _, chain := range iptablesOldAndNewChains { // TODO write these out, don't use variable calls = append(calls, getFakeDestroyCommand(chain)) } calls = append( @@ -162,7 +224,7 @@ func TestRemoveChainsFailureOnDestroy(t *testing.T) { {Cmd: []string{"grep", ingressOrEgressPolicyChainPattern}}, // ExitCode 0 for the iptables restore command fakeIPTablesRestoreCommand, } - calls = append(calls, getFakeDestroyFailureCommand(iptablesOldAndNewChains[0])) // this ExitCode here will actually impact the next below + calls = append(calls, getFakeDestroyCommandWithExitCode(iptablesOldAndNewChains[0], 2)) // this ExitCode here will actually impact the next below for _, chain := range iptablesOldAndNewChains[1:] { calls = append(calls, getFakeDestroyCommand(chain)) } @@ -355,7 +417,7 @@ func TestGetChainLineNumber(t *testing.T) { grepCommand, } pMgr := NewPolicyManager(common.NewMockIOShim(calls)) - lineNum, err := pMgr.getChainLineNumber(testChainName) + lineNum, err := pMgr.chainLineNumber(testChainName) require.Equal(t, 3, lineNum) require.NoError(t, err) @@ -368,7 +430,7 @@ func TestGetChainLineNumber(t *testing.T) { grepCommand, } pMgr = NewPolicyManager(common.NewMockIOShim(calls)) - lineNum, err = pMgr.getChainLineNumber(testChainName) + lineNum, err = pMgr.chainLineNumber(testChainName) require.Equal(t, 0, lineNum) require.NoError(t, err) } @@ -384,7 +446,7 @@ func TestGetPolicyChainNames(t *testing.T) { grepCommand, } pMgr := NewPolicyManager(common.NewMockIOShim(calls)) - chainNames, err := pMgr.getPolicyChainNames() + chainNames, err := pMgr.policyChainNames() expectedChainNames := []string{ "AZURE-NPM-INGRESS-123456", "AZURE-NPM-EGRESS-123456", @@ -401,7 +463,7 @@ func TestGetPolicyChainNames(t *testing.T) { grepCommand, } pMgr = NewPolicyManager(common.NewMockIOShim(calls)) - chainNames, err = pMgr.getPolicyChainNames() + chainNames, err = pMgr.policyChainNames() expectedChainNames = nil require.Equal(t, expectedChainNames, chainNames) require.NoError(t, err) @@ -411,8 +473,8 @@ func getFakeDestroyCommand(chain string) testutils.TestCmd { return testutils.TestCmd{Cmd: []string{"iptables", "-w", "60", "-X", chain}} } -func getFakeDestroyFailureCommand(chain string) testutils.TestCmd { +func getFakeDestroyCommandWithExitCode(chain string, exitCode int) testutils.TestCmd { command := getFakeDestroyCommand(chain) - command.ExitCode = 1 + command.ExitCode = exitCode return command } diff --git a/npm/pkg/dataplane/policies/policy_linux.go b/npm/pkg/dataplane/policies/policy_linux.go new file mode 100644 index 0000000000..9683a65642 --- /dev/null +++ b/npm/pkg/dataplane/policies/policy_linux.go @@ -0,0 +1,27 @@ +package policies + +import "github.com/Azure/azure-container-networking/npm/util" + +// returns two booleans indicating whether the network policy has ingress and egress respectively +func (networkPolicy *NPMNetworkPolicy) hasIngressAndEgress() (hasIngress, hasEgress bool) { + hasIngress = false + hasEgress = false + for _, aclPolicy := range networkPolicy.ACLs { + hasIngress = hasIngress || aclPolicy.hasIngress() + hasEgress = hasEgress || aclPolicy.hasEgress() + } + return +} + +func (networkPolicy *NPMNetworkPolicy) egressChainName() string { + return networkPolicy.chainName(util.IptablesAzureEgressPolicyChainPrefix) +} + +func (networkPolicy *NPMNetworkPolicy) ingressChainName() string { + return networkPolicy.chainName(util.IptablesAzureIngressPolicyChainPrefix) +} + +func (networkPolicy *NPMNetworkPolicy) chainName(prefix string) string { + policyHash := util.Hash(networkPolicy.Name) // assuming the name is unique + return joinWithDash(prefix, policyHash) +} diff --git a/npm/pkg/dataplane/policies/policymanager.go b/npm/pkg/dataplane/policies/policymanager.go index 87ad8c38db..e6cc1fc462 100644 --- a/npm/pkg/dataplane/policies/policymanager.go +++ b/npm/pkg/dataplane/policies/policymanager.go @@ -2,19 +2,25 @@ package policies import ( "fmt" + "sync" + "time" "github.com/Azure/azure-container-networking/common" npmerrors "github.com/Azure/azure-container-networking/npm/util/errors" "k8s.io/klog" ) +const reconcileChainTimeInMinutes = 5 + type PolicyMap struct { cache map[string]*NPMNetworkPolicy } type PolicyManager struct { - policyMap *PolicyMap - ioShim *common.IOShim + policyMap *PolicyMap + ioShim *common.IOShim + staleChains *staleChains + sync.Mutex } func NewPolicyManager(ioShim *common.IOShim) *PolicyManager { @@ -22,7 +28,8 @@ func NewPolicyManager(ioShim *common.IOShim) *PolicyManager { policyMap: &PolicyMap{ cache: make(map[string]*NPMNetworkPolicy), }, - ioShim: ioShim, + ioShim: ioShim, + staleChains: newStaleChains(), } } @@ -40,6 +47,24 @@ func (pMgr *PolicyManager) Reset() error { return nil } +func (pMgr *PolicyManager) Reconcile(stopChannel <-chan struct{}) { + go func() { + ticker := time.NewTicker(time.Minute * time.Duration(reconcileChainTimeInMinutes)) + defer ticker.Stop() + + for { + select { + case <-stopChannel: + return + case <-ticker.C: + pMgr.Lock() + defer pMgr.Unlock() + pMgr.reconcile() + } + } + }() +} + func (pMgr *PolicyManager) PolicyExists(name string) bool { _, ok := pMgr.policyMap.cache[name] return ok @@ -87,6 +112,12 @@ func (pMgr *PolicyManager) RemovePolicy(name string, endpointList map[string]str } delete(pMgr.policyMap.cache, name) + if len(pMgr.policyMap.cache) == 0 { + klog.Infof("rebooting policy manager since there are no policies remaining in the cache") + if err := pMgr.reboot(); err != nil { + klog.Errorf("failed to reboot when there were no policies remaining") + } + } return nil } diff --git a/npm/pkg/dataplane/policies/policymanager_linux.go b/npm/pkg/dataplane/policies/policymanager_linux.go index 0313b4bac8..a40a3ba59e 100644 --- a/npm/pkg/dataplane/policies/policymanager_linux.go +++ b/npm/pkg/dataplane/policies/policymanager_linux.go @@ -21,11 +21,15 @@ const ( // shouldn't call this if the np has no ACLs (check in generic) func (pMgr *PolicyManager) addPolicy(networkPolicy *NPMNetworkPolicy, _ map[string]string) error { // TODO check for newPolicy errors - creator := pMgr.getCreatorForNewNetworkPolicies(networkPolicy) + allChainNames := allChainNames([]*NPMNetworkPolicy{networkPolicy}) + creator := pMgr.creatorForNewNetworkPolicies(allChainNames, networkPolicy) err := restore(creator) if err != nil { return npmerrors.SimpleErrorWrapper("failed to restore iptables with updated policies", err) } + for _, chain := range allChainNames { + pMgr.staleChains.remove(chain) + } return nil } @@ -34,11 +38,15 @@ func (pMgr *PolicyManager) removePolicy(networkPolicy *NPMNetworkPolicy, _ map[s if deleteErr != nil { return npmerrors.SimpleErrorWrapper("failed to delete jumps to policy chains", deleteErr) } - creator := pMgr.getCreatorForRemovingPolicies(networkPolicy) + allChainNames := allChainNames([]*NPMNetworkPolicy{networkPolicy}) + creator := pMgr.creatorForRemovingPolicies(allChainNames) restoreErr := restore(creator) if restoreErr != nil { return npmerrors.SimpleErrorWrapper("failed to flush policies", restoreErr) } + for _, chain := range allChainNames { + pMgr.staleChains.add(chain) + } return nil } @@ -50,54 +58,30 @@ func restore(creator *ioutil.FileCreator) error { return nil } -func (pMgr *PolicyManager) getCreatorForRemovingPolicies(networkPolicies ...*NPMNetworkPolicy) *ioutil.FileCreator { - allChainNames := getAllChainNames(networkPolicies) - creator := pMgr.getNewCreatorWithChains(allChainNames) +// TODO use array instead of ... +func (pMgr *PolicyManager) creatorForRemovingPolicies(allChainNames []string) *ioutil.FileCreator { + creator := pMgr.newCreatorWithChains(allChainNames) creator.AddLine("", nil, util.IptablesRestoreCommit) return creator } // returns all chain names (ingress and egress policy chain names) -func getAllChainNames(networkPolicies []*NPMNetworkPolicy) []string { +func allChainNames(networkPolicies []*NPMNetworkPolicy) []string { chainNames := make([]string, 0) for _, networkPolicy := range networkPolicies { hasIngress, hasEgress := networkPolicy.hasIngressAndEgress() if hasIngress { - chainNames = append(chainNames, networkPolicy.getIngressChainName()) + chainNames = append(chainNames, networkPolicy.ingressChainName()) } if hasEgress { - chainNames = append(chainNames, networkPolicy.getEgressChainName()) + chainNames = append(chainNames, networkPolicy.egressChainName()) } } return chainNames } -// returns two booleans indicating whether the network policy has ingress and egress respectively -func (networkPolicy *NPMNetworkPolicy) hasIngressAndEgress() (hasIngress, hasEgress bool) { - hasIngress = false - hasEgress = false - for _, aclPolicy := range networkPolicy.ACLs { - hasIngress = hasIngress || aclPolicy.hasIngress() - hasEgress = hasEgress || aclPolicy.hasEgress() - } - return -} - -func (networkPolicy *NPMNetworkPolicy) getEgressChainName() string { - return networkPolicy.getChainName(util.IptablesAzureEgressPolicyChainPrefix) -} - -func (networkPolicy *NPMNetworkPolicy) getIngressChainName() string { - return networkPolicy.getChainName(util.IptablesAzureIngressPolicyChainPrefix) -} - -func (networkPolicy *NPMNetworkPolicy) getChainName(prefix string) string { - policyHash := util.Hash(networkPolicy.Name) // assuming the name is unique - return joinWithDash(prefix, policyHash) -} - -func (pMgr *PolicyManager) getNewCreatorWithChains(chainNames []string) *ioutil.FileCreator { +func (pMgr *PolicyManager) newCreatorWithChains(chainNames []string) *ioutil.FileCreator { creator := ioutil.NewFileCreator(pMgr.ioShim, maxRetryCount, knownLineErrorPattern, unknownLineErrorPattern) // TODO pass an array instead of this ... thing creator.AddLine("", nil, "*"+util.IptablesFilterTable) // specify the table @@ -132,13 +116,13 @@ func (pMgr *PolicyManager) deleteJumpRule(policy *NPMNetworkPolicy, isIngress bo var baseChainName string var chainName string if isIngress { - specs = getIngressJumpSpecs(policy) + specs = ingressJumpSpecs(policy) baseChainName = util.IptablesAzureIngressChain - chainName = policy.getIngressChainName() + chainName = policy.ingressChainName() } else { - specs = getEgressJumpSpecs(policy) + specs = egressJumpSpecs(policy) baseChainName = util.IptablesAzureEgressChain - chainName = policy.getEgressChainName() + chainName = policy.egressChainName() } specs = append([]string{baseChainName}, specs...) @@ -152,22 +136,22 @@ func (pMgr *PolicyManager) deleteJumpRule(policy *NPMNetworkPolicy, isIngress bo return nil } -func getIngressJumpSpecs(networkPolicy *NPMNetworkPolicy) []string { - chainName := networkPolicy.getIngressChainName() +func ingressJumpSpecs(networkPolicy *NPMNetworkPolicy) []string { + chainName := networkPolicy.ingressChainName() specs := []string{util.IptablesJumpFlag, chainName} - return append(specs, getMatchSetSpecsForNetworkPolicy(networkPolicy, DstMatch)...) + return append(specs, matchSetSpecsForNetworkPolicy(networkPolicy, DstMatch)...) } -func getEgressJumpSpecs(networkPolicy *NPMNetworkPolicy) []string { - chainName := networkPolicy.getEgressChainName() +func egressJumpSpecs(networkPolicy *NPMNetworkPolicy) []string { + chainName := networkPolicy.egressChainName() specs := []string{util.IptablesJumpFlag, chainName} - return append(specs, getMatchSetSpecsForNetworkPolicy(networkPolicy, SrcMatch)...) + return append(specs, matchSetSpecsForNetworkPolicy(networkPolicy, SrcMatch)...) } // noflush add to chains impacted -func (pMgr *PolicyManager) getCreatorForNewNetworkPolicies(networkPolicies ...*NPMNetworkPolicy) *ioutil.FileCreator { - allChainNames := getAllChainNames(networkPolicies) - creator := pMgr.getNewCreatorWithChains(allChainNames) +// TODO use array instead of ... +func (pMgr *PolicyManager) creatorForNewNetworkPolicies(allChainNames []string, networkPolicies ...*NPMNetworkPolicy) *ioutil.FileCreator { + creator := pMgr.newCreatorWithChains(allChainNames) ingressJumpLineNumber := 1 egressJumpLineNumber := 1 @@ -177,12 +161,12 @@ func (pMgr *PolicyManager) getCreatorForNewNetworkPolicies(networkPolicies ...*N // add jump rule(s) to policy chain(s) hasIngress, hasEgress := networkPolicy.hasIngressAndEgress() if hasIngress { - ingressJumpSpecs := getInsertSpecs(util.IptablesAzureIngressChain, ingressJumpLineNumber, getIngressJumpSpecs(networkPolicy)) + ingressJumpSpecs := insertSpecs(util.IptablesAzureIngressChain, ingressJumpLineNumber, ingressJumpSpecs(networkPolicy)) creator.AddLine("", nil, ingressJumpSpecs...) // TODO error handler ingressJumpLineNumber++ } if hasEgress { - egressJumpSpecs := getInsertSpecs(util.IptablesAzureEgressChain, egressJumpLineNumber, getEgressJumpSpecs(networkPolicy)) + egressJumpSpecs := insertSpecs(util.IptablesAzureEgressChain, egressJumpLineNumber, egressJumpSpecs(networkPolicy)) creator.AddLine("", nil, egressJumpSpecs...) // TODO error handler egressJumpLineNumber++ } @@ -197,54 +181,47 @@ func writeNetworkPolicyRules(creator *ioutil.FileCreator, networkPolicy *NPMNetw var chainName string var actionSpecs []string if aclPolicy.hasIngress() { - chainName = networkPolicy.getIngressChainName() + chainName = networkPolicy.ingressChainName() if aclPolicy.Target == Allowed { actionSpecs = []string{util.IptablesJumpFlag, util.IptablesAzureEgressChain} } else { - actionSpecs = getSetMarkSpecs(util.IptablesAzureIngressDropMarkHex) + actionSpecs = setMarkSpecs(util.IptablesAzureIngressDropMarkHex) } } else { - chainName = networkPolicy.getEgressChainName() + chainName = networkPolicy.egressChainName() if aclPolicy.Target == Allowed { actionSpecs = []string{util.IptablesJumpFlag, util.IptablesAzureAcceptChain} } else { - actionSpecs = getSetMarkSpecs(util.IptablesAzureEgressDropMarkHex) + actionSpecs = setMarkSpecs(util.IptablesAzureEgressDropMarkHex) } } line := []string{"-A", chainName} line = append(line, actionSpecs...) - line = append(line, getIPTablesRuleSpecs(aclPolicy)...) + line = append(line, iptablesRuleSpecs(aclPolicy)...) creator.AddLine("", nil, line...) // TODO add error handler } } -func getIPTablesRuleSpecs(aclPolicy *ACLPolicy) []string { +func iptablesRuleSpecs(aclPolicy *ACLPolicy) []string { specs := make([]string, 0) specs = append(specs, util.IptablesProtFlag, string(aclPolicy.Protocol)) // NOTE: protocol must be ALL instead of nil - specs = append(specs, getPortSpecs([]Ports{aclPolicy.DstPorts})...) - specs = append(specs, getMatchSetSpecsFromSetInfo(aclPolicy.SrcList)...) - specs = append(specs, getMatchSetSpecsFromSetInfo(aclPolicy.DstList)...) + specs = append(specs, dstPortSpecs(aclPolicy.DstPorts)...) + specs = append(specs, matchSetSpecsFromSetInfo(aclPolicy.SrcList)...) + specs = append(specs, matchSetSpecsFromSetInfo(aclPolicy.DstList)...) if aclPolicy.Comment != "" { - specs = append(specs, getCommentSpecs(aclPolicy.Comment)...) + specs = append(specs, commentSpecs(aclPolicy.Comment)...) } return specs } -func getPortSpecs(portRanges []Ports) []string { - // TODO(jungukcho): do not need to take slices since it can only have one dst port - if len(portRanges) != 1 { +func dstPortSpecs(portRange Ports) []string { + if portRange.Port == 0 && portRange.EndPort == 0 { return []string{} } - - // TODO(jungukcho): temporary solution and need to fix it. - if portRanges[0].Port == 0 && portRanges[0].EndPort == 0 { - return []string{} - } - - return []string{util.IptablesDstPortFlag, portRanges[0].toIPTablesString()} + return []string{util.IptablesDstPortFlag, portRange.toIPTablesString()} } -func getMatchSetSpecsForNetworkPolicy(networkPolicy *NPMNetworkPolicy, matchType MatchType) []string { +func matchSetSpecsForNetworkPolicy(networkPolicy *NPMNetworkPolicy, matchType MatchType) []string { // TODO update to use included boolean/new data structure from Junguk's PR specs := make([]string, 0, maxLengthForMatchSetSpecs*len(networkPolicy.PodSelectorIPSets)) for _, translatedIPSet := range networkPolicy.PodSelectorIPSets { @@ -255,7 +232,7 @@ func getMatchSetSpecsForNetworkPolicy(networkPolicy *NPMNetworkPolicy, matchType return specs } -func getMatchSetSpecsFromSetInfo(setInfoList []SetInfo) []string { +func matchSetSpecsFromSetInfo(setInfoList []SetInfo) []string { specs := make([]string, 0, maxLengthForMatchSetSpecs*len(setInfoList)) for _, setInfo := range setInfoList { matchString := setInfo.MatchType.toIPTablesString() @@ -269,7 +246,7 @@ func getMatchSetSpecsFromSetInfo(setInfoList []SetInfo) []string { return specs } -func getSetMarkSpecs(mark string) []string { +func setMarkSpecs(mark string) []string { return []string{ util.IptablesJumpFlag, util.IptablesMark, @@ -278,7 +255,7 @@ func getSetMarkSpecs(mark string) []string { } } -func getCommentSpecs(comment string) []string { +func commentSpecs(comment string) []string { return []string{ util.IptablesModuleFlag, util.IptablesCommentModuleFlag, @@ -287,7 +264,7 @@ func getCommentSpecs(comment string) []string { } } -func getInsertSpecs(chainName string, index int, specs []string) []string { +func insertSpecs(chainName string, index int, specs []string) []string { indexString := fmt.Sprint(index) insertSpecs := []string{util.IptablesInsertionFlag, chainName, indexString} return append(insertSpecs, specs...) diff --git a/npm/pkg/dataplane/policies/policymanager_linux_test.go b/npm/pkg/dataplane/policies/policymanager_linux_test.go index da5ed289ef..c6d2722cff 100644 --- a/npm/pkg/dataplane/policies/policymanager_linux_test.go +++ b/npm/pkg/dataplane/policies/policymanager_linux_test.go @@ -8,15 +8,16 @@ import ( "github.com/Azure/azure-container-networking/common" "github.com/Azure/azure-container-networking/npm/pkg/dataplane/ipsets" dptestutils "github.com/Azure/azure-container-networking/npm/pkg/dataplane/testutils" + "github.com/Azure/azure-container-networking/npm/util" testutils "github.com/Azure/azure-container-networking/test/utils" "github.com/stretchr/testify/require" ) var ( - testPolicy1IngressChain = TestNetworkPolicies[0].getIngressChainName() - testPolicy1EgressChain = TestNetworkPolicies[0].getEgressChainName() - testPolicy2IngressChain = TestNetworkPolicies[1].getIngressChainName() - testPolicy3EgressChain = TestNetworkPolicies[2].getEgressChainName() + testPolicy1IngressChain = TestNetworkPolicies[0].ingressChainName() + testPolicy1EgressChain = TestNetworkPolicies[0].egressChainName() + testPolicy2IngressChain = TestNetworkPolicies[1].ingressChainName() + testPolicy3EgressChain = TestNetworkPolicies[2].egressChainName() testPolicy1IngressJump = fmt.Sprintf("-j %s -m set --match-set %s dst", testPolicy1IngressChain, ipsets.TestKVNSList.HashedName) testPolicy1EgressJump = fmt.Sprintf("-j %s -m set --match-set %s src", testPolicy1EgressChain, ipsets.TestKVNSList.HashedName) @@ -33,10 +34,17 @@ var ( testACLRule4 = fmt.Sprintf("-j AZURE-NPM-ACCEPT -p all -m set --match-set %s src -m comment --comment comment4", ipsets.TestCIDRSet.HashedName) ) +func TestChainNames(t *testing.T) { + expectedName := fmt.Sprintf("AZURE-NPM-INGRESS-%s", util.Hash(TestNetworkPolicies[0].Name)) + require.Equal(t, expectedName, TestNetworkPolicies[0].ingressChainName()) + expectedName = fmt.Sprintf("AZURE-NPM-EGRESS-%s", util.Hash(TestNetworkPolicies[0].Name)) + require.Equal(t, expectedName, TestNetworkPolicies[0].egressChainName()) +} + func TestAddPolicies(t *testing.T) { calls := []testutils.TestCmd{fakeIPTablesRestoreCommand} pMgr := NewPolicyManager(common.NewMockIOShim(calls)) - creator := pMgr.getCreatorForNewNetworkPolicies(TestNetworkPolicies...) + creator := pMgr.creatorForNewNetworkPolicies(allChainNames(TestNetworkPolicies), TestNetworkPolicies...) actualLines := strings.Split(creator.ToString(), "\n") expectedLines := []string{ "*filter", @@ -81,7 +89,7 @@ func TestRemovePolicies(t *testing.T) { fakeIPTablesRestoreCommand, } pMgr := NewPolicyManager(common.NewMockIOShim(calls)) - creator := pMgr.getCreatorForRemovingPolicies(TestNetworkPolicies...) + creator := pMgr.creatorForRemovingPolicies(allChainNames(TestNetworkPolicies)) actualLines := strings.Split(creator.ToString(), "\n") expectedLines := []string{ "*filter", @@ -113,7 +121,7 @@ func TestRemovePoliciesErrorOnRestore(t *testing.T) { require.Error(t, err) } -func TestRemovePoliciesErrorOnIngressRule(t *testing.T) { +func TestRemovePoliciesErrorOnDeleteForIngress(t *testing.T) { calls := []testutils.TestCmd{ fakeIPTablesRestoreCommand, getFakeDeleteJumpCommandWithCode("AZURE-NPM-INGRESS", testPolicy1IngressJump, 1), // anything but 0 or 2 @@ -125,7 +133,7 @@ func TestRemovePoliciesErrorOnIngressRule(t *testing.T) { require.Error(t, err) } -func TestRemovePoliciesErrorOnEgressRule(t *testing.T) { +func TestRemovePoliciesErrorOnDeleteForEgress(t *testing.T) { calls := []testutils.TestCmd{ fakeIPTablesRestoreCommand, getFakeDeleteJumpCommand("AZURE-NPM-INGRESS", testPolicy1IngressJump), @@ -137,3 +145,34 @@ func TestRemovePoliciesErrorOnEgressRule(t *testing.T) { err = pMgr.RemovePolicy(TestNetworkPolicies[0].Name, nil) require.Error(t, err) } + +func TestUpdatingChainsToCleanup(t *testing.T) { + calls := GetAddPolicyTestCalls(TestNetworkPolicies[0]) + calls = append(calls, GetRemovePolicyTestCalls(TestNetworkPolicies[0])...) + calls = append(calls, GetAddPolicyTestCalls(TestNetworkPolicies[1])...) + calls = append(calls, GetRemovePolicyFailureTestCalls(TestNetworkPolicies[1])...) + calls = append(calls, GetAddPolicyTestCalls(TestNetworkPolicies[2])...) + calls = append(calls, GetRemovePolicyTestCalls(TestNetworkPolicies[2])...) + calls = append(calls, GetAddPolicyFailureTestCalls(TestNetworkPolicies[2])...) + calls = append(calls, GetAddPolicyTestCalls(TestNetworkPolicies[0])...) + ioshim := common.NewMockIOShim(calls) + // TODO defer ioshim.VerifyCalls(t, ioshim, calls) + pMgr := NewPolicyManager(ioshim) + + require.NoError(t, pMgr.AddPolicy(TestNetworkPolicies[0], nil)) + assertStaleChainsContain(t, pMgr.staleChains) + require.NoError(t, pMgr.RemovePolicy(TestNetworkPolicies[0].Name, nil)) + assertStaleChainsContain(t, pMgr.staleChains, testPolicy1IngressChain, testPolicy1EgressChain) + + // TODO uncomment when grep stuff is fixed + // require.NoError(t, pMgr.AddPolicy(TestNetworkPolicies[1], nil)) + // assertStaleChainsContain(t, pMgr.staleChains, testPolicy1IngressChain, testPolicy1EgressChain) + // require.Error(t, pMgr.RemovePolicy(TestNetworkPolicies[1].Name, nil)) + // assertStaleChainsContain(t, pMgr.staleChains, testPolicy1IngressChain, testPolicy1EgressChain) + // require.NoError(t, pMgr.AddPolicy(TestNetworkPolicies[2], nil)) + // assertStaleChainsContain(t, pMgr.staleChains, testPolicy1IngressChain, testPolicy1EgressChain) + // require.Error(t, pMgr.RemovePolicy(TestNetworkPolicies[2].Name, nil)) + // assertStaleChainsContain(t, pMgr.staleChains, testPolicy1IngressChain, testPolicy1EgressChain, testPolicy3EgressChain) + // require.NoError(t, pMgr.AddPolicy(TestNetworkPolicies[0], nil)) + // assertStaleChainsContain(t, pMgr.staleChains, testPolicy3EgressChain) +} diff --git a/npm/pkg/dataplane/policies/policymanager_windows.go b/npm/pkg/dataplane/policies/policymanager_windows.go index 6a56e2d6e3..acbcb55267 100644 --- a/npm/pkg/dataplane/policies/policymanager_windows.go +++ b/npm/pkg/dataplane/policies/policymanager_windows.go @@ -15,11 +15,22 @@ var ( ErrFailedUnMarshalACLSettings = errors.New("Failed to unmarshal ACL settings") ) +type staleChains struct{} // unused in Windows + type endpointPolicyBuilder struct { aclPolicies []*NPMACLPolSettings otherPolicies []hcn.EndpointPolicy } +func newStaleChains() *staleChains { + return &staleChains{} +} + +func (pMgr *PolicyManager) reboot() error { + // TODO should we something here? + return nil +} + func (pMgr *PolicyManager) initialize() error { // TODO return nil @@ -30,6 +41,10 @@ func (pMgr *PolicyManager) reset() error { return nil } +func (pMgr *PolicyManager) reconcile() { + // TODO +} + func (pMgr *PolicyManager) addPolicy(policy *NPMNetworkPolicy, endpointList map[string]string) error { klog.Infof("[DataPlane Windows] adding policy %s on %+v", policy.Name, endpointList) if endpointList == nil { diff --git a/npm/pkg/dataplane/policies/testutils_linux.go b/npm/pkg/dataplane/policies/testutils_linux.go index 0710a74d39..2220aba949 100644 --- a/npm/pkg/dataplane/policies/testutils_linux.go +++ b/npm/pkg/dataplane/policies/testutils_linux.go @@ -19,17 +19,21 @@ func GetAddPolicyTestCalls(_ *NPMNetworkPolicy) []testutils.TestCmd { return []testutils.TestCmd{fakeIPTablesRestoreCommand} } +func GetAddPolicyFailureTestCalls(_ *NPMNetworkPolicy) []testutils.TestCmd { + return []testutils.TestCmd{fakeIPTablesRestoreFailureCommand} +} + func GetRemovePolicyTestCalls(policy *NPMNetworkPolicy) []testutils.TestCmd { calls := []testutils.TestCmd{} hasIngress, hasEgress := policy.hasIngressAndEgress() if hasIngress { deleteIngressJumpSpecs := []string{"iptables", "-w", "60", "-D", util.IptablesAzureIngressChain} - deleteIngressJumpSpecs = append(deleteIngressJumpSpecs, getIngressJumpSpecs(policy)...) + deleteIngressJumpSpecs = append(deleteIngressJumpSpecs, ingressJumpSpecs(policy)...) calls = append(calls, testutils.TestCmd{Cmd: deleteIngressJumpSpecs}) } if hasEgress { deleteEgressJumpSpecs := []string{"iptables", "-w", "60", "-D", util.IptablesAzureEgressChain} - deleteEgressJumpSpecs = append(deleteEgressJumpSpecs, getEgressJumpSpecs(policy)...) + deleteEgressJumpSpecs = append(deleteEgressJumpSpecs, egressJumpSpecs(policy)...) calls = append(calls, testutils.TestCmd{Cmd: deleteEgressJumpSpecs}) } @@ -37,6 +41,13 @@ func GetRemovePolicyTestCalls(policy *NPMNetworkPolicy) []testutils.TestCmd { return calls } +// GetRemovePolicyFailureTestCalls fails on the restore +func GetRemovePolicyFailureTestCalls(policy *NPMNetworkPolicy) []testutils.TestCmd { + calls := GetRemovePolicyTestCalls(policy) + calls[len(calls)-1] = fakeIPTablesRestoreFailureCommand // replace the restore success with a failure + return calls +} + func GetInitializeTestCalls() []testutils.TestCmd { return []testutils.TestCmd{ fakeIPTablesRestoreCommand, // gives correct exit code