Skip to content

Commit

Permalink
Address PR feedback.
Browse files Browse the repository at this point in the history
Change the array creation test code to use the visitor pattern.
  • Loading branch information
eerhardt committed Mar 5, 2019
1 parent 6ebc80e commit 558ec56
Showing 1 changed file with 95 additions and 82 deletions.
177 changes: 95 additions & 82 deletions csharp/test/Apache.Arrow.Tests/TestData.cs
Expand Up @@ -48,7 +48,7 @@ public static RecordBatch CreateSampleRecordBatch(int length)

IEnumerable<IArrowArray> arrays = CreateArrays(schema, length);

return new RecordBatch(builder.Build(), arrays, length);
return new RecordBatch(schema, arrays, length);
}

private static Field CreateField(ArrowType type)
Expand All @@ -70,88 +70,101 @@ private static IEnumerable<IArrowArray> CreateArrays(Schema schema, int length)

private static IArrowArray CreateArray(Field field, int length)
{
switch (field.DataType.TypeId)
var creator = new ArrayBufferCreator(length);
field.DataType.Accept(creator);

ArrayData data = new ArrayData(field.DataType, length, 0, 0,
new[] { ArrowBuffer.Empty, creator.Buffer });

return ArrowArrayFactory.BuildArray(data);
}

private class ArrayBufferCreator :
IArrowTypeVisitor<BooleanType>,
IArrowTypeVisitor<Int8Type>,
IArrowTypeVisitor<Int16Type>,
IArrowTypeVisitor<Int32Type>,
IArrowTypeVisitor<Int64Type>,
IArrowTypeVisitor<UInt8Type>,
IArrowTypeVisitor<UInt16Type>,
IArrowTypeVisitor<UInt32Type>,
IArrowTypeVisitor<UInt64Type>,
IArrowTypeVisitor<FloatType>,
IArrowTypeVisitor<DoubleType>
{
private readonly int _length;
public ArrowBuffer Buffer { get; private set; }

public ArrayBufferCreator(int length)
{
_length = length;
}

public void Visit(BooleanType type)
{
ArrowBuffer.Builder<bool> builder = new ArrowBuffer.Builder<bool>(_length);
for (int i = 0; i < _length; i++)
builder.Append(i % 2 == 0);

Buffer = builder.Build();
}

public void Visit(Int8Type type)
{
ArrowBuffer.Builder<sbyte> builder = new ArrowBuffer.Builder<sbyte>(_length);
for (int i = 0; i < _length; i++)
builder.Append((sbyte)i);

Buffer = builder.Build();
}

public void Visit(UInt8Type type)
{
ArrowBuffer.Builder<byte> builder = new ArrowBuffer.Builder<byte>(_length);
for (int i = 0; i < _length; i++)
builder.Append((byte)i);

Buffer = builder.Build();
}

public void Visit(Int16Type type)
{
ArrowBuffer.Builder<short> builder = new ArrowBuffer.Builder<short>(_length);
for (int i = 0; i < _length; i++)
builder.Append((short)i);

Buffer = builder.Build();
}

public void Visit(UInt16Type type)
{
ArrowBuffer.Builder<ushort> builder = new ArrowBuffer.Builder<ushort>(_length);
for (int i = 0; i < _length; i++)
builder.Append((ushort)i);

Buffer = builder.Build();
}

public void Visit(Int32Type type) => CreateNumberArray<int>(type);
public void Visit(UInt32Type type) => CreateNumberArray<uint>(type);
public void Visit(Int64Type type) => CreateNumberArray<long>(type);
public void Visit(UInt64Type type) => CreateNumberArray<ulong>(type);
public void Visit(FloatType type) => CreateNumberArray<float>(type);
public void Visit(DoubleType type) => CreateNumberArray<double>(type);

private void CreateNumberArray<T>(IArrowType type)
where T : struct
{
ArrowBuffer.Builder<T> builder = new ArrowBuffer.Builder<T>(_length);
for (int i = 0; i < _length; i++)
builder.Append((T)Convert.ChangeType(i, typeof(T)));

Buffer = builder.Build();
}

public void Visit(IArrowType type)
{
case ArrowTypeId.Boolean:
ArrowBuffer.Builder<bool> boolBuilder = new ArrowBuffer.Builder<bool>(length);
for (int i = 0; i < length; i++)
boolBuilder.Append(i % 2 == 0);
return new BooleanArray(boolBuilder.Build(), ArrowBuffer.Empty, length, 0, 0);
case ArrowTypeId.UInt8:
ArrowBuffer.Builder<byte> byteBuilder = new ArrowBuffer.Builder<byte>(length);
for (int i = 0; i < length; i++)
byteBuilder.Append((byte)i);
return new UInt8Array(byteBuilder.Build(), ArrowBuffer.Empty, length, 0, 0);
case ArrowTypeId.Int8:
ArrowBuffer.Builder<sbyte> sbyteBuilder = new ArrowBuffer.Builder<sbyte>(length);
for (int i = 0; i < length; i++)
sbyteBuilder.Append((sbyte)i);
return new Int8Array(sbyteBuilder.Build(), ArrowBuffer.Empty, length, 0, 0);
case ArrowTypeId.UInt16:
ArrowBuffer.Builder<ushort> ushortBuilder = new ArrowBuffer.Builder<ushort>(length);
for (int i = 0; i < length; i++)
ushortBuilder.Append((ushort)i);
return new UInt16Array(ushortBuilder.Build(), ArrowBuffer.Empty, length, 0, 0);
case ArrowTypeId.Int16:
ArrowBuffer.Builder<short> shortBuilder = new ArrowBuffer.Builder<short>(length);
for (int i = 0; i < length; i++)
shortBuilder.Append((short)i);
return new Int16Array(shortBuilder.Build(), ArrowBuffer.Empty, length, 0, 0);
case ArrowTypeId.UInt32:
ArrowBuffer.Builder<uint> uintBuilder = new ArrowBuffer.Builder<uint>(length);
for (int i = 0; i < length; i++)
uintBuilder.Append((uint)i);
return new UInt32Array(uintBuilder.Build(), ArrowBuffer.Empty, length, 0, 0);
case ArrowTypeId.Int32:
ArrowBuffer.Builder<int> intBuilder = new ArrowBuffer.Builder<int>(length);
for (int i = 0; i < length; i++)
intBuilder.Append(i);
return new Int32Array(intBuilder.Build(), ArrowBuffer.Empty, length, 0, 0);
case ArrowTypeId.UInt64:
ArrowBuffer.Builder<ulong> ulongBuilder = new ArrowBuffer.Builder<ulong>(length);
for (int i = 0; i < length; i++)
ulongBuilder.Append((ulong)i);
return new UInt64Array(ulongBuilder.Build(), ArrowBuffer.Empty, length, 0, 0);
case ArrowTypeId.Int64:
ArrowBuffer.Builder<long> longBuilder = new ArrowBuffer.Builder<long>(length);
for (int i = 0; i < length; i++)
longBuilder.Append(i);
return new Int64Array(longBuilder.Build(), ArrowBuffer.Empty, length, 0, 0);
case ArrowTypeId.Float:
ArrowBuffer.Builder<float> floatBuilder = new ArrowBuffer.Builder<float>(length);
for (int i = 0; i < length; i++)
floatBuilder.Append(i);
return new FloatArray(floatBuilder.Build(), ArrowBuffer.Empty, length, 0, 0);
case ArrowTypeId.Double:
ArrowBuffer.Builder<double> doubleBuilder = new ArrowBuffer.Builder<double>(length);
for (int i = 0; i < length; i++)
doubleBuilder.Append(i);
return new DoubleArray(doubleBuilder.Build(), ArrowBuffer.Empty, length, 0, 0);
//TODO: there is no DecimalArray
//case ArrowTypeId.Decimal:
// ArrowBuffer.Builder<decimal> builder = new ArrowBuffer.Builder<decimal>(length);
// for (int i = 0; i < length; i++)
// builder.Append(i);
// return new DecimalArray
// break;

//case ArrowTypeId.HalfFloat:
// break;
//case ArrowTypeId.String:
// break;
//case ArrowTypeId.Date32:
// break;
//case ArrowTypeId.Date64:
// break;
//case ArrowTypeId.Time32:
// break;
//case ArrowTypeId.Time64:
// break;
//case ArrowTypeId.Timestamp:
// break;

default:
throw new NotSupportedException($"Could not create an array for type '{field.DataType.TypeId}'");
throw new NotImplementedException();
}
}
}
Expand Down

0 comments on commit 558ec56

Please sign in to comment.