Skip to content

Commit

Permalink
fix: wireguard can't be auto closed
Browse files Browse the repository at this point in the history
  • Loading branch information
wwqgtxx committed Apr 24, 2024
1 parent b2280c8 commit 2f8f139
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 122 deletions.
247 changes: 125 additions & 122 deletions adapter/outbound/wireguard.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,17 @@ type WireGuard struct {
device *device.Device
tunDevice wireguard.Device
dialer proxydialer.SingDialer
init func(ctx context.Context) error
resolver *dns.Resolver
refP *refProxyAdapter

initOk atomic.Bool
initMutex sync.Mutex
initErr error
option WireGuardOption
connectAddr M.Socksaddr
localPrefixes []netip.Prefix

closeCh chan struct{} // for test
}

type WireGuardOption struct {
Expand Down Expand Up @@ -141,19 +149,6 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
}
runtime.SetFinalizer(outbound, closeWireGuard)

resolv := func(ctx context.Context, address M.Socksaddr) (netip.AddrPort, error) {
if address.Addr.IsValid() {
return address.AddrPort(), nil
}
udpAddr, err := resolveUDPAddrWithPrefer(ctx, "udp", address.String(), outbound.prefer)
if err != nil {
return netip.AddrPort{}, err
}
// net.ResolveUDPAddr maybe return 4in6 address, so unmap at here
addrPort := udpAddr.AddrPort()
return netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()), nil
}

var reserved [3]uint8
if len(option.Reserved) > 0 {
if len(option.Reserved) != 3 {
Expand All @@ -162,29 +157,28 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
copy(reserved[:], option.Reserved)
}
var isConnect bool
var connectAddr M.Socksaddr
if len(option.Peers) < 2 {
isConnect = true
if len(option.Peers) == 1 {
connectAddr = option.Peers[0].Addr()
outbound.connectAddr = option.Peers[0].Addr()
} else {
connectAddr = option.Addr()
outbound.connectAddr = option.Addr()
}
}
outbound.bind = wireguard.NewClientBind(context.Background(), wgSingErrorHandler{outbound.Name()}, outbound.dialer, isConnect, connectAddr.AddrPort(), reserved)
outbound.bind = wireguard.NewClientBind(context.Background(), wgSingErrorHandler{outbound.Name()}, outbound.dialer, isConnect, outbound.connectAddr.AddrPort(), reserved)

localPrefixes, err := option.Prefixes()
var err error
outbound.localPrefixes, err = option.Prefixes()
if err != nil {
return nil, err
}

var privateKey string
{
bytes, err := base64.StdEncoding.DecodeString(option.PrivateKey)
if err != nil {
return nil, E.Cause(err, "decode private key")
}
privateKey = hex.EncodeToString(bytes)
option.PrivateKey = hex.EncodeToString(bytes)
}

if len(option.Peers) > 0 {
Expand Down Expand Up @@ -230,110 +224,16 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
option.PreSharedKey = hex.EncodeToString(bytes)
}
}

var (
initOk atomic.Bool
initMutex sync.Mutex
initErr error
)

outbound.init = func(ctx context.Context) error {
if initOk.Load() {
return nil
}
initMutex.Lock()
defer initMutex.Unlock()
// double check like sync.Once
if initOk.Load() {
return nil
}
if initErr != nil {
return initErr
}

outbound.bind.ResetReservedForEndpoint()
ipcConf := "private_key=" + privateKey
if len(option.Peers) > 0 {
for i, peer := range option.Peers {
destination, err := resolv(ctx, peer.Addr())
if err != nil {
// !!! do not set initErr here !!!
// let us can retry domain resolve in next time
return E.Cause(err, "resolve endpoint domain for peer ", i)
}
ipcConf += "\npublic_key=" + peer.PublicKey
ipcConf += "\nendpoint=" + destination.String()
if peer.PreSharedKey != "" {
ipcConf += "\npreshared_key=" + peer.PreSharedKey
}
for _, allowedIP := range peer.AllowedIPs {
ipcConf += "\nallowed_ip=" + allowedIP
}
if len(peer.Reserved) > 0 {
copy(reserved[:], option.Reserved)
outbound.bind.SetReservedForEndpoint(destination, reserved)
}
}
} else {
ipcConf += "\npublic_key=" + option.PublicKey
destination, err := resolv(ctx, connectAddr)
if err != nil {
// !!! do not set initErr here !!!
// let us can retry domain resolve in next time
return E.Cause(err, "resolve endpoint domain")
}
outbound.bind.SetConnectAddr(destination)
ipcConf += "\nendpoint=" + destination.String()
if option.PreSharedKey != "" {
ipcConf += "\npreshared_key=" + option.PreSharedKey
}
var has4, has6 bool
for _, address := range localPrefixes {
if address.Addr().Is4() {
has4 = true
} else {
has6 = true
}
}
if has4 {
ipcConf += "\nallowed_ip=0.0.0.0/0"
}
if has6 {
ipcConf += "\nallowed_ip=::/0"
}
}

if option.PersistentKeepalive != 0 {
ipcConf += fmt.Sprintf("\npersistent_keepalive_interval=%d", option.PersistentKeepalive)
}

if debug.Enabled {
log.SingLogger.Trace(fmt.Sprintf("[WG](%s) created wireguard ipc conf: \n %s", option.Name, ipcConf))
}
err = outbound.device.IpcSet(ipcConf)
if err != nil {
initErr = E.Cause(err, "setup wireguard")
return initErr
}

err = outbound.tunDevice.Start()
if err != nil {
initErr = err
return initErr
}

initOk.Store(true)
return nil
}
outbound.option = option

mtu := option.MTU
if mtu == 0 {
mtu = 1408
}
if len(localPrefixes) == 0 {
if len(outbound.localPrefixes) == 0 {
return nil, E.New("missing local address")
}
outbound.tunDevice, err = wireguard.NewStackDevice(localPrefixes, uint32(mtu))
outbound.tunDevice, err = wireguard.NewStackDevice(outbound.localPrefixes, uint32(mtu))
if err != nil {
return nil, E.Cause(err, "create WireGuard device")
}
Expand All @@ -347,7 +247,7 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
}, option.Workers)

var has6 bool
for _, address := range localPrefixes {
for _, address := range outbound.localPrefixes {
if !address.Addr().Unmap().Is4() {
has6 = true
break
Expand All @@ -373,11 +273,117 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
return outbound, nil
}

func (w *WireGuard) resolve(ctx context.Context, address M.Socksaddr) (netip.AddrPort, error) {
if address.Addr.IsValid() {
return address.AddrPort(), nil
}
udpAddr, err := resolveUDPAddrWithPrefer(ctx, "udp", address.String(), w.prefer)
if err != nil {
return netip.AddrPort{}, err
}
// net.ResolveUDPAddr maybe return 4in6 address, so unmap at here
addrPort := udpAddr.AddrPort()
return netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()), nil
}

func (w *WireGuard) init(ctx context.Context) error {
if w.initOk.Load() {
return nil
}
w.initMutex.Lock()
defer w.initMutex.Unlock()
// double check like sync.Once
if w.initOk.Load() {
return nil
}
if w.initErr != nil {
return w.initErr
}

w.bind.ResetReservedForEndpoint()
ipcConf := "private_key=" + w.option.PrivateKey
if len(w.option.Peers) > 0 {
for i, peer := range w.option.Peers {
destination, err := w.resolve(ctx, peer.Addr())
if err != nil {
// !!! do not set initErr here !!!
// let us can retry domain resolve in next time
return E.Cause(err, "resolve endpoint domain for peer ", i)
}
ipcConf += "\npublic_key=" + peer.PublicKey
ipcConf += "\nendpoint=" + destination.String()
if peer.PreSharedKey != "" {
ipcConf += "\npreshared_key=" + peer.PreSharedKey
}
for _, allowedIP := range peer.AllowedIPs {
ipcConf += "\nallowed_ip=" + allowedIP
}
if len(peer.Reserved) > 0 {
var reserved [3]uint8
copy(reserved[:], w.option.Reserved)
w.bind.SetReservedForEndpoint(destination, reserved)
}
}
} else {
ipcConf += "\npublic_key=" + w.option.PublicKey
destination, err := w.resolve(ctx, w.connectAddr)
if err != nil {
// !!! do not set initErr here !!!
// let us can retry domain resolve in next time
return E.Cause(err, "resolve endpoint domain")
}
w.bind.SetConnectAddr(destination)
ipcConf += "\nendpoint=" + destination.String()
if w.option.PreSharedKey != "" {
ipcConf += "\npreshared_key=" + w.option.PreSharedKey
}
var has4, has6 bool
for _, address := range w.localPrefixes {
if address.Addr().Is4() {
has4 = true
} else {
has6 = true
}
}
if has4 {
ipcConf += "\nallowed_ip=0.0.0.0/0"
}
if has6 {
ipcConf += "\nallowed_ip=::/0"
}
}

if w.option.PersistentKeepalive != 0 {
ipcConf += fmt.Sprintf("\npersistent_keepalive_interval=%d", w.option.PersistentKeepalive)
}

if debug.Enabled {
log.SingLogger.Trace(fmt.Sprintf("[WG](%s) created wireguard ipc conf: \n %s", w.option.Name, ipcConf))
}
err := w.device.IpcSet(ipcConf)
if err != nil {
w.initErr = E.Cause(err, "setup wireguard")
return w.initErr
}

err = w.tunDevice.Start()
if err != nil {
w.initErr = err
return w.initErr
}

w.initOk.Store(true)
return nil
}

func closeWireGuard(w *WireGuard) {
if w.device != nil {
w.device.Close()
}
_ = common.Close(w.tunDevice)
if w.closeCh != nil {
close(w.closeCh)
}
}

func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
Expand Down Expand Up @@ -416,9 +422,6 @@ func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadat
if err = w.init(ctx); err != nil {
return nil, err
}
if err != nil {
return nil, err
}
if (!metadata.Resolved() || w.resolver != nil) && metadata.Host != "" {
r := resolver.DefaultResolver
if w.resolver != nil {
Expand Down
44 changes: 44 additions & 0 deletions adapter/outbound/wireguard_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//go:build with_gvisor

package outbound

import (
"context"
"runtime"
"testing"
"time"
)

func TestWireGuardGC(t *testing.T) {
option := WireGuardOption{}
option.Server = "162.159.192.1"
option.Port = 2408
option.PrivateKey = "iOx7749AdqH3IqluG7+0YbGKd0m1mcEXAfGRzpy9rG8="
option.PublicKey = "bmXOC+F1FxEMF9dyiK2H5/1SUtzH0JuVo51h2wPfgyo="
option.Ip = "172.16.0.2"
option.Ipv6 = "2606:4700:110:8d29:be92:3a6a:f4:c437"
option.Reserved = []uint8{51, 69, 125}
wg, err := NewWireGuard(option)
if err != nil {
t.Error(err)
}
closeCh := make(chan struct{})
wg.closeCh = closeCh
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
err = wg.init(ctx)
if err != nil {
t.Error(err)
}
// must do a small sleep before test GC
// because it maybe deadlocks if w.device.Close call too fast after w.device.Start
time.Sleep(10 * time.Millisecond)
wg = nil
runtime.GC()
select {
case <-closeCh:
return
case <-ctx.Done():
t.Error("timeout not GC")
}
}

0 comments on commit 2f8f139

Please sign in to comment.