Skip to content

Commit

Permalink
ARROW-15071: [C#] Fixed a bug in Column.cs ValidateArrayDataTypes method
Browse files Browse the repository at this point in the history
Fixed a bug in Column.cs ValidateArrayDataTypes method:

From: if (Data.Array(i).Data.DataType != Field.DataType)

To: if (Data.Array(i).Data.DataType.TypeId != Field.DataType.TypeId)

Added unit test in TestTableBasics and others.

Closes #11931 from zixi-bwang/CSharpUnitTesting

Lead-authored-by: Zixi <zixi.bwang@gmail.com>
Co-authored-by: Zixi <89567557+zixi-bwang@users.noreply.github.com>
Signed-off-by: Eric Erhardt <eric.erhardt@microsoft.com>
  • Loading branch information
2 people authored and eerhardt committed Dec 28, 2021
1 parent 968e6ea commit 06b1013
Show file tree
Hide file tree
Showing 6 changed files with 279 additions and 1 deletion.
147 changes: 147 additions & 0 deletions csharp/src/Apache.Arrow/Arrays/ArrayDataTypeComparer.cs
@@ -0,0 +1,147 @@
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

using System;
using Apache.Arrow.Types;

namespace Apache.Arrow
{
internal sealed class ArrayDataTypeComparer :
IArrowTypeVisitor<TimestampType>,
IArrowTypeVisitor<Date32Type>,
IArrowTypeVisitor<Date64Type>,
IArrowTypeVisitor<Time32Type>,
IArrowTypeVisitor<Time64Type>,
IArrowTypeVisitor<FixedSizeBinaryType>,
IArrowTypeVisitor<ListType>,
IArrowTypeVisitor<StructType>
{
private readonly IArrowType _expectedType;
private bool _dataTypeMatch;

public ArrayDataTypeComparer(IArrowType expectedType)
{
_expectedType = expectedType;
}

public bool DataTypeMatch => _dataTypeMatch;

public void Visit(TimestampType actualType)
{
if (_expectedType is TimestampType expectedType
&& expectedType.Timezone == actualType.Timezone
&& expectedType.Unit == actualType.Unit)
{
_dataTypeMatch = true;
}
}

public void Visit(Date32Type actualType)
{
if (_expectedType is Date32Type expectedType
&& expectedType.Unit == actualType.Unit)
{
_dataTypeMatch = true;
}
}

public void Visit(Date64Type actualType)
{
if (_expectedType is Date64Type expectedType
&& expectedType.Unit == actualType.Unit)
{
_dataTypeMatch = true;
}
}

public void Visit(Time32Type actualType)
{
if (_expectedType is Time32Type expectedType
&& expectedType.Unit == actualType.Unit)
{
_dataTypeMatch = true;
}
}

public void Visit(Time64Type actualType)
{
if (_expectedType is Time64Type expectedType
&& expectedType.Unit == actualType.Unit)
{
_dataTypeMatch = true;
}
}

public void Visit(FixedSizeBinaryType actualType)
{
if (_expectedType is FixedSizeBinaryType expectedType
&& expectedType.ByteWidth == actualType.ByteWidth)
{
_dataTypeMatch = true;
}
}

public void Visit(ListType actualType)
{
if (_expectedType is ListType expectedType
&& CompareNested(expectedType, actualType))
{
_dataTypeMatch = true;
}
}

public void Visit(StructType actualType)
{
if (_expectedType is StructType expectedType
&& CompareNested(expectedType, actualType))
{
_dataTypeMatch = true;
}
}

private static bool CompareNested(NestedType expectedType, NestedType actualType)
{
if (expectedType.Fields.Count != actualType.Fields.Count)
{
return false;
}

for (int i = 0; i < expectedType.Fields.Count; i++)
{
if (expectedType.Fields[i].DataType.TypeId != actualType.Fields[i].DataType.TypeId)
{
return false;
}

var dataTypeMatch = FieldComparer.Compare(expectedType.Fields[i], actualType.Fields[i]);

if (!dataTypeMatch)
{
return false;
}
}

return true;
}

public void Visit(IArrowType actualType)
{
if (_expectedType.TypeId == actualType.TypeId)
{
_dataTypeMatch = true;
}
}
}
}
60 changes: 60 additions & 0 deletions csharp/src/Apache.Arrow/Arrays/FieldComparer.cs
@@ -0,0 +1,60 @@
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

using System.Linq;

namespace Apache.Arrow
{
internal static class FieldComparer
{
public static bool Compare(Field expected, Field actual)
{
if (ReferenceEquals(expected, actual))
{
return true;
}

if (expected.Name != actual.Name || expected.IsNullable != actual.IsNullable ||
expected.HasMetadata != actual.HasMetadata)
{
return false;
}

if (expected.HasMetadata)
{
if (expected.Metadata.Count != actual.Metadata.Count)
{
return false;
}

if (!expected.Metadata.Keys.All(k => actual.Metadata.ContainsKey(k) && expected.Metadata[k] == actual.Metadata[k]))
{
return false;
}
}

var dataTypeComparer = new ArrayDataTypeComparer(expected.DataType);

actual.DataType.Accept(dataTypeComparer);

if (!dataTypeComparer.DataTypeMatch)
{
return false;
}

return true;
}
}
}
11 changes: 10 additions & 1 deletion csharp/src/Apache.Arrow/Column.cs
Expand Up @@ -60,9 +60,18 @@ public Column Slice(int offset)

private bool ValidateArrayDataTypes()
{
var dataTypeComparer = new ArrayDataTypeComparer(Field.DataType);

for (int i = 0; i < Data.ArrayCount; i++)
{
if (Data.Array(i).Data.DataType != Field.DataType)
if (Data.Array(i).Data.DataType.TypeId != Field.DataType.TypeId)
{
return false;
}

Data.Array(i).Data.DataType.Accept(dataTypeComparer);

if (!dataTypeComparer.DataTypeMatch)
{
return false;
}
Expand Down
18 changes: 18 additions & 0 deletions csharp/test/Apache.Arrow.Tests/ArrayBuilderTests.cs
Expand Up @@ -101,6 +101,24 @@ public void ListArrayBuilder()
new List<string> { "444", null, "555", "666" },
ConvertStringArrayToList(list.GetSlicedValues(3) as StringArray));

Assert.Throws<ArgumentOutOfRangeException>(() => list.GetValueLength(-1));
Assert.Throws<ArgumentOutOfRangeException>(() => list.GetValueLength(4));

listBuilder.Resize(2);
var truncatedList = listBuilder.Build();

Assert.Equal(
new List<string> { "22", "33", "444", null, "555", "666" },
ConvertStringArrayToList(truncatedList.GetSlicedValues(2) as StringArray));

Assert.Throws<ArgumentOutOfRangeException>(() => truncatedList.GetSlicedValues(-1));
Assert.Throws<ArgumentOutOfRangeException>(() => truncatedList.GetSlicedValues(3));

listBuilder.Clear();
var emptyList = listBuilder.Build();

Assert.Equal(0, emptyList.Length);

List<string> ConvertStringArrayToList(StringArray array)
{
var length = array.Length;
Expand Down
16 changes: 16 additions & 0 deletions csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs
Expand Up @@ -15,7 +15,9 @@

using Apache.Arrow.Ipc;
using Apache.Arrow.Memory;
using Apache.Arrow.Types;
using System;
using System.Collections.Generic;
using System.IO;
using System.Threading.Tasks;
using Xunit;
Expand Down Expand Up @@ -155,5 +157,19 @@ public async Task TestReadMultipleRecordBatchAsync()
ArrowReaderVerifier.CompareBatches(originalBatch1, readBatch3);
}
}

[Fact]
public void TestRecordBatchBasics()
{
RecordBatch recordBatch = TestData.CreateSampleRecordBatch(length: 1);
Assert.Throws<ArgumentOutOfRangeException>(() => new RecordBatch(recordBatch.Schema, recordBatch.Arrays, -1));

var col1 = recordBatch.Column(0);
var col2 = recordBatch.Column("list0");
ArrowReaderVerifier.CompareArrays(col1, col2);

recordBatch.Dispose();
}

}
}
28 changes: 28 additions & 0 deletions csharp/test/Apache.Arrow.Tests/TableTests.cs
Expand Up @@ -51,6 +51,34 @@ public void TestTableBasics()
Assert.Equal(1, table.ColumnCount);
}

[Fact]
public void TestTableFromRecordBatches()
{
RecordBatch recordBatch1 = TestData.CreateSampleRecordBatch(length: 10, true);
RecordBatch recordBatch2 = TestData.CreateSampleRecordBatch(length: 10, true);
IList<RecordBatch> recordBatches = new List<RecordBatch>() { recordBatch1, recordBatch2 };

Table table1 = Table.TableFromRecordBatches(recordBatch1.Schema, recordBatches);
Assert.Equal(20, table1.RowCount);
Assert.Equal(21, table1.ColumnCount);

FixedSizeBinaryType type = new FixedSizeBinaryType(17);
Field newField1 = new Field(type.Name, type, false);
Schema newSchema1 = recordBatch1.Schema.SetField(20, newField1);
Assert.Throws<ArgumentException>(() => Table.TableFromRecordBatches(newSchema1, recordBatches));

List<Field> fields = new List<Field>();
Field.Builder fieldBuilder = new Field.Builder();
fields.Add(fieldBuilder.Name("Ints").DataType(Int32Type.Default).Nullable(true).Build());
fieldBuilder = new Field.Builder();
fields.Add(fieldBuilder.Name("Strings").DataType(StringType.Default).Nullable(true).Build());
StructType structType = new StructType(fields);

Field newField2 = new Field(structType.Name, structType, false);
Schema newSchema2 = recordBatch1.Schema.SetField(16, newField2);
Assert.Throws<ArgumentException>(() => Table.TableFromRecordBatches(newSchema2, recordBatches));
}

[Fact]
public void TestTableAddRemoveAndSetColumn()
{
Expand Down

0 comments on commit 06b1013

Please sign in to comment.