Skip to content

Commit

Permalink
Merge 3871a63 into 273cae5
Browse files Browse the repository at this point in the history
  • Loading branch information
medzin committed Jan 22, 2018
2 parents 273cae5 + 3871a63 commit 4c668fa
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 51 deletions.
28 changes: 28 additions & 0 deletions xnet/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package xnet

import "fmt"

// MultiError is returned by batch operations when there are errors with
// particular elements.
type MultiError []error

func (m MultiError) Error() string {
s, n := "", 0
for _, e := range m {
if e != nil {
if n == 0 {
s = e.Error()
}
n++
}
}
switch n {
case 0:
return "(0 errors)"
case 1:
return s
case 2:
return s + " (and 1 other error)"
}
return fmt.Sprintf("%s (and %d other errors)", s, n-1)
}
69 changes: 69 additions & 0 deletions xnet/tcp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package xnet

import (
"fmt"
"net"

log "github.com/sirupsen/logrus"
)

// TCPSender is a Sender implementation that can write payload to the network
// address and reuses TCP connections for the same addresses.
type TCPSender struct {
Dialer net.Dialer

connections map[Address]net.Conn
}

// Send sends given payload to passed address. Data is sent using pool of TCP
// connections. It returns number of bytes sent and error - if there was any.
func (s *TCPSender) Send(addr Address, payload []byte) (int, error) {
if s.connections == nil {
s.connections = make(map[Address]net.Conn)
}
conn, ok := s.connections[addr]
if !ok {
newConn, err := s.dial(addr)
if err != nil {
return 0, fmt.Errorf("unable to dial %s address: %s", addr, err)
}
s.connections[addr] = newConn
conn = newConn
}
n, err := conn.Write(payload)
if err != nil {
// let's be nice and at least try to close connection on our side
closeErr := s.connections[addr].Close()
if closeErr != nil {
log.WithError(closeErr).Warn("Unable to close TCP connection properly")
}
delete(s.connections, addr)
}
return n, err
}

// Release frees system socket used by sender.
func (s *TCPSender) Release() error {
if s.connections == nil {
return nil
}
var errs []error
for _, conn := range s.connections {
if err := conn.Close(); err != nil {
errs = append(errs, err)
}
}
s.connections = nil
if len(errs) > 0 {
return MultiError(errs)
}
return nil
}

func (s *TCPSender) dial(addr Address) (net.Conn, error) {
conn, err := s.Dialer.Dial("tcp", string(addr))
if err != nil {
return nil, err // we want plain error here
}
return conn, nil
}
71 changes: 71 additions & 0 deletions xnet/tcp_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package xnet

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/allegro/mesos-executor/xnet/xnettest"
)

func TestIfTCPNetworkSenderReusesConnections(t *testing.T) {
listener1, results1, err := xnettest.LoopbackServer("tcp")
require.NoError(t, err)
defer listener1.Close()
listener2, results2, err := xnettest.LoopbackServer("tcp")
require.NoError(t, err)
defer listener2.Close()

sender := &TCPSender{}
defer sender.Release()

_, err = sender.Send(Address(listener1.Addr().String()), []byte("test"))
require.NoError(t, err)
<-results1

_, err = sender.Send(Address(listener1.Addr().String()), []byte("test"))
require.NoError(t, err)
<-results1

_, err = sender.Send(Address(listener2.Addr().String()), []byte("test"))
require.NoError(t, err)
<-results2

assert.Len(t, sender.connections, 2)
}

func TestIfTCPNetworkSenderReleasesResources(t *testing.T) {
listener, _, err := xnettest.LoopbackServer("tcp")
require.NoError(t, err)
defer listener.Close()

sender := &TCPSender{}
_, err = sender.Send(Address(listener.Addr().String()), []byte("test"))
require.NoError(t, err)
sender.Release()

assert.Empty(t, sender.connections)
}

func TestIfTCPNetworkSenderReturnsNumberOfSentBytes(t *testing.T) {
listener, results, err := xnettest.LoopbackServer("tcp")
require.NoError(t, err)
defer listener.Close()

sender := &TCPSender{}
bytesSent, err := sender.Send(Address(listener.Addr().String()), []byte("test"))

assert.NoError(t, err)
assert.Equal(t, len([]byte("test")), bytesSent)
assert.Equal(t, []byte("test"), <-results)
}

func TestIfTCPNetworkSenderReturnsErrorWhenConnectionUnavailable(t *testing.T) {
sender := &TCPSender{}

bytesSent, err := sender.Send("198.51.100.5", []byte("test")) // see RFC 5737 for more info about this IP address

assert.Error(t, err)
assert.Zero(t, bytesSent)
}
21 changes: 3 additions & 18 deletions xnet/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"net"
"reflect"
"sort"
"strconv"
"time"

"github.com/hashicorp/consul/api"
Expand Down Expand Up @@ -77,7 +76,7 @@ func (r *roundRobinWriter) write(payload []byte) (int, error) {
}

// UDPSender is a Sender implementation that can write payload to the network
// Address and reuses single system socket. It uses UDP packets to send data.
// address and reuses single system socket. It uses UDP packets to send data.
type UDPSender struct {
conn *net.UDPConn
}
Expand All @@ -93,12 +92,12 @@ func (s *UDPSender) Send(addr Address, payload []byte) (int, error) {
s.conn = conn
}

udpAddr, err := addressToUDP(addr)
udpAddr, err := net.ResolveUDPAddr("udp", string(addr))
if err != nil {
return 0, fmt.Errorf("invalid address %s: %s", addr, err)
}

n, err := s.conn.WriteTo(payload, &udpAddr)
n, err := s.conn.WriteTo(payload, udpAddr)
if err != nil {
return 0, fmt.Errorf("could not sent payload to %s: %s", addr, err)
}
Expand All @@ -115,20 +114,6 @@ func (s *UDPSender) Release() error {
return err
}

func addressToUDP(addr Address) (net.UDPAddr, error) {
host, p, err := net.SplitHostPort(string(addr))
if err != nil {
return net.UDPAddr{}, err
}

port, err := strconv.Atoi(p)
if err != nil {
return net.UDPAddr{}, err
}

return net.UDPAddr{IP: net.ParseIP(host), Port: port}, nil
}

// DiscoveryServiceInstanceProvider returns InstanceProvider that is updated with
// list of instances in interval
func DiscoveryServiceInstanceProvider(serviceName string, interval time.Duration, client DiscoveryServiceClient) InstanceProvider {
Expand Down
51 changes: 18 additions & 33 deletions xnet/writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package xnet
import (
"fmt"
"net"
"strconv"
"testing"
"time"

Expand All @@ -15,7 +16,6 @@ import (

const (
testPayload = "test"
loopback = "127.0.0.1"
)

func TestIntegrationWithConsulRoundRobinAndNetworkSend(t *testing.T) {
Expand All @@ -26,16 +26,19 @@ func TestIntegrationWithConsulRoundRobinAndNetworkSend(t *testing.T) {
defer stopConsul(server)

// create listener acting as a service
port, result, err := udpServer()
addr, result, err := udpServer()
require.NoError(t, err)
host, portString, err := net.SplitHostPort(string(addr))
require.NoError(t, err)
port, _ := strconv.Atoi(portString)

// register service in consul
agent := consulApiClient.Agent()
err = agent.ServiceRegister(&api.AgentServiceRegistration{
ID: "1",
Name: "service-name",
Port: port,
Address: loopback,
Address: host,
})
require.NoError(t, err)

Expand All @@ -57,50 +60,36 @@ func TestIntegrationWithConsulRoundRobinAndNetworkSend(t *testing.T) {
assert.Equal(t, testPayload, <-result)
}

func TestNetworkSendShouldReturnErrorWhenConnectionUnavailable(t *testing.T) {
func TestUDPNetworkSendShouldReturnErrorWhenConnectionUnavailable(t *testing.T) {
sender := &UDPSender{}

bytesSent, err := sender.Send(loopback, []byte("test"))
bytesSent, err := sender.Send("198.51.100.5", []byte("test")) // see RFC 5737 for more info about this IP address

assert.Error(t, err)
assert.Zero(t, bytesSent)
}

func TestNetworkSendShouldReturnNumberOfSentBytes(t *testing.T) {
port, result, err := udpServer()
require.NoError(t, err)

sender := &UDPSender{}

bytesSent, err := sender.Send(localhost(port), []byte(testPayload))

assert.NoError(t, err)
assert.Equal(t, len(testPayload), bytesSent)
assert.Equal(t, testPayload, <-result)
}

func TestUDPSenderWithSharedConnShouldReturnNumberOfSentBytes(t *testing.T) {
port, result, err := udpServer()
func TestUDPNetworkSendShouldReturnNumberOfSentBytes(t *testing.T) {
addr, result, err := udpServer()
require.NoError(t, err)

sender := &UDPSender{}

bytesSent, err := sender.Send(localhost(port), []byte(testPayload))
bytesSent, err := sender.Send(addr, []byte(testPayload))

assert.NoError(t, err)
assert.Equal(t, len(testPayload), bytesSent)
assert.Equal(t, testPayload, <-result)
}

func udpServer() (int, <-chan string, error) {
udpAddr := net.UDPAddr{}
conn, err := net.ListenUDP("udp", &udpAddr)
func udpServer() (Address, <-chan string, error) {
udpAddr := net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
}
conn, err := net.ListenUDP("udp4", &udpAddr)
if err != nil {
return 0, nil, err
return "", nil, err
}

udpAddr, _ = addressToUDP(Address(conn.LocalAddr().String()))

result := make(chan string)

go func() {
Expand All @@ -115,7 +104,7 @@ func udpServer() (int, <-chan string, error) {
result <- string(buf[0:n])
}()

return udpAddr.Port, result, nil
return Address(conn.LocalAddr().String()), result, nil
}

func TestDiscoveryServiceInstanceProviderShouldPeriodicallyUpdatesInstances(t *testing.T) {
Expand Down Expand Up @@ -325,10 +314,6 @@ func createTestConsulServer(t *testing.T) (config *api.Config, server *testutil.
return config, server
}

func localhost(port int) Address {
return Address(fmt.Sprintf("%s:%d", loopback, port))
}

type MockSender struct {
mock.Mock
}
Expand Down
37 changes: 37 additions & 0 deletions xnet/xnettest/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package xnettest

import (
"net"
)

// LoopbackServer creates a new network Listener that is binded to the loopback
// interface and can be used to test tcp/udp connections. Listener must be
// closed at the end of the tests to release system resources. It returns
// configured listener and the channel to which it will send received data.
func LoopbackServer(network string) (net.Listener, <-chan []byte, error) {
listener, err := net.Listen(network, "127.0.0.1:0")
if err != nil {
return nil, nil, err
}
results := make(chan []byte)
go func() {
for {
conn, err := listener.Accept()
if err != nil {
return // if we are unable to accept connections listener is probably closed
}
go func() {
for {
buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil {
_ = conn.Close()
return
}
results <- buf[0:n]
}
}()
}
}()
return listener, results, nil
}
Loading

0 comments on commit 4c668fa

Please sign in to comment.