diff --git a/pool/node.go b/pool/node.go index ab3dabe..003b8f6 100644 --- a/pool/node.go +++ b/pool/node.go @@ -53,6 +53,7 @@ type ( localWorkers sync.Map // workers created by this node workerStreams sync.Map // worker streams indexed by ID + workerAckStreams sync.Map // streams for worker acks indexed by ID pendingJobChannels sync.Map // channels used to send DispatchJob results, nil if event is requeued pendingEvents sync.Map // pending events indexed by sender and event IDs @@ -218,6 +219,7 @@ func AddNode(ctx context.Context, poolName string, rdb *redis.Client, opts ...No shutdownMap: wsm, tickerMap: tm, workerStreams: sync.Map{}, + workerAckStreams: sync.Map{}, pendingJobChannels: sync.Map{}, pendingEvents: sync.Map{}, poolStream: poolStream, @@ -682,7 +684,7 @@ func (node *Node) ackWorkerEvent(ctx context.Context, ev *streaming.Event) { // dispatched the job. if pending.EventName == evStartJob { _, nodeID := unmarshalJobKeyAndNodeID(pending.Payload) - stream, err := streaming.NewStream(nodeStreamName(node.PoolName, nodeID), node.rdb, soptions.WithStreamLogger(node.logger)) + stream, err := node.getOrCreateWorkerAckStream(ctx, nodeID) if err != nil { node.logger.Error(fmt.Errorf("ackWorkerEvent: failed to create node event stream %q: %w", nodeStreamName(node.PoolName, nodeID), err)) return @@ -1111,6 +1113,29 @@ func (node *Node) workerStream(_ context.Context, id string) (*streaming.Stream, return val.(*streaming.Stream), nil } +// getOrCreateWorkerAckStream gets or creates a stream for worker acks +func (node *Node) getOrCreateWorkerAckStream(ctx context.Context, nodeID string) (*streaming.Stream, error) { + if val, ok := node.workerAckStreams.Load(nodeID); ok { + return val.(*streaming.Stream), nil + } + + stream, err := streaming.NewStream( + nodeStreamName(node.PoolName, nodeID), + node.rdb, + soptions.WithStreamLogger(node.logger), + ) + if err != nil { + return nil, err + } + + actual, loaded := node.workerAckStreams.LoadOrStore(nodeID, stream) + if loaded { + // Another goroutine created the stream first, just discard our local reference + return actual.(*streaming.Stream), nil + } + return stream, nil +} + // cleanup removes the worker from all pool maps. func (node *Node) cleanupWorker(ctx context.Context, id string) { if _, err := node.workerMap.Delete(ctx, id); err != nil { diff --git a/pool/node_test.go b/pool/node_test.go index 8c0826a..a099505 100644 --- a/pool/node_test.go +++ b/pool/node_test.go @@ -834,6 +834,44 @@ func TestShutdownStopsAllJobs(t *testing.T) { assert.Empty(t, worker2.Jobs(), "Worker2 should have no remaining jobs") } +func TestWorkerAckStreams(t *testing.T) { + testName := strings.Replace(t.Name(), "/", "_", -1) + ctx := ptesting.NewTestContext(t) + rdb := ptesting.NewRedisClient(t) + node := newTestNode(t, ctx, rdb, testName) + defer ptesting.CleanupRedis(t, rdb, true, testName) + + // Create a worker and dispatch a job + worker := newTestWorker(t, ctx, node) + require.NoError(t, node.DispatchJob(ctx, testName, []byte("payload"))) + + // Wait for the job to start and be acknowledged + require.Eventually(t, func() bool { + return len(worker.Jobs()) == 1 + }, max, delay) + + // Verify stream is created and cached + stream1, err := node.getOrCreateWorkerAckStream(ctx, node.ID) + require.NoError(t, err) + stream2, err := node.getOrCreateWorkerAckStream(ctx, node.ID) + require.NoError(t, err) + assert.Same(t, stream1, stream2, "Expected same stream instance to be returned") + + // Verify stream exists before shutdown + streamKey := "pulse:stream:" + nodeStreamName(testName, node.ID) + exists, err := rdb.Exists(ctx, streamKey).Result() + assert.NoError(t, err) + assert.Equal(t, int64(1), exists, "Expected stream to exist before shutdown") + + // Shutdown node + assert.NoError(t, node.Shutdown(ctx)) + + // Verify stream is destroyed in Redis + exists, err = rdb.Exists(ctx, streamKey).Result() + assert.NoError(t, err) + assert.Equal(t, int64(0), exists, "Expected stream to be destroyed after shutdown") +} + type mockAcker struct { XAckFunc func(ctx context.Context, streamKey, sinkName string, ids ...string) *redis.IntCmd } diff --git a/pool/worker.go b/pool/worker.go index a3eef85..5666fc6 100644 --- a/pool/worker.go +++ b/pool/worker.go @@ -291,22 +291,18 @@ func (w *Worker) notify(_ context.Context, key string, payload []byte) error { // ackPoolEvent acknowledges the pool event that originated from the node with // the given ID. func (w *Worker) ackPoolEvent(ctx context.Context, nodeID, eventID string, ackerr error) { - stream, ok := w.nodeStreams.Load(nodeID) - if !ok { - var err error - stream, err = streaming.NewStream(nodeStreamName(w.node.PoolName, nodeID), w.node.rdb, soptions.WithStreamLogger(w.logger)) - if err != nil { - w.logger.Error(fmt.Errorf("failed to create stream for node %q: %w", nodeID, err)) - return - } - w.nodeStreams.Store(nodeID, stream) + stream, err := w.node.getOrCreateWorkerAckStream(ctx, nodeID) + if err != nil { + w.logger.Error(fmt.Errorf("failed to get ack stream for node %q: %w", nodeID, err)) + return } + var msg string if ackerr != nil { msg = ackerr.Error() } ack := &ack{EventID: eventID, Error: msg} - if _, err := stream.(*streaming.Stream).Add(ctx, evAck, marshalEnvelope(w.ID, marshalAck(ack))); err != nil { + if _, err := stream.Add(ctx, evAck, marshalEnvelope(w.ID, marshalAck(ack))); err != nil { w.logger.Error(fmt.Errorf("failed to ack event %q from node %q: %w", eventID, nodeID, err)) } }