forked from weaveworks/weave
/
udp_sender.go
138 lines (123 loc) · 3.29 KB
/
udp_sender.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
package router
import (
"code.google.com/p/gopacket"
"code.google.com/p/gopacket/layers"
"log"
"net"
"syscall"
)
type UDPSender interface {
Send([]byte) error
Shutdown() error
}
type SimpleUDPSender struct {
conn *LocalConnection
udpConn *net.UDPConn
}
type RawUDPSender struct {
ipBuf gopacket.SerializeBuffer
opts gopacket.SerializeOptions
udpHeader *layers.UDP
socket *net.IPConn
conn *LocalConnection
}
type MsgTooBigError struct {
PMTU int // actual pmtu, i.e. what the kernel told us
}
func NewSimpleUDPSender(conn *LocalConnection) *SimpleUDPSender {
return &SimpleUDPSender{udpConn: conn.Router.UDPListener, conn: conn}
}
func (sender *SimpleUDPSender) Send(msg []byte) error {
_, err := sender.udpConn.WriteToUDP(msg, sender.conn.RemoteUDPAddr())
return err
}
func (sender *SimpleUDPSender) Shutdown() error {
return nil
}
func NewRawUDPSender(conn *LocalConnection) (*RawUDPSender, error) {
ipSocket, err := dialIP(conn)
if err != nil {
return nil, err
}
udpHeader := &layers.UDP{SrcPort: layers.UDPPort(Port)}
ipBuf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
FixLengths: true,
// UDP header is calculated with a phantom IP
// header. Yes, it's totally nuts. Thankfully, for UDP
// over IPv4, the checksum is optional. It's not
// optional for IPv6, but we'll ignore that for
// now. TODO
ComputeChecksums: false}
return &RawUDPSender{
ipBuf: ipBuf,
opts: opts,
udpHeader: udpHeader,
socket: ipSocket,
conn: conn}, nil
}
func (sender *RawUDPSender) Send(msg []byte) error {
payload := gopacket.Payload(msg)
sender.udpHeader.DstPort = layers.UDPPort(sender.conn.RemoteUDPAddr().Port)
err := gopacket.SerializeLayers(sender.ipBuf, sender.opts, sender.udpHeader, &payload)
if err != nil {
return err
}
packet := sender.ipBuf.Bytes()
_, err = sender.socket.Write(packet)
if err == nil || PosixError(err) != syscall.EMSGSIZE {
return err
}
f, err := sender.socket.File()
if err != nil {
return err
}
defer f.Close()
fd := int(f.Fd())
log.Println("EMSGSIZE on send, expecting PMTU update (IP packet was",
len(packet), "bytes, payload was", len(msg), "bytes)")
pmtu, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_MTU)
if err != nil {
return err
}
return MsgTooBigError{PMTU: pmtu}
}
func (sender *RawUDPSender) Shutdown() error {
defer func() { sender.socket = nil }()
return sender.socket.Close()
}
func dialIP(conn *LocalConnection) (*net.IPConn, error) {
ipLocalAddr, err := ipAddr(conn.TCPConn.LocalAddr())
if err != nil {
return nil, err
}
ipRemoteAddr, err := ipAddr(conn.TCPConn.RemoteAddr())
if err != nil {
return nil, err
}
ipSocket, err := net.DialIP("ip4:UDP", ipLocalAddr, ipRemoteAddr)
if err != nil {
return nil, err
}
f, err := ipSocket.File()
if err != nil {
return nil, err
}
defer f.Close()
fd := int(f.Fd())
// This Makes sure all packets we send out have DF set on them.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_DO)
if err != nil {
return nil, err
}
return ipSocket, nil
}
func ipAddr(addr net.Addr) (*net.IPAddr, error) {
host, _, err := net.SplitHostPort(addr.String())
if err != nil {
return nil, err
}
return &net.IPAddr{
IP: net.ParseIP(host),
Zone: ""}, nil
}