diff --git a/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs b/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs index e9cd9361..887036c2 100644 --- a/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs +++ b/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs @@ -87,7 +87,7 @@ public override RecordBatch ReadNextRecordBatch() else if (messageLength == MessageSerializer.IpcContinuationToken) { // ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length - if (_buffer.Length <= _bufferPosition + sizeof(int)) + if (_buffer.Length < _bufferPosition + sizeof(int)) { throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message."); } @@ -136,7 +136,7 @@ public override void ReadSchema() if (schemaMessageLength == MessageSerializer.IpcContinuationToken) { // ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length - if (_buffer.Length <= _bufferPosition + sizeof(int)) + if (_buffer.Length < _bufferPosition + sizeof(int)) { throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message."); } diff --git a/src/Apache.Arrow/Ipc/ArrowMemoryStreamReaderImplementation.cs b/src/Apache.Arrow/Ipc/ArrowMemoryStreamReaderImplementation.cs new file mode 100644 index 00000000..01d3b0b5 --- /dev/null +++ b/src/Apache.Arrow/Ipc/ArrowMemoryStreamReaderImplementation.cs @@ -0,0 +1,214 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Buffers; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Memory; + +namespace Apache.Arrow.Ipc +{ + /// + /// Reads Arrow IPC streams from a whose backing buffer is publicly visible. + /// + /// + /// Message metadata can be read directly from the exposed stream buffer, but record batch bodies are + /// still copied into allocator-owned buffers to preserve ownership semantics. + /// + internal sealed class ArrowMemoryStreamReaderImplementation : ArrowStreamReaderImplementation + { + private readonly MemoryStream _stream; + private readonly Memory _streamMemory; + + public ArrowMemoryStreamReaderImplementation( + MemoryStream stream, + MemoryAllocator allocator, + ICompressionCodecFactory compressionCodecFactory, + bool leaveOpen, + ExtensionTypeRegistry extensionRegistry) + : base(stream, allocator, compressionCodecFactory, leaveOpen, extensionRegistry) + { + _stream = stream; + + if (!stream.TryGetBuffer(out ArraySegment streamBuffer)) + { + throw new InvalidOperationException("Expected MemoryStream to expose its backing buffer."); + } + + _streamMemory = streamBuffer.Array.AsMemory(streamBuffer.Offset, streamBuffer.Count); + } + + public override ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + try + { + return new ValueTask(ReadNextRecordBatch()); + } + catch (Exception ex) + { + return new ValueTask(Task.FromException(ex)); + } + } + + public override RecordBatch ReadNextRecordBatch() + { + ReadSchema(); + + ReadResult result = default; + do + { + result = ReadMessageFromMemory(); + } while (result.Batch == null && result.MessageLength > 0); + + return result.Batch; + } + + public override ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + + if (HasReadSchema) + { + return new ValueTask(_schema); + } + + try + { + ReadSchema(); + return new ValueTask(_schema); + } + catch (Exception ex) + { + return new ValueTask(Task.FromException(ex)); + } + } + + public override void ReadSchema() + { + if (HasReadSchema) + { + return; + } + + int schemaMessageLength = ReadMessageLengthFromMemory(throwOnFullRead: true, returnOnEmptyStream: true); + if (schemaMessageLength == 0) + { + return; + } + + Memory schemaBuffer = ReadMemory(schemaMessageLength); + _schema = MessageSerializer.GetSchema(ReadMessage(CreateByteBuffer(schemaBuffer)), ref _dictionaryMemo, _extensionRegistry); + } + + private ReadResult ReadMessageFromMemory() + { + int messageLength = ReadMessageLengthFromMemory(throwOnFullRead: false, returnOnEmptyStream: false); + if (messageLength == 0) + { + return default; + } + + Memory messageBuffer = ReadMemory(messageLength); + Flatbuf.Message message = Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuffer)); + + if (message.BodyLength > int.MaxValue) + { + throw new OverflowException( + $"Arrow IPC message body length ({message.BodyLength}) is larger than " + + $"the maximum supported message size ({int.MaxValue})"); + } + + int bodyLength = (int)message.BodyLength; + Memory sourceBodyBuffer = ReadMemory(bodyLength); + IMemoryOwner bodyBufferOwner = AllocateMessageBodyBuffer(bodyLength); + Memory bodyBuffer = bodyBufferOwner.Memory.Slice(0, bodyLength); + sourceBodyBuffer.CopyTo(bodyBuffer); + Google.FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuffer); + + // Keep stream-reader ownership semantics: batches outlive the source MemoryStream buffer. + return new ReadResult(messageLength, CreateArrowObjectFromMessage(message, bodybb, bodyBufferOwner)); + } + + private int ReadMessageLengthFromMemory(bool throwOnFullRead, bool returnOnEmptyStream) + { + if (_stream.Position == _stream.Length && returnOnEmptyStream) + { + return 0; + } + + if (!TryReadInt32(throwOnFullRead, out int messageLength)) + { + return 0; + } + + if (messageLength == MessageSerializer.IpcContinuationToken && + !TryReadInt32(throwOnFullRead, out messageLength)) + { + return 0; + } + + return messageLength; + } + + private bool TryReadInt32(bool throwOnFullRead, out int value) + { + value = 0; + + if (!TryReadMemory(sizeof(int), throwOnFullRead, out Memory buffer)) + { + return false; + } + + value = BitUtility.ReadInt32(buffer); + return true; + } + + private bool TryReadMemory(int length, bool throwOnFullRead, out Memory buffer) + { + buffer = default; + + long remainingLength = _stream.Length - _stream.Position; + if (remainingLength < length) + { + if (throwOnFullRead) + { + throw new InvalidOperationException("Unexpectedly reached the end of the stream before a full buffer was read."); + } + + _stream.Position = _stream.Length; + return false; + } + + buffer = ReadMemory(length); + return true; + } + + private Memory ReadMemory(int length) + { + if (length == 0) + { + return Memory.Empty; + } + + Memory buffer = _streamMemory.Slice(checked((int)_stream.Position), length); + _stream.Position += length; + return buffer; + } + } +} diff --git a/src/Apache.Arrow/Ipc/ArrowStreamReader.cs b/src/Apache.Arrow/Ipc/ArrowStreamReader.cs index e5dade2b..6100eea3 100644 --- a/src/Apache.Arrow/Ipc/ArrowStreamReader.cs +++ b/src/Apache.Arrow/Ipc/ArrowStreamReader.cs @@ -68,7 +68,7 @@ public ArrowStreamReader(Stream stream, MemoryAllocator allocator, ICompressionC if (stream == null) throw new ArgumentNullException(nameof(stream)); - _implementation = new ArrowStreamReaderImplementation(stream, allocator, compressionCodecFactory, leaveOpen); + _implementation = CreateImplementation(stream, allocator, compressionCodecFactory, leaveOpen, extensionRegistry: null); } public ArrowStreamReader(ArrowContext context, Stream stream, bool leaveOpen = false) @@ -78,7 +78,7 @@ public ArrowStreamReader(ArrowContext context, Stream stream, bool leaveOpen = f if (context == null) throw new ArgumentNullException(nameof(context)); - _implementation = new ArrowStreamReaderImplementation(stream, context.Allocator, context.CompressionCodecFactory, leaveOpen, context.ExtensionRegistry); + _implementation = CreateImplementation(stream, context.Allocator, context.CompressionCodecFactory, leaveOpen, context.ExtensionRegistry); } public ArrowStreamReader(ReadOnlyMemory buffer) @@ -104,6 +104,21 @@ private protected ArrowStreamReader(ArrowReaderImplementation implementation) _implementation = implementation; } + private static ArrowReaderImplementation CreateImplementation( + Stream stream, + MemoryAllocator allocator, + ICompressionCodecFactory compressionCodecFactory, + bool leaveOpen, + ExtensionTypeRegistry extensionRegistry) + { + if (stream is MemoryStream memoryStream && memoryStream.TryGetBuffer(out _)) + { + return new ArrowMemoryStreamReaderImplementation(memoryStream, allocator, compressionCodecFactory, leaveOpen, extensionRegistry); + } + + return new ArrowStreamReaderImplementation(stream, allocator, compressionCodecFactory, leaveOpen, extensionRegistry); + } + public void Dispose() { Dispose(true); diff --git a/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs b/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs index 9d0cbe33..179b9984 100644 --- a/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs +++ b/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs @@ -47,11 +47,10 @@ protected override void Dispose(bool disposing) } } - public override async ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken) + public override ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken) { - // TODO: Loop until a record batch is read. cancellationToken.ThrowIfCancellationRequested(); - return await ReadRecordBatchAsync(cancellationToken).ConfigureAwait(false); + return ReadRecordBatchAsync(cancellationToken); } public override RecordBatch ReadNextRecordBatch() @@ -61,7 +60,7 @@ public override RecordBatch ReadNextRecordBatch() protected async ValueTask ReadRecordBatchAsync(CancellationToken cancellationToken = default) { - await ReadSchemaAsync().ConfigureAwait(false); + await ReadSchemaAsync(cancellationToken).ConfigureAwait(false); ReadResult result = default; do @@ -94,7 +93,7 @@ protected async ValueTask ReadMessageAsync(CancellationToken cancell int bodyLength = checked((int)message.BodyLength); - IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength); + IMemoryOwner bodyBuffOwner = AllocateMessageBodyBuffer(bodyLength); Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength); bytesRead = await BaseStream.ReadFullBufferAsync(bodyBuff, cancellationToken) .ConfigureAwait(false); @@ -145,7 +144,7 @@ protected ReadResult ReadMessage() } int bodyLength = (int)message.BodyLength; - IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength); + IMemoryOwner bodyBuffOwner = AllocateMessageBodyBuffer(bodyLength); Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength); bytesRead = BaseStream.ReadFullBuffer(bodyBuff); EnsureFullRead(bodyBuff, bytesRead); @@ -157,13 +156,25 @@ protected ReadResult ReadMessage() return new ReadResult(messageLength, result); } - public override async ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default) + protected IMemoryOwner AllocateMessageBodyBuffer(int bodyLength) { + return _allocator.Allocate(bodyLength); + } + + public override ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + if (HasReadSchema) { - return _schema; + return new ValueTask(_schema); } + return ReadSchemaAsyncCore(cancellationToken); + } + + private async ValueTask ReadSchemaAsyncCore(CancellationToken cancellationToken) + { // Figure out length of schema int schemaMessageLength = await ReadMessageLengthAsync(throwOnFullRead: true, returnOnEmptyStream: true, cancellationToken) .ConfigureAwait(false); diff --git a/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs b/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs index 3760d940..9305adb1 100644 --- a/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs +++ b/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs @@ -14,8 +14,8 @@ // limitations under the License. using System; +using System.Collections.Generic; using System.IO; -using System.Linq; using System.Threading.Tasks; using Apache.Arrow.Ipc; using Apache.Arrow.Memory; @@ -32,13 +32,19 @@ public class ArrowReaderBenchmark [Params(10_000, 1_000_000)] public int Count { get; set; } + [Params(1, 5)] + public int ColumnSetCount { get; set; } + private MemoryStream _memoryStream; private static readonly MemoryAllocator s_allocator = new TestMemoryAllocator(); [GlobalSetup] public async Task GlobalSetup() { - RecordBatch batch = TestData.CreateSampleRecordBatch(length: Count, createDictionaryArray: false); + RecordBatch batch = TestData.CreateSampleRecordBatch( + length: Count, + columnSetCount: ColumnSetCount, + excludedTypes: new HashSet { ArrowTypeId.Dictionary, ArrowTypeId.RunEndEncoded }); _memoryStream = new MemoryStream(); ArrowStreamWriter writer = new ArrowStreamWriter(_memoryStream, batch.Schema); @@ -83,6 +89,73 @@ public async Task ArrowReaderWithMemoryStream_ManagedMemory() return sum; } + [Benchmark] + public async Task ArrowReaderWithMemoryStream_ExplicitDefaultAllocator() + { + double sum = 0; + var reader = new ArrowStreamReader(_memoryStream, MemoryAllocator.Default.Value); + RecordBatch recordBatch; + while ((recordBatch = await reader.ReadNextRecordBatchAsync()) != null) + { + using (recordBatch) + { + sum += SumAllNumbers(recordBatch); + } + } + return sum; + } + + [Benchmark] + public async Task ArrowReaderWithNonPubliclyVisibleMemoryStream() + { + double sum = 0; + using var stream = CreateNonPubliclyVisibleReadStream(); + using var reader = new ArrowStreamReader(stream); + RecordBatch recordBatch; + while ((recordBatch = await reader.ReadNextRecordBatchAsync()) != null) + { + using (recordBatch) + { + sum += SumAllNumbers(recordBatch); + } + } + return sum; + } + + [Benchmark] + public async Task ArrowReaderWithNonPubliclyVisibleMemoryStream_ManagedMemory() + { + double sum = 0; + using var stream = CreateNonPubliclyVisibleReadStream(); + using var reader = new ArrowStreamReader(stream, s_allocator); + RecordBatch recordBatch; + while ((recordBatch = await reader.ReadNextRecordBatchAsync()) != null) + { + using (recordBatch) + { + sum += SumAllNumbers(recordBatch); + } + } + return sum; + } + + [Benchmark] + public async Task ArrowReaderWithNonPubliclyVisibleMemoryStream_ExplicitDefaultAllocator() + { + double sum = 0; + using var stream = CreateNonPubliclyVisibleReadStream(); + using var reader = new ArrowStreamReader(stream, MemoryAllocator.Default.Value); + RecordBatch recordBatch; + while ((recordBatch = await reader.ReadNextRecordBatchAsync()) != null) + { + using (recordBatch) + { + sum += SumAllNumbers(recordBatch); + } + } + return sum; + } + [Benchmark] public async Task ArrowReaderWithMemory() { @@ -99,14 +172,25 @@ public async Task ArrowReaderWithMemory() return sum; } + private MemoryStream CreateNonPubliclyVisibleReadStream() + { + return new MemoryStream( + _memoryStream.GetBuffer(), + index: 0, + count: checked((int)_memoryStream.Length), + writable: false, + publiclyVisible: false); + } + private static double SumAllNumbers(RecordBatch recordBatch) { double sum = 0; for (int k = 0; k < recordBatch.ColumnCount; k++) { - var array = recordBatch.Arrays.ElementAt(k); - switch (recordBatch.Schema.GetFieldByIndex(k).DataType.TypeId) + var array = recordBatch.Column(k); + ArrowTypeId typeId = recordBatch.Schema.GetFieldByIndex(k).DataType.TypeId; + switch (typeId) { case ArrowTypeId.Int64: Int64Array int64Array = (Int64Array)array; diff --git a/test/Apache.Arrow.Compression.Tests/ArrowStreamReaderTests.cs b/test/Apache.Arrow.Compression.Tests/ArrowStreamReaderTests.cs index 9c2bf75d..99e74424 100644 --- a/test/Apache.Arrow.Compression.Tests/ArrowStreamReaderTests.cs +++ b/test/Apache.Arrow.Compression.Tests/ArrowStreamReaderTests.cs @@ -14,8 +14,10 @@ // limitations under the License. using System; +using System.IO; using System.Reflection; using Apache.Arrow.Ipc; +using Apache.Arrow.Memory; using Apache.Arrow.Tests; using Xunit; @@ -47,13 +49,34 @@ public void CanReadCompressedIpcStreamFromMemoryBuffer(string fileName) using var stream = assembly.GetManifestResourceStream($"Apache.Arrow.Compression.Tests.Resources.{fileName}"); Assert.NotNull(stream); var buffer = new byte[stream.Length]; - stream.ReadFullBuffer(buffer); + ReadExactly(stream, buffer); var codecFactory = new Compression.CompressionCodecFactory(); using var reader = new ArrowStreamReader(buffer, codecFactory); VerifyCompressedIpcFileBatch(reader.ReadNextRecordBatch()); } + [Theory] + [InlineData("ipc_lz4_compression.arrow_stream")] + [InlineData("ipc_zstd_compression.arrow_stream")] + public void CanReadCompressedIpcStreamFromMemoryBuffer_UsesDefaultAllocator(string fileName) + { + var assembly = Assembly.GetExecutingAssembly(); + using var stream = assembly.GetManifestResourceStream($"Apache.Arrow.Compression.Tests.Resources.{fileName}"); + Assert.NotNull(stream); + var buffer = new byte[stream.Length]; + ReadExactly(stream, buffer); + var codecFactory = new Compression.CompressionCodecFactory(); + + long allocationsBeforeRead = MemoryAllocator.Default.Value.Statistics.Allocations; + + using var reader = new ArrowStreamReader(buffer, codecFactory); + using RecordBatch batch = reader.ReadNextRecordBatch(); + VerifyCompressedIpcFileBatch(batch); + + Assert.True(MemoryAllocator.Default.Value.Statistics.Allocations > allocationsBeforeRead); + } + [Fact] public void ErrorReadingCompressedStreamWithoutCodecFactory() { @@ -86,6 +109,21 @@ public void MemoryPoolDisposedOnReadCompressedIpcStream(string fileName) } + private static void ReadExactly(Stream stream, byte[] buffer) + { + int offset = 0; + while (offset < buffer.Length) + { + int bytesRead = stream.Read(buffer, offset, buffer.Length - offset); + if (bytesRead == 0) + { + throw new EndOfStreamException(); + } + + offset += bytesRead; + } + } + private static void VerifyCompressedIpcFileBatch(RecordBatch batch) { var intArray = (Int32Array)batch.Column("integers"); @@ -103,4 +141,3 @@ private static void VerifyCompressedIpcFileBatch(RecordBatch batch) } } } - diff --git a/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs b/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs index 5e7c57e1..d04e0cd9 100644 --- a/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs +++ b/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs @@ -16,9 +16,11 @@ using System; using System.Buffers.Binary; using System.IO; +using System.Reflection; using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Ipc; +using Apache.Arrow.Memory; using Apache.Arrow.Types; using Xunit; @@ -69,13 +71,20 @@ public async Task Ctor_MemoryPool_AllocatesFromPool(bool shouldLeaveOpen, bool c var memoryPool = new TestMemoryAllocator(); ArrowStreamReader reader = new ArrowStreamReader(stream, memoryPool, shouldLeaveOpen); - reader.ReadNextRecordBatch(); - - Assert.Equal(expectedAllocations, memoryPool.Statistics.Allocations); - Assert.True(memoryPool.Statistics.BytesAllocated > 0); + using (RecordBatch readBatch = reader.ReadNextRecordBatch()) + { + Assert.Equal(expectedAllocations, memoryPool.Statistics.Allocations); + Assert.True(memoryPool.Statistics.BytesAllocated > 0); + Assert.Equal(expectedAllocations, memoryPool.Rented); + } reader.Dispose(); + if (!createDictionaryArray) + { + Assert.Equal(0, memoryPool.Rented); + } + if (shouldLeaveOpen) { Assert.True(stream.Position > 0); @@ -109,6 +118,22 @@ public async Task ReadRecordBatchAsync_Memory(bool writeEnd) await TestReaderFromMemory(ArrowReaderVerifier.VerifyReaderAsync, writeEnd); } + [Fact] + public async Task ReadRecordBatch_Memory_ExactLengthSlice() + { + await TestReaderFromMemoryExactLength((reader, originalBatch) => + { + ArrowReaderVerifier.VerifyReader(reader, originalBatch); + return Task.CompletedTask; + }); + } + + [Fact] + public async Task ReadRecordBatchAsync_Memory_ExactLengthSlice() + { + await TestReaderFromMemoryExactLength(ArrowReaderVerifier.VerifyReaderAsync); + } + private static async Task TestReaderFromMemory( Func verificationFunc, bool writeEnd) @@ -131,6 +156,24 @@ private static async Task TestReaderFromMemory( await verificationFunc(reader, originalBatch); } + private static async Task TestReaderFromMemoryExactLength( + Func verificationFunc) + { + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + + ReadOnlyMemory buffer; + using (MemoryStream stream = new MemoryStream()) + { + ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema); + await writer.WriteRecordBatchAsync(originalBatch); + await writer.WriteEndAsync(); + buffer = stream.GetBuffer().AsMemory(0, checked((int)stream.Length)); + } + + ArrowStreamReader reader = new ArrowStreamReader(buffer); + await verificationFunc(reader, originalBatch); + } + [Fact] public void ReadRecordBatch_EmptyStream() { @@ -167,6 +210,40 @@ public async Task ReadRecordBatchAsync_EmptyStream() } } + [Fact] + public async Task ReadRecordBatchAsync_PassesCancellationTokenToSchemaRead() + { + using var stream = new RequiresCancelableReadStream(); + using var reader = new ArrowStreamReader(stream); + using var cancellation = new CancellationTokenSource(); + + await Assert.ThrowsAnyAsync(async () => + await reader.ReadNextRecordBatchAsync(cancellation.Token)); + + Assert.True(stream.SawCancelableToken); + } + + [Fact] + public async Task ReadRecordBatchAsync_Stream_DictionaryFixtureWithoutRee() + { + using RecordBatch originalBatch = TestData.CreateSampleRecordBatch( + length: 100, + columnSetCount: 5, + excludedTypes: new System.Collections.Generic.HashSet { ArrowTypeId.RunEndEncoded }); + + using var stream = new MemoryStream(); + using (ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true)) + { + await writer.WriteRecordBatchAsync(originalBatch); + await writer.WriteEndAsync(); + } + + stream.Position = 0; + + using var reader = new ArrowStreamReader(stream); + await ArrowReaderVerifier.VerifyReaderAsync(reader, originalBatch); + } + [Theory] [InlineData(true, true)] [InlineData(true, false)] @@ -177,6 +254,84 @@ public async Task ReadRecordBatchAsync_Stream(bool writeEnd, bool createDictiona await TestReaderFromStream(ArrowReaderVerifier.VerifyReaderAsync, writeEnd, createDictionaryArray); } + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadRecordBatchAsync_NonPubliclyVisibleMemoryStream(bool createDictionaryArray) + { + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray); + + byte[] buffer; + using (MemoryStream stream = new MemoryStream()) + { + ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true); + await writer.WriteRecordBatchAsync(originalBatch); + await writer.WriteEndAsync(); + buffer = stream.ToArray(); + } + + using (MemoryStream stream = new MemoryStream(buffer)) + { + ArrowStreamReader reader = new ArrowStreamReader(stream); + await ArrowReaderVerifier.VerifyReaderAsync(reader, originalBatch); + } + } + + [Fact] + public async Task ReadRecordBatchAsync_NonPubliclyVisibleMemoryStream_UsesExplicitAllocator() + { + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false); + + byte[] buffer; + using (MemoryStream stream = new MemoryStream()) + { + ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true); + await writer.WriteRecordBatchAsync(originalBatch); + await writer.WriteEndAsync(); + buffer = stream.ToArray(); + } + + var allocator = new TestMemoryAllocator(); + using (MemoryStream stream = new MemoryStream(buffer)) + using (var reader = new ArrowStreamReader(stream, allocator)) + { + using (RecordBatch readBatch = await reader.ReadNextRecordBatchAsync()) + { + ArrowReaderVerifier.CompareBatches(originalBatch, readBatch); + } + + Assert.True(allocator.Statistics.Allocations > 0); + Assert.Equal(0, allocator.Rented); + Assert.Null(await reader.ReadNextRecordBatchAsync()); + } + } + + [Fact] + public async Task ReadRecordBatchAsync_NonPubliclyVisibleMemoryStream_UsesDefaultAllocator() + { + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false); + + byte[] buffer; + using (MemoryStream stream = new MemoryStream()) + { + ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true); + await writer.WriteRecordBatchAsync(originalBatch); + await writer.WriteEndAsync(); + buffer = stream.ToArray(); + } + + long allocationsBeforeRead = MemoryAllocator.Default.Value.Statistics.Allocations; + + using (MemoryStream stream = new MemoryStream(buffer)) + using (var reader = new ArrowStreamReader(stream, MemoryAllocator.Default.Value)) + using (RecordBatch readBatch = await reader.ReadNextRecordBatchAsync()) + { + ArrowReaderVerifier.CompareBatches(originalBatch, readBatch); + } + + Assert.True(MemoryAllocator.Default.Value.Statistics.Allocations > allocationsBeforeRead); + } + private static async Task TestReaderFromStream( Func verificationFunc, bool writeEnd, bool createDictionaryArray) @@ -199,6 +354,92 @@ private static async Task TestReaderFromStream( } } + [Fact] + public async Task ReadRecordBatch_ExposedMemoryStream_BatchRemainsUsableAfterDispose() + { + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false); + RecordBatch readBatch; + + using (MemoryStream stream = new MemoryStream()) + { + ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true); + await writer.WriteRecordBatchAsync(originalBatch); + await writer.WriteEndAsync(); + + stream.Position = 0; + + using (ArrowStreamReader reader = new ArrowStreamReader(stream, leaveOpen: true)) + { + readBatch = reader.ReadNextRecordBatch(); + } + } + + using (readBatch) + { + ArrowReaderVerifier.CompareBatches(originalBatch, readBatch); + } + } + + [Fact] + public async Task ReadRecordBatch_ExposedMemoryStream_BatchDoesNotAliasMutableStreamBuffer() + { + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false); + RecordBatch readBatch; + byte[] streamBuffer; + + using (MemoryStream stream = new MemoryStream()) + { + ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true); + await writer.WriteRecordBatchAsync(originalBatch); + await writer.WriteEndAsync(); + + streamBuffer = stream.GetBuffer(); + stream.Position = 0; + + using (ArrowStreamReader reader = new ArrowStreamReader(stream, leaveOpen: true)) + { + readBatch = reader.ReadNextRecordBatch(); + } + } + + System.Array.Clear(streamBuffer, 0, streamBuffer.Length); + + using (readBatch) + { + ArrowReaderVerifier.CompareBatches(originalBatch, readBatch); + } + } + + [Fact] + public async Task ReadRecordBatchAsync_ExposedMemoryStream_BatchDoesNotAliasMutableStreamBuffer() + { + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false); + RecordBatch readBatch; + byte[] streamBuffer; + + using (MemoryStream stream = new MemoryStream()) + { + ArrowStreamWriter writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true); + await writer.WriteRecordBatchAsync(originalBatch); + await writer.WriteEndAsync(); + + streamBuffer = stream.GetBuffer(); + stream.Position = 0; + + using (ArrowStreamReader reader = new ArrowStreamReader(stream, leaveOpen: true)) + { + readBatch = await reader.ReadNextRecordBatchAsync(); + } + } + + System.Array.Clear(streamBuffer, 0, streamBuffer.Length); + + using (readBatch) + { + ArrowReaderVerifier.CompareBatches(originalBatch, readBatch); + } + } + [Theory] [InlineData(true)] [InlineData(false)] @@ -243,11 +484,34 @@ private static async Task TestReaderFromPartialReadStream(Func /// A stream class that only returns a part of the data at a time. /// - private class PartialReadStream : MemoryStream + private class PartialReadStream : Stream { + private readonly MemoryStream _innerStream = new MemoryStream(); + // by default return 20 bytes at a time public int PartialReadLength { get; set; } = 20; + public override bool CanRead => _innerStream.CanRead; + public override bool CanSeek => _innerStream.CanSeek; + public override bool CanWrite => _innerStream.CanWrite; + public override long Length => _innerStream.Length; + public override long Position { get => _innerStream.Position; set => _innerStream.Position = value; } + + public override void Flush() => _innerStream.Flush(); + public override long Seek(long offset, SeekOrigin origin) => _innerStream.Seek(offset, origin); + public override void SetLength(long value) => _innerStream.SetLength(value); + public override void Write(byte[] buffer, int offset, int count) => _innerStream.Write(buffer, offset, count); + + public override int Read(byte[] buffer, int offset, int count) + { + return _innerStream.Read(buffer, offset, Math.Min(count, PartialReadLength)); + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) + { + return _innerStream.ReadAsync(buffer, offset, Math.Min(count, PartialReadLength), cancellationToken); + } + #if NET5_0_OR_GREATER public override int Read(Span destination) { @@ -256,7 +520,7 @@ public override int Read(Span destination) destination = destination.Slice(0, PartialReadLength); } - return base.Read(destination); + return _innerStream.Read(destination); } public override ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) @@ -266,53 +530,50 @@ public override ValueTask ReadAsync(Memory destination, CancellationT destination = destination.Slice(0, PartialReadLength); } - return base.ReadAsync(destination, cancellationToken); - } -#else - public override int Read(byte[] buffer, int offset, int length) - { - return base.Read(buffer, offset, Math.Min(length, PartialReadLength)); - } - - public override Task ReadAsync(byte[] buffer, int offset, int length, CancellationToken cancellationToken = default) - { - return base.ReadAsync(buffer, offset, Math.Min(length, PartialReadLength), cancellationToken); + return _innerStream.ReadAsync(destination, cancellationToken); } #endif } - [Fact] - public unsafe void MalformedColumnNameLength() + private class RequiresCancelableReadStream : Stream { - const int FieldNameLengthOffset = 108; - const int FakeFieldNameLength = 165535; + public bool SawCancelableToken { get; private set; } - byte[] buffer; - using (var stream = new MemoryStream()) + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => false; + public override long Length => throw new NotSupportedException(); + public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } + + public override void Flush() { } + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + +#if NET5_0_OR_GREATER + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) { - Schema schema = new( - [new Field("index", Int32Type.Default, nullable: false)], - metadata: []); - using (var writer = new ArrowStreamWriter(stream, schema, leaveOpen: true)) + SawCancelableToken = cancellationToken.CanBeCanceled; + if (!SawCancelableToken) { - writer.WriteStart(); - writer.WriteEnd(); + throw new InvalidOperationException("Expected the caller's cancellation token during schema read."); } - buffer = stream.ToArray(); - } - Span length = buffer.AsSpan().Slice(FieldNameLengthOffset, sizeof(int)).CastTo(); - Assert.Equal(5, length[0]); - length[0] = FakeFieldNameLength; - - Assert.Throws(() => + throw new OperationCanceledException(cancellationToken); + } +#else + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - using (var stream = new MemoryStream(buffer)) - using (var reader = new ArrowStreamReader(stream)) + SawCancelableToken = cancellationToken.CanBeCanceled; + if (!SawCancelableToken) { - reader.ReadNextRecordBatch(); + throw new InvalidOperationException("Expected the caller's cancellation token during schema read."); } - }); + + throw new OperationCanceledException(cancellationToken); + } +#endif } [Fact] @@ -468,34 +729,52 @@ private static int ReadVectorDataStart(byte[] buffer, int tablePos, int vtableSl private static void WriteInt64LittleEndian(byte[] buffer, int offset, long value) { - System.Buffers.Binary.BinaryPrimitives.WriteInt64LittleEndian( - buffer.AsSpan(offset), value); + BinaryPrimitives.WriteInt64LittleEndian(buffer.AsSpan(offset), value); } [Fact] - public async Task EmptyStreamNoSyncRead() + public unsafe void MalformedColumnNameLength() { - using (var stream = new EmptyAsyncOnlyStream()) + const int FieldNameLengthOffset = 108; + const int FakeFieldNameLength = 165535; + + byte[] buffer; + using (var stream = new MemoryStream()) { - var reader = new ArrowStreamReader(stream); - var schema = await reader.GetSchema(); - Assert.Null(schema); + Schema schema = new( + [new Field("index", Int32Type.Default, nullable: false)], + metadata: []); + using (var writer = new ArrowStreamWriter(stream, schema, leaveOpen: true)) + { + writer.WriteStart(); + writer.WriteEnd(); + } + buffer = stream.ToArray(); } - } - private static short ToInt16LittleEndian(byte[] buffer, int offset) - { - return BinaryPrimitives.ReadInt16LittleEndian(buffer.AsSpan().Slice(offset)); - } + Span length = buffer.AsSpan().Slice(FieldNameLengthOffset, sizeof(int)).CastTo(); + Assert.Equal(5, length[0]); + length[0] = FakeFieldNameLength; - private static int ToInt32LittleEndian(byte[] buffer, int offset) - { - return BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan().Slice(offset)); + Assert.Throws(() => + { + using (var stream = new MemoryStream(buffer)) + using (var reader = new ArrowStreamReader(stream)) + { + reader.ReadNextRecordBatch(); + } + }); } - private static long ToInt64LittleEndian(byte[] buffer, int offset) + [Fact] + public async Task EmptyStreamNoSyncRead() { - return BinaryPrimitives.ReadInt64LittleEndian(buffer.AsSpan().Slice(offset)); + using (var stream = new EmptyAsyncOnlyStream()) + { + var reader = new ArrowStreamReader(stream); + var schema = await reader.GetSchema(); + Assert.Null(schema); + } } private class EmptyAsyncOnlyStream : Stream @@ -512,5 +791,20 @@ public override void Flush() { } public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => Task.FromResult(0); } + + private static short ToInt16LittleEndian(byte[] buffer, int offset) + { + return BinaryPrimitives.ReadInt16LittleEndian(buffer.AsSpan().Slice(offset)); + } + + private static int ToInt32LittleEndian(byte[] buffer, int offset) + { + return BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan().Slice(offset)); + } + + private static long ToInt64LittleEndian(byte[] buffer, int offset) + { + return BinaryPrimitives.ReadInt64LittleEndian(buffer.AsSpan().Slice(offset)); + } } }