diff --git a/pkg/agent/multicast/mcast_controller_test.go b/pkg/agent/multicast/mcast_controller_test.go index 57e0bc1fab7..96d951e8de4 100644 --- a/pkg/agent/multicast/mcast_controller_test.go +++ b/pkg/agent/multicast/mcast_controller_test.go @@ -24,6 +24,9 @@ import ( "time" "antrea.io/libOpenflow/openflow13" + "antrea.io/libOpenflow/protocol" + "antrea.io/libOpenflow/util" + "antrea.io/ofnet/ofctrl" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "k8s.io/apimachinery/pkg/util/sets" @@ -55,6 +58,8 @@ var ( } nodeIf1IP = net.ParseIP("192.168.20.22") externalInterfaceIP = net.ParseIP("192.168.50.23") + pktInSrcMAC, _ = net.ParseMAC("11:22:33:44:55:66") + pktInDstMAC, _ = net.ParseMAC("01:00:5e:00:00:16") ) func TestAddGroupMemberStatus(t *testing.T) { @@ -241,6 +246,62 @@ func TestClearStaleGroups(t *testing.T) { } } +func TestProcessPacketIn(t *testing.T) { + mockController := newMockMulticastController(t) + snooper := mockController.igmpSnooper + stopCh := make(chan struct{}) + defer close(stopCh) + go mockController.eventHandler(stopCh) + + getIPs := func(ipStrs []string) []net.IP { + ips := make([]net.IP, len(ipStrs)) + for i := range ipStrs { + ips[i] = net.ParseIP(ipStrs[i]) + } + return ips + } + for _, tc := range []struct { + iface *interfacestore.InterfaceConfig + version uint8 + joinedGroups sets.String + leftGroups sets.String + }{ + { + iface: createInterface("p1", 1), + joinedGroups: sets.NewString("224.1.101.2", "224.1.101.3", "224.1.101.4"), + leftGroups: sets.NewString("224.1.101.2", "224.1.101.4"), + version: 1, + }, + { + iface: createInterface("p2", 2), + joinedGroups: sets.NewString("224.1.102.2", "224.1.102.3", "224.1.102.4"), + leftGroups: sets.NewString("224.1.102.3"), + version: 2, + }, + { + iface: createInterface("p3", 3), + joinedGroups: sets.NewString("224.1.103.2", "224.1.103.3", "224.1.103.4"), + leftGroups: sets.NewString("224.1.103.2"), + version: 3, + }, + } { + packets := createIGMPReportPacketIn(getIPs(tc.joinedGroups.List()), getIPs(tc.leftGroups.List()), tc.version, uint32(tc.iface.OFPort)) + mockOFClient.EXPECT().SendIGMPQueryPacketOut(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + for _, pkt := range packets { + mockIfaceStore.EXPECT().GetInterfaceByOFPort(uint32(tc.iface.OFPort)).Return(tc.iface, true) + err := snooper.processPacketIn(&pkt) + assert.Nil(t, err) + } + time.Sleep(time.Second) + expGroups := tc.joinedGroups.Difference(tc.leftGroups) + statues := mockController.getGroupMemberStatusesByPod(tc.iface.InterfaceName) + assert.Equal(t, expGroups.Len(), len(statues)) + for _, s := range statues { + assert.True(t, expGroups.Has(s.group.String())) + } + } +} + func compareGroupStatus(t *testing.T, cache cache.Indexer, event *mcastGroupEvent) { obj, exits, err := cache.GetByKey(event.group.String()) assert.Nil(t, err) @@ -280,3 +341,90 @@ func (c *Controller) initialize(t *testing.T) error { mockMulticastSocket.EXPECT().AllocateVIFs(gomock.Any(), uint16(1)).Times(1).Return([]uint16{1, 2}, nil) return c.Initialize() } + +func createInterface(name string, ofport uint32) *interfacestore.InterfaceConfig { + return &interfacestore.InterfaceConfig{ + InterfaceName: name, + OVSPortConfig: &interfacestore.OVSPortConfig{ + OFPort: int32(ofport), + }, + } +} + +func createIGMPReportPacketIn(joinedGroups []net.IP, leftGroups []net.IP, version uint8, ofport uint32) []ofctrl.PacketIn { + joinMessages := createIGMPJoinMessage(joinedGroups, version) + leftMessages := createIGMPLeaveMessage(leftGroups, version) + generatePacket := func(m util.Message) ofctrl.PacketIn { + pkt := openflow13.NewPacketIn() + matchInport := openflow13.NewInPortField(ofport) + pkt.Match.AddField(*matchInport) + ipPacket := &protocol.IPv4{ + Version: 0x4, + IHL: 5, + Protocol: IGMPProtocolNumber, + Length: 20 + m.Len(), + Data: m, + } + pkt.Data = protocol.Ethernet{ + HWDst: pktInDstMAC, + HWSrc: pktInSrcMAC, + Ethertype: protocol.IPv4_MSG, + Data: ipPacket, + } + return ofctrl.PacketIn(*pkt) + } + pkts := make([]ofctrl.PacketIn, 0) + for _, m := range joinMessages { + pkt := generatePacket(m) + pkts = append(pkts, pkt) + } + for _, m := range leftMessages { + pkt := generatePacket(m) + pkts = append(pkts, pkt) + } + return pkts +} + +func createIGMPLeaveMessage(groups []net.IP, version uint8) []util.Message { + pkts := make([]util.Message, 0) + switch version { + case 1: + for i := range groups { + pkts = append(pkts, protocol.NewIGMPv2Leave(groups[i])) + } + return pkts + case 2: + for i := range groups { + pkts = append(pkts, protocol.NewIGMPv2Leave(groups[i])) + } + return pkts + case 3: + records := make([]protocol.IGMPv3GroupRecord, 0) + for _, g := range groups { + records = append(records, protocol.NewGroupRecord(protocol.IGMPIsIn, g, nil)) + } + pkts = append(pkts, protocol.NewIGMPv3Report(records)) + } + return pkts +} + +func createIGMPJoinMessage(groups []net.IP, version uint8) []util.Message { + pkts := make([]util.Message, 0) + switch version { + case 1: + for i := range groups { + pkts = append(pkts, protocol.NewIGMPv1Report(groups[i])) + } + case 2: + for i := range groups { + pkts = append(pkts, protocol.NewIGMPv2Report(groups[i])) + } + case 3: + records := make([]protocol.IGMPv3GroupRecord, 0) + for _, g := range groups { + records = append(records, protocol.NewGroupRecord(protocol.IGMPIsEx, g, nil)) + } + pkts = append(pkts, protocol.NewIGMPv3Report(records)) + } + return pkts +} diff --git a/pkg/agent/multicast/mcast_discovery.go b/pkg/agent/multicast/mcast_discovery.go index 63096fbabd2..9184c41e2d0 100644 --- a/pkg/agent/multicast/mcast_discovery.go +++ b/pkg/agent/multicast/mcast_discovery.go @@ -127,7 +127,7 @@ func (s *IGMPSnooper) processPacketIn(pktIn *ofctrl.PacketIn) error { fallthrough case protocol.IGMPv2Report: mgroup := igmp.(*protocol.IGMPv1or2).GroupAddress - klog.InfoS("Received IGMPv1or2 Report message", "group", mgroup.String(), "interface", iface.PodName) + klog.InfoS("Received IGMPv1or2 Report message", "group", mgroup.String(), "interface", iface.InterfaceName) event := &mcastGroupEvent{ group: mgroup, eType: groupJoin, @@ -139,10 +139,14 @@ func (s *IGMPSnooper) processPacketIn(pktIn *ofctrl.PacketIn) error { msg := igmp.(*protocol.IGMPv3MembershipReport) for _, gr := range msg.GroupRecords { mgroup := gr.MulticastAddress - klog.InfoS("Received IGMPv3 Report message", "group", mgroup.String(), "interface", iface.PodName) + klog.InfoS("Received IGMPv3 Report message", "group", mgroup.String(), "interface", iface.InterfaceName, "recordType", gr.Type, "sourceCount", gr.NumberOfSources) + evtType := groupJoin + if gr.Type == protocol.IGMPIsIn && gr.NumberOfSources == 0 { + evtType = groupLeave + } event := &mcastGroupEvent{ group: mgroup, - eType: groupJoin, + eType: evtType, time: now, iface: iface, } @@ -151,7 +155,7 @@ func (s *IGMPSnooper) processPacketIn(pktIn *ofctrl.PacketIn) error { case protocol.IGMPv2LeaveGroup: mgroup := igmp.(*protocol.IGMPv1or2).GroupAddress - klog.InfoS("Received IGMPv2 Leave message", "group", mgroup.String(), "interface", iface.PodName) + klog.InfoS("Received IGMPv2 Leave message", "group", mgroup.String(), "interface", iface.InterfaceName) event := &mcastGroupEvent{ group: mgroup, eType: groupLeave,