Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 207 additions & 6 deletions network-api/network-api.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,18 @@ func Register(router *msgpackrouter.Router) {

_ = router.RegisterMethod("tcp/connectSSL", tcpConnectSSL)

_ = router.RegisterMethod("udp/connect", udpConnect)
_ = router.RegisterMethod("udp/write", udpWrite)
_ = router.RegisterMethod("udp/awaitRead", udpAwaitRead)
_ = router.RegisterMethod("udp/read", udpRead)
_ = router.RegisterMethod("udp/close", udpClose)
}

var lock sync.RWMutex
var liveConnections = make(map[uint]net.Conn)
var liveListeners = make(map[uint]net.Listener)
var liveUdpConnections = make(map[uint]net.PacketConn)
var udpReadBuffers = make(map[uint][]byte)
var nextConnectionID atomic.Uint32

// takeLockAndGenerateNextID generates a new unique ID for a connection or listener.
Expand Down Expand Up @@ -213,8 +220,8 @@ func tcpCloseListener(ctx context.Context, rpc *msgpackrpc.Connection, params []
}

func tcpRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
if len(params) != 2 {
return nil, []any{1, "Invalid number of parameters, expected (connection ID, max bytes to read)"}
if len(params) != 2 && len(params) != 3 {
return nil, []any{1, "Invalid number of parameters, expected (connection ID, max bytes to read[, optional timeout in ms])"}
}
id, ok := msgpackrpc.ToUint(params[0])
if !ok {
Expand All @@ -230,12 +237,20 @@ func tcpRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_re
if !ok {
return nil, []any{1, "Invalid parameter type, expected int for max bytes to read"}
}
var deadline time.Time // default value == no timeout
if len(params) == 2 {
// It seems that there is no way to set a 0 ms timeout (immediate return) on a TCP connection.
// Setting the read deadline to time.Now() will always returns an empty (zero bytes)
// read, so we set it by default to a very short duration in the future (1 ms).
deadline = time.Now().Add(time.Millisecond)
} else if ms, ok := msgpackrpc.ToInt(params[2]); !ok {
return nil, []any{1, "Invalid parameter type, expected int for timeout in ms"}
} else if ms > 0 {
deadline = time.Now().Add(time.Duration(ms) * time.Millisecond)
}

buffer := make([]byte, maxBytes)
// It seems that the only way to make a non-blocking read is to set a read deadline.
// BTW setting the read deadline to time.Now() will always returns an empty (zero bytes)
// read, so we set it to a very short duration in the future.
if err := conn.SetReadDeadline(time.Now().Add(time.Millisecond)); err != nil {
if err := conn.SetReadDeadline(deadline); err != nil {
return nil, []any{3, "Failed to set read timeout: " + err.Error()}
}
n, err := conn.Read(buffer)
Expand Down Expand Up @@ -328,3 +343,189 @@ func tcpConnectSSL(ctx context.Context, rpc *msgpackrpc.Connection, params []any
unlock()
return id, nil
}

func udpConnect(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
if len(params) != 2 {
return nil, []any{1, "Invalid number of parameters, expected server address and port"}
}
serverAddr, ok := params[0].(string)
if !ok {
return nil, []any{1, "Invalid parameter type, expected string for server address"}
}
serverPort, ok := msgpackrpc.ToUint(params[1])
if !ok {
return nil, []any{1, "Invalid parameter type, expected uint16 for server port"}
}

serverAddr = net.JoinHostPort(serverAddr, fmt.Sprintf("%d", serverPort))
udpAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
return nil, []any{2, "Failed to resolve UDP address: " + err.Error()}
}
udpConn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return nil, []any{2, "Failed to connect to server: " + err.Error()}
}

// Successfully opened UDP channel

id, unlock := takeLockAndGenerateNextID()
liveUdpConnections[id] = udpConn
unlock()
return id, nil
}

func udpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
if len(params) != 4 {
return nil, []any{1, "Invalid number of parameters, expected udpConnId, dest address, dest port, payload"}
}
id, ok := msgpackrpc.ToUint(params[0])
if !ok {
return nil, []any{1, "Invalid parameter type, expected int for UDP connection ID"}
}
targetIP, ok := params[1].(string)
if !ok {
return nil, []any{1, "Invalid parameter type, expected string for server address"}
}
targetPort, ok := msgpackrpc.ToUint(params[2])
if !ok {
return nil, []any{1, "Invalid parameter type, expected uint16 for server port"}
}
data, ok := params[3].([]byte)
if !ok {
if dataStr, ok := params[3].(string); ok {
data = []byte(dataStr)
} else {
// If data is not []byte or string, return an error
return nil, []any{1, "Invalid parameter type, expected []byte or string for data to write"}
}
}

lock.RLock()
udpConn, ok := liveUdpConnections[id]
lock.RUnlock()
if !ok {
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
}

targetAddr := net.JoinHostPort(targetIP, fmt.Sprintf("%d", targetPort))
addr, err := net.ResolveUDPAddr("udp", targetAddr) // TODO: This is inefficient, implement some caching
if err != nil {
return nil, []any{3, "Failed to resolve target address: " + err.Error()}
}
if n, err := udpConn.WriteTo(data, addr); err != nil {
return nil, []any{4, "Failed to write to UDP connection: " + err.Error()}
} else {
return n, nil
}
}

func udpAwaitRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
if len(params) != 1 && len(params) != 2 {
return nil, []any{1, "Invalid number of parameters, expected (UDP connection ID[, optional timeout in ms])"}
}
id, ok := msgpackrpc.ToUint(params[0])
if !ok {
return nil, []any{1, "Invalid parameter type, expected uint for UDP connection ID"}
}
var deadline time.Time // default value == no timeout
if len(params) == 2 {
if ms, ok := msgpackrpc.ToInt(params[1]); !ok {
return nil, []any{1, "Invalid parameter type, expected int for timeout in ms"}
} else if ms > 0 {
deadline = time.Now().Add(time.Duration(ms) * time.Millisecond)
}
}

lock.RLock()
udpConn, ok := liveUdpConnections[id]
lock.RUnlock()
if !ok {
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
}
if err := udpConn.SetReadDeadline(deadline); err != nil {
return nil, []any{3, "Failed to set read deadline: " + err.Error()}
}
buffer := make([]byte, 64*1024) // 64 KB buffer
n, addr, err := udpConn.ReadFrom(buffer)
if errors.Is(err, os.ErrDeadlineExceeded) {
// timeout
return nil, []any{5, "Timeout"}
}
if err != nil {
return nil, []any{3, "Failed to read from UDP connection: " + err.Error()}
}
host, portStr, err := net.SplitHostPort(addr.String())
if err != nil {
// Should never fail, but...
return nil, []any{4, "Failed to parse source address: " + err.Error()}
}
port, err := strconv.Atoi(portStr)
if err != nil {
// Should never fail, but...
return nil, []any{4, "Failed to parse source address: " + err.Error()}
}

lock.Lock()
udpReadBuffers[id] = buffer[:n]
lock.Unlock()
return []any{n, host, port}, nil
}

func udpRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
if len(params) != 2 && len(params) != 3 {
return nil, []any{1, "Invalid number of parameters, expected (UDP connection ID, max bytes to read)"}
}
id, ok := msgpackrpc.ToUint(params[0])
if !ok {
return nil, []any{1, "Invalid parameter type, expected uint for UDP connection ID"}
}
maxBytes, ok := msgpackrpc.ToUint(params[1])
if !ok {
return nil, []any{1, "Invalid parameter type, expected uint for max bytes to read"}
}

lock.Lock()
buffer, exists := udpReadBuffers[id]
n := uint(len(buffer))
if exists {
// keep the remainder of the buffer for the next read
if n > maxBytes {
udpReadBuffers[id] = buffer[maxBytes:]
n = maxBytes
} else {
udpReadBuffers[id] = nil
}
}
lock.Unlock()

return buffer[:n], nil
}

func udpClose(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
if len(params) != 1 {
return nil, []any{1, "Invalid number of parameters, expected UDP connection ID"}
}
id, ok := msgpackrpc.ToUint(params[0])
if !ok {
return nil, []any{1, "Invalid parameter type, expected int for UDP connection ID"}
}

lock.Lock()
udpConn, existsConn := liveUdpConnections[id]
delete(liveUdpConnections, id)
delete(udpReadBuffers, id)
lock.Unlock()

if !existsConn {
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
}

// Close the connection if it exists
// We do not return an error to the caller if the close operation fails, as it is not critical,
// but we only log the error for debugging purposes.
if err := udpConn.Close(); err != nil {
return err.Error(), nil
}
return "", nil
}
Loading