Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public override RecordBatch ReadNextRecordBatch()
else if (messageLength == MessageSerializer.IpcContinuationToken)
{
// ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length
if (_buffer.Length <= _bufferPosition + sizeof(int))
if (_buffer.Length < _bufferPosition + sizeof(int))
{
throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message.");
}
Expand Down Expand Up @@ -136,7 +136,7 @@ public override void ReadSchema()
if (schemaMessageLength == MessageSerializer.IpcContinuationToken)
{
// ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length
if (_buffer.Length <= _bufferPosition + sizeof(int))
if (_buffer.Length < _bufferPosition + sizeof(int))
{
throw new InvalidDataException("Corrupted IPC message. Received a continuation token at the end of the message.");
}
Expand Down
214 changes: 214 additions & 0 deletions src/Apache.Arrow/Ipc/ArrowMemoryStreamReaderImplementation.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

using System;
using System.Buffers;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Memory;

namespace Apache.Arrow.Ipc
{
/// <summary>
/// Reads Arrow IPC streams from a <see cref="MemoryStream"/> whose backing buffer is publicly visible.
/// </summary>
/// <remarks>
/// Message metadata can be read directly from the exposed stream buffer, but record batch bodies are
/// still copied into allocator-owned buffers to preserve <see cref="ArrowStreamReader"/> ownership semantics.
/// </remarks>
internal sealed class ArrowMemoryStreamReaderImplementation : ArrowStreamReaderImplementation
{
private readonly MemoryStream _stream;
private readonly Memory<byte> _streamMemory;

public ArrowMemoryStreamReaderImplementation(
MemoryStream stream,
MemoryAllocator allocator,
ICompressionCodecFactory compressionCodecFactory,
bool leaveOpen,
ExtensionTypeRegistry extensionRegistry)
: base(stream, allocator, compressionCodecFactory, leaveOpen, extensionRegistry)
{
_stream = stream;
Comment thread
InCerryGit marked this conversation as resolved.

if (!stream.TryGetBuffer(out ArraySegment<byte> streamBuffer))
{
throw new InvalidOperationException("Expected MemoryStream to expose its backing buffer.");
}

_streamMemory = streamBuffer.Array.AsMemory(streamBuffer.Offset, streamBuffer.Count);
}

public override ValueTask<RecordBatch> ReadNextRecordBatchAsync(CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

try
{
return new ValueTask<RecordBatch>(ReadNextRecordBatch());
}
catch (Exception ex)
{
return new ValueTask<RecordBatch>(Task.FromException<RecordBatch>(ex));
}
}

public override RecordBatch ReadNextRecordBatch()
{
ReadSchema();

ReadResult result = default;
do
{
result = ReadMessageFromMemory();
} while (result.Batch == null && result.MessageLength > 0);

return result.Batch;
}

public override ValueTask<Schema> ReadSchemaAsync(CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();

if (HasReadSchema)
{
return new ValueTask<Schema>(_schema);
}

try
{
ReadSchema();
return new ValueTask<Schema>(_schema);
}
catch (Exception ex)
{
return new ValueTask<Schema>(Task.FromException<Schema>(ex));
}
}

public override void ReadSchema()
{
if (HasReadSchema)
{
return;
}

int schemaMessageLength = ReadMessageLengthFromMemory(throwOnFullRead: true, returnOnEmptyStream: true);
if (schemaMessageLength == 0)
{
return;
}

Memory<byte> schemaBuffer = ReadMemory(schemaMessageLength);
_schema = MessageSerializer.GetSchema(ReadMessage<Flatbuf.Schema>(CreateByteBuffer(schemaBuffer)), ref _dictionaryMemo, _extensionRegistry);
}

private ReadResult ReadMessageFromMemory()
{
int messageLength = ReadMessageLengthFromMemory(throwOnFullRead: false, returnOnEmptyStream: false);
if (messageLength == 0)
{
return default;
}

Memory<byte> messageBuffer = ReadMemory(messageLength);
Flatbuf.Message message = Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuffer));

if (message.BodyLength > int.MaxValue)
{
throw new OverflowException(
$"Arrow IPC message body length ({message.BodyLength}) is larger than " +
$"the maximum supported message size ({int.MaxValue})");
}

int bodyLength = (int)message.BodyLength;
Memory<byte> sourceBodyBuffer = ReadMemory(bodyLength);
IMemoryOwner<byte> bodyBufferOwner = AllocateMessageBodyBuffer(bodyLength);
Memory<byte> bodyBuffer = bodyBufferOwner.Memory.Slice(0, bodyLength);
sourceBodyBuffer.CopyTo(bodyBuffer);
Google.FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuffer);

// Keep stream-reader ownership semantics: batches outlive the source MemoryStream buffer.
return new ReadResult(messageLength, CreateArrowObjectFromMessage(message, bodybb, bodyBufferOwner));
}

private int ReadMessageLengthFromMemory(bool throwOnFullRead, bool returnOnEmptyStream)
{
if (_stream.Position == _stream.Length && returnOnEmptyStream)
{
return 0;
}

if (!TryReadInt32(throwOnFullRead, out int messageLength))
{
return 0;
}

if (messageLength == MessageSerializer.IpcContinuationToken &&
!TryReadInt32(throwOnFullRead, out messageLength))
{
return 0;
}

return messageLength;
}

private bool TryReadInt32(bool throwOnFullRead, out int value)
{
value = 0;

if (!TryReadMemory(sizeof(int), throwOnFullRead, out Memory<byte> buffer))
{
return false;
}

value = BitUtility.ReadInt32(buffer);
return true;
}

private bool TryReadMemory(int length, bool throwOnFullRead, out Memory<byte> buffer)
{
buffer = default;

long remainingLength = _stream.Length - _stream.Position;
if (remainingLength < length)
{
if (throwOnFullRead)
{
throw new InvalidOperationException("Unexpectedly reached the end of the stream before a full buffer was read.");
}

_stream.Position = _stream.Length;
return false;
}

buffer = ReadMemory(length);
return true;
}

private Memory<byte> ReadMemory(int length)
{
if (length == 0)
{
return Memory<byte>.Empty;
}

Memory<byte> buffer = _streamMemory.Slice(checked((int)_stream.Position), length);
_stream.Position += length;
return buffer;
}
}
}
19 changes: 17 additions & 2 deletions src/Apache.Arrow/Ipc/ArrowStreamReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public ArrowStreamReader(Stream stream, MemoryAllocator allocator, ICompressionC
if (stream == null)
throw new ArgumentNullException(nameof(stream));

_implementation = new ArrowStreamReaderImplementation(stream, allocator, compressionCodecFactory, leaveOpen);
_implementation = CreateImplementation(stream, allocator, compressionCodecFactory, leaveOpen, extensionRegistry: null);
}

public ArrowStreamReader(ArrowContext context, Stream stream, bool leaveOpen = false)
Expand All @@ -78,7 +78,7 @@ public ArrowStreamReader(ArrowContext context, Stream stream, bool leaveOpen = f
if (context == null)
throw new ArgumentNullException(nameof(context));

_implementation = new ArrowStreamReaderImplementation(stream, context.Allocator, context.CompressionCodecFactory, leaveOpen, context.ExtensionRegistry);
_implementation = CreateImplementation(stream, context.Allocator, context.CompressionCodecFactory, leaveOpen, context.ExtensionRegistry);
}

public ArrowStreamReader(ReadOnlyMemory<byte> buffer)
Expand All @@ -104,6 +104,21 @@ private protected ArrowStreamReader(ArrowReaderImplementation implementation)
_implementation = implementation;
}

private static ArrowReaderImplementation CreateImplementation(
Stream stream,
MemoryAllocator allocator,
ICompressionCodecFactory compressionCodecFactory,
bool leaveOpen,
ExtensionTypeRegistry extensionRegistry)
{
if (stream is MemoryStream memoryStream && memoryStream.TryGetBuffer(out _))
{
return new ArrowMemoryStreamReaderImplementation(memoryStream, allocator, compressionCodecFactory, leaveOpen, extensionRegistry);
}

return new ArrowStreamReaderImplementation(stream, allocator, compressionCodecFactory, leaveOpen, extensionRegistry);
}

public void Dispose()
{
Dispose(true);
Expand Down
27 changes: 19 additions & 8 deletions src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,10 @@ protected override void Dispose(bool disposing)
}
}

public override async ValueTask<RecordBatch> ReadNextRecordBatchAsync(CancellationToken cancellationToken)
public override ValueTask<RecordBatch> ReadNextRecordBatchAsync(CancellationToken cancellationToken)
{
// TODO: Loop until a record batch is read.
cancellationToken.ThrowIfCancellationRequested();
return await ReadRecordBatchAsync(cancellationToken).ConfigureAwait(false);
return ReadRecordBatchAsync(cancellationToken);
Comment thread
InCerryGit marked this conversation as resolved.
}

public override RecordBatch ReadNextRecordBatch()
Expand All @@ -61,7 +60,7 @@ public override RecordBatch ReadNextRecordBatch()

protected async ValueTask<RecordBatch> ReadRecordBatchAsync(CancellationToken cancellationToken = default)
{
await ReadSchemaAsync().ConfigureAwait(false);
await ReadSchemaAsync(cancellationToken).ConfigureAwait(false);

ReadResult result = default;
do
Expand Down Expand Up @@ -94,7 +93,7 @@ protected async ValueTask<ReadResult> ReadMessageAsync(CancellationToken cancell

int bodyLength = checked((int)message.BodyLength);

IMemoryOwner<byte> bodyBuffOwner = _allocator.Allocate(bodyLength);
IMemoryOwner<byte> bodyBuffOwner = AllocateMessageBodyBuffer(bodyLength);
Memory<byte> bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength);
bytesRead = await BaseStream.ReadFullBufferAsync(bodyBuff, cancellationToken)
.ConfigureAwait(false);
Expand Down Expand Up @@ -145,7 +144,7 @@ protected ReadResult ReadMessage()
}
int bodyLength = (int)message.BodyLength;

IMemoryOwner<byte> bodyBuffOwner = _allocator.Allocate(bodyLength);
IMemoryOwner<byte> bodyBuffOwner = AllocateMessageBodyBuffer(bodyLength);
Memory<byte> bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength);
bytesRead = BaseStream.ReadFullBuffer(bodyBuff);
EnsureFullRead(bodyBuff, bytesRead);
Expand All @@ -157,13 +156,25 @@ protected ReadResult ReadMessage()
return new ReadResult(messageLength, result);
}

public override async ValueTask<Schema> ReadSchemaAsync(CancellationToken cancellationToken = default)
protected IMemoryOwner<byte> AllocateMessageBodyBuffer(int bodyLength)
{
return _allocator.Allocate(bodyLength);
}

public override ValueTask<Schema> ReadSchemaAsync(CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();

if (HasReadSchema)
{
return _schema;
return new ValueTask<Schema>(_schema);
}

return ReadSchemaAsyncCore(cancellationToken);
}

private async ValueTask<Schema> ReadSchemaAsyncCore(CancellationToken cancellationToken)
{
// Figure out length of schema
int schemaMessageLength = await ReadMessageLengthAsync(throwOnFullRead: true, returnOnEmptyStream: true, cancellationToken)
.ConfigureAwait(false);
Expand Down
Loading