Skip to content

Commit

Permalink
GH-37477: [MATLAB] Add AllowNonScalar name-value pair to arrow.inte…
Browse files Browse the repository at this point in the history
…rnal.validate.index.* validation functions (#37482)

### Rationale for this change

Per #37475 (comment), we should consider adding a name-value pair like `AllowNonScalar = true | false` to the `arrow.internal.validate.index.*` validation functions since it is relatively common to want to explicitly allow (or disallow) non-scalar inputs to indexing functions (e.g. the `column` method of `RecordBatch` should only support scalar index values).

### What changes are included in this PR?

1. Modified all functions within the `arrow.internal.valdiate.index` package (i.e. `numeric()`, `string()`, and `numericOrString()`)  to accept a name-value pair called `AllowNonScalar`. This name-value pair can be set to `logical` scalar, and by default it's set to `true`.
2. Updated the `column()` method in `RecordBatch` to pass `AllowNonScalar=false` to `numericOrString()`.
3. Updated the `field()` method in `RecordBatch` to pass `AllowNonScalar=false` to `numericOrString()`.

**NOTE:** While character row vectors (e.g. `'ABC'`) are not scalar, they are equivalent to scalar `string` arrays. Therefore, both `string()` and `numericOrString()` do not error if given a character row vector as the index to validate and `AllowNonScalar=false`. 

### Are these changes tested?

Yes. Added new test cases to `tNumeric.m`, `tString.m` and `tNumericOrString.m`

### Are there any user-facing changes?

No.

* Closes: #37477

Authored-by: Sarah Gilmore <sgilmore@mathworks.com>
Signed-off-by: Kevin Gurney <kgurney@mathworks.com>
  • Loading branch information
sgilmore10 authored Aug 31, 2023
1 parent 323a92f commit 2f3db65
Show file tree
Hide file tree
Showing 10 changed files with 191 additions and 22 deletions.
13 changes: 12 additions & 1 deletion matlab/src/matlab/+arrow/+internal/+validate/+index/numeric.m
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,25 @@
% implied. See the License for the specific language governing
% permissions and limitations under the License.

function index = numeric(index, intType)
function index = numeric(index, intType, opts)
arguments
index
intType(1, 1) string
opts.AllowNonScalar(1, 1) = true
end

if ~isnumeric(index)
errid = "arrow:badsubscript:NonNumeric";
msg = "Expected numeric index values.";
error(errid, msg);
end

if ~opts.AllowNonScalar && ~isscalar(index)
errid = "arrow:badsubscript:NonScalar";
msg = "Expected a scalar index value.";
error(errid, msg);
end

% Convert to full storage if sparse
if issparse(index)
index = full(index);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,21 @@
% implied. See the License for the specific language governing
% permissions and limitations under the License.

function idx = numericOrString(idx, numericIndexType)
function idx = numericOrString(idx, numericIndexType, opts)
arguments
idx
numericIndexType(1, 1) string
opts.AllowNonScalar(1, 1) logical = true
end

import arrow.internal.validate.*

opts = namedargs2cell(opts);
idx = convertCharsToStrings(idx);
if isnumeric(idx)
idx = index.numeric(idx, numericIndexType);
idx = index.numeric(idx, numericIndexType, opts{:});
elseif isstring(idx)
idx = index.string(idx);
idx = index.string(idx, opts{:});
else
errid = "arrow:badsubscript:UnsupportedIndexType";
msg = "Indices must be positive integers or nonmissing strings.";
Expand Down
13 changes: 11 additions & 2 deletions matlab/src/matlab/+arrow/+internal/+validate/+index/string.m
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,21 @@
% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
% implied. See the License for the specific language governing
% permissions and limitations under the License.
function index = string(index)
function index = string(index, opts)
arguments
index
opts.AllowNonScalar(1, 1) = true
end

index = convertCharsToStrings(index);

index = reshape(index, [], 1);

if ~opts.AllowNonScalar && ~isscalar(index)
errid = "arrow:badsubscript:NonScalar";
msg = "Expected a scalar index value.";
error(errid, msg);
end

if ~isstring(index)
errid = "arrow:badsubscript:NonString";
msg = "Expected string index values.";
Expand Down
4 changes: 1 addition & 3 deletions matlab/src/matlab/+arrow/+tabular/RecordBatch.m
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,12 @@
function arrowArray = column(obj, idx)
import arrow.internal.validate.*

idx = index.numericOrString(idx, "int32");
idx = index.numericOrString(idx, "int32", AllowNonScalar=false);

if isnumeric(idx)
validateattributes(idx, "int32", "scalar");
args = struct(Index=idx);
[proxyID, typeID] = obj.Proxy.getColumnByIndex(args);
else
validateattributes(idx, "string", "scalar");
args = struct(Name=idx);
[proxyID, typeID] = obj.Proxy.getColumnByName(args);
end
Expand Down
4 changes: 1 addition & 3 deletions matlab/src/matlab/+arrow/+tabular/Schema.m
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,12 @@
function F = field(obj, idx)
import arrow.internal.validate.*

idx = index.numericOrString(idx, "int32");
idx = index.numericOrString(idx, "int32", AllowNonScalar=false);

if isnumeric(idx)
validateattributes(idx, "int32", "scalar");
args = struct(Index=idx);
proxyID = obj.Proxy.getFieldByIndex(args);
else
validateattributes(idx, "string", "scalar");
args = struct(Name=idx);
proxyID = obj.Proxy.getFieldByName(args);
end
Expand Down
39 changes: 38 additions & 1 deletion matlab/test/arrow/internal/validate/index/tNumeric.m
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ function ErrorIfNonNumeric(testCase)

import arrow.internal.validate.index.numeric

fcn = @() numeric(false);
fcn = @() numeric(false, "int32");
testCase.verifyError(fcn, "arrow:badsubscript:NonNumeric");
end

Expand Down Expand Up @@ -161,5 +161,42 @@ function OutputShape(testCase)
actual = numeric(original, "int32");
testCase.verifyEqual(actual, expected);
end

function AllowNonScalarTrue(testCase)
% Verify numeric() behaves as expected provided
% AllowNonScalar=true.

import arrow.internal.validate.index.numeric

% Provide a nonscalar array
original = [1 2 3]';
expected = int32([1 2 3])';
actual = numeric(original, "int32", AllowNonScalar=true);
testCase.verifyEqual(actual, expected);

% Provide a scalar array
original = 1;
expected = int32(1);
actual = numeric(original, "int32", AllowNonScalar=true);
testCase.verifyEqual(actual, expected);
end

function AllowNonScalarFalse(testCase)
% Verify numeric() behaves as expected when provided
% AllowNonScalar=false.

import arrow.internal.validate.index.numeric

% Should throw an error when provided a nonscalar double array
original = [1 2 3]';
fcn = @() numeric(original, "int32", AllowNonScalar=false);
testCase.verifyError(fcn, "arrow:badsubscript:NonScalar");

% Should not throw an error when provided a scalar double array
original = 1;
expected = int32(1);
actual = numeric(original, "int32", AllowNonScalar=true);
testCase.verifyEqual(actual, expected);
end
end
end
66 changes: 66 additions & 0 deletions matlab/test/arrow/internal/validate/index/tNumericOrString.m
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,71 @@ function ValidStringArray(testCase)

testCase.verifyEqual(numericOrString(["B" "A"], "int32"), ["B", "A"]');
end

function AllowNonScalarTrue(testCase)
% Verify numericOrString() behaves as expected provided
% AllowNonScalar=true.

import arrow.internal.validate.index.numericOrString

% Provide a nonscalar double array
original = [1 2 3]';
expected = int32([1 2 3])';
actual = numericOrString(original, "int32", AllowNonScalar=true);
testCase.verifyEqual(actual, expected);

% Provide a scalar double array
original = 1;
expected = int32(1);
actual = numericOrString(original, "int32", AllowNonScalar=true);
testCase.verifyEqual(actual, expected);

% Provide a nonscalar string array
original = ["A", "B", "C"];
expected = ["A", "B", "C"]';
actual = numericOrString(original, "int32", AllowNonScalar=true);
testCase.verifyEqual(actual, expected);

% Provide a scalar string array
original = "A";
expected = "A";
actual = numericOrString(original, "int32", AllowNonScalar=true);
testCase.verifyEqual(actual, expected);
end

function AllowNonScalarFalse(testCase)
% Verify numericOrString() behaves as expected when provided
% AllowNonScalar=false.

import arrow.internal.validate.index.numericOrString

% Should throw an error when provided a nonscalar double array
original = [1 2 3]';
fcn = @() numericOrString(original, "int32", AllowNonScalar=false);
testCase.verifyError(fcn, "arrow:badsubscript:NonScalar");

% Should not throw an error when provided a scalar double array
original = 1;
expected = int32(1);
actual = numericOrString(original, "int32", AllowNonScalar=true);
testCase.verifyEqual(actual, expected);

% Should throw an error if provided a nonscalar string array
original = ["A", "B", "C"];
fcn = @() numericOrString(original, "int32", AllowNonScalar=false);
testCase.verifyError(fcn, "arrow:badsubscript:NonScalar");

% Should not throw an error if provided a scalar string array
original = "A";
expected = "A";
actual = numericOrString(original, "int32", AllowNonScalar=false);
testCase.verifyEqual(actual, expected);

% Should not throw an error if provided a character row vector
original = 'ABC';
expected = "ABC";
actual = numericOrString(original, "int32", AllowNonScalar=false);
testCase.verifyEqual(actual, expected);
end
end
end
43 changes: 43 additions & 0 deletions matlab/test/arrow/internal/validate/index/tString.m
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,48 @@ function OutputShape(testCase)
actual = index.string(original);
testCase.verifyEqual(actual, expected);
end

function AllowNonScalarTrue(testCase)
% Verify string() behaves as expected provided
% AllowNonScalar=true.

import arrow.internal.validate.*

% Provide a nonscalar string array
original = ["A", "B", "C"];
expected = ["A", "B", "C"]';
actual = index.string(original, AllowNonScalar=true);
testCase.verifyEqual(actual, expected);

% Provide a scalar string array
original = "A";
expected = "A";
actual = index.string(original, AllowNonScalar=true);
testCase.verifyEqual(actual, expected);
end

function AllowNonScalarFalse(testCase)
% Verify string() behaves as expected when provided
% AllowNonScalar=false.

import arrow.internal.validate.*

% Should throw an error if provided a nonscalar string array
original = ["A", "B", "C"];
fcn = @() index.string(original, AllowNonScalar=false);
testCase.verifyError(fcn, "arrow:badsubscript:NonScalar");

% Should not throw an error if provided a scalar string array
original = "A";
expected = "A";
actual = index.string(original, AllowNonScalar=false);
testCase.verifyEqual(actual, expected);

% Should not throw an error if provided a character row vector
original = 'ABC';
expected = "ABC";
actual = index.string(original, AllowNonScalar=false);
testCase.verifyEqual(actual, expected);
end
end
end
6 changes: 3 additions & 3 deletions matlab/test/arrow/tabular/tRecordBatch.m
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ function ErrorIfIndexIsNonScalar(tc)
TOriginal = table(1, 2, 3);
arrowRecordBatch = arrow.recordBatch(TOriginal);
fcn = @() arrowRecordBatch.column([1 2]);
tc.verifyError(fcn, "MATLAB:expectedScalar");
tc.verifyError(fcn, "arrow:badsubscript:NonScalar");
end

function ErrorIfIndexIsNonPositive(tc)
Expand Down Expand Up @@ -380,10 +380,10 @@ function ErrorIfColumnNameIsNonScalar(testCase)
);

name = ["A", "B", "C"];
testCase.verifyError(@() recordBatch.column(name), "MATLAB:expectedScalar");
testCase.verifyError(@() recordBatch.column(name), "arrow:badsubscript:NonScalar");

name = ["A"; "B"; "C"];
testCase.verifyError(@() recordBatch.column(name), "MATLAB:expectedScalar");
testCase.verifyError(@() recordBatch.column(name), "arrow:badsubscript:NonScalar");
end

end
Expand Down
12 changes: 6 additions & 6 deletions matlab/test/arrow/tabular/tSchema.m
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ function ErrorIfUnsupportedFieldIndex(testCase)
]);

index = [];
testCase.verifyError(@() schema.field(index), "MATLAB:expectedScalar");
testCase.verifyError(@() schema.field(index), "arrow:badsubscript:NonScalar");

index = 0;
testCase.verifyError(@() schema.field(index), "arrow:badsubscript:NonPositive");
Expand All @@ -157,7 +157,7 @@ function ErrorIfUnsupportedFieldIndex(testCase)
testCase.verifyError(@() schema.field(index), "arrow:badsubscript:UnsupportedIndexType");

index = [1; 1];
testCase.verifyError(@() schema.field(index), "MATLAB:expectedScalar");
testCase.verifyError(@() schema.field(index), "arrow:badsubscript:NonScalar");
end

function GetFieldByIndex(testCase)
Expand Down Expand Up @@ -446,10 +446,10 @@ function ErrorIfNumericIndexIsNonScalar(testCase)
]);

fieldName = [1, 2, 3];
testCase.verifyError(@() schema.field(fieldName), "MATLAB:expectedScalar");
testCase.verifyError(@() schema.field(fieldName), "arrow:badsubscript:NonScalar");

fieldName = [1; 2; 3];
testCase.verifyError(@() schema.field(fieldName), "MATLAB:expectedScalar");
testCase.verifyError(@() schema.field(fieldName), "arrow:badsubscript:NonScalar");
end

function ErrorIfFieldNameIsNonScalar(testCase)
Expand All @@ -462,10 +462,10 @@ function ErrorIfFieldNameIsNonScalar(testCase)
]);

fieldName = ["A", "B", "C"];
testCase.verifyError(@() schema.field(fieldName), "MATLAB:expectedScalar");
testCase.verifyError(@() schema.field(fieldName), "arrow:badsubscript:NonScalar");

fieldName = ["A"; "B"; "C"];
testCase.verifyError(@() schema.field(fieldName), "MATLAB:expectedScalar");
testCase.verifyError(@() schema.field(fieldName), "arrow:badsubscript:NonScalar");
end

end
Expand Down

0 comments on commit 2f3db65

Please sign in to comment.