From c30cd6a2b25394430e988a33073f6007c0d24927 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Tue, 25 Jul 2023 16:22:23 +0300 Subject: [PATCH] Modbus server examples --- .github/workflows/main.yml | 2 +- CHANGELOG.md | 3 +- README.md | 7 +- builder_external_test.go | 5 +- builder_test.go | 32 ++-- client.go | 110 ++++++------ client_test.go | 16 +- examples/server_and_request_test.go | 118 ++++++++++++ modbustest/helpers.go | 37 +++- modbustest/server.go | 146 --------------- packet/packet.go | 41 +++-- packet/packet_test.go | 80 +++++---- packet/registers.go | 7 +- packet/request.go | 4 +- scripts/.githooks/pre-commit | 4 +- serial.md | 3 + server/modbus.go | 41 +++++ server/server.go | 267 ++++++++++++++++++++++++++++ server/server_test.go | 120 +++++++++++++ 19 files changed, 747 insertions(+), 296 deletions(-) create mode 100644 examples/server_and_request_test.go delete mode 100644 modbustest/server.go create mode 100644 server/modbus.go create mode 100644 server/server.go create mode 100644 server/server_test.go diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 765e7e1..6ce8db6 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -13,7 +13,7 @@ permissions: contents: read # to fetch code (actions/checkout) env: - # run coverage and benchmarks only with the latest Go version + # run coverage only with the latest Go version LATEST_GO_VERSION: "1.20" diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e96a04..b7ced2e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,8 +9,9 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ### Added -* Added `packet.IsLikeModbusTCP()` to check if given bytes are possibly TCP packet or start of packet. +* Added `packet.LooksLikeModbusTCP()` to check if given bytes are possibly TCP packet or start of packet. * Added `Parse*Request*` for every function type to help implement Modbus servers. +* Added `Server` package to implement your own modbus server ## [0.0.1] - 2021-04-11 diff --git a/README.md b/README.md index 233be18..d333efd 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,7 @@ for _, req := range requests { assert.Equal(t, "alarm_do_1", fields[1].Field.Name) } ``` + ### RTU over serial port RTU examples to interact with serial port can be found from [serial.md](serial.md) @@ -77,7 +78,10 @@ RTU examples to interact with serial port can be found from [serial.md](serial.m ### Low level packets ```go -client := modbus.NewTCPClient(modbus.WithTimeouts(10*time.Second, 10*time.Second)) +client := modbus.NewTCPClientWithConfig(modbus.ClientConfig{ + WriteTimeout: 2 * time.Second, + ReadTimeout: 2 * time.Second, +}) if err := client.Connect(context.Background(), "localhost:5020"); err != nil { return err } @@ -103,6 +107,7 @@ uint32Var, err := registers.Uint32(17) // extract uint32 value from register 17 ``` To create single TCP packet use following methods. Use `RTU` suffix to create RTU packets. + ```go import "github.com/aldas/go-modbus-client/packet" diff --git a/builder_external_test.go b/builder_external_test.go index 631e1f7..1a9b134 100644 --- a/builder_external_test.go +++ b/builder_external_test.go @@ -41,14 +41,11 @@ func TestExternalUsage(t *testing.T) { assert.NoError(t, err) assert.Len(t, reqs, 1) - client := modbus.NewClient() + client := modbus.NewTCPClient() if err := client.Connect(context.Background(), addr); err != nil { return } - //for _, req := range reqs { - // - //} req := reqs[0] // skip looping as we always have 1 request in this example resp, err := client.Do(context.Background(), req) diff --git a/builder_test.go b/builder_test.go index 6a8745a..a2b6e6a 100644 --- a/builder_test.go +++ b/builder_test.go @@ -16,9 +16,6 @@ func TestBuilder_ReadHoldingRegistersTCP(t *testing.T) { receivedChan := make(chan []byte, 1) handler := func(received []byte, bytesRead int) (response []byte, closeConnection bool) { - if bytesRead == 0 { - return nil, false - } receivedChan <- received resp := packet.ReadHoldingRegistersResponseTCP{ MBAPHeader: packet.MBAPHeader{TransactionID: 123, ProtocolID: 0}, @@ -41,17 +38,25 @@ func TestBuilder_ReadHoldingRegistersTCP(t *testing.T) { assert.NoError(t, err) assert.Len(t, reqs, 1) - client := NewClient() - err = client.Connect(context.Background(), addr) + ctxReq, cancelReq := context.WithTimeout(context.Background(), 2*time.Second) + defer cancelReq() + + client := NewTCPClient() + err = client.Connect(ctxReq, addr) assert.NoError(t, err) request := reqs[0] - resp, err := client.Do(context.Background(), request) + resp, err := client.Do(ctxReq, request) assert.NoError(t, err) assert.NotNil(t, resp) - received := <-receivedChan - assert.Equal(t, []byte{0, 0, 0, 6, 0, 3, 0, 18, 0, 4}, received[2:]) // trim transaction ID + select { + case received := <-receivedChan: + assert.Equal(t, []byte{0, 0, 0, 6, 0, 3, 0, 18, 0, 4}, received[2:]) // trim transaction ID + default: + t.Errorf("nothing received") + } + } func TestBuilder_ReadHoldingRegistersRTU(t *testing.T) { @@ -60,9 +65,6 @@ func TestBuilder_ReadHoldingRegistersRTU(t *testing.T) { receivedChan := make(chan []byte, 1) handler := func(received []byte, bytesRead int) (response []byte, closeConnection bool) { - if bytesRead == 0 { - return nil, false - } receivedChan <- received resp := packet.ReadHoldingRegistersResponseRTU{ ReadHoldingRegistersResponse: packet.ReadHoldingRegistersResponse{ @@ -103,9 +105,6 @@ func TestBuilder_ReadInputRegistersTCP(t *testing.T) { receivedChan := make(chan []byte, 1) handler := func(received []byte, bytesRead int) (response []byte, closeConnection bool) { - if bytesRead == 0 { - return nil, false - } receivedChan <- received resp := packet.ReadInputRegistersResponseTCP{ MBAPHeader: packet.MBAPHeader{TransactionID: 123, ProtocolID: 0}, @@ -128,7 +127,7 @@ func TestBuilder_ReadInputRegistersTCP(t *testing.T) { assert.NoError(t, err) assert.Len(t, reqs, 1) - client := NewClient() + client := NewTCPClient() err = client.Connect(context.Background(), addr) assert.NoError(t, err) @@ -147,9 +146,6 @@ func TestBuilder_ReadInputRegistersRTU(t *testing.T) { receivedChan := make(chan []byte, 1) handler := func(received []byte, bytesRead int) (response []byte, closeConnection bool) { - if bytesRead == 0 { - return nil, false - } receivedChan <- received resp := packet.ReadInputRegistersResponseRTU{ ReadInputRegistersResponse: packet.ReadInputRegistersResponse{ diff --git a/client.go b/client.go index 0dea825..7c6178a 100644 --- a/client.go +++ b/client.go @@ -64,8 +64,22 @@ type ClientHooks interface { BeforeParse(received []byte) } -func defaultClient() *Client { - return &Client{ +// ClientConfig is configuration for Client +type ClientConfig struct { + // WriteTimeout is total amount of time writing the request can take after client returns error + WriteTimeout time.Duration + // ReadTimeout is total amount of time reading the response can take before client returns error + ReadTimeout time.Duration + + DialContextFunc func(ctx context.Context, address string) (net.Conn, error) + AsProtocolErrorFunc func(data []byte) error + ParseResponseFunc func(data []byte) (packet.Response, error) + + Hooks ClientHooks +} + +func defaultClient(conf ClientConfig) *Client { + c := &Client{ timeNow: time.Now, writeTimeout: defaultWriteTimeout, readTimeout: defaultReadTimeout, @@ -75,75 +89,57 @@ func defaultClient() *Client { asProtocolErrorFunc: packet.AsTCPErrorPacket, parseResponseFunc: packet.ParseTCPResponse, } -} -// NewTCPClient creates new instance of Modbus Client for Modbus TCP protocol -func NewTCPClient(opts ...ClientOptionFunc) *Client { - client := defaultClient() - for _, o := range opts { - o(client) + if conf.WriteTimeout > 0 { + c.writeTimeout = conf.WriteTimeout } - return client -} - -// NewRTUClient creates new instance of Modbus Client for Modbus RTU protocol -func NewRTUClient(opts ...ClientOptionFunc) *Client { - client := defaultClient() - client.asProtocolErrorFunc = packet.AsRTUErrorPacket - client.parseResponseFunc = packet.ParseRTUResponseWithCRC - - for _, o := range opts { - o(client) + if conf.ReadTimeout > 0 { + c.readTimeout = conf.ReadTimeout } - return client -} - -// NewClient creates new instance of Modbus Client with given options -func NewClient(opts ...ClientOptionFunc) *Client { - client := defaultClient() - for _, o := range opts { - o(client) + if conf.DialContextFunc != nil { + c.dialContextFunc = conf.DialContextFunc } - return client + if conf.AsProtocolErrorFunc != nil { + c.asProtocolErrorFunc = conf.AsProtocolErrorFunc + } + if conf.ParseResponseFunc != nil { + c.parseResponseFunc = conf.ParseResponseFunc + } + if conf.Hooks != nil { + c.hooks = conf.Hooks + } + return c } -// ClientOptionFunc is options type for NewClient function -type ClientOptionFunc func(c *Client) - -// WithProtocolErrorFunc is option to provide custom function for parsing error packet -func WithProtocolErrorFunc(errorFunc func(data []byte) error) func(c *Client) { - return func(c *Client) { - c.asProtocolErrorFunc = errorFunc - } +// NewTCPClient creates new instance of Modbus Client for Modbus TCP protocol +func NewTCPClient() *Client { + return NewTCPClientWithConfig(ClientConfig{}) } -// WithParseResponseFunc is option to provide custom function for parsing protocol packet -func WithParseResponseFunc(parseFunc func(data []byte) (packet.Response, error)) func(c *Client) { - return func(c *Client) { - c.parseResponseFunc = parseFunc - } +// NewTCPClientWithConfig creates new instance of Modbus Client for Modbus TCP protocol with given configuration options +func NewTCPClientWithConfig(conf ClientConfig) *Client { + client := defaultClient(conf) + client.asProtocolErrorFunc = packet.AsTCPErrorPacket + client.parseResponseFunc = packet.ParseTCPResponse + return client } -// WithDialContextFunc is option to provide custom function for creating new connection -func WithDialContextFunc(dialContextFunc func(ctx context.Context, address string) (net.Conn, error)) func(c *Client) { - return func(c *Client) { - c.dialContextFunc = dialContextFunc - } +// NewRTUClient creates new instance of Modbus Client for Modbus RTU protocol +func NewRTUClient() *Client { + return NewRTUClientWithConfig(ClientConfig{}) } -// WithTimeouts is option to for setting writing packet or reading packet timeouts -func WithTimeouts(writeTimeout time.Duration, readTimeout time.Duration) func(c *Client) { - return func(c *Client) { - c.writeTimeout = writeTimeout - c.readTimeout = readTimeout - } +// NewRTUClientWithConfig creates new instance of Modbus Client for Modbus RTU protocol with given configuration options +func NewRTUClientWithConfig(conf ClientConfig) *Client { + client := defaultClient(conf) + client.asProtocolErrorFunc = packet.AsRTUErrorPacket + client.parseResponseFunc = packet.ParseRTUResponseWithCRC + return client } -// WithHooks is option to set hooks in client -func WithHooks(logger ClientHooks) func(c *Client) { - return func(c *Client) { - c.hooks = logger - } +// NewClient creates new instance of Modbus Client with given configuration options +func NewClient(conf ClientConfig) *Client { + return defaultClient(conf) } // Connect opens network connection to Client to server. Context lifetime is only meant for this call. diff --git a/client_test.go b/client_test.go index 1db9575..af165e7 100644 --- a/client_test.go +++ b/client_test.go @@ -112,10 +112,14 @@ func (l *mockLogger) BeforeParse(received []byte) { func TestWithOptions(t *testing.T) { client := NewClient( - WithProtocolErrorFunc(packet.AsRTUErrorPacket), - WithParseResponseFunc(packet.ParseRTUResponse), - WithTimeouts(99*time.Second, 98*time.Second), - WithHooks(new(mockLogger)), + ClientConfig{ + WriteTimeout: 99 * time.Second, + ReadTimeout: 98 * time.Second, + DialContextFunc: nil, + AsProtocolErrorFunc: packet.AsRTUErrorPacket, + ParseResponseFunc: packet.ParseRTUResponse, + Hooks: new(mockLogger), + }, ) assert.NotNil(t, client.asProtocolErrorFunc) assert.NotNil(t, client.parseResponseFunc) @@ -146,7 +150,7 @@ func TestClient_Do_receivePacketWith1Read(t *testing.T) { logger.On("AfterEachRead", []byte{0x12, 0x34, 0x0, 0x0, 0x0, 0x5, 0x1, 0x1, 0x2, 0x0, 0x1}, 11, nil).Once() logger.On("BeforeParse", []byte{0x12, 0x34, 0x0, 0x0, 0x0, 0x5, 0x1, 0x1, 0x2, 0x0, 0x1}).Once() - client := NewTCPClient(WithHooks(logger)) + client := NewTCPClientWithConfig(ClientConfig{Hooks: logger}) client.conn = conn client.timeNow = func() time.Time { return exampleNow @@ -475,7 +479,7 @@ func TestClient_Do_ReadMoreBytesThanPacketCanBe(t *testing.T) { conn.On("Read", mock.Anything). Return(tcpPacketMaxLen+1, nil) - client := NewClient() + client := NewClient(ClientConfig{}) client.conn = conn client.timeNow = func() time.Time { return exampleNow diff --git a/examples/server_and_request_test.go b/examples/server_and_request_test.go new file mode 100644 index 0000000..19b2843 --- /dev/null +++ b/examples/server_and_request_test.go @@ -0,0 +1,118 @@ +package examples_test + +import ( + "context" + "errors" + "github.com/aldas/go-modbus-client" + "github.com/aldas/go-modbus-client/packet" + "github.com/aldas/go-modbus-client/server" + "log" + "net" + "os" + "os/signal" + "testing" + "time" +) + +func TestRequestToServer(t *testing.T) { + mbs := new(mbServer) + + serverAddrCh := make(chan string) + s := server.Server{ + // OnServeFunc is useful integration tests when in situations where actual server is spun up to serve requests + // in that case it is useful it start it on random (":0") port. This callback is run just before server starts + // listening for new connections + OnServeFunc: func(addr net.Addr) { + serverAddrCh <- addr.String() + log.Printf("listening on: %v\n", addr.String()) + }, + OnErrorFunc: nil, + OnAcceptFunc: nil, + } + + tCtx, tCancel := context.WithTimeout(context.Background(), 1*time.Second) + defer tCancel() + ctx, cancel := signal.NotifyContext(tCtx, os.Kill, os.Interrupt) + defer cancel() + + // we start the server and listen for incoming connections/data in separate goroutine. ListenAndServe is blocking call. + go func() { + err := s.ListenAndServe(ctx, "localhost:5020", mbs) + if err != nil && !errors.Is(err, server.ErrServerClosed) { + log.Printf("ListenAndServe end: %v", err) + } + }() + + select { + case <-ctx.Done(): + return + case serverAddr := <-serverAddrCh: // wait for server to "start" + // do the FC03 request + if err := doRequest(ctx, serverAddr); err != nil { + log.Printf("doRequest err: %v\n", err) + return + } + } + + // gracefully shut down the server. + // We could have used here: + //<-ctx.Done() + // to wait for ctrl+c or kill signal but for example we close the server after request has been done + graceful, gCancel := context.WithTimeout(context.Background(), 1*time.Second) + defer gCancel() + if err := s.Shutdown(graceful); err != nil { + log.Printf("Shutdown end: %v", err) + } +} + +func doRequest(ctx context.Context, serverAddress string) error { + client := modbus.NewTCPClientWithConfig(modbus.ClientConfig{ + WriteTimeout: 2 * time.Second, + ReadTimeout: 2 * time.Second, + }) + if err := client.Connect(ctx, serverAddress); err != nil { + return err + } + defer client.Close() + + unitID := uint8(1) + startAddress := uint16(10) + quantity := uint16(2) + req, err := packet.NewReadHoldingRegistersRequestTCP(unitID, startAddress, quantity) + if err != nil { + return err + } + + resp, err := client.Do(ctx, req) + if err != nil { + return err + } + + registers, err := resp.(*packet.ReadHoldingRegistersResponseTCP).AsRegisters(startAddress) + if err != nil { + return err + } + uint16Var, err := registers.Uint16(11) // extract uint16 value from register 11 + log.Printf("Received as register 11 value: %v (hex: %X)\n", uint16Var, uint16Var) + + return nil +} + +type mbServer struct { +} + +func (s *mbServer) Handle(ctx context.Context, received packet.Request) (packet.Response, error) { + switch req := received.(type) { + case *packet.ReadHoldingRegistersRequestTCP: + p := packet.ReadHoldingRegistersResponseTCP{ + MBAPHeader: req.MBAPHeader, + ReadHoldingRegistersResponse: packet.ReadHoldingRegistersResponse{ + UnitID: req.UnitID, + RegisterByteLen: 4, + Data: []byte{0x0, 0x1, 0x01, 0x02}, // register[0] = 0x0001, register[1] = 0x0102 + }, + } + return p, nil + } + return nil, packet.NewErrorParseTCP(packet.ErrIllegalFunction, "nope") +} diff --git a/modbustest/helpers.go b/modbustest/helpers.go index 89dcd4f..3579daa 100644 --- a/modbustest/helpers.go +++ b/modbustest/helpers.go @@ -3,18 +3,35 @@ package modbustest import ( "context" "errors" + "github.com/aldas/go-modbus-client/packet" + "github.com/aldas/go-modbus-client/server" "log" + "net" "time" ) // RunServerOnRandomPort is low level helper function for testing modbus packets. Method starts server in separate -// goroutine and runs it until given context is cancelled. Given ReadHandler is used by server to handle incoming data. -func RunServerOnRandomPort(ctx context.Context, handler ReadHandler) (string, error) { +// goroutine and runs it until given context is cancelled. Given PacketAssembler is used by server to handle incoming data. +func RunServerOnRandomPort( + ctx context.Context, + handler func(received []byte, bytesRead int) (response []byte, closeConnection bool), +) (string, error) { addrChan := make(chan string) serverErrChan := make(chan error) - server := Server{OnServeAddrChan: addrChan} + + rr := &rawReader{ + handler: handler, + } + srv := server.Server{ + AssemblerCreatorFunc: func(_ server.ModbusHandler) server.PacketAssembler { + return rr + }, + OnServeFunc: func(addr net.Addr) { + addrChan <- addr.String() + }, + } go func() { - if err := server.ListenAndServe(ctx, ":0", handler); err != nil { + if err := srv.ListenAndServe(ctx, ":0", rr); err != nil { log.Printf("server err: %v", err) serverErrChan <- err } @@ -29,3 +46,15 @@ func RunServerOnRandomPort(ctx context.Context, handler ReadHandler) (string, er return addr, nil } } + +type rawReader struct { + handler func(received []byte, bytesRead int) (response []byte, closeConnection bool) +} + +func (r *rawReader) Handle(ctx context.Context, received packet.Request) (packet.Response, error) { + panic("this is not called") +} + +func (r *rawReader) ReceiveRead(ctx context.Context, received []byte, bytesRead int) (response []byte, closeConnection bool) { + return r.handler(received, bytesRead) +} diff --git a/modbustest/server.go b/modbustest/server.go deleted file mode 100644 index ccb726b..0000000 --- a/modbustest/server.go +++ /dev/null @@ -1,146 +0,0 @@ -package modbustest - -import ( - "context" - "errors" - "fmt" - "io" - "log" - "net" - "os" - "sync" - "time" -) - -// ErrServerClosed is returned when server context is ended -var ErrServerClosed = errors.New("modbus test server closed") - -// Server simple TCP server implementation for testing modbus packets -type Server struct { - mu sync.RWMutex - listener net.Listener // for simplicity we only allow serving one listener - - OnServeAddrChan chan<- string - OnErrorFunc func(err error) -} - -// ReadHandler is function called when server reads bytes from client connection. -// -// Handler can be called even if no bytes are read from connection. In that case bytesRead==0. -// This is so you can emulate writing modbus packet as multiple fragments. -// return with closeConnection=true when you are done sending fragments and want to close connection -type ReadHandler func(received []byte, bytesRead int) (response []byte, closeConnection bool) - -// ListenAndServe starts accepting connection on given address and handles received data with handler function. -// Method blocks until context is cancelled -func (s *Server) ListenAndServe(ctx context.Context, address string, handler ReadHandler) error { - s.mu.Lock() - defer s.mu.Unlock() - - listener, err := net.Listen("tcp", address) - if err != nil { - return fmt.Errorf("modbustest listnener creation error: %w", err) - } - return s.serve(ctx, listener, handler) -} - -// Serve accepts connections from listener and handles received data with handler function. -// Method blocks until context is cancelled -func (s *Server) Serve(ctx context.Context, listener net.Listener, handler ReadHandler) error { - s.mu.Lock() - defer s.mu.Unlock() - - return s.serve(ctx, listener, handler) -} - -func (s *Server) serve(ctx context.Context, listener net.Listener, handler ReadHandler) error { - if handler == nil { - return errors.New("handler can not be nil") - } - if s.OnServeAddrChan != nil { - // when listener is started with ":0" (random port) this chan will be helpful knowing where to connect - // and if server is listening already - s.OnServeAddrChan <- listener.Addr().String() - } - onErrorFunc := s.OnErrorFunc - if onErrorFunc == nil { - onErrorFunc = func(err error) { - log.Printf("modbus test server connection error: %v", err) - } - } - - s.listener = listener - - l := onceCloseListener{Listener: listener} - defer l.Close() - - for { - conn, err := l.Accept() - if err != nil { - return err - } - - select { - case <-ctx.Done(): - return ErrServerClosed - default: - } - go handleConnection(ctx, conn, handler, onErrorFunc) - } -} - -func handleConnection(ctx context.Context, conn net.Conn, handler ReadHandler, onErrorFunc func(error)) { - defer conn.Close() - received := make([]byte, 300) - for { - select { - case <-ctx.Done(): - return - default: - } - - _ = conn.SetReadDeadline(time.Now().Add(500 * time.Microsecond)) // max 0.5ms block time for read per iteration - n, err := conn.Read(received) - if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) { - if !errors.Is(err, io.EOF) { - onErrorFunc(err) - } - return // when read fails due some unknown error we close connection - } - // NB: handler can be called even if client did not send anything. It is up to developer to handle that case. - toSend, closeConn := handler(received[:n], n) - if toSend != nil { - _ = conn.SetWriteDeadline(time.Now().Add(500 * time.Microsecond)) - if _, err := conn.Write(toSend); err != nil { - onErrorFunc(err) - return // when write fails to client we close connection - } - } - if closeConn { - return - } - } -} - -// Addr returns currently running server address -func (s *Server) Addr() net.Addr { - s.mu.RLock() - defer s.mu.RUnlock() - - return s.listener.Addr() -} - -type onceCloseListener struct { - net.Listener - once sync.Once - closeErr error -} - -func (oc *onceCloseListener) Close() error { - oc.once.Do(oc.close) - return oc.closeErr -} - -func (oc *onceCloseListener) close() { - oc.closeErr = oc.Listener.Close() -} diff --git a/packet/packet.go b/packet/packet.go index c0ce0e8..ab07599 100644 --- a/packet/packet.go +++ b/packet/packet.go @@ -84,17 +84,24 @@ type LooksLikeType int const ( // DataTooShort is case when slice of bytes is too short to determine result - DataTooShort = iota + DataTooShort LooksLikeType = 0 // IsNotTPCPacket is case when slice of bytes can not be Modbus TCP packet - IsNotTPCPacket + IsNotTPCPacket LooksLikeType = 1 // LooksLikeTCPPacket is case when slice of bytes looks like Modbus TCP packet with supported function code - LooksLikeTCPPacket + LooksLikeTCPPacket LooksLikeType = 2 // UnsupportedFunctionCode is case when slice of bytes looks like Modbus TCP packet but function code value is not supported - UnsupportedFunctionCode + UnsupportedFunctionCode LooksLikeType = 3 ) -// IsLikeModbusTCP checks if given data starts with bytes that could be potentially parsed as Modbus TCP packet. -func IsLikeModbusTCP(data []byte, allowUnSupportedFunctionCodes bool) (expectedLen int, looksLike LooksLikeType) { +var ( + // ErrTCPDataTooShort is returned when received data is still too short to be actual Modbus TCP packet. + ErrTCPDataTooShort = NewErrorParseTCP(ErrUnknown, "data is too short to be a Modbus TCP packet") + // ErrIsNotTCPPacket is returned when received data does not look like Modbus TCP packet + ErrIsNotTCPPacket = NewErrorParseTCP(ErrUnknown, "data does not like Modbus TCP packet") +) + +// LooksLikeModbusTCP checks if given data starts with bytes that could be potentially parsed as Modbus TCP packet. +func LooksLikeModbusTCP(data []byte, allowUnSupportedFunctionCodes bool) (expectedLen int, error error) { // Example of first 8 bytes // 0x81 0x80 - transaction id (0,1) // 0x00 0x00 - protocol id (2,3) @@ -104,29 +111,37 @@ func IsLikeModbusTCP(data []byte, allowUnSupportedFunctionCodes bool) (expectedL // minimal amount is 9 bytes (header + unit id + function code + 1 byte of something ala error code) if len(data) < 9 { - return 0, DataTooShort + return 0, ErrTCPDataTooShort } if !(data[2] == 0x0 && data[3] == 0x0) { // check protocol id - return 0, IsNotTPCPacket + return 0, ErrIsNotTCPPacket } pduLen := binary.BigEndian.Uint16(data[4:6]) // number of bytes in the message to follow if pduLen < 3 { // every request is more than 2 bytes of PDU - return 0, IsNotTPCPacket + return 0, ErrIsNotTCPPacket } functionCode := data[7] // function code if functionCode == 0 { - return 0, IsNotTPCPacket + return 0, ErrIsNotTCPPacket } expectedLen = int(pduLen) + 6 if allowUnSupportedFunctionCodes { - return expectedLen, LooksLikeTCPPacket + return expectedLen, nil } for _, fc := range supportedFunctionCodes { if fc == functionCode { - return expectedLen, LooksLikeTCPPacket + return expectedLen, nil } } - return expectedLen, UnsupportedFunctionCode + return expectedLen, &ErrorParseTCP{ + Message: "unsupported function code", + Packet: ErrorResponseTCP{ + TransactionID: binary.BigEndian.Uint16(data[0:2]), + UnitID: data[6], + Function: functionCode, + Code: ErrIllegalFunction, + }, + } } // diff --git a/packet/packet_test.go b/packet/packet_test.go index 59ac6b8..3574979 100644 --- a/packet/packet_test.go +++ b/packet/packet_test.go @@ -72,84 +72,88 @@ func TestParseMBAPHeader(t *testing.T) { } } -func TestIsLikeModbusTCP(t *testing.T) { +func TestLooksLikeModbusTCP(t *testing.T) { var testCases = []struct { name string when []byte whenAllowUnsupportedFC bool expectLength int - expectLooksLike LooksLikeType + expectError string }{ { - name: "ok, full packet", - when: []byte{0x01, 0x02, 0x00, 0x00, 0x00, 0x06, 0x10, 0x01, 0x00, 0x6B, 0x00, 0x03}, - expectLength: 12, - expectLooksLike: LooksLikeTCPPacket, + name: "ok, full packet", + when: []byte{0x01, 0x02, 0x00, 0x00, 0x00, 0x06, 0x10, 0x01, 0x00, 0x6B, 0x00, 0x03}, + expectLength: 12, + expectError: "", }, { - name: "ok, fragment of packet", - when: []byte{0x01, 0x02, 0x00, 0x00, 0x00, 0x06, 0x10, 0x01, 0x00}, - expectLength: 12, - expectLooksLike: LooksLikeTCPPacket, + name: "ok, fragment of packet", + when: []byte{0x01, 0x02, 0x00, 0x00, 0x00, 0x06, 0x10, 0x01, 0x00}, + expectLength: 12, + expectError: "", }, { - name: "nok, ErrorResponseTCP (code=3)", - when: []byte{0x81, 0x80, 0x0, 0x0, 0x0, 0x3, 0x1, 0x82, 0x3}, - expectLength: 9, - expectLooksLike: UnsupportedFunctionCode, + name: "nok, ErrorResponseTCP (code=3)", + when: []byte{0x81, 0x80, 0x0, 0x0, 0x0, 0x3, 0x1, 0x82, 0x3}, + expectLength: 9, + expectError: "unsupported function code", }, { - name: "nok, too few bytes", - when: []byte{0x01, 0x02, 0x00, 0x00, 0x00, 0x06, 0x10, 0x01}, - expectLength: 0, - expectLooksLike: DataTooShort, + name: "nok, too few bytes", + when: []byte{0x01, 0x02, 0x00, 0x00, 0x00, 0x06, 0x10, 0x01}, + expectLength: 0, + expectError: "data is too short to be a Modbus TCP packet", }, { - name: "nok, invalid packet id, 1", - when: []byte{0x01, 0x02, 0x01 /* 0x00 */, 0x00, 0x00, 0x06, 0x10, 0x01, 0x00, 0x6B, 0x00, 0x03}, - expectLength: 0, - expectLooksLike: IsNotTPCPacket, + name: "nok, invalid packet id, 1", + when: []byte{0x01, 0x02, 0x01 /* 0x00 */, 0x00, 0x00, 0x06, 0x10, 0x01, 0x00, 0x6B, 0x00, 0x03}, + expectLength: 0, + expectError: "data does not like Modbus TCP packet", }, { - name: "nok, invalid packet id, 2", - when: []byte{0x01, 0x02, 0x00, 0x01 /* 0x00 */, 0x00, 0x06, 0x10, 0x01, 0x00, 0x6B, 0x00, 0x03}, - expectLength: 0, - expectLooksLike: IsNotTPCPacket, + name: "nok, invalid packet id, 2", + when: []byte{0x01, 0x02, 0x00, 0x01 /* 0x00 */, 0x00, 0x06, 0x10, 0x01, 0x00, 0x6B, 0x00, 0x03}, + expectLength: 0, + expectError: "data does not like Modbus TCP packet", }, { - name: "nok, pdu too short", - when: []byte{0x01, 0x02, 0x00, 0x00, 0x00, 0x02 /* 0x04+ */, 0x10, 0x01, 0x00, 0x6B, 0x00, 0x03}, - expectLength: 0, - expectLooksLike: IsNotTPCPacket, + name: "nok, pdu too short", + when: []byte{0x01, 0x02, 0x00, 0x00, 0x00, 0x02 /* 0x04+ */, 0x10, 0x01, 0x00, 0x6B, 0x00, 0x03}, + expectLength: 0, + expectError: "data does not like Modbus TCP packet", }, { - name: "nok, function code = 0", - when: []byte{0x01, 0x02, 0x00, 0x00, 0x00, 0x06, 0x10, 0x00 /* 0x01 */, 0x00, 0x6B, 0x00, 0x03}, - expectLength: 0, - expectLooksLike: IsNotTPCPacket, + name: "nok, function code = 0", + when: []byte{0x01, 0x02, 0x00, 0x00, 0x00, 0x06, 0x10, 0x00 /* 0x01 */, 0x00, 0x6B, 0x00, 0x03}, + expectLength: 0, + expectError: "data does not like Modbus TCP packet", }, { name: "ok, allow unsupported function code = 1F", when: []byte{0x01, 0x02, 0x00, 0x00, 0x00, 0x06, 0x10, 0x1f /* 0x01 */, 0x00, 0x6B, 0x00, 0x03}, whenAllowUnsupportedFC: true, expectLength: 12, - expectLooksLike: LooksLikeTCPPacket, + expectError: "", }, { name: "ok, unsupported function code = 1F", when: []byte{0x01, 0x02, 0x00, 0x00, 0x00, 0x06, 0x10, 0x1f /* 0x01 */, 0x00, 0x6B, 0x00, 0x03}, whenAllowUnsupportedFC: false, expectLength: 12, - expectLooksLike: UnsupportedFunctionCode, + expectError: "unsupported function code", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - expectedLen, looksLike := IsLikeModbusTCP(tc.when, tc.whenAllowUnsupportedFC) + expectedLen, err := LooksLikeModbusTCP(tc.when, tc.whenAllowUnsupportedFC) assert.Equal(t, tc.expectLength, expectedLen) - assert.Equal(t, tc.expectLooksLike, looksLike) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } }) } } diff --git a/packet/registers.go b/packet/registers.go index 4ea8f25..67fefe7 100644 --- a/packet/registers.go +++ b/packet/registers.go @@ -42,12 +42,13 @@ import ( // Source: http://unixpapa.com/incnote/byteorder.html // // 32bit (dword) integer is in: -// Little Endian (ABCD) = 0x01020304 (0x04 + (0x03 << 8) + (0x02 << 16) + (0x01 << 24)) +// +// Little Endian (ABCD) = 0x01020304 (0x04 + (0x03 << 8) + (0x02 << 16) + (0x01 << 24)) // // May be sent over tcp/udp as: -// Big Endian (DCBA) = 0x04030201 -// Big Endian Low Word First (BADC) = 0x02010403 <-- used by WAGO 750-XXX to send modbus packets over tcp/udp // +// Big Endian (DCBA) = 0x04030201 +// Big Endian Low Word First (BADC) = 0x02010403 <-- used by WAGO 750-XXX to send modbus packets over tcp/udp const ( useDefaultByteOrder ByteOrder = 0 // BigEndian system stores the most significant byte of a word at the smallest memory address and the least diff --git a/packet/request.go b/packet/request.go index 0aca361..ca357bf 100644 --- a/packet/request.go +++ b/packet/request.go @@ -19,7 +19,7 @@ type Request interface { // ParseTCPRequest parses given bytes into modbus TCP request packet or returns error func ParseTCPRequest(data []byte) (Request, error) { if len(data) < 9 { - return nil, errors.New("data is too short to be a Modbus TCP packet") + return nil, ErrTCPDataTooShort } functionCode := data[7] switch functionCode { @@ -42,7 +42,7 @@ func ParseTCPRequest(data []byte) (Request, error) { case FunctionReadWriteMultipleRegisters: // 0x17 return ParseReadWriteMultipleRegistersRequestTCP(data) default: - return nil, fmt.Errorf("unknown function code parsed: %v", functionCode) + return nil, NewErrorParseTCP(ErrIllegalFunction, fmt.Sprintf("unknown function code parsed: %v", functionCode)) } } diff --git a/scripts/.githooks/pre-commit b/scripts/.githooks/pre-commit index df2aadb..2c106ed 100755 --- a/scripts/.githooks/pre-commit +++ b/scripts/.githooks/pre-commit @@ -1,8 +1,8 @@ #!/usr/bin/env bash if [[ -z "$(which golint)" ]]; then - echo "golint not found, executing: go get github.com/golang/lint/golint" - go get github.com/golang/lint/golint + echo "golint not found, executing: go install golang.org/x/lint/golint@latest" + go install golang.org/x/lint/golint@latest fi STAGED_GO_FILES=$(git diff --cached --name-only | grep ".go$") diff --git a/serial.md b/serial.md index f1a2f84..279b718 100644 --- a/serial.md +++ b/serial.md @@ -6,6 +6,7 @@ using cheap [USB To RS485 422](https://www.aliexpress.com/item/32888122294.html) ## github.com/jacobsa/go-serial/serial Example for: [github.com/jacobsa/go-serial/serial](https://github.com/jacobsa/go-serial/) + ```go // import "github.com/jacobsa/go-serial/serial" serialPort, err := serial.Open(serial.OpenOptions{ @@ -40,6 +41,7 @@ fmt.Printf("temperature: %v\n", float32(temp) / 10) ## github.com/tarm/serial Example for: [github.com/tarm/serial](https://github.com/tarm/serial) + ```go serialPort, err := serial.OpenPort(&serial.Config{Name: "/dev/ttyUSB0", Baud: 9600, ReadTimeout: 2 * time.Second}) if err != nil { @@ -65,6 +67,7 @@ fmt.Printf("temperature: %v\n", float32(temp) / 10) ## Raw syscall Example for with raw syscalls (only on unix/linux systems) + ```go serialPort, _ := os.OpenFile("/dev/ttyUSB0", syscall.O_RDWR|syscall.O_NOCTTY|syscall.O_NONBLOCK, 0666) diff --git a/server/modbus.go b/server/modbus.go new file mode 100644 index 0000000..77fcba1 --- /dev/null +++ b/server/modbus.go @@ -0,0 +1,41 @@ +package server + +import ( + "bytes" + "context" + "errors" + "github.com/aldas/go-modbus-client/packet" +) + +// modbusTCPAssembler assembles read data into complete packets and calls ModbusHandler with assembled packet +type modbusTCPAssembler struct { + handler ModbusHandler + received bytes.Buffer +} + +func (m *modbusTCPAssembler) ReceiveRead(ctx context.Context, received []byte, bytesRead int) (response []byte, closeConnection bool) { + m.received.Write(received) + + n, err := packet.LooksLikeModbusTCP(m.received.Bytes(), false) + if err == packet.ErrTCPDataTooShort { + return nil, false // wait for more data to arrive + } else if err != nil { + return err.(*packet.ErrorParseTCP).Bytes(), false + } + + p, err := packet.ParseTCPRequest(m.received.Next(n)) + if err != nil { + return err.(*packet.ErrorParseTCP).Bytes(), false + } + + resp, err := m.handler.Handle(ctx, p) + if err != nil { + var target *packet.ErrorParseTCP + if errors.As(err, &target) { + return target.Bytes(), false + } + return packet.NewErrorParseTCP(packet.ErrUnknown, err.Error()).Bytes(), false + } + + return resp.Bytes(), false +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..5bb24c1 --- /dev/null +++ b/server/server.go @@ -0,0 +1,267 @@ +package server + +import ( + "context" + "errors" + "fmt" + "github.com/aldas/go-modbus-client/packet" + "io" + "log" + "net" + "os" + "sync" + "sync/atomic" + "time" +) + +const ( + readTimeout = 1 * time.Millisecond + writeTimeout = 1 * time.Millisecond + idleTimeout = 25 * time.Second +) + +var ( + // ErrServerClosed is returned when server context is ended (by shutdown) + ErrServerClosed = errors.New("modbus server closed") +) + +// PacketAssembler is called when server reads data from client connection. Is responsible for assembling data read +// from the connection to whole modbus packet. +// +// return with closeConnection=true when you are done sending and want to close connection +type PacketAssembler interface { + ReceiveRead(ctx context.Context, received []byte, bytesRead int) (response []byte, closeConnection bool) +} + +// ModbusHandler calls Handle method when it has received enough data to be parsed into Modbus packet. +type ModbusHandler interface { + Handle(ctx context.Context, received packet.Request) (packet.Response, error) +} + +// Server simple TCP server implementation for server to serve modbus packets. +// Each connection is handled in separate goroutine, in which panics are recovered. +// +// Public fields are not designed to be goroutine safe. Do not mutate after server has been started +type Server struct { + mu sync.RWMutex + listener net.Listener // for simplicity, we only allow serving one listener + isShutdown atomic.Bool + activeConnections map[*connection]struct{} + + // AssemblerCreatorFunc creates Assembler for each connetion to assemble different read byte fragments into complete + // modbus packet. Could have different implementations for TCP or RTU packets + AssemblerCreatorFunc func(handler ModbusHandler) PacketAssembler + + // OnServeFunc allows capturing listener address just before server starts to accepting connections. This is useful + // for testing when listener is started with random port `:0`. + OnServeFunc func(addr net.Addr) + OnErrorFunc func(err error) + OnAcceptFunc func(ctx context.Context, remoteAddr net.Addr) error +} + +type connection struct { + conn net.Conn + isBeingHandled atomic.Bool + assembler PacketAssembler + + onErrorFunc func(error) +} + +// ListenAndServe starts accepting connection on given address and handles received data with handler function. +// Method blocks until context is cancelled +func (s *Server) ListenAndServe(ctx context.Context, address string, handler ModbusHandler) error { + listener, err := net.Listen("tcp", address) + if err != nil { + return fmt.Errorf("modbus listnener creation error: %w", err) + } + return s.serve(ctx, listener, handler) +} + +// Serve accepts connections from listener and handles received data with handler function. +// Method blocks until context is cancelled +func (s *Server) Serve(ctx context.Context, listener net.Listener, handler ModbusHandler) error { + return s.serve(ctx, listener, handler) +} + +func (s *Server) serve(ctx context.Context, listener net.Listener, handler ModbusHandler) error { + if s.AssemblerCreatorFunc == nil { + s.AssemblerCreatorFunc = func(handler ModbusHandler) PacketAssembler { + return &modbusTCPAssembler{handler: handler} + } + } + onErrorFunc := s.OnErrorFunc + if onErrorFunc == nil { + onErrorFunc = func(err error) { + log.Printf("modbus server connection error: %v", err) + } + } + if s.OnServeFunc != nil { + // when listener is started with ":0" (random port) this will be helpful knowing where to connect + // and if server is listening already + s.OnServeFunc(listener.Addr()) + } + + s.listener = listener + l := onceCloseListener{Listener: listener} + defer l.Close() + + for { + netConn, err := l.Accept() + if err != nil { + if s.isShutdown.Load() { + return ErrServerClosed + } + return err + } + + if s.OnAcceptFunc != nil { + if err := s.OnAcceptFunc(ctx, netConn.RemoteAddr()); err != nil { + continue + } + } + + select { + case <-ctx.Done(): + return ErrServerClosed + default: + } + + c := &connection{ + conn: netConn, + isBeingHandled: atomic.Bool{}, + assembler: s.AssemblerCreatorFunc(handler), + onErrorFunc: onErrorFunc, + } + s.trackConn(c, true) + go func(ctx context.Context, conn *connection) { + defer func() { + if rec := recover(); rec != nil { + conn.onErrorFunc(fmt.Errorf("recovered panic in handler, %v", rec)) + } + if err := conn.conn.Close(); err != nil { + conn.onErrorFunc(fmt.Errorf("failed to close handler connection, err: %w", err)) + } + s.trackConn(c, false) + }() + conn.handle(ctx) + }(ctx, c) + } +} + +type onceCloseListener struct { + net.Listener + once sync.Once + closeErr error +} + +func (oc *onceCloseListener) Close() error { + oc.once.Do(oc.close) + return oc.closeErr +} + +func (oc *onceCloseListener) close() { + oc.closeErr = oc.Listener.Close() +} + +func (s *Server) trackConn(c *connection, isAdd bool) { + // this is how http.Server does it + s.mu.Lock() + defer s.mu.Unlock() + + if s.activeConnections == nil { + s.activeConnections = make(map[*connection]struct{}) + } + if isAdd { + s.activeConnections[c] = struct{}{} + } else { + delete(s.activeConnections, c) + } +} + +func (c *connection) handle(ctx context.Context) { + cCtx, cCancel := context.WithCancel(ctx) + defer cCancel() + + conn := c.conn + var lastReceived time.Time + received := make([]byte, 300) + for { + select { + case <-cCtx.Done(): + return + default: + } + + _ = conn.SetReadDeadline(time.Now().Add(readTimeout)) + n, err := conn.Read(received) + if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) { + if !errors.Is(err, io.EOF) { + c.onErrorFunc(err) + } + return // when read fails due some unknown error we close connection + } + if n > 0 { + lastReceived = time.Now() + } else if time.Now().Sub(lastReceived) > idleTimeout { + return // close idle connection + } else { + continue // nothing read and not idle yet + } + + c.isBeingHandled.Store(true) + toSend, closeConn := c.assembler.ReceiveRead(cCtx, received[0:n], n) + if toSend != nil { + _ = conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + if _, err := conn.Write(toSend); err != nil { + c.onErrorFunc(err) + return // when write fails to client we close connection + } + } + c.isBeingHandled.Store(false) + if closeConn { + return + } + } +} + +// Addr returns currently running server address +func (s *Server) Addr() net.Addr { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.listener.Addr() +} + +// Shutdown gracefully shuts down the server without interrupting any active connections. +// Works similarly as `http.Server.Shutdown()` +func (s *Server) Shutdown(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + s.isShutdown.Store(true) + + err := s.listener.Close() + + timer := time.NewTimer(50 * time.Millisecond) + defer timer.Stop() + for { + allIdle := true + for c := range s.activeConnections { + if c.isBeingHandled.Load() { + allIdle = false + continue + } + (*c).conn.Close() + delete(s.activeConnections, c) + } + if allIdle { + return err + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + timer.Reset(50 * time.Millisecond) + } + } +} diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 0000000..d464e5d --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,120 @@ +package server + +import ( + "context" + "errors" + "github.com/aldas/go-modbus-client" + "github.com/aldas/go-modbus-client/packet" + "github.com/stretchr/testify/assert" + "net" + "os" + "os/signal" + "testing" + "time" +) + +func TestRequestToServer(t *testing.T) { + mbs := new(mbServer) + + serverAddrCh := make(chan string) + s := Server{ + OnServeFunc: func(addr net.Addr) { + serverAddrCh <- addr.String() + }, + OnErrorFunc: nil, + OnAcceptFunc: nil, + } + + tCtx, tCancel := context.WithTimeout(context.Background(), 1*time.Second) + defer tCancel() + ctx, cancel := signal.NotifyContext(tCtx, os.Kill, os.Interrupt) + defer cancel() + + // we start the server and listen for incoming connections/data in separate goroutine. ListenAndServe is blocking call. + go func() { + err := s.ListenAndServe(ctx, "localhost:5020", mbs) + if err != nil && !errors.Is(err, ErrServerClosed) { + assert.NoError(t, err) + } + }() + + select { + case <-ctx.Done(): + return + case serverAddr := <-serverAddrCh: // wait for server to "start" + register11, err := doRequest(ctx, serverAddr) + assert.NoError(t, err) + assert.Equal(t, uint16(258), register11) + } + + graceful, gCancel := context.WithTimeout(context.Background(), 1*time.Second) + defer gCancel() + if err := s.Shutdown(graceful); err != nil { + assert.NoError(t, err) + } +} + +func doRequest(ctx context.Context, serverAddress string) (uint16, error) { + client := modbus.NewTCPClientWithConfig(modbus.ClientConfig{ + WriteTimeout: 2 * time.Second, + ReadTimeout: 2 * time.Second, + }) + if err := client.Connect(ctx, serverAddress); err != nil { + return 0, err + } + defer client.Close() + + unitID := uint8(1) + startAddress := uint16(10) + quantity := uint16(2) + req, err := packet.NewReadHoldingRegistersRequestTCP(unitID, startAddress, quantity) + if err != nil { + return 0, err + } + + resp, err := client.Do(ctx, req) + if err != nil { + return 0, err + } + + registers, err := resp.(*packet.ReadHoldingRegistersResponseTCP).AsRegisters(startAddress) + if err != nil { + return 0, err + } + + return registers.Uint16(11) +} + +type mbServer struct { +} + +func (s *mbServer) Handle(ctx context.Context, received packet.Request) (packet.Response, error) { + switch req := received.(type) { + case *packet.ReadHoldingRegistersRequestTCP: + p := packet.ReadHoldingRegistersResponseTCP{ + MBAPHeader: req.MBAPHeader, + ReadHoldingRegistersResponse: packet.ReadHoldingRegistersResponse{ + UnitID: req.UnitID, + RegisterByteLen: 4, + Data: []byte{0x0, 0x1, 0x01, 0x02}, // register[0] = 0x0001, register[1] = 0x0102 + }, + } + return p, nil + } + return nil, packet.NewErrorParseTCP(packet.ErrIllegalFunction, "nope") +} + +func TestServer_Addr(t *testing.T) { + listener, err := net.Listen("tcp", ":0") + if !assert.NoError(t, err) { + return + } + defer listener.Close() + + lAddr := listener.Addr().String() + + s := Server{ + listener: listener, + } + assert.Equal(t, lAddr, s.Addr().String()) +}