Skip to content

Commit 44fb214

Browse files
committed
Added read-buffers to udp packet reader
1 parent 869c0d8 commit 44fb214

File tree

2 files changed

+80
-29
lines changed

2 files changed

+80
-29
lines changed

network-api/network-api.go

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ func Register(router *msgpackrouter.Router) {
4848

4949
_ = router.RegisterMethod("udp/connect", udpConnect)
5050
_ = router.RegisterMethod("udp/write", udpWrite)
51+
_ = router.RegisterMethod("udp/awaitRead", udpAwaitRead)
5152
_ = router.RegisterMethod("udp/read", udpRead)
5253
_ = router.RegisterMethod("udp/close", udpClose)
5354
}
@@ -56,6 +57,7 @@ var lock sync.RWMutex
5657
var liveConnections = make(map[uint]net.Conn)
5758
var liveListeners = make(map[uint]net.Listener)
5859
var liveUdpConnections = make(map[uint]net.PacketConn)
60+
var udpReadBuffers = make(map[uint][]byte)
5961
var nextConnectionID atomic.Uint32
6062

6163
// takeLockAndGenerateNextID generates a new unique ID for a connection or listener.
@@ -420,39 +422,35 @@ func udpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_r
420422
}
421423
}
422424

423-
func udpRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
424-
if len(params) != 2 && len(params) != 3 {
425-
return nil, []any{1, "Invalid number of parameters, expected (UDP connection ID, max bytes to read[, optional timeout in ms])"}
425+
func udpAwaitRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
426+
if len(params) != 1 && len(params) != 2 {
427+
return nil, []any{1, "Invalid number of parameters, expected (UDP connection ID[, optional timeout in ms])"}
426428
}
427429
id, ok := msgpackrpc.ToUint(params[0])
428430
if !ok {
429431
return nil, []any{1, "Invalid parameter type, expected uint for UDP connection ID"}
430432
}
431-
lock.RLock()
432-
udpConn, ok := liveUdpConnections[id]
433-
lock.RUnlock()
434-
if !ok {
435-
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
436-
}
437-
maxBytes, ok := msgpackrpc.ToUint(params[1])
438-
if !ok {
439-
return nil, []any{1, "Invalid parameter type, expected uint for max bytes to read"}
440-
}
441433
var deadline time.Time // default value == no timeout
442-
if len(params) == 2 {
434+
if len(params) == 1 {
443435
// No timeout
444-
} else if ms, ok := msgpackrpc.ToInt(params[2]); !ok {
436+
} else if ms, ok := msgpackrpc.ToInt(params[1]); !ok {
445437
return nil, []any{1, "Invalid parameter type, expected int for timeout in ms"}
446438
} else if ms > 0 {
447439
deadline = time.Now().Add(time.Duration(ms) * time.Millisecond)
448440
} else if ms == 0 {
449441
// No timeout
450442
}
451443

444+
lock.RLock()
445+
udpConn, ok := liveUdpConnections[id]
446+
lock.RUnlock()
447+
if !ok {
448+
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
449+
}
452450
if err := udpConn.SetReadDeadline(deadline); err != nil {
453451
return nil, []any{3, "Failed to set read deadline: " + err.Error()}
454452
}
455-
buffer := make([]byte, maxBytes)
453+
buffer := make([]byte, 64*1024) // 64 KB buffer
456454
n, addr, err := udpConn.ReadFrom(buffer)
457455
if errors.Is(err, os.ErrDeadlineExceeded) {
458456
// timeout
@@ -471,7 +469,41 @@ func udpRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_re
471469
// Should never fail, but...
472470
return nil, []any{4, "Failed to parse source address: " + err.Error()}
473471
}
474-
return []any{buffer[:n], host, port}, nil
472+
473+
lock.Lock()
474+
udpReadBuffers[id] = buffer[:n]
475+
lock.Unlock()
476+
return []any{n, host, port}, nil
477+
}
478+
479+
func udpRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
480+
if len(params) != 2 && len(params) != 3 {
481+
return nil, []any{1, "Invalid number of parameters, expected (UDP connection ID, max bytes to read)"}
482+
}
483+
id, ok := msgpackrpc.ToUint(params[0])
484+
if !ok {
485+
return nil, []any{1, "Invalid parameter type, expected uint for UDP connection ID"}
486+
}
487+
maxBytes, ok := msgpackrpc.ToUint(params[1])
488+
if !ok {
489+
return nil, []any{1, "Invalid parameter type, expected uint for max bytes to read"}
490+
}
491+
492+
lock.Lock()
493+
buffer, exists := udpReadBuffers[id]
494+
n := uint(len(buffer))
495+
if exists {
496+
// keep the remainder of the buffer for the next read
497+
if n > maxBytes {
498+
udpReadBuffers[id] = buffer[maxBytes:]
499+
n = maxBytes
500+
} else {
501+
udpReadBuffers[id] = nil
502+
}
503+
}
504+
lock.Unlock()
505+
506+
return buffer[:n], nil
475507
}
476508

477509
func udpClose(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
@@ -485,9 +517,8 @@ func udpClose(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_r
485517

486518
lock.Lock()
487519
udpConn, existsConn := liveUdpConnections[id]
488-
if existsConn {
489-
delete(liveUdpConnections, id)
490-
}
520+
delete(liveUdpConnections, id)
521+
delete(udpReadBuffers, id)
491522
lock.Unlock()
492523

493524
if !existsConn {

network-api/network-api_test.go

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,17 @@ func TestUDPNetworkUnboundClientAPI(t *testing.T) {
304304
require.Equal(t, 5, res)
305305
}
306306
{
307-
res, err := udpRead(ctx, nil, []any{conn2, 100})
307+
res, err := udpAwaitRead(ctx, nil, []any{conn2})
308+
require.Nil(t, err)
309+
require.Equal(t, 5, res.([]any)[0])
310+
311+
res2, err := udpRead(ctx, nil, []any{conn2, 2})
308312
require.Nil(t, err)
309-
require.Equal(t, []uint8("Hello"), res.([]any)[0])
313+
require.Equal(t, []uint8("He"), res2)
314+
315+
res2, err = udpRead(ctx, nil, []any{conn2, 20})
316+
require.Nil(t, err)
317+
require.Equal(t, []uint8("llo"), res2)
310318
}
311319
{
312320
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9900, []byte("One")})
@@ -319,14 +327,22 @@ func TestUDPNetworkUnboundClientAPI(t *testing.T) {
319327
require.Equal(t, 3, res)
320328
}
321329
{
322-
res, err := udpRead(ctx, nil, []any{conn2, 100})
330+
res, err := udpAwaitRead(ctx, nil, []any{conn2})
323331
require.Nil(t, err)
324-
require.Equal(t, []uint8("One"), res.([]any)[0])
332+
require.Equal(t, 3, res.([]any)[0])
333+
334+
res2, err := udpRead(ctx, nil, []any{conn2, 100})
335+
require.Nil(t, err)
336+
require.Equal(t, []uint8("One"), res2)
325337
}
326338
{
327-
res, err := udpRead(ctx, nil, []any{conn2, 100})
339+
res, err := udpAwaitRead(ctx, nil, []any{conn2})
340+
require.Nil(t, err)
341+
require.Equal(t, 3, res.([]any)[0])
342+
343+
res2, err := udpRead(ctx, nil, []any{conn2, 100})
328344
require.Nil(t, err)
329-
require.Equal(t, []uint8("Two"), res.([]any)[0])
345+
require.Equal(t, []uint8("Two"), res2)
330346
}
331347

332348
// Check timeouts
@@ -338,15 +354,19 @@ func TestUDPNetworkUnboundClientAPI(t *testing.T) {
338354
}()
339355
{
340356
start := time.Now()
341-
res, err := udpRead(ctx, nil, []any{conn2, 100, 10})
357+
res, err := udpAwaitRead(ctx, nil, []any{conn2, 10})
342358
require.Less(t, time.Since(start), 20*time.Millisecond)
343359
require.Equal(t, []any{5, "Timeout"}, err)
344360
require.Nil(t, res)
345361
}
346362
{
347-
res, err := udpRead(ctx, nil, []any{conn2, 100, 0})
363+
res, err := udpAwaitRead(ctx, nil, []any{conn2, 0})
364+
require.Nil(t, err)
365+
require.Equal(t, 5, res.([]any)[0])
366+
367+
res2, err := udpRead(ctx, nil, []any{conn2, 100, 0})
348368
require.Nil(t, err)
349-
require.Equal(t, []uint8("Three"), res.([]any)[0])
369+
require.Equal(t, []uint8("Three"), res2)
350370
}
351371

352372
{

0 commit comments

Comments
 (0)