diff --git a/core/src/main/java/org/apache/datafusion/CancellationToken.java b/core/src/main/java/org/apache/datafusion/CancellationToken.java new file mode 100644 index 0000000..dd02531 --- /dev/null +++ b/core/src/main/java/org/apache/datafusion/CancellationToken.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicLong; + +/** + * A handle that signals an in-flight {@link + * DataFrame#collect(org.apache.arrow.memory.BufferAllocator, CancellationToken)} or {@link + * DataFrame#executeStream(org.apache.arrow.memory.BufferAllocator, CancellationToken)} to abort. + * + *

Allocate via {@link SessionContext#newCancellationToken()}, pass to the desired call from one + * thread, and invoke {@link #cancel()} from another. The running query aborts at its next + * cooperative poll point. The exception type depends on the call site: {@code collect(..., token)} + * and pre-stream cancellation in {@code executeStream(..., token)} surface {@link + * CancellationException}; mid-stream cancellation in {@code executeStream} surfaces from {@link + * org.apache.arrow.vector.ipc.ArrowReader#loadNextBatch} as a {@link java.io.IOException} whose + * message contains {@code "query cancelled"} (the Arrow C-data wrapper hides the typed signal). See + * the {@code executeStream} Javadoc for the full contract. + * + *

The token is not bound to any particular DataFrame; the same token may be passed to several + * concurrent {@code collect} / {@code executeStream} calls, and {@link #cancel()} fires all of them + * at once. Once cancelled, {@link #isCancelled()} returns {@code true} permanently — to "reset", + * allocate a fresh token. This matches the underlying {@code tokio_util::sync::CancellationToken} + * contract. + * + *

Instances are safe to call {@link #cancel()} / {@link #isCancelled()} on from any thread, and + * must be {@link #close() closed} to release the native handle. {@code close()} is idempotent and a + * no-op once invoked. + */ +public final class CancellationToken implements AutoCloseable { + static { + NativeLibraryLoader.loadLibrary(); + } + + // Atomic so concurrent close + cancel/isCancelled/handle-pass-to-JNI cannot + // double-free or pass a stale handle to the native side. Reads observe either + // a live registry id or the post-close 0 sentinel; the close path uses + // getAndSet so only one thread can issue closeToken. + private final AtomicLong nativeHandle; + + CancellationToken(long nativeHandle) { + if (nativeHandle == 0) { + throw new IllegalArgumentException("CancellationToken native handle is null"); + } + this.nativeHandle = new AtomicLong(nativeHandle); + } + + /** + * The internal native handle, or {@code 0} if this token is closed. Package-private so {@link + * DataFrame} can pass it across JNI. The native side's registry lookup gracefully rejects a + * closed handle, so a race between {@code handle()} and {@link #close()} is bounded to a clean + * "closed" error rather than a use-after-free. + */ + long handle() { + return nativeHandle.get(); + } + + /** + * Signal the token. Any {@code collect} or {@code executeStream} call that received this token + * aborts at its next poll point. The thrown exception type depends on whether the cancel reaches + * the call before or after the JNI call returns — see the class-level Javadoc. Idempotent: + * subsequent calls are no-ops. + * + * @throws IllegalStateException if this token is closed. + */ + public void cancel() { + long h = nativeHandle.get(); + if (h == 0) { + throw new IllegalStateException("CancellationToken is closed"); + } + cancelToken(h); + } + + /** + * @return {@code true} if {@link #cancel()} has been invoked on this token, {@code false} + * otherwise. Non-blocking. + * @throws IllegalStateException if this token is closed. + */ + public boolean isCancelled() { + long h = nativeHandle.get(); + if (h == 0) { + throw new IllegalStateException("CancellationToken is closed"); + } + return isCancelledToken(h); + } + + /** + * Release the native handle. Idempotent. After {@code close()}, {@link #cancel()} and {@link + * #isCancelled()} throw {@link IllegalStateException}. Closing a token that already fired is + * harmless and does not cancel anything else; queries that already received the cancel signal + * remain aborted. + */ + @Override + public void close() { + long h = nativeHandle.getAndSet(0L); + if (h != 0) { + closeToken(h); + } + } + + private static native void cancelToken(long handle); + + private static native boolean isCancelledToken(long handle); + + private static native void closeToken(long handle); +} diff --git a/core/src/main/java/org/apache/datafusion/DataFrame.java b/core/src/main/java/org/apache/datafusion/DataFrame.java index 86dd523..2ee2faf 100644 --- a/core/src/main/java/org/apache/datafusion/DataFrame.java +++ b/core/src/main/java/org/apache/datafusion/DataFrame.java @@ -61,14 +61,29 @@ public final class DataFrame implements AutoCloseable { * {@link #executeStream(BufferAllocator)} for analytics-scale queries. */ public ArrowReader collect(BufferAllocator allocator) { + return collect(allocator, null); + } + + /** + * Execute the plan with cooperative cancellation. Identical to {@link #collect(BufferAllocator)} + * except that {@link CancellationToken#cancel()} on {@code token} from another thread aborts the + * call with a {@link java.util.concurrent.CancellationException} at the next poll point. + * + *

{@code token} may be {@code null}, in which case this overload behaves exactly like the + * single-argument form. + * + * @throws java.util.concurrent.CancellationException if the token is fired during the call. + */ + public ArrowReader collect(BufferAllocator allocator, CancellationToken token) { if (nativeHandle == 0) { throw new IllegalStateException("DataFrame is closed or already collected"); } + long tokenHandle = resolveTokenHandle(token); ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator); long handle = nativeHandle; nativeHandle = 0; try { - collectDataFrame(handle, stream.memoryAddress()); + collectDataFrame(handle, stream.memoryAddress(), tokenHandle); return Data.importArrayStream(allocator, stream); } catch (Throwable e) { stream.close(); @@ -91,14 +106,45 @@ public ArrowReader collect(BufferAllocator allocator) { * use this method. */ public ArrowReader executeStream(BufferAllocator allocator) { + return executeStream(allocator, null); + } + + /** + * Execute the plan as a streaming reader with cooperative cancellation. Identical to {@link + * #executeStream(BufferAllocator)} except that the returned reader holds the supplied {@code + * token} for its full lifetime: firing the token from another thread aborts the next {@link + * ArrowReader#loadNextBatch} call. + * + *

{@code token} may be {@code null}, in which case this overload behaves exactly like the + * single-argument form. + * + *

Cancellation surface (read carefully). The exception type depends on when + * the cancel fires, because the Arrow C-data stream layer wraps any underlying error before it + * reaches Java: + * + *

+ * + * @throws java.util.concurrent.CancellationException if the token fires before the stream is + * established. + */ + public ArrowReader executeStream(BufferAllocator allocator, CancellationToken token) { if (nativeHandle == 0) { throw new IllegalStateException("DataFrame is closed or already collected"); } + long tokenHandle = resolveTokenHandle(token); ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator); long handle = nativeHandle; nativeHandle = 0; try { - executeStreamDataFrame(handle, stream.memoryAddress()); + executeStreamDataFrame(handle, stream.memoryAddress(), tokenHandle); return Data.importArrayStream(allocator, stream); } catch (Throwable e) { stream.close(); @@ -106,6 +152,24 @@ public ArrowReader executeStream(BufferAllocator allocator) { } } + /** + * A {@code null} token disables cancellation; a non-null but already-closed token is rejected + * with {@link IllegalStateException}, matching how {@link CancellationToken#cancel()} and {@link + * CancellationToken#isCancelled()} behave on a closed token. Without this check, premature {@code + * close()} on a token would silently fall back to an uncancellable call, which is hard to + * diagnose. + */ + private static long resolveTokenHandle(CancellationToken token) { + if (token == null) { + return 0L; + } + long handle = token.handle(); + if (handle == 0L) { + throw new IllegalStateException("CancellationToken is closed"); + } + return handle; + } + /** Execute the plan and return the number of rows. */ public long count() { if (nativeHandle == 0) { @@ -358,9 +422,10 @@ public void close() { } } - private static native void collectDataFrame(long handle, long ffiStreamAddr); + private static native void collectDataFrame(long handle, long ffiStreamAddr, long tokenHandle); - private static native void executeStreamDataFrame(long handle, long ffiStreamAddr); + private static native void executeStreamDataFrame( + long handle, long ffiStreamAddr, long tokenHandle); private static native void closeDataFrame(long handle); diff --git a/core/src/main/java/org/apache/datafusion/SessionContext.java b/core/src/main/java/org/apache/datafusion/SessionContext.java index 674341a..0cc6f30 100644 --- a/core/src/main/java/org/apache/datafusion/SessionContext.java +++ b/core/src/main/java/org/apache/datafusion/SessionContext.java @@ -426,6 +426,24 @@ public DataFrame readAvro(String path, AvroReadOptions options) { return new DataFrame(dfHandle); } + /** + * Allocate a fresh {@link CancellationToken}. Pass it to {@link + * DataFrame#collect(BufferAllocator, CancellationToken)} or {@link + * DataFrame#executeStream(BufferAllocator, CancellationToken)} to make a query cancellable; fire + * the token from any thread to abort the in-flight call. + * + *

The same token may be reused across multiple concurrent calls; firing it cancels them all. + * Once fired, a token stays cancelled — allocate a fresh token for the next query. + * + *

The token is independent of this {@link SessionContext}: closing the context does not + * implicitly close outstanding tokens, and a token outliving its session is harmless (it just has + * nothing left to cancel). Always close tokens in their own try-with-resources to release the + * native handle. + */ + public CancellationToken newCancellationToken() { + return new CancellationToken(createCancellationToken()); + } + /** * Register a Java-implemented scalar UDF. After registration, the function can be invoked by SQL * via the UDF's name or referenced in DataFusion plans deserialised with {@link #fromProto}. @@ -559,6 +577,8 @@ private static native long readJsonWithOptions( private static native void registerScalarUdf( long handle, String name, byte[] signatureSchemaBytes, byte volatility, ScalarFunction impl); + private static native long createCancellationToken(); + private static native void registerTableNative( long handle, String name, byte[] schemaIpcBytes, TableProvider provider); } diff --git a/core/src/test/java/org/apache/datafusion/CancellationTokenTest.java b/core/src/test/java/org/apache/datafusion/CancellationTokenTest.java new file mode 100644 index 0000000..7fe191c --- /dev/null +++ b/core/src/test/java/org/apache/datafusion/CancellationTokenTest.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +class CancellationTokenTest { + + @Test + void freshTokenIsNotCancelled() { + try (SessionContext ctx = new SessionContext(); + CancellationToken token = ctx.newCancellationToken()) { + assertFalse(token.isCancelled()); + } + } + + @Test + void cancelMakesIsCancelledTrue() { + try (SessionContext ctx = new SessionContext(); + CancellationToken token = ctx.newCancellationToken()) { + token.cancel(); + assertTrue(token.isCancelled()); + } + } + + @Test + void cancelIsIdempotent() { + try (SessionContext ctx = new SessionContext(); + CancellationToken token = ctx.newCancellationToken()) { + token.cancel(); + token.cancel(); + token.cancel(); + assertTrue(token.isCancelled()); + } + } + + @Test + void closeIsIdempotent() { + try (SessionContext ctx = new SessionContext()) { + CancellationToken token = ctx.newCancellationToken(); + token.close(); + token.close(); + } + } + + @Test + void operationsAfterCloseThrow() { + try (SessionContext ctx = new SessionContext()) { + CancellationToken token = ctx.newCancellationToken(); + token.close(); + assertThrows(IllegalStateException.class, token::cancel); + assertThrows(IllegalStateException.class, token::isCancelled); + } + } + + @Test + void closeAfterCancelDoesNotThrow() { + try (SessionContext ctx = new SessionContext()) { + CancellationToken token = ctx.newCancellationToken(); + token.cancel(); + token.close(); + } + } + + @Test + void tokensFromSameSessionAreIndependent() { + try (SessionContext ctx = new SessionContext(); + CancellationToken a = ctx.newCancellationToken(); + CancellationToken b = ctx.newCancellationToken()) { + a.cancel(); + assertTrue(a.isCancelled()); + assertFalse(b.isCancelled()); + } + } + + @Test + @Timeout(value = 30, unit = TimeUnit.SECONDS) + void closeRacingWithCancelDoesNotCrashJvm() throws Exception { + // Stress the close()/cancel()/isCancelled() race that a per-token AtomicLong + // + native registry are designed to make safe. Without the atomic+registry + // pair, a thread could read a stale native pointer from a token that + // another thread has already closed and call into freed memory. + final int iterations = 200; + final int callers = 4; + ExecutorService pool = Executors.newCachedThreadPool(); + AtomicInteger illegalState = new AtomicInteger(); + AtomicInteger ok = new AtomicInteger(); + try (SessionContext ctx = new SessionContext()) { + for (int i = 0; i < iterations; i++) { + final CancellationToken token = ctx.newCancellationToken(); + CountDownLatch start = new CountDownLatch(1); + Runnable cancelOrCheck = + () -> { + try { + start.await(); + token.cancel(); + token.isCancelled(); + ok.incrementAndGet(); + } catch (IllegalStateException e) { + illegalState.incrementAndGet(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }; + Runnable closer = + () -> { + try { + start.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + token.close(); + }; + for (int c = 0; c < callers; c++) { + pool.submit(cancelOrCheck); + } + pool.submit(closer); + start.countDown(); + } + } finally { + pool.shutdown(); + assertTrue(pool.awaitTermination(20, TimeUnit.SECONDS)); + } + // Survival is the actual assertion: no JVM crash, no native panic. Either + // the cancel/isCancelled won the race (ok) or close did (IllegalStateException). + assertTrue(ok.get() + illegalState.get() > 0); + } + + @Test + void tokenOutlivesItsSession() { + CancellationToken token; + try (SessionContext ctx = new SessionContext()) { + token = ctx.newCancellationToken(); + } + // Closing the session does not invalidate the token; cancel() and + // isCancelled() must continue to work, and close() must not panic. + assertFalse(token.isCancelled()); + token.cancel(); + assertTrue(token.isCancelled()); + token.close(); + } +} diff --git a/core/src/test/java/org/apache/datafusion/DataFrameCancellationTest.java b/core/src/test/java/org/apache/datafusion/DataFrameCancellationTest.java new file mode 100644 index 0000000..d5ec228 --- /dev/null +++ b/core/src/test/java/org/apache/datafusion/DataFrameCancellationTest.java @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +class DataFrameCancellationTest { + + private static final ArrowType INT32 = new ArrowType.Int(32, true); + + /** + * Volatile UDF that sleeps {@code sleepMillis} per row and returns the input unchanged. Used to + * synthesise a query whose runtime is long enough for cancellation to land mid-flight without + * coupling tests to TPC-H or other large fixtures. {@link Volatility#VOLATILE} prevents the + * planner from constant-folding it away even with a literal input. + * + *

The latch is released by the first invocation, so the firing thread can wait until the query + * is genuinely in the UDF's hot path before calling cancel. + */ + static final class SleepingIdentity implements ScalarFunction { + private final long sleepMillis; + private final CountDownLatch firstInvocation; + + SleepingIdentity(long sleepMillis, CountDownLatch firstInvocation) { + this.sleepMillis = sleepMillis; + this.firstInvocation = firstInvocation; + } + + @Override + public String name() { + return "sleep_identity"; + } + + @Override + public List argTypes() { + return List.of(INT32); + } + + @Override + public ArrowType returnType() { + return INT32; + } + + @Override + public Volatility volatility() { + return Volatility.VOLATILE; + } + + @Override + public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) { + firstInvocation.countDown(); + try { + Thread.sleep(sleepMillis); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + IntVector in = (IntVector) args.args().get(0).vector(); + IntVector out = new IntVector("sleep_out", allocator); + int n = in.getValueCount(); + out.allocateNew(n); + for (int i = 0; i < n; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.set(i, in.get(i)); + } + } + out.setValueCount(n); + return ColumnarValue.array((FieldVector) out); + } + } + + @Test + void preCancelledTokenAbortsCollect() throws Exception { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext(); + CancellationToken token = ctx.newCancellationToken()) { + token.cancel(); + try (DataFrame df = ctx.sql("SELECT * FROM (VALUES (1), (2), (3)) AS t(x)")) { + assertThrows(CancellationException.class, () -> df.collect(allocator, token)); + } + } + } + + @Test + void preCancelledTokenAbortsExecuteStream() throws Exception { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext(); + CancellationToken token = ctx.newCancellationToken()) { + token.cancel(); + try (DataFrame df = ctx.sql("SELECT * FROM (VALUES (1), (2), (3)) AS t(x)")) { + assertThrows(CancellationException.class, () -> df.executeStream(allocator, token)); + } + } + } + + @Test + @Timeout(value = 30, unit = TimeUnit.SECONDS) + void cancelMidCollectAborts() throws Exception { + CountDownLatch firstInvocation = new CountDownLatch(1); + ExecutorService pool = Executors.newSingleThreadExecutor(); + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = SessionContext.builder().batchSize(1).build(); + CancellationToken token = ctx.newCancellationToken()) { + ctx.registerUdf(new ScalarUdf(new SleepingIdentity(50, firstInvocation))); + + // 1000 rows * 50ms ≈ 50s without cancel; cancel must abort well under the + // test timeout (30s). + DataFrame df = + ctx.sql("SELECT sleep_identity(CAST(value AS INT)) AS y FROM generate_series(1, 1000)"); + + Future future = pool.submit(() -> df.collect(allocator, token)); + assertTrue( + firstInvocation.await(10, TimeUnit.SECONDS), + "UDF should have been invoked at least once"); + token.cancel(); + + ExecutionException thrown = assertThrows(ExecutionException.class, future::get); + assertTrue( + thrown.getCause() instanceof CancellationException, + "expected CancellationException, got " + thrown.getCause()); + df.close(); + } finally { + pool.shutdownNow(); + pool.awaitTermination(5, TimeUnit.SECONDS); + } + } + + @Test + @Timeout(value = 30, unit = TimeUnit.SECONDS) + void cancelMidExecuteStreamAbortsNextLoadBatch() throws Exception { + CountDownLatch firstInvocation = new CountDownLatch(1); + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = SessionContext.builder().batchSize(1).build(); + CancellationToken token = ctx.newCancellationToken()) { + ctx.registerUdf(new ScalarUdf(new SleepingIdentity(50, firstInvocation))); + + DataFrame df = + ctx.sql("SELECT sleep_identity(CAST(value AS INT)) AS y FROM generate_series(1, 1000)"); + + try (ArrowReader reader = df.executeStream(allocator, token)) { + // Spawn a watcher that fires cancel once the UDF has been entered. The + // first loadNextBatch may legitimately succeed (DataFusion can have + // already produced one batch before we cancel), so we drain in a loop + // that asserts a cancel surfaces eventually. + // + // The Arrow C-Data stream layer wraps the native error in an + // IOException; we assert the message round-trip rather than the + // exception type so the test stays decoupled from Arrow's wrapper + // policy. Once a typed exception layer lands, this can tighten. + Thread canceller = + new Thread( + () -> { + try { + if (firstInvocation.await(10, TimeUnit.SECONDS)) { + token.cancel(); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }, + "cancel-watcher"); + canceller.start(); + try { + Throwable caught = + assertThrows( + Throwable.class, + () -> { + while (reader.loadNextBatch()) { + // drain — eventually loadNextBatch will throw on cancel. + } + }); + assertTrue( + caught.getMessage() != null && caught.getMessage().contains("query cancelled"), + "expected cancel message to surface, got: " + caught); + } finally { + canceller.join(); + } + } + } + } + + @Test + void closedTokenIsRejected() throws Exception { + // A closed token has a zeroed native handle. Treating it as "no token" would + // silently fall back to an uncancellable call -- the API instead must fail + // fast so premature close() is easy to diagnose. + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + CancellationToken token = ctx.newCancellationToken(); + token.close(); + try (DataFrame df = ctx.sql("SELECT 1")) { + assertThrows(IllegalStateException.class, () -> df.collect(allocator, token)); + } + try (DataFrame df = ctx.sql("SELECT 1")) { + assertThrows(IllegalStateException.class, () -> df.executeStream(allocator, token)); + } + } + } + + @Test + void nullTokenOverloadEquivalentToNoToken() throws Exception { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + try (DataFrame df = ctx.sql("SELECT * FROM (VALUES (1), (2), (3)) AS t(x)"); + ArrowReader reader = df.collect(allocator, null)) { + long total = 0; + while (reader.loadNextBatch()) { + total += reader.getVectorSchemaRoot().getRowCount(); + } + assertEquals(3L, total); + } + try (DataFrame df = ctx.sql("SELECT * FROM (VALUES (1), (2), (3)) AS t(x)"); + ArrowReader reader = df.executeStream(allocator, null)) { + long total = 0; + while (reader.loadNextBatch()) { + total += reader.getVectorSchemaRoot().getRowCount(); + } + assertEquals(3L, total); + } + } + } + + @Test + void unfiredTokenDoesNotAffectCollect() throws Exception { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext(); + CancellationToken token = ctx.newCancellationToken()) { + try (DataFrame df = ctx.sql("SELECT * FROM (VALUES (1), (2), (3), (4)) AS t(x)"); + ArrowReader reader = df.collect(allocator, token)) { + long total = 0; + while (reader.loadNextBatch()) { + total += reader.getVectorSchemaRoot().getRowCount(); + } + assertEquals(4L, total); + } + assertFalse(token.isCancelled()); + } + } + + @Test + void freshTokenAfterFirstCollectStillCancelsSecond() throws Exception { + // A token that has not yet fired remains usable across queries on the same + // session. Once fired, it stays fired -- a follow-up call with the same + // token aborts immediately. + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext(); + CancellationToken token = ctx.newCancellationToken()) { + try (DataFrame df1 = ctx.sql("SELECT 1"); + ArrowReader reader = df1.collect(allocator, token)) { + while (reader.loadNextBatch()) { + // drain + } + } + assertFalse(token.isCancelled()); + token.cancel(); + try (DataFrame df2 = ctx.sql("SELECT 2")) { + assertThrows(CancellationException.class, () -> df2.collect(allocator, token)); + } + } + } + + @Test + @Timeout(value = 30, unit = TimeUnit.SECONDS) + void sameTokenCancelsConcurrentCollects() throws Exception { + CountDownLatch firstInvocation = new CountDownLatch(1); + ExecutorService pool = Executors.newFixedThreadPool(3); + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = SessionContext.builder().batchSize(1).build(); + CancellationToken token = ctx.newCancellationToken()) { + ctx.registerUdf(new ScalarUdf(new SleepingIdentity(50, firstInvocation))); + + // Spawn 3 collects against the same token. SessionContext is documented as + // not thread-safe for concurrent use, so each task constructs its own + // DataFrame *before* we kick off (sql() runs on the test thread inside + // submit's lambda capture: we must serialise that part). + List dfs = new java.util.ArrayList<>(); + for (int i = 0; i < 3; i++) { + dfs.add( + ctx.sql( + "SELECT sleep_identity(CAST(value AS INT)) AS y FROM generate_series(1, 1000)")); + } + List> futures = new java.util.ArrayList<>(); + for (DataFrame df : dfs) { + futures.add(pool.submit(() -> df.collect(allocator, token))); + } + + assertTrue(firstInvocation.await(10, TimeUnit.SECONDS)); + token.cancel(); + + for (Future f : futures) { + ExecutionException thrown = assertThrows(ExecutionException.class, f::get); + assertTrue( + thrown.getCause() instanceof CancellationException, + "expected CancellationException, got " + thrown.getCause()); + } + for (DataFrame df : dfs) { + df.close(); + } + } finally { + pool.shutdownNow(); + pool.awaitTermination(5, TimeUnit.SECONDS); + } + } + + @Test + void countAndShowResolveBatchVectors() throws Exception { + // Smoke test: BigIntVector is reachable so the import works on this JDK. + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext(); + DataFrame df = ctx.sql("SELECT 1"); + ArrowReader reader = df.collect(allocator)) { + assertTrue(reader.loadNextBatch()); + assertTrue(reader.getVectorSchemaRoot().getVector(0) instanceof BigIntVector); + } + } +} diff --git a/native/Cargo.lock b/native/Cargo.lock index 7171f72..e228fb9 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1259,6 +1259,7 @@ dependencies = [ "prost-build", "protoc-bin-vendored", "tokio", + "tokio-util", ] [[package]] @@ -3081,6 +3082,7 @@ dependencies = [ "bytes", "futures-core", "futures-sink", + "futures-util", "pin-project-lite", "tokio", ] diff --git a/native/Cargo.toml b/native/Cargo.toml index 28e1e8f..b45b645 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -32,7 +32,8 @@ datafusion-proto = "53.1.0" futures = "0.3" jni = "0.21" prost = "0.14" -tokio = { version = "1", features = ["rt-multi-thread"] } +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } +tokio-util = { version = "0.7", features = ["rt"] } [build-dependencies] prost-build = "0.14" diff --git a/native/src/cancellation.rs b/native/src/cancellation.rs new file mode 100644 index 0000000..02a18c8 --- /dev/null +++ b/native/src/cancellation.rs @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Cancellation tokens for in-flight queries. +//! +//! Java handles are opaque `u64` IDs, not raw pointers. A process-global +//! registry maps each live ID to its `Arc`. JNI handlers +//! look up by ID and clone the `Arc` out of the registry (under a lock) so a +//! concurrent close can never invalidate a borrow already in flight: the worst +//! a race produces is a clean "closed" error from a missing-ID lookup, never +//! a use-after-free of a freed `Box`. +//! +//! This is the same scaffolding upstream's open close()-race issue calls for +//! across all handle types; it is applied here first because cancellation +//! tokens are designed to be fired from a thread that does not own them, so +//! the race window is the widest of any handle in the binding. + +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex, OnceLock}; + +use jni::objects::JClass; +use jni::sys::{jboolean, jlong}; +use jni::JNIEnv; +use tokio_util::sync::CancellationToken; + +use crate::errors::{try_unwrap_or_throw, JniResult}; + +fn registry() -> &'static Mutex>> { + static REG: OnceLock>>> = OnceLock::new(); + REG.get_or_init(|| Mutex::new(HashMap::new())) +} + +fn next_id() -> u64 { + // Start at 1 so jlong 0 stays reserved as the "no token" sentinel on the + // Java side. Monotonic; 2^64 IDs is enough that reuse is never observed. + static COUNTER: AtomicU64 = AtomicU64::new(1); + COUNTER.fetch_add(1, Ordering::Relaxed) +} + +/// Look up the `Arc` for `handle`. Returns `None` if the +/// handle is zero (no token) or has already been closed -- *not* an error, +/// because callers want to distinguish those cases. The cloned `Arc` keeps the +/// inner `CancellationToken` alive for the borrow's lifetime, so a concurrent +/// `closeToken` removing the registry entry is safe: the entry's drop just +/// decrements one of several `Arc` counts. +pub(crate) fn token_arc(handle: jlong) -> Option> { + if handle == 0 { + return None; + } + let id = handle as u64; + let guard = registry().lock().expect("cancellation registry poisoned"); + guard.get(&id).cloned() +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_datafusion_SessionContext_createCancellationToken<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, +) -> jlong { + try_unwrap_or_throw(&mut env, 0, |_env| -> JniResult { + let token: Arc = Arc::new(CancellationToken::new()); + let id = next_id(); + let mut guard = registry().lock().expect("cancellation registry poisoned"); + guard.insert(id, token); + Ok(id as jlong) + }) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_datafusion_CancellationToken_cancelToken<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, +) { + try_unwrap_or_throw(&mut env, (), |_env| -> JniResult<()> { + match token_arc(handle) { + Some(t) => { + t.cancel(); + Ok(()) + } + None => Err("CancellationToken is closed".into()), + } + }) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_datafusion_CancellationToken_isCancelledToken<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, +) -> jboolean { + try_unwrap_or_throw(&mut env, 0, |_env| -> JniResult { + match token_arc(handle) { + Some(t) => Ok(if t.is_cancelled() { 1 } else { 0 }), + None => Err("CancellationToken is closed".into()), + } + }) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_datafusion_CancellationToken_closeToken<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, +) { + try_unwrap_or_throw(&mut env, (), |_env| -> JniResult<()> { + if handle == 0 { + return Ok(()); + } + let id = handle as u64; + // Remove (rather than `drop` a raw Box) -- the underlying Arc may still + // have outstanding clones held by an in-flight collect/executeStream + // future, and those keep the inner token alive until they finish. + let mut guard = registry().lock().expect("cancellation registry poisoned"); + guard.remove(&id); + Ok(()) + }) +} diff --git a/native/src/errors.rs b/native/src/errors.rs index e779bb7..da47c23 100644 --- a/native/src/errors.rs +++ b/native/src/errors.rs @@ -23,6 +23,13 @@ use jni::JNIEnv; pub type JniResult = Result>; +/// Error message used to round-trip a query-cancellation signal from a JNI +/// handler to the Java side. `try_unwrap_or_throw` matches on this string and +/// throws a `java.util.concurrent.CancellationException` instead of the default +/// `RuntimeException`. Keep this stable across the codebase so the cancellation +/// path stays grep-friendly. +pub const CANCELLED_MESSAGE: &str = "datafusion-java: query cancelled"; + pub fn try_unwrap_or_throw(env: &mut JNIEnv, default: T, f: F) -> T where F: FnOnce(&mut JNIEnv) -> JniResult, @@ -30,21 +37,26 @@ where match catch_unwind(AssertUnwindSafe(|| f(env))) { Ok(Ok(value)) => value, Ok(Err(err)) => { - throw_runtime_exception(env, &err.to_string()); + throw_for_message(env, &err.to_string()); default } Err(panic) => { - throw_runtime_exception(env, &panic_message(&panic)); + throw_for_message(env, &panic_message(&panic)); default } } } -fn throw_runtime_exception(env: &mut JNIEnv, message: &str) { +fn throw_for_message(env: &mut JNIEnv, message: &str) { if env.exception_check().unwrap_or(false) { return; } - let _ = env.throw_new("java/lang/RuntimeException", message); + let class = if message.contains(CANCELLED_MESSAGE) { + "java/util/concurrent/CancellationException" + } else { + "java/lang/RuntimeException" + }; + let _ = env.throw_new(class, message); } fn panic_message(panic: &Box) -> String { diff --git a/native/src/lib.rs b/native/src/lib.rs index a235cd3..74db8d3 100644 --- a/native/src/lib.rs +++ b/native/src/lib.rs @@ -17,6 +17,7 @@ mod arrow; mod avro; +mod cancellation; mod csv; mod errors; mod jni_util; @@ -56,11 +57,28 @@ use jni::JavaVM; use prost::Message; use tokio::runtime::Runtime; -use crate::errors::{try_unwrap_or_throw, JniResult}; +use crate::cancellation::token_arc; +use crate::errors::{try_unwrap_or_throw, JniResult, CANCELLED_MESSAGE}; use crate::proto_gen::ParquetReadOptionsProto; use crate::proto_gen::SessionOptions; use crate::schema::decode_optional_schema; +/// Resolve a `jlong` cancellation-token handle into an optional `Arc`. A zero +/// handle means "no token" and yields `Ok(None)`. A non-zero but already-closed +/// handle yields `Err`, matching the Java overload's contract that a closed +/// token is rejected (rather than silently treated as no token). +fn resolve_token( + token_handle: jlong, +) -> JniResult>> { + if token_handle == 0 { + return Ok(None); + } + match token_arc(token_handle) { + Some(t) => Ok(Some(t)), + None => Err("CancellationToken is closed".into()), + } +} + static JAVA_VM: OnceLock = OnceLock::new(); #[no_mangle] @@ -187,6 +205,7 @@ pub extern "system" fn Java_org_apache_datafusion_DataFrame_collectDataFrame<'lo _class: JClass<'local>, handle: jlong, ffi_stream_addr: jlong, + token_handle: jlong, ) { try_unwrap_or_throw(&mut env, (), |_env| -> JniResult<()> { if handle == 0 { @@ -196,10 +215,19 @@ pub extern "system" fn Java_org_apache_datafusion_DataFrame_collectDataFrame<'lo return Err("ffi stream address is null".into()); } let df = unsafe { *Box::from_raw(handle as *mut DataFrame) }; + let token = resolve_token(token_handle)?; let ffi: FFI_ArrowArrayStream = runtime().block_on(async { let schema: SchemaRef = Arc::new(df.schema().as_arrow().clone()); - let batches = df.collect().await?; + let collect_fut = df.collect(); + let batches = match &token { + Some(t) => tokio::select! { + biased; + _ = t.cancelled() => Err(DataFusionError::Execution(CANCELLED_MESSAGE.into())), + r = collect_fut => r, + }, + None => collect_fut.await, + }?; let iter = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); Ok::<_, DataFusionError>(FFI_ArrowArrayStream::new(Box::new(iter))) })?; @@ -215,10 +243,13 @@ pub extern "system" fn Java_org_apache_datafusion_DataFrame_collectDataFrame<'lo /// [`RecordBatchReader`] interface that `FFI_ArrowArrayStream` (and therefore /// the Java `ArrowReader`) consumes. Each call to `next()` drives one /// `runtime().block_on(stream.next())`, so memory pressure stays bounded by the -/// executor pipeline plus a single in-flight batch. +/// executor pipeline plus a single in-flight batch. When a cancellation token +/// is attached, each batch poll races against `token.cancelled()` so a cancel +/// from another thread aborts on the next `loadNextBatch`. struct StreamingReader { schema: SchemaRef, stream: SendableRecordBatchStream, + token: Option>, } impl Iterator for StreamingReader { @@ -230,7 +261,20 @@ impl Iterator for StreamingReader { // here (buggy UDF, arrow cast that panics, runtime poison) would // unwind across C/FFI -- undefined behaviour. Catch it and surface as // an ArrowError so the Java side sees a normal exception instead. - let next = catch_unwind(AssertUnwindSafe(|| runtime().block_on(self.stream.next()))); + let next = catch_unwind(AssertUnwindSafe(|| { + runtime().block_on(async { + match &self.token { + Some(t) => tokio::select! { + biased; + _ = t.cancelled() => Some(Err(datafusion::error::DataFusionError::Execution( + CANCELLED_MESSAGE.into(), + ))), + r = self.stream.next() => r, + }, + None => self.stream.next().await, + } + }) + })); match next { Ok(item) => item.map(|r| r.map_err(|e| ArrowError::ExternalError(Box::new(e)))), Err(panic) => { @@ -261,6 +305,7 @@ pub extern "system" fn Java_org_apache_datafusion_DataFrame_executeStreamDataFra _class: JClass<'local>, handle: jlong, ffi_stream_addr: jlong, + token_handle: jlong, ) { try_unwrap_or_throw(&mut env, (), |_env| -> JniResult<()> { if handle == 0 { @@ -270,11 +315,24 @@ pub extern "system" fn Java_org_apache_datafusion_DataFrame_executeStreamDataFra return Err("ffi stream address is null".into()); } let df = unsafe { *Box::from_raw(handle as *mut DataFrame) }; + let token = resolve_token(token_handle)?; let ffi: FFI_ArrowArrayStream = runtime().block_on(async { let schema: SchemaRef = Arc::new(df.schema().as_arrow().clone()); - let stream = df.execute_stream().await?; - let reader = StreamingReader { schema, stream }; + let stream_fut = df.execute_stream(); + let stream = match &token { + Some(t) => tokio::select! { + biased; + _ = t.cancelled() => Err(DataFusionError::Execution(CANCELLED_MESSAGE.into())), + r = stream_fut => r, + }, + None => stream_fut.await, + }?; + let reader = StreamingReader { + schema, + stream, + token: token.clone(), + }; Ok::<_, DataFusionError>(FFI_ArrowArrayStream::new(Box::new(reader))) })?;