In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from keras_tuner import HyperParameters

import autokeras as ak

from cerebro.nas.hphpmodel import HyperHyperModel
import pandas as pd

In [2]:
df = pd.read_csv("/Users/zijian/Desktop/ucsd/cse234/project/cerebro-system/Iris_clean.csv", header='infer')

train=df.sample(frac=0.8,random_state=200) #random state is a seed value
test=df.drop(train.index)

feature_columns = ['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm']
x_train = train[feature_columns]
y_train = train['Species']

In [3]:
train.dtypes

SepalLengthCm    float64
SepalWidthCm     float64
PetalLengthCm    float64
PetalWidthCm     float64
Species            int64
dtype: object

In [4]:
input_node = [ak.StructuredDataInput() for c in feature_columns]
embeddings = [ak.StructuredDataBlock()(innode) for innode in input_node]
output_node = ak.Merge()([embeddings])
output_node = ak.ClassificationHead()(output_node)
clf = ak.AutoModel(
    inputs=input_node, 
    outputs=output_node, 
    max_trials=20, 
    tuner='random',
    objective="val_loss",
)

INFO:tensorflow:Reloading Oracle from existing project ./auto_model/oracle.json
INFO:tensorflow:Reloading Tuner from ./auto_model/tuner0.json


In [5]:
x = [x_train[[feature]] for feature in feature_columns]

In [6]:
clf.fit(
    x,
    y_train,
    # The name of the label column.
    epochs=5,
)

2021-11-21 20:00:57.358877: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2021-11-21 20:00:57.359089: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-11-21 20:00:57.414353: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)


INFO:tensorflow:Oracle triggered exit
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


2021-11-21 20:01:01.610706: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.


INFO:tensorflow:Assets written to: ./auto_model/best_model/assets


<tensorflow.python.keras.callbacks.History at 0x187294350>

In [7]:
model = clf.tuner.get_best_model()

In [8]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 1)]          0                                            
__________________________________________________________________________________________________
multi_category_encoding_2 (Mult (None, 1)            0           input_3[0][0]                    
__________________________________________________________________________________________________
normalization_1 (Normalization) (None, 1)            3           multi_category_encoding_2[0][0]  
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 1)]          0                                            
______________________________________________________________________________________________

In [11]:
x_test = test[feature_columns]
y_test = test['Species']

x_test = np.array(x_test)
x_test = [x_test[:,i,np.newaxis] for i in range(x_test.shape[1])]
y_test = np.array(y_test)
y_test[:10]

array([0, 0, 0, 0, 0, 0, 0, 0, 1, 1])

In [12]:
model.predict(x_test)

array([[1.00000000e+00, 5.81818187e-12, 1.38920356e-13],
       [1.00000000e+00, 6.91627117e-21, 1.74427243e-11],
       [9.99997139e-01, 2.34918322e-34, 2.85116948e-06],
       [9.99996424e-01, 3.28661191e-31, 3.58094326e-06],
       [1.00000000e+00, 2.59460422e-15, 4.50647519e-09],
       [1.00000000e+00, 5.98095217e-19, 1.46587054e-09],
       [1.00000000e+00, 3.27511533e-20, 1.07206148e-16],
       [1.00000000e+00, 2.48271475e-20, 1.03590746e-16],
       [1.06542777e-14, 2.44966905e-05, 9.99975562e-01],
       [4.86013543e-22, 9.86436069e-01, 1.35639487e-02],
       [1.61351650e-16, 5.96532912e-09, 1.00000000e+00],
       [1.42609300e-32, 9.98998225e-01, 1.00173661e-03],
       [1.39117623e-15, 9.52642552e-07, 9.99999046e-01],
       [3.84483914e-22, 9.92197454e-01, 7.80249853e-03],
       [4.86539571e-18, 1.89894184e-01, 8.10105860e-01],
       [4.77830746e-22, 9.99949574e-01, 5.04367490e-05],
       [5.31651115e-17, 6.26253346e-12, 1.00000000e+00],
       [5.39565920e-25, 9.99507