Skip to content

Commit

Permalink
optimzie udp process gorouties
Browse files Browse the repository at this point in the history
Signed-off-by: Asutorufa <16442314+Asutorufa@users.noreply.github.com>
  • Loading branch information
Asutorufa committed Jul 1, 2023
1 parent c17f3fa commit a5017da
Show file tree
Hide file tree
Showing 7 changed files with 226 additions and 89 deletions.
57 changes: 52 additions & 5 deletions pkg/components/inbound/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package inbound
import (
"context"
"fmt"
"runtime"
"time"

"github.com/Asutorufa/yuhaiin/pkg/log"
Expand All @@ -14,15 +15,48 @@ import (
var Timeout = time.Second * 20

type handler struct {
dialer proxy.Proxy
table *nat.Table
dialer proxy.Proxy
table *nat.Table
packetChan chan struct {
ctx context.Context
packet *proxy.Packet
}

doneCtx context.Context
cancelCtx func()
}

func NewHandler(dialer proxy.Proxy) *handler {
return &handler{
ctx, cancel := context.WithCancel(context.Background())
h := &handler{
dialer: dialer,
table: nat.NewTable(dialer),
packetChan: make(chan struct {
ctx context.Context
packet *proxy.Packet
}),
doneCtx: ctx,
cancelCtx: cancel,
}

procs := runtime.GOMAXPROCS(0)
if procs < 4 {
procs = 4
}
for i := 0; i < procs; i++ {
go func() {
for {
select {
case pack := <-h.packetChan:
h.packet(pack.ctx, pack.packet)
case <-h.doneCtx.Done():
return
}
}
}()
}

return h
}

func (s *handler) Stream(ctx context.Context, meta *proxy.StreamMeta) {
Expand Down Expand Up @@ -58,8 +92,18 @@ func (s *handler) stream(ctx context.Context, meta *proxy.StreamMeta) error {
relay.Relay(meta.Src, remote)
return nil
}

func (s *handler) Packet(ctx context.Context, pack *proxy.Packet) {
select {
case s.packetChan <- struct {
ctx context.Context
packet *proxy.Packet
}{ctx, pack}:

case <-s.doneCtx.Done():
}
}

func (s *handler) packet(ctx context.Context, pack *proxy.Packet) {
ctx, cancel := context.WithTimeout(ctx, Timeout)
defer cancel()

Expand All @@ -70,4 +114,7 @@ func (s *handler) Packet(ctx context.Context, pack *proxy.Packet) {
}
}

func (s *handler) Close() error { return s.table.Close() }
func (s *handler) Close() error {
s.cancelCtx()
return s.table.Close()
}
146 changes: 99 additions & 47 deletions pkg/net/dns/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,57 @@ type dnsServer struct {
resolver proxy.Resolver
listener net.PacketConn
tcpListener net.Listener

reqChan chan dnsRequest
doneCtx context.Context
cancelCtx func()
}

type dnsRequest struct {
payload []byte
writeBack func([]byte) error
}

func NewDnsServer(server string, process proxy.Resolver) proxy.DNSHandler {
d := &dnsServer{server: server, resolver: process}
ctx, cancel := context.WithCancel(context.TODO())

d := &dnsServer{
server: server,
resolver: process,
reqChan: make(chan dnsRequest),
doneCtx: ctx,
cancelCtx: cancel,
}

do := func(req dnsRequest) error {
ctx, cancel := context.WithTimeout(context.TODO(), time.Second*7)
defer cancel()

data, err := d.handle(ctx, req.payload)
if err != nil {
return err
}
return req.writeBack(data)
}

procs := runtime.GOMAXPROCS(0)
if procs < 4 {
procs = 4
}
for i := 0; i < procs; i++ {
go func() {
for {
select {
case <-ctx.Done():
return
case req := <-d.reqChan:
if err := do(req); err != nil {
log.Error("dns server handle failed", "err", err)
}
}
}
}()
}

if server == "" {
log.Warn("dns server is empty, skip to listen tcp and udp")
Expand All @@ -49,6 +96,8 @@ func NewDnsServer(server string, process proxy.Resolver) proxy.DNSHandler {
}

func (d *dnsServer) Close() error {
d.cancelCtx()

if d.listener != nil {
d.listener.Close()
}
Expand All @@ -67,35 +116,35 @@ func (d *dnsServer) startUDP() (err error) {

log.Info("new udp dns server", "host", d.server)

for i := 0; i < runtime.GOMAXPROCS(0); i++ {
go func() {
defer d.Close()
go func() {
defer d.Close()
for {
buf := pool.GetBytes(nat.MaxSegmentSize)
defer pool.PutBytes(buf)
for {
n, addr, err := d.listener.ReadFrom(buf)
if err != nil {
if e, ok := err.(net.Error); ok && e.Temporary() {
continue
}
log.Error("dns udp server handle failed", "err", err)
return
n, addr, err := d.listener.ReadFrom(buf)
if err != nil {
pool.PutBytes(buf)
if e, ok := err.(net.Error); ok && e.Temporary() {
continue
}
log.Error("dns udp server handle failed", "err", err)
return
}

ctx, cancel := context.WithTimeout(context.TODO(), time.Second*10)
defer cancel()
ctx, cancel := context.WithTimeout(context.TODO(), time.Second*10)
defer cancel()

data, err := d.handle(ctx, buf[:n])
if err != nil {
log.Error("dns server handle data failed", slog.Any("err", err))
} else {
if _, err = d.listener.WriteTo(data, addr); err != nil {
log.Error("write dns response to client failed", slog.Any("err", err))
}
err = d.Do(ctx, buf[:n], func(b []byte) error {
defer pool.PutBytes(buf)
if _, err = d.listener.WriteTo(b, addr); err != nil {
return fmt.Errorf("write dns response to client failed: %w", err)
}
return nil
})
if err != nil {
log.Error("dns server handle data failed", slog.Any("err", err))
}
}()
}
}
}()

return nil
}
Expand All @@ -122,10 +171,7 @@ func (d *dnsServer) startTCP() (err error) {
go func() {
defer conn.Close()

ctx, cancel := context.WithTimeout(context.TODO(), time.Second*10)
defer cancel()

if err := d.HandleTCP(ctx, conn); err != nil {
if err := d.HandleTCP(context.TODO(), conn); err != nil {
log.Error("handle dns tcp failed", "err", err)
}
}()
Expand All @@ -139,43 +185,49 @@ func (d *dnsServer) HandleTCP(ctx context.Context, c net.Conn) error {
}

data := pool.GetBytes(int(length))
defer pool.PutBytes(data)

n, err := io.ReadFull(c, data[:length])
if err != nil {
pool.PutBytes(data)
return fmt.Errorf("dns server read data failed: %w", err)
}

data, err = d.handle(ctx, data[:n])
if err != nil {
return fmt.Errorf("dns server handle failed: %w", err)
}

if err = binary.Write(c, binary.BigEndian, uint16(len(data))); err != nil {
return fmt.Errorf("dns server write length failed: %w", err)
}
_, err = c.Write(data)
return err
return d.Do(ctx, data[:n], func(b []byte) error {
defer pool.PutBytes(data)
if err = binary.Write(c, binary.BigEndian, uint16(len(b))); err != nil {
return fmt.Errorf("dns server write length failed: %w", err)
}
_, err = c.Write(b)
return err
})
}

func (d *dnsServer) HandleUDP(ctx context.Context, l net.PacketConn) error {
buf := pool.GetBytes(nat.MaxSegmentSize)
defer pool.PutBytes(buf)

n, addr, err := l.ReadFrom(buf)
if err != nil {
pool.PutBytes(buf)
return err
}

data, err := d.handle(ctx, buf[:n])
if err != nil {
return fmt.Errorf("dns server handle failed: %w", err)
}
_, err = l.WriteTo(data, addr)
return err
return d.Do(ctx, buf[:n], func(b []byte) error {
defer pool.PutBytes(buf)
_, err = l.WriteTo(b, addr)
return err
})
}

func (d *dnsServer) Do(ctx context.Context, b []byte) ([]byte, error) { return d.handle(ctx, b) }
func (d *dnsServer) Do(ctx context.Context, b []byte, writeBack func([]byte) error) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-d.doneCtx.Done():
return io.EOF
case d.reqChan <- dnsRequest{b, writeBack}:
return nil
}
}

func (d *dnsServer) handle(ctx context.Context, raw []byte) ([]byte, error) {
var parse dnsmessage.Parser
Expand Down
10 changes: 5 additions & 5 deletions pkg/net/interfaces/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ type DNSHandler interface {
Server
HandleUDP(context.Context, net.PacketConn) error
HandleTCP(context.Context, net.Conn) error
Do(context.Context, []byte) ([]byte, error)
Do(context.Context, []byte, func([]byte) error) error
}

var EmptyDNSServer DNSHandler = &emptyHandler{}

type emptyHandler struct{}

func (e *emptyHandler) Close() error { return nil }
func (e *emptyHandler) HandleUDP(context.Context, net.PacketConn) error { return io.EOF }
func (e *emptyHandler) HandleTCP(context.Context, net.Conn) error { return io.EOF }
func (e *emptyHandler) Do(context.Context, []byte) ([]byte, error) { return nil, io.EOF }
func (e *emptyHandler) Close() error { return nil }
func (e *emptyHandler) HandleUDP(context.Context, net.PacketConn) error { return io.EOF }
func (e *emptyHandler) HandleTCP(context.Context, net.Conn) error { return io.EOF }
func (e *emptyHandler) Do(context.Context, []byte, func([]byte) error) error { return io.EOF }
32 changes: 14 additions & 18 deletions pkg/net/proxy/socks5/server/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"net"
"runtime"

"github.com/Asutorufa/yuhaiin/pkg/log"
"github.com/Asutorufa/yuhaiin/pkg/net/dialer"
Expand All @@ -32,27 +31,24 @@ func (s *Socks5) newUDPServer(handler proxy.Handler) error {
u := &udpServer{PacketConn: l, handler: handler}
s.udpServer = u

for i := 0; i < runtime.GOMAXPROCS(0); i++ {
go func() {
defer s.Close()
go func() {
defer s.Close()

buf := pool.GetBytes(nat.MaxSegmentSize)
defer pool.PutBytes(buf)
buf := pool.GetBytes(nat.MaxSegmentSize)
defer pool.PutBytes(buf)

for {
n, src, err := u.PacketConn.ReadFrom(buf)
if err != nil {
log.Error("read udp request failed, stop socks5 server", slog.Any("err", err))
return
}

if err := u.handle(buf[:n], src); err != nil && !errors.Is(err, net.ErrClosed) {
log.Error("handle udp request failed", "err", err)
}
for {
n, src, err := u.PacketConn.ReadFrom(buf)
if err != nil {
log.Error("read udp request failed, stop socks5 server", slog.Any("err", err))
return
}

}()
}
if err := u.handle(buf[:n], src); err != nil && !errors.Is(err, net.ErrClosed) {
log.Error("handle udp request failed", "err", err)
}
}
}()

return nil
}
Expand Down
Loading

0 comments on commit a5017da

Please sign in to comment.