diff --git a/transport/internet/dialer.go b/transport/internet/dialer.go index b6289c2788..b37725a215 100644 --- a/transport/internet/dialer.go +++ b/transport/internet/dialer.go @@ -55,7 +55,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *MemoryStrea return dialer(ctx, dest, streamSettings) } - if dest.Network == net.Network_UDP { + if dest.Network == net.Network_UDP || dest.Network == net.Network_UNIXGRAM { udpDialer := transportDialerCache["udp"] if udpDialer == nil { return nil, newError("UDP dialer not registered").AtError() diff --git a/transport/internet/system_dialer.go b/transport/internet/system_dialer.go index 314664a784..275ca317a0 100644 --- a/transport/internet/system_dialer.go +++ b/transport/internet/system_dialer.go @@ -2,11 +2,13 @@ package internet import ( "context" + "fmt" "syscall" "time" "github.com/v2fly/v2ray-core/v5/common/net" "github.com/v2fly/v2ray-core/v5/common/session" + "github.com/v2fly/v2ray-core/v5/common/uuid" ) var effectiveSystemDialer SystemDialer = &DefaultSystemDialer{} @@ -24,16 +26,29 @@ func resolveSrcAddr(network net.Network, src net.Address) net.Addr { return nil } - if network == net.Network_TCP { + switch network { + case net.Network_TCP: return &net.TCPAddr{ IP: src.IP(), Port: 0, } - } - - return &net.UDPAddr{ - IP: src.IP(), - Port: 0, + case net.Network_UDP: + return &net.UDPAddr{ + IP: src.IP(), + Port: 0, + } + case net.Network_UNIX: + return &net.UnixAddr{ + Name: src.Domain(), + Net: "unix", + } + case net.Network_UNIXGRAM: + return &net.UnixAddr{ + Name: src.Domain(), + Net: "unixgram", + } + default: + return nil } } @@ -42,27 +57,40 @@ func hasBindAddr(sockopt *SocketConfig) bool { } func (d *DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest net.Destination, sockopt *SocketConfig) (net.Conn, error) { - if dest.Network == net.Network_UDP && !hasBindAddr(sockopt) { - srcAddr := resolveSrcAddr(net.Network_UDP, src) - if srcAddr == nil { - srcAddr = &net.UDPAddr{ - IP: []byte{0, 0, 0, 0}, - Port: 0, + if (dest.Network == net.Network_UDP && !hasBindAddr(sockopt)) || dest.Network == net.Network_UNIXGRAM { + var srcAddr, dstAddr net.Addr + var err error + switch dest.Network { + case net.Network_UDP: + srcAddr = resolveSrcAddr(dest.Network, src) + if srcAddr == nil { + srcAddr = &net.UDPAddr{IP: []byte{0, 0, 0, 0}, Port: 0} + } + dstAddr, err = net.ResolveUDPAddr("udp", dest.NetAddr()) + if err != nil { + return nil, err + } + case net.Network_UNIXGRAM: + srcAddr = resolveSrcAddr(dest.Network, src) + if srcAddr == nil { + uuid := uuid.New() + srcAddr = &net.UnixAddr{Name: fmt.Sprintf("@v2ray/dialer/%s", uuid.String()), Net: "unixgram"} + } + dstAddr, err = net.ResolveUnixAddr("unixgram", dest.NetAddr()) + if err != nil { + return nil, err } } packetConn, err := ListenSystemPacket(ctx, srcAddr, sockopt) if err != nil { return nil, err } - destAddr, err := net.ResolveUDPAddr("udp", dest.NetAddr()) - if err != nil { - return nil, err - } return &packetConnWrapper{ conn: packetConn, - dest: destAddr, + dest: dstAddr, }, nil } + goStdKeepAlive := time.Duration(0) if sockopt != nil && (sockopt.TcpKeepAliveInterval != 0 || sockopt.TcpKeepAliveIdle != 0) { goStdKeepAlive = time.Duration(-1) diff --git a/transport/internet/system_listener.go b/transport/internet/system_listener.go index ef54228d64..9996c43793 100644 --- a/transport/internet/system_listener.go +++ b/transport/internet/system_listener.go @@ -36,6 +36,26 @@ func (l *combinedListener) Close() error { return l.Listener.Close() } +type combinedUnixConn struct { + *net.UnixConn + path string + locker *FileLocker // for unix domain socket +} + +func (l *combinedUnixConn) Close() error { + if l.locker != nil { + l.locker.Release() + l.locker = nil + } + err := l.UnixConn.Close() + if err != nil { + return err + } + // Unlike UnixListener, Close() on UnixConn will not unlink the underlying file descriptor + // We have to do it manually. + return syscall.Unlink(l.path) +} + func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []controller) func(network, address string, c syscall.RawConn) error { return func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { @@ -97,6 +117,10 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S fm := os.FileMode(fMode) fileMode = &fm } + // normal unix domain socket must be unlinked before listening + if err := syscall.Unlink(address); err != nil && err != syscall.ENOENT { + return nil, newError("failed to unlink unix socket file: ", address).Base(err) + } // normal unix domain socket needs lock locker := &FileLocker{ path: address + ".lock", @@ -135,10 +159,77 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S func (dl *DefaultListener) ListenPacket(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (net.PacketConn, error) { var lc net.ListenConfig + var network, address string - lc.Control = getControlFunc(ctx, sockopt, dl.controllers) + // callback is called after the Listen function returns + // this is used to wrap the listener and do some post processing + callback := func(l net.PacketConn, err error) (net.PacketConn, error) { + return l, err + } + switch addr := addr.(type) { + case *net.UDPAddr: + network = addr.Network() + address = addr.String() + lc.Control = getControlFunc(ctx, sockopt, dl.controllers) + case *net.UnixAddr: + lc.Control = nil + network = addr.Network() + address = addr.Name + if (runtime.GOOS == "linux" || runtime.GOOS == "android") && address[0] == '@' { + // linux abstract unix domain socket is lockfree + if len(address) > 1 && address[1] == '@' { + // but may need padding to work with haproxy + fullAddr := make([]byte, len(syscall.RawSockaddrUnix{}.Path)) + copy(fullAddr, address[1:]) + address = string(fullAddr) + } + } else { + // normal unix domain socket + var fileMode *os.FileMode + // parse file mode from address + if s := strings.Split(address, ","); len(s) == 2 { + fMode, err := strconv.ParseUint(s[1], 8, 32) + if err != nil { + return nil, newError("failed to parse file mode").Base(err) + } + address = s[0] + fm := os.FileMode(fMode) + fileMode = &fm + } + // normal unix domain socket must be unlinked before listening + if err := syscall.Unlink(address); err != nil && err != syscall.ENOENT { + return nil, newError("failed to unlink unix socket file: ", address).Base(err) + } + // normal unix domain socket needs lock + locker := &FileLocker{ + path: address + ".lock", + } + if err := locker.Acquire(); err != nil { + return nil, err + } + // set file mode for unix domain socket when it is created + callback = func(l net.PacketConn, err error) (net.PacketConn, error) { + if err != nil { + locker.Release() + return nil, err + } + l = &combinedUnixConn{UnixConn: l.(*net.UnixConn), path: address, locker: locker} + if fileMode == nil { + return l, err + } + if cerr := os.Chmod(address, *fileMode); cerr != nil { + // failed to set file mode, close the listener + l.Close() + return nil, newError("failed to set file mode for file: ", address).Base(cerr) + } + return l, err + } + } + } - return lc.ListenPacket(ctx, addr.Network(), addr.String()) + l, err := lc.ListenPacket(ctx, network, address) + l, err = callback(l, err) + return l, err } // RegisterListenerController adds a controller to the effective system listener.