diff --git a/core/logic.go b/core/logic.go index e5c7f1608d..652b91e4f2 100644 --- a/core/logic.go +++ b/core/logic.go @@ -169,7 +169,7 @@ func CleanupAfterContainerDeletion(ifaceName string, macAddress net.HardwareAddr return errors.New("received null veth pair for cleanup") } - netlink.DeleteNetworkLink(ifaceName) + netlink.DeleteLink(ifaceName) err = ebtables.RemoveDnatBasedOnIPV4Address(targetVeth.ip.String(), macAddress.String()) if err != nil { fmt.Println(err.Error()) @@ -177,7 +177,7 @@ func CleanupAfterContainerDeletion(ifaceName string, macAddress net.HardwareAddr } a := targetVeth.ip fmt.Println("going to add " + a.String() + "-- to " + targetVeth.ifaceNameCaWasTakenFrom) - netlink.AddLinkIPAddress(targetVeth.ifaceNameCaWasTakenFrom, *(targetVeth.ip), targetVeth.ipNet) + netlink.AddIpAddress(targetVeth.ifaceNameCaWasTakenFrom, *(targetVeth.ip), targetVeth.ipNet) delete(vethPairCollection, int(val)) return nil } @@ -209,14 +209,14 @@ func GetTargetInterface(interfaceNameToAttach string, ipAddressToAttach string) name1 := fmt.Sprintf("%s%d", vethPrefix, pair.peer1) name2 := fmt.Sprintf("%s%d", vethPrefix, pair.peer2) fmt.Println("Received veth pair names as ", name1, "-", name2, ". Now creating these.") - err = netlink.CreateVethPair(name1, name2) + err = netlink.AddVethPair(name1, name2) if err != nil { return net.IPNet{}, net.IPNet{}, nil, -1, "", "", net.IP{}, err.Error() } fmt.Println("Successfully generated veth pair.") fmt.Println("Going to add ip address ", *ip, ipNet, " to ", name1) - err = netlink.AddLinkIPAddress(name1, *ip, ipNet) + err = netlink.AddIpAddress(name1, *ip, ipNet) if err != nil { return net.IPNet{}, net.IPNet{}, nil, -1, "", "", net.IP{}, err.Error() } @@ -225,15 +225,14 @@ func GetTargetInterface(interfaceNameToAttach string, ipAddressToAttach string) fmt.Println("Updating veth pair state") fmt.Println("Going to set ", name2, " as up.") - command := fmt.Sprintf("ip link set %s up", name2) - err = ExecuteShellCommand(command) + err = netlink.SetLinkState(name2, true) if err != nil { return net.IPNet{}, net.IPNet{}, nil, -1, "", "", net.IP{}, err.Error() } fmt.Println("successfully ifupped ", name2, ".") fmt.Println("Going to add ", name2, " to aqua.") - err = netlink.AddInterfaceToBridge(name2, "aqua") + err = netlink.SetLinkMaster(name2, "aqua") if err != nil { fmt.Println(err.Error()) return net.IPNet{}, net.IPNet{}, nil, -1, "", "", net.IP{}, err.Error() @@ -268,26 +267,23 @@ func GetTargetInterface(interfaceNameToAttach string, ipAddressToAttach string) func FreeSlaves() error { for ifaceName, ifaceDetails := range mapEnslavedInterfaces { fmt.Println("Going to remove " + ifaceName + " from bridge") - err := netlink.RemoveInterfaceFromBridge(ifaceName) + err := netlink.SetLinkMaster(ifaceName, "") fmt.Println("Going to if down the interface so that mac address can be fixed") - command := fmt.Sprintf("ip link set %s down", ifaceName) - err = ExecuteShellCommand(command) + err = netlink.SetLinkState(ifaceName, false) if err != nil { return err } macAddress := ifaceDetails.rnmAllocatedMacAddress fmt.Println("Going to revert hardware address of " + ifaceName + " to " + macAddress.String()) - command = fmt.Sprintf("ip link set %s address %s", ifaceName, macAddress) - err = ExecuteShellCommand(command) + err = netlink.SetLinkAddress(ifaceName, macAddress) if err != nil { return err } - fmt.Println("Going to revert hardware address") - command = fmt.Sprintf("ip link set %s up", ifaceName) - err = ExecuteShellCommand(command) + fmt.Println("Going to if up") + err = netlink.SetLinkState(ifaceName, true) if err != nil { return err } @@ -302,7 +298,7 @@ func FreeSlaves() error { fmt.Println("Going to add ip addresses back to interface " + ifaceName) for _, caDetails := range ifaceDetails.provisionedCas { - netlink.AddLinkIPAddress(ifaceName, *(caDetails.ip), caDetails.ipNet) + netlink.AddIpAddress(ifaceName, *(caDetails.ip), caDetails.ipNet) } } @@ -471,14 +467,13 @@ func enslaveInterfaceIfRequired(iface *net.Interface, bridge string) error { _, err := net.InterfaceByName(bridge) if err != nil { // bridge does not exist - if err := netlink.CreateBridge(bridge); err != nil { + if err := netlink.AddLink(bridge, "bridge"); err != nil { return err } } fmt.Println("Going to iff up the bridge " + bridge) - command := fmt.Sprintf("ip link set %s up", bridge) - err = ExecuteShellCommand(command) + err = netlink.SetLinkState(bridge, true) if err != nil { return err } @@ -496,8 +491,7 @@ func enslaveInterfaceIfRequired(iface *net.Interface, bridge string) error { } fmt.Println("Going to iff down " + iface.Name) - command = fmt.Sprintf("ip link set %s down", iface.Name) - err = ExecuteShellCommand(command) + err = netlink.SetLinkState(iface.Name, false) if err != nil { fmt.Println(err.Error()) return err @@ -526,23 +520,21 @@ func enslaveInterfaceIfRequired(iface *net.Interface, bridge string) error { } fmt.Println("Going to set " + newMac.String() + " on " + iface.Name) - command = fmt.Sprintf("ip link set %s address %s", iface.Name, newMac.String()) - err = ExecuteShellCommand(command) + err = netlink.SetLinkAddress(iface.Name, newMac) if err != nil { fmt.Println(err.Error()) return err } fmt.Println("Going to iff up the link " + iface.Name) - command = fmt.Sprintf("ip link set %s up", iface.Name) - err = ExecuteShellCommand(command) + err = netlink.SetLinkState(iface.Name, true) if err != nil { fmt.Println(err.Error()) return err } fmt.Println("Going to add link " + iface.Name + " to " + bridge) - err = netlink.AddInterfaceToBridge(iface.Name, bridge) + err = netlink.SetLinkMaster(iface.Name, bridge) if err != nil { fmt.Println(err.Error()) return err @@ -593,7 +585,7 @@ func getAvailableCaAndRemoveFromHostInterface(iface *net.Interface, ipAddressToA return nil, nil, errors.New(erMsg) } - netlink.RemoveLinkIPAddress(iface.Name, *targetCa.ip, targetCa.ipNet) + netlink.DeleteIpAddress(iface.Name, *targetCa.ip, targetCa.ipNet) return targetCa.ip, targetCa.ipNet, nil } @@ -602,7 +594,7 @@ func getAvailableCaAndRemoveFromHostInterface(iface *net.Interface, ipAddressToA for caName, caDetails := range ensalvedIface.provisionedCas { if !isCaAlreadyAssigned(caName, caDetails) { fmt.Println("Found an unused CA " + caName) - netlink.RemoveLinkIPAddress(iface.Name, *caDetails.ip, caDetails.ipNet) + netlink.DeleteIpAddress(iface.Name, *caDetails.ip, caDetails.ipNet) return caDetails.ip, caDetails.ipNet, nil } } diff --git a/netlink/bridge.go b/netlink/bridge.go deleted file mode 100644 index b2293a394d..0000000000 --- a/netlink/bridge.go +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright Microsoft Corp. -// All rights reserved. - -package netlink - -import ( - "fmt" - "math/rand" - "net" - - "golang.org/x/sys/unix" -) - -// CreateBridge creates a bridge device -func CreateBridge(bridgeName string) error { - - deviceTypeData := "bridge" - deviceNameData := bridgeName + "\000" - - netlinkSocketFd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) - if err != nil { - fmt.Printf("Error happenned while creating netlink socket %s \n", err) - return err - } - defer unix.Close(netlinkSocketFd) - - var sockAddrNetlink unix.SockaddrNetlink - sockAddrNetlink.Family = unix.AF_NETLINK // Address family of socket - sockAddrNetlink.Pad = 0 // should always be zero - sockAddrNetlink.Pid = 0 // have the kernel process message - - if err := unix.Bind(netlinkSocketFd, &sockAddrNetlink); err != nil { - fmt.Println("Got an error while binding netlikn socket " + err.Error()) - } - - // Create netlink message header - var netlinkMsgHeader unix.NlMsghdr - netlinkMsgHeader.Type = unix.RTM_NEWLINK - netlinkMsgHeader.Flags = unix.NLM_F_REQUEST | unix.NLM_F_CREATE | unix.NLM_F_EXCL | unix.NLM_F_ACK - // Seq is a 4 byte arbitrary number which is used to correlate request with response - netlinkMsgHeader.Seq = rand.Uint32() - netlinkMsgHeader.Pid = uint32(unix.Getpid()) - netlinkMsgHeader.Len = uint32(unix.SizeofNlMsghdr) + - uint32(unix.SizeofIfInfomsg) + - uint32(rtaAlign(unix.SizeofRtAttr)+rtaAlign(unix.SizeofRtAttr+len(deviceTypeData))) + - uint32(rtaAlign(unix.SizeofRtAttr+len(deviceNameData))) - - // Message of type NewLINK is required to contain an ifinfomsg structure - // after the header (ancilliary data) - var ifInfomsg unix.IfInfomsg - ifInfomsg.Family = unix.AF_UNSPEC // it has to be this Value - - attrLen := unix.SizeofRtAttr - - // this will contain the link type - rtAttrLinkInfo := createRtAttr(unix.IFLA_LINKINFO, rtaAlign(attrLen+attrLen+len(deviceTypeData))) - rtAttrLinkType := createRtAttr(1, rtaAlign(attrLen+len(deviceTypeData))) - rtAttrName := createRtAttr(unix.IFLA_IFNAME, rtaAlign(attrLen+len(deviceNameData))) - - data := SerializeNetLinkMessageHeader(&netlinkMsgHeader) - data = append(data, SerializeAncilliaryMessageHeader(&ifInfomsg)...) - - rtAttrLinkInfoSerialized := SerializeRoutingAttribute(&rtAttrLinkInfo) - rtAttrLinkInfoWithPadding := make([]byte, rtaAlign(unix.SizeofRtAttr)) - copy(rtAttrLinkInfoWithPadding[0:rtaAlign(unix.SizeofRtAttr)], rtAttrLinkInfoSerialized) - - data = append(data, rtAttrLinkInfoWithPadding...) - - rtAttrLinkTypeSerialized := SerializeRoutingAttribute(&rtAttrLinkType) - rtAttrLinkTypeLength := rtaAlign(unix.SizeofRtAttr + len(deviceTypeData)) - rtAttrLinkTypeWithPadding := make([]byte, rtAttrLinkTypeLength) - copy(rtAttrLinkTypeWithPadding[0:unix.SizeofRtAttr], rtAttrLinkTypeSerialized) - copy(rtAttrLinkTypeWithPadding[unix.SizeofRtAttr:rtAttrLinkTypeLength], []byte(deviceTypeData)) - data = append(data, rtAttrLinkTypeWithPadding...) - - rtAttrLinkNameSerialized := SerializeRoutingAttribute(&rtAttrName) - rtAttrLinkNameLength := rtaAlign(unix.SizeofRtAttr + len(deviceNameData)) - rtAttrLinkNameWithPadding := make([]byte, rtAttrLinkNameLength) - copy(rtAttrLinkNameWithPadding[0:unix.SizeofRtAttr], rtAttrLinkNameSerialized) - copy(rtAttrLinkNameWithPadding[unix.SizeofRtAttr:rtAttrLinkNameLength], []byte(deviceNameData)) - data = append(data, rtAttrLinkNameWithPadding...) - - flags := 0 - - // the only way to communicate with netlink sockets is via sendto - if err := unix.Sendto(netlinkSocketFd, data, flags, &sockAddrNetlink); err != nil { - fmt.Println("Got an error while sending message to kernel " + err.Error()) - return err - } - fmt.Println("Bridge creation command given to kernel successfully") - - if err := SetupKernelAcknowledgement(netlinkSocketFd, netlinkMsgHeader.Seq); err != nil { - fmt.Printf("Error received from kernel -> %s\n", err.Error()) - return err - } - fmt.Println("Bridge created successfully") - return nil - -} - -// AddInterfaceToBridge adds an interface to bridge -func AddInterfaceToBridge(linkName string, bridgeName string) error { - bridge, err := net.InterfaceByName(bridgeName) - if err != nil { - return err - } - fmt.Printf("Going to add %s to %s\n", linkName, bridgeName) - return addInterfaceToBridgeInternal(linkName, bridge.Index) -} - -// RemoveInterfaceFromBridge removes an interface from bridge -func RemoveInterfaceFromBridge(linkName string) error { - return addInterfaceToBridgeInternal(linkName, 0) -} - -func addInterfaceToBridgeInternal(linkName string, bridgeIndex int) error { - - iface, err := net.InterfaceByName(linkName) - if err != nil { - return err - } - - bindex := uint32(bridgeIndex) - bindexByte := make([]byte, 4) - getNativeType().PutUint32(bindexByte, bindex) - - netlinkSocketFd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) - if err != nil { - fmt.Printf("Error happenned while creating netlink socket %s \n", err) - return err - } - - defer unix.Close(netlinkSocketFd) - - var sockAddrNetlink unix.SockaddrNetlink - sockAddrNetlink.Family = unix.AF_NETLINK // Address family of socket - sockAddrNetlink.Pad = 0 // should always be zero - sockAddrNetlink.Pid = 0 // have the kernel process message - // sockAddrNetlink.Groups // Not yet sure what to do with this - - if err := unix.Bind(netlinkSocketFd, &sockAddrNetlink); err != nil { - fmt.Println("Got an error while binding netlikn socket " + err.Error()) - } - - // Create netlink message header - var netlinkMsgHeader unix.NlMsghdr - netlinkMsgHeader.Type = unix.RTM_SETLINK - netlinkMsgHeader.Flags = unix.NLM_F_REQUEST | unix.NLM_F_ACK - // Seq is a 4 byte arbitrary number which is used to correlate request with response - netlinkMsgHeader.Seq = rand.Uint32() - netlinkMsgHeader.Pid = uint32(unix.Getpid()) - netlinkMsgHeader.Len = - uint32(unix.SizeofNlMsghdr) + - uint32(unix.SizeofIfInfomsg) + - uint32(unix.SizeofRtAttr) + - uint32(len(bindexByte)) - - // Message of type NewLINK is required to contain an ifinfomsg structure - // after the header (ancilliary data) - var ifInfomsg unix.IfInfomsg - ifInfomsg.Family = unix.AF_UNSPEC // it has to be this Value - ifInfomsg.Type = unix.RTM_SETLINK - ifInfomsg.Flags = unix.NLM_F_REQUEST - ifInfomsg.Index = int32(iface.Index) - ifInfomsg.Change = 0xFFFFFFFF - - attrLen := unix.SizeofRtAttr - rtAttrLinkInfo := createRtAttr(unix.IFLA_MASTER, rtaAlign(attrLen+len(bindexByte))) - - data := SerializeNetLinkMessageHeader(&netlinkMsgHeader) - data = append(data, SerializeAncilliaryMessageHeader(&ifInfomsg)...) - - rtAttrLinkInfoSerialized := SerializeRoutingAttribute(&rtAttrLinkInfo) - rtAttrLinkInfoLength := rtaAlign(unix.SizeofRtAttr + len(bindexByte)) - rtAttrLinkInfoWithPadding := make([]byte, rtAttrLinkInfoLength) - copy(rtAttrLinkInfoWithPadding[0:unix.SizeofRtAttr], rtAttrLinkInfoSerialized) - copy(rtAttrLinkInfoWithPadding[unix.SizeofRtAttr:rtAttrLinkInfoLength], []byte(bindexByte)) - data = append(data, rtAttrLinkInfoWithPadding...) - - flags := 0 - if err := unix.Sendto(netlinkSocketFd, data, flags, &sockAddrNetlink); err != nil { - fmt.Println("Got an error while sending message to kernel " + err.Error()) - } - - fmt.Println("Add interface to bridge command given to kernel successfully") - - if err := SetupKernelAcknowledgement(netlinkSocketFd, netlinkMsgHeader.Seq); err != nil { - fmt.Printf("Error received from kernel -> %s\n", err.Error()) - return err - } - - return nil - -} diff --git a/netlink/common.go b/netlink/common.go deleted file mode 100644 index a8dc9012a3..0000000000 --- a/netlink/common.go +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright Microsoft Corp. -// All rights reserved. - -package netlink - -import ( - "encoding/binary" - "errors" - "fmt" - "io" - - "syscall" - "unsafe" - - "golang.org/x/sys/unix" -) - -func rtaAlign(length int) int { - return (length + unix.RTA_ALIGNTO - 1) & ^(unix.RTA_ALIGNTO - 1) -} - -func createRtAttr(attrType int, lengthInBytes int) unix.RtAttr { - var rtAttr unix.RtAttr - rtAttr.Type = uint16(attrType) - rtAttr.Len = uint16(lengthInBytes) - return rtAttr -} - -// SerializeNetLinkMessageHeader serializes the netlink header -func SerializeNetLinkMessageHeader(netlinkMessageHeader *unix.NlMsghdr) []byte { - - headerInByteArray := make([]byte, unix.SizeofNlMsghdr) - hdr := (*(*[unix.SizeofNlMsghdr]byte)(unsafe.Pointer(netlinkMessageHeader)))[:] - next := unix.SizeofNlMsghdr - copy(headerInByteArray[0:next], hdr) - - return headerInByteArray -} - -// SerializeAncilliaryMessageHeader serializes ancilliary message header -func SerializeAncilliaryMessageHeader(ifInfomsgHdr *unix.IfInfomsg) []byte { - - headerInByteArray := make([]byte, unix.SizeofIfInfomsg) - hdr := (*(*[unix.SizeofIfInfomsg]byte)(unsafe.Pointer(ifInfomsgHdr)))[:] - next := unix.SizeofIfInfomsg - copy(headerInByteArray[0:next], hdr) - - return headerInByteArray -} - -// SerializeAddressMessageHeader serializes address message header -func SerializeAddressMessageHeader(ifAddrmsgHdr *unix.IfAddrmsg) []byte { - - headerInByteArray := make([]byte, unix.SizeofIfAddrmsg) - hdr := (*(*[unix.SizeofIfAddrmsg]byte)(unsafe.Pointer(ifAddrmsgHdr)))[:] - next := unix.SizeofIfAddrmsg - copy(headerInByteArray[0:next], hdr) - - return headerInByteArray -} - -// SerializeRoutingAttribute serializes routing attribute -func SerializeRoutingAttribute(rtAttr *unix.RtAttr) []byte { - - attrInByteArray := make([]byte, unix.SizeofRtAttr) - attr := (*(*[unix.SizeofRtAttr]byte)(unsafe.Pointer(rtAttr)))[:] - next := unix.SizeofRtAttr - copy(attrInByteArray[0:next], attr) - - return attrInByteArray -} - -// RecvfromKernel sets up socket ot receive response from kernel -func RecvfromKernel(netlinkSocketFd int) ([]syscall.NetlinkMessage, error) { - rb := make([]byte, unix.Getpagesize()) - nr, _, err := unix.Recvfrom(netlinkSocketFd, rb, 0) - if err != nil { - return nil, err - } - if nr < unix.NLMSG_HDRLEN { - return nil, errors.New("Short response from netlink") - } - fmt.Printf("Received %d bytes in response from kernel. ", nr) - fmt.Printf("Header len: %d bytes\n", unix.NLMSG_HDRLEN) - rb = rb[:nr] - - return syscall.ParseNetlinkMessage(rb) -} - -// SetupKernelAcknowledgement sets up ack from kernel -func SetupKernelAcknowledgement(netlinkSocketFd int, seqNo uint32) error { - var lsa unix.Sockaddr - var err error - lsa, err = unix.Getsockname(netlinkSocketFd) - if err != nil { - return err - } - - pid := uint32(0) - switch v := lsa.(type) { - case *unix.SockaddrNetlink: - pid = uint32(v.Pid) - } - if pid == 0 { - return errors.New("Wrong socket type") - } - -outer: - for { - msgs, err := RecvfromKernel(netlinkSocketFd) - if err != nil { - return err - } - for _, m := range msgs { - if err := validate(m, seqNo, pid); err != nil { - if err == io.EOF { - break outer - } - return err - } - } - } - - return nil -} - -func getNativeType() binary.ByteOrder { - var a uint32 = 0x01020304 - if *(*byte)(unsafe.Pointer(&a)) == 0x01 { - return binary.BigEndian - } - return binary.LittleEndian -} - -func validate(m syscall.NetlinkMessage, seq, pid uint32) error { - if m.Header.Seq != seq { - return fmt.Errorf("invalid seq no: %d, expected: %d", m.Header.Seq, seq) - } - - fmt.Printf("Received sequence no: %d Expected: %d\n", m.Header.Seq, seq) - - if m.Header.Pid != pid { - return fmt.Errorf("wrong pid: %d, expected: %d", m.Header.Pid, pid) - } - fmt.Printf("Received Pid: %d Expected: %d\n", m.Header.Pid, pid) - - if m.Header.Type == unix.NLMSG_DONE { - fmt.Printf("Received unix.NLMSG_DONE\n") - return io.EOF - } - - if m.Header.Type == unix.NLMSG_ERROR { - fmt.Printf("Received unix.NLMSG_ERROR\n") - e := int32(getNativeType().Uint32(m.Data[0:4])) - fmt.Printf("Received error no. as %d\n", e) - if e == 0 { - return io.EOF - } - return syscall.Errno(-e) - } - return nil -} diff --git a/netlink/interface.go b/netlink/interface.go deleted file mode 100644 index 7c97cb6ec1..0000000000 --- a/netlink/interface.go +++ /dev/null @@ -1,330 +0,0 @@ -// Copyright Microsoft Corp. -// All rights reserved. - -package netlink - -import ( - "fmt" - "math/rand" - "net" - - "golang.org/x/sys/unix" -) - -// AddLinkIPAddress adds an ip address to a link -func AddLinkIPAddress(linkName string, ip net.IP, ipNet *net.IPNet) error { - action := unix.RTM_NEWADDR - flags := unix.NLM_F_REQUEST | unix.NLM_F_CREATE | unix.NLM_F_EXCL | unix.NLM_F_ACK - return configureLinkIPAddress(linkName, ip, ipNet, action, flags) -} - -// RemoveLinkIPAddress removes ip address configured on a link -func RemoveLinkIPAddress(linkName string, ip net.IP, ipNet *net.IPNet) error { - action := unix.RTM_DELADDR - flags := unix.NLM_F_REQUEST | unix.NLM_F_ACK - return configureLinkIPAddress(linkName, ip, ipNet, action, flags) -} - -// CreateVethPair creates a pair of veth devices -func CreateVethPair(veth1 string, veth2 string) error { - - deviceTypeData := "veth" + "\000" - rtAttrLen := unix.SizeofRtAttr - - netlinkSocketFd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) - if err != nil { - fmt.Printf("Error happenned while creating netlink socket %s \n", err) - return err - } - defer unix.Close(netlinkSocketFd) - - var sockAddrNetlink unix.SockaddrNetlink - sockAddrNetlink.Family = unix.AF_NETLINK // Address family of socket - sockAddrNetlink.Pad = 0 // should always be zero - sockAddrNetlink.Pid = 0 // have the kernel process message - // sockAddrNetlink.Groups // Not yet sure what to do with this - - if err := unix.Bind(netlinkSocketFd, &sockAddrNetlink); err != nil { - fmt.Println("Got an error while binding netlikn socket " + err.Error()) - } - - // Create netlink message header - var netlinkMsgHeader unix.NlMsghdr - netlinkMsgHeader.Type = unix.RTM_NEWLINK - netlinkMsgHeader.Flags = unix.NLM_F_REQUEST | unix.NLM_F_CREATE | unix.NLM_F_EXCL | unix.NLM_F_ACK - netlinkMsgHeader.Seq = rand.Uint32() - netlinkMsgHeader.Pid = uint32(unix.Getpid()) - netlinkMsgHeader.Len = - uint32(unix.SizeofNlMsghdr + - unix.SizeofIfInfomsg + - rtaAlign(unix.SizeofRtAttr+len(veth1+"\000")) + - rtaAlign(unix.SizeofRtAttr) + // IFLA_LINKINFO - rtaAlign(unix.SizeofRtAttr+len(deviceTypeData)) + // veth - rtaAlign(unix.SizeofRtAttr) + // data - rtaAlign(unix.SizeofRtAttr) + //peer2 - unix.SizeofIfInfomsg + // AF_UNSPEC - rtaAlign(unix.SizeofRtAttr+len(veth2+"\000"))) - - // Message of type NewLINK is required to contain an ifinfomsg structure - // after the header (ancilliary data) - var ifInfomsg unix.IfInfomsg - ifInfomsg.Family = unix.AF_UNSPEC // it has to be this Value - - rtAttrName := createRtAttr(unix.IFLA_IFNAME, rtAttrLen+len(veth1+"\000")) - - rtAttrLinkInfoLen := rtaAlign(unix.SizeofRtAttr) + // IFLA_LINKINFO - rtaAlign(unix.SizeofRtAttr+len(deviceTypeData)) + // veth - rtaAlign(unix.SizeofRtAttr) + // data - rtaAlign(unix.SizeofRtAttr) + //peer2 - unix.SizeofIfInfomsg + // AF_UNSPEC - rtaAlign(unix.SizeofRtAttr+len(veth2+"\000")) - rtAttrLinkInfo := createRtAttr(unix.IFLA_LINKINFO, rtAttrLinkInfoLen) - - rtAttrLinkType := createRtAttr(1, rtaAlign(unix.SizeofRtAttr+len(deviceTypeData))) - - // we need a peer - dataLen := rtaAlign(unix.SizeofRtAttr) + // data - rtaAlign(unix.SizeofRtAttr) + //peer2 - unix.SizeofIfInfomsg + // AF_UNSPEC - rtaAlign(unix.SizeofRtAttr+len(veth2+"\000")) - rtAttrLinkTypeData := createRtAttr(2, dataLen) - - peerLength := rtaAlign(unix.SizeofRtAttr) + //peer2 - unix.SizeofIfInfomsg + // AF_UNSPEC - rtaAlign(unix.SizeofRtAttr+len(veth2+"\000")) - rtAttrLinkPeer := createRtAttr(1, peerLength) - - var ifInfomsgForPeer unix.IfInfomsg - ifInfomsgForPeer.Family = unix.AF_UNSPEC // it has to be this Value - - rtAtttrLinkPeerName := createRtAttr(unix.IFLA_IFNAME, rtaAlign(rtAttrLen+len(veth2+"\000"))) - - data := SerializeNetLinkMessageHeader(&netlinkMsgHeader) - data = append(data, SerializeAncilliaryMessageHeader(&ifInfomsg)...) - - rtAttrLinkNameSerialized := SerializeRoutingAttribute(&rtAttrName) - rtAttrLinkNameLength := rtaAlign(unix.SizeofRtAttr + len(veth1+"\000")) - rtAttrLinkNameWithPadding := make([]byte, rtAttrLinkNameLength) - copy(rtAttrLinkNameWithPadding[0:unix.SizeofRtAttr], rtAttrLinkNameSerialized) - copy(rtAttrLinkNameWithPadding[unix.SizeofRtAttr:rtAttrLinkNameLength], []byte(veth1+"\000")) - data = append(data, rtAttrLinkNameWithPadding...) - - rtAttrLinkInfoSerialized := SerializeRoutingAttribute(&rtAttrLinkInfo) - rtAttrLinkInfoWithPadding := make([]byte, rtaAlign(unix.SizeofRtAttr)) - copy(rtAttrLinkInfoWithPadding[0:rtaAlign(unix.SizeofRtAttr)], rtAttrLinkInfoSerialized) - data = append(data, rtAttrLinkInfoWithPadding...) - - rtAttrLinkTypeSerialized := SerializeRoutingAttribute(&rtAttrLinkType) - rtAttrLinkTypeLength := rtaAlign(unix.SizeofRtAttr + len(deviceTypeData)) - rtAttrLinkTypeWithPadding := make([]byte, rtAttrLinkTypeLength) - copy(rtAttrLinkTypeWithPadding[0:unix.SizeofRtAttr], rtAttrLinkTypeSerialized) - copy(rtAttrLinkTypeWithPadding[unix.SizeofRtAttr:rtAttrLinkTypeLength], []byte(deviceTypeData)) - data = append(data, rtAttrLinkTypeWithPadding...) - - rtAttrLinkTypeDataSerialized := SerializeRoutingAttribute(&rtAttrLinkTypeData) - rtAttrLinkTypeDataWithPadding := make([]byte, rtaAlign(unix.SizeofRtAttr)) - copy(rtAttrLinkTypeDataWithPadding[0:rtaAlign(unix.SizeofRtAttr)], rtAttrLinkTypeDataSerialized) - data = append(data, rtAttrLinkTypeDataWithPadding...) - - rtAttrLinkPeerSerialized := SerializeRoutingAttribute(&rtAttrLinkPeer) - rtAttrLinkPeerWithPadding := make([]byte, rtaAlign(unix.SizeofRtAttr)) - copy(rtAttrLinkPeerWithPadding[0:rtaAlign(unix.SizeofRtAttr)], rtAttrLinkPeerSerialized) - data = append(data, rtAttrLinkPeerWithPadding...) - - data = append(data, SerializeAncilliaryMessageHeader(&ifInfomsgForPeer)...) - - rtAtttrLinkPeerNameSerialized := SerializeRoutingAttribute(&rtAtttrLinkPeerName) - rtAtttrLinkPeerNameLength := rtaAlign(unix.SizeofRtAttr + len(veth2+"\000")) - rtAtttrLinkPeerNameWithPadding := make([]byte, rtAtttrLinkPeerNameLength) - copy(rtAtttrLinkPeerNameWithPadding[0:unix.SizeofRtAttr], rtAtttrLinkPeerNameSerialized) - copy(rtAtttrLinkPeerNameWithPadding[unix.SizeofRtAttr:rtAtttrLinkPeerNameLength], []byte(veth2+"\000")) - data = append(data, rtAtttrLinkPeerNameWithPadding...) - - flags := 0 - - // the only way to communicate with netlink sockets is via sendto - if err := unix.Sendto(netlinkSocketFd, data, flags, &sockAddrNetlink); err != nil { - fmt.Println("Got an error while sending message to kernel " + err.Error()) - return err - } - fmt.Println("Veth pair creation command given to kernel successfully") - - if err := SetupKernelAcknowledgement(netlinkSocketFd, netlinkMsgHeader.Seq); err != nil { - fmt.Printf("Error received from kernel in response to veth pair -> %s\n", err.Error()) - return err - } - - return nil -} - -// DeleteNetworkLink deletes a network link -func DeleteNetworkLink(linkName string) error { - - iface, err := net.InterfaceByName(linkName) - if err != nil { - return err - } - - netlinkSocketFd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) - if err != nil { - fmt.Printf("Error happenned while creating netlink socket %s \n", err) - return err - } - - defer unix.Close(netlinkSocketFd) - - var sockAddrNetlink unix.SockaddrNetlink - sockAddrNetlink.Family = unix.AF_NETLINK // Address family of socket - sockAddrNetlink.Pad = 0 // should always be zero - sockAddrNetlink.Pid = 0 // have the kernel process message - // sockAddrNetlink.Groups // Not yet sure what to do with this - - if err := unix.Bind(netlinkSocketFd, &sockAddrNetlink); err != nil { - fmt.Println("Got an error while binding netlikn socket " + err.Error()) - } - - // Create netlink message header - var netlinkMsgHeader unix.NlMsghdr - // DELLINK is used to delete devices (ethernet, bridge etc.) - netlinkMsgHeader.Type = unix.RTM_DELLINK - // REQUEST: indicates a request messages - netlinkMsgHeader.Flags = unix.NLM_F_REQUEST | unix.NLM_F_ACK - // Seq is a 4 byte arbitrary number which is used to correlate request with response - netlinkMsgHeader.Seq = rand.Uint32() - netlinkMsgHeader.Pid = uint32(unix.Getpid()) - netlinkMsgHeader.Len = uint32(unix.SizeofNlMsghdr) + uint32(unix.SizeofIfInfomsg) - - // Message of type NewLINK is required to contain an ifinfomsg structure - // after the header (ancilliary data) - var ifInfomsg unix.IfInfomsg - ifInfomsg.Family = unix.AF_UNSPEC // it has to be this Value - ifInfomsg.Index = int32(iface.Index) - - data := SerializeNetLinkMessageHeader(&netlinkMsgHeader) - data = append(data, SerializeAncilliaryMessageHeader(&ifInfomsg)...) - - flags := 0 - if err := unix.Sendto(netlinkSocketFd, data, flags, &sockAddrNetlink); err != nil { - fmt.Println("Got an error while sending message to kernel " + err.Error()) - } - - fmt.Println("Delete link command given to kernel successfully") - - if err := SetupKernelAcknowledgement(netlinkSocketFd, netlinkMsgHeader.Seq); err != nil { - fmt.Printf("Error received from kernel -> %s\n", err.Error()) - return err - } - - return nil - -} - -func configureLinkIPAddress(linkName string, ip net.IP, ipNet *net.IPNet, action int, headerFlags int) error { - - iface, err := net.InterfaceByName(linkName) - if err != nil { - return err - } - - fmt.Printf("Got interface %s\n", iface.Name) - - ifaceIndex := uint32(iface.Index) - ifaceIndexByte := make([]byte, 4) - getNativeType().PutUint32(ifaceIndexByte, ifaceIndex) - - netlinkSocketFd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) - if err != nil { - fmt.Printf("Error happenned while creating netlink socket %s \n", err) - return err - } - - defer unix.Close(netlinkSocketFd) - - var sockAddrNetlink unix.SockaddrNetlink - sockAddrNetlink.Family = unix.AF_NETLINK // Address family of socket - sockAddrNetlink.Pad = 0 // should always be zero - sockAddrNetlink.Pid = 0 // have the kernel process message - // sockAddrNetlink.Groups // Not yet sure what to do with this - - if err := unix.Bind(netlinkSocketFd, &sockAddrNetlink); err != nil { - fmt.Println("Got an error while binding netlink socket " + err.Error()) - } - - // Create netlink message header - var netlinkMsgHeader unix.NlMsghdr - netlinkMsgHeader.Type = uint16(action) - netlinkMsgHeader.Flags = uint16(headerFlags) - netlinkMsgHeader.Seq = rand.Uint32() - netlinkMsgHeader.Pid = uint32(unix.Getpid()) - - var ifAddrmsg unix.IfAddrmsg - ifAddrmsg.Family = unix.AF_INET // ipv4 - - //ifAddrmsg.Flags = unix.NLM_F_REQUEST - ifAddrmsg.Index = uint32(iface.Index) - prefixlen, _ := ipNet.Mask.Size() - fmt.Printf("prefixlen: %d\n", prefixlen) - ifAddrmsg.Prefixlen = uint8(prefixlen) - - var payLoad []byte - if ifAddrmsg.Family == unix.AF_INET { - payLoad = ip.To4() - fmt.Println(payLoad) - } else { - // not supported - fmt.Println(payLoad) - payLoad = ip.To16() - } - fmt.Printf("Payload: %s \n", ip.To4()) - netlinkMsgHeader.Len = - uint32(unix.SizeofNlMsghdr) + - uint32(rtaAlign(unix.SizeofIfAddrmsg)) + - uint32(rtaAlign(unix.SizeofRtAttr+len(payLoad))) + - uint32(rtaAlign(unix.SizeofRtAttr+len(payLoad))) - - attrLen := unix.SizeofRtAttr - var rtAttrLocal unix.RtAttr - rtAttrLocal.Type = unix.IFA_LOCAL - rtAttrLocal.Len = uint16(rtaAlign(attrLen + len(payLoad))) - - var rtAttrAddress unix.RtAttr - rtAttrAddress.Type = unix.IFA_ADDRESS - rtAttrAddress.Len = uint16(rtaAlign(attrLen + len(payLoad))) - - data := SerializeNetLinkMessageHeader(&netlinkMsgHeader) - - ifAddrmsgSerialized := SerializeAddressMessageHeader(&ifAddrmsg) - ifAddrmsgLength := rtaAlign(unix.SizeofIfAddrmsg) - ifAddrmsgWithPadding := make([]byte, ifAddrmsgLength) - copy(ifAddrmsgWithPadding[0:unix.SizeofIfAddrmsg], ifAddrmsgSerialized) - data = append(data, ifAddrmsgWithPadding...) - - rtAttrLocalSerialized := SerializeRoutingAttribute(&rtAttrLocal) - rtAttrLocalLength := rtaAlign(unix.SizeofRtAttr + len(payLoad)) - rtAttrLocalWithPadding := make([]byte, rtAttrLocalLength) - copy(rtAttrLocalWithPadding[0:unix.SizeofRtAttr], rtAttrLocalSerialized) - copy(rtAttrLocalWithPadding[unix.SizeofRtAttr:rtAttrLocalLength], []byte(payLoad)) - data = append(data, rtAttrLocalWithPadding...) - - rtAttrAddressSerialized := SerializeRoutingAttribute(&rtAttrAddress) - rtAttrAddressLength := rtaAlign(unix.SizeofRtAttr + len(payLoad)) - rtAttrAddressWithPadding := make([]byte, rtAttrAddressLength) - copy(rtAttrAddressWithPadding[0:unix.SizeofRtAttr], rtAttrAddressSerialized) - copy(rtAttrAddressWithPadding[unix.SizeofRtAttr:rtAttrAddressLength], []byte(payLoad)) - data = append(data, rtAttrAddressWithPadding...) - - flags := 0 - if err := unix.Sendto(netlinkSocketFd, data, flags, &sockAddrNetlink); err != nil { - fmt.Println("Got an error while sending message to kernel " + err.Error()) - } - - fmt.Println("Configure link ipaddress command given to kernel successfully") - - if err := SetupKernelAcknowledgement(netlinkSocketFd, netlinkMsgHeader.Seq); err != nil { - fmt.Printf("Error received from kernel -> %s\n", err.Error()) - return err - } - - return nil - -} diff --git a/netlink/netlink.go b/netlink/netlink.go new file mode 100644 index 0000000000..ff2e721f0a --- /dev/null +++ b/netlink/netlink.go @@ -0,0 +1,275 @@ +// Copyright Microsoft Corp. +// All rights reserved. + +package netlink + +import ( + "fmt" + "net" + + "golang.org/x/sys/unix" +) + +// Initializes netlink module. +func init() { + initEncoder() +} + +// Sends a netlink echo request message. +func Echo(text string) error { + s, err := getSocket() + if err != nil { + return err + } + + req := newRequest(unix.NLMSG_NOOP, unix.NLM_F_ECHO | unix.NLM_F_ACK) + if req == nil { + return unix.ENOMEM + } + + req.addPayload(newAttributeString(0, text)) + + return s.sendAndComplete(req) +} + +// Adds a new network link of a specified type. +func AddLink(name string, linkType string) error { + if name == "" || linkType == "" { + return fmt.Errorf("Invalid link name or type") + } + + s, err := getSocket() + if err != nil { + return err + } + + req := newRequest(unix.RTM_NEWLINK, unix.NLM_F_CREATE | unix.NLM_F_EXCL | unix.NLM_F_ACK) + + ifInfo := newIfInfoMsg() + req.addPayload(ifInfo) + + attrLinkInfo := newAttribute(unix.IFLA_LINKINFO, nil) + attrLinkInfo.addNested(newAttributeString(IFLA_INFO_KIND, linkType)) + req.addPayload(attrLinkInfo) + + attrIfName := newAttributeStringZ(unix.IFLA_IFNAME, name) + req.addPayload(attrIfName) + + return s.sendAndComplete(req) +} + +// Deletes a network link. +func DeleteLink(name string) error { + if name == "" { + return fmt.Errorf("Invalid link name") + } + + s, err := getSocket() + if err != nil { + return err + } + + iface, err := net.InterfaceByName(name) + if err != nil { + return err + } + + req := newRequest(unix.RTM_DELLINK, unix.NLM_F_ACK) + + ifInfo := newIfInfoMsg() + ifInfo.Index = int32(iface.Index) + req.addPayload(ifInfo) + + return s.sendAndComplete(req) +} + +// Sets the operational state of a network interface. +func SetLinkState(name string, up bool) error { + s, err := getSocket() + if err != nil { + return err + } + + iface, err := net.InterfaceByName(name) + if err != nil { + return err + } + + req := newRequest(unix.RTM_NEWLINK, unix.NLM_F_ACK) + + ifInfo := newIfInfoMsg() + ifInfo.Type = unix.RTM_SETLINK + ifInfo.Index = int32(iface.Index) + + if up { + ifInfo.Flags = unix.IFF_UP + ifInfo.Change = unix.IFF_UP + } else { + ifInfo.Flags = 0 & ^unix.IFF_UP + ifInfo.Change = DEFAULT_CHANGE + } + + req.addPayload(ifInfo) + + return s.sendAndComplete(req) +} + +// Sets the master (upper) device of a network interface. +func SetLinkMaster(name string, master string) error { + s, err := getSocket() + if err != nil { + return err + } + + iface, err := net.InterfaceByName(name) + if err != nil { + return err + } + + var masterIndex uint32 + if master != "" { + masterIface, err := net.InterfaceByName(master) + if err != nil { + return err + } + masterIndex = uint32(masterIface.Index) + } + + req := newRequest(unix.RTM_SETLINK, unix.NLM_F_ACK) + + ifInfo := newIfInfoMsg() + ifInfo.Type = unix.RTM_SETLINK + ifInfo.Index = int32(iface.Index) + ifInfo.Flags = unix.NLM_F_REQUEST + ifInfo.Change = DEFAULT_CHANGE + req.addPayload(ifInfo) + + attrMaster := newAttributeUint32(unix.IFLA_MASTER, masterIndex) + req.addPayload(attrMaster) + + return s.sendAndComplete(req) +} + +// Sets the link layer hardware address of a network interface. +func SetLinkAddress(ifName string, hwAddress net.HardwareAddr) error { + s, err := getSocket() + if err != nil { + return err + } + + iface, err := net.InterfaceByName(ifName) + if err != nil { + return err + } + + req := newRequest(unix.RTM_SETLINK, unix.NLM_F_ACK) + + ifInfo := newIfInfoMsg() + ifInfo.Type = unix.RTM_SETLINK + ifInfo.Index = int32(iface.Index) + ifInfo.Flags = unix.NLM_F_REQUEST + ifInfo.Change = DEFAULT_CHANGE + req.addPayload(ifInfo) + + req.addPayload(newAttribute(unix.IFLA_ADDRESS, hwAddress)) + + return s.sendAndComplete(req) +} + +// Adds a new veth pair. +func AddVethPair(name1 string, name2 string) error { + s, err := getSocket() + if err != nil { + return err + } + + req := newRequest(unix.RTM_NEWLINK, unix.NLM_F_CREATE | unix.NLM_F_EXCL | unix.NLM_F_ACK) + + ifInfo := newIfInfoMsg() + req.addPayload(ifInfo) + + attrIfName := newAttributeStringZ(unix.IFLA_IFNAME, name1) + req.addPayload(attrIfName) + + attrLinkInfo := newAttribute(unix.IFLA_LINKINFO, nil) + attrLinkInfo.addNested(newAttributeStringZ(IFLA_INFO_KIND, "veth")) + + attrData := newAttribute(IFLA_INFO_DATA, nil) + + attrPeer := newAttribute(VETH_INFO_PEER, nil) + attrPeer.addNested(newIfInfoMsg()) + attrPeer.addNested(newAttributeStringZ(unix.IFLA_IFNAME, name2)) + + attrLinkInfo.addNested(attrData) + attrData.addNested(attrPeer) + + req.addPayload(attrLinkInfo) + + return s.sendAndComplete(req) +} + +// Returns the address family of an IP address. +func getIpAddressFamily(ip net.IP) int { + if len(ip) <= net.IPv4len { + return unix.AF_INET + } + if ip.To4() != nil { + return unix.AF_INET + } + return unix.AF_INET6 +} + +// Sends an IP address set request. +func setIpAddress(ifName string, ipAddress net.IP, ipNet *net.IPNet, add bool) error { + var msgType, flags int + + s, err := getSocket() + if err != nil { + return err + } + + iface, err := net.InterfaceByName(ifName) + if err != nil { + return err + } + + if add { + msgType = unix.RTM_NEWADDR + flags = unix.NLM_F_CREATE | unix.NLM_F_EXCL | unix.NLM_F_ACK + } else { + msgType = unix.RTM_DELADDR + flags = unix.NLM_F_EXCL | unix.NLM_F_ACK + } + + req := newRequest(msgType, flags) + + family := getIpAddressFamily(ipAddress) + + ifAddr := newIfAddrMsg(family) + ifAddr.Index = uint32(iface.Index) + prefixLen, _ := ipNet.Mask.Size() + ifAddr.Prefixlen = uint8(prefixLen) + req.addPayload(ifAddr) + + var ipAddrValue []byte + if family == unix.AF_INET { + ipAddrValue = ipAddress.To4() + } else { + ipAddrValue = ipAddress.To16() + } + + req.addPayload(newAttribute(unix.IFA_LOCAL, ipAddrValue)) + req.addPayload(newAttribute(unix.IFA_ADDRESS, ipAddrValue)) + + return s.sendAndComplete(req) +} + +// Adds an IP address to an interface. +func AddIpAddress(ifName string, ipAddress net.IP, ipNet *net.IPNet) error { + return setIpAddress(ifName, ipAddress, ipNet, true) +} + +// Deletes an IP address from an interface. +func DeleteIpAddress(ifName string, ipAddress net.IP, ipNet *net.IPNet) error { + return setIpAddress(ifName, ipAddress, ipNet, false) +} diff --git a/netlink/netlink_test.go b/netlink/netlink_test.go new file mode 100644 index 0000000000..505cf13e1d --- /dev/null +++ b/netlink/netlink_test.go @@ -0,0 +1,61 @@ +// Copyright Microsoft Corp. +// All rights reserved. + +package netlink + +import ( + "fmt" + "testing" +) + +const ifname string = "nltest" + +// Tests basic netlink messaging via echo. +func TestEcho(t *testing.T) { + fmt.Println("Test: Echo") + + err := Echo("this is a test") + + if err != nil { + t.Errorf("Echo failed: %+v", err) + } +} + +// Tests creating a new network interface. +func TestAddLink(t *testing.T) { + fmt.Println("Test: AddLink") + + err := AddLink(ifname, "bridge") + + if err != nil { + t.Errorf("AddLink failed: %+v", err) + } +} + +// Tests setting the operational state of a network interface. +func TestSetLinkState(t *testing.T) { + fmt.Println("Test: SetLinkState") + + err := SetLinkState(ifname, true) + + if err != nil { + t.Errorf("SetLinkState up failed: %+v", err) + } + + err = SetLinkState(ifname, false) + + if err != nil { + t.Errorf("SetLinkState down failed: %+v", err) + } +} + +// Tests deleting a network interface. +func TestDeleteLink(t *testing.T) { + fmt.Println("Test: DeleteLink") + + err := DeleteLink(ifname) + + if err != nil { + t.Errorf("DeleteLink failed: %+v", err) + } +} diff --git a/netlink/protocol.go b/netlink/protocol.go new file mode 100644 index 0000000000..03fe7257e4 --- /dev/null +++ b/netlink/protocol.go @@ -0,0 +1,254 @@ +// Copyright Microsoft Corp. +// All rights reserved. + +package netlink + +import ( + "encoding/binary" + "unsafe" + + "golang.org/x/sys/unix" +) + +// Netlink protocol constants that are not already defined in unix package. +const ( + IFLA_INFO_KIND = 1 + IFLA_INFO_DATA = 2 + VETH_INFO_PEER = 1 + DEFAULT_CHANGE = 0xFFFFFFFF +) + +// Serializable types are used to construct netlink messages. +type serializable interface { + serialize() []byte + length() int +} + +// Byte encoder +var encoder binary.ByteOrder + +// Initializes the byte encoder. +func initEncoder() { + var x uint32 = 0x01020304 + if *(*byte)(unsafe.Pointer(&x)) == 0x01 { + encoder = binary.BigEndian + } else { + encoder = binary.LittleEndian + } +} + +// +// Netlink message +// + +// Generic netlink message +type message struct { + unix.NlMsghdr + payload []serializable +} + +// Creates a new netlink message. +func newMessage(msgType int, flags int) *message { + return &message{ + NlMsghdr: unix.NlMsghdr{ + Len: uint32(unix.NLMSG_HDRLEN), + Type: uint16(msgType), + Flags: uint16(flags), + Seq: 0, + Pid: uint32(unix.Getpid()), + }, + } +} + +// Creates a new netlink request message. +func newRequest(msgType int, flags int) *message { + return newMessage(msgType, flags | unix.NLM_F_REQUEST) +} + +// Appends protocol specific payload to a netlink message. +func (msg *message) addPayload(payload serializable) { + if payload != nil { + msg.payload = append(msg.payload, payload) + } +} + +// Serializes a netlink message. +func (msg *message) serialize() []byte { + // Serialize the protocol specific payload. + msg.Len = uint32(unix.NLMSG_HDRLEN) + payload := make([][]byte, len(msg.payload)) + for i, p := range msg.payload { + payload[i] = p.serialize() + msg.Len += uint32(len(payload[i])) + } + + // Serialize the message header. + b := make([]byte, msg.Len) + encoder.PutUint32(b[0:4], msg.Len) + encoder.PutUint16(b[4:6], msg.Type) + encoder.PutUint16(b[6:8], msg.Flags) + encoder.PutUint32(b[8:12], msg.Seq) + encoder.PutUint32(b[12:16], msg.Pid) + + // Append the payload. + next := 16 + for _, p := range payload { + copy(b[next:], p) + next += len(p) + } + + return b +} + +// +// Netlink message attribute +// + +// Generic netlink message attribute +type attribute struct { + unix.NlAttr + value []byte + children []serializable +} + +// Creates a new attribute. +func newAttribute(attrType int, value []byte) *attribute { + return &attribute{ + NlAttr: unix.NlAttr{ + Type: uint16(attrType), + }, + value: value, + children: []serializable{}, + } +} + +// Creates a new attribute with a string value. +func newAttributeString(attrType int, value string) *attribute { + return newAttribute(attrType, []byte(value)) +} + +// Creates a new attribute with a null-terminated string value. +func newAttributeStringZ(attrType int, value string) *attribute { + return newAttribute(attrType, []byte(value + "\000")) +} + +// Creates a new attribute with a uint32 value. +func newAttributeUint32(attrType int, value uint32) *attribute { + buf := make([]byte, 4) + encoder.PutUint32(buf, value) + return newAttribute(attrType, buf) +} + +// Adds a nested attribute to an attribute. +func (attr *attribute) addNested(nested serializable) { + attr.children = append(attr.children, nested) +} + +// Serializes an attribute. +func (attr *attribute) serialize() []byte { + length := attr.length() + buf := make([]byte, length) + + // Encode length. + if l := uint16(length); l != 0 { + encoder.PutUint16(buf[0:2], l) + } + + // Encode type. + encoder.PutUint16(buf[2:4], attr.Type) + + if attr.value != nil { + // Encode value. + copy(buf[4:], attr.value) + } else { + // Serialize any nested attributes. + offset := 4 + for _, child := range attr.children { + childBuf := child.serialize() + copy(buf[offset:], childBuf) + offset += len(childBuf) + } + } + + return buf +} + +// Returns the aligned length of an attribute. +func (attr *attribute) length() int { + len := unix.SizeofNlAttr + len(attr.value) + + for _, child := range attr.children { + len += child.length() + } + + return (len + unix.NLA_ALIGNTO - 1) & ^(unix.NLA_ALIGNTO - 1) +} + +// +// Network interface service module +// + +// Interface info message +type ifInfoMsg struct { + unix.IfInfomsg +} + +// Creates a new interface info message. +func newIfInfoMsg() *ifInfoMsg { + return &ifInfoMsg{ + IfInfomsg: unix.IfInfomsg{ + Family: uint8(unix.AF_UNSPEC), + }, + } +} + +// Serializes an interface info message. +func (ifInfo *ifInfoMsg) serialize() []byte { + b := make([]byte, ifInfo.length()) + b[0] = ifInfo.Family + b[1] = 0 // Padding. + encoder.PutUint16(b[2:4], ifInfo.Type) + encoder.PutUint32(b[4:8], uint32(ifInfo.Index)) + encoder.PutUint32(b[8:12], ifInfo.Flags) + encoder.PutUint32(b[12:16], ifInfo.Change) + return b +} + +// Returns the length of an interface info message. +func (ifInfo *ifInfoMsg) length() int { + return unix.SizeofIfInfomsg +} + +// +// IP address service module +// + +// Interface address message +type ifAddrMsg struct { + unix.IfAddrmsg +} + +// Creates a new interface address message. +func newIfAddrMsg(family int) *ifAddrMsg { + return &ifAddrMsg{ + IfAddrmsg: unix.IfAddrmsg{ + Family: uint8(family), + }, + } +} + +// Serializes an interface address message. +func (ifAddr *ifAddrMsg) serialize() []byte { + b := make([]byte, ifAddr.length()) + b[0] = ifAddr.Family + b[1] = ifAddr.Prefixlen + b[2] = ifAddr.Flags + b[3] = ifAddr.Scope + encoder.PutUint32(b[4:8], ifAddr.Index) + return b +} + +// Returns the length of an interface address message. +func (ifAddr *ifAddrMsg) length() int { + return unix.SizeofIfAddrmsg +} diff --git a/netlink/socket.go b/netlink/socket.go new file mode 100644 index 0000000000..e4ac6deedf --- /dev/null +++ b/netlink/socket.go @@ -0,0 +1,135 @@ +// Copyright Microsoft Corp. +// All rights reserved. + +package netlink + +import ( + "fmt" + "sync" + "sync/atomic" + "syscall" + + "golang.org/x/sys/unix" + "github.com/Azure/Aqua/log" +) + +// Represents a netlink socket. +type socket struct { + fd int + sa unix.SockaddrNetlink + pid uint32 + seq uint32 + sync.Mutex +} + +// Default netlink socket. +var s *socket +var once sync.Once + +// Returns a reference to the default netlink socket. +func getSocket() (*socket, error) { + var err error + once.Do(func() { s, err = newSocket() }) + return s, err +} + +// Creates a new netlink socket object. +func newSocket() (*socket, error) { + fd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) + defer log.Printf("[netlink] Socket created, err=%v\n", err) + if err != nil { + return nil, err + } + + s := &socket{ + fd: fd, + pid: uint32(unix.Getpid()), + seq: 0, + } + + s.sa.Family = unix.AF_NETLINK + + err = unix.Bind(fd, &s.sa) + if err != nil { + unix.Close(fd) + return nil, err + } + + return s, nil +} + +// Closes the socket. +func (s *socket) close() { + err := unix.Close(s.fd) + log.Printf("[netlink] Socket closed, err=%v\n", err) +} + +// Sends a netlink message. +func (s *socket) send(msg *message) error { + msg.Seq = atomic.AddUint32(&s.seq, 1) + err := unix.Sendto(s.fd, msg.serialize(), 0, &s.sa) + log.Printf("[netlink] Sent %+v, err=%v\n", *msg, err) + return err +} + +// Sends a netlink message and blocks until its completion. +func (s *socket) sendAndComplete(msg *message) error { + s.Lock() + defer s.Unlock() + + err := s.send(msg) + if err != nil { + return err + } + + return s.waitForAck(msg) +} + +// Receives a netlink message. +func (s *socket) receive() ([]syscall.NetlinkMessage, error) { + buffer := make([]byte, unix.Getpagesize()) + n, _, err := unix.Recvfrom(s.fd, buffer, 0) + + if err != nil { + return nil, err + } + + if n < unix.NLMSG_HDRLEN { + return nil, fmt.Errorf("Invalid netlink message") + } + + buffer = buffer[:n] + return syscall.ParseNetlinkMessage(buffer) +} + +// Waits for the acknowledgement for the given sent message. +func (s *socket) waitForAck(sent *message) error { + for { + received, err := s.receive() + if err != nil { + log.Printf("[netlink] Receive err=%v\n", err) + return err + } + + for _, msg := range received { + // An acknowledgement is an error message with error code set to + // zero, followed by the original request message header. + if msg.Header.Type == unix.NLMSG_ERROR && + msg.Header.Seq == sent.Seq && + msg.Header.Pid == sent.Pid { + + errCode := int32(encoder.Uint32(msg.Data[0:4])) + if errCode != 0 { + // Request failed. + err = syscall.Errno(-errCode) + } + + log.Printf("[netlink] Received %+v, err=%v\n", msg, err) + + return err + } else { + log.Printf("[netlink] Ignoring unexpected message %+v\n", msg) + } + } + } +} diff --git a/network/plugin.go b/network/plugin.go index 03e2f4aff5..01ab4ed3cb 100644 --- a/network/plugin.go +++ b/network/plugin.go @@ -284,10 +284,12 @@ func (plugin *netPlugin) createEndpoint(w http.ResponseWriter, r *http.Request) // lets lock driver for now.. will optimize later plugin.Lock() if !plugin.networkExists(netID) { + plugin.Unlock() plugin.listener.SendError(w, fmt.Sprintf("Could not find [networkID:%s]\n", netID)) return } if plugin.endpointExists(netID, endID) { + plugin.Unlock() plugin.listener.SendError(w, fmt.Sprintf("Endpoint already exists [networkID:%s endpointID:%s]\n", netID, endID)) return } @@ -304,8 +306,8 @@ func (plugin *netPlugin) createEndpoint(w http.ResponseWriter, r *http.Request) rGatewayIPv4, ermsg := core.GetTargetInterface(interfaceToAttach, ipaddressToAttach) if ermsg != "" { - plugin.listener.SendError(w, ermsg) plugin.Unlock() + plugin.listener.SendError(w, ermsg) return }