Skip to content

Commit

Permalink
GH-37570: [MATLAB] Implement isequal for the `arrow.tabular.RecordB…
Browse files Browse the repository at this point in the history
…atch` MATLAB class (#37627)

### Rationale for this change

Following on to #37474, #37446, and #37525, we should implement `isequal` for the `arrow.tabular.RecordBatch` MATLAB class.

### What changes are included in this PR?

1. Implemented `isequal` method for `arrow.tabular.RecordBatch`

### Are these changes tested?

Yes. Added `isequal` unit tests to `tRecordBatch.m`.

### Are there any user-facing changes?

Yes, users can now use `isequal` to compare `arrow.tabular.RecordBatch`es. 

**Example**

```matlab
>> t1 = table(1, "A", false, VariableNames=["Number",  "String", "Logical"]);
>> t2 = table([1; 2], ["A"; "B"], [false; false], VariableNames=["Number",  "String", "Logical"]); 
>> rb1 = arrow.recordBatch(t1);
>> rb2 = arrow.recordBatch(t2);
>> rb3 = arrow.recordBatch(t1);

>> isequal(rb1, rb2)

ans =

  logical

   0

>> isequal(rb1, rb3)

ans =

  logical

   1
```

### Future Directions
1. #37628

* Closes: #37570

Authored-by: Sarah Gilmore <sgilmore@mathworks.com>
Signed-off-by: Kevin Gurney <kgurney@mathworks.com>
  • Loading branch information
sgilmore10 committed Sep 8, 2023
1 parent 50015f0 commit 0e6b8c5
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 2 deletions.
46 changes: 44 additions & 2 deletions matlab/src/matlab/+arrow/+tabular/RecordBatch.m
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
%RECORDBATCH A tabular data structure representing a set of
%arrow.array.Array objects with a fixed schema.


% Licensed to the Apache Software Foundation (ASF) under one or more
% contributor license agreements. See the NOTICE file distributed with
% this work for additional information regarding copyright ownership.
Expand All @@ -15,8 +19,6 @@

classdef RecordBatch < matlab.mixin.CustomDisplay & ...
matlab.mixin.Scalar
%arrow.tabular.RecordBatch A tabular data structure representing
% a set of arrow.array.Array objects with a fixed schema.

properties (Dependent, SetAccess=private, GetAccess=public)
NumColumns
Expand Down Expand Up @@ -91,6 +93,46 @@
function T = toMATLAB(obj)
T = obj.table();
end

function tf = isequal(obj, varargin)
narginchk(2, inf);
tf = false;

schemasToCompare = cell([1 numel(varargin)]);
for ii = 1:numel(varargin)
rb = varargin{ii};
if ~isa(rb, "arrow.tabular.RecordBatch")
% If rb is not a RecordBatch, then it cannot be equal
% to obj. Return false early.
return;
end
schemasToCompare{ii} = rb.Schema;
end

if ~isequal(obj.Schema, schemasToCompare{:})
% If the schemas are not equal, the record batches are not
% equal. Return false early.
return;
end

% Function that extracts the column stored at colIndex from the
% record batch stored at rbIndex in varargin.
getColumnFcn = @(rbIndex, colIndex) varargin{rbIndex}.column(colIndex);

rbIndices = 1:numel(varargin);
for ii = 1:obj.NumColumns
colIndices = repmat(ii, [1 numel(rbIndices)]);
% Gather all columns at index ii across the record
% batches stored in varargin. Compare these columns with
% the corresponding column in obj. If they are not equal,
% then the record batches are not equal. Return false.
columnsToCompare = arrayfun(getColumnFcn, rbIndices, colIndices, UniformOutput=false);
if ~isequal(obj.column(ii), columnsToCompare{:})
return;
end
end
tf = true;
end
end

methods (Access = private)
Expand Down
72 changes: 72 additions & 0 deletions matlab/test/arrow/tabular/tRecordBatch.m
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,78 @@ function ErrorIfColumnNameIsNonScalar(testCase)
testCase.verifyError(@() recordBatch.column(name), "arrow:badsubscript:NonScalar");
end

function TestIsEqualTrue(testCase)
% Verify two record batches are considered equal if:
% 1. They have the same schema
% 2. Their corresponding columns are equal
import arrow.tabular.RecordBatch

a1 = arrow.array([1 2 3]);
a2 = arrow.array(["A" "B" "C"]);
a3 = arrow.array([true true false]);

rb1 = RecordBatch.fromArrays(a1, a2, a3, ...
ColumnNames=["A", "B", "C"]);
rb2 = RecordBatch.fromArrays(a1, a2, a3, ...
ColumnNames=["A", "B", "C"]);
testCase.verifyTrue(isequal(rb1, rb2));

% Compare zero-column record batches
rb3 = RecordBatch.fromArrays();
rb4 = RecordBatch.fromArrays();
testCase.verifyTrue(isequal(rb3, rb4));

% Compare zero-row record batches
a4 = arrow.array([]);
a5 = arrow.array(strings(0, 0));
rb5 = RecordBatch.fromArrays(a4, a5, ColumnNames=["D" "E"]);
rb6 = RecordBatch.fromArrays(a4, a5, ColumnNames=["D" "E"]);
testCase.verifyTrue(isequal(rb5, rb6));

% Call isequal with more than two arguments
testCase.verifyTrue(isequal(rb3, rb4, rb3, rb4));
end

function TestIsEqualFalse(testCase)
% Verify isequal returns false when expected.
import arrow.tabular.RecordBatch

a1 = arrow.array([1 2 3]);
a2 = arrow.array(["A" "B" "C"]);
a3 = arrow.array([true true false]);
a4 = arrow.array(["A" missing "C"]);
a5 = arrow.array([1 2]);
a6 = arrow.array(["A" "B"]);
a7 = arrow.array([true true]);

rb1 = RecordBatch.fromArrays(a1, a2, a3, ...
ColumnNames=["A", "B", "C"]);
rb2 = RecordBatch.fromArrays(a1, a2, a3, ...
ColumnNames=["D", "E", "F"]);
rb3 = RecordBatch.fromArrays(a1, a4, a3, ...
ColumnNames=["A", "B", "C"]);
rb4 = RecordBatch.fromArrays(a5, a6, a7, ...
ColumnNames=["A", "B", "C"]);
rb5 = RecordBatch.fromArrays(a1, a2, a3, a1, ...
ColumnNames=["A", "B", "C", "D"]);

% The column names are not equal
testCase.verifyFalse(isequal(rb1, rb2));

% The columns are not equal
testCase.verifyFalse(isequal(rb1, rb3));

% The number of rows are not equal
testCase.verifyFalse(isequal(rb1, rb4));

% The number of columns are not equal
testCase.verifyFalse(isequal(rb1, rb5));

% Call isequal with more than two arguments
testCase.verifyFalse(isequal(rb1, rb2, rb3, rb4));
end


end

methods
Expand Down

0 comments on commit 0e6b8c5

Please sign in to comment.