Skip to content
Permalink
Browse files
ARROW-5019: [C#] ArrowStreamWriter doesn't work on a non-seekable stream
Allow ArrowStreamWriter to write to a non-seekable stream, like a network stream.

@chutchinson @stephentoub @pgovind

Author: Eric Erhardt <eric.erhardt@microsoft.com>

Closes #4052 from eerhardt/WriteToNetworkStream and squashes the following commits:

e7125bf <Eric Erhardt> PR feedback
7333b8c <Eric Erhardt> ArrowStreamWriter doesn't work on a non-seekable stream
  • Loading branch information
eerhardt authored and kou committed Apr 4, 2019
1 parent 47cc7e5 commit 384a3b070bab958ae0a7d0a45f148d2ef1f86da6
@@ -111,7 +111,12 @@ public Builder<T> Clear()

public ArrowBuffer Build(MemoryPool pool = default)
{
var length = BitUtility.RoundUpToMultipleOf64(_buffer.Length);
int length;
checked
{
length = (int)BitUtility.RoundUpToMultipleOf64(_buffer.Length);
}

var memoryPool = pool ?? MemoryPool.Default.Value;
var memory = memoryPool.Allocate(length);

@@ -99,7 +99,7 @@ public static int CountBits(ReadOnlySpan<byte> data)
/// </summary>
/// <param name="n">Integer to round.</param>
/// <returns>Integer rounded to the nearest multiple of 64.</returns>
public static int RoundUpToMultipleOf64(int n) =>
public static long RoundUpToMultipleOf64(long n) =>
RoundUpToMultiplePowerOfTwo(n, 64);

/// <summary>
@@ -111,7 +111,7 @@ public static int CountBits(ReadOnlySpan<byte> data)
/// <param name="n">Integer to round up.</param>
/// <param name="factor">Power of two factor to round up to.</param>
/// <returns>Integer rounded up to the nearest power of two.</returns>
public static int RoundUpToMultiplePowerOfTwo(int n, int factor)
public static long RoundUpToMultiplePowerOfTwo(long n, int factor)
{
// Assert that factor is a power of two.
Debug.Assert(factor > 0 && (factor & (factor - 1)) == 0);
@@ -16,14 +16,17 @@
using System;
using System.Buffers.Binary;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Threading;
using System.Threading.Tasks;

namespace Apache.Arrow.Ipc
{
public class ArrowFileWriter: ArrowStreamWriter
{
{
private long _currentRecordBatchOffset = -1;

private bool HasWrittenHeader { get; set; }
private bool HasWrittenFooter { get; set; }

@@ -67,10 +70,35 @@ public override async Task WriteRecordBatchAsync(RecordBatch recordBatch, Cancel

cancellationToken.ThrowIfCancellationRequested();

var block = await WriteRecordBatchInternalAsync(recordBatch, cancellationToken)
await WriteRecordBatchInternalAsync(recordBatch, cancellationToken)
.ConfigureAwait(false);
}

private protected override void StartingWritingRecordBatch()
{
_currentRecordBatchOffset = BaseStream.Position;
}

private protected override void FinishedWritingRecordBatch(long bodyLength, long metadataLength)
{
// Record batches only appear after a Schema is written, so the record batch offsets must
// always be greater than 0.
Debug.Assert(_currentRecordBatchOffset > 0, "_currentRecordBatchOffset must be positive.");

int metadataLengthInt;
checked
{
metadataLengthInt = (int)metadataLength;
}

var block = new Block(
offset: _currentRecordBatchOffset,
length: bodyLength,
metadataLength: metadataLengthInt);

RecordBatchBlocks.Add(block);

_currentRecordBatchOffset = -1;
}

public async Task WriteFooterAsync(CancellationToken cancellationToken = default)
@@ -112,7 +140,7 @@ private async ValueTask WriteFooterAsync(Schema schema, CancellationToken cancel
foreach (var recordBatch in RecordBatchBlocks)
{
Flatbuf.Block.CreateBlock(
Builder, recordBatch.Offset, recordBatch.MetadataLength, recordBatch.Length);
Builder, recordBatch.Offset, recordBatch.MetadataLength, recordBatch.BodyLength);
}

var recordBatchesVectorOffset = Builder.EndVector();
@@ -141,8 +169,13 @@ private async ValueTask WriteFooterAsync(Schema schema, CancellationToken cancel

await Buffers.RentReturnAsync(4, async (buffer) =>
{
BinaryPrimitives.WriteInt32LittleEndian(buffer.Span,
Convert.ToInt32(BaseStream.Position - offset));
int footerLength;
checked
{
footerLength = (int)(BaseStream.Position - offset);
}

BinaryPrimitives.WriteInt32LittleEndian(buffer.Span, footerLength);

await BaseStream.WriteAsync(buffer, cancellationToken).ConfigureAwait(false);
}).ConfigureAwait(false);
@@ -125,20 +125,6 @@ public void Visit(IArrowArray array)
}
}

protected struct Block
{
public readonly int Offset;
public readonly int Length;
public readonly int MetadataLength;

public Block(int offset, int length, int metadataLength)
{
Offset = offset;
Length = length;
MetadataLength = metadataLength;
}
}

protected Stream BaseStream { get; }

protected ArrayPool<byte> Buffers { get; }
@@ -174,7 +160,7 @@ public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen)
_fieldTypeBuilder = new ArrowTypeFlatbufferBuilder(Builder);
}

protected virtual async Task<Block> WriteRecordBatchInternalAsync(RecordBatch recordBatch,
private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBatch,
CancellationToken cancellationToken = default)
{
// TODO: Truncate buffers with extraneous padding / unused capacity
@@ -228,63 +214,55 @@ public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen)

// Serialize record batch

StartingWritingRecordBatch();

var recordBatchOffset = Flatbuf.RecordBatch.CreateRecordBatch(Builder, recordBatch.Length,
fieldNodesVectorOffset,
buffersVectorOffset);

var metadataOffset = BaseStream.Position;

await WriteMessageAsync(Flatbuf.MessageHeader.RecordBatch,
long metadataLength = await WriteMessageAsync(Flatbuf.MessageHeader.RecordBatch,
recordBatchOffset, recordBatchBuilder.TotalLength,
cancellationToken).ConfigureAwait(false);

var metadataLength = BaseStream.Position - metadataOffset;

// Write buffer data

var lengthOffset = BaseStream.Position;
long bodyLength = 0;

for (var i = 0; i < buffers.Count; i++)
{
if (buffers[i].DataBuffer.IsEmpty)
continue;


await WriteBufferAsync(buffers[i].DataBuffer, cancellationToken).ConfigureAwait(false);
bodyLength += buffers[i].DataBuffer.Length;
}

// Write padding so the record batch message body length is a multiple of 8 bytes

var bodyLength = Convert.ToInt32(BaseStream.Position - lengthOffset);
var bodyPaddingLength = CalculatePadding(bodyLength);
int bodyPaddingLength = CalculatePadding(bodyLength);

await WritePaddingAsync(bodyPaddingLength).ConfigureAwait(false);

return new Block(
offset: Convert.ToInt32(metadataOffset),
length: bodyLength + bodyPaddingLength,
metadataLength: Convert.ToInt32(metadataLength));
FinishedWritingRecordBatch(bodyLength + bodyPaddingLength, metadataLength);
}

private protected virtual void StartingWritingRecordBatch()
{
}

private protected virtual void FinishedWritingRecordBatch(long bodyLength, long metadataLength)
{
}

public virtual Task WriteRecordBatchAsync(RecordBatch recordBatch, CancellationToken cancellationToken = default)
{
return WriteRecordBatchInternalAsync(recordBatch, cancellationToken);
}
public Task WriteBufferAsync(ArrowBuffer arrowBuffer, CancellationToken cancellationToken = default)

public async Task WriteBufferAsync(ArrowBuffer arrowBuffer, CancellationToken cancellationToken = default)
{
byte[] buffer = null;
try
{
var span = arrowBuffer.Span;
buffer = ArrayPool<byte>.Shared.Rent(span.Length);
span.CopyTo(buffer);
return BaseStream.WriteAsync(buffer, 0, span.Length, cancellationToken);
}
finally
{
ArrayPool<byte>.Shared.Return(buffer);
}
await BaseStream.WriteAsync(arrowBuffer.Memory, cancellationToken)
.ConfigureAwait(false);
}

private protected Offset<Flatbuf.Schema> SerializeSchema(Schema schema)
@@ -319,7 +297,6 @@ private protected Offset<Flatbuf.Schema> SerializeSchema(Schema schema)
Builder, endianness, fieldsVectorOffset);
}


private async ValueTask<Offset<Flatbuf.Schema>> WriteSchemaAsync(Schema schema, CancellationToken cancellationToken)
{
Builder.Clear();
@@ -336,7 +313,13 @@ await WriteMessageAsync(Flatbuf.MessageHeader.Schema, schemaOffset, 0, cancellat
return schemaOffset;
}

private async ValueTask WriteMessageAsync<T>(
/// <summary>
/// Writes the message to the <see cref="BaseStream"/>.
/// </summary>
/// <returns>
/// The number of bytes written to the stream.
/// </returns>
private async ValueTask<long> WriteMessageAsync<T>(
Flatbuf.MessageHeader headerType, Offset<T> headerOffset, int bodyLength,
CancellationToken cancellationToken)
where T: struct
@@ -359,6 +342,11 @@ await WriteMessageAsync(Flatbuf.MessageHeader.Schema, schemaOffset, 0, cancellat

await BaseStream.WriteAsync(messageData, cancellationToken).ConfigureAwait(false);
await WritePaddingAsync(messagePaddingLength).ConfigureAwait(false);

checked
{
return 4 + messageData.Length + messagePaddingLength;
}
}

private protected async ValueTask WriteFlatBufferAsync(CancellationToken cancellationToken = default)
@@ -368,8 +356,14 @@ private protected async ValueTask WriteFlatBufferAsync(CancellationToken cancell
await BaseStream.WriteAsync(segment, cancellationToken).ConfigureAwait(false);
}

protected int CalculatePadding(int offset, int alignment = 8) =>
BitUtility.RoundUpToMultiplePowerOfTwo(offset, alignment) - offset;
protected int CalculatePadding(long offset, int alignment = 8)
{
long result = BitUtility.RoundUpToMultiplePowerOfTwo(offset, alignment) - offset;
checked
{
return (int)result;
}
}

protected Task WritePaddingAsync(int length)
{
@@ -17,24 +17,24 @@

namespace Apache.Arrow.Ipc
{
internal class Block
internal readonly struct Block
{
public long Offset { get; }
public int MetaDataLength { get; }
public long BodyLength { get; }
public readonly long Offset;
public readonly long BodyLength;
public readonly int MetadataLength;

public Block(long offset, int metadataLength, long bodyLength)
public Block(long offset, long length, int metadataLength)
{
Offset = offset;
MetaDataLength = metadataLength;
BodyLength = bodyLength;
BodyLength = length;
MetadataLength = metadataLength;
}

public Block(Flatbuf.Block block)
{
Offset = Convert.ToInt32(block.Offset);
MetaDataLength = Convert.ToInt32(block.MetaDataLength);
BodyLength = Convert.ToInt32(block.BodyLength);
Offset = block.Offset;
BodyLength = block.BodyLength;
MetadataLength = block.MetaDataLength;
}
}
}
@@ -16,6 +16,9 @@
using Apache.Arrow.Ipc;
using System;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Threading.Tasks;
using Xunit;

namespace Apache.Arrow.Tests
@@ -48,5 +51,59 @@ public void Ctor_LeaveOpenTrue_StreamValidOnDispose()
new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true).Dispose();
Assert.Equal(0, stream.Position);
}

[Fact]
public async Task CanWriteToNetworkStream()
{
RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100);

const int port = 32154;
TcpListener listener = new TcpListener(IPAddress.Loopback, port);
listener.Start();

using (TcpClient sender = new TcpClient())
{
sender.Connect(IPAddress.Loopback, port);
NetworkStream stream = sender.GetStream();

using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema))
{
await writer.WriteRecordBatchAsync(originalBatch);
stream.Flush();
}
}

using (TcpClient receiver = listener.AcceptTcpClient())
{
NetworkStream stream = receiver.GetStream();
using (var reader = new ArrowStreamReader(stream))
{
RecordBatch newBatch = reader.ReadNextRecordBatch();
ArrowReaderVerifier.CompareBatches(originalBatch, newBatch);
}
}
}

[Fact]
public async Task WriteEmptyBatch()
{
RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 0);

using (MemoryStream stream = new MemoryStream())
{
using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true))
{
await writer.WriteRecordBatchAsync(originalBatch);
}

stream.Position = 0;

using (var reader = new ArrowStreamReader(stream))
{
RecordBatch newBatch = reader.ReadNextRecordBatch();
ArrowReaderVerifier.CompareBatches(originalBatch, newBatch);
}
}
}
}
}

0 comments on commit 384a3b0

Please sign in to comment.