Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

uds: implement a connect timeout option #299

Merged
merged 6 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions statsd/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ var (
defaultWorkerCount = 32
defaultSenderQueueSize = 0
defaultWriteTimeout = 100 * time.Millisecond
defaultConnectTimeout = 1000 * time.Millisecond
defaultTelemetry = true
defaultReceivingMode = mutexMode
defaultChannelModeBufferSize = 4096
Expand All @@ -40,6 +41,7 @@ type Options struct {
workersCount int
senderQueueSize int
writeTimeout time.Duration
connectTimeout time.Duration
telemetry bool
receiveMode receivingMode
channelModeBufferSize int
Expand All @@ -65,6 +67,7 @@ func resolveOptions(options []Option) (*Options, error) {
workersCount: defaultWorkerCount,
senderQueueSize: defaultSenderQueueSize,
writeTimeout: defaultWriteTimeout,
connectTimeout: defaultConnectTimeout,
telemetry: defaultTelemetry,
receiveMode: defaultReceivingMode,
channelModeBufferSize: defaultChannelModeBufferSize,
Expand Down Expand Up @@ -206,6 +209,16 @@ func WithWriteTimeout(writeTimeout time.Duration) Option {
}
}

// WithConnectTimeout sets the timeout for network connection with the Agent, after this interval the connection
// attempt is aborted. This is only used for UDS connection. This will also reset the connection if nothing can be
// written to it for this duration.
func WithConnectTimeout(connectTimeout time.Duration) Option {
return func(o *Options) error {
o.connectTimeout = connectTimeout
return nil
}
}

// WithChannelMode make the client use channels to receive metrics
//
// This determines how the client receive metrics from the app (for example when calling the `Gauge()` method).
Expand Down
12 changes: 6 additions & 6 deletions statsd/statsd.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ func parseAgentURL(agentURL string) string {
return ""
}

func createWriter(addr string, writeTimeout time.Duration) (Transport, string, error) {
func createWriter(addr string, writeTimeout time.Duration, connectTimeout time.Duration) (Transport, string, error) {
addr = resolveAddr(addr)
if addr == "" {
return nil, "", errors.New("No address passed and autodetection from environment failed")
Expand All @@ -379,13 +379,13 @@ func createWriter(addr string, writeTimeout time.Duration) (Transport, string, e
w, err := newWindowsPipeWriter(addr, writeTimeout)
return w, writerWindowsPipe, err
case strings.HasPrefix(addr, UnixAddressPrefix):
w, err := newUDSWriter(addr[len(UnixAddressPrefix):], writeTimeout, "")
w, err := newUDSWriter(addr[len(UnixAddressPrefix):], writeTimeout, connectTimeout, "")
return w, writerNameUDS, err
case strings.HasPrefix(addr, UnixAddressDatagramPrefix):
w, err := newUDSWriter(addr[len(UnixAddressDatagramPrefix):], writeTimeout, "unixgram")
w, err := newUDSWriter(addr[len(UnixAddressDatagramPrefix):], writeTimeout, connectTimeout, "unixgram")
return w, writerNameUDS, err
case strings.HasPrefix(addr, UnixAddressStreamPrefix):
w, err := newUDSWriter(addr[len(UnixAddressStreamPrefix):], writeTimeout, "unix")
w, err := newUDSWriter(addr[len(UnixAddressStreamPrefix):], writeTimeout, connectTimeout, "unix")
return w, writerNameUDS, err
default:
w, err := newUDPWriter(addr, writeTimeout)
Expand All @@ -401,7 +401,7 @@ func New(addr string, options ...Option) (*Client, error) {
return nil, err
}

w, writerType, err := createWriter(addr, o.writeTimeout)
w, writerType, err := createWriter(addr, o.writeTimeout, o.connectTimeout)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -542,7 +542,7 @@ func newWithWriter(w Transport, o *Options, writerName string) (*Client, error)
c.telemetryClient = newTelemetryClient(&c, c.agg != nil)
} else {
var err error
c.telemetryClient, err = newTelemetryClientWithCustomAddr(&c, o.telemetryAddr, c.agg != nil, bufferPool, o.writeTimeout)
c.telemetryClient, err = newTelemetryClientWithCustomAddr(&c, o.telemetryAddr, c.agg != nil, bufferPool, o.writeTimeout, o.connectTimeout)
if err != nil {
return nil, err
}
Expand Down
6 changes: 4 additions & 2 deletions statsd/telemetry.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,10 @@ func newTelemetryClient(c *Client, aggregationEnabled bool) *telemetryClient {
return t
}

func newTelemetryClientWithCustomAddr(c *Client, telemetryAddr string, aggregationEnabled bool, pool *bufferPool, writeTimeout time.Duration) (*telemetryClient, error) {
telemetryWriter, _, err := createWriter(telemetryAddr, writeTimeout)
func newTelemetryClientWithCustomAddr(c *Client, telemetryAddr string, aggregationEnabled bool, pool *bufferPool,
writeTimeout time.Duration, connectTimeout time.Duration,
) (*telemetryClient, error) {
telemetryWriter, _, err := createWriter(telemetryAddr, writeTimeout, connectTimeout)
if err != nil {
return nil, fmt.Errorf("Could not resolve telemetry address: %v", err)
}
Expand Down
71 changes: 26 additions & 45 deletions statsd/uds.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ type udsWriter struct {
conn net.Conn
// write timeout
writeTimeout time.Duration
sync.RWMutex // used to lock conn / writer can replace it
// connect timeout
connectTimeout time.Duration
sync.RWMutex // used to lock conn / writer can replace it
}

// newUDSWriter returns a pointer to a new udsWriter given a socket file path as addr.
func newUDSWriter(addr string, writeTimeout time.Duration, transport string) (*udsWriter, error) {
func newUDSWriter(addr string, writeTimeout time.Duration, connectTimeout time.Duration, transport string) (*udsWriter, error) {
// Defer connection to first Write
writer := &udsWriter{addr: addr, transport: transport, conn: nil, writeTimeout: writeTimeout}
writer := &udsWriter{addr: addr, transport: transport, conn: nil, writeTimeout: writeTimeout, connectTimeout: connectTimeout}
return writer, nil
}

Expand All @@ -43,56 +45,23 @@ func (w *udsWriter) GetTransportName() string {
}
}

// retryOnWriteErr returns true if we should retry writing after a write error
func (w *udsWriter) retryOnWriteErr(err error, stream bool) bool {
// Never retry when using unixgram (to preserve the historical behavior)
if !stream {
return false
}
// Otherwise we retry on timeout because we might have written a partial packet
if networkError, ok := err.(net.Error); ok && networkError.Timeout() {
func (w *udsWriter) shouldCloseConnection(err error, partialWrite bool) bool {
if err != nil && partialWrite {
// We can't recover from a partial write
return true
}
return false
}

func (w *udsWriter) shouldCloseConnection(err error) bool {
if err, isNetworkErr := err.(net.Error); err != nil && (!isNetworkErr || !err.Timeout()) {
// Statsd server disconnected, retry connecting at next packet
return true
}
return false
}

// writeFull writes the whole data to the UDS connection
func (w *udsWriter) writeFull(data []byte, stopIfNoneWritten bool, stream bool) (int, error) {
vickenty marked this conversation as resolved.
Show resolved Hide resolved
written := 0
for written < len(data) {
n, e := w.conn.Write(data[written:])
written += n

// If we haven't written anything, and we're supposed to stop if we can't write anything, return the error
if written == 0 && stopIfNoneWritten {
return written, e
}

// If there's an error, check if it is retryable
if e != nil && !w.retryOnWriteErr(e, stream) {
return written, e
}

// When using "unix" we need to be able to finish to write partially written packets once we have started.
if stream {
w.conn.SetWriteDeadline(time.Time{})
}
}
return written, nil
}

// Write data to the UDS connection with write timeout and minimal error handling:
// create the connection if nil, and destroy it if the statsd server has disconnected
func (w *udsWriter) Write(data []byte) (int, error) {
var n int
partialWrite := false
conn, err := w.ensureConnection()
if err != nil {
return 0, err
Expand All @@ -107,15 +76,26 @@ func (w *udsWriter) Write(data []byte) (int, error) {
if stream {
bs := []byte{0, 0, 0, 0}
binary.LittleEndian.PutUint32(bs, uint32(len(data)))
_, err = w.writeFull(bs, true, true)
_, err = w.conn.Write(bs)
vickenty marked this conversation as resolved.
Show resolved Hide resolved

partialWrite = true

// W need to be able to finish to write partially written packets once we have started.
// But we will reset the connection if we can't write anything at all for a long time.
w.conn.SetWriteDeadline(time.Now().Add(w.connectTimeout))

// Continue writing only if we've written the length of the packet
if err == nil {
n, err = w.writeFull(data, false, true)
n, err = w.conn.Write(data)
if err == nil {
partialWrite = false
}
}
} else {
n, err = w.writeFull(data, true, false)
n, err = w.conn.Write(data)
}

if w.shouldCloseConnection(err) {
if w.shouldCloseConnection(err, partialWrite) {
w.unsetConnection()
}
return n, err
Expand All @@ -133,7 +113,7 @@ func (w *udsWriter) tryToDial(network string) (net.Conn, error) {
if err != nil {
return nil, err
}
newConn, err := net.Dial(udsAddr.Network(), udsAddr.String())
newConn, err := net.DialTimeout(udsAddr.Network(), udsAddr.String(), w.connectTimeout)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -182,5 +162,6 @@ func (w *udsWriter) ensureConnection() (net.Conn, error) {
func (w *udsWriter) unsetConnection() {
w.Lock()
defer w.Unlock()
_ = w.conn.Close()
vickenty marked this conversation as resolved.
Show resolved Hide resolved
w.conn = nil
}
56 changes: 48 additions & 8 deletions statsd/uds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ package statsd

import (
"encoding/binary"
"golang.org/x/net/nettest"
"math/rand"
"net"
"os"
"testing"
"time"

"golang.org/x/net/nettest"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand All @@ -21,13 +22,13 @@ func init() {
}

func TestNewUDSWriter(t *testing.T) {
w, err := newUDSWriter("/tmp/test.socket", 100*time.Millisecond, "")
w, err := newUDSWriter("/tmp/test.socket", 100*time.Millisecond, 1000*time.Millisecond, "")
assert.NotNil(t, w)
assert.NoError(t, err)
w, err = newUDSWriter("/tmp/test.socket", 100*time.Millisecond, "unix")
w, err = newUDSWriter("/tmp/test.socket", 100*time.Millisecond, 1000*time.Millisecond, "unix")
assert.NotNil(t, w)
assert.NoError(t, err)
w, err = newUDSWriter("/tmp/test.socket", 100*time.Millisecond, "unixgram")
w, err = newUDSWriter("/tmp/test.socket", 100*time.Millisecond, 1000*time.Millisecond, "unixgram")
assert.NotNil(t, w)
assert.NoError(t, err)
}
Expand All @@ -44,7 +45,7 @@ func TestUDSDatagramWrite(t *testing.T) {
err = os.Chmod(socketPath, 0722)
require.NoError(t, err)

w, err := newUDSWriter(socketPath, 100*time.Millisecond, "")
w, err := newUDSWriter(socketPath, 100*time.Millisecond, 1000*time.Millisecond, "")
require.Nil(t, err)
require.NotNil(t, w)

Expand Down Expand Up @@ -74,7 +75,7 @@ func TestUDSDatagramWriteUnsetConnection(t *testing.T) {
err = os.Chmod(socketPath, 0722)
require.NoError(t, err)

w, err := newUDSWriter(socketPath, 100*time.Millisecond, "")
w, err := newUDSWriter(socketPath, 100*time.Millisecond, 1000*time.Millisecond, "")
require.Nil(t, err)
require.NotNil(t, w)

Expand Down Expand Up @@ -107,7 +108,7 @@ func TestUDSStreamWrite(t *testing.T) {
err = os.Chmod(socketPath, 0722)
require.NoError(t, err)

w, err := newUDSWriter(socketPath, 100*time.Millisecond, "")
w, err := newUDSWriter(socketPath, 100*time.Millisecond, 1000*time.Millisecond, "")
require.Nil(t, err)
require.NotNil(t, w)

Expand All @@ -120,6 +121,7 @@ func TestUDSStreamWrite(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, len(msg), n)

// This works because the kernel accepts sockets before the accept call
if conn == nil {
conn, err = listener.Accept()
require.NoError(t, err)
Expand Down Expand Up @@ -148,7 +150,7 @@ func TestUDSStreamWriteUnsetConnection(t *testing.T) {
err = os.Chmod(socketPath, 0722)
require.NoError(t, err)

w, err := newUDSWriter(socketPath, 100*time.Millisecond, "")
w, err := newUDSWriter(socketPath, 100*time.Millisecond, 1000*time.Millisecond, "")
require.Nil(t, err)
require.NotNil(t, w)

Expand All @@ -161,6 +163,7 @@ func TestUDSStreamWriteUnsetConnection(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, len(msg), n)

// This works because the kernel accepts sockets before the accept call
if conn == nil {
conn, err = listener.Accept()
require.NoError(t, err)
Expand All @@ -180,3 +183,40 @@ func TestUDSStreamWriteUnsetConnection(t *testing.T) {
conn = nil
}
}

func TestUDSStreamPartialWrite(t *testing.T) {
socketPath, err := nettest.LocalPath()
require.NoError(t, err)
defer os.Remove(socketPath)

address, err := net.ResolveUnixAddr("unix", socketPath)
require.NoError(t, err)
listener, err := net.ListenUnix("unix", address)
require.NoError(t, err)
defer listener.Close()
err = os.Chmod(socketPath, 0722)
require.NoError(t, err)

w, err := newUDSWriter(socketPath, 100*time.Millisecond, 1000*time.Millisecond, "")
require.Nil(t, err)
require.NotNil(t, w)

// Force a connection
w.ensureConnection()
conn, err := listener.Accept()
iksaif marked this conversation as resolved.
Show resolved Hide resolved
require.NoError(t, err)
defer conn.Close()

// Set a very low buffer size to force a partial write, but still enough to write the header
require.NoError(t, w.conn.(*net.UnixConn).SetWriteBuffer(1))
// On linux we need to force a timeout this way
w.connectTimeout = -1 * time.Millisecond

msg := []byte("some data")
n, err := w.Write(msg)
require.Error(t, err)
assert.Lessf(t, n, len(msg), "n: %d, len(msg): %d", n, len(msg))

// The connection should be dropped
assert.Nil(t, w.conn)
}
2 changes: 1 addition & 1 deletion statsd/uds_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ import (

// newUDSWriter is disabled on Windows, SOCK_DGRAM are still unavailable but
// SOCK_STREAM should work once implemented in the agent (https://devblogs.microsoft.com/commandline/af_unix-comes-to-windows/)
func newUDSWriter(_ string, _ time.Duration, _ string) (Transport, error) {
func newUDSWriter(_ string, _ time.Duration, _ time.Duration, _ string) (Transport, error) {
return nil, fmt.Errorf("Unix socket is not available on Windows")
}