Skip to content

Commit b7bdfdf

Browse files
committed
Added UDP Network API
1 parent f7384c6 commit b7bdfdf

File tree

2 files changed

+253
-1
lines changed

2 files changed

+253
-1
lines changed

network-api/network-api.go

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,16 @@ func Register(router *msgpackrouter.Router) {
4646

4747
_ = router.RegisterMethod("tcp/connectSSL", tcpConnectSSL)
4848

49+
_ = router.RegisterMethod("udp/connect", udpConnect)
50+
_ = router.RegisterMethod("udp/write", udpWrite)
51+
_ = router.RegisterMethod("udp/read", udpRead)
52+
_ = router.RegisterMethod("udp/close", udpClose)
4953
}
5054

5155
var lock sync.RWMutex
5256
var liveConnections = make(map[uint]net.Conn)
5357
var liveListeners = make(map[uint]net.Listener)
58+
var liveUdpConnections = make(map[uint]net.PacketConn)
5459
var nextConnectionID atomic.Uint32
5560

5661
// takeLockAndGenerateNextID generates a new unique ID for a connection or listener.
@@ -328,3 +333,146 @@ func tcpConnectSSL(ctx context.Context, rpc *msgpackrpc.Connection, params []any
328333
unlock()
329334
return id, nil
330335
}
336+
337+
func udpConnect(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
338+
if len(params) != 2 {
339+
return nil, []any{1, "Invalid number of parameters, expected server address and port"}
340+
}
341+
serverAddr, ok := params[0].(string)
342+
if !ok {
343+
return nil, []any{1, "Invalid parameter type, expected string for server address"}
344+
}
345+
serverPort, ok := msgpackrpc.ToUint(params[1])
346+
if !ok {
347+
return nil, []any{1, "Invalid parameter type, expected uint16 for server port"}
348+
}
349+
350+
serverAddr = net.JoinHostPort(serverAddr, strconv.Itoa(int(serverPort)))
351+
udpAddr, err := net.ResolveUDPAddr("udp", serverAddr)
352+
if err != nil {
353+
return nil, []any{2, "Failed to resolve UDP address: " + err.Error()}
354+
}
355+
udpConn, err := net.ListenUDP("udp", udpAddr)
356+
if err != nil {
357+
return nil, []any{2, "Failed to connect to server: " + err.Error()}
358+
}
359+
360+
// Successfully opened UDP channel
361+
362+
id, unlock := takeLockAndGenerateNextID()
363+
liveUdpConnections[id] = udpConn
364+
unlock()
365+
return id, nil
366+
}
367+
368+
func udpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
369+
if len(params) != 4 {
370+
return nil, []any{1, "Invalid number of parameters, expected udpConnId, dest address, dest port, payload"}
371+
}
372+
id, ok := msgpackrpc.ToUint(params[0])
373+
if !ok {
374+
return nil, []any{1, "Invalid parameter type, expected int for UDP connection ID"}
375+
}
376+
targetIP, ok := params[1].(string)
377+
if !ok {
378+
return nil, []any{1, "Invalid parameter type, expected string for server address"}
379+
}
380+
targetPort, ok := msgpackrpc.ToUint(params[2])
381+
if !ok {
382+
return nil, []any{1, "Invalid parameter type, expected uint16 for server port"}
383+
}
384+
data, ok := params[3].([]byte)
385+
if !ok {
386+
if dataStr, ok := params[3].(string); ok {
387+
data = []byte(dataStr)
388+
} else {
389+
// If data is not []byte or string, return an error
390+
return nil, []any{1, "Invalid parameter type, expected []byte or string for data to write"}
391+
}
392+
}
393+
394+
lock.RLock()
395+
udpConn, ok := liveUdpConnections[id]
396+
lock.RUnlock()
397+
if !ok {
398+
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
399+
}
400+
401+
targetAddr := net.JoinHostPort(targetIP, strconv.Itoa(int(targetPort)))
402+
addr, err := net.ResolveUDPAddr("udp", targetAddr) // TODO: This is inefficient, implement some caching
403+
if err != nil {
404+
return nil, []any{3, "Failed to resolve target address: " + err.Error()}
405+
}
406+
if n, err := udpConn.WriteTo(data, addr); err != nil {
407+
return nil, []any{4, "Failed to write to UDP connection: " + err.Error()}
408+
} else {
409+
return n, nil
410+
}
411+
}
412+
413+
func udpRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
414+
if len(params) != 2 {
415+
return nil, []any{1, "Invalid number of parameters, expected (UDP connection ID, max bytes to read)"}
416+
}
417+
id, ok := msgpackrpc.ToUint(params[0])
418+
if !ok {
419+
return nil, []any{1, "Invalid parameter type, expected uint for UDP connection ID"}
420+
}
421+
lock.RLock()
422+
udpConn, ok := liveUdpConnections[id]
423+
lock.RUnlock()
424+
if !ok {
425+
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
426+
}
427+
maxBytes, ok := msgpackrpc.ToUint(params[1])
428+
if !ok {
429+
return nil, []any{1, "Invalid parameter type, expected uint for max bytes to read"}
430+
}
431+
432+
buffer := make([]byte, maxBytes)
433+
434+
n, addr, err := udpConn.ReadFrom(buffer)
435+
if err != nil {
436+
return nil, []any{3, "Failed to read from UDP connection: " + err.Error()}
437+
}
438+
host, portStr, err := net.SplitHostPort(addr.String())
439+
if err != nil {
440+
// Should never fail, but...
441+
return nil, []any{4, "Failed to parse source address: " + err.Error()}
442+
}
443+
port, err := strconv.Atoi(portStr)
444+
if err != nil {
445+
// Should never fail, but...
446+
return nil, []any{4, "Failed to parse source address: " + err.Error()}
447+
}
448+
return []any{buffer[:n], host, port}, nil
449+
}
450+
451+
func udpClose(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
452+
if len(params) != 1 {
453+
return nil, []any{1, "Invalid number of parameters, expected UDP connection ID"}
454+
}
455+
id, ok := msgpackrpc.ToUint(params[0])
456+
if !ok {
457+
return nil, []any{1, "Invalid parameter type, expected int for UDP connection ID"}
458+
}
459+
460+
lock.Lock()
461+
udpConn, existsConn := liveUdpConnections[id]
462+
if existsConn {
463+
delete(liveUdpConnections, id)
464+
}
465+
lock.Unlock()
466+
467+
if !existsConn {
468+
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
469+
}
470+
471+
// Close the connection if it exists
472+
// We do not return an error to the caller if the close operation fails, as it is not critical,
473+
// but we only log the error for debugging purposes.
474+
if err := udpConn.Close(); err != nil {
475+
return err.Error(), nil
476+
}
477+
return "", nil
478+
}

network-api/network-api_test.go

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ const testCert = "-----BEGIN CERTIFICATE-----\n" +
156156
"HAIgNadMPgxv01dy59kCgzehgKzmKdTF0rG1SniYqnkLqPA=\n" +
157157
"-----END CERTIFICATE-----\n"
158158

159-
func TestNetworkAPI(t *testing.T) {
159+
func TestTCPNetworkAPI(t *testing.T) {
160160
ctx := t.Context()
161161
var rpc *msgpackrpc.Connection
162162
listID, err := tcpListen(ctx, rpc, []any{"localhost", 9999})
@@ -236,3 +236,107 @@ func TestNetworkAPI(t *testing.T) {
236236

237237
wg.Wait()
238238
}
239+
240+
func TestUDPNetworkAPI(t *testing.T) {
241+
ctx := t.Context()
242+
conn1, err := udpConnect(ctx, nil, []any{"0.0.0.0", 9800})
243+
require.Nil(t, err)
244+
require.Equal(t, uint(1), conn1)
245+
246+
conn2, err := udpConnect(ctx, nil, []any{"0.0.0.0", 9900})
247+
require.Nil(t, err)
248+
require.Equal(t, uint(2), conn2)
249+
250+
{
251+
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9900, []byte("Hello")})
252+
require.Nil(t, err)
253+
require.Equal(t, 5, res)
254+
}
255+
{
256+
res, err := udpRead(ctx, nil, []any{conn2, 100})
257+
require.Nil(t, err)
258+
require.Equal(t, []any{[]uint8("Hello"), "127.0.0.1", 9800}, res)
259+
}
260+
{
261+
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9900, []byte("One")})
262+
require.Nil(t, err)
263+
require.Equal(t, 3, res)
264+
}
265+
{
266+
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9900, []byte("Two")})
267+
require.Nil(t, err)
268+
require.Equal(t, 3, res)
269+
}
270+
{
271+
res, err := udpRead(ctx, nil, []any{conn2, 100})
272+
require.Nil(t, err)
273+
require.Equal(t, []any{[]uint8("One"), "127.0.0.1", 9800}, res)
274+
}
275+
{
276+
res, err := udpRead(ctx, nil, []any{conn2, 100})
277+
require.Nil(t, err)
278+
require.Equal(t, []any{[]uint8("Two"), "127.0.0.1", 9800}, res)
279+
}
280+
{
281+
res, err := udpClose(ctx, nil, []any{conn1})
282+
require.Nil(t, err)
283+
require.Equal(t, "", res)
284+
}
285+
{
286+
res, err := udpClose(ctx, nil, []any{conn2})
287+
require.Nil(t, err)
288+
require.Equal(t, "", res)
289+
}
290+
}
291+
292+
func TestUDPNetworkUnboundClientAPI(t *testing.T) {
293+
ctx := t.Context()
294+
conn1, err := udpConnect(ctx, nil, []any{"", 0})
295+
require.Nil(t, err)
296+
require.Equal(t, uint(1), conn1)
297+
298+
conn2, err := udpConnect(ctx, nil, []any{"0.0.0.0", 9900})
299+
require.Nil(t, err)
300+
require.Equal(t, uint(2), conn2)
301+
302+
{
303+
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9900, []byte("Hello")})
304+
require.Nil(t, err)
305+
require.Equal(t, 5, res)
306+
}
307+
{
308+
res, err := udpRead(ctx, nil, []any{conn2, 100})
309+
require.Nil(t, err)
310+
require.Equal(t, []uint8("Hello"), res.([]any)[0])
311+
}
312+
{
313+
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9900, []byte("One")})
314+
require.Nil(t, err)
315+
require.Equal(t, 3, res)
316+
}
317+
{
318+
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9900, []byte("Two")})
319+
require.Nil(t, err)
320+
require.Equal(t, 3, res)
321+
}
322+
{
323+
res, err := udpRead(ctx, nil, []any{conn2, 100})
324+
require.Nil(t, err)
325+
require.Equal(t, []uint8("One"), res.([]any)[0])
326+
}
327+
{
328+
res, err := udpRead(ctx, nil, []any{conn2, 100})
329+
require.Nil(t, err)
330+
require.Equal(t, []uint8("Two"), res.([]any)[0])
331+
}
332+
{
333+
res, err := udpClose(ctx, nil, []any{conn1})
334+
require.Nil(t, err)
335+
require.Equal(t, "", res)
336+
}
337+
{
338+
res, err := udpClose(ctx, nil, []any{conn2})
339+
require.Nil(t, err)
340+
require.Equal(t, "", res)
341+
}
342+
}

0 commit comments

Comments
 (0)