Skip to content

Commit

Permalink
stream: fix flatMap concurrency
Browse files Browse the repository at this point in the history
Co-Authored-By: Benjamin Gruenbaum <benjamingr@gmail.com>
  • Loading branch information
MoLow and benjamingr committed May 3, 2024
1 parent b876e00 commit 948b245
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 38 deletions.
92 changes: 56 additions & 36 deletions lib/internal/streams/operators.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ const {
PromisePrototypeThen,
PromiseReject,
PromiseResolve,
SafeSet,
Symbol,
SymbolAsyncIterator,
SymbolIterator,
} = primordials;

const { AbortController, AbortSignal } = require('internal/abort_controller');
Expand Down Expand Up @@ -39,6 +42,7 @@ const { isWritable, isNodeStream } = require('internal/streams/utils');

const kEmpty = Symbol('kEmpty');
const kEof = Symbol('kEof');
const kFlatMap = Symbol('kFlatMap');

function compose(stream, options) {
if (options != null) {
Expand Down Expand Up @@ -92,11 +96,20 @@ function map(fn, options) {

highWaterMark += concurrency;

const flatMap = options?.[kFlatMap] != null;

return async function* map() {
const signal = AbortSignal.any([options?.signal].filter(Boolean));
const stream = this;
const queue = [];
const signalOpt = { signal };
const baseIterator = (async function* baseIterator() {
for await (const value of stream) {
// wrap in an object to avoid awaitng if result is a promise
yield { result: fn(value, signalOpt) };
}
})();
const iterators = new SafeSet([baseIterator]);

let next;
let resume;
Expand Down Expand Up @@ -125,45 +138,54 @@ function map(fn, options) {
}
}

function addIterator(result) {
if (result && (result[SymbolAsyncIterator] || result[SymbolIterator])) {
const iterator = result[SymbolAsyncIterator] ? result[SymbolAsyncIterator]() : result[SymbolIterator]();
iterators.add(iterator);
return kEmpty;
}
return result;
}

async function pump() {
try {
for await (let val of stream) {
if (done) {
return;
}

if (signal.aborted) {
throw new AbortError();
}

try {
val = fn(val, signalOpt);

if (val === kEmpty) {
continue;
while (iterators.size > 0) {
for (const iterator of iterators) {
if (done) {
return;
}

val = PromiseResolve(val);
} catch (err) {
val = PromiseReject(err);
}

cnt += 1;
if (signal.aborted) {
throw new AbortError();
}
let val = PromisePrototypeThen(PromiseResolve(iterator.next()), ({ value, done }) => {
if (done) {
iterators.delete(iterator);
return kEmpty;
}
return iterator === baseIterator ? value.result : value;
});

PromisePrototypeThen(val, afterItemProcessed, onCatch);
if (flatMap && baseIterator === iterator) {
val = PromisePrototypeThen(val, addIterator);
}
PromisePrototypeThen(val, afterItemProcessed, onCatch);
cnt += 1;
queue.push(val);

queue.push(val);
if (next) {
next();
next = null;
}
if (next) {
next();
next = null;
}

if (!done && (queue.length >= highWaterMark || cnt >= concurrency)) {
await new Promise((resolve) => {
resume = resolve;
});
if (!done && (queue.length >= highWaterMark || cnt >= concurrency)) {
await new Promise((resolve) => {
resume = resolve;
});
}
}
}

queue.push(kEof);
} catch (err) {
const val = PromiseReject(err);
Expand Down Expand Up @@ -343,12 +365,10 @@ async function toArray(options) {
}

function flatMap(fn, options) {
const values = map.call(this, fn, options);
return async function* flatMap() {
for await (const val of values) {
yield* val;
}
}.call(this);
if (options != null) {
validateObject(options, 'options');
}
return map.call(this, fn, { ...options, [kFlatMap]: true });
}

function toIntegerOrInfinity(number) {
Expand Down
4 changes: 2 additions & 2 deletions test/parallel/test-stream-flatMap.js
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ function oneTo5() {
{
// Concurrency + AbortSignal
const ac = new AbortController();
const stream = oneTo5().flatMap(common.mustNotCall(async (_, { signal }) => {
const stream = oneTo5().flatMap(common.mustCall(async (_, { signal }) => {
await setTimeout(100, { signal });
}), { signal: ac.signal, concurrency: 2 });
}, 2), { signal: ac.signal, concurrency: 2 });
// pump
assert.rejects(async () => {
for await (const item of stream) {
Expand Down

0 comments on commit 948b245

Please sign in to comment.