diff --git a/decode.go b/decode.go index 99863c9..4db19c0 100644 --- a/decode.go +++ b/decode.go @@ -71,10 +71,17 @@ func decodemac(pkt []byte) uint64 { // Decode decodes the headers of a Packet. func (p *Packet) Decode() { + if len(p.Data) <= 14 { + return + } + p.Type = int(binary.BigEndian.Uint16(p.Data[12:14])) p.DestMac = decodemac(p.Data[0:6]) p.SrcMac = decodemac(p.Data[6:12]) - p.Payload = p.Data[14:] + + if len(p.Data) >= 15 { + p.Payload = p.Data[14:] + } switch p.Type { case TYPE_IP: @@ -163,6 +170,10 @@ func (arp *Arphdr) String() (s string) { } func (p *Packet) decodeArp() { + if len(p.Payload) < 8 { + return + } + pkt := p.Payload arp := new(Arphdr) arp.Addrtype = binary.BigEndian.Uint16(pkt[0:2]) @@ -170,13 +181,20 @@ func (p *Packet) decodeArp() { arp.HwAddressSize = pkt[4] arp.ProtAddressSize = pkt[5] arp.Operation = binary.BigEndian.Uint16(pkt[6:8]) + + if len(pkt) < int(8+2*arp.HwAddressSize+2*arp.ProtAddressSize) { + return + } arp.SourceHwAddress = pkt[8 : 8+arp.HwAddressSize] arp.SourceProtAddress = pkt[8+arp.HwAddressSize : 8+arp.HwAddressSize+arp.ProtAddressSize] arp.DestHwAddress = pkt[8+arp.HwAddressSize+arp.ProtAddressSize : 8+2*arp.HwAddressSize+arp.ProtAddressSize] arp.DestProtAddress = pkt[8+2*arp.HwAddressSize+arp.ProtAddressSize : 8+2*arp.HwAddressSize+2*arp.ProtAddressSize] p.Headers = append(p.Headers, arp) - p.Payload = p.Payload[8+2*arp.HwAddressSize+2*arp.ProtAddressSize:] + + if len(pkt) >= int(8+2*arp.HwAddressSize+2*arp.ProtAddressSize) { + p.Payload = p.Payload[8+2*arp.HwAddressSize+2*arp.ProtAddressSize:] + } } // IPadr is the header of an IP packet. @@ -196,6 +214,10 @@ type Iphdr struct { } func (p *Packet) decodeIp() { + if len(p.Payload) < 20 { + return + } + pkt := p.Payload ip := new(Iphdr) @@ -212,11 +234,18 @@ func (p *Packet) decodeIp() { ip.Checksum = binary.BigEndian.Uint16(pkt[10:12]) ip.SrcIp = pkt[12:16] ip.DestIp = pkt[16:20] + pEnd := int(ip.Length) if pEnd > len(pkt) { pEnd = len(pkt) } - p.Payload = pkt[ip.Ihl*4 : pEnd] + + if len(pkt) >= pEnd && int(ip.Ihl*4) < pEnd { + p.Payload = pkt[ip.Ihl*4 : pEnd] + } else { + p.Payload = []byte{} + } + p.Headers = append(p.Headers, ip) p.IP = ip @@ -250,12 +279,20 @@ func (v *Vlanhdr) String() { func (p *Packet) decodeVlan() { pkt := p.Payload vlan := new(Vlanhdr) + if len(pkt) < 4 { + return + } + vlan.Priority = (pkt[2] & 0xE0) >> 13 vlan.DropEligible = pkt[2]&0x10 != 0 vlan.VlanIdentifier = int(binary.BigEndian.Uint16(pkt[:2])) & 0x0FFF vlan.Type = int(binary.BigEndian.Uint16(p.Payload[2:4])) p.Headers = append(p.Headers, vlan) - p.Payload = p.Payload[4:] + + if len(pkt) >= 5 { + p.Payload = p.Payload[4:] + } + switch vlan.Type { case TYPE_IP: p.decodeIp() @@ -292,6 +329,10 @@ const ( ) func (p *Packet) decodeTcp() { + if len(p.Payload) < 20 { + return + } + pkt := p.Payload tcp := new(Tcphdr) tcp.SrcPort = binary.BigEndian.Uint16(pkt[0:2]) @@ -303,7 +344,9 @@ func (p *Packet) decodeTcp() { tcp.Window = binary.BigEndian.Uint16(pkt[14:16]) tcp.Checksum = binary.BigEndian.Uint16(pkt[16:18]) tcp.Urgent = binary.BigEndian.Uint16(pkt[18:20]) - p.Payload = pkt[tcp.DataOffset*4:] + if len(pkt) >= int(tcp.DataOffset*4) { + p.Payload = pkt[tcp.DataOffset*4:] + } p.Headers = append(p.Headers, tcp) p.TCP = tcp } @@ -354,6 +397,10 @@ type Udphdr struct { } func (p *Packet) decodeUdp() { + if len(p.Payload) < 8 { + return + } + pkt := p.Payload udp := new(Udphdr) udp.SrcPort = binary.BigEndian.Uint16(pkt[0:2]) @@ -362,7 +409,9 @@ func (p *Packet) decodeUdp() { udp.Checksum = binary.BigEndian.Uint16(pkt[6:8]) p.Headers = append(p.Headers, udp) p.UDP = udp - p.Payload = pkt[8:] + if len(p.Payload) >= 8 { + p.Payload = pkt[8:] + } } func (udp *Udphdr) String(hdr addrHdr) string { @@ -381,6 +430,10 @@ type Icmphdr struct { } func (p *Packet) decodeIcmp() *Icmphdr { + if len(p.Payload) < 8 { + return nil + } + pkt := p.Payload icmp := new(Icmphdr) icmp.Type = pkt[0] @@ -436,6 +489,10 @@ type Ip6hdr struct { } func (p *Packet) decodeIp6() { + if len(p.Payload) < 40 { + return + } + pkt := p.Payload ip6 := new(Ip6hdr) ip6.Version = uint8(pkt[0]) >> 4 @@ -446,7 +503,11 @@ func (p *Packet) decodeIp6() { ip6.HopLimit = pkt[7] ip6.SrcIp = pkt[8:24] ip6.DestIp = pkt[24:40] - p.Payload = pkt[40:] + + if len(p.Payload) >= 40 { + p.Payload = pkt[40:] + } + p.Headers = append(p.Headers, ip6) switch ip6.NextHeader { diff --git a/decode_test.go b/decode_test.go index 558fab1..7328c4f 100644 --- a/decode_test.go +++ b/decode_test.go @@ -3,6 +3,7 @@ package pcap import ( "bytes" "testing" + "time" ) var testSimpleTcpPacket *Packet = &Packet{ @@ -192,3 +193,55 @@ func TestDecodeVlanPacket(t *testing.T) { t.Errorf("Third header isn't TCP: %q", p.Headers[2]) } } + +func TestDecodeFuzzFallout(t *testing.T) { + testData := []struct { + Data []byte + }{ + {[]byte("000000000000\x81\x000")}, + {[]byte("000000000000\x81\x00000")}, + {[]byte("000000000000\x86\xdd0")}, + {[]byte("000000000000\b\x000")}, + {[]byte("000000000000\b\x060")}, + {[]byte{}}, + {[]byte("000000000000\b\x0600000000")}, + {[]byte("000000000000\x86\xdd000000\x01000000000000000000000000000000000")}, + {[]byte("000000000000\x81\x0000\b\x0600000000")}, + {[]byte("000000000000\b\x00n0000000000000000000")}, + {[]byte("000000000000\x86\xdd000000\x0100000000000000000000000000000000000")}, + {[]byte("000000000000\x81\x0000\b\x00g0000000000000000000")}, + //{[]byte()}, + {[]byte("000000000000\b\x00400000000\x110000000000")}, + {[]byte("0nMØ¡\xfe\x13\x13\x81\x00gr\b\x00&x\xc9\xe5b'\x1e0\x00\x04\x00\x0020596224")}, + {[]byte("000000000000\x81\x0000\b\x00400000000\x110000000000")}, + {[]byte("000000000000\b\x00000000000\x0600\xff0000000")}, + {[]byte("000000000000\x86\xdd000000\x06000000000000000000000000000000000")}, + {[]byte("000000000000\x81\x0000\b\x00000000000\x0600b0000000")}, + {[]byte("000000000000\x81\x0000\b\x00400000000\x060000000000")}, + {[]byte("000000000000\x86\xdd000000\x11000000000000000000000000000000000")}, + {[]byte("000000000000\x86\xdd000000\x0600000000000000000000000000000000000000000000M")}, + {[]byte("000000000000\b\x00500000000\x0600000000000")}, + {[]byte("0nM\xd80\xfe\x13\x13\x81\x00gr\b\x00&x\xc9\xe5b'\x1e0\x00\x04\x00\x0020596224")}, + } + + for _, entry := range testData { + pkt := &Packet{ + Time: time.Now(), + Caplen: uint32(len(entry.Data)), + Len: uint32(len(entry.Data)), + Data: entry.Data, + } + + pkt.Decode() + /* + func() { + defer func() { + if err := recover(); err != nil { + t.Fatalf("%d. %q failed: %v", idx, string(entry.Data), err) + } + }() + pkt.Decode() + }() + */ + } +}