Skip to content

Commit

Permalink
GH-37571: [MATLAB] Add arrow.tabular.Table MATLAB class (#37620)
Browse files Browse the repository at this point in the history
### Rationale for this change

Following on from #37525, which adds `arrow.array.ChunkedArray` to the MATLAB interface, this pull request adds support for a new `arrow.tabular.Table` MATLAB class.

This pull request is intended to be an initial implementation of `Table` support and does not include all methods or properties that may be useful on `arrow.tabular.Table`.

### What changes are included in this PR?

1. Added new `arrow.tabular.Table` MATLAB class.

**Properties**

* `NumRows`
* `NumColumns`
* `ColumnNames`
* `Schema`

**Methods**

* `fromArrays(<array-1>, ..., <array-N>)`
* `column(<index>)`
* `table()`
* `toMATLAB()`

**Example of `arrow.tabular.Table.fromArrays(<array_1>, ..., <array-N>)` static construction method**
```matlab
>> arrowTable = arrow.tabular.Table.fromArrays(arrow.array([1, 2, 3]), arrow.array(["A", "B", "C"]), arrow.array([true, false, true]))

arrowTable = 

Column1: double
Column2: string
Column3: bool
----
Column1:
  [
    [
      1,
      2,
      3
    ]
  ]
Column2:
  [
    [
      "A",
      "B",
      "C"
    ]
  ]
Column3:
  [
    [
      true,
      false,
      true
    ]
  ]

>> matlabTable = table(arrowTable)

matlabTable =

  3×3 table

    Column1    Column2    Column3
    _______    _______    _______

       1         "A"       true  
       2         "B"       false 
       3         "C"       true  
```

2. Added a new `arrow.table(<matlab-table>)` construction function which creates an `arrow.tabular.Table` from a MATLAB `table`. 

**Example of `arrow.table(<matlab-table>)` construction function**
```matlab
>> matlabTable = table([1; 2; 3], ["A"; "B"; "C"], [true; false; true])

matlabTable =

  3×3 table

    Var1    Var2    Var3 
    ____    ____    _____

     1      "A"     true 
     2      "B"     false
     3      "C"     true 

>> arrowTable = arrow.table(matlabTable)

arrowTable = 

Var1: double
Var2: string
Var3: bool
----
Var1:
  [
    [
      1,
      2,
      3
    ]
  ]
Var2:
  [
    [
      "A",
      "B",
      "C"
    ]
  ]
Var3:
  [
    [
      true,
      false,
      true
    ]
  ]

>> arrowTable.NumRows

ans =

  int64

   3

>> arrowTable.NumColumns

ans =

  int32

   3

>> arrowTable.ColumnNames

ans = 

  1×3 string array

    "Var1"    "Var2"    "Var3"

>> arrowTable.Schema

ans = 

Var1: double
Var2: string
Var3: bool

>> table(arrowTable)

ans =

  3×3 table

    Var1    Var2    Var3 
    ____    ____    _____

     1      "A"     true 
     2      "B"     false
     3      "C"     true 

>> isequal(ans, matlabTable)

ans =

  logical

   1
```

### Are these changes tested?

Yes.

1. Added a new `tTable` test class for `arrow.tabular.Table` and `arrow.table(<matlab-table>)` tests.

### Are there any user-facing changes?

Yes.

1. Users can now create `arrow.tabular.Table` objects using the `fromArrays` static construction method or the `arrow.table(<matlab-table>)` construction function.

### Future Directions

1. Create shared test infrastructure for common `RecordBatch` and `Table` MATLAB tests.
2. Implement equality check (i.e. `isequal`) for `arrow.tabular.Table` instances.
4. Add more static construction methods to `arrow.tabular.Table`. For example: `fromChunkedArrays(<chunkedArray-1>, ..., <chunkedArray-N>)` and `fromRecordBatches(<recordBatch-1>, ..., <recordBatch-N>)`.

### Notes

1. A lot of the code for `arrow.tabular.Table` is very similar to the code for `arrow.tabular.RecordBatch`. It may make sense for us to try to share more of the code using C++ templates or another approach.
2. Thank you @ sgilmore10 for your help with this pull request!
* Closes: #37571

Lead-authored-by: Kevin Gurney <kgurney@mathworks.com>
Co-authored-by: Sarah Gilmore <sgilmore@mathworks.com>
Signed-off-by: Kevin Gurney <kgurney@mathworks.com>
  • Loading branch information
kevingurney and sgilmore10 committed Sep 7, 2023
1 parent cdc95c2 commit da602af
Show file tree
Hide file tree
Showing 8 changed files with 1,068 additions and 0 deletions.
2 changes: 2 additions & 0 deletions matlab/src/cpp/arrow/matlab/error/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ namespace arrow::matlab::error {
static const char* UNKNOWN_PROXY_FOR_ARRAY_TYPE = "arrow:array:UnknownProxyForArrayType";
static const char* RECORD_BATCH_NUMERIC_INDEX_WITH_EMPTY_RECORD_BATCH = "arrow:tabular:recordbatch:NumericIndexWithEmptyRecordBatch";
static const char* RECORD_BATCH_INVALID_NUMERIC_COLUMN_INDEX = "arrow:tabular:recordbatch:InvalidNumericColumnIndex";
static const char* TABLE_NUMERIC_INDEX_WITH_EMPTY_TABLE = "arrow:tabular:table:NumericIndexWithEmptyTable";
static const char* TABLE_INVALID_NUMERIC_COLUMN_INDEX = "arrow:tabular:table:InvalidNumericColumnIndex";
static const char* FAILED_TO_OPEN_FILE_FOR_WRITE = "arrow:io:FailedToOpenFileForWrite";
static const char* FAILED_TO_OPEN_FILE_FOR_READ = "arrow:io:FailedToOpenFileForRead";
static const char* FEATHER_FAILED_TO_WRITE_TABLE = "arrow:io:feather:FailedToWriteTable";
Expand Down
2 changes: 2 additions & 0 deletions matlab/src/cpp/arrow/matlab/proxy/factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "arrow/matlab/array/proxy/time64_array.h"
#include "arrow/matlab/array/proxy/chunked_array.h"
#include "arrow/matlab/tabular/proxy/record_batch.h"
#include "arrow/matlab/tabular/proxy/table.h"
#include "arrow/matlab/tabular/proxy/schema.h"
#include "arrow/matlab/error/error.h"
#include "arrow/matlab/type/proxy/primitive_ctype.h"
Expand Down Expand Up @@ -60,6 +61,7 @@ libmexclass::proxy::MakeResult Factory::make_proxy(const ClassName& class_name,
REGISTER_PROXY(arrow.array.proxy.Date64Array , arrow::matlab::array::proxy::NumericArray<arrow::Date64Type>);
REGISTER_PROXY(arrow.array.proxy.ChunkedArray , arrow::matlab::array::proxy::ChunkedArray);
REGISTER_PROXY(arrow.tabular.proxy.RecordBatch , arrow::matlab::tabular::proxy::RecordBatch);
REGISTER_PROXY(arrow.tabular.proxy.Table , arrow::matlab::tabular::proxy::Table);
REGISTER_PROXY(arrow.tabular.proxy.Schema , arrow::matlab::tabular::proxy::Schema);
REGISTER_PROXY(arrow.type.proxy.Field , arrow::matlab::type::proxy::Field);
REGISTER_PROXY(arrow.type.proxy.Float32Type , arrow::matlab::type::proxy::PrimitiveCType<float>);
Expand Down
215 changes: 215 additions & 0 deletions matlab/src/cpp/arrow/matlab/tabular/proxy/table.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "libmexclass/proxy/ProxyManager.h"

#include "arrow/matlab/array/proxy/array.h"
#include "arrow/matlab/array/proxy/chunked_array.h"
#include "arrow/matlab/array/proxy/wrap.h"

#include "arrow/matlab/error/error.h"
#include "arrow/matlab/tabular/proxy/table.h"
#include "arrow/matlab/tabular/proxy/schema.h"
#include "arrow/type.h"
#include "arrow/util/utf8.h"

#include "libmexclass/proxy/ProxyManager.h"
#include "libmexclass/error/Error.h"

namespace arrow::matlab::tabular::proxy {

namespace {
libmexclass::error::Error makeEmptyTableError() {
const std::string error_msg = "Numeric indexing using the column method is not supported for tables with no columns.";
return libmexclass::error::Error{error::TABLE_NUMERIC_INDEX_WITH_EMPTY_TABLE, error_msg};
}

libmexclass::error::Error makeInvalidNumericIndexError(const int32_t matlab_index, const int32_t num_columns) {
std::stringstream error_message_stream;
error_message_stream << "Invalid column index: ";
error_message_stream << matlab_index;
error_message_stream << ". Column index must be between 1 and the number of columns (";
error_message_stream << num_columns;
error_message_stream << ").";
return libmexclass::error::Error{error::TABLE_INVALID_NUMERIC_COLUMN_INDEX, error_message_stream.str()};
}
}

Table::Table(std::shared_ptr<arrow::Table> table) : table{table} {
REGISTER_METHOD(Table, toString);
REGISTER_METHOD(Table, getNumRows);
REGISTER_METHOD(Table, getNumColumns);
REGISTER_METHOD(Table, getColumnNames);
REGISTER_METHOD(Table, getSchema);
REGISTER_METHOD(Table, getColumnByIndex);
REGISTER_METHOD(Table, getColumnByName);
}

std::shared_ptr<arrow::Table> Table::unwrap() {
return table;
}

void Table::toString(libmexclass::proxy::method::Context& context) {
namespace mda = ::matlab::data;
MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(const auto utf16_string, arrow::util::UTF8StringToUTF16(table->ToString()), context, error::UNICODE_CONVERSION_ERROR_ID);
mda::ArrayFactory factory;
auto str_mda = factory.createScalar(utf16_string);
context.outputs[0] = str_mda;
}

libmexclass::proxy::MakeResult Table::make(const libmexclass::proxy::FunctionArguments& constructor_arguments) {
using ArrayProxy = arrow::matlab::array::proxy::Array;
using TableProxy = arrow::matlab::tabular::proxy::Table;
namespace mda = ::matlab::data;
mda::StructArray opts = constructor_arguments[0];
const mda::TypedArray<uint64_t> arrow_array_proxy_ids = opts[0]["ArrayProxyIDs"];
const mda::StringArray column_names = opts[0]["ColumnNames"];

std::vector<std::shared_ptr<arrow::Array>> arrow_arrays;
// Retrieve all of the Arrow Array Proxy instances from the libmexclass ProxyManager.
for (const auto& arrow_array_proxy_id : arrow_array_proxy_ids) {
auto proxy = libmexclass::proxy::ProxyManager::getProxy(arrow_array_proxy_id);
auto arrow_array_proxy = std::static_pointer_cast<ArrayProxy>(proxy);
auto arrow_array = arrow_array_proxy->unwrap();
arrow_arrays.push_back(arrow_array);
}

std::vector<std::shared_ptr<Field>> fields;
for (size_t i = 0; i < arrow_arrays.size(); ++i) {
const auto type = arrow_arrays[i]->type();
const auto column_name_utf16 = std::u16string(column_names[i]);
MATLAB_ASSIGN_OR_ERROR(const auto column_name_utf8, arrow::util::UTF16StringToUTF8(column_name_utf16), error::UNICODE_CONVERSION_ERROR_ID);
fields.push_back(std::make_shared<arrow::Field>(column_name_utf8, type));
}

arrow::SchemaBuilder schema_builder;
MATLAB_ERROR_IF_NOT_OK(schema_builder.AddFields(fields), error::SCHEMA_BUILDER_ADD_FIELDS_ERROR_ID);
MATLAB_ASSIGN_OR_ERROR(const auto schema, schema_builder.Finish(), error::SCHEMA_BUILDER_FINISH_ERROR_ID);
const auto num_rows = arrow_arrays.size() == 0 ? 0 : arrow_arrays[0]->length();
const auto table = arrow::Table::Make(schema, arrow_arrays, num_rows);
auto table_proxy = std::make_shared<TableProxy>(table);

return table_proxy;
}

void Table::getNumRows(libmexclass::proxy::method::Context& context) {
namespace mda = ::matlab::data;
mda::ArrayFactory factory;
const auto num_rows = table->num_rows();
auto num_rows_mda = factory.createScalar(num_rows);
context.outputs[0] = num_rows_mda;
}

void Table::getNumColumns(libmexclass::proxy::method::Context& context) {
namespace mda = ::matlab::data;
mda::ArrayFactory factory;
const auto num_columns = table->num_columns();
auto num_columns_mda = factory.createScalar(num_columns);
context.outputs[0] = num_columns_mda;
}

void Table::getColumnNames(libmexclass::proxy::method::Context& context) {
namespace mda = ::matlab::data;
mda::ArrayFactory factory;
const int num_columns = table->num_columns();

std::vector<mda::MATLABString> column_names;
const auto schema = table->schema();
const auto field_names = schema->field_names();
for (int i = 0; i < num_columns; ++i) {
const auto column_name_utf8 = field_names[i];
MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(auto column_name_utf16, arrow::util::UTF8StringToUTF16(column_name_utf8), context, error::UNICODE_CONVERSION_ERROR_ID);
const mda::MATLABString matlab_string = mda::MATLABString(std::move(column_name_utf16));
column_names.push_back(matlab_string);
}
auto column_names_mda = factory.createArray({size_t{1}, static_cast<size_t>(num_columns)}, column_names.begin(), column_names.end());
context.outputs[0] = column_names_mda;
}

void Table::getSchema(libmexclass::proxy::method::Context& context) {
namespace mda = ::matlab::data;
using namespace libmexclass::proxy;
using SchemaProxy = arrow::matlab::tabular::proxy::Schema;
mda::ArrayFactory factory;

const auto schema = table->schema();
const auto schema_proxy = std::make_shared<SchemaProxy>(std::move(schema));
const auto schema_proxy_id = ProxyManager::manageProxy(schema_proxy);
const auto schema_proxy_id_mda = factory.createScalar(schema_proxy_id);

context.outputs[0] = schema_proxy_id_mda;
}

void Table::getColumnByIndex(libmexclass::proxy::method::Context& context) {
using ChunkedArrayProxy = arrow::matlab::array::proxy::ChunkedArray;
namespace mda = ::matlab::data;
using namespace libmexclass::proxy;
mda::ArrayFactory factory;

mda::StructArray args = context.inputs[0];
const mda::TypedArray<int32_t> index_mda = args[0]["Index"];
const auto matlab_index = int32_t(index_mda[0]);

// Note: MATLAB uses 1-based indexing, so subtract 1.
// arrow::Schema::field does not do any bounds checking.
const int32_t index = matlab_index - 1;
const auto num_columns = table->num_columns();

if (num_columns == 0) {
context.error = makeEmptyTableError();
return;
}

if (matlab_index < 1 || matlab_index > num_columns) {
context.error = makeInvalidNumericIndexError(matlab_index, num_columns);
return;
}

const auto chunked_array = table->column(index);
const auto chunked_array_proxy = std::make_shared<ChunkedArrayProxy>(chunked_array);

const auto chunked_array_proxy_id = ProxyManager::manageProxy(chunked_array_proxy);
const auto chunked_array_proxy_id_mda = factory.createScalar(chunked_array_proxy_id);

context.outputs[0] = chunked_array_proxy_id_mda;
}

void Table::getColumnByName(libmexclass::proxy::method::Context& context) {
using ChunkedArrayProxy = arrow::matlab::array::proxy::ChunkedArray;
namespace mda = ::matlab::data;
using namespace libmexclass::proxy;
mda::ArrayFactory factory;

mda::StructArray args = context.inputs[0];
const mda::StringArray name_mda = args[0]["Name"];
const auto name_utf16 = std::u16string(name_mda[0]);
MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(const auto name, arrow::util::UTF16StringToUTF8(name_utf16), context, error::UNICODE_CONVERSION_ERROR_ID);

const std::vector<std::string> names = {name};
const auto& schema = table->schema();
MATLAB_ERROR_IF_NOT_OK_WITH_CONTEXT(schema->CanReferenceFieldsByNames(names), context, error::ARROW_TABULAR_SCHEMA_AMBIGUOUS_FIELD_NAME);

const auto chunked_array = table->GetColumnByName(name);
const auto chunked_array_proxy = std::make_shared<ChunkedArrayProxy>(chunked_array);

const auto chunked_array_proxy_id = ProxyManager::manageProxy(chunked_array_proxy);
const auto chunked_array_proxy_id_mda = factory.createScalar(chunked_array_proxy_id);

context.outputs[0] = chunked_array_proxy_id_mda;
}

}
48 changes: 48 additions & 0 deletions matlab/src/cpp/arrow/matlab/tabular/proxy/table.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include "arrow/table.h"

#include "libmexclass/proxy/Proxy.h"

namespace arrow::matlab::tabular::proxy {

class Table : public libmexclass::proxy::Proxy {
public:
Table(std::shared_ptr<arrow::Table> table);

virtual ~Table() {}

std::shared_ptr<arrow::Table> unwrap();

static libmexclass::proxy::MakeResult make(const libmexclass::proxy::FunctionArguments& constructor_arguments);

protected:
void toString(libmexclass::proxy::method::Context& context);
void getNumRows(libmexclass::proxy::method::Context& context);
void getNumColumns(libmexclass::proxy::method::Context& context);
void getColumnNames(libmexclass::proxy::method::Context& context);
void getSchema(libmexclass::proxy::method::Context& context);
void getColumnByIndex(libmexclass::proxy::method::Context& context);
void getColumnByName(libmexclass::proxy::method::Context& context);

std::shared_ptr<arrow::Table> table;
};

}

0 comments on commit da602af

Please sign in to comment.