diff --git a/internal/listenmap/internal/listener/listen.go b/internal/listenmap/internal/listener/listen.go index 083f907..71b8411 100644 --- a/internal/listenmap/internal/listener/listen.go +++ b/internal/listenmap/internal/listener/listen.go @@ -45,7 +45,7 @@ func New(p int) *Listener { var i int pause := make(chan struct{}) unpause := func() { pause <- struct{}{} } - wait, cancel := func() {}, func() {} + cancelAndWait := func() {} for { select { case r := <-l.lockL: @@ -58,7 +58,7 @@ func New(p int) *Listener { case r := <-l.addL: i += 1 if i == 1 { - cancel, wait, err = l.run(r.getCb, r.workers, r.buffer) + cancelAndWait, err = l.run(r.getCb, r.workers, r.buffer) } r.err <- err err = nil @@ -69,9 +69,7 @@ func New(p int) *Listener { } if i == 0 { // shut her down - _ = l.conn.Close() - cancel() - wait() + cancelAndWait() } } } @@ -116,33 +114,31 @@ type runParams struct { err chan<- error } -func (l *Listener) run(getCb func(net.IP, uint16) func(context.Context, *ping.Ping), workers int, buffer int) (cancel func(), wait func(), err error) { +func (l *Listener) run(getCb func(net.IP, uint16) func(context.Context, *ping.Ping), workers int, buffer int) (cancelAndWait func(), err error) { + cancelAndWait = func() {} l.conn, err = icmp.ListenPacket(l.Props.Network, l.Props.Src) if err != nil { - return func() {}, func() {}, err + return cancelAndWait, err } err = setPacketCon(l.conn) if err != nil { _ = l.conn.Close() - return func() {}, func() {}, err + return cancelAndWait, err } // this is not inheriting a context. Each ip has a context, which will decrement the waitgroup when it's done. - ctx, cancel := context.WithCancel(context.Background()) + wCtx, wCancel := context.WithCancel(context.Background()) // start workers - wWg := sync.WaitGroup{} - proc := getProcFunc(ctx, workers, buffer, &wWg) + proc, wWait := getProcFunc(wCtx, workers, buffer) + + pWg := sync.WaitGroup{} - wWg.Add(1) + ctx, cancel := context.WithCancel(context.Background()) + rWg := sync.WaitGroup{} + rWg.Add(1) go func() { - defer wWg.Done() for { - select { - case <-ctx.Done(): - return - default: - } r := &messages.RecvMsg{ Payload: make([]byte, l.Props.ExpectedLen), } @@ -150,10 +146,27 @@ func (l *Listener) run(getCb func(net.IP, uint16) func(context.Context, *ping.Pi if err != nil { continue } + select { + case <-ctx.Done(): + rWg.Done() + return + default: + } + pWg.Add(1) r.Recieved = time.Now() - proc(&procMsg{ctx, r, getCb}) + proc(ctx, r, getCb, pWg.Done) } }() - return cancel, wWg.Wait, nil + cancelAndWait = func() { + cancel() // stop conection listener + // this is not unblocking readPacket, why? + for err := l.conn.Close(); err != nil; err = l.conn.Close() { + } + rWg.Wait() // wait for connection listener to stop + pWg.Wait() // wait for packets to be distributed + wCancel() // stop workers + wWait() // wait for workers to stop + } + return cancelAndWait, nil } diff --git a/internal/listenmap/internal/listener/workers.go b/internal/listenmap/internal/listener/workers.go index e0df1fb..41bd7d8 100644 --- a/internal/listenmap/internal/listener/workers.go +++ b/internal/listenmap/internal/listener/workers.go @@ -9,50 +9,92 @@ import ( "github.com/TrilliumIT/go-multiping/ping" ) -func getProcFunc(ctx context.Context, workers, buffer int, wWg *sync.WaitGroup) func(*procMsg) { +type procFunc func( + ctx context.Context, + r *messages.RecvMsg, + getCb func(net.IP, uint16) func(context.Context, *ping.Ping), + done func(), +) + +func getProcFunc(ctx context.Context, workers, buffer int) (procFunc, func()) { // start workers - proc := processMessage + if workers < -1 { + return func( + ctx context.Context, + r *messages.RecvMsg, + getCb func(net.IP, uint16) func(context.Context, *ping.Ping), + done func(), + ) { + processMessage(ctx, r, getCb) + done() + }, func() {} + } + if workers == 0 { - proc = func(p *procMsg) { - wWg.Add(1) + return func( + ctx context.Context, + r *messages.RecvMsg, + getCb func(net.IP, uint16) func(context.Context, *ping.Ping), + done func(), + ) { go func() { - processMessage(p) - wWg.Done() + processMessage(ctx, r, getCb) + done() }() - } + }, func() {} } - if workers == -1 || workers > 0 { - wCh := make(chan *procMsg, buffer) - if workers == -1 { - proc = func(p *procMsg) { - select { - case wCh <- p: - return - default: - } - wWg.Add(1) - go func() { - runWorker(ctx, wCh) - wWg.Done() - }() - wCh <- p + wCh := make(chan *procMsg, buffer) + wWg := sync.WaitGroup{} + if workers == -1 { + return func( + ctx context.Context, + r *messages.RecvMsg, + getCb func(net.IP, uint16) func(context.Context, *ping.Ping), + done func(), + ) { + select { + case wCh <- &procMsg{ctx, r, getCb, done}: + return + case <-ctx.Done(): + return + default: } - } - if workers > 0 { - proc = func(p *procMsg) { - wCh <- p - } - } - for w := 0; w < workers; w++ { wWg.Add(1) go func() { runWorker(ctx, wCh) wWg.Done() }() - } + select { + case wCh <- &procMsg{ctx, r, getCb, done}: + return + case <-ctx.Done(): + return + } + }, wWg.Wait } - return proc + + for w := 0; w < workers; w++ { + wWg.Add(1) + go func() { + runWorker(ctx, wCh) + wWg.Done() + }() + } + + return func( + ctx context.Context, + r *messages.RecvMsg, + getCb func(net.IP, uint16) func(context.Context, *ping.Ping), + done func(), + ) { + select { + case wCh <- &procMsg{ctx, r, getCb, done}: + return + case <-ctx.Done(): + return + } + }, wWg.Wait } func runWorker(ctx context.Context, wCh <-chan *procMsg) { @@ -61,7 +103,8 @@ func runWorker(ctx context.Context, wCh <-chan *procMsg) { case <-ctx.Done(): return case p := <-wCh: - processMessage(p) + processMessage(p.ctx, p.r, p.getCb) + p.done() } } } @@ -70,18 +113,23 @@ type procMsg struct { ctx context.Context r *messages.RecvMsg getCb func(net.IP, uint16) func(context.Context, *ping.Ping) + done func() } -func processMessage(pm *procMsg) { - p := pm.r.ToPing() +func processMessage( + ctx context.Context, + r *messages.RecvMsg, + getCb func(net.IP, uint16) func(context.Context, *ping.Ping), +) { + p := r.ToPing() if p == nil { return } - cb := pm.getCb(p.Dst, uint16(p.ID)) + cb := getCb(p.Dst, uint16(p.ID)) if cb == nil { return } - cb(pm.ctx, p) + cb(ctx, p) } diff --git a/pinger/errors_test.go b/pinger/errors_test.go index eb48120..5868d33 100644 --- a/pinger/errors_test.go +++ b/pinger/errors_test.go @@ -12,6 +12,8 @@ import ( func TestDupListener(t *testing.T) { assert := assert.New(t) - assert.NoError(DefaultConn().lm.Add(context.Background(), net.ParseIP("127.0.0.1"), 5, nil)) + ctx, cancel := context.WithCancel(context.Background()) + assert.NoError(DefaultConn().lm.Add(ctx, net.ParseIP("127.0.0.1"), 5, nil)) assert.Equal(listenmap.ErrAlreadyExists, DefaultConn().lm.Add(context.Background(), net.ParseIP("127.0.0.1"), 5, nil)) + cancel() } diff --git a/pinger/success_test.go b/pinger/success_test.go index 0b052e3..d3a97e4 100644 --- a/pinger/success_test.go +++ b/pinger/success_test.go @@ -21,8 +21,8 @@ func TestMain(m *testing.M) { go func() { <-c go func() { - time.Sleep(5 * time.Second) - err := pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) + err := pprof.Lookup("goroutine").WriteTo(os.Stderr, 1) + time.Sleep(1 * time.Second) panic(err) }() }() @@ -93,8 +93,8 @@ func testHosts(t *testing.T, cf PingConf) { hosts := []string{ "127.0.0.1", "127.0.0.1", - "::1", - "::1", + //"::1", + //"::1", "127.0.0.2", "127.0.0.3", }