diff --git a/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs b/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
index e9cd9361..887036c2 100644
--- a/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
+++ b/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
@@ -87,7 +87,7 @@ public override RecordBatch ReadNextRecordBatch()
else if (messageLength == MessageSerializer.IpcContinuationToken)
{
// ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length
- if (_buffer.Length <= _bufferPosition + sizeof(int))
+ if (_buffer.Length < _bufferPosition + sizeof(int))
{
throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message.");
}
@@ -136,7 +136,7 @@ public override void ReadSchema()
if (schemaMessageLength == MessageSerializer.IpcContinuationToken)
{
// ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length
- if (_buffer.Length <= _bufferPosition + sizeof(int))
+ if (_buffer.Length < _bufferPosition + sizeof(int))
{
throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message.");
}
diff --git a/src/Apache.Arrow/Ipc/ArrowMemoryStreamReaderImplementation.cs b/src/Apache.Arrow/Ipc/ArrowMemoryStreamReaderImplementation.cs
new file mode 100644
index 00000000..01d3b0b5
--- /dev/null
+++ b/src/Apache.Arrow/Ipc/ArrowMemoryStreamReaderImplementation.cs
@@ -0,0 +1,214 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System;
+using System.Buffers;
+using System.IO;
+using System.Threading;
+using System.Threading.Tasks;
+using Apache.Arrow.Memory;
+
+namespace Apache.Arrow.Ipc
+{
+ ///
+ /// Reads Arrow IPC streams from a whose backing buffer is publicly visible.
+ ///
+ ///
+ /// Message metadata can be read directly from the exposed stream buffer, but record batch bodies are
+ /// still copied into allocator-owned buffers to preserve ownership semantics.
+ ///
+ internal sealed class ArrowMemoryStreamReaderImplementation : ArrowStreamReaderImplementation
+ {
+ private readonly MemoryStream _stream;
+ private readonly Memory _streamMemory;
+
+ public ArrowMemoryStreamReaderImplementation(
+ MemoryStream stream,
+ MemoryAllocator allocator,
+ ICompressionCodecFactory compressionCodecFactory,
+ bool leaveOpen,
+ ExtensionTypeRegistry extensionRegistry)
+ : base(stream, allocator, compressionCodecFactory, leaveOpen, extensionRegistry)
+ {
+ _stream = stream;
+
+ if (!stream.TryGetBuffer(out ArraySegment streamBuffer))
+ {
+ throw new InvalidOperationException("Expected MemoryStream to expose its backing buffer.");
+ }
+
+ _streamMemory = streamBuffer.Array.AsMemory(streamBuffer.Offset, streamBuffer.Count);
+ }
+
+ public override ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken)
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+
+ try
+ {
+ return new ValueTask(ReadNextRecordBatch());
+ }
+ catch (Exception ex)
+ {
+ return new ValueTask(Task.FromException(ex));
+ }
+ }
+
+ public override RecordBatch ReadNextRecordBatch()
+ {
+ ReadSchema();
+
+ ReadResult result = default;
+ do
+ {
+ result = ReadMessageFromMemory();
+ } while (result.Batch == null && result.MessageLength > 0);
+
+ return result.Batch;
+ }
+
+ public override ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default)
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+
+ if (HasReadSchema)
+ {
+ return new ValueTask(_schema);
+ }
+
+ try
+ {
+ ReadSchema();
+ return new ValueTask(_schema);
+ }
+ catch (Exception ex)
+ {
+ return new ValueTask(Task.FromException(ex));
+ }
+ }
+
+ public override void ReadSchema()
+ {
+ if (HasReadSchema)
+ {
+ return;
+ }
+
+ int schemaMessageLength = ReadMessageLengthFromMemory(throwOnFullRead: true, returnOnEmptyStream: true);
+ if (schemaMessageLength == 0)
+ {
+ return;
+ }
+
+ Memory schemaBuffer = ReadMemory(schemaMessageLength);
+ _schema = MessageSerializer.GetSchema(ReadMessage(CreateByteBuffer(schemaBuffer)), ref _dictionaryMemo, _extensionRegistry);
+ }
+
+ private ReadResult ReadMessageFromMemory()
+ {
+ int messageLength = ReadMessageLengthFromMemory(throwOnFullRead: false, returnOnEmptyStream: false);
+ if (messageLength == 0)
+ {
+ return default;
+ }
+
+ Memory messageBuffer = ReadMemory(messageLength);
+ Flatbuf.Message message = Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuffer));
+
+ if (message.BodyLength > int.MaxValue)
+ {
+ throw new OverflowException(
+ $"Arrow IPC message body length ({message.BodyLength}) is larger than " +
+ $"the maximum supported message size ({int.MaxValue})");
+ }
+
+ int bodyLength = (int)message.BodyLength;
+ Memory sourceBodyBuffer = ReadMemory(bodyLength);
+ IMemoryOwner bodyBufferOwner = AllocateMessageBodyBuffer(bodyLength);
+ Memory bodyBuffer = bodyBufferOwner.Memory.Slice(0, bodyLength);
+ sourceBodyBuffer.CopyTo(bodyBuffer);
+ Google.FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuffer);
+
+ // Keep stream-reader ownership semantics: batches outlive the source MemoryStream buffer.
+ return new ReadResult(messageLength, CreateArrowObjectFromMessage(message, bodybb, bodyBufferOwner));
+ }
+
+ private int ReadMessageLengthFromMemory(bool throwOnFullRead, bool returnOnEmptyStream)
+ {
+ if (_stream.Position == _stream.Length && returnOnEmptyStream)
+ {
+ return 0;
+ }
+
+ if (!TryReadInt32(throwOnFullRead, out int messageLength))
+ {
+ return 0;
+ }
+
+ if (messageLength == MessageSerializer.IpcContinuationToken &&
+ !TryReadInt32(throwOnFullRead, out messageLength))
+ {
+ return 0;
+ }
+
+ return messageLength;
+ }
+
+ private bool TryReadInt32(bool throwOnFullRead, out int value)
+ {
+ value = 0;
+
+ if (!TryReadMemory(sizeof(int), throwOnFullRead, out Memory buffer))
+ {
+ return false;
+ }
+
+ value = BitUtility.ReadInt32(buffer);
+ return true;
+ }
+
+ private bool TryReadMemory(int length, bool throwOnFullRead, out Memory buffer)
+ {
+ buffer = default;
+
+ long remainingLength = _stream.Length - _stream.Position;
+ if (remainingLength < length)
+ {
+ if (throwOnFullRead)
+ {
+ throw new InvalidOperationException("Unexpectedly reached the end of the stream before a full buffer was read.");
+ }
+
+ _stream.Position = _stream.Length;
+ return false;
+ }
+
+ buffer = ReadMemory(length);
+ return true;
+ }
+
+ private Memory ReadMemory(int length)
+ {
+ if (length == 0)
+ {
+ return Memory.Empty;
+ }
+
+ Memory buffer = _streamMemory.Slice(checked((int)_stream.Position), length);
+ _stream.Position += length;
+ return buffer;
+ }
+ }
+}
diff --git a/src/Apache.Arrow/Ipc/ArrowStreamReader.cs b/src/Apache.Arrow/Ipc/ArrowStreamReader.cs
index e5dade2b..6100eea3 100644
--- a/src/Apache.Arrow/Ipc/ArrowStreamReader.cs
+++ b/src/Apache.Arrow/Ipc/ArrowStreamReader.cs
@@ -68,7 +68,7 @@ public ArrowStreamReader(Stream stream, MemoryAllocator allocator, ICompressionC
if (stream == null)
throw new ArgumentNullException(nameof(stream));
- _implementation = new ArrowStreamReaderImplementation(stream, allocator, compressionCodecFactory, leaveOpen);
+ _implementation = CreateImplementation(stream, allocator, compressionCodecFactory, leaveOpen, extensionRegistry: null);
}
public ArrowStreamReader(ArrowContext context, Stream stream, bool leaveOpen = false)
@@ -78,7 +78,7 @@ public ArrowStreamReader(ArrowContext context, Stream stream, bool leaveOpen = f
if (context == null)
throw new ArgumentNullException(nameof(context));
- _implementation = new ArrowStreamReaderImplementation(stream, context.Allocator, context.CompressionCodecFactory, leaveOpen, context.ExtensionRegistry);
+ _implementation = CreateImplementation(stream, context.Allocator, context.CompressionCodecFactory, leaveOpen, context.ExtensionRegistry);
}
public ArrowStreamReader(ReadOnlyMemory buffer)
@@ -104,6 +104,21 @@ private protected ArrowStreamReader(ArrowReaderImplementation implementation)
_implementation = implementation;
}
+ private static ArrowReaderImplementation CreateImplementation(
+ Stream stream,
+ MemoryAllocator allocator,
+ ICompressionCodecFactory compressionCodecFactory,
+ bool leaveOpen,
+ ExtensionTypeRegistry extensionRegistry)
+ {
+ if (stream is MemoryStream memoryStream && memoryStream.TryGetBuffer(out _))
+ {
+ return new ArrowMemoryStreamReaderImplementation(memoryStream, allocator, compressionCodecFactory, leaveOpen, extensionRegistry);
+ }
+
+ return new ArrowStreamReaderImplementation(stream, allocator, compressionCodecFactory, leaveOpen, extensionRegistry);
+ }
+
public void Dispose()
{
Dispose(true);
diff --git a/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs b/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
index 9d0cbe33..179b9984 100644
--- a/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
+++ b/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
@@ -47,11 +47,10 @@ protected override void Dispose(bool disposing)
}
}
- public override async ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken)
+ public override ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken)
{
- // TODO: Loop until a record batch is read.
cancellationToken.ThrowIfCancellationRequested();
- return await ReadRecordBatchAsync(cancellationToken).ConfigureAwait(false);
+ return ReadRecordBatchAsync(cancellationToken);
}
public override RecordBatch ReadNextRecordBatch()
@@ -61,7 +60,7 @@ public override RecordBatch ReadNextRecordBatch()
protected async ValueTask ReadRecordBatchAsync(CancellationToken cancellationToken = default)
{
- await ReadSchemaAsync().ConfigureAwait(false);
+ await ReadSchemaAsync(cancellationToken).ConfigureAwait(false);
ReadResult result = default;
do
@@ -94,7 +93,7 @@ protected async ValueTask ReadMessageAsync(CancellationToken cancell
int bodyLength = checked((int)message.BodyLength);
- IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength);
+ IMemoryOwner bodyBuffOwner = AllocateMessageBodyBuffer(bodyLength);
Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength);
bytesRead = await BaseStream.ReadFullBufferAsync(bodyBuff, cancellationToken)
.ConfigureAwait(false);
@@ -145,7 +144,7 @@ protected ReadResult ReadMessage()
}
int bodyLength = (int)message.BodyLength;
- IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength);
+ IMemoryOwner bodyBuffOwner = AllocateMessageBodyBuffer(bodyLength);
Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength);
bytesRead = BaseStream.ReadFullBuffer(bodyBuff);
EnsureFullRead(bodyBuff, bytesRead);
@@ -157,13 +156,25 @@ protected ReadResult ReadMessage()
return new ReadResult(messageLength, result);
}
- public override async ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default)
+ protected IMemoryOwner AllocateMessageBodyBuffer(int bodyLength)
{
+ return _allocator.Allocate(bodyLength);
+ }
+
+ public override ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default)
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+
if (HasReadSchema)
{
- return _schema;
+ return new ValueTask(_schema);
}
+ return ReadSchemaAsyncCore(cancellationToken);
+ }
+
+ private async ValueTask ReadSchemaAsyncCore(CancellationToken cancellationToken)
+ {
// Figure out length of schema
int schemaMessageLength = await ReadMessageLengthAsync(throwOnFullRead: true, returnOnEmptyStream: true, cancellationToken)
.ConfigureAwait(false);
diff --git a/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs b/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs
index 3760d940..9305adb1 100644
--- a/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs
+++ b/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs
@@ -14,8 +14,8 @@
// limitations under the License.
using System;
+using System.Collections.Generic;
using System.IO;
-using System.Linq;
using System.Threading.Tasks;
using Apache.Arrow.Ipc;
using Apache.Arrow.Memory;
@@ -32,13 +32,19 @@ public class ArrowReaderBenchmark
[Params(10_000, 1_000_000)]
public int Count { get; set; }
+ [Params(1, 5)]
+ public int ColumnSetCount { get; set; }
+
private MemoryStream _memoryStream;
private static readonly MemoryAllocator s_allocator = new TestMemoryAllocator();
[GlobalSetup]
public async Task GlobalSetup()
{
- RecordBatch batch = TestData.CreateSampleRecordBatch(length: Count, createDictionaryArray: false);
+ RecordBatch batch = TestData.CreateSampleRecordBatch(
+ length: Count,
+ columnSetCount: ColumnSetCount,
+ excludedTypes: new HashSet { ArrowTypeId.Dictionary, ArrowTypeId.RunEndEncoded });
_memoryStream = new MemoryStream();
ArrowStreamWriter writer = new ArrowStreamWriter(_memoryStream, batch.Schema);
@@ -83,6 +89,73 @@ public async Task ArrowReaderWithMemoryStream_ManagedMemory()
return sum;
}
+ [Benchmark]
+ public async Task ArrowReaderWithMemoryStream_ExplicitDefaultAllocator()
+ {
+ double sum = 0;
+ var reader = new ArrowStreamReader(_memoryStream, MemoryAllocator.Default.Value);
+ RecordBatch recordBatch;
+ while ((recordBatch = await reader.ReadNextRecordBatchAsync()) != null)
+ {
+ using (recordBatch)
+ {
+ sum += SumAllNumbers(recordBatch);
+ }
+ }
+ return sum;
+ }
+
+ [Benchmark]
+ public async Task ArrowReaderWithNonPubliclyVisibleMemoryStream()
+ {
+ double sum = 0;
+ using var stream = CreateNonPubliclyVisibleReadStream();
+ using var reader = new ArrowStreamReader(stream);
+ RecordBatch recordBatch;
+ while ((recordBatch = await reader.ReadNextRecordBatchAsync()) != null)
+ {
+ using (recordBatch)
+ {
+ sum += SumAllNumbers(recordBatch);
+ }
+ }
+ return sum;
+ }
+
+ [Benchmark]
+ public async Task ArrowReaderWithNonPubliclyVisibleMemoryStream_ManagedMemory()
+ {
+ double sum = 0;
+ using var stream = CreateNonPubliclyVisibleReadStream();
+ using var reader = new ArrowStreamReader(stream, s_allocator);
+ RecordBatch recordBatch;
+ while ((recordBatch = await reader.ReadNextRecordBatchAsync()) != null)
+ {
+ using (recordBatch)
+ {
+ sum += SumAllNumbers(recordBatch);
+ }
+ }
+ return sum;
+ }
+
+ [Benchmark]
+ public async Task ArrowReaderWithNonPubliclyVisibleMemoryStream_ExplicitDefaultAllocator()
+ {
+ double sum = 0;
+ using var stream = CreateNonPubliclyVisibleReadStream();
+ using var reader = new ArrowStreamReader(stream, MemoryAllocator.Default.Value);
+ RecordBatch recordBatch;
+ while ((recordBatch = await reader.ReadNextRecordBatchAsync()) != null)
+ {
+ using (recordBatch)
+ {
+ sum += SumAllNumbers(recordBatch);
+ }
+ }
+ return sum;
+ }
+
[Benchmark]
public async Task ArrowReaderWithMemory()
{
@@ -99,14 +172,25 @@ public async Task ArrowReaderWithMemory()
return sum;
}
+ private MemoryStream CreateNonPubliclyVisibleReadStream()
+ {
+ return new MemoryStream(
+ _memoryStream.GetBuffer(),
+ index: 0,
+ count: checked((int)_memoryStream.Length),
+ writable: false,
+ publiclyVisible: false);
+ }
+
private static double SumAllNumbers(RecordBatch recordBatch)
{
double sum = 0;
for (int k = 0; k < recordBatch.ColumnCount; k++)
{
- var array = recordBatch.Arrays.ElementAt(k);
- switch (recordBatch.Schema.GetFieldByIndex(k).DataType.TypeId)
+ var array = recordBatch.Column(k);
+ ArrowTypeId typeId = recordBatch.Schema.GetFieldByIndex(k).DataType.TypeId;
+ switch (typeId)
{
case ArrowTypeId.Int64:
Int64Array int64Array = (Int64Array)array;
diff --git a/test/Apache.Arrow.Compression.Tests/ArrowStreamReaderTests.cs b/test/Apache.Arrow.Compression.Tests/ArrowStreamReaderTests.cs
index 9c2bf75d..99e74424 100644
--- a/test/Apache.Arrow.Compression.Tests/ArrowStreamReaderTests.cs
+++ b/test/Apache.Arrow.Compression.Tests/ArrowStreamReaderTests.cs
@@ -14,8 +14,10 @@
// limitations under the License.
using System;
+using System.IO;
using System.Reflection;
using Apache.Arrow.Ipc;
+using Apache.Arrow.Memory;
using Apache.Arrow.Tests;
using Xunit;
@@ -47,13 +49,34 @@ public void CanReadCompressedIpcStreamFromMemoryBuffer(string fileName)
using var stream = assembly.GetManifestResourceStream($"Apache.Arrow.Compression.Tests.Resources.{fileName}");
Assert.NotNull(stream);
var buffer = new byte[stream.Length];
- stream.ReadFullBuffer(buffer);
+ ReadExactly(stream, buffer);
var codecFactory = new Compression.CompressionCodecFactory();
using var reader = new ArrowStreamReader(buffer, codecFactory);
VerifyCompressedIpcFileBatch(reader.ReadNextRecordBatch());
}
+ [Theory]
+ [InlineData("ipc_lz4_compression.arrow_stream")]
+ [InlineData("ipc_zstd_compression.arrow_stream")]
+ public void CanReadCompressedIpcStreamFromMemoryBuffer_UsesDefaultAllocator(string fileName)
+ {
+ var assembly = Assembly.GetExecutingAssembly();
+ using var stream = assembly.GetManifestResourceStream($"Apache.Arrow.Compression.Tests.Resources.{fileName}");
+ Assert.NotNull(stream);
+ var buffer = new byte[stream.Length];
+ ReadExactly(stream, buffer);
+ var codecFactory = new Compression.CompressionCodecFactory();
+
+ long allocationsBeforeRead = MemoryAllocator.Default.Value.Statistics.Allocations;
+
+ using var reader = new ArrowStreamReader(buffer, codecFactory);
+ using RecordBatch batch = reader.ReadNextRecordBatch();
+ VerifyCompressedIpcFileBatch(batch);
+
+ Assert.True(MemoryAllocator.Default.Value.Statistics.Allocations > allocationsBeforeRead);
+ }
+
[Fact]
public void ErrorReadingCompressedStreamWithoutCodecFactory()
{
@@ -86,6 +109,21 @@ public void MemoryPoolDisposedOnReadCompressedIpcStream(string fileName)
}
+ private static void ReadExactly(Stream stream, byte[] buffer)
+ {
+ int offset = 0;
+ while (offset < buffer.Length)
+ {
+ int bytesRead = stream.Read(buffer, offset, buffer.Length - offset);
+ if (bytesRead == 0)
+ {
+ throw new EndOfStreamException();
+ }
+
+ offset += bytesRead;
+ }
+ }
+
private static void VerifyCompressedIpcFileBatch(RecordBatch batch)
{
var intArray = (Int32Array)batch.Column("integers");
@@ -103,4 +141,3 @@ private static void VerifyCompressedIpcFileBatch(RecordBatch batch)
}
}
}
-
diff --git a/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs b/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
index 5e7c57e1..d04e0cd9 100644
--- a/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
+++ b/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
@@ -16,9 +16,11 @@
using System;
using System.Buffers.Binary;
using System.IO;
+using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Ipc;
+using Apache.Arrow.Memory;
using Apache.Arrow.Types;
using Xunit;
@@ -69,13 +71,20 @@ public async Task Ctor_MemoryPool_AllocatesFromPool(bool shouldLeaveOpen, bool c
var memoryPool = new TestMemoryAllocator();
ArrowStreamReader reader = new ArrowStreamReader(stream, memoryPool, shouldLeaveOpen);
- reader.ReadNextRecordBatch();
-
- Assert.Equal(expectedAllocations, memoryPool.Statistics.Allocations);
- Assert.True(memoryPool.Statistics.BytesAllocated > 0);
+ using (RecordBatch readBatch = reader.ReadNextRecordBatch())
+ {
+ Assert.Equal(expectedAllocations, memoryPool.Statistics.Allocations);
+ Assert.True(memoryPool.Statistics.BytesAllocated > 0);
+ Assert.Equal(expectedAllocations, memoryPool.Rented);
+ }
reader.Dispose();
+ if (!createDictionaryArray)
+ {
+ Assert.Equal(0, memoryPool.Rented);
+ }
+
if (shouldLeaveOpen)
{
Assert.True(stream.Position > 0);
@@ -109,6 +118,22 @@ public async Task ReadRecordBatchAsync_Memory(bool writeEnd)
await TestReaderFromMemory(ArrowReaderVerifier.VerifyReaderAsync, writeEnd);
}
+ [Fact]
+ public async Task ReadRecordBatch_Memory_ExactLengthSlice()
+ {
+ await TestReaderFromMemoryExactLength((reader, originalBatch) =>
+ {
+ ArrowReaderVerifier.VerifyReader(reader, originalBatch);
+ return Task.CompletedTask;
+ });
+ }
+
+ [Fact]
+ public async Task ReadRecordBatchAsync_Memory_ExactLengthSlice()
+ {
+ await TestReaderFromMemoryExactLength(ArrowReaderVerifier.VerifyReaderAsync);
+ }
+
private static async Task TestReaderFromMemory(
Func verificationFunc,
bool writeEnd)
@@ -131,6 +156,24 @@ private static async Task TestReaderFromMemory(
await verificationFunc(reader, originalBatch);
}
+ private static async Task TestReaderFromMemoryExactLength(
+ Func verificationFunc)
+ {
+ RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100);
+
+ ReadOnlyMemory buffer;
+ using (MemoryStream stream = new MemoryStream())
+ {
+ ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema);
+ await writer.WriteRecordBatchAsync(originalBatch);
+ await writer.WriteEndAsync();
+ buffer = stream.GetBuffer().AsMemory(0, checked((int)stream.Length));
+ }
+
+ ArrowStreamReader reader = new ArrowStreamReader(buffer);
+ await verificationFunc(reader, originalBatch);
+ }
+
[Fact]
public void ReadRecordBatch_EmptyStream()
{
@@ -167,6 +210,40 @@ public async Task ReadRecordBatchAsync_EmptyStream()
}
}
+ [Fact]
+ public async Task ReadRecordBatchAsync_PassesCancellationTokenToSchemaRead()
+ {
+ using var stream = new RequiresCancelableReadStream();
+ using var reader = new ArrowStreamReader(stream);
+ using var cancellation = new CancellationTokenSource();
+
+ await Assert.ThrowsAnyAsync(async () =>
+ await reader.ReadNextRecordBatchAsync(cancellation.Token));
+
+ Assert.True(stream.SawCancelableToken);
+ }
+
+ [Fact]
+ public async Task ReadRecordBatchAsync_Stream_DictionaryFixtureWithoutRee()
+ {
+ using RecordBatch originalBatch = TestData.CreateSampleRecordBatch(
+ length: 100,
+ columnSetCount: 5,
+ excludedTypes: new System.Collections.Generic.HashSet { ArrowTypeId.RunEndEncoded });
+
+ using var stream = new MemoryStream();
+ using (ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true))
+ {
+ await writer.WriteRecordBatchAsync(originalBatch);
+ await writer.WriteEndAsync();
+ }
+
+ stream.Position = 0;
+
+ using var reader = new ArrowStreamReader(stream);
+ await ArrowReaderVerifier.VerifyReaderAsync(reader, originalBatch);
+ }
+
[Theory]
[InlineData(true, true)]
[InlineData(true, false)]
@@ -177,6 +254,84 @@ public async Task ReadRecordBatchAsync_Stream(bool writeEnd, bool createDictiona
await TestReaderFromStream(ArrowReaderVerifier.VerifyReaderAsync, writeEnd, createDictionaryArray);
}
+ [Theory]
+ [InlineData(true)]
+ [InlineData(false)]
+ public async Task ReadRecordBatchAsync_NonPubliclyVisibleMemoryStream(bool createDictionaryArray)
+ {
+ RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray);
+
+ byte[] buffer;
+ using (MemoryStream stream = new MemoryStream())
+ {
+ ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true);
+ await writer.WriteRecordBatchAsync(originalBatch);
+ await writer.WriteEndAsync();
+ buffer = stream.ToArray();
+ }
+
+ using (MemoryStream stream = new MemoryStream(buffer))
+ {
+ ArrowStreamReader reader = new ArrowStreamReader(stream);
+ await ArrowReaderVerifier.VerifyReaderAsync(reader, originalBatch);
+ }
+ }
+
+ [Fact]
+ public async Task ReadRecordBatchAsync_NonPubliclyVisibleMemoryStream_UsesExplicitAllocator()
+ {
+ RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false);
+
+ byte[] buffer;
+ using (MemoryStream stream = new MemoryStream())
+ {
+ ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true);
+ await writer.WriteRecordBatchAsync(originalBatch);
+ await writer.WriteEndAsync();
+ buffer = stream.ToArray();
+ }
+
+ var allocator = new TestMemoryAllocator();
+ using (MemoryStream stream = new MemoryStream(buffer))
+ using (var reader = new ArrowStreamReader(stream, allocator))
+ {
+ using (RecordBatch readBatch = await reader.ReadNextRecordBatchAsync())
+ {
+ ArrowReaderVerifier.CompareBatches(originalBatch, readBatch);
+ }
+
+ Assert.True(allocator.Statistics.Allocations > 0);
+ Assert.Equal(0, allocator.Rented);
+ Assert.Null(await reader.ReadNextRecordBatchAsync());
+ }
+ }
+
+ [Fact]
+ public async Task ReadRecordBatchAsync_NonPubliclyVisibleMemoryStream_UsesDefaultAllocator()
+ {
+ RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false);
+
+ byte[] buffer;
+ using (MemoryStream stream = new MemoryStream())
+ {
+ ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true);
+ await writer.WriteRecordBatchAsync(originalBatch);
+ await writer.WriteEndAsync();
+ buffer = stream.ToArray();
+ }
+
+ long allocationsBeforeRead = MemoryAllocator.Default.Value.Statistics.Allocations;
+
+ using (MemoryStream stream = new MemoryStream(buffer))
+ using (var reader = new ArrowStreamReader(stream, MemoryAllocator.Default.Value))
+ using (RecordBatch readBatch = await reader.ReadNextRecordBatchAsync())
+ {
+ ArrowReaderVerifier.CompareBatches(originalBatch, readBatch);
+ }
+
+ Assert.True(MemoryAllocator.Default.Value.Statistics.Allocations > allocationsBeforeRead);
+ }
+
private static async Task TestReaderFromStream(
Func verificationFunc,
bool writeEnd, bool createDictionaryArray)
@@ -199,6 +354,92 @@ private static async Task TestReaderFromStream(
}
}
+ [Fact]
+ public async Task ReadRecordBatch_ExposedMemoryStream_BatchRemainsUsableAfterDispose()
+ {
+ RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false);
+ RecordBatch readBatch;
+
+ using (MemoryStream stream = new MemoryStream())
+ {
+ ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true);
+ await writer.WriteRecordBatchAsync(originalBatch);
+ await writer.WriteEndAsync();
+
+ stream.Position = 0;
+
+ using (ArrowStreamReader reader = new ArrowStreamReader(stream, leaveOpen: true))
+ {
+ readBatch = reader.ReadNextRecordBatch();
+ }
+ }
+
+ using (readBatch)
+ {
+ ArrowReaderVerifier.CompareBatches(originalBatch, readBatch);
+ }
+ }
+
+ [Fact]
+ public async Task ReadRecordBatch_ExposedMemoryStream_BatchDoesNotAliasMutableStreamBuffer()
+ {
+ RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false);
+ RecordBatch readBatch;
+ byte[] streamBuffer;
+
+ using (MemoryStream stream = new MemoryStream())
+ {
+ ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true);
+ await writer.WriteRecordBatchAsync(originalBatch);
+ await writer.WriteEndAsync();
+
+ streamBuffer = stream.GetBuffer();
+ stream.Position = 0;
+
+ using (ArrowStreamReader reader = new ArrowStreamReader(stream, leaveOpen: true))
+ {
+ readBatch = reader.ReadNextRecordBatch();
+ }
+ }
+
+ System.Array.Clear(streamBuffer, 0, streamBuffer.Length);
+
+ using (readBatch)
+ {
+ ArrowReaderVerifier.CompareBatches(originalBatch, readBatch);
+ }
+ }
+
+ [Fact]
+ public async Task ReadRecordBatchAsync_ExposedMemoryStream_BatchDoesNotAliasMutableStreamBuffer()
+ {
+ RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false);
+ RecordBatch readBatch;
+ byte[] streamBuffer;
+
+ using (MemoryStream stream = new MemoryStream())
+ {
+ ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true);
+ await writer.WriteRecordBatchAsync(originalBatch);
+ await writer.WriteEndAsync();
+
+ streamBuffer = stream.GetBuffer();
+ stream.Position = 0;
+
+ using (ArrowStreamReader reader = new ArrowStreamReader(stream, leaveOpen: true))
+ {
+ readBatch = await reader.ReadNextRecordBatchAsync();
+ }
+ }
+
+ System.Array.Clear(streamBuffer, 0, streamBuffer.Length);
+
+ using (readBatch)
+ {
+ ArrowReaderVerifier.CompareBatches(originalBatch, readBatch);
+ }
+ }
+
[Theory]
[InlineData(true)]
[InlineData(false)]
@@ -243,11 +484,34 @@ private static async Task TestReaderFromPartialReadStream(Func
/// A stream class that only returns a part of the data at a time.
///
- private class PartialReadStream : MemoryStream
+ private class PartialReadStream : Stream
{
+ private readonly MemoryStream _innerStream = new MemoryStream();
+
// by default return 20 bytes at a time
public int PartialReadLength { get; set; } = 20;
+ public override bool CanRead => _innerStream.CanRead;
+ public override bool CanSeek => _innerStream.CanSeek;
+ public override bool CanWrite => _innerStream.CanWrite;
+ public override long Length => _innerStream.Length;
+ public override long Position { get => _innerStream.Position; set => _innerStream.Position = value; }
+
+ public override void Flush() => _innerStream.Flush();
+ public override long Seek(long offset, SeekOrigin origin) => _innerStream.Seek(offset, origin);
+ public override void SetLength(long value) => _innerStream.SetLength(value);
+ public override void Write(byte[] buffer, int offset, int count) => _innerStream.Write(buffer, offset, count);
+
+ public override int Read(byte[] buffer, int offset, int count)
+ {
+ return _innerStream.Read(buffer, offset, Math.Min(count, PartialReadLength));
+ }
+
+ public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default)
+ {
+ return _innerStream.ReadAsync(buffer, offset, Math.Min(count, PartialReadLength), cancellationToken);
+ }
+
#if NET5_0_OR_GREATER
public override int Read(Span destination)
{
@@ -256,7 +520,7 @@ public override int Read(Span destination)
destination = destination.Slice(0, PartialReadLength);
}
- return base.Read(destination);
+ return _innerStream.Read(destination);
}
public override ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default)
@@ -266,53 +530,50 @@ public override ValueTask ReadAsync(Memory destination, CancellationT
destination = destination.Slice(0, PartialReadLength);
}
- return base.ReadAsync(destination, cancellationToken);
- }
-#else
- public override int Read(byte[] buffer, int offset, int length)
- {
- return base.Read(buffer, offset, Math.Min(length, PartialReadLength));
- }
-
- public override Task ReadAsync(byte[] buffer, int offset, int length, CancellationToken cancellationToken = default)
- {
- return base.ReadAsync(buffer, offset, Math.Min(length, PartialReadLength), cancellationToken);
+ return _innerStream.ReadAsync(destination, cancellationToken);
}
#endif
}
- [Fact]
- public unsafe void MalformedColumnNameLength()
+ private class RequiresCancelableReadStream : Stream
{
- const int FieldNameLengthOffset = 108;
- const int FakeFieldNameLength = 165535;
+ public bool SawCancelableToken { get; private set; }
- byte[] buffer;
- using (var stream = new MemoryStream())
+ public override bool CanRead => true;
+ public override bool CanSeek => false;
+ public override bool CanWrite => false;
+ public override long Length => throw new NotSupportedException();
+ public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
+
+ public override void Flush() { }
+ public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException();
+ public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
+ public override void SetLength(long value) => throw new NotSupportedException();
+ public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException();
+
+#if NET5_0_OR_GREATER
+ public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default)
{
- Schema schema = new(
- [new Field("index", Int32Type.Default, nullable: false)],
- metadata: []);
- using (var writer = new ArrowStreamWriter(stream, schema, leaveOpen: true))
+ SawCancelableToken = cancellationToken.CanBeCanceled;
+ if (!SawCancelableToken)
{
- writer.WriteStart();
- writer.WriteEnd();
+ throw new InvalidOperationException("Expected the caller's cancellation token during schema read.");
}
- buffer = stream.ToArray();
- }
- Span length = buffer.AsSpan().Slice(FieldNameLengthOffset, sizeof(int)).CastTo();
- Assert.Equal(5, length[0]);
- length[0] = FakeFieldNameLength;
-
- Assert.Throws(() =>
+ throw new OperationCanceledException(cancellationToken);
+ }
+#else
+ public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
- using (var stream = new MemoryStream(buffer))
- using (var reader = new ArrowStreamReader(stream))
+ SawCancelableToken = cancellationToken.CanBeCanceled;
+ if (!SawCancelableToken)
{
- reader.ReadNextRecordBatch();
+ throw new InvalidOperationException("Expected the caller's cancellation token during schema read.");
}
- });
+
+ throw new OperationCanceledException(cancellationToken);
+ }
+#endif
}
[Fact]
@@ -468,34 +729,52 @@ private static int ReadVectorDataStart(byte[] buffer, int tablePos, int vtableSl
private static void WriteInt64LittleEndian(byte[] buffer, int offset, long value)
{
- System.Buffers.Binary.BinaryPrimitives.WriteInt64LittleEndian(
- buffer.AsSpan(offset), value);
+ BinaryPrimitives.WriteInt64LittleEndian(buffer.AsSpan(offset), value);
}
[Fact]
- public async Task EmptyStreamNoSyncRead()
+ public unsafe void MalformedColumnNameLength()
{
- using (var stream = new EmptyAsyncOnlyStream())
+ const int FieldNameLengthOffset = 108;
+ const int FakeFieldNameLength = 165535;
+
+ byte[] buffer;
+ using (var stream = new MemoryStream())
{
- var reader = new ArrowStreamReader(stream);
- var schema = await reader.GetSchema();
- Assert.Null(schema);
+ Schema schema = new(
+ [new Field("index", Int32Type.Default, nullable: false)],
+ metadata: []);
+ using (var writer = new ArrowStreamWriter(stream, schema, leaveOpen: true))
+ {
+ writer.WriteStart();
+ writer.WriteEnd();
+ }
+ buffer = stream.ToArray();
}
- }
- private static short ToInt16LittleEndian(byte[] buffer, int offset)
- {
- return BinaryPrimitives.ReadInt16LittleEndian(buffer.AsSpan().Slice(offset));
- }
+ Span length = buffer.AsSpan().Slice(FieldNameLengthOffset, sizeof(int)).CastTo();
+ Assert.Equal(5, length[0]);
+ length[0] = FakeFieldNameLength;
- private static int ToInt32LittleEndian(byte[] buffer, int offset)
- {
- return BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan().Slice(offset));
+ Assert.Throws(() =>
+ {
+ using (var stream = new MemoryStream(buffer))
+ using (var reader = new ArrowStreamReader(stream))
+ {
+ reader.ReadNextRecordBatch();
+ }
+ });
}
- private static long ToInt64LittleEndian(byte[] buffer, int offset)
+ [Fact]
+ public async Task EmptyStreamNoSyncRead()
{
- return BinaryPrimitives.ReadInt64LittleEndian(buffer.AsSpan().Slice(offset));
+ using (var stream = new EmptyAsyncOnlyStream())
+ {
+ var reader = new ArrowStreamReader(stream);
+ var schema = await reader.GetSchema();
+ Assert.Null(schema);
+ }
}
private class EmptyAsyncOnlyStream : Stream
@@ -512,5 +791,20 @@ public override void Flush() { }
public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException();
public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => Task.FromResult(0);
}
+
+ private static short ToInt16LittleEndian(byte[] buffer, int offset)
+ {
+ return BinaryPrimitives.ReadInt16LittleEndian(buffer.AsSpan().Slice(offset));
+ }
+
+ private static int ToInt32LittleEndian(byte[] buffer, int offset)
+ {
+ return BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan().Slice(offset));
+ }
+
+ private static long ToInt64LittleEndian(byte[] buffer, int offset)
+ {
+ return BinaryPrimitives.ReadInt64LittleEndian(buffer.AsSpan().Slice(offset));
+ }
}
}