Skip to content

Commit

Permalink
GH-36072: [MATLAB] Add MATLAB arrow.tabular.RecordBatch class (#36190)
Browse files Browse the repository at this point in the history
### Rationale for this change

Now that the MATLAB interface supports some basic `arrow.array.Array` types, it would be helpful to start building out the tabular types (e.g. `RecordBatch` and `Table`) in parallel.

This pull request contains a basic implementation of `arrow.tabular.RecordBatch` (name subject to change).

### What changes are included in this PR?

1. Added new `arrow.tabular.RecordBatch` class that can be constructed from a MATLAB `table`.
2. Added new test class `tRecordBatch`.

### Are these changes tested?

Yes.

1. Added new test class `tRecordBatch` containing basic tests for the `arrow.tabular.RecordBatch` class.

### Are there any user-facing changes?

Yes.

1. Added new class `arrow.tabular.RecordBatch`.

**Example**:

```matlab
>> matlabTable = table(uint64([1,2,3]'), [true false true]', [0.1, 0.2, 0.3]', VariableNames=["UInt64", "Boolean", "Float64"])

matlabTable =

  3x3 table

    UInt64    Boolean    Float64
    ______    _______    _______

      1        true        0.1  
      2        false       0.2  
      3        true        0.3  

>> arrowRecordBatch = arrow.tabular.RecordBatch(matlabTable)

arrowRecordBatch = 

UInt64:   [
    1,
    2,
    3
  ]
Boolean:   [
    true,
    false,
    true
  ]
Float64:   [
    0.1,
    0.2,
    0.3
  ]

>> convertedMatlabTable = table(arrowRecordBatch)    

convertedMatlabTable =

  3x3 table

    UInt64    Boolean    Float64
    ______    _______    _______

      1        true        0.1  
      2        false       0.2  
      3        true        0.3  

>> isequal(matlabTable, convertedMatlabTable)

ans =

  logical

   1
```

2. Added properties `NumColumns` and `ColumnNames` to `arrow.tabular.RecordBatch`:

**Example**:

```matlab
>> arrowRecordBatch.NumColumns 

ans =

  int32

   3

>> arrowRecordBatch.ColumnNames

ans = 

  1x3 string array

    "UInt64"    "Boolean"    "Float64"
```

3. Added `column(i)` method to `arrow.tabular.RecordBatch` to retrieve the `i`th column of a `RecordBatch` as an `arrow.array.Array`.

**Example**:

```matlab
>> arrowUInt64Array = arrowRecordBatch.column(1) 

arrowUInt64Array = 

[
  1,
  2,
  3
]
>> class(arrowUInt64Array)

ans =

    'arrow.array.UInt64Array'

>> arrowBooleanArray = arrowRecordBatch.column(2)

arrowBooleanArray = 

[
  true,
  false,
  true
]

>> class(arrowBooleanArray)

ans =

    'arrow.array.UInt64Array'

>> arrowFloat64Array = arrowRecordBatch.column(3)

arrowFloat64Array = 

[
  0.1,
  0.2,
  0.3
]

>> class(arrowFloat64Array)

ans =

    'arrow.array.Float64Array'
```

4. Added `toMATLAB` and `table` conversion methods to convert from a `RecordBatch` to a MATLAB `table`.

### Future Directions

1. Implement C++ logic for `toMATLAB` when the Arrow memory for a `RecordBatch` did originate from a MATLAB array (e.g. read from a Parquet file or somewhere else).
2. Add more supported construction interfaces (e.g. `arrow.tabular.RecordBatch(array1, ..., arrayN)`, arrow.tabular.RecordBatch.fromArrays(arrays)`, etc.).
3. Create an `arrow.tabular.Schema` class. Expose this as a public property on the `RecordBatch` class. Create related `arrow.type.Field` and `arrow.type.Type` classes.
4. Create an `arrow.tabular.Table` and related `arrow.array.ChunkedArray` class.
5. Add more `arrow.array.Array` types (e.g. `StringArray`, `TimestampArray`, `Time64Array`).
6. Create a basic workflow example of serializing a `RecordBatch` to disk using an I/O function (e.g. Parquet writing).

### Notes

1. Thanks @ sgilmore10 for your help with this pull request!
2. While writing the tests for `RecordBatch`, we stumbled upon a set of [accidentally committed diff markers] in `UInt64Array.m` or `tUInt64Array.m`. We removed these diff markers in this PR to unblock the `RecordBatch` tests. The unfortunate thing is that this wasn't caught before because MATLAB was simply ignoring the test file `tUInt64Array.m` because it had a syntax error in it. We could choose to explicitly list out all test files in the MATLAB CI workflows to try and avoid similar situations in the future, but this might get unwieldy to maintain over time as we add more tests. We are happy to hear any suggestions from other community members related to this topic.
* Closes: #36072

Lead-authored-by: Kevin Gurney <kgurney@mathworks.com>
Co-authored-by: Kevin Gurney <kevin.p.gurney@gmail.com>
Co-authored-by: Sarah Gilmore <sgilmore@mathworks.com>
Co-authored-by: Sutou Kouhei <kou@cozmixng.org>
Signed-off-by: Sutou Kouhei <kou@clear-code.com>
  • Loading branch information
3 people committed Jun 23, 2023
1 parent bd1ebec commit 382230d
Show file tree
Hide file tree
Showing 12 changed files with 432 additions and 4 deletions.
4 changes: 4 additions & 0 deletions matlab/src/cpp/arrow/matlab/array/proxy/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ namespace arrow::matlab::array::proxy {
REGISTER_METHOD(Array, valid);
}

std::shared_ptr<arrow::Array> Array::getArray() {
return array;
}

void Array::toString(libmexclass::proxy::method::Context& context) {
::matlab::data::ArrayFactory factory;

Expand Down
2 changes: 2 additions & 0 deletions matlab/src/cpp/arrow/matlab/array/proxy/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class Array : public libmexclass::proxy::Proxy {

virtual ~Array() {}

std::shared_ptr<arrow::Array> getArray();

protected:

void toString(libmexclass::proxy::method::Context& context);
Expand Down
3 changes: 3 additions & 0 deletions matlab/src/cpp/arrow/matlab/error/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,7 @@ namespace arrow::matlab::error {
static const char* BUILD_ARRAY_ERROR_ID = "arrow:matlab:proxy:make:FailedToAppendValues";
static const char* BITPACK_VALIDITY_BITMAP_ERROR_ID = "arrow:matlab:proxy:make:FailedToBitPackValidityBitmap";
static const char* UNKNOWN_PROXY_ERROR_ID = "arrow:matlab:proxy:UnknownProxy";
static const char* SCHEMA_BUILDER_FINISH_ERROR_ID = "arrow:matlab:tabular:proxy:SchemaBuilderAddFields";
static const char* SCHEMA_BUILDER_ADD_FIELDS_ERROR_ID = "arrow:matlab:tabular:proxy:SchemaBuilderFinish";
static const char* UNICODE_CONVERSION_ERROR_ID = "arrow:matlab:unicode:UnicodeConversion";
}
3 changes: 3 additions & 0 deletions matlab/src/cpp/arrow/matlab/proxy/factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "arrow/matlab/array/proxy/boolean_array.h"
#include "arrow/matlab/array/proxy/numeric_array.h"
#include "arrow/matlab/tabular/proxy/record_batch.h"
#include "arrow/matlab/error/error.h"

#include "factory.h"
Expand All @@ -39,6 +40,8 @@ libmexclass::proxy::MakeResult Factory::make_proxy(const ClassName& class_name,
REGISTER_PROXY(arrow.array.proxy.Int64Array , arrow::matlab::array::proxy::NumericArray<int64_t>);
// Register MATLAB Proxy class for boolean arrays
REGISTER_PROXY(arrow.array.proxy.BooleanArray, arrow::matlab::array::proxy::BooleanArray);

REGISTER_PROXY(arrow.tabular.proxy.RecordBatch , arrow::matlab::tabular::proxy::RecordBatch);

return libmexclass::error::Error{error::UNKNOWN_PROXY_ERROR_ID, "Did not find matching C++ proxy for " + class_name};
};
Expand Down
116 changes: 116 additions & 0 deletions matlab/src/cpp/arrow/matlab/tabular/proxy/record_batch.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "libmexclass/proxy/ProxyManager.h"

#include "arrow/matlab/array/proxy/array.h"
#include "arrow/matlab/error/error.h"
#include "arrow/matlab/tabular/proxy/record_batch.h"
#include "arrow/type.h"
#include "arrow/util/utf8.h"

namespace arrow::matlab::tabular::proxy {

RecordBatch::RecordBatch(std::shared_ptr<arrow::RecordBatch> record_batch) : record_batch{record_batch} {
REGISTER_METHOD(RecordBatch, toString);
REGISTER_METHOD(RecordBatch, numColumns);
REGISTER_METHOD(RecordBatch, columnNames);
}

void RecordBatch::toString(libmexclass::proxy::method::Context& context) {
namespace mda = ::matlab::data;
mda::ArrayFactory factory;
const auto maybe_utf16_string = arrow::util::UTF8StringToUTF16(record_batch->ToString());
// TODO: Add a helper macro to avoid having to write out an explicit if-statement here when handling errors.
if (!maybe_utf16_string.ok()) {
// TODO: This error message could probably be improved.
context.error = libmexclass::error::Error{error::UNICODE_CONVERSION_ERROR_ID, maybe_utf16_string.status().message()};
return;
}
auto str_mda = factory.createScalar(*maybe_utf16_string);
context.outputs[0] = str_mda;
}

libmexclass::proxy::MakeResult RecordBatch::make(const libmexclass::proxy::FunctionArguments& constructor_arguments) {
namespace mda = ::matlab::data;
mda::StructArray opts = constructor_arguments[0];
const mda::TypedArray<uint64_t> arrow_array_proxy_ids = opts[0]["ArrayProxyIDs"];
const mda::StringArray column_names = opts[0]["ColumnNames"];

std::vector<std::shared_ptr<arrow::Array>> arrow_arrays;
// Retrieve all of the Arrow Array Proxy instances from the libmexclass ProxyManager.
for (const auto& arrow_array_proxy_id : arrow_array_proxy_ids) {
auto proxy = libmexclass::proxy::ProxyManager::getProxy(arrow_array_proxy_id);
auto arrow_array_proxy = std::static_pointer_cast<arrow::matlab::array::proxy::Array>(proxy);
auto arrow_array = arrow_array_proxy->getArray();
arrow_arrays.push_back(arrow_array);
}

std::vector<std::shared_ptr<Field>> fields;
for (size_t i = 0; i < arrow_arrays.size(); ++i) {
const auto type = arrow_arrays[i]->type();
const auto column_name_str = std::u16string(column_names[i]);
const auto maybe_column_name_str = arrow::util::UTF16StringToUTF8(column_name_str);
MATLAB_ERROR_IF_NOT_OK(maybe_column_name_str.status(), error::UNICODE_CONVERSION_ERROR_ID);
fields.push_back(std::make_shared<arrow::Field>(*maybe_column_name_str, type));
}

arrow::SchemaBuilder schema_builder;
MATLAB_ERROR_IF_NOT_OK(schema_builder.AddFields(fields), error::SCHEMA_BUILDER_ADD_FIELDS_ERROR_ID);
auto maybe_schema = schema_builder.Finish();
MATLAB_ERROR_IF_NOT_OK(maybe_schema.status(), error::SCHEMA_BUILDER_FINISH_ERROR_ID);

const auto schema = *maybe_schema;
const auto num_rows = arrow_arrays.size() == 0 ? 0 : arrow_arrays[0]->length();
const auto record_batch = arrow::RecordBatch::Make(schema, num_rows, arrow_arrays);
auto record_batch_proxy = std::make_shared<arrow::matlab::tabular::proxy::RecordBatch>(record_batch);

return record_batch_proxy;
}

void RecordBatch::numColumns(libmexclass::proxy::method::Context& context) {
namespace mda = ::matlab::data;
mda::ArrayFactory factory;
const auto num_columns = record_batch->num_columns();
auto num_columns_mda = factory.createScalar(num_columns);
context.outputs[0] = num_columns_mda;
}

void RecordBatch::columnNames(libmexclass::proxy::method::Context& context) {
namespace mda = ::matlab::data;
mda::ArrayFactory factory;
const int num_columns = record_batch->num_columns();

std::vector<mda::MATLABString> column_names;
for (int i = 0; i < num_columns; ++i) {
const auto column_name_utf8 = record_batch->column_name(i);
auto maybe_column_name_utf16 = arrow::util::UTF8StringToUTF16(column_name_utf8);
// TODO: Add a helper macro to avoid having to write out an explicit if-statement here when handling errors.
if (!maybe_column_name_utf16.ok()) {
// TODO: This error message could probably be improved.
context.error = libmexclass::error::Error{error::UNICODE_CONVERSION_ERROR_ID, maybe_column_name_utf16.status().message()};
return;
}
auto column_name_utf16 = *maybe_column_name_utf16;
const mda::MATLABString matlab_string = mda::MATLABString(std::move(column_name_utf16));
column_names.push_back(matlab_string);
}
auto column_names_mda = factory.createArray({size_t{1}, static_cast<size_t>(num_columns)}, column_names.begin(), column_names.end());
context.outputs[0] = column_names_mda;
}

}
42 changes: 42 additions & 0 deletions matlab/src/cpp/arrow/matlab/tabular/proxy/record_batch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include "arrow/record_batch.h"

#include "libmexclass/proxy/Proxy.h"

namespace arrow::matlab::tabular::proxy {

class RecordBatch : public libmexclass::proxy::Proxy {
public:
RecordBatch(std::shared_ptr<arrow::RecordBatch> record_batch);

virtual ~RecordBatch() {}

static libmexclass::proxy::MakeResult make(const libmexclass::proxy::FunctionArguments& constructor_arguments);

protected:
void toString(libmexclass::proxy::method::Context& context);
void numColumns(libmexclass::proxy::method::Context& context);
void columnNames(libmexclass::proxy::method::Context& context);

std::shared_ptr<arrow::RecordBatch> record_batch;
};

}
3 changes: 1 addition & 2 deletions matlab/src/matlab/+arrow/+array/Array.m
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
% implied. See the License for the specific language governing
% permissions and limitations under the License.


properties (Access=protected)
properties (GetAccess=public, SetAccess=private, Hidden)
Proxy
end

Expand Down
1 change: 0 additions & 1 deletion matlab/src/matlab/+arrow/+array/UInt64Array.m
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
% implied. See the License for the specific language governing
% permissions and limitations under the License.
<<<<<<< HEAD

classdef UInt64Array < arrow.array.NumericArray
% arrow.array.UInt64Array
Expand Down
158 changes: 158 additions & 0 deletions matlab/src/matlab/+arrow/+tabular/RecordBatch.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
% Licensed to the Apache Software Foundation (ASF) under one or more
% contributor license agreements. See the NOTICE file distributed with
% this work for additional information regarding copyright ownership.
% The ASF licenses this file to you under the Apache License, Version
% 2.0 (the "License"); you may not use this file except in compliance
% with the License. You may obtain a copy of the License at
%
% http://www.apache.org/licenses/LICENSE-2.0
%
% Unless required by applicable law or agreed to in writing, software
% distributed under the License is distributed on an "AS IS" BASIS,
% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
% implied. See the License for the specific language governing
% permissions and limitations under the License.

classdef RecordBatch < matlab.mixin.CustomDisplay & ...
matlab.mixin.Scalar
%arrow.tabular.RecordBatch A tabular data structure representing
% a set of arrow.array.Array objects with a fixed schema.

properties (Access=private)
ArrowArrays = {};
end

properties (Dependent, SetAccess=private, GetAccess=public)
NumColumns
ColumnNames
end

properties (Access=protected)
Proxy
end

methods

function numColumns = get.NumColumns(obj)
numColumns = obj.Proxy.numColumns();
end

function columnNames = get.ColumnNames(obj)
columnNames = obj.Proxy.columnNames();
end

function arrowArray = column(obj, idx)
arrowArray = obj.ArrowArrays{idx};
end

function obj = RecordBatch(T)
obj.ArrowArrays = arrow.tabular.RecordBatch.decompose(T);
columnNames = string(T.Properties.VariableNames);
arrayProxyIDs = arrow.tabular.RecordBatch.getArrowProxyIDs(obj.ArrowArrays);
opts = struct("ArrayProxyIDs", arrayProxyIDs, ...
"ColumnNames", columnNames);
obj.Proxy = libmexclass.proxy.Proxy("Name", "arrow.tabular.proxy.RecordBatch", "ConstructorArguments", {opts});
end

function T = table(obj)
matlabArrays = cell(1, numel(obj.ArrowArrays));

for ii = 1:numel(obj.ArrowArrays)
matlabArrays{ii} = toMATLAB(obj.ArrowArrays{ii});
end

variableNames = matlab.lang.makeUniqueStrings(obj.ColumnNames);
% NOTE: Does not currently handle edge cases like ColumnNames
% matching the table DimensionNames.
T = table(matlabArrays{:}, VariableNames=variableNames);
end

function T = toMATLAB(obj)
T = obj.table();
end

end

methods (Static)

function arrowArrays = decompose(T)
% Decompose the input MATLAB table
% input a cell array of equivalent arrow.array.Array
% instances.
arguments
T table
end

numColumns = width(T);
arrowArrays = cell(1, numColumns);

% Convert each MATLAB array into a corresponding
% arrow.array.Array.
for ii = 1:numColumns
arrowArrays{ii} = arrow.tabular.RecordBatch.makeArray(T{:, ii});
end
end

function arrowArray = makeArray(matlabArray)
% Decompose the input MATLAB table
% input a cell array of equivalent arrow.array.Array
% instances.

switch class(matlabArray)
case "single"
arrowArray = arrow.array.Float32Array(matlabArray);
case "double"
arrowArray = arrow.array.Float64Array(matlabArray);
case "uint8"
arrowArray = arrow.array.UInt8Array(matlabArray);
case "uint16"
arrowArray = arrow.array.UInt16Array(matlabArray);
case "uint32"
arrowArray = arrow.array.UInt32Array(matlabArray);
case "uint64"
arrowArray = arrow.array.UInt64Array(matlabArray);
case "int8"
arrowArray = arrow.array.Int8Array(matlabArray);
case "int16"
arrowArray = arrow.array.Int16Array(matlabArray);
case "int32"
arrowArray = arrow.array.Int32Array(matlabArray);
case "int64"
arrowArray = arrow.array.Int64Array(matlabArray);
case "logical"
arrowArray = arrow.array.BooleanArray(matlabArray);
otherwise
error("arrow:tabular:recordbatch:UnsupportedMatlabArrayType", ...
"RecordBatch cannot be constructed from a MATLAB array of type '" + class(matlabArray) + "'.");
end

end

function proxyIDs = getArrowProxyIDs(arrowArrays)
% Extract the Proxy IDs underlying a cell array of
% arrow.array.Array instances.
proxyIDs = zeros(1, numel(arrowArrays), "uint64");

% Convert each MATLAB array into a corresponding
% arrow.array.Array.
for ii = 1:numel(arrowArrays)
proxyIDs(ii) = arrowArrays{ii}.Proxy.ID;
end
end

end

methods (Access = private)
function str = toString(obj)
str = obj.Proxy.toString();
end
end

methods (Access=protected)
function displayScalarObject(obj)
disp(obj.toString());
end
end

end

1 change: 0 additions & 1 deletion matlab/test/arrow/array/tUInt64Array.m
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
% Licensed to the Apache Software Foundation (ASF) under one or more
>>>>>>> b27d47fde (Add abstract NumericArray class)
% contributor license agreements. See the NOTICE file distributed with
% this work for additional information regarding copyright ownership.
% The ASF licenses this file to you under the Apache License, Version
Expand Down

0 comments on commit 382230d

Please sign in to comment.