From b101914cea0688f2969dd6bb2823cd340b02b243 Mon Sep 17 00:00:00 2001 From: Nikhil Kak Date: Tue, 10 Apr 2018 16:40:49 -0700 Subject: [PATCH] MLP: Check for 1-hot encoding of dependent variable for minibatch This commit adds a check to make sure that the dependent variable for mlp minibatch is one hot encoded. This only validates that the dependent variable array has more than 1 value. Co-authored-by: Orhan Kislal --- .../postgres/modules/convex/mlp_igd.py_in | 3 ++ .../utilities/minibatch_validation.py_in | 29 +++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 src/ports/postgres/modules/utilities/minibatch_validation.py_in diff --git a/src/ports/postgres/modules/convex/mlp_igd.py_in b/src/ports/postgres/modules/convex/mlp_igd.py_in index 2799355a2..6ff3d8681 100644 --- a/src/ports/postgres/modules/convex/mlp_igd.py_in +++ b/src/ports/postgres/modules/convex/mlp_igd.py_in @@ -52,6 +52,7 @@ from utilities.validate_args import input_tbl_valid from utilities.validate_args import is_var_valid from utilities.validate_args import output_tbl_valid from utilities.validate_args import table_exists +from utilities.minibatch_validation import is_var_one_hot_encoded_for_minibatch def mlp(schema_madlib, source_table, output_table, independent_varname, dependent_varname, hidden_layer_sizes, optimizer_param_str, activation, @@ -681,6 +682,8 @@ def _validate_dependent_var(source_table, dependent_varname, # strip out '[]' from expr_type _assert(is_psql_numeric_type(expr_type[:-2]), "Dependent variable column should be of numeric type.") + if is_classification: + is_var_one_hot_encoded_for_minibatch(source_table,dependent_varname) else: if is_classification: _assert(("[]" in expr_type \ diff --git a/src/ports/postgres/modules/utilities/minibatch_validation.py_in b/src/ports/postgres/modules/utilities/minibatch_validation.py_in new file mode 100644 index 000000000..16b11a97c --- /dev/null +++ b/src/ports/postgres/modules/utilities/minibatch_validation.py_in @@ -0,0 +1,29 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import plpy + +def is_var_one_hot_encoded_for_minibatch(table_name, var_name): + query = """SELECT array_upper({var_name}, 2) > 1 AS is_encoded FROM + {table_name} LIMIT 1;""".format(**locals()) + result = plpy.execute(query) + if not result[0]["is_encoded"]: + plpy.error("MiniBatch expects the variable {0} to be one hot encoded." + " You might need to re run the minibatch_preprocessor function" + " and make sure that the variable is encoded".format(var_name))