Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 165 additions & 28 deletions cmd/udpProxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,44 @@ import (
"fmt"
"log"
"net"
"os"
"sync"
"sync/atomic"
"time"

"github.com/cilium/ebpf"
"github.com/txthinking/socks5"
)

const udpReadTimeout = 5 * time.Second
const (
udpSessionIdleTimeout = 60 * time.Second
udpSessionSendQueue = 256
udpDialTimeout = 5 * time.Second
)

const (
localDNSStubAddr = "127.0.0.53:53"
publicDNSAddr = "1.1.1.1:53"
)

type udpSessionManager struct {
proxyConn *net.UDPConn
udpMap *ebpf.Map
mirror *MirrorDispatcher
sessions sync.Map
}

type udpSession struct {
mgr *udpSessionManager
key string
clientAddr *net.UDPAddr
targetAddr string
remote net.Conn
sendCh chan []byte
done chan struct{}
lastActive atomic.Int64
closeOnce sync.Once
}

// StartUDPProxy listens on addr (UDP) and forwards packets to original destinations
// looked up from the BPF map (key = client addr, value = original dst ip:port).
func StartUDPProxy(addr string, udpMap *ebpf.Map, mirror *MirrorDispatcher) {
Expand All @@ -38,6 +62,12 @@ func StartUDPProxy(addr string, udpMap *ebpf.Map, mirror *MirrorDispatcher) {
}
defer conn.Close()

mgr := &udpSessionManager{
proxyConn: conn,
udpMap: udpMap,
mirror: mirror,
}

buf := make([]byte, 64*1024)
for {
n, clientAddr, err := conn.ReadFromUDP(buf)
Expand All @@ -50,58 +80,165 @@ func StartUDPProxy(addr string, udpMap *ebpf.Map, mirror *MirrorDispatcher) {
}
payload := make([]byte, n)
copy(payload, buf[:n])
go handleUDPPacket(conn, clientAddr, payload, udpMap, mirror)
mgr.dispatch(clientAddr, payload)
}
}

func handleUDPPacket(proxyConn *net.UDPConn, clientAddr *net.UDPAddr, payload []byte, udpMap *ebpf.Map, mirror *MirrorDispatcher) {
targetAddr, err := getUDPOriginalDest(clientAddr, udpMap)
func (m *udpSessionManager) dispatch(clientAddr *net.UDPAddr, payload []byte) {
key := clientAddr.String()
if v, ok := m.sessions.Load(key); ok {
v.(*udpSession).tryEnqueue(payload)
return
}

sess, err := m.createSession(clientAddr)
if err != nil {
log.Printf("UDP proxy: lookup original dest for %s: %v", clientAddr, err)
return
}

actual, loaded := m.sessions.LoadOrStore(key, sess)
if loaded {
sess.close()
actual.(*udpSession).tryEnqueue(payload)
return
}

go sess.run()
sess.tryEnqueue(payload)
}

func (m *udpSessionManager) createSession(clientAddr *net.UDPAddr) (*udpSession, error) {
targetAddr, err := getUDPOriginalDest(clientAddr, m.udpMap)
if err != nil {
log.Printf("UDP proxy: lookup original dest for %s: %v", clientAddr, err)
return nil, err
}
targetAddr = maybeRewriteLocalDNSStub(targetAddr)
fmt.Printf("UDP Original destination: %s\n", targetAddr)
log.Printf("UDP Original destination: %s", targetAddr)

var remoteConn net.Conn
if socks5ProxyAddr == "" {
remoteConn, err = net.DialTimeout("udp", targetAddr, 5*time.Second)
remoteConn, err = net.DialTimeout("udp", targetAddr, udpDialTimeout)
} else {
remoteConn, err = dialUDPViaSOCKS5(targetAddr)
}
if err != nil {
log.Printf("UDP proxy: dial %s: %v", targetAddr, err)
return
return nil, err
}
defer remoteConn.Close()

_, err = remoteConn.Write(payload)
if err != nil {
log.Printf("UDP proxy: write to %s: %v", targetAddr, err)
return
sess := &udpSession{
mgr: m,
key: clientAddr.String(),
clientAddr: clientAddr,
targetAddr: targetAddr,
remote: remoteConn,
sendCh: make(chan []byte, udpSessionSendQueue),
done: make(chan struct{}),
}
if mirror != nil && mirror.ShouldMirror("udp") {
mirror.Enqueue("udp", payload)
sess.touch()
return sess, nil
}

func (m *udpSessionManager) removeSession(key string) {
m.sessions.Delete(key)
}

func (s *udpSession) touch() {
s.lastActive.Store(time.Now().UnixNano())
}

func (s *udpSession) idleExpired() bool {
last := time.Unix(0, s.lastActive.Load())
return time.Since(last) >= udpSessionIdleTimeout
}

func (s *udpSession) tryEnqueue(payload []byte) {
select {
case <-s.done:
log.Printf("UDP proxy: session %s already closed, dropping packet", s.key)
case s.sendCh <- payload:
s.touch()
default:
log.Printf("UDP proxy: session %s send queue full (%d), dropping packet", s.key, udpSessionSendQueue)
}
}

remoteConn.SetReadDeadline(time.Now().Add(udpReadTimeout))
respBuf := make([]byte, 64*1024)
m, err := remoteConn.Read(respBuf)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
func (s *udpSession) close() {
s.closeOnce.Do(func() {
close(s.done)
s.remote.Close()
})
}

func (s *udpSession) run() {
defer s.mgr.removeSession(s.key)
defer s.close()

go s.writeLoop()

buf := make([]byte, 64*1024)
for {
if s.idleExpired() {
return
}
if ne, ok := err.(net.Error); ok && ne.Timeout() {

remaining := udpSessionIdleTimeout - time.Since(time.Unix(0, s.lastActive.Load()))
readDeadline := time.Second
if remaining < readDeadline {
readDeadline = remaining
}
if readDeadline <= 0 {
return
}

s.remote.SetReadDeadline(time.Now().Add(readDeadline))
n, err := s.remote.Read(buf)
if n > 0 {
if _, werr := s.mgr.proxyConn.WriteToUDP(buf[:n], s.clientAddr); werr != nil {
log.Printf("UDP proxy: write back to client %s: %v", s.clientAddr, werr)
return
}
s.touch()
continue
}
if err != nil {
if isUDPReadTimeout(err) {
continue
}
if !errors.Is(err, net.ErrClosed) {
log.Printf("UDP proxy: read from %s: %v", s.targetAddr, err)
}
return
}
log.Printf("UDP proxy: read from %s: %v", targetAddr, err)
return
}
_, err = proxyConn.WriteToUDP(respBuf[:m], clientAddr)
if err != nil {
log.Printf("UDP proxy: write back to client %s: %v", clientAddr, err)
return
}

func (s *udpSession) writeLoop() {
for {
select {
case <-s.done:
return
case payload := <-s.sendCh:
if _, err := s.remote.Write(payload); err != nil {
log.Printf("UDP proxy: write to %s: %v", s.targetAddr, err)
s.close()
return
}
if s.mgr.mirror != nil && s.mgr.mirror.ShouldMirror("udp") {
s.mgr.mirror.Enqueue("udp", payload)
}
s.touch()
}
}
}

func isUDPReadTimeout(err error) bool {
if err == nil {
return false
}
var ne net.Error
return errors.As(err, &ne) && ne.Timeout()
}

func maybeRewriteLocalDNSStub(targetAddr string) string {
Expand Down
2 changes: 1 addition & 1 deletion cmd/version.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
package main

var Version = "0.9.2"
var Version = "0.9.3"
Loading