diff --git a/cns/imdsclient/imdsclient.go b/cns/imdsclient/imdsclient.go index 918ac29d43..ccfd0fa2d3 100644 --- a/cns/imdsclient/imdsclient.go +++ b/cns/imdsclient/imdsclient.go @@ -1,17 +1,24 @@ -// Copyright 2017 Microsoft. All rights reserved. -// MIT License - package imdsclient import ( + "bytes" "encoding/json" "encoding/xml" "fmt" + "io" "math" + "net" "net/http" - "strings" "github.com/Azure/azure-container-networking/cns/logger" + "github.com/pkg/errors" +) + +var ( + // ErrNoPrimaryInterface indicates the imds respnose does not have a primary interface indicated. + ErrNoPrimaryInterface = errors.New("no primary interface found") + // ErrInsufficientAddressSpace indicates that the CIDR space is too small to include a gateway IP; it is 1 IP. + ErrInsufficientAddressSpace = errors.New("insufficient address space to generate gateway IP") ) // GetNetworkContainerInfoFromHost retrieves the programmed version of network container from Host. @@ -45,6 +52,7 @@ func (imdsClient *ImdsClient) GetNetworkContainerInfoFromHost(networkContainerID } // GetPrimaryInterfaceInfoFromHost retrieves subnet and gateway of primary NIC from Host. +// TODO(rbtr): this is not a good client contract, we should return the resp. func (imdsClient *ImdsClient) GetPrimaryInterfaceInfoFromHost() (*InterfaceInfo, error) { logger.Printf("[Azure CNS] GetPrimaryInterfaceInfoFromHost") @@ -53,64 +61,87 @@ func (imdsClient *ImdsClient) GetPrimaryInterfaceInfoFromHost() (*InterfaceInfo, if err != nil { return nil, err } - defer resp.Body.Close() + b, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read response body") + } - logger.Printf("[Azure CNS] Response received from NMAgent for get interface details: %v", resp.Body) + logger.Printf("[Azure CNS] Response received from NMAgent for get interface details: %s", string(b)) var doc xmlDocument - decoder := xml.NewDecoder(resp.Body) - err = decoder.Decode(&doc) - if err != nil { - return nil, err + if err := xml.NewDecoder(bytes.NewReader(b)).Decode(&doc); err != nil { + return nil, errors.Wrap(err, "failed to decode response body") } foundPrimaryInterface := false // For each interface. for _, i := range doc.Interface { - // Find primary Interface. - if i.IsPrimary { - interfaceInfo.IsPrimary = true - - // Get the first subnet. - for _, s := range i.IPSubnet { - interfaceInfo.Subnet = s.Prefix - malformedSubnetError := fmt.Errorf("Malformed subnet received from host %s", s.Prefix) - - st := strings.Split(s.Prefix, "/") - if len(st) != 2 { - return nil, malformedSubnetError - } - - ip := strings.Split(st[0], ".") - if len(ip) != 4 { - return nil, malformedSubnetError - } - - interfaceInfo.Gateway = fmt.Sprintf("%s.%s.%s.1", ip[0], ip[1], ip[2]) - for _, ip := range s.IPAddress { - if ip.IsPrimary { - interfaceInfo.PrimaryIP = ip.Address - } + // skip if not primary + if !i.IsPrimary { + continue + } + interfaceInfo.IsPrimary = true + + // Get the first subnet. + for _, s := range i.IPSubnet { + interfaceInfo.Subnet = s.Prefix + gw, err := calculateGatewayIP(s.Prefix) + if err != nil { + return nil, err + } + interfaceInfo.Gateway = gw.String() + for _, ip := range s.IPAddress { + if ip.IsPrimary { + interfaceInfo.PrimaryIP = ip.Address } - - imdsClient.primaryInterface = interfaceInfo - break } - foundPrimaryInterface = true + imdsClient.primaryInterface = interfaceInfo break } + + foundPrimaryInterface = true + break } - var er error - er = nil if !foundPrimaryInterface { - er = fmt.Errorf("Unable to find primary NIC") + return nil, ErrNoPrimaryInterface } - return interfaceInfo, er + return interfaceInfo, nil +} + +// calculateGatewayIP parses the passed CIDR string and returns the first IP in the range. +func calculateGatewayIP(cidr string) (net.IP, error) { + _, subnet, err := net.ParseCIDR(cidr) + if err != nil { + return nil, errors.Wrap(err, "received malformed subnet from host") + } + + // check if we have enough address space to calculate a gateway IP + // we need at least 2 IPs (eg the IPv4 mask cannot be greater than 31) + // since the zeroth is reserved and the gateway is the first. + mask, bits := subnet.Mask.Size() + if mask == bits { + return nil, ErrInsufficientAddressSpace + } + + // the subnet IP is the zero base address, so we need to increment it by one to get the gateway. + gw := make([]byte, len(subnet.IP)) + copy(gw, subnet.IP) + for idx := len(gw) - 1; idx >= 0; idx-- { + gw[idx]++ + // net.IP is a binary byte array, check if we have overflowed and need to continue incrementing to the left + // along the arary or if we're done. + // it's like if we have a 9 in base 10, and add 1, it rolls over to 0 so we're not done - we need to move + // left and increment that digit also. + if gw[idx] != 0 { + break + } + } + return gw, nil } // GetPrimaryInterfaceInfoFromMemory retrieves subnet and gateway of primary NIC that is saved in memory. diff --git a/cns/imdsclient/imdsclient_test.go b/cns/imdsclient/imdsclient_test.go new file mode 100644 index 0000000000..71dca8fb13 --- /dev/null +++ b/cns/imdsclient/imdsclient_test.go @@ -0,0 +1,56 @@ +package imdsclient + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenerateGatewayIP(t *testing.T) { + tests := []struct { + name string + cidr string + want net.IP + wantErr bool + }{ + { + name: "base case", + cidr: "10.0.0.0/8", + want: net.IPv4(10, 0, 0, 1), + }, + { + name: "nonzero start", + cidr: "10.177.233.128/27", + want: net.IPv4(10, 177, 233, 129), + }, + { + name: "invalid", + cidr: "test", + wantErr: true, + }, + { + name: "no available", + cidr: "255.255.255.255/32", + wantErr: true, + }, + { + name: "max IPv4", + cidr: "255.255.255.255/31", + want: net.IPv4(255, 255, 255, 255), + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + got, err := calculateGatewayIP(tt.cidr) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + assert.Truef(t, tt.want.Equal(got), "want %s, got %s", tt.want.String(), got.String()) + }) + } +} diff --git a/cns/imdsclient/testdata/interfaces.xml b/cns/imdsclient/testdata/interfaces.xml new file mode 100644 index 0000000000..dc1fc17a83 --- /dev/null +++ b/cns/imdsclient/testdata/interfaces.xml @@ -0,0 +1,7 @@ + + + + + + +