# Model Selection for Multilayer Perceptron Using Keras and MADlib

E2E classification example using MADlib calling a Keras MLP for different hyperparameters and model architectures.

Deep learning works best on very large datasets, but that is not convenient for a quick introduction to the syntax.  So in this workbook we use the well known iris data set from https://archive.ics.uci.edu/ml/datasets/iris to help get you started.  It is similar to the example in user docs http://madlib.apache.org/docs/latest/index.html

For more realistic examples please refer to the deep learning notebooks at https://github.com/apache/madlib-site/tree/asf-site/community-artifacts

## Table of contents

<a href="#class">Classification</a>

* <a href="#create_input_data">1. Create input data</a>

* <a href="#pp">2. Call preprocessor for deep learning</a>

* <a href="#load">3. Define and load model architecture</a>

* <a href="#def_mst">4. Define and load model selection tuples</a>

* <a href="#train">5. Train</a>

* <a href="#eval">6. Evaluate</a>

* <a href="#pred">7. Predict</a>

<a href="#class2">Classification with Other Parameters</a>

* <a href="#val_dataset">1. Validation dataset</a>

* <a href="#pred_prob">2. Predict probabilities</a>

* <a href="#warm_start">3. Warm start</a>

In [1]:
%load_ext sql

In [2]:
# Greenplum Database 5.x on GCP - via tunnel
%sql postgresql://gpadmin@localhost:8000/madlib
        
# PostgreSQL local
#%sql postgresql://fmcquillan@localhost:5432/madlib

In [3]:
%sql select madlib.version();
#%sql select version();

1 rows affected.


version
"MADlib version: 1.18.0-dev, git revision: rel/v1.17.0-89-g14a91ce, cmake configuration time: Fri Mar 5 23:08:38 UTC 2021, build type: release, build system: Linux-3.10.0-1160.11.1.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5"


<a id="class"></a>
# Classification

<a id="create_input_data"></a>
# 1.  Create input data

Load iris data set.

In [4]:
%%sql 
DROP TABLE IF EXISTS iris_data;

CREATE TABLE iris_data(
    id serial,
    attributes numeric[],
    class_text varchar
);

INSERT INTO iris_data(id, attributes, class_text) VALUES
(1,ARRAY[5.1,3.5,1.4,0.2],'Iris-setosa'),
(2,ARRAY[4.9,3.0,1.4,0.2],'Iris-setosa'),
(3,ARRAY[4.7,3.2,1.3,0.2],'Iris-setosa'),
(4,ARRAY[4.6,3.1,1.5,0.2],'Iris-setosa'),
(5,ARRAY[5.0,3.6,1.4,0.2],'Iris-setosa'),
(6,ARRAY[5.4,3.9,1.7,0.4],'Iris-setosa'),
(7,ARRAY[4.6,3.4,1.4,0.3],'Iris-setosa'),
(8,ARRAY[5.0,3.4,1.5,0.2],'Iris-setosa'),
(9,ARRAY[4.4,2.9,1.4,0.2],'Iris-setosa'),
(10,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),
(11,ARRAY[5.4,3.7,1.5,0.2],'Iris-setosa'),
(12,ARRAY[4.8,3.4,1.6,0.2],'Iris-setosa'),
(13,ARRAY[4.8,3.0,1.4,0.1],'Iris-setosa'),
(14,ARRAY[4.3,3.0,1.1,0.1],'Iris-setosa'),
(15,ARRAY[5.8,4.0,1.2,0.2],'Iris-setosa'),
(16,ARRAY[5.7,4.4,1.5,0.4],'Iris-setosa'),
(17,ARRAY[5.4,3.9,1.3,0.4],'Iris-setosa'),
(18,ARRAY[5.1,3.5,1.4,0.3],'Iris-setosa'),
(19,ARRAY[5.7,3.8,1.7,0.3],'Iris-setosa'),
(20,ARRAY[5.1,3.8,1.5,0.3],'Iris-setosa'),
(21,ARRAY[5.4,3.4,1.7,0.2],'Iris-setosa'),
(22,ARRAY[5.1,3.7,1.5,0.4],'Iris-setosa'),
(23,ARRAY[4.6,3.6,1.0,0.2],'Iris-setosa'),
(24,ARRAY[5.1,3.3,1.7,0.5],'Iris-setosa'),
(25,ARRAY[4.8,3.4,1.9,0.2],'Iris-setosa'),
(26,ARRAY[5.0,3.0,1.6,0.2],'Iris-setosa'),
(27,ARRAY[5.0,3.4,1.6,0.4],'Iris-setosa'),
(28,ARRAY[5.2,3.5,1.5,0.2],'Iris-setosa'),
(29,ARRAY[5.2,3.4,1.4,0.2],'Iris-setosa'),
(30,ARRAY[4.7,3.2,1.6,0.2],'Iris-setosa'),
(31,ARRAY[4.8,3.1,1.6,0.2],'Iris-setosa'),
(32,ARRAY[5.4,3.4,1.5,0.4],'Iris-setosa'),
(33,ARRAY[5.2,4.1,1.5,0.1],'Iris-setosa'),
(34,ARRAY[5.5,4.2,1.4,0.2],'Iris-setosa'),
(35,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),
(36,ARRAY[5.0,3.2,1.2,0.2],'Iris-setosa'),
(37,ARRAY[5.5,3.5,1.3,0.2],'Iris-setosa'),
(38,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),
(39,ARRAY[4.4,3.0,1.3,0.2],'Iris-setosa'),
(40,ARRAY[5.1,3.4,1.5,0.2],'Iris-setosa'),
(41,ARRAY[5.0,3.5,1.3,0.3],'Iris-setosa'),
(42,ARRAY[4.5,2.3,1.3,0.3],'Iris-setosa'),
(43,ARRAY[4.4,3.2,1.3,0.2],'Iris-setosa'),
(44,ARRAY[5.0,3.5,1.6,0.6],'Iris-setosa'),
(45,ARRAY[5.1,3.8,1.9,0.4],'Iris-setosa'),
(46,ARRAY[4.8,3.0,1.4,0.3],'Iris-setosa'),
(47,ARRAY[5.1,3.8,1.6,0.2],'Iris-setosa'),
(48,ARRAY[4.6,3.2,1.4,0.2],'Iris-setosa'),
(49,ARRAY[5.3,3.7,1.5,0.2],'Iris-setosa'),
(50,ARRAY[5.0,3.3,1.4,0.2],'Iris-setosa'),
(51,ARRAY[7.0,3.2,4.7,1.4],'Iris-versicolor'),
(52,ARRAY[6.4,3.2,4.5,1.5],'Iris-versicolor'),
(53,ARRAY[6.9,3.1,4.9,1.5],'Iris-versicolor'),
(54,ARRAY[5.5,2.3,4.0,1.3],'Iris-versicolor'),
(55,ARRAY[6.5,2.8,4.6,1.5],'Iris-versicolor'),
(56,ARRAY[5.7,2.8,4.5,1.3],'Iris-versicolor'),
(57,ARRAY[6.3,3.3,4.7,1.6],'Iris-versicolor'),
(58,ARRAY[4.9,2.4,3.3,1.0],'Iris-versicolor'),
(59,ARRAY[6.6,2.9,4.6,1.3],'Iris-versicolor'),
(60,ARRAY[5.2,2.7,3.9,1.4],'Iris-versicolor'),
(61,ARRAY[5.0,2.0,3.5,1.0],'Iris-versicolor'),
(62,ARRAY[5.9,3.0,4.2,1.5],'Iris-versicolor'),
(63,ARRAY[6.0,2.2,4.0,1.0],'Iris-versicolor'),
(64,ARRAY[6.1,2.9,4.7,1.4],'Iris-versicolor'),
(65,ARRAY[5.6,2.9,3.6,1.3],'Iris-versicolor'),
(66,ARRAY[6.7,3.1,4.4,1.4],'Iris-versicolor'),
(67,ARRAY[5.6,3.0,4.5,1.5],'Iris-versicolor'),
(68,ARRAY[5.8,2.7,4.1,1.0],'Iris-versicolor'),
(69,ARRAY[6.2,2.2,4.5,1.5],'Iris-versicolor'),
(70,ARRAY[5.6,2.5,3.9,1.1],'Iris-versicolor'),
(71,ARRAY[5.9,3.2,4.8,1.8],'Iris-versicolor'),
(72,ARRAY[6.1,2.8,4.0,1.3],'Iris-versicolor'),
(73,ARRAY[6.3,2.5,4.9,1.5],'Iris-versicolor'),
(74,ARRAY[6.1,2.8,4.7,1.2],'Iris-versicolor'),
(75,ARRAY[6.4,2.9,4.3,1.3],'Iris-versicolor'),
(76,ARRAY[6.6,3.0,4.4,1.4],'Iris-versicolor'),
(77,ARRAY[6.8,2.8,4.8,1.4],'Iris-versicolor'),
(78,ARRAY[6.7,3.0,5.0,1.7],'Iris-versicolor'),
(79,ARRAY[6.0,2.9,4.5,1.5],'Iris-versicolor'),
(80,ARRAY[5.7,2.6,3.5,1.0],'Iris-versicolor'),
(81,ARRAY[5.5,2.4,3.8,1.1],'Iris-versicolor'),
(82,ARRAY[5.5,2.4,3.7,1.0],'Iris-versicolor'),
(83,ARRAY[5.8,2.7,3.9,1.2],'Iris-versicolor'),
(84,ARRAY[6.0,2.7,5.1,1.6],'Iris-versicolor'),
(85,ARRAY[5.4,3.0,4.5,1.5],'Iris-versicolor'),
(86,ARRAY[6.0,3.4,4.5,1.6],'Iris-versicolor'),
(87,ARRAY[6.7,3.1,4.7,1.5],'Iris-versicolor'),
(88,ARRAY[6.3,2.3,4.4,1.3],'Iris-versicolor'),
(89,ARRAY[5.6,3.0,4.1,1.3],'Iris-versicolor'),
(90,ARRAY[5.5,2.5,4.0,1.3],'Iris-versicolor'),
(91,ARRAY[5.5,2.6,4.4,1.2],'Iris-versicolor'),
(92,ARRAY[6.1,3.0,4.6,1.4],'Iris-versicolor'),
(93,ARRAY[5.8,2.6,4.0,1.2],'Iris-versicolor'),
(94,ARRAY[5.0,2.3,3.3,1.0],'Iris-versicolor'),
(95,ARRAY[5.6,2.7,4.2,1.3],'Iris-versicolor'),
(96,ARRAY[5.7,3.0,4.2,1.2],'Iris-versicolor'),
(97,ARRAY[5.7,2.9,4.2,1.3],'Iris-versicolor'),
(98,ARRAY[6.2,2.9,4.3,1.3],'Iris-versicolor'),
(99,ARRAY[5.1,2.5,3.0,1.1],'Iris-versicolor'),
(100,ARRAY[5.7,2.8,4.1,1.3],'Iris-versicolor'),
(101,ARRAY[6.3,3.3,6.0,2.5],'Iris-virginica'),
(102,ARRAY[5.8,2.7,5.1,1.9],'Iris-virginica'),
(103,ARRAY[7.1,3.0,5.9,2.1],'Iris-virginica'),
(104,ARRAY[6.3,2.9,5.6,1.8],'Iris-virginica'),
(105,ARRAY[6.5,3.0,5.8,2.2],'Iris-virginica'),
(106,ARRAY[7.6,3.0,6.6,2.1],'Iris-virginica'),
(107,ARRAY[4.9,2.5,4.5,1.7],'Iris-virginica'),
(108,ARRAY[7.3,2.9,6.3,1.8],'Iris-virginica'),
(109,ARRAY[6.7,2.5,5.8,1.8],'Iris-virginica'),
(110,ARRAY[7.2,3.6,6.1,2.5],'Iris-virginica'),
(111,ARRAY[6.5,3.2,5.1,2.0],'Iris-virginica'),
(112,ARRAY[6.4,2.7,5.3,1.9],'Iris-virginica'),
(113,ARRAY[6.8,3.0,5.5,2.1],'Iris-virginica'),
(114,ARRAY[5.7,2.5,5.0,2.0],'Iris-virginica'),
(115,ARRAY[5.8,2.8,5.1,2.4],'Iris-virginica'),
(116,ARRAY[6.4,3.2,5.3,2.3],'Iris-virginica'),
(117,ARRAY[6.5,3.0,5.5,1.8],'Iris-virginica'),
(118,ARRAY[7.7,3.8,6.7,2.2],'Iris-virginica'),
(119,ARRAY[7.7,2.6,6.9,2.3],'Iris-virginica'),
(120,ARRAY[6.0,2.2,5.0,1.5],'Iris-virginica'),
(121,ARRAY[6.9,3.2,5.7,2.3],'Iris-virginica'),
(122,ARRAY[5.6,2.8,4.9,2.0],'Iris-virginica'),
(123,ARRAY[7.7,2.8,6.7,2.0],'Iris-virginica'),
(124,ARRAY[6.3,2.7,4.9,1.8],'Iris-virginica'),
(125,ARRAY[6.7,3.3,5.7,2.1],'Iris-virginica'),
(126,ARRAY[7.2,3.2,6.0,1.8],'Iris-virginica'),
(127,ARRAY[6.2,2.8,4.8,1.8],'Iris-virginica'),
(128,ARRAY[6.1,3.0,4.9,1.8],'Iris-virginica'),
(129,ARRAY[6.4,2.8,5.6,2.1],'Iris-virginica'),
(130,ARRAY[7.2,3.0,5.8,1.6],'Iris-virginica'),
(131,ARRAY[7.4,2.8,6.1,1.9],'Iris-virginica'),
(132,ARRAY[7.9,3.8,6.4,2.0],'Iris-virginica'),
(133,ARRAY[6.4,2.8,5.6,2.2],'Iris-virginica'),
(134,ARRAY[6.3,2.8,5.1,1.5],'Iris-virginica'),
(135,ARRAY[6.1,2.6,5.6,1.4],'Iris-virginica'),
(136,ARRAY[7.7,3.0,6.1,2.3],'Iris-virginica'),
(137,ARRAY[6.3,3.4,5.6,2.4],'Iris-virginica'),
(138,ARRAY[6.4,3.1,5.5,1.8],'Iris-virginica'),
(139,ARRAY[6.0,3.0,4.8,1.8],'Iris-virginica'),
(140,ARRAY[6.9,3.1,5.4,2.1],'Iris-virginica'),
(141,ARRAY[6.7,3.1,5.6,2.4],'Iris-virginica'),
(142,ARRAY[6.9,3.1,5.1,2.3],'Iris-virginica'),
(143,ARRAY[5.8,2.7,5.1,1.9],'Iris-virginica'),
(144,ARRAY[6.8,3.2,5.9,2.3],'Iris-virginica'),
(145,ARRAY[6.7,3.3,5.7,2.5],'Iris-virginica'),
(146,ARRAY[6.7,3.0,5.2,2.3],'Iris-virginica'),
(147,ARRAY[6.3,2.5,5.0,1.9],'Iris-virginica'),
(148,ARRAY[6.5,3.0,5.2,2.0],'Iris-virginica'),
(149,ARRAY[6.2,3.4,5.4,2.3],'Iris-virginica'),
(150,ARRAY[5.9,3.0,5.1,1.8],'Iris-virginica');

SELECT * FROM iris_data ORDER BY id;

Done.
Done.
150 rows affected.
150 rows affected.


id,attributes,class_text
1,"[Decimal('5.1'), Decimal('3.5'), Decimal('1.4'), Decimal('0.2')]",Iris-setosa
2,"[Decimal('4.9'), Decimal('3.0'), Decimal('1.4'), Decimal('0.2')]",Iris-setosa
3,"[Decimal('4.7'), Decimal('3.2'), Decimal('1.3'), Decimal('0.2')]",Iris-setosa
4,"[Decimal('4.6'), Decimal('3.1'), Decimal('1.5'), Decimal('0.2')]",Iris-setosa
5,"[Decimal('5.0'), Decimal('3.6'), Decimal('1.4'), Decimal('0.2')]",Iris-setosa
6,"[Decimal('5.4'), Decimal('3.9'), Decimal('1.7'), Decimal('0.4')]",Iris-setosa
7,"[Decimal('4.6'), Decimal('3.4'), Decimal('1.4'), Decimal('0.3')]",Iris-setosa
8,"[Decimal('5.0'), Decimal('3.4'), Decimal('1.5'), Decimal('0.2')]",Iris-setosa
9,"[Decimal('4.4'), Decimal('2.9'), Decimal('1.4'), Decimal('0.2')]",Iris-setosa
10,"[Decimal('4.9'), Decimal('3.1'), Decimal('1.5'), Decimal('0.1')]",Iris-setosa


Create a test/validation dataset from the training data

In [5]:
%%sql
DROP TABLE IF EXISTS iris_train, iris_test;

-- Set seed so results are reproducible
SELECT setseed(0);

SELECT madlib.train_test_split('iris_data',     -- Source table
                               'iris',          -- Output table root name
                                0.8,            -- Train proportion
                                NULL,           -- Test proportion (0.2)
                                NULL,           -- Strata definition
                                NULL,           -- Output all columns
                                NULL,           -- Sample without replacement
                                TRUE            -- Separate output tables
                              );

SELECT COUNT(*) FROM iris_train;

Done.
1 rows affected.
1 rows affected.
1 rows affected.


count
120


<a id="pp"></a>
# 2. Call preprocessor for deep learning
Training dataset (uses training preprocessor):

In [7]:
%%sql
DROP TABLE IF EXISTS iris_train_packed, iris_train_packed_summary;

SELECT madlib.training_preprocessor_dl('iris_train',         -- Source table
                                       'iris_train_packed',  -- Output table
                                       'class_text',        -- Dependent variable
                                       'attributes'         -- Independent variable
                                        ); 

SELECT attributes_shape, class_text_shape, buffer_id FROM iris_train_packed ORDER BY buffer_id;

Done.
1 rows affected.
2 rows affected.


attributes_shape,class_text_shape,buffer_id
"[60, 4]","[60, 3]",0
"[60, 4]","[60, 3]",1


In [8]:
%%sql
SELECT * FROM iris_train_packed_summary;

1 rows affected.


source_table,output_table,dependent_varname,independent_varname,dependent_vartype,class_text_class_values,buffer_size,normalizing_const,num_classes,distribution_rules,__internal_gpu_config__
iris_train,iris_train_packed,[u'class_text'],[u'attributes'],[u'character varying'],"[u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica']",60,1.0,[3],all_segments,all_segments


Validation dataset (uses validation preprocessor):

In [9]:
%%sql
DROP TABLE IF EXISTS iris_test_packed, iris_test_packed_summary;

SELECT madlib.validation_preprocessor_dl('iris_test',          -- Source table
                                         'iris_test_packed',   -- Output table
                                         'class_text',         -- Dependent variable
                                         'attributes',         -- Independent variable
                                         'iris_train_packed'   -- From training preprocessor step
                                          ); 

SELECT attributes_shape, class_text_shape, buffer_id FROM iris_test_packed ORDER BY buffer_id;

Done.
1 rows affected.
2 rows affected.


attributes_shape,class_text_shape,buffer_id
"[15, 4]","[15, 3]",0
"[15, 4]","[15, 3]",1


In [10]:
%%sql
SELECT * FROM iris_test_packed_summary;

1 rows affected.


source_table,output_table,dependent_varname,independent_varname,dependent_vartype,class_text_class_values,buffer_size,normalizing_const,num_classes,distribution_rules,__internal_gpu_config__
iris_test,iris_test_packed,[u'class_text'],[u'attributes'],[u'character varying'],"[u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica']",15,1.0,[3],all_segments,all_segments


<a id="load"></a>
# 3. Define and load model architecture
Import Keras libraries

In [11]:
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

Define model architecture with 1 hidden layer:

In [13]:
model1 = Sequential()
model1.add(Dense(10, activation='relu', input_shape=(4,)))
model1.add(Dense(10, activation='relu'))
model1.add(Dense(3, activation='softmax'))
    
model1.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_3 (Dense)              (None, 10)                50        
_________________________________________________________________
dense_4 (Dense)              (None, 10)                110       
_________________________________________________________________
dense_5 (Dense)              (None, 3)                 33        
Total params: 193
Trainable params: 193
Non-trainable params: 0
_________________________________________________________________


In [14]:
model1.to_json()

'{"class_name": "Sequential", "keras_version": "2.2.4-tf", "config": {"layers": [{"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "GlorotUniform", "config": {"dtype": "float32", "seed": null}}, "name": "dense_3", "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, "dtype": "float32", "activation": "relu", "trainable": true, "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros", "config": {"dtype": "float32"}}, "units": 10, "batch_input_shape": [null, 4], "use_bias": true, "activity_regularizer": null}}, {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "GlorotUniform", "config": {"dtype": "float32", "seed": null}}, "name": "dense_4", "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, "dtype": "float32", "activation": "relu", "trainable": true, "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros", "config": {"dtype": "float32"}}, "units": 10, "use_bias": t

Define model architecture with 2 hidden layers:

In [15]:
model2 = Sequential()
model2.add(Dense(10, activation='relu', input_shape=(4,)))
model2.add(Dense(10, activation='relu'))
model2.add(Dense(10, activation='relu'))
model2.add(Dense(3, activation='softmax'))
    
model2.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_6 (Dense)              (None, 10)                50        
_________________________________________________________________
dense_7 (Dense)              (None, 10)                110       
_________________________________________________________________
dense_8 (Dense)              (None, 10)                110       
_________________________________________________________________
dense_9 (Dense)              (None, 3)                 33        
Total params: 303
Trainable params: 303
Non-trainable params: 0
_________________________________________________________________


In [14]:
model2.to_json()

'{"class_name": "Sequential", "keras_version": "2.1.6", "config": [{"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "name": "dense_4", "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, "dtype": "float32", "activation": "relu", "trainable": true, "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 10, "batch_input_shape": [null, 4], "use_bias": true, "activity_regularizer": null}}, {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "name": "dense_5", "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, "activation": "relu", "trainable": true, "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 10,

Load into model architecture table

In [16]:
%%sql
DROP TABLE IF EXISTS model_arch_library;

SELECT madlib.load_keras_model('model_arch_library',  -- Output table,
                               
$$
{"class_name": "Sequential", "keras_version": "2.1.6", "config": [{"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "name": "dense_1", "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, "dtype": "float32", "activation": "relu", "trainable": true, "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 10, "batch_input_shape": [null, 4], "use_bias": true, "activity_regularizer": null}}, {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "name": "dense_2", "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, "activation": "relu", "trainable": true, "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 10, "use_bias": true, "activity_regularizer": null}}, {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "name": "dense_3", "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, "activation": "softmax", "trainable": true, "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 3, "use_bias": true, "activity_regularizer": null}}], "backend": "tensorflow"}
$$
::json,         -- JSON blob
                               NULL,                  -- Weights
                               'Sophie',              -- Name
                               'MLP with 1 hidden layer'       -- Descr
);

SELECT madlib.load_keras_model('model_arch_library',  -- Output table,
                               
$$
{"class_name": "Sequential", "keras_version": "2.1.6", "config": [{"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "name": "dense_4", "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, "dtype": "float32", "activation": "relu", "trainable": true, "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 10, "batch_input_shape": [null, 4], "use_bias": true, "activity_regularizer": null}}, {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "name": "dense_5", "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, "activation": "relu", "trainable": true, "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 10, "use_bias": true, "activity_regularizer": null}}, {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "name": "dense_6", "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, "activation": "relu", "trainable": true, "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 10, "use_bias": true, "activity_regularizer": null}}, {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "name": "dense_7", "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, "activation": "softmax", "trainable": true, "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 3, "use_bias": true, "activity_regularizer": null}}], "backend": "tensorflow"}
$$
::json,         -- JSON blob
                               NULL,                  -- Weights
                               'Maria',               -- Name
                               'MLP with 2 hidden layers'       -- Descr
);

SELECT * FROM model_arch_library ORDER BY model_id;

Done.
1 rows affected.
1 rows affected.
2 rows affected.


model_id,model_arch,model_weights,name,description,__internal_madlib_id__
1,"{u'class_name': u'Sequential', u'keras_version': u'2.1.6', u'config': [{u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_1', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'dtype': u'float32', u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'batch_input_shape': [None, 4], u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_2', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_3', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'softmax', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 3, u'use_bias': True, u'activity_regularizer': None}}], u'backend': u'tensorflow'}",,Sophie,MLP with 1 hidden layer,__madlib_temp_99030268_1614985897_73934030__
2,"{u'class_name': u'Sequential', u'keras_version': u'2.1.6', u'config': [{u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_4', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'dtype': u'float32', u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'batch_input_shape': [None, 4], u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_5', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_6', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'relu', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 10, u'use_bias': True, u'activity_regularizer': None}}, {u'class_name': u'Dense', u'config': {u'kernel_initializer': {u'class_name': u'VarianceScaling', u'config': {u'distribution': u'uniform', u'scale': 1.0, u'seed': None, u'mode': u'fan_avg'}}, u'name': u'dense_7', u'kernel_constraint': None, u'bias_regularizer': None, u'bias_constraint': None, u'activation': u'softmax', u'trainable': True, u'kernel_regularizer': None, u'bias_initializer': {u'class_name': u'Zeros', u'config': {}}, u'units': 3, u'use_bias': True, u'activity_regularizer': None}}], u'backend': u'tensorflow'}",,Maria,MLP with 2 hidden layers,__madlib_temp_69765081_1614985897_31307059__


<a id="def_mst"></a>
# 4.  Define and load model selection tuples

Generate model configurations using grid search. The output table for grid search contains the unique combinations of model architectures, compile and fit parameters.

In [17]:
%%sql
DROP TABLE IF EXISTS mst_table, mst_table_summary;

SELECT madlib.generate_model_configs(
                                        'model_arch_library', -- model architecture table
                                        'mst_table',          -- model selection table output
                                         ARRAY[1,2],          -- model ids from model architecture table
                                         $$
                                            {'loss': ['categorical_crossentropy'],
                                             'optimizer_params_list': [ {'optimizer': ['Adam'], 'lr': [0.001, 0.01, 0.1]} ],
                                             'metrics': ['accuracy']}
                                         $$,                  -- compile_param_grid
                                         $$
                                         { 'batch_size': [4, 8],
                                           'epochs': [1]
                                         }
                                         $$,                  -- fit_param_grid
                                         'grid'               -- search_type
                                         );

SELECT * FROM mst_table ORDER BY mst_key;

Done.
1 rows affected.
12 rows affected.


mst_key,model_id,compile_params,fit_params
1,1,"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=4"
2,1,"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=8"
3,1,"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=4"
4,1,"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=8"
5,1,"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=4"
6,1,"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=8"
7,2,"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=4"
8,2,"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=8"
9,2,"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=4"
10,2,"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=8"


This is the name of the model architecture table that corresponds to the model selection table:

In [18]:
%%sql
SELECT * FROM mst_table_summary;

1 rows affected.


model_arch_table,object_table
model_arch_library,


<a id="train"></a>
# 5.  Train
Train multiple models:

In [19]:
%%sql
DROP TABLE IF EXISTS iris_multi_model, iris_multi_model_summary, iris_multi_model_info;

SELECT madlib.madlib_keras_fit_multiple_model('iris_train_packed',    -- source_table
                                              'iris_multi_model',     -- model_output_table
                                              'mst_table',            -- model_selection_table
                                              10,                     -- num_iterations
                                              FALSE                   -- use gpus
                                             );

Done.
1 rows affected.


madlib_keras_fit_multiple_model


View the model summary:

In [20]:
%%sql
SELECT * FROM iris_multi_model_summary;

1 rows affected.


source_table,validation_table,model,model_info,dependent_varname,independent_varname,model_arch_table,model_selection_table,object_table,num_iterations,metrics_compute_frequency,warm_start,name,description,start_training_time,end_training_time,madlib_version,num_classes,class_text_class_values,dependent_vartype,normalizing_const,metrics_iters
iris_train_packed,,iris_multi_model,iris_multi_model_info,[u'class_text'],[u'attributes'],model_arch_library,mst_table,,10,10,False,,,2021-03-05 23:17:53.015997,2021-03-05 23:19:25.828328,1.18.0-dev,[1],"[u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica']",[u'character varying'],1.0,[10]


View results for each model:

In [22]:
%%sql
SELECT * FROM iris_multi_model_info ORDER BY training_metrics_final DESC, training_loss_final;

12 rows affected.


mst_key,model_id,compile_params,fit_params,model_type,model_size,metrics_elapsed_time,metrics_type,loss_type,training_metrics_final,training_loss_final,training_metrics,training_loss,validation_metrics_final,validation_loss_final,validation_metrics,validation_loss
9,2,"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=4",madlib_keras,1.18359375,[90.4152021408081],[u'accuracy'],categorical_crossentropy,0.983333349228,0.0784567892551,[0.983333349227905],[0.0784567892551422],,,,
7,2,"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=4",madlib_keras,1.18359375,[89.9267370700836],[u'accuracy'],categorical_crossentropy,0.949999988079,0.370112359524,[0.949999988079071],[0.370112359523773],,,,
10,2,"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=8",madlib_keras,1.18359375,[90.6731050014496],[u'accuracy'],categorical_crossentropy,0.933333337307,0.120034113526,[0.933333337306976],[0.120034113526344],,,,
4,1,"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=8",madlib_keras,0.75390625,[91.596951007843],[u'accuracy'],categorical_crossentropy,0.908333361149,0.221316426992,[0.908333361148834],[0.221316426992416],,,,
11,2,"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=4",madlib_keras,1.18359375,[92.810093164444],[u'accuracy'],categorical_crossentropy,0.866666674614,0.256317943335,[0.866666674613953],[0.256317943334579],,,,
8,2,"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=8",madlib_keras,1.18359375,[92.3147950172424],[u'accuracy'],categorical_crossentropy,0.774999976158,0.937343239784,[0.774999976158142],[0.937343239784241],,,,
3,1,"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=4",madlib_keras,0.75390625,[90.1556451320648],[u'accuracy'],categorical_crossentropy,0.725000023842,0.670977592468,[0.725000023841858],[0.670977592468262],,,,
2,1,"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=8",madlib_keras,0.75390625,[91.1536350250244],[u'accuracy'],categorical_crossentropy,0.691666662693,0.679427325726,[0.691666662693024],[0.679427325725555],,,,
1,1,"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=4",madlib_keras,0.75390625,[92.044242143631],[u'accuracy'],categorical_crossentropy,0.691666662693,0.818258941174,[0.691666662693024],[0.818258941173553],,,,
12,2,"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=8",madlib_keras,1.18359375,[90.9328460693359],[u'accuracy'],categorical_crossentropy,0.641666650772,0.462306410074,[0.641666650772095],[0.462306410074234],,,,


<a id="eval"></a>
# 6. Evaluate

Now run evaluate using model we built above:

In [24]:
%%sql
DROP TABLE IF EXISTS iris_validate;
SELECT madlib.madlib_keras_evaluate('iris_multi_model',  -- model
                                    'iris_test_packed',  -- test table
                                    'iris_validate',     -- output table
                                     NULL,               -- use gpus
                                     9                   -- mst_key to use
                                   );

SELECT * FROM iris_validate;

Done.
1 rows affected.
1 rows affected.


loss,metric,metrics_type,loss_type
0.143714919686,0.966666638851,[u'accuracy'],categorical_crossentropy


<a id="pred"></a>
# 7. Predict

Now predict using model we built.  We will use the validation data set for prediction as well, which is not usual but serves to show the syntax. The prediction is in the estimated_class_text column:

In [25]:
%%sql
DROP TABLE IF EXISTS iris_predict;

SELECT madlib.madlib_keras_predict('iris_multi_model', -- model
                                   'iris_test',        -- test_table
                                   'id',               -- id column
                                   'attributes',       -- independent var
                                   'iris_predict',     -- output table
                                    'response',        -- prediction type
                                    FALSE,             -- use gpus
                                    9                  -- mst_key to use
                                   );

SELECT * FROM iris_predict ORDER BY id;

Done.
1 rows affected.
30 rows affected.


id,class_name,class_value,prob
2,class_text,Iris-setosa,0.99811894
5,class_text,Iris-setosa,0.99811894
7,class_text,Iris-setosa,0.99811894
13,class_text,Iris-setosa,0.99811894
23,class_text,Iris-setosa,0.99811894
24,class_text,Iris-setosa,0.99811894
27,class_text,Iris-setosa,0.99811894
30,class_text,Iris-setosa,0.99811894
31,class_text,Iris-setosa,0.99811894
32,class_text,Iris-setosa,0.99811894


Count missclassifications

In [29]:
%%sql
SELECT COUNT(*) FROM iris_predict JOIN iris_test USING (id) 
WHERE iris_predict.class_value != iris_test.class_text;

1 rows affected.


count
1


Percent missclassifications

In [30]:
%%sql
SELECT round(count(*)*100/(150*0.2),2) as test_accuracy_percent from
    (select iris_test.class_text as actual, iris_predict.class_value as estimated
     from iris_predict inner join iris_test
     on iris_test.id=iris_predict.id) q
WHERE q.actual=q.estimated;

1 rows affected.


test_accuracy_percent
96.67


<a id="class2"></a>
# Classification with Other Parameters

<a id="val_dataset"></a>
# 1.  Validation dataset

Now use a validation dataset and compute metrics every 2nd iteration using the 'metrics_compute_frequency' parameter.  This can help reduce run time if you do not need metrics computed at every iteration.

In [31]:
%%sql
DROP TABLE IF EXISTS iris_multi_model, iris_multi_model_summary, iris_multi_model_info;

SELECT madlib.madlib_keras_fit_multiple_model('iris_train_packed',    -- source_table
                                              'iris_multi_model',     -- model_output_table
                                              'mst_table',            -- model_selection_table
                                               10,                     -- num_iterations
                                               FALSE,                 -- use gpus
                                              'iris_test_packed',     -- validation dataset
                                               3,                     -- metrics compute frequency
                                               FALSE,                 -- warm start
                                              'Sophie L.',            -- name
                                              'Model selection for iris dataset'  -- description
                                             );

Done.
1 rows affected.


madlib_keras_fit_multiple_model


View the model summary:

In [32]:
%%sql
SELECT * FROM iris_multi_model_summary;

1 rows affected.


source_table,validation_table,model,model_info,dependent_varname,independent_varname,model_arch_table,model_selection_table,object_table,num_iterations,metrics_compute_frequency,warm_start,name,description,start_training_time,end_training_time,madlib_version,num_classes,class_text_class_values,dependent_vartype,normalizing_const,metrics_iters
iris_train_packed,iris_test_packed,iris_multi_model,iris_multi_model_info,[u'class_text'],[u'attributes'],model_arch_library,mst_table,,10,3,False,Sophie L.,Model selection for iris dataset,2021-03-05 23:30:29.361525,2021-03-05 23:32:25.242549,1.18.0-dev,[1],"[u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica']",[u'character varying'],1.0,"[3, 6, 9, 10]"


View performance of each model:

In [33]:
%%sql
SELECT * FROM iris_multi_model_info ORDER BY validation_metrics_final DESC;

12 rows affected.


mst_key,model_id,compile_params,fit_params,model_type,model_size,metrics_elapsed_time,metrics_type,loss_type,training_metrics_final,training_loss_final,training_metrics,training_loss,validation_metrics_final,validation_loss_final,validation_metrics,validation_loss
3,1,"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=4",madlib_keras,0.75390625,"[30.6736240386963, 63.8817899227142, 97.0417211055756, 113.172317028046]",[u'accuracy'],categorical_crossentropy,0.975000023842,0.0857621207833,"[0.866666674613953, 0.908333361148834, 0.824999988079071, 0.975000023841858]","[0.283501744270325, 0.201569080352783, 0.365902632474899, 0.085762120783329]",0.933333337307,0.122390650213,"[0.733333349227905, 0.933333337306976, 0.866666674613953, 0.933333337306976]","[0.402765065431595, 0.179033249616623, 0.302312880754471, 0.122390650212765]"
11,2,"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=4",madlib_keras,1.18359375,"[33.3402421474457, 66.678878068924, 100.063696146011, 115.878378152847]",[u'accuracy'],categorical_crossentropy,0.966666638851,0.122639112175,"[0.841666638851166, 0.866666674613953, 0.983333349227905, 0.966666638851166]","[0.927325308322906, 0.235888451337814, 0.152433648705482, 0.122639112174511]",0.933333337307,0.213843882084,"[0.800000011920929, 0.766666650772095, 0.966666638851166, 0.933333337306976]","[1.31817901134491, 0.403295516967773, 0.225061357021332, 0.213843882083893]"
4,1,"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=8",madlib_keras,0.75390625,"[32.3545279502869, 65.5994129180908, 98.8639390468597, 114.628155946732]",[u'accuracy'],categorical_crossentropy,0.949999988079,0.340993762016,"[0.725000023841858, 0.975000023841858, 0.966666638851166, 0.949999988079071]","[0.634531021118164, 0.48712894320488, 0.342138230800629, 0.340993762016296]",0.933333337307,0.375507682562,"[0.633333325386047, 0.933333337306976, 0.966666638851166, 0.933333337306976]","[0.697710037231445, 0.519804179668427, 0.398834854364395, 0.375507682561874]"
10,2,"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=8",madlib_keras,1.18359375,"[31.1844570636749, 64.4128880500793, 97.6602430343628, 113.686207056046]",[u'accuracy'],categorical_crossentropy,0.966666638851,0.0938184261322,"[0.975000023841858, 0.949999988079071, 0.966666638851166, 0.966666638851166]","[0.107082977890968, 0.102665595710278, 0.0681213364005089, 0.0938184261322021]",0.933333337307,0.166130110621,"[0.966666638851166, 0.933333337306976, 0.933333337306976, 0.933333337306976]","[0.141592919826508, 0.121055454015732, 0.0925953686237335, 0.166130110621452]"
6,1,"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=8",madlib_keras,0.75390625,"[32.1151111125946, 65.3588180541992, 98.6281480789185, 114.400741100311]",[u'accuracy'],categorical_crossentropy,0.875,0.260402351618,"[0.975000023841858, 0.649999976158142, 0.774999976158142, 0.875]","[0.350311905145645, 0.55135190486908, 0.319230705499649, 0.260402351617813]",0.866666674614,0.267347544432,"[0.966666638851166, 0.466666668653488, 0.666666686534882, 0.866666674613953]","[0.401203900575638, 0.844793558120728, 0.492979735136032, 0.267347544431686]"
9,2,"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=4",madlib_keras,1.18359375,"[30.9298729896545, 64.1563010215759, 97.3999631404877, 113.427273988724]",[u'accuracy'],categorical_crossentropy,0.908333361149,0.262101709843,"[0.725000023841858, 0.958333313465118, 0.916666686534882, 0.908333361148834]","[0.42991915345192, 0.225951835513115, 0.209349796175957, 0.262101709842682]",0.800000011921,0.303691267967,"[0.633333325386047, 1.0, 0.966666638851166, 0.800000011920929]","[0.514671266078949, 0.251676499843597, 0.234429702162743, 0.303691267967224]"
7,2,"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=4",madlib_keras,1.18359375,"[30.4536979198456, 63.6586000919342, 96.8189079761505, 112.950505018234]",[u'accuracy'],categorical_crossentropy,0.866666674614,0.441878378391,"[0.474999994039536, 0.691666662693024, 0.741666674613953, 0.866666674613953]","[0.957030355930328, 0.640025198459625, 0.472760319709778, 0.441878378391266]",0.766666650772,0.470352619886,"[0.433333337306976, 0.566666662693024, 0.600000023841858, 0.766666650772095]","[0.983189880847931, 0.697314381599426, 0.513611257076263, 0.470352619886398]"
8,2,"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=8",madlib_keras,1.18359375,"[32.8489799499512, 66.1956899166107, 99.4372820854187, 115.379984140396]",[u'accuracy'],categorical_crossentropy,0.850000023842,0.454301446676,"[0.691666662693024, 0.691666662693024, 0.725000023841858, 0.850000023841858]","[0.845666646957397, 0.635031342506409, 0.49227574467659, 0.454301446676254]",0.766666650772,0.534793019295,"[0.566666662693024, 0.566666662693024, 0.633333325386047, 0.766666650772095]","[0.97626668214798, 0.749890267848969, 0.588431000709534, 0.534793019294739]"
5,1,"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=4",madlib_keras,0.75390625,"[33.0761120319366, 66.4210600852966, 99.6609480381012, 115.608960151672]",[u'accuracy'],categorical_crossentropy,0.850000023842,0.456756323576,"[0.966666638851166, 0.958333313465118, 0.566666662693024, 0.850000023841858]","[0.138208195567131, 0.242085874080658, 2.31771349906921, 0.456756323575974]",0.733333349228,0.720676660538,"[0.966666638851166, 0.966666638851166, 0.666666686534882, 0.733333349227905]","[0.188635662198067, 0.311398893594742, 2.20467662811279, 0.72067666053772]"
12,2,"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=8",madlib_keras,1.18359375,"[31.4386551380157, 64.6823661327362, 97.9228360652924, 113.945574998856]",[u'accuracy'],categorical_crossentropy,0.691666662693,0.461679339409,"[0.858333349227905, 0.691666662693024, 0.641666650772095, 0.691666662693024]","[0.436321765184402, 0.492918580770493, 0.463972359895706, 0.461679339408875]",0.566666662693,0.466376572847,"[0.800000011920929, 0.566666662693024, 0.766666650772095, 0.566666662693024]","[0.436118185520172, 0.50031191110611, 0.46391037106514, 0.466376572847366]"


Plot validation results

In [34]:
%matplotlib notebook
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from collections import defaultdict
import pandas as pd
import seaborn as sns
sns.set_palette(sns.color_palette("hls", 20))
plt.rcParams.update({'font.size': 12})
pd.set_option('display.max_colwidth', -1)

In [36]:
df_results = %sql SELECT * FROM iris_multi_model_info ORDER BY validation_loss ASC LIMIT 7;
df_results = df_results.DataFrame()

df_summary = %sql SELECT * FROM iris_multi_model_summary;
df_summary = df_summary.DataFrame()

#set up plots
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,5))
fig.legend(ncol=4)
fig.tight_layout()

ax_metric = axs[0]
ax_loss = axs[1]

ax_metric.xaxis.set_major_locator(MaxNLocator(integer=True))
ax_metric.set_xlabel('Iteration')
ax_metric.set_ylabel('Metric')
ax_metric.set_title('Validation metric curve')

ax_loss.xaxis.set_major_locator(MaxNLocator(integer=True))
ax_loss.set_xlabel('Iteration')
ax_loss.set_ylabel('Loss')
ax_loss.set_title('Validation loss curve')

iters = df_summary['metrics_iters'][0]

for mst_key in df_results['mst_key']:
    df_output_info = %sql SELECT validation_metrics,validation_loss FROM iris_multi_model_info WHERE mst_key = $mst_key
    df_output_info = df_output_info.DataFrame()
    validation_metrics = df_output_info['validation_metrics'][0]
    validation_loss = df_output_info['validation_loss'][0]
    
    ax_metric.plot(iters, validation_metrics, label=mst_key, marker='o')
    ax_loss.plot(iters, validation_loss, label=mst_key, marker='o')

plt.legend();
# fig.savefig('./lc_keras_fit.png', dpi = 300)

7 rows affected.
1 rows affected.


<IPython.core.display.Javascript object>

1 rows affected.
1 rows affected.
1 rows affected.
1 rows affected.
1 rows affected.
1 rows affected.
1 rows affected.


<a id="pred_prob"></a>
# 2.  Predict probabilities

Predict with probabilities for each class:

In [38]:
%%sql
DROP TABLE IF EXISTS iris_predict;

SELECT madlib.madlib_keras_predict('iris_multi_model', -- model
                                   'iris_test',        -- test_table
                                   'id',               -- id column
                                   'attributes',       -- independent var
                                   'iris_predict',     -- output table
                                    'prob',            -- prediction type
                                    FALSE,             -- use gpus
                                    3                  -- mst_key to use
                                   );

SELECT * FROM iris_predict ORDER BY id, rank;

Done.
1 rows affected.
90 rows affected.


id,class_name,class_value,prob,rank
2,class_text,Iris-setosa,0.9998863,1
2,class_text,Iris-versicolor,0.00011367707,2
2,class_text,Iris-virginica,1.2402717e-12,3
5,class_text,Iris-setosa,0.9999517,1
5,class_text,Iris-versicolor,4.8225505e-05,2
5,class_text,Iris-virginica,9.1029716e-14,3
7,class_text,Iris-setosa,0.9998826,1
7,class_text,Iris-versicolor,0.00011736089,2
7,class_text,Iris-virginica,1.3640493e-12,3
13,class_text,Iris-setosa,0.9998809,1


<a id="warm_start"></a>
# 3.  Warm start

Next, use the warm_start parameter to continue learning, using the coefficients from the run above. Note that we don't drop the model table or model summary table:

In [39]:
%%sql
SELECT madlib.madlib_keras_fit_multiple_model('iris_train_packed',    -- source_table
                                              'iris_multi_model',     -- model_output_table
                                              'mst_table',            -- model_selection_table
                                               3,                     -- num_iterations
                                               FALSE,                 -- use gpus
                                              'iris_test_packed',     -- validation dataset
                                               1,                     -- metrics compute frequency
                                               TRUE,                  -- warm start
                                              'Sophie L.',            -- name
                                              'Simple MLP for iris dataset'  -- description
                                             );

1 rows affected.


madlib_keras_fit_multiple_model


View summary:

In [40]:
%%sql
SELECT * FROM iris_multi_model_summary;

1 rows affected.


source_table,validation_table,model,model_info,dependent_varname,independent_varname,model_arch_table,model_selection_table,object_table,num_iterations,metrics_compute_frequency,warm_start,name,description,start_training_time,end_training_time,madlib_version,num_classes,class_text_class_values,dependent_vartype,normalizing_const,metrics_iters
iris_train_packed,iris_test_packed,iris_multi_model,iris_multi_model_info,[u'class_text'],[u'attributes'],model_arch_library,mst_table,,3,1,True,Sophie L.,Simple MLP for iris dataset,2021-03-05 23:33:49.889172,2021-03-05 23:34:35.696218,1.18.0-dev,[1],"[u'Iris-setosa', u'Iris-versicolor', u'Iris-virginica']",[u'character varying'],1.0,"[1, 2, 3]"


View performance of each model:

In [41]:
%%sql
SELECT * FROM iris_multi_model_info ORDER BY validation_metrics_final DESC;

12 rows affected.


mst_key,model_id,compile_params,fit_params,model_type,model_size,metrics_elapsed_time,metrics_type,loss_type,training_metrics_final,training_loss_final,training_metrics,training_loss,validation_metrics_final,validation_loss_final,validation_metrics,validation_loss
7,2,"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=4",madlib_keras,1.18359375,"[12.4149520397186, 27.5328221321106, 42.8307020664215]",[u'accuracy'],categorical_crossentropy,0.925000011921,0.378388017416,"[0.866666674613953, 0.883333325386047, 0.925000011920929]","[0.418134957551956, 0.397540986537933, 0.378388017416]",0.933333337307,0.401163935661,"[0.766666650772095, 0.800000011920929, 0.933333337306976]","[0.450369209051132, 0.430310726165771, 0.401163935661316]"
10,2,"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=8",madlib_keras,1.18359375,"[13.1818931102753, 28.5486171245575, 43.8156039714813]",[u'accuracy'],categorical_crossentropy,0.975000023842,0.0919722244143,"[0.966666638851166, 0.841666638851166, 0.975000023841858]","[0.0829650238156319, 0.478152126073837, 0.0919722244143486]",0.933333337307,0.181837588549,"[0.966666638851166, 0.733333349227905, 0.933333337306976]","[0.108266495168209, 0.985528528690338, 0.18183758854866]"
6,1,"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=8",madlib_keras,0.75390625,"[13.8922369480133, 29.253427028656, 44.6270771026611]",[u'accuracy'],categorical_crossentropy,0.933333337307,0.163653150201,"[0.891666650772095, 0.491666674613953, 0.933333337306976]","[0.223677828907967, 1.31514668464661, 0.163653150200844]",0.899999976158,0.214205592871,"[0.766666650772095, 0.600000023841858, 0.899999976158142]","[0.346942394971848, 1.53849768638611, 0.214205592870712]"
5,1,"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=4",madlib_keras,0.75390625,"[15.0511469841003, 30.1774880886078, 45.5478649139404]",[u'accuracy'],categorical_crossentropy,0.933333337307,0.159554898739,"[0.75, 0.941666662693024, 0.933333337306976]","[0.376825720071793, 0.160363256931305, 0.159554898738861]",0.899999976158,0.249670743942,"[0.666666686534882, 0.899999976158142, 0.899999976158142]","[0.572766482830048, 0.20397062599659, 0.249670743942261]"
8,2,"optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=8",madlib_keras,1.18359375,"[14.5801939964294, 29.9574360847473, 45.324392080307]",[u'accuracy'],categorical_crossentropy,0.941666662693,0.370102822781,"[0.891666650772095, 0.916666686534882, 0.941666662693024]","[0.420222342014313, 0.39033767580986, 0.370102822780609]",0.866666674614,0.429161161184,"[0.800000011920929, 0.833333313465118, 0.866666674613953]","[0.495854884386063, 0.458509385585785, 0.429161161184311]"
12,2,"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=8",madlib_keras,1.18359375,"[13.4523401260376, 28.810662984848, 44.1752660274506]",[u'accuracy'],categorical_crossentropy,0.641666650772,0.462273716927,"[0.691666662693024, 0.691666662693024, 0.641666650772095]","[0.461356490850449, 0.460641026496887, 0.462273716926575]",0.766666650772,0.461440712214,"[0.566666662693024, 0.566666662693024, 0.766666650772095]","[0.494220763444901, 0.470757216215134, 0.461440712213516]"
11,2,"optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=4",madlib_keras,1.18359375,"[15.3230481147766, 30.4359600543976, 45.8051240444183]",[u'accuracy'],categorical_crossentropy,0.641666650772,0.465268939734,"[0.691666662693024, 0.691666662693024, 0.641666650772095]","[0.468634635210037, 0.460501492023468, 0.465268939733505]",0.766666650772,0.453260302544,"[0.566666662693024, 0.566666662693024, 0.766666650772095]","[0.527487695217133, 0.486553341150284, 0.45326030254364]"
9,2,"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=4",madlib_keras,1.18359375,"[12.9218399524689, 28.2852990627289, 43.3031449317932]",[u'accuracy'],categorical_crossentropy,0.908333361149,0.282380491495,"[0.975000023841858, 0.916666686534882, 0.908333361148834]","[0.133402094244957, 0.194111600518227, 0.282380491495132]",0.766666650772,0.321279972792,"[0.966666638851166, 0.966666638851166, 0.766666650772095]","[0.14349028468132, 0.207735612988472, 0.321279972791672]"
3,1,"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=4",madlib_keras,0.75390625,"[12.6596629619598, 28.0149381160736, 43.0508260726929]",[u'accuracy'],categorical_crossentropy,0.841666638851,0.374685227871,"[0.933333337306976, 0.899999976158142, 0.841666638851166]","[0.162063658237457, 0.223208039999008, 0.374685227870941]",0.733333349228,0.698961436749,"[0.866666674613953, 0.733333349227905, 0.733333349227905]","[0.297752887010574, 0.422860831022263, 0.698961436748505]"
4,1,"optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy'","epochs=1,batch_size=8",madlib_keras,0.75390625,"[14.1109259128571, 29.4764680862427, 44.847892999649]",[u'accuracy'],categorical_crossentropy,0.858333349228,0.323420703411,"[0.975000023841858, 0.975000023841858, 0.858333349227905]","[0.260769069194794, 0.237523972988129, 0.323420703411102]",0.699999988079,0.45600476861,"[0.933333337306976, 0.933333337306976, 0.699999988079071]","[0.302719950675964, 0.286018937826157, 0.456004768610001]"


Plot validation results:

In [42]:
df_results = %sql SELECT * FROM iris_multi_model_info ORDER BY validation_loss ASC LIMIT 7;
df_results = df_results.DataFrame()

df_summary = %sql SELECT * FROM iris_multi_model_summary;
df_summary = df_summary.DataFrame()

#set up plots
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,5))
fig.legend(ncol=4)
fig.tight_layout()

ax_metric = axs[0]
ax_loss = axs[1]

ax_metric.xaxis.set_major_locator(MaxNLocator(integer=True))
ax_metric.set_xlabel('Iteration')
ax_metric.set_ylabel('Metric')
ax_metric.set_title('Validation metric curve')

ax_loss.xaxis.set_major_locator(MaxNLocator(integer=True))
ax_loss.set_xlabel('Iteration')
ax_loss.set_ylabel('Loss')
ax_loss.set_title('Validation loss curve')

iters = df_summary['metrics_iters'][0]

for mst_key in df_results['mst_key']:
    df_output_info = %sql SELECT validation_metrics,validation_loss FROM iris_multi_model_info WHERE mst_key = $mst_key
    df_output_info = df_output_info.DataFrame()
    validation_metrics = df_output_info['validation_metrics'][0]
    validation_loss = df_output_info['validation_loss'][0]
    
    ax_metric.plot(iters, validation_metrics, label=mst_key, marker='o')
    ax_loss.plot(iters, validation_loss, label=mst_key, marker='o')

plt.legend();
# fig.savefig('./lc_keras_fit.png', dpi = 300)

7 rows affected.
1 rows affected.


<IPython.core.display.Javascript object>

1 rows affected.
1 rows affected.
1 rows affected.
1 rows affected.
1 rows affected.
1 rows affected.
1 rows affected.
