Skip to content

Commit 3a4b461

Browse files
authored
Added UDP Network API (#3)
* Added UDP Network API * Fix linter warnings * Added timeouts support * Added read-buffers to udp packet reader * Let's make linter happy * Allow tests to run in parallel
1 parent cbcccca commit 3a4b461

File tree

2 files changed

+364
-7
lines changed

2 files changed

+364
-7
lines changed

network-api/network-api.go

Lines changed: 207 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,18 @@ func Register(router *msgpackrouter.Router) {
4646

4747
_ = router.RegisterMethod("tcp/connectSSL", tcpConnectSSL)
4848

49+
_ = router.RegisterMethod("udp/connect", udpConnect)
50+
_ = router.RegisterMethod("udp/write", udpWrite)
51+
_ = router.RegisterMethod("udp/awaitRead", udpAwaitRead)
52+
_ = router.RegisterMethod("udp/read", udpRead)
53+
_ = router.RegisterMethod("udp/close", udpClose)
4954
}
5055

5156
var lock sync.RWMutex
5257
var liveConnections = make(map[uint]net.Conn)
5358
var liveListeners = make(map[uint]net.Listener)
59+
var liveUdpConnections = make(map[uint]net.PacketConn)
60+
var udpReadBuffers = make(map[uint][]byte)
5461
var nextConnectionID atomic.Uint32
5562

5663
// takeLockAndGenerateNextID generates a new unique ID for a connection or listener.
@@ -213,8 +220,8 @@ func tcpCloseListener(ctx context.Context, rpc *msgpackrpc.Connection, params []
213220
}
214221

215222
func tcpRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
216-
if len(params) != 2 {
217-
return nil, []any{1, "Invalid number of parameters, expected (connection ID, max bytes to read)"}
223+
if len(params) != 2 && len(params) != 3 {
224+
return nil, []any{1, "Invalid number of parameters, expected (connection ID, max bytes to read[, optional timeout in ms])"}
218225
}
219226
id, ok := msgpackrpc.ToUint(params[0])
220227
if !ok {
@@ -230,12 +237,20 @@ func tcpRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_re
230237
if !ok {
231238
return nil, []any{1, "Invalid parameter type, expected int for max bytes to read"}
232239
}
240+
var deadline time.Time // default value == no timeout
241+
if len(params) == 2 {
242+
// It seems that there is no way to set a 0 ms timeout (immediate return) on a TCP connection.
243+
// Setting the read deadline to time.Now() will always returns an empty (zero bytes)
244+
// read, so we set it by default to a very short duration in the future (1 ms).
245+
deadline = time.Now().Add(time.Millisecond)
246+
} else if ms, ok := msgpackrpc.ToInt(params[2]); !ok {
247+
return nil, []any{1, "Invalid parameter type, expected int for timeout in ms"}
248+
} else if ms > 0 {
249+
deadline = time.Now().Add(time.Duration(ms) * time.Millisecond)
250+
}
233251

234252
buffer := make([]byte, maxBytes)
235-
// It seems that the only way to make a non-blocking read is to set a read deadline.
236-
// BTW setting the read deadline to time.Now() will always returns an empty (zero bytes)
237-
// read, so we set it to a very short duration in the future.
238-
if err := conn.SetReadDeadline(time.Now().Add(time.Millisecond)); err != nil {
253+
if err := conn.SetReadDeadline(deadline); err != nil {
239254
return nil, []any{3, "Failed to set read timeout: " + err.Error()}
240255
}
241256
n, err := conn.Read(buffer)
@@ -328,3 +343,189 @@ func tcpConnectSSL(ctx context.Context, rpc *msgpackrpc.Connection, params []any
328343
unlock()
329344
return id, nil
330345
}
346+
347+
func udpConnect(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
348+
if len(params) != 2 {
349+
return nil, []any{1, "Invalid number of parameters, expected server address and port"}
350+
}
351+
serverAddr, ok := params[0].(string)
352+
if !ok {
353+
return nil, []any{1, "Invalid parameter type, expected string for server address"}
354+
}
355+
serverPort, ok := msgpackrpc.ToUint(params[1])
356+
if !ok {
357+
return nil, []any{1, "Invalid parameter type, expected uint16 for server port"}
358+
}
359+
360+
serverAddr = net.JoinHostPort(serverAddr, fmt.Sprintf("%d", serverPort))
361+
udpAddr, err := net.ResolveUDPAddr("udp", serverAddr)
362+
if err != nil {
363+
return nil, []any{2, "Failed to resolve UDP address: " + err.Error()}
364+
}
365+
udpConn, err := net.ListenUDP("udp", udpAddr)
366+
if err != nil {
367+
return nil, []any{2, "Failed to connect to server: " + err.Error()}
368+
}
369+
370+
// Successfully opened UDP channel
371+
372+
id, unlock := takeLockAndGenerateNextID()
373+
liveUdpConnections[id] = udpConn
374+
unlock()
375+
return id, nil
376+
}
377+
378+
func udpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
379+
if len(params) != 4 {
380+
return nil, []any{1, "Invalid number of parameters, expected udpConnId, dest address, dest port, payload"}
381+
}
382+
id, ok := msgpackrpc.ToUint(params[0])
383+
if !ok {
384+
return nil, []any{1, "Invalid parameter type, expected int for UDP connection ID"}
385+
}
386+
targetIP, ok := params[1].(string)
387+
if !ok {
388+
return nil, []any{1, "Invalid parameter type, expected string for server address"}
389+
}
390+
targetPort, ok := msgpackrpc.ToUint(params[2])
391+
if !ok {
392+
return nil, []any{1, "Invalid parameter type, expected uint16 for server port"}
393+
}
394+
data, ok := params[3].([]byte)
395+
if !ok {
396+
if dataStr, ok := params[3].(string); ok {
397+
data = []byte(dataStr)
398+
} else {
399+
// If data is not []byte or string, return an error
400+
return nil, []any{1, "Invalid parameter type, expected []byte or string for data to write"}
401+
}
402+
}
403+
404+
lock.RLock()
405+
udpConn, ok := liveUdpConnections[id]
406+
lock.RUnlock()
407+
if !ok {
408+
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
409+
}
410+
411+
targetAddr := net.JoinHostPort(targetIP, fmt.Sprintf("%d", targetPort))
412+
addr, err := net.ResolveUDPAddr("udp", targetAddr) // TODO: This is inefficient, implement some caching
413+
if err != nil {
414+
return nil, []any{3, "Failed to resolve target address: " + err.Error()}
415+
}
416+
if n, err := udpConn.WriteTo(data, addr); err != nil {
417+
return nil, []any{4, "Failed to write to UDP connection: " + err.Error()}
418+
} else {
419+
return n, nil
420+
}
421+
}
422+
423+
func udpAwaitRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
424+
if len(params) != 1 && len(params) != 2 {
425+
return nil, []any{1, "Invalid number of parameters, expected (UDP connection ID[, optional timeout in ms])"}
426+
}
427+
id, ok := msgpackrpc.ToUint(params[0])
428+
if !ok {
429+
return nil, []any{1, "Invalid parameter type, expected uint for UDP connection ID"}
430+
}
431+
var deadline time.Time // default value == no timeout
432+
if len(params) == 2 {
433+
if ms, ok := msgpackrpc.ToInt(params[1]); !ok {
434+
return nil, []any{1, "Invalid parameter type, expected int for timeout in ms"}
435+
} else if ms > 0 {
436+
deadline = time.Now().Add(time.Duration(ms) * time.Millisecond)
437+
}
438+
}
439+
440+
lock.RLock()
441+
udpConn, ok := liveUdpConnections[id]
442+
lock.RUnlock()
443+
if !ok {
444+
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
445+
}
446+
if err := udpConn.SetReadDeadline(deadline); err != nil {
447+
return nil, []any{3, "Failed to set read deadline: " + err.Error()}
448+
}
449+
buffer := make([]byte, 64*1024) // 64 KB buffer
450+
n, addr, err := udpConn.ReadFrom(buffer)
451+
if errors.Is(err, os.ErrDeadlineExceeded) {
452+
// timeout
453+
return nil, []any{5, "Timeout"}
454+
}
455+
if err != nil {
456+
return nil, []any{3, "Failed to read from UDP connection: " + err.Error()}
457+
}
458+
host, portStr, err := net.SplitHostPort(addr.String())
459+
if err != nil {
460+
// Should never fail, but...
461+
return nil, []any{4, "Failed to parse source address: " + err.Error()}
462+
}
463+
port, err := strconv.Atoi(portStr)
464+
if err != nil {
465+
// Should never fail, but...
466+
return nil, []any{4, "Failed to parse source address: " + err.Error()}
467+
}
468+
469+
lock.Lock()
470+
udpReadBuffers[id] = buffer[:n]
471+
lock.Unlock()
472+
return []any{n, host, port}, nil
473+
}
474+
475+
func udpRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
476+
if len(params) != 2 && len(params) != 3 {
477+
return nil, []any{1, "Invalid number of parameters, expected (UDP connection ID, max bytes to read)"}
478+
}
479+
id, ok := msgpackrpc.ToUint(params[0])
480+
if !ok {
481+
return nil, []any{1, "Invalid parameter type, expected uint for UDP connection ID"}
482+
}
483+
maxBytes, ok := msgpackrpc.ToUint(params[1])
484+
if !ok {
485+
return nil, []any{1, "Invalid parameter type, expected uint for max bytes to read"}
486+
}
487+
488+
lock.Lock()
489+
buffer, exists := udpReadBuffers[id]
490+
n := uint(len(buffer))
491+
if exists {
492+
// keep the remainder of the buffer for the next read
493+
if n > maxBytes {
494+
udpReadBuffers[id] = buffer[maxBytes:]
495+
n = maxBytes
496+
} else {
497+
udpReadBuffers[id] = nil
498+
}
499+
}
500+
lock.Unlock()
501+
502+
return buffer[:n], nil
503+
}
504+
505+
func udpClose(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
506+
if len(params) != 1 {
507+
return nil, []any{1, "Invalid number of parameters, expected UDP connection ID"}
508+
}
509+
id, ok := msgpackrpc.ToUint(params[0])
510+
if !ok {
511+
return nil, []any{1, "Invalid parameter type, expected int for UDP connection ID"}
512+
}
513+
514+
lock.Lock()
515+
udpConn, existsConn := liveUdpConnections[id]
516+
delete(liveUdpConnections, id)
517+
delete(udpReadBuffers, id)
518+
lock.Unlock()
519+
520+
if !existsConn {
521+
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
522+
}
523+
524+
// Close the connection if it exists
525+
// We do not return an error to the caller if the close operation fails, as it is not critical,
526+
// but we only log the error for debugging purposes.
527+
if err := udpConn.Close(); err != nil {
528+
return err.Error(), nil
529+
}
530+
return "", nil
531+
}

0 commit comments

Comments
 (0)