diff --git a/syncer/syncer.go b/syncer/syncer.go index f3d4c14..ec1597b 100644 --- a/syncer/syncer.go +++ b/syncer/syncer.go @@ -1,6 +1,7 @@ package syncer import ( + "context" "errors" "io" "log" @@ -540,26 +541,44 @@ func (s *Syncer) relayV2TransactionSet(txns []types.V2Transaction, origin *gatew } } -func (s *Syncer) acceptLoop() error { - allowConnect := func(peer string) error { - s.mu.Lock() - defer s.mu.Unlock() - if s.pm.Banned(peer) { - return errors.New("banned") - } - var in int - for _, p := range s.peers { - if p.Inbound { - in++ - } +func (s *Syncer) allowConnect(peer string, inbound bool) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.l == nil { + return errors.New("syncer is shutting down") + } + if s.pm.Banned(peer) { + return errors.New("banned") + } + var in, out int + for _, p := range s.peers { + if p.Inbound { + in++ + } else { + out++ } - // TODO: subnet-based limits - if in >= s.config.MaxInboundPeers { - return errors.New("too many inbound peers") + } + // TODO: subnet-based limits + if inbound && in >= s.config.MaxInboundPeers { + return errors.New("too many inbound peers") + } else if !inbound && out >= s.config.MaxOutboundPeers { + return errors.New("too many outbound peers") + } + return nil +} + +func (s *Syncer) alreadyConnected(peer *gateway.Peer) bool { + s.mu.Lock() + defer s.mu.Unlock() + for _, p := range s.peers { + if p.UniqueID == peer.UniqueID { + return true } - return nil } + return false +} +func (s *Syncer) acceptLoop() error { for { conn, err := s.l.Accept() if err != nil { @@ -567,10 +586,12 @@ func (s *Syncer) acceptLoop() error { } go func() { defer conn.Close() - if err := allowConnect(conn.RemoteAddr().String()); err != nil { + if err := s.allowConnect(conn.RemoteAddr().String(), true); err != nil { s.log.Printf("rejected inbound connection from %v: %v", conn.RemoteAddr(), err) } else if p, err := gateway.Accept(conn, s.header); err != nil { s.log.Printf("failed to accept inbound connection from %v: %v", conn.RemoteAddr(), err) + } else if s.alreadyConnected(p) { + s.log.Printf("rejected inbound connection from %v: already connected", conn.RemoteAddr()) } else { s.runPeer(p) } @@ -635,6 +656,11 @@ func (s *Syncer) peerLoop(closeChan <-chan struct{}) error { return false } } + closing := func() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.l == nil + } for fst := true; fst || sleep(); fst = false { if numOutbound() >= s.config.MaxOutboundPeers { continue @@ -645,7 +671,7 @@ func (s *Syncer) peerLoop(closeChan <-chan struct{}) error { continue } for _, p := range candidates { - if numOutbound() >= s.config.MaxOutboundPeers { + if numOutbound() >= s.config.MaxOutboundPeers || closing() { break } if _, err := s.Connect(p); err == nil { @@ -736,7 +762,7 @@ func (s *Syncer) syncLoop(closeChan <-chan struct{}) error { if err != nil { s.log.Printf("syncing with %v failed after %v blocks: %v", p, totalBlocks, err) } else if newTip := s.cm.Tip(); newTip != oldTip { - s.log.Printf("finished syncing %v blocks with %v, tip now %v", p, totalBlocks, newTip) + s.log.Printf("finished syncing %v blocks with %v, tip now %v", totalBlocks, p, newTip) } else { s.log.Printf("finished syncing with %v, tip unchanged", p) } @@ -761,6 +787,7 @@ func (s *Syncer) Run() error { close(closeChan) s.l.Close() s.mu.Lock() + s.l = nil for addr, p := range s.peers { p.Close() delete(s.peers, addr) @@ -776,35 +803,39 @@ func (s *Syncer) Run() error { // Connect forms an outbound connection to a peer. func (s *Syncer) Connect(addr string) (*gateway.Peer, error) { - allowConnect := func(peer string) error { - s.mu.Lock() - defer s.mu.Unlock() - if s.pm.Banned(peer) { - return errors.New("banned") - } - var out int - for _, p := range s.peers { - if !p.Inbound { - out++ - } - } - // TODO: subnet-based limits - if out >= s.config.MaxOutboundPeers { - return errors.New("too many outbound peers") - } - return nil - } - - if err := allowConnect(addr); err != nil { + if err := s.allowConnect(addr, false); err != nil { return nil, err } - conn, err := net.DialTimeout("tcp", addr, s.config.ConnectTimeout) + ctx, cancel := context.WithTimeout(context.Background(), s.config.ConnectTimeout) + defer cancel() + // slightly gross polling hack so that we shutdown quickly + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(100 * time.Millisecond): + s.mu.Lock() + if s.l == nil { + cancel() + } + s.mu.Unlock() + } + } + }() + conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr) if err != nil { return nil, err } + conn.SetDeadline(time.Now().Add(s.config.ConnectTimeout)) + defer conn.SetDeadline(time.Time{}) p, err := gateway.Dial(conn, s.header) if err != nil { + conn.Close() return nil, err + } else if s.alreadyConnected(p) { + conn.Close() + return nil, errors.New("already connected") } go s.runPeer(p)