diff --git a/cni/network/invoker_cns.go b/cni/network/invoker_cns.go index 7bf72779d4..f3e75690f7 100644 --- a/cni/network/invoker_cns.go +++ b/cni/network/invoker_cns.go @@ -227,25 +227,26 @@ func setHostOptions(ncSubnetPrefix *net.IPNet, options map[string]interface{}, i // we need to snat IMDS traffic to node IP, this sets up snat '--to' snatHostIPJump := fmt.Sprintf("%s --to %s", iptables.Snat, info.hostPrimaryIP) + iptablesClient := iptables.NewClient() var iptableCmds []iptables.IPTableEntry - if !iptables.ChainExists(iptables.V4, iptables.Nat, iptables.Swift) { - iptableCmds = append(iptableCmds, iptables.GetCreateChainCmd(iptables.V4, iptables.Nat, iptables.Swift)) + if !iptablesClient.ChainExists(iptables.V4, iptables.Nat, iptables.Swift) { + iptableCmds = append(iptableCmds, iptablesClient.GetCreateChainCmd(iptables.V4, iptables.Nat, iptables.Swift)) } - if !iptables.RuleExists(iptables.V4, iptables.Nat, iptables.Postrouting, "", iptables.Swift) { - iptableCmds = append(iptableCmds, iptables.GetAppendIptableRuleCmd(iptables.V4, iptables.Nat, iptables.Postrouting, "", iptables.Swift)) + if !iptablesClient.RuleExists(iptables.V4, iptables.Nat, iptables.Postrouting, "", iptables.Swift) { + iptableCmds = append(iptableCmds, iptablesClient.GetAppendIptableRuleCmd(iptables.V4, iptables.Nat, iptables.Postrouting, "", iptables.Swift)) } - if !iptables.RuleExists(iptables.V4, iptables.Nat, iptables.Swift, azureDNSUDPMatch, snatPrimaryIPJump) { - iptableCmds = append(iptableCmds, iptables.GetInsertIptableRuleCmd(iptables.V4, iptables.Nat, iptables.Swift, azureDNSUDPMatch, snatPrimaryIPJump)) + if !iptablesClient.RuleExists(iptables.V4, iptables.Nat, iptables.Swift, azureDNSUDPMatch, snatPrimaryIPJump) { + iptableCmds = append(iptableCmds, iptablesClient.GetInsertIptableRuleCmd(iptables.V4, iptables.Nat, iptables.Swift, azureDNSUDPMatch, snatPrimaryIPJump)) } - if !iptables.RuleExists(iptables.V4, iptables.Nat, iptables.Swift, azureDNSTCPMatch, snatPrimaryIPJump) { - iptableCmds = append(iptableCmds, iptables.GetInsertIptableRuleCmd(iptables.V4, iptables.Nat, iptables.Swift, azureDNSTCPMatch, snatPrimaryIPJump)) + if !iptablesClient.RuleExists(iptables.V4, iptables.Nat, iptables.Swift, azureDNSTCPMatch, snatPrimaryIPJump) { + iptableCmds = append(iptableCmds, iptablesClient.GetInsertIptableRuleCmd(iptables.V4, iptables.Nat, iptables.Swift, azureDNSTCPMatch, snatPrimaryIPJump)) } - if !iptables.RuleExists(iptables.V4, iptables.Nat, iptables.Swift, azureIMDSMatch, snatHostIPJump) { - iptableCmds = append(iptableCmds, iptables.GetInsertIptableRuleCmd(iptables.V4, iptables.Nat, iptables.Swift, azureIMDSMatch, snatHostIPJump)) + if !iptablesClient.RuleExists(iptables.V4, iptables.Nat, iptables.Swift, azureIMDSMatch, snatHostIPJump) { + iptableCmds = append(iptableCmds, iptablesClient.GetInsertIptableRuleCmd(iptables.V4, iptables.Nat, iptables.Swift, azureIMDSMatch, snatHostIPJump)) } options[network.IPTablesKey] = iptableCmds diff --git a/cni/network/network.go b/cni/network/network.go index ba603a2268..2381b8ffc4 100644 --- a/cni/network/network.go +++ b/cni/network/network.go @@ -117,7 +117,7 @@ func NewPlugin(name string, nl := netlink.NewNetlink() // Setup network manager. - nm, err := network.NewNetworkManager(nl, platform.NewExecClient(logger), &netio.NetIO{}, network.NewNamespaceClient()) + nm, err := network.NewNetworkManager(nl, platform.NewExecClient(logger), &netio.NetIO{}, network.NewNamespaceClient(), iptables.NewClient()) if err != nil { return nil, err } diff --git a/cnm/network/network.go b/cnm/network/network.go index faf6dd775f..4358d1a485 100644 --- a/cnm/network/network.go +++ b/cnm/network/network.go @@ -11,6 +11,7 @@ import ( "github.com/Azure/azure-container-networking/cnm" cnsclient "github.com/Azure/azure-container-networking/cns/client" "github.com/Azure/azure-container-networking/common" + "github.com/Azure/azure-container-networking/iptables" "github.com/Azure/azure-container-networking/log" "github.com/Azure/azure-container-networking/netio" "github.com/Azure/azure-container-networking/netlink" @@ -53,7 +54,7 @@ func NewPlugin(config *common.PluginConfig) (NetPlugin, error) { nl := netlink.NewNetlink() // Setup network manager. - nm, err := network.NewNetworkManager(nl, platform.NewExecClient(nil), &netio.NetIO{}, network.NewNamespaceClient()) + nm, err := network.NewNetworkManager(nl, platform.NewExecClient(nil), &netio.NetIO{}, network.NewNamespaceClient(), iptables.NewClient()) if err != nil { return nil, err } diff --git a/cnms/service/networkmonitor.go b/cnms/service/networkmonitor.go index 545442c8f2..96c790a667 100644 --- a/cnms/service/networkmonitor.go +++ b/cnms/service/networkmonitor.go @@ -10,6 +10,7 @@ import ( cnms "github.com/Azure/azure-container-networking/cnms/cnmspackage" acn "github.com/Azure/azure-container-networking/common" + "github.com/Azure/azure-container-networking/iptables" "github.com/Azure/azure-container-networking/log" "github.com/Azure/azure-container-networking/netio" "github.com/Azure/azure-container-networking/netlink" @@ -157,7 +158,7 @@ func main() { } nl := netlink.NewNetlink() - nm, err := network.NewNetworkManager(nl, platform.NewExecClient(nil), &netio.NetIO{}, network.NewNamespaceClient()) + nm, err := network.NewNetworkManager(nl, platform.NewExecClient(nil), &netio.NetIO{}, network.NewNamespaceClient(), iptables.NewClient()) if err != nil { log.Printf("[monitor] Failed while creating network manager") return diff --git a/iptables/iptables.go b/iptables/iptables.go index 2e80e30227..2d3fdf2e22 100644 --- a/iptables/iptables.go +++ b/iptables/iptables.go @@ -87,8 +87,14 @@ type IPTableEntry struct { Params string } +type Client struct{} + +func NewClient() *Client { + return &Client{} +} + // Run iptables command -func RunCmd(version, params string) error { +func (c *Client) RunCmd(version, params string) error { var cmd string p := platform.NewExecClient(logger) @@ -111,16 +117,16 @@ func RunCmd(version, params string) error { } // check if iptable chain alreay exists -func ChainExists(version, tableName, chainName string) bool { +func (c *Client) ChainExists(version, tableName, chainName string) bool { params := fmt.Sprintf("-t %s -L %s", tableName, chainName) - if err := RunCmd(version, params); err != nil { + if err := c.RunCmd(version, params); err != nil { return false } return true } -func GetCreateChainCmd(version, tableName, chainName string) IPTableEntry { +func (c *Client) GetCreateChainCmd(version, tableName, chainName string) IPTableEntry { return IPTableEntry{ Version: version, Params: fmt.Sprintf("-t %s -N %s", tableName, chainName), @@ -128,12 +134,12 @@ func GetCreateChainCmd(version, tableName, chainName string) IPTableEntry { } // create new iptable chain under specified table name -func CreateChain(version, tableName, chainName string) error { +func (c *Client) CreateChain(version, tableName, chainName string) error { var err error - if !ChainExists(version, tableName, chainName) { - cmd := GetCreateChainCmd(version, tableName, chainName) - err = RunCmd(version, cmd.Params) + if !c.ChainExists(version, tableName, chainName) { + cmd := c.GetCreateChainCmd(version, tableName, chainName) + err = c.RunCmd(version, cmd.Params) } else { logger.Info("Chain exists in table", zap.String("chainName", chainName), zap.String("tableName", tableName)) } @@ -142,15 +148,15 @@ func CreateChain(version, tableName, chainName string) error { } // check if iptable rule alreay exists -func RuleExists(version, tableName, chainName, match, target string) bool { +func (c *Client) RuleExists(version, tableName, chainName, match, target string) bool { params := fmt.Sprintf("-t %s -C %s %s -j %s", tableName, chainName, match, target) - if err := RunCmd(version, params); err != nil { + if err := c.RunCmd(version, params); err != nil { return false } return true } -func GetInsertIptableRuleCmd(version, tableName, chainName, match, target string) IPTableEntry { +func (c *Client) GetInsertIptableRuleCmd(version, tableName, chainName, match, target string) IPTableEntry { return IPTableEntry{ Version: version, Params: fmt.Sprintf("-t %s -I %s 1 %s -j %s", tableName, chainName, match, target), @@ -158,17 +164,17 @@ func GetInsertIptableRuleCmd(version, tableName, chainName, match, target string } // Insert iptable rule at beginning of iptable chain -func InsertIptableRule(version, tableName, chainName, match, target string) error { - if RuleExists(version, tableName, chainName, match, target) { +func (c *Client) InsertIptableRule(version, tableName, chainName, match, target string) error { + if c.RuleExists(version, tableName, chainName, match, target) { logger.Info("Rule already exists") return nil } - cmd := GetInsertIptableRuleCmd(version, tableName, chainName, match, target) - return RunCmd(version, cmd.Params) + cmd := c.GetInsertIptableRuleCmd(version, tableName, chainName, match, target) + return c.RunCmd(version, cmd.Params) } -func GetAppendIptableRuleCmd(version, tableName, chainName, match, target string) IPTableEntry { +func (c *Client) GetAppendIptableRuleCmd(version, tableName, chainName, match, target string) IPTableEntry { return IPTableEntry{ Version: version, Params: fmt.Sprintf("-t %s -A %s %s -j %s", tableName, chainName, match, target), @@ -176,18 +182,18 @@ func GetAppendIptableRuleCmd(version, tableName, chainName, match, target string } // Append iptable rule at end of iptable chain -func AppendIptableRule(version, tableName, chainName, match, target string) error { - if RuleExists(version, tableName, chainName, match, target) { +func (c *Client) AppendIptableRule(version, tableName, chainName, match, target string) error { + if c.RuleExists(version, tableName, chainName, match, target) { logger.Info("Rule already exists") return nil } - cmd := GetAppendIptableRuleCmd(version, tableName, chainName, match, target) - return RunCmd(version, cmd.Params) + cmd := c.GetAppendIptableRuleCmd(version, tableName, chainName, match, target) + return c.RunCmd(version, cmd.Params) } // Delete matched iptable rule -func DeleteIptableRule(version, tableName, chainName, match, target string) error { +func (c *Client) DeleteIptableRule(version, tableName, chainName, match, target string) error { params := fmt.Sprintf("-t %s -D %s %s -j %s", tableName, chainName, match, target) - return RunCmd(version, params) + return c.RunCmd(version, params) } diff --git a/network/endpoint.go b/network/endpoint.go index d8d20553c3..0adf857ea9 100644 --- a/network/endpoint.go +++ b/network/endpoint.go @@ -140,6 +140,7 @@ func (nw *network) newEndpoint( plc platform.ExecClient, netioCli netio.NetIOInterface, nsc NamespaceClientInterface, + iptc ipTablesClient, epInfo []*EndpointInfo, ) (*endpoint, error) { var ep *endpoint @@ -153,7 +154,7 @@ func (nw *network) newEndpoint( // Call the platform implementation. // Pass nil for epClient and will be initialized in newendpointImpl - ep, err = nw.newEndpointImpl(apipaCli, nl, plc, netioCli, nil, nsc, epInfo) + ep, err = nw.newEndpointImpl(apipaCli, nl, plc, netioCli, nil, nsc, iptc, epInfo) if err != nil { return nil, err } @@ -164,7 +165,9 @@ func (nw *network) newEndpoint( } // DeleteEndpoint deletes an existing endpoint from the network. -func (nw *network) deleteEndpoint(nl netlink.NetlinkInterface, plc platform.ExecClient, nioc netio.NetIOInterface, nsc NamespaceClientInterface, endpointID string) error { +func (nw *network) deleteEndpoint(nl netlink.NetlinkInterface, plc platform.ExecClient, nioc netio.NetIOInterface, nsc NamespaceClientInterface, + iptc ipTablesClient, endpointID string, +) error { var err error logger.Info("Deleting endpoint from network", zap.String("endpointID", endpointID), zap.String("id", nw.Id)) @@ -183,7 +186,7 @@ func (nw *network) deleteEndpoint(nl netlink.NetlinkInterface, plc platform.Exec // Call the platform implementation. // Pass nil for epClient and will be initialized in deleteEndpointImpl - err = nw.deleteEndpointImpl(nl, plc, nil, nioc, nsc, ep) + err = nw.deleteEndpointImpl(nl, plc, nil, nioc, nsc, iptc, ep) if err != nil { return err } diff --git a/network/endpoint_linux.go b/network/endpoint_linux.go index 60cdf61744..f603f389fb 100644 --- a/network/endpoint_linux.go +++ b/network/endpoint_linux.go @@ -56,6 +56,7 @@ func (nw *network) newEndpointImpl( netioCli netio.NetIOInterface, testEpClient EndpointClient, nsc NamespaceClientInterface, + iptc ipTablesClient, epInfo []*EndpointInfo, ) (*endpoint, error) { var ( @@ -134,7 +135,7 @@ func (nw *network) newEndpointImpl( if _, ok := epInfo.Data[SnatBridgeIPKey]; ok { nw.SnatBridgeIP = epInfo.Data[SnatBridgeIPKey].(string) } - epClient = NewTransparentVlanEndpointClient(nw, epInfo, hostIfName, contIfName, vlanid, localIP, nl, plc, nsc) + epClient = NewTransparentVlanEndpointClient(nw, epInfo, hostIfName, contIfName, vlanid, localIP, nl, plc, nsc, iptc) } else { logger.Info("OVS client") if _, ok := epInfo.Data[SnatBridgeIPKey]; ok { @@ -150,7 +151,8 @@ func (nw *network) newEndpointImpl( localIP, nl, ovsctl.NewOvsctl(), - plc) + plc, + iptc) } } else if nw.Mode != opModeTransparent { logger.Info("Bridge client") @@ -255,7 +257,9 @@ func (nw *network) newEndpointImpl( } // deleteEndpointImpl deletes an existing endpoint from the network. -func (nw *network) deleteEndpointImpl(nl netlink.NetlinkInterface, plc platform.ExecClient, epClient EndpointClient, nioc netio.NetIOInterface, nsc NamespaceClientInterface, ep *endpoint) error { +func (nw *network) deleteEndpointImpl(nl netlink.NetlinkInterface, plc platform.ExecClient, epClient EndpointClient, nioc netio.NetIOInterface, nsc NamespaceClientInterface, + iptc ipTablesClient, ep *endpoint, +) error { // Delete the veth pair by deleting one of the peer interfaces. // Deleting the host interface is more convenient since it does not require // entering the container netns and hence works both for CNI and CNM. @@ -267,10 +271,10 @@ func (nw *network) deleteEndpointImpl(nl netlink.NetlinkInterface, plc platform. epInfo := ep.getInfo() if nw.Mode == opModeTransparentVlan { logger.Info("Transparent vlan client") - epClient = NewTransparentVlanEndpointClient(nw, epInfo, ep.HostIfName, "", ep.VlanID, ep.LocalIP, nl, plc, nsc) + epClient = NewTransparentVlanEndpointClient(nw, epInfo, ep.HostIfName, "", ep.VlanID, ep.LocalIP, nl, plc, nsc, iptc) } else { - epClient = NewOVSEndpointClient(nw, epInfo, ep.HostIfName, "", ep.VlanID, ep.LocalIP, nl, ovsctl.NewOvsctl(), plc) + epClient = NewOVSEndpointClient(nw, epInfo, ep.HostIfName, "", ep.VlanID, ep.LocalIP, nl, ovsctl.NewOvsctl(), plc, iptc) } } else if nw.Mode != opModeTransparent { epClient = NewLinuxBridgeEndpointClient(nw.extIf, ep.HostIfName, "", nw.Mode, nl, plc) diff --git a/network/endpoint_snatroute_linux.go b/network/endpoint_snatroute_linux.go index 3f45a7f16e..0f8de4e1ac 100644 --- a/network/endpoint_snatroute_linux.go +++ b/network/endpoint_snatroute_linux.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/Azure/azure-container-networking/netlink" - "github.com/Azure/azure-container-networking/network/networkutils" "github.com/Azure/azure-container-networking/network/snat" "github.com/Azure/azure-container-networking/platform" "github.com/pkg/errors" @@ -36,8 +35,7 @@ func AddSnatEndpointRules(snatClient *snat.Client, hostToNC, ncToHost bool, nl n if err := snatClient.BlockIPAddressesOnSnatBridge(); err != nil { return errors.Wrap(err, "failed to block ip addresses on snat bridge") } - nuc := networkutils.NewNetworkUtils(nl, plc) - if err := nuc.EnableIPForwarding(); err != nil { + if err := snatClient.EnableIPForwarding(); err != nil { return errors.Wrap(err, "failed to enable ip forwarding") } diff --git a/network/endpoint_test.go b/network/endpoint_test.go index 669b16804a..0fc0cf829e 100644 --- a/network/endpoint_test.go +++ b/network/endpoint_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/Azure/azure-container-networking/cns" + "github.com/Azure/azure-container-networking/iptables" "github.com/Azure/azure-container-networking/netio" "github.com/Azure/azure-container-networking/netlink" "github.com/Azure/azure-container-networking/platform" @@ -183,7 +184,7 @@ var _ = Describe("Test Endpoint", func() { It("Should be added", func() { // Add endpoint with valid id ep, err := nw.newEndpointImpl(nil, netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), - netio.NewMockNetIO(false, 0), NewMockEndpointClient(nil), NewMockNamespaceClient(), []*EndpointInfo{epInfo}) + netio.NewMockNetIO(false, 0), NewMockEndpointClient(nil), NewMockNamespaceClient(), iptables.NewClient(), []*EndpointInfo{epInfo}) Expect(err).NotTo(HaveOccurred()) Expect(ep).NotTo(BeNil()) Expect(ep.Id).To(Equal(epInfo.Id)) @@ -195,7 +196,7 @@ var _ = Describe("Test Endpoint", func() { extIf: &externalInterface{IPv4Gateway: net.ParseIP("192.168.0.1")}, } ep, err := nw2.newEndpointImpl(nil, netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), - netio.NewMockNetIO(false, 0), NewMockEndpointClient(nil), NewMockNamespaceClient(), []*EndpointInfo{epInfo}) + netio.NewMockNetIO(false, 0), NewMockEndpointClient(nil), NewMockNamespaceClient(), iptables.NewClient(), []*EndpointInfo{epInfo}) Expect(err).NotTo(HaveOccurred()) Expect(ep).NotTo(BeNil()) Expect(ep.Id).To(Equal(epInfo.Id)) @@ -211,7 +212,7 @@ var _ = Describe("Test Endpoint", func() { Expect(err).ToNot(HaveOccurred()) // Adding endpoint with same id should fail and delete should cleanup the state ep2, err := nw.newEndpointImpl(nil, netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), - netio.NewMockNetIO(false, 0), mockCli, NewMockNamespaceClient(), []*EndpointInfo{epInfo}) + netio.NewMockNetIO(false, 0), mockCli, NewMockNamespaceClient(), iptables.NewClient(), []*EndpointInfo{epInfo}) Expect(err).To(HaveOccurred()) Expect(ep2).To(BeNil()) assert.Contains(GinkgoT(), err.Error(), "Endpoint already exists") @@ -221,17 +222,17 @@ var _ = Describe("Test Endpoint", func() { // Adding an endpoint with an id. mockCli := NewMockEndpointClient(nil) ep2, err := nw.newEndpointImpl(nil, netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), - netio.NewMockNetIO(false, 0), mockCli, NewMockNamespaceClient(), []*EndpointInfo{epInfo}) + netio.NewMockNetIO(false, 0), mockCli, NewMockNamespaceClient(), iptables.NewClient(), []*EndpointInfo{epInfo}) Expect(err).ToNot(HaveOccurred()) Expect(ep2).ToNot(BeNil()) Expect(len(mockCli.endpoints)).To(Equal(1)) // Deleting the endpoint //nolint:errcheck // ignore error - nw.deleteEndpointImpl(netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), mockCli, netio.NewMockNetIO(false, 0), NewMockNamespaceClient(), ep2) + nw.deleteEndpointImpl(netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), mockCli, netio.NewMockNetIO(false, 0), NewMockNamespaceClient(), iptables.NewClient(), ep2) Expect(len(mockCli.endpoints)).To(Equal(0)) // Deleting same endpoint with same id should not fail //nolint:errcheck // ignore error - nw.deleteEndpointImpl(netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), mockCli, netio.NewMockNetIO(false, 0), NewMockNamespaceClient(), ep2) + nw.deleteEndpointImpl(netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), mockCli, netio.NewMockNetIO(false, 0), NewMockNamespaceClient(), iptables.NewClient(), ep2) Expect(len(mockCli.endpoints)).To(Equal(0)) }) }) @@ -252,11 +253,11 @@ var _ = Describe("Test Endpoint", func() { } return nil - }), NewMockNamespaceClient(), []*EndpointInfo{epInfo}) + }), NewMockNamespaceClient(), iptables.NewClient(), []*EndpointInfo{epInfo}) Expect(err).To(HaveOccurred()) Expect(ep).To(BeNil()) ep, err = nw.newEndpointImpl(nil, netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), - netio.NewMockNetIO(false, 0), NewMockEndpointClient(nil), NewMockNamespaceClient(), []*EndpointInfo{epInfo}) + netio.NewMockNetIO(false, 0), NewMockEndpointClient(nil), NewMockNamespaceClient(), iptables.NewClient(), []*EndpointInfo{epInfo}) Expect(err).NotTo(HaveOccurred()) Expect(ep).NotTo(BeNil()) Expect(ep.Id).To(Equal(epInfo.Id)) @@ -282,14 +283,14 @@ var _ = Describe("Test Endpoint", func() { It("Should not endpoint to the network when there is an error", func() { secondaryEpInfo.MacAddress = netio.BadHwAddr // mock netlink will fail to set link state on bad eth ep, err := nw.newEndpointImpl(nil, netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), - netio.NewMockNetIO(false, 0), nil, NewMockNamespaceClient(), []*EndpointInfo{epInfo, secondaryEpInfo}) + netio.NewMockNetIO(false, 0), nil, NewMockNamespaceClient(), iptables.NewClient(), []*EndpointInfo{epInfo, secondaryEpInfo}) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(Equal("SecondaryEndpointClient Error: " + netlink.ErrorMockNetlink.Error())) Expect(ep).To(BeNil()) secondaryEpInfo.MacAddress = netio.HwAddr ep, err = nw.newEndpointImpl(nil, netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), - netio.NewMockNetIO(false, 0), nil, NewMockNamespaceClient(), []*EndpointInfo{epInfo, secondaryEpInfo}) + netio.NewMockNetIO(false, 0), nil, NewMockNamespaceClient(), iptables.NewClient(), []*EndpointInfo{epInfo, secondaryEpInfo}) Expect(err).ToNot(HaveOccurred()) Expect(ep.Id).To(Equal(epInfo.Id)) }) @@ -297,7 +298,7 @@ var _ = Describe("Test Endpoint", func() { It("Should add endpoint when there are no errors", func() { secondaryEpInfo.MacAddress = netio.HwAddr ep, err := nw.newEndpointImpl(nil, netlink.NewMockNetlink(false, ""), platform.NewMockExecClient(false), - netio.NewMockNetIO(false, 0), nil, NewMockNamespaceClient(), []*EndpointInfo{epInfo, secondaryEpInfo}) + netio.NewMockNetIO(false, 0), nil, NewMockNamespaceClient(), iptables.NewClient(), []*EndpointInfo{epInfo, secondaryEpInfo}) Expect(err).ToNot(HaveOccurred()) Expect(ep.Id).To(Equal(epInfo.Id)) }) diff --git a/network/endpoint_windows.go b/network/endpoint_windows.go index 4c17db94fb..244ab6f234 100644 --- a/network/endpoint_windows.go +++ b/network/endpoint_windows.go @@ -71,6 +71,7 @@ func (nw *network) newEndpointImpl( _ netio.NetIOInterface, _ EndpointClient, _ NamespaceClientInterface, + _ ipTablesClient, epInfo []*EndpointInfo, ) (*endpoint, error) { // there is only 1 epInfo for windows, multiple interfaces will be added in the future @@ -409,7 +410,9 @@ func (nw *network) newEndpointImplHnsV2(cli apipaClient, epInfo *EndpointInfo) ( } // deleteEndpointImpl deletes an existing endpoint from the network. -func (nw *network) deleteEndpointImpl(_ netlink.NetlinkInterface, _ platform.ExecClient, _ EndpointClient, _ netio.NetIOInterface, _ NamespaceClientInterface, ep *endpoint) error { +func (nw *network) deleteEndpointImpl(_ netlink.NetlinkInterface, _ platform.ExecClient, _ EndpointClient, _ netio.NetIOInterface, _ NamespaceClientInterface, + _ ipTablesClient, ep *endpoint, +) error { if useHnsV2, err := UseHnsV2(ep.NetNs); useHnsV2 { if err != nil { return err diff --git a/network/iptables.go b/network/iptables.go new file mode 100644 index 0000000000..c572198453 --- /dev/null +++ b/network/iptables.go @@ -0,0 +1,9 @@ +package network + +type ipTablesClient interface { + InsertIptableRule(version, tableName, chainName, match, target string) error + AppendIptableRule(version, tableName, chainName, match, target string) error + DeleteIptableRule(version, tableName, chainName, match, target string) error + CreateChain(version, tableName, chainName string) error + RunCmd(version, params string) error +} diff --git a/network/manager.go b/network/manager.go index a36424ee05..65e45c2440 100644 --- a/network/manager.go +++ b/network/manager.go @@ -81,6 +81,7 @@ type networkManager struct { netio netio.NetIOInterface plClient platform.ExecClient nsClient NamespaceClientInterface + iptablesClient ipTablesClient sync.Mutex } @@ -113,13 +114,16 @@ type NetworkManager interface { } // Creates a new network manager. -func NewNetworkManager(nl netlink.NetlinkInterface, plc platform.ExecClient, netioCli netio.NetIOInterface, nsc NamespaceClientInterface) (NetworkManager, error) { +func NewNetworkManager(nl netlink.NetlinkInterface, plc platform.ExecClient, netioCli netio.NetIOInterface, nsc NamespaceClientInterface, + iptc ipTablesClient, +) (NetworkManager, error) { nm := &networkManager{ ExternalInterfaces: make(map[string]*externalInterface), netlink: nl, plClient: plc, netio: netioCli, nsClient: nsc, + iptablesClient: iptc, } return nm, nil @@ -386,7 +390,8 @@ func (nm *networkManager) CreateEndpoint(cli apipaClient, networkID string, epIn epInfo[0].Data[VlanIDKey] = nw.VlanId } } - ep, err := nw.newEndpoint(cli, nm.netlink, nm.plClient, nm.netio, nm.nsClient, epInfo) + + ep, err := nw.newEndpoint(cli, nm.netlink, nm.plClient, nm.netio, nm.nsClient, nm.iptablesClient, epInfo) if err != nil { return err } @@ -429,7 +434,7 @@ func (nm *networkManager) DeleteEndpoint(networkID, endpointID string, epInfo *E return err } - err = nw.deleteEndpoint(nm.netlink, nm.plClient, nm.netio, nm.nsClient, endpointID) + err = nw.deleteEndpoint(nm.netlink, nm.plClient, nm.netio, nm.nsClient, nm.iptablesClient, endpointID) if err != nil { return err } @@ -466,7 +471,7 @@ func (nm *networkManager) DeleteEndpointState(networkID string, epInfo *Endpoint NetworkContainerID: epInfo.Id, } logger.Info("Deleting endpoint with", zap.String("Endpoint Info: ", epInfo.PrettyString()), zap.String("HNISID : ", ep.HnsId)) - return nw.deleteEndpointImpl(netlink.NewNetlink(), platform.NewExecClient(logger), nil, nil, nil, ep) + return nw.deleteEndpointImpl(netlink.NewNetlink(), platform.NewExecClient(logger), nil, nil, nil, nil, ep) } // GetEndpointInfo returns information about the given endpoint. diff --git a/network/network_linux.go b/network/network_linux.go index 503bbf1a09..c66e6263c8 100644 --- a/network/network_linux.go +++ b/network/network_linux.go @@ -103,7 +103,7 @@ func (nm *networkManager) newNetworkImpl(nwInfo *NetworkInfo, extIf *externalInt } logger.Info("Disabled ipv6") // Blocks wireserver traffic from apipa nic - if err := networkutils.BlockEgressTrafficFromContainer(iptables.V4, networkutils.AzureDNS, iptables.TCP, iptables.HTTPPort); err != nil { + if err := nu.BlockEgressTrafficFromContainer(nm.iptablesClient, iptables.V4, networkutils.AzureDNS, iptables.TCP, iptables.HTTPPort); err != nil { return nil, errors.Wrap(err, "unable to insert vm iptables rule drop wireserver packets") } logger.Info("Block wireserver traffic rule added") @@ -611,7 +611,7 @@ func (nm *networkManager) connectExternalInterface(extIf *externalInterface, nwI // unmark packet if set by kube-proxy to skip kube-postrouting rule and processed // by cni snat rule - if err = iptables.InsertIptableRule(iptables.V6, iptables.Mangle, iptables.Postrouting, "", "MARK --set-mark 0x0"); err != nil { + if err = nm.iptablesClient.InsertIptableRule(iptables.V6, iptables.Mangle, iptables.Postrouting, "", "MARK --set-mark 0x0"); err != nil { logger.Error("Adding Iptable mangle rule failed", zap.Error(err)) return err } @@ -651,10 +651,10 @@ func (nm *networkManager) disconnectExternalInterface(extIf *externalInterface, logger.Info("Disconnected interface", zap.String("Name", extIf.Name)) } -func (*networkManager) addToIptables(cmds []iptables.IPTableEntry) error { +func (nm *networkManager) addToIptables(cmds []iptables.IPTableEntry) error { logger.Info("Adding additional iptable rules...") for _, cmd := range cmds { - err := iptables.RunCmd(cmd.Version, cmd.Params) + err := nm.iptablesClient.RunCmd(cmd.Version, cmd.Params) if err != nil { return err } @@ -684,7 +684,7 @@ func (nm *networkManager) addIpv6NatGateway(nwInfo *NetworkInfo) error { } // snat ipv6 traffic to secondary ipv6 ip before leaving VM -func (*networkManager) addIpv6SnatRule(extIf *externalInterface, nwInfo *NetworkInfo) error { +func (nm *networkManager) addIpv6SnatRule(extIf *externalInterface, nwInfo *NetworkInfo) error { var ( ipv6SnatRuleSet bool ipv6SubnetPrefix net.IPNet @@ -702,14 +702,16 @@ func (*networkManager) addIpv6SnatRule(extIf *externalInterface, nwInfo *Network } for _, ipAddr := range extIf.IPAddresses { - if ipAddr.IP.To4() == nil { - logger.Info("Adding ipv6 snat rule") - matchSrcPrefix := fmt.Sprintf("-s %s", ipv6SubnetPrefix.String()) - if err := networkutils.AddSnatRule(matchSrcPrefix, ipAddr.IP); err != nil { - return fmt.Errorf("Adding iptable snat rule failed:%w", err) - } - ipv6SnatRuleSet = true + if ipAddr.IP.To4() != nil { + continue + } + logger.Info("Adding ipv6 snat rule") + matchSrcPrefix := fmt.Sprintf("-s %s", ipv6SubnetPrefix.String()) + nu := networkutils.NewNetworkUtils(nm.netlink, nm.plClient) + if err := nu.AddSnatRule(nm.iptablesClient, matchSrcPrefix, ipAddr.IP); err != nil { + return fmt.Errorf("adding iptable snat rule failed:%w", err) } + ipv6SnatRuleSet = true } if !ipv6SnatRuleSet { diff --git a/network/networkutils/networkutils_linux.go b/network/networkutils/networkutils_linux.go index ad1c1a83f2..519da94f9c 100644 --- a/network/networkutils/networkutils_linux.go +++ b/network/networkutils/networkutils_linux.go @@ -31,7 +31,6 @@ RFC for Link Local Addresses: https://tools.ietf.org/html/rfc3927 */ const ( - enableIPForwardCmd = "sysctl -w net.ipv4.ip_forward=1" toggleIPV6Cmd = "sysctl -w net.ipv6.conf.all.disable_ipv6=%d" enableIPV6ForwardCmd = "sysctl -w net.ipv6.conf.all.forwarding=1" enableIPV4ForwardCmd = "sysctl -w net.ipv4.conf.all.forwarding=1" @@ -41,6 +40,12 @@ const ( var logger = log.CNILogger.With(zap.String("component", "net-utils")) +type ipTablesClient interface { + InsertIptableRule(version, tableName, chainName, match, target string) error + AppendIptableRule(version, tableName, chainName, match, target string) error + DeleteIptableRule(version, tableName, chainName, match, target string) error +} + var errorNetworkUtils = errors.New("NetworkUtils Error") func newErrorNetworkUtils(errStr string) error { @@ -130,7 +135,7 @@ func (nu NetworkUtils) AssignIPToInterface(interfaceName string, ipAddresses []n return nil } -func addOrDeleteFilterRule(bridgeName, action, ipAddress, chainName, target string) error { +func (nu NetworkUtils) addOrDeleteFilterRule(iptablesClient ipTablesClient, bridgeName, action, ipAddress, chainName, target string) error { var err error option := "i" @@ -142,32 +147,32 @@ func addOrDeleteFilterRule(bridgeName, action, ipAddress, chainName, target stri switch action { case iptables.Insert: - err = iptables.InsertIptableRule(iptables.V4, iptables.Filter, chainName, matchCondition, target) + err = iptablesClient.InsertIptableRule(iptables.V4, iptables.Filter, chainName, matchCondition, target) case iptables.Append: - err = iptables.AppendIptableRule(iptables.V4, iptables.Filter, chainName, matchCondition, target) + err = iptablesClient.AppendIptableRule(iptables.V4, iptables.Filter, chainName, matchCondition, target) case iptables.Delete: - err = iptables.DeleteIptableRule(iptables.V4, iptables.Filter, chainName, matchCondition, target) + err = iptablesClient.DeleteIptableRule(iptables.V4, iptables.Filter, chainName, matchCondition, target) } return err } -func AllowIPAddresses(bridgeName string, skipAddresses []string, action string) error { +func (nu NetworkUtils) AllowIPAddresses(iptablesClient ipTablesClient, bridgeName string, skipAddresses []string, action string) error { chains := getFilterChains() target := getFilterchainTarget() logger.Info("Addresses to allow", zap.Any("skipAddresses", skipAddresses)) for _, address := range skipAddresses { - if err := addOrDeleteFilterRule(bridgeName, action, address, chains[0], target[0]); err != nil { + if err := nu.addOrDeleteFilterRule(iptablesClient, bridgeName, action, address, chains[0], target[0]); err != nil { return err } - if err := addOrDeleteFilterRule(bridgeName, action, address, chains[1], target[0]); err != nil { + if err := nu.addOrDeleteFilterRule(iptablesClient, bridgeName, action, address, chains[1], target[0]); err != nil { return err } - if err := addOrDeleteFilterRule(bridgeName, action, address, chains[2], target[0]); err != nil { + if err := nu.addOrDeleteFilterRule(iptablesClient, bridgeName, action, address, chains[2], target[0]); err != nil { return err } @@ -176,13 +181,13 @@ func AllowIPAddresses(bridgeName string, skipAddresses []string, action string) return nil } -func BlockEgressTrafficFromContainer(version, ipAddress, protocol string, port int) error { +func (nu NetworkUtils) BlockEgressTrafficFromContainer(iptablesClient ipTablesClient, version, ipAddress, protocol string, port int) error { // iptables -t filter -I FORWARD -j DROP -d -p -m --dport dropTraffic := fmt.Sprintf("-d %s -p %s -m %s --dport %d", ipAddress, protocol, protocol, port) - return errors.Wrap(iptables.InsertIptableRule(version, iptables.Filter, iptables.Forward, dropTraffic, iptables.Drop), "iptables block traffic failed") + return errors.Wrap(iptablesClient.InsertIptableRule(version, iptables.Filter, iptables.Forward, dropTraffic, iptables.Drop), "iptables block traffic failed") } -func BlockIPAddresses(bridgeName, action string) error { +func (nu NetworkUtils) BlockIPAddresses(iptablesClient ipTablesClient, bridgeName, action string) error { privateIPAddresses := getPrivateIPSpace() chains := getFilterChains() target := getFilterchainTarget() @@ -190,15 +195,15 @@ func BlockIPAddresses(bridgeName, action string) error { logger.Info("Addresses to block", zap.Any("privateIPAddresses", privateIPAddresses)) for _, ipAddress := range privateIPAddresses { - if err := addOrDeleteFilterRule(bridgeName, action, ipAddress, chains[0], target[1]); err != nil { + if err := nu.addOrDeleteFilterRule(iptablesClient, bridgeName, action, ipAddress, chains[0], target[1]); err != nil { return err } - if err := addOrDeleteFilterRule(bridgeName, action, ipAddress, chains[1], target[1]); err != nil { + if err := nu.addOrDeleteFilterRule(iptablesClient, bridgeName, action, ipAddress, chains[1], target[1]); err != nil { return err } - if err := addOrDeleteFilterRule(bridgeName, action, ipAddress, chains[2], target[1]); err != nil { + if err := nu.addOrDeleteFilterRule(iptablesClient, bridgeName, action, ipAddress, chains[2], target[1]); err != nil { return err } } @@ -206,27 +211,6 @@ func BlockIPAddresses(bridgeName, action string) error { return nil } -// This function enables ip forwarding in VM and allow forwarding packets from the interface -func (nu NetworkUtils) EnableIPForwarding() error { - // Enable ip forwading on linux vm. - // sysctl -w net.ipv4.ip_forward=1 - cmd := fmt.Sprint(enableIPForwardCmd) - _, err := nu.plClient.ExecuteCommand(cmd) - if err != nil { - logger.Error("Enable ipforwarding failed with", zap.Error(err)) - return err - } - - // Append a rule in forward chain to allow forwarding from bridge - if err := iptables.AppendIptableRule(iptables.V4, iptables.Filter, iptables.Forward, "", iptables.Accept); err != nil { - logger.Error("Appending forward chain rule: allow traffic coming from snatbridge failed with", - zap.Error(err)) - return err - } - - return nil -} - func (nu NetworkUtils) EnableIPV4Forwarding() error { _, err := nu.plClient.ExecuteCommand(enableIPV4ForwardCmd) if err != nil { @@ -260,15 +244,15 @@ func (nu NetworkUtils) UpdateIPV6Setting(disable int) error { return err } -// This fucntion adds rule which snat to ip passed filtered by match string. -func AddSnatRule(match string, ip net.IP) error { +// This function adds rule which snat to ip passed filtered by match string. +func (nu NetworkUtils) AddSnatRule(iptablesClient ipTablesClient, match string, ip net.IP) error { version := iptables.V4 if ip.To4() == nil { version = iptables.V6 } target := fmt.Sprintf("SNAT --to %s", ip.String()) - return iptables.InsertIptableRule(version, iptables.Nat, iptables.Postrouting, match, target) + return errors.Wrap(iptablesClient.InsertIptableRule(version, iptables.Nat, iptables.Postrouting, match, target), "failed to add snat rule") } func (nu NetworkUtils) DisableRAForInterface(ifName string) error { diff --git a/network/ovs_endpoint_snatroute_linux.go b/network/ovs_endpoint_snatroute_linux.go index 7df00ffb21..3e10db761c 100644 --- a/network/ovs_endpoint_snatroute_linux.go +++ b/network/ovs_endpoint_snatroute_linux.go @@ -32,6 +32,7 @@ func (client *OVSEndpointClient) NewSnatClient(snatBridgeIP, localIP string, epI false, client.netlink, client.plClient, + client.iptablesClient, ) } } diff --git a/network/ovs_endpointclient_linux.go b/network/ovs_endpointclient_linux.go index 428e533ab6..569232eb9a 100644 --- a/network/ovs_endpointclient_linux.go +++ b/network/ovs_endpointclient_linux.go @@ -35,6 +35,7 @@ type OVSEndpointClient struct { netioshim netio.NetIOInterface ovsctlClient ovsctl.OvsInterface plClient platform.ExecClient + iptablesClient ipTablesClient } const ( @@ -52,6 +53,7 @@ func NewOVSEndpointClient( nl netlink.NetlinkInterface, ovs ovsctl.OvsInterface, plc platform.ExecClient, + iptc ipTablesClient, ) *OVSEndpointClient { client := &OVSEndpointClient{ bridgeName: nw.extIf.BridgeName, @@ -68,6 +70,7 @@ func NewOVSEndpointClient( netlink: nl, ovsctlClient: ovs, plClient: plc, + iptablesClient: iptc, netioshim: &netio.NetIO{}, } diff --git a/network/snat/snat_linux.go b/network/snat/snat_linux.go index e1968243e5..fb7f348b78 100644 --- a/network/snat/snat_linux.go +++ b/network/snat/snat_linux.go @@ -26,10 +26,18 @@ const ( vlanDropAddRule = "ebtables -t nat -A PREROUTING -p 802_1Q -j DROP" vlanDropMatch = "-p 802_1Q -j DROP" l2PreroutingEntries = "ebtables -t nat -L PREROUTING" + enableIPForwardCmd = "sysctl -w net.ipv4.ip_forward=1" ) var logger = log.CNILogger.With(zap.String("component", "net")) +type ipTablesClient interface { + InsertIptableRule(version, tableName, chainName, match, target string) error + AppendIptableRule(version, tableName, chainName, match, target string) error + DeleteIptableRule(version, tableName, chainName, match, target string) error + CreateChain(version, tableName, chainName string) error +} + var errorSnatClient = errors.New("SnatClient Error") func newErrorSnatClient(errStr string) error { @@ -46,6 +54,7 @@ type Client struct { enableProxyArpOnBridge bool netlink netlink.NetlinkInterface plClient platform.ExecClient + ipTablesClient ipTablesClient } func NewSnatClient(hostIfName string, @@ -57,6 +66,7 @@ func NewSnatClient(hostIfName string, enableProxyArpOnBridge bool, nl netlink.NetlinkInterface, plClient platform.ExecClient, + iptc ipTablesClient, ) Client { logger.Info("Initialize new snat client") snatClient := Client{ @@ -68,6 +78,7 @@ func NewSnatClient(hostIfName string, enableProxyArpOnBridge: enableProxyArpOnBridge, netlink: nl, plClient: plClient, + ipTablesClient: iptc, } snatClient.SkipAddressesFromBlock = append(snatClient.SkipAddressesFromBlock, skipAddressesFromBlock...) @@ -120,7 +131,8 @@ func (client *Client) CreateSnatEndpoint() error { // AllowIPAddressesOnSnatBridge adds iptables rules that allows only specific Private IPs via linux bridge func (client *Client) AllowIPAddressesOnSnatBridge() error { - if err := networkutils.AllowIPAddresses(SnatBridgeName, client.SkipAddressesFromBlock, iptables.Insert); err != nil { + nu := networkutils.NewNetworkUtils(client.netlink, client.plClient) + if err := nu.AllowIPAddresses(client.ipTablesClient, SnatBridgeName, client.SkipAddressesFromBlock, iptables.Insert); err != nil { logger.Error("AllowIPAddresses failed with", zap.Error(err)) return newErrorSnatClient(err.Error()) } @@ -130,7 +142,8 @@ func (client *Client) AllowIPAddressesOnSnatBridge() error { // BlockIPAddressesOnSnatBridge adds iptables rules that blocks all private IPs flowing via linux bridge func (client *Client) BlockIPAddressesOnSnatBridge() error { - if err := networkutils.BlockIPAddresses(SnatBridgeName, iptables.Append); err != nil { + nu := networkutils.NewNetworkUtils(client.netlink, client.plClient) + if err := nu.BlockIPAddresses(client.ipTablesClient, SnatBridgeName, iptables.Append); err != nil { logger.Error("AllowIPAddresses failed with", zap.Error(err)) return newErrorSnatClient(err.Error()) } @@ -172,40 +185,40 @@ func (client *Client) AllowInboundFromHostToNC() error { bridgeIP, containerIP := getNCLocalAndGatewayIP(client) // Create CNI Output chain - if err := iptables.CreateChain(iptables.V4, iptables.Filter, iptables.CNIOutputChain); err != nil { + if err := client.ipTablesClient.CreateChain(iptables.V4, iptables.Filter, iptables.CNIOutputChain); err != nil { logger.Error("AllowInboundFromHostToNC: Creating failed with", zap.Any("CNIOutputChain", iptables.CNIOutputChain), zap.Error(err)) return newErrorSnatClient(err.Error()) } // Forward traffic from Ouptut chain to CNI Output chain - if err := iptables.InsertIptableRule(iptables.V4, iptables.Filter, iptables.Output, "", iptables.CNIOutputChain); err != nil { + if err := client.ipTablesClient.InsertIptableRule(iptables.V4, iptables.Filter, iptables.Output, "", iptables.CNIOutputChain); err != nil { logger.Error("AllowInboundFromHostToNC: Creating failed with", zap.Any("CNIOutputChain", iptables.CNIOutputChain), zap.Error(err)) return newErrorSnatClient(err.Error()) } // Allow connection from Host to NC matchCondition := fmt.Sprintf("-s %s -d %s", bridgeIP.String(), containerIP.String()) - err := iptables.InsertIptableRule(iptables.V4, iptables.Filter, iptables.CNIOutputChain, matchCondition, iptables.Accept) + err := client.ipTablesClient.InsertIptableRule(iptables.V4, iptables.Filter, iptables.CNIOutputChain, matchCondition, iptables.Accept) if err != nil { logger.Error("AllowInboundFromHostToNC: Inserting output rule failed with ", zap.Error(err)) return newErrorSnatClient(err.Error()) } // Create cniinput chain - if err := iptables.CreateChain(iptables.V4, iptables.Filter, iptables.CNIInputChain); err != nil { + if err = client.ipTablesClient.CreateChain(iptables.V4, iptables.Filter, iptables.CNIInputChain); err != nil { logger.Error("AllowInboundFromHostToNC: Creating failed with", zap.Any("CNIOutputChain", iptables.CNIOutputChain), zap.Error(err)) return newErrorSnatClient(err.Error()) } // Forward from Input to cniinput chain - if err := iptables.InsertIptableRule(iptables.V4, iptables.Filter, iptables.Input, "", iptables.CNIInputChain); err != nil { + if err = client.ipTablesClient.InsertIptableRule(iptables.V4, iptables.Filter, iptables.Input, "", iptables.CNIInputChain); err != nil { logger.Error("AllowInboundFromHostToNC: Inserting forward rule to failed with", zap.Any("CNIOutputChain", iptables.CNIOutputChain), zap.Error(err)) return newErrorSnatClient(err.Error()) } // Accept packets from NC only if established connection matchCondition = fmt.Sprintf(" -i %s -m state --state %s,%s", SnatBridgeName, iptables.Established, iptables.Related) - err = iptables.InsertIptableRule(iptables.V4, iptables.Filter, iptables.CNIInputChain, matchCondition, iptables.Accept) + err = client.ipTablesClient.InsertIptableRule(iptables.V4, iptables.Filter, iptables.CNIInputChain, matchCondition, iptables.Accept) if err != nil { logger.Error("AllowInboundFromHostToNC: Inserting input rule failed with", zap.Error(err)) return newErrorSnatClient(err.Error()) @@ -237,7 +250,7 @@ func (client *Client) DeleteInboundFromHostToNC() error { // Delete allow connection from Host to NC matchCondition := fmt.Sprintf("-s %s -d %s", bridgeIP.String(), containerIP.String()) - err := iptables.DeleteIptableRule(iptables.V4, iptables.Filter, iptables.CNIOutputChain, matchCondition, iptables.Accept) + err := client.ipTablesClient.DeleteIptableRule(iptables.V4, iptables.Filter, iptables.CNIOutputChain, matchCondition, iptables.Accept) if err != nil { logger.Error("DeleteInboundFromHostToNC: Error removing output rule", zap.Error(err)) } @@ -264,14 +277,14 @@ func (client *Client) AllowInboundFromNCToHost() error { bridgeIP, containerIP := getNCLocalAndGatewayIP(client) // Create CNI Input chain - if err := iptables.CreateChain(iptables.V4, iptables.Filter, iptables.CNIInputChain); err != nil { + if err := client.ipTablesClient.CreateChain(iptables.V4, iptables.Filter, iptables.CNIInputChain); err != nil { logger.Error("AllowInboundFromHostToNC: Creating failed with", zap.String("CNIInputChain", iptables.CNIInputChain), zap.Error(err)) return err } // Forward traffic from Input to cniinput chain - if err := iptables.InsertIptableRule(iptables.V4, iptables.Filter, iptables.Input, "", iptables.CNIInputChain); err != nil { + if err := client.ipTablesClient.InsertIptableRule(iptables.V4, iptables.Filter, iptables.Input, "", iptables.CNIInputChain); err != nil { logger.Error("AllowInboundFromHostToNC: Inserting forward rule to failed with", zap.String("CNIInputChain", iptables.CNIInputChain), zap.Error(err)) return err @@ -279,21 +292,21 @@ func (client *Client) AllowInboundFromNCToHost() error { // Allow NC to Host connection matchCondition := fmt.Sprintf("-s %s -d %s", containerIP.String(), bridgeIP.String()) - err := iptables.InsertIptableRule(iptables.V4, iptables.Filter, iptables.CNIInputChain, matchCondition, iptables.Accept) + err := client.ipTablesClient.InsertIptableRule(iptables.V4, iptables.Filter, iptables.CNIInputChain, matchCondition, iptables.Accept) if err != nil { logger.Error("AllowInboundFromHostToNC: Inserting output rule failed with", zap.Error(err)) return err } // Create CNI output chain - if err := iptables.CreateChain(iptables.V4, iptables.Filter, iptables.CNIOutputChain); err != nil { + if err = client.ipTablesClient.CreateChain(iptables.V4, iptables.Filter, iptables.CNIOutputChain); err != nil { logger.Error("AllowInboundFromHostToNC: Creating failed with", zap.String("CNIInputChain", iptables.CNIInputChain), zap.Error(err)) return err } // Forward traffic from Output to CNI Output chain - if err := iptables.InsertIptableRule(iptables.V4, iptables.Filter, iptables.Output, "", iptables.CNIOutputChain); err != nil { + if err = client.ipTablesClient.InsertIptableRule(iptables.V4, iptables.Filter, iptables.Output, "", iptables.CNIOutputChain); err != nil { logger.Error("AllowInboundFromHostToNC: Inserting forward rule to failed with", zap.String("CNIInputChain", iptables.CNIInputChain), zap.Error(err)) return err @@ -301,7 +314,7 @@ func (client *Client) AllowInboundFromNCToHost() error { // Accept packets from Host only if established connection matchCondition = fmt.Sprintf(" -o %s -m state --state %s,%s", SnatBridgeName, iptables.Established, iptables.Related) - err = iptables.InsertIptableRule(iptables.V4, iptables.Filter, iptables.CNIOutputChain, matchCondition, iptables.Accept) + err = client.ipTablesClient.InsertIptableRule(iptables.V4, iptables.Filter, iptables.CNIOutputChain, matchCondition, iptables.Accept) if err != nil { logger.Error("AllowInboundFromHostToNC: Inserting input rule failed with", zap.Error(err)) return err @@ -331,7 +344,7 @@ func (client *Client) DeleteInboundFromNCToHost() error { // Delete allow NC to Host connection matchCondition := fmt.Sprintf("-s %s -d %s", containerIP.String(), bridgeIP.String()) - err := iptables.DeleteIptableRule(iptables.V4, iptables.Filter, iptables.CNIInputChain, matchCondition, iptables.Accept) + err := client.ipTablesClient.DeleteIptableRule(iptables.V4, iptables.Filter, iptables.CNIInputChain, matchCondition, iptables.Accept) if err != nil { logger.Error("DeleteInboundFromNCToHost: Error removing output rule", zap.Error(err)) } @@ -453,7 +466,8 @@ func (client *Client) createSnatBridge(snatBridgeIP, hostPrimaryMac string) erro func (client *Client) addMasqueradeRule(snatBridgeIPWithPrefix string) error { _, ipNet, _ := net.ParseCIDR(snatBridgeIPWithPrefix) matchCondition := fmt.Sprintf("-s %s", ipNet.String()) - return iptables.InsertIptableRule(iptables.V4, iptables.Nat, iptables.Postrouting, matchCondition, iptables.Masquerade) + return errors.Wrap(client.ipTablesClient.InsertIptableRule(iptables.V4, iptables.Nat, iptables.Postrouting, matchCondition, iptables.Masquerade), + "failed to add masquerade rule") } // Drop all vlan traffic on linux bridge @@ -474,3 +488,20 @@ func (client *Client) addVlanDropRule() error { _, err = client.plClient.ExecuteCommand(vlanDropAddRule) return err } + +// This function enables ip forwarding in VM and allow forwarding packets from the interface +func (client *Client) EnableIPForwarding() error { + // Enable ip forwading on linux vm. + // sysctl -w net.ipv4.ip_forward=1 + _, err := client.plClient.ExecuteCommand(enableIPForwardCmd) + if err != nil { + return errors.Wrap(err, "enable ipforwarding command failed") + } + + // Append a rule in forward chain to allow forwarding from bridge + if err := client.ipTablesClient.AppendIptableRule(iptables.V4, iptables.Filter, iptables.Forward, "", iptables.Accept); err != nil { + return errors.Wrap(err, "appending forward chain rule to allow traffic from snat bridge failed") + } + + return nil +} diff --git a/network/snat/snat_linux_test.go b/network/snat/snat_linux_test.go index 1838b6aab1..0ffee1ebf2 100644 --- a/network/snat/snat_linux_test.go +++ b/network/snat/snat_linux_test.go @@ -4,6 +4,7 @@ import ( "os" "testing" + "github.com/Azure/azure-container-networking/iptables" "github.com/Azure/azure-container-networking/netlink" ) @@ -19,11 +20,13 @@ func TestMain(m *testing.M) { func TestAllowInboundFromHostToNC(t *testing.T) { nl := netlink.NewNetlink() + iptc := iptables.NewClient() client := &Client{ SnatBridgeIP: "169.254.0.1/16", localIP: "169.254.0.4/16", containerSnatVethName: anyInterface, netlink: nl, + ipTablesClient: iptc, } if err := nl.AddLink(&netlink.DummyLink{ @@ -66,11 +69,13 @@ func TestAllowInboundFromHostToNC(t *testing.T) { func TestAllowInboundFromNCToHost(t *testing.T) { nl := netlink.NewNetlink() + iptc := iptables.NewClient() client := &Client{ SnatBridgeIP: "169.254.0.1/16", localIP: "169.254.0.4/16", containerSnatVethName: anyInterface, netlink: nl, + ipTablesClient: iptc, } if err := nl.AddLink(&netlink.DummyLink{ diff --git a/network/transparent_vlan_endpoint_snatroute_linux.go b/network/transparent_vlan_endpoint_snatroute_linux.go index 609ec137df..b6f712758a 100644 --- a/network/transparent_vlan_endpoint_snatroute_linux.go +++ b/network/transparent_vlan_endpoint_snatroute_linux.go @@ -20,6 +20,7 @@ func (client *TransparentVlanEndpointClient) NewSnatClient(snatBridgeIP, localIP true, client.netlink, client.plClient, + client.iptablesClient, ) } } diff --git a/network/transparent_vlan_endpointclient_linux.go b/network/transparent_vlan_endpointclient_linux.go index 2ec756ad0d..6fcf9719e9 100644 --- a/network/transparent_vlan_endpointclient_linux.go +++ b/network/transparent_vlan_endpointclient_linux.go @@ -66,6 +66,7 @@ type TransparentVlanEndpointClient struct { plClient platform.ExecClient netUtilsClient networkutils.NetworkUtils nsClient NamespaceClientInterface + iptablesClient ipTablesClient } func NewTransparentVlanEndpointClient( @@ -78,6 +79,7 @@ func NewTransparentVlanEndpointClient( nl netlink.NetlinkInterface, plc platform.ExecClient, nsc NamespaceClientInterface, + iptc ipTablesClient, ) *TransparentVlanEndpointClient { vlanVethName := fmt.Sprintf("%s_%d", nw.extIf.Name, vlanid) vnetNSName := fmt.Sprintf("az_ns_%d", vlanid) @@ -100,6 +102,7 @@ func NewTransparentVlanEndpointClient( plClient: plc, netUtilsClient: networkutils.NewNetworkUtils(nl, plc), nsClient: nsc, + iptablesClient: iptc, } client.NewSnatClient(nw.SnatBridgeIP, localIP, ep) @@ -396,16 +399,16 @@ func (client *TransparentVlanEndpointClient) AddEndpointRules(epInfo *EndpointIn func (client *TransparentVlanEndpointClient) AddVnetRules(epInfo *EndpointInfo) error { // iptables -t mangle -I PREROUTING -j MARK --set-mark markOption := fmt.Sprintf("MARK --set-mark %d", tunnelingMark) - if err := iptables.InsertIptableRule(iptables.V4, "mangle", "PREROUTING", "", markOption); err != nil { + if err := client.iptablesClient.InsertIptableRule(iptables.V4, "mangle", "PREROUTING", "", markOption); err != nil { return errors.Wrap(err, "unable to insert iptables rule mark all packets not entering on vlan interface") } // iptables -t mangle -I PREROUTING -j ACCEPT -i match := fmt.Sprintf("-i %s", client.vlanIfName) - if err := iptables.InsertIptableRule(iptables.V4, "mangle", "PREROUTING", match, "ACCEPT"); err != nil { + if err := client.iptablesClient.InsertIptableRule(iptables.V4, "mangle", "PREROUTING", match, "ACCEPT"); err != nil { return errors.Wrap(err, "unable to insert iptables rule accept all incoming from vlan interface") } // Blocks wireserver traffic from customer vnet nic - if err := networkutils.BlockEgressTrafficFromContainer(iptables.V4, networkutils.AzureDNS, iptables.TCP, iptables.HTTPPort); err != nil { + if err := client.netUtilsClient.BlockEgressTrafficFromContainer(client.iptablesClient, iptables.V4, networkutils.AzureDNS, iptables.TCP, iptables.HTTPPort); err != nil { return errors.Wrap(err, "unable to insert iptables rule to drop wireserver packets") }