# TUTORIAL OF TENSORFLOW API 04 ==> FUNCTIONAL API TENSORFLOW


In [15]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model
from tensorflow.keras.models import Sequential, save_model, load_model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Input
from tensorflow.keras.losses import SparseCategoricalCrossentropy, BinaryCrossentropy
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt

# 1_Sequential Model -->  a -- b -- c
## one input --> one output


In [8]:
model1 = Sequential([
    Flatten(input_shape=(28,28)),
    Dense(128, activation='relu'),
    Dense(10)
], name='sequential model')
print(model1.summary())

Model: "sequential model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten_2 (Flatten)         (None, 784)               0         
                                                                 
 dense_4 (Dense)             (None, 128)               100480    
                                                                 
 dense_5 (Dense)             (None, 10)                1290      
                                                                 
Total params: 101770 (397.54 KB)
Trainable params: 101770 (397.54 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
None


# 2_Functional API

# Model --> a -- b -- [c, d]
## one input --> 2 outputs

In [11]:
inputs = Input(shape=(28,28))
flatten = Flatten()
dense1 = Dense(128, activation='relu')
dense2c = Dense(10)
dense2d = Dense(1)

In [12]:
x = flatten(inputs)
x = dense1(x)
output_c = dense2c(x)
output_d = dense2d(x)
outputs = [output_c, output_d]
model2 = Model(inputs=inputs, outputs=outputs, name='functional api')
print(model2.summary())

Model: "functional api"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_2 (InputLayer)        [(None, 28, 28)]             0         []                            
                                                                                                  
 flatten_3 (Flatten)         (None, 784)                  0         ['input_2[0][0]']             
                                                                                                  
 dense_6 (Dense)             (None, 128)                  100480    ['flatten_3[0][0]']           
                                                                                                  
 dense_7 (Dense)             (None, 10)                   1290      ['dense_6[0][0]']             
                                                                                     

# 3_Example with MNIST using Functional API

In [13]:
inputs1 = Input(shape=(28,28))
flatten1 = Flatten()
dense1 = Dense(128, activation='relu')

dense2 = Dense(10, activation='softmax', name='categorical_output')
dense3 = Dense(1, activation='sigmoid', name='left_right_output')

In [24]:
x = flatten1(inputs1)
x = dense1(x)

outputs1 = dense2(x)
outputs2 = dense3(x)

model_mnist = Model(inputs=inputs1, outputs=[outputs1, outputs2], name='mnist_functional_api')
print(model_mnist.summary())

Model: "mnist_functional_api"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_3 (InputLayer)        [(None, 28, 28)]             0         []                            
                                                                                                  
 flatten_4 (Flatten)         (None, 784)                  0         ['input_3[0][0]']             
                                                                                                  
 dense_9 (Dense)             (None, 128)                  100480    ['flatten_4[2][0]']           
                                                                                                  
 categorical_output (Dense)  (None, 10)                   1290      ['dense_9[2][0]']             
                                                                               

In [25]:
loss1 = SparseCategoricalCrossentropy(from_logits=False)
loss2 = BinaryCrossentropy(from_logits=False)
optimizer = Adam(learning_rate=0.001)
metrics = ['accuracy']
losses = {
    'categorical_output': loss1,
    'left_right_output': loss2
}
model_mnist.compile(loss=losses, optimizer=optimizer, metrics=metrics)

In [26]:
mnist = keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# 0=left, 1=right
y_leftright = np.zeros(y_train.shape, dtype=np.uint8)
for idx, y in enumerate(y_train):
    if y < 5:
        y_leftright[idx] = 0
    else:
        y_leftright[idx] = 1

print(y_train.dtype, y_train[:20])
print(y_leftright.dtype, y_leftright[:20])

y = {
    "categorical_output":y_train,
    "left_right_output":y_leftright
}

uint8 [5 0 4 1 9 2 1 3 1 4 3 5 3 6 1 7 2 8 6 9]
uint8 [1 0 0 0 1 0 0 0 0 0 0 1 0 1 0 1 0 1 1 1]


In [27]:
model_mnist.fit(
    x=x_train,
    y=y,
    epochs=5,
    batch_size=64,
    validation_split=0.2,
    verbose=1
)


Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.src.callbacks.History at 0x7dc4ad06eb00>

In [28]:
predictions = model_mnist.predict(x_test)
len(predictions)



2

In [29]:
predictions_categorical = predictions[0]
predictions_leftright = predictions[1]

pred_categorical = predictions_categorical[:20]
pred_leftright = predictions_leftright[:20]

label_categorical = np.argmax(pred_categorical, axis=1)
label_leftright = np.array([1 if p >= 0.5 else 0 for p in pred_leftright])

In [30]:
print(y_test[:20])
print(label_categorical)
print(label_leftright)

[7 2 1 0 4 1 4 9 5 9 0 6 9 0 1 5 9 7 3 4]
[7 2 1 0 4 1 4 9 6 9 0 6 9 0 1 5 9 7 3 4]
[1 0 0 0 0 0 0 1 1 1 0 1 1 0 0 1 1 1 0 0]
