Skip to content

Commit

Permalink
chore: dns outbound support tcp
Browse files Browse the repository at this point in the history
  • Loading branch information
wwqgtxx committed Mar 7, 2024
1 parent 0488676 commit fad1a08
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 23 deletions.
55 changes: 36 additions & 19 deletions adapter/outbound/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ package outbound

import (
"context"
"fmt"
"net"
"time"

N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/pool"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/resolver"
Expand All @@ -24,7 +24,9 @@ type DnsOption struct {

// DialContext implements C.ProxyAdapter
func (d *Dns) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
return nil, fmt.Errorf("dns outbound does not support tcp")
left, right := N.Pipe()
go resolver.RelayDnsConn(context.Background(), right, 0)
return NewConn(left, d), nil
}

// ListenPacketContext implements C.ProxyAdapter
Expand Down Expand Up @@ -76,29 +78,44 @@ func (d *dnsPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
}

func (d *dnsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
select {
case <-d.ctx.Done():
return 0, net.ErrClosed
default:
}

if len(p) > resolver.SafeDnsPacketSize {
// wtf???
return len(p), nil
}

ctx, cancel := context.WithTimeout(d.ctx, resolver.DefaultDnsRelayTimeout)
defer cancel()

buf := pool.Get(resolver.SafeDnsPacketSize)
put := func() { _ = pool.Put(buf) }
buf, err = resolver.RelayDnsPacket(ctx, p, buf)
if err != nil {
put()
return 0, err
}
copy(buf, p) // avoid p be changed after WriteTo returned

packet := dnsPacket{
data: buf,
put: put,
addr: addr,
}
select {
case d.response <- packet:
return len(p), nil
case <-d.ctx.Done():
put()
return 0, net.ErrClosed
}
go func() { // don't block the WriteTo function
buf, err = resolver.RelayDnsPacket(ctx, buf[:len(p)], buf)
if err != nil {
put()
return
}

packet := dnsPacket{
data: buf,
put: put,
addr: addr,
}
select {
case d.response <- packet:
break
case <-d.ctx.Done():
put()
}
}()
return len(p), nil
}

func (d *dnsPacketConn) Close() error {
Expand Down
5 changes: 5 additions & 0 deletions common/net/deadline/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ type Conn struct {
resultCh chan *connReadResult
}

func IsConn(conn any) bool {
_, ok := conn.(*Conn)
return ok
}

func NewConn(conn net.Conn) *Conn {
c := &Conn{
ExtendedConn: bufio.NewExtendedConn(conn),
Expand Down
5 changes: 5 additions & 0 deletions common/net/deadline/pipe_sing.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,8 @@ func (p *pipe) waitReadBuffer() (buffer *buf.Buffer, err error) {
return nil, os.ErrDeadlineExceeded
}
}

func IsPipe(conn any) bool {
_, ok := conn.(*pipe)
return ok
}
6 changes: 6 additions & 0 deletions common/net/sing.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ type ExtendedReader = network.ExtendedReader
var WriteBuffer = bufio.WriteBuffer

func NewDeadlineConn(conn net.Conn) ExtendedConn {
if deadline.IsPipe(conn) || deadline.IsPipe(network.UnwrapReader(conn)) {
return NewExtendedConn(conn) // pipe always have correctly deadline implement
}
if deadline.IsConn(conn) || deadline.IsConn(network.UnwrapReader(conn)) {
return NewExtendedConn(conn) // was a *deadline.Conn
}
return deadline.NewConn(conn)
}

Expand Down
6 changes: 3 additions & 3 deletions component/resolver/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ const DefaultDnsRelayTimeout = time.Second * 5

const SafeDnsPacketSize = 2 * 1024 // safe size which is 1232 from https://dnsflagday.net/2020/, so 2048 is enough

func RelayDnsConn(ctx context.Context, conn net.Conn) error {
func RelayDnsConn(ctx context.Context, conn net.Conn, readTimeout time.Duration) error {
buff := pool.Get(pool.UDPBufferSize)
defer func() {
_ = pool.Put(buff)
_ = conn.Close()
}()
for {
if conn.SetReadDeadline(time.Now().Add(DefaultDnsReadTimeout)) != nil {
break
if readTimeout > 0 {
_ = conn.SetReadDeadline(time.Now().Add(readTimeout))
}

length := uint16(0)
Expand Down
2 changes: 1 addition & 1 deletion listener/sing_tun/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (h *ListenerHandler) ShouldHijackDns(targetAddr netip.AddrPort) bool {
func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
if h.ShouldHijackDns(metadata.Destination.AddrPort()) {
log.Debugln("[DNS] hijack tcp:%s", metadata.Destination.String())
return resolver.RelayDnsConn(ctx, conn)
return resolver.RelayDnsConn(ctx, conn, resolver.DefaultDnsReadTimeout)
}
return h.ListenerHandler.NewConnection(ctx, conn, metadata)
}
Expand Down

0 comments on commit fad1a08

Please sign in to comment.